tokio/runtime/task/
raw.rs

1// It doesn't make sense to enforce `unsafe_op_in_unsafe_fn` for this module because
2//
3// * This module is doing the low-level task management that requires tons of unsafe
4//   operations.
5// * Excessive `unsafe {}` blocks hurt readability significantly.
6// TODO: replace with `#[expect(unsafe_op_in_unsafe_fn)]` after bumpping
7// the MSRV to 1.81.0.
8#![allow(unsafe_op_in_unsafe_fn)]
9
10use crate::future::Future;
11use crate::runtime::task::core::{Core, Trailer};
12use crate::runtime::task::{Cell, Harness, Header, Id, Schedule, State};
13#[cfg(tokio_unstable)]
14use std::panic::Location;
15use std::ptr::NonNull;
16use std::task::{Poll, Waker};
17
18/// Raw task handle
19#[derive(Clone)]
20pub(crate) struct RawTask {
21    ptr: NonNull<Header>,
22}
23
24pub(super) struct Vtable {
25    /// Polls the future.
26    pub(super) poll: unsafe fn(NonNull<Header>),
27
28    /// Schedules the task for execution on the runtime.
29    pub(super) schedule: unsafe fn(NonNull<Header>),
30
31    /// Deallocates the memory.
32    pub(super) dealloc: unsafe fn(NonNull<Header>),
33
34    /// Reads the task output, if complete.
35    pub(super) try_read_output: unsafe fn(NonNull<Header>, *mut (), &Waker),
36
37    /// The join handle has been dropped.
38    pub(super) drop_join_handle_slow: unsafe fn(NonNull<Header>),
39
40    /// An abort handle has been dropped.
41    pub(super) drop_abort_handle: unsafe fn(NonNull<Header>),
42
43    /// Scheduler is being shutdown.
44    pub(super) shutdown: unsafe fn(NonNull<Header>),
45
46    /// The number of bytes that the `trailer` field is offset from the header.
47    pub(super) trailer_offset: usize,
48
49    /// The number of bytes that the `scheduler` field is offset from the header.
50    pub(super) scheduler_offset: usize,
51
52    /// The number of bytes that the `id` field is offset from the header.
53    pub(super) id_offset: usize,
54
55    /// The number of bytes that the `spawned_at` field is offset from the header.
56    #[cfg(tokio_unstable)]
57    pub(super) spawn_location_offset: usize,
58}
59
60/// Get the vtable for the requested `T` and `S` generics.
61pub(super) fn vtable<T: Future, S: Schedule>() -> &'static Vtable {
62    &Vtable {
63        poll: poll::<T, S>,
64        schedule: schedule::<S>,
65        dealloc: dealloc::<T, S>,
66        try_read_output: try_read_output::<T, S>,
67        drop_join_handle_slow: drop_join_handle_slow::<T, S>,
68        drop_abort_handle: drop_abort_handle::<T, S>,
69        shutdown: shutdown::<T, S>,
70        trailer_offset: OffsetHelper::<T, S>::TRAILER_OFFSET,
71        scheduler_offset: OffsetHelper::<T, S>::SCHEDULER_OFFSET,
72        id_offset: OffsetHelper::<T, S>::ID_OFFSET,
73        #[cfg(tokio_unstable)]
74        spawn_location_offset: OffsetHelper::<T, S>::SPAWN_LOCATION_OFFSET,
75    }
76}
77
78/// Calling `get_trailer_offset` directly in vtable doesn't work because it
79/// prevents the vtable from being promoted to a static reference.
80///
81/// See this thread for more info:
82/// <https://users.rust-lang.org/t/custom-vtables-with-integers/78508>
83struct OffsetHelper<T, S>(T, S);
84impl<T: Future, S: Schedule> OffsetHelper<T, S> {
85    // Pass `size_of`/`align_of` as arguments rather than calling them directly
86    // inside `get_trailer_offset` because trait bounds on generic parameters
87    // of const fn are unstable on our MSRV.
88    const TRAILER_OFFSET: usize = get_trailer_offset(
89        std::mem::size_of::<Header>(),
90        std::mem::size_of::<Core<T, S>>(),
91        std::mem::align_of::<Core<T, S>>(),
92        std::mem::align_of::<Trailer>(),
93    );
94
95    // The `scheduler` is the first field of `Core`, so it has the same
96    // offset as `Core`.
97    const SCHEDULER_OFFSET: usize = get_core_offset(
98        std::mem::size_of::<Header>(),
99        std::mem::align_of::<Core<T, S>>(),
100    );
101
102    const ID_OFFSET: usize = get_id_offset(
103        std::mem::size_of::<Header>(),
104        std::mem::align_of::<Core<T, S>>(),
105        std::mem::size_of::<S>(),
106        std::mem::align_of::<Id>(),
107    );
108
109    #[cfg(tokio_unstable)]
110    const SPAWN_LOCATION_OFFSET: usize = get_spawn_location_offset(
111        std::mem::size_of::<Header>(),
112        std::mem::align_of::<Core<T, S>>(),
113        std::mem::size_of::<S>(),
114        std::mem::align_of::<Id>(),
115        std::mem::size_of::<Id>(),
116        std::mem::align_of::<&'static Location<'static>>(),
117    );
118}
119
120/// Compute the offset of the `Trailer` field in `Cell<T, S>` using the
121/// `#[repr(C)]` algorithm.
122///
123/// Pseudo-code for the `#[repr(C)]` algorithm can be found here:
124/// <https://doc.rust-lang.org/reference/type-layout.html#reprc-structs>
125const fn get_trailer_offset(
126    header_size: usize,
127    core_size: usize,
128    core_align: usize,
129    trailer_align: usize,
130) -> usize {
131    let mut offset = header_size;
132
133    let core_misalign = offset % core_align;
134    if core_misalign > 0 {
135        offset += core_align - core_misalign;
136    }
137    offset += core_size;
138
139    let trailer_misalign = offset % trailer_align;
140    if trailer_misalign > 0 {
141        offset += trailer_align - trailer_misalign;
142    }
143
144    offset
145}
146
147/// Compute the offset of the `Core<T, S>` field in `Cell<T, S>` using the
148/// `#[repr(C)]` algorithm.
149///
150/// Pseudo-code for the `#[repr(C)]` algorithm can be found here:
151/// <https://doc.rust-lang.org/reference/type-layout.html#reprc-structs>
152const fn get_core_offset(header_size: usize, core_align: usize) -> usize {
153    let mut offset = header_size;
154
155    let core_misalign = offset % core_align;
156    if core_misalign > 0 {
157        offset += core_align - core_misalign;
158    }
159
160    offset
161}
162
163/// Compute the offset of the `Id` field in `Cell<T, S>` using the
164/// `#[repr(C)]` algorithm.
165///
166/// Pseudo-code for the `#[repr(C)]` algorithm can be found here:
167/// <https://doc.rust-lang.org/reference/type-layout.html#reprc-structs>
168const fn get_id_offset(
169    header_size: usize,
170    core_align: usize,
171    scheduler_size: usize,
172    id_align: usize,
173) -> usize {
174    let mut offset = get_core_offset(header_size, core_align);
175    offset += scheduler_size;
176
177    let id_misalign = offset % id_align;
178    if id_misalign > 0 {
179        offset += id_align - id_misalign;
180    }
181
182    offset
183}
184
185/// Compute the offset of the `&'static Location<'static>` field in `Cell<T, S>`
186/// using the `#[repr(C)]` algorithm.
187///
188/// Pseudo-code for the `#[repr(C)]` algorithm can be found here:
189/// <https://doc.rust-lang.org/reference/type-layout.html#reprc-structs>
190#[cfg(tokio_unstable)]
191const fn get_spawn_location_offset(
192    header_size: usize,
193    core_align: usize,
194    scheduler_size: usize,
195    id_align: usize,
196    id_size: usize,
197    spawn_location_align: usize,
198) -> usize {
199    let mut offset = get_id_offset(header_size, core_align, scheduler_size, id_align);
200    offset += id_size;
201
202    let spawn_location_misalign = offset % spawn_location_align;
203    if spawn_location_misalign > 0 {
204        offset += spawn_location_align - spawn_location_misalign;
205    }
206
207    offset
208}
209
210impl RawTask {
211    pub(super) fn new<T, S>(
212        task: T,
213        scheduler: S,
214        id: Id,
215        _spawned_at: super::SpawnLocation,
216    ) -> RawTask
217    where
218        T: Future,
219        S: Schedule,
220    {
221        let ptr = Box::into_raw(Cell::<_, S>::new(
222            task,
223            scheduler,
224            State::new(),
225            id,
226            #[cfg(tokio_unstable)]
227            _spawned_at.0,
228        ));
229        let ptr = unsafe { NonNull::new_unchecked(ptr.cast()) };
230
231        RawTask { ptr }
232    }
233
234    /// # Safety
235    ///
236    /// `ptr` must be a valid pointer to a [`Header`].
237    pub(super) unsafe fn from_raw(ptr: NonNull<Header>) -> RawTask {
238        RawTask { ptr }
239    }
240
241    pub(super) fn header_ptr(&self) -> NonNull<Header> {
242        self.ptr
243    }
244
245    pub(super) fn trailer_ptr(&self) -> NonNull<Trailer> {
246        unsafe { Header::get_trailer(self.ptr) }
247    }
248
249    /// Returns a reference to the task's header.
250    pub(super) fn header(&self) -> &Header {
251        unsafe { self.ptr.as_ref() }
252    }
253
254    /// Returns a reference to the task's trailer.
255    pub(super) fn trailer(&self) -> &Trailer {
256        unsafe { &*self.trailer_ptr().as_ptr() }
257    }
258
259    /// Returns a reference to the task's state.
260    pub(super) fn state(&self) -> &State {
261        &self.header().state
262    }
263
264    /// Safety: mutual exclusion is required to call this function.
265    pub(crate) fn poll(self) {
266        let vtable = self.header().vtable;
267        unsafe { (vtable.poll)(self.ptr) }
268    }
269
270    pub(super) fn schedule(self) {
271        let vtable = self.header().vtable;
272        unsafe { (vtable.schedule)(self.ptr) }
273    }
274
275    pub(super) fn dealloc(self) {
276        let vtable = self.header().vtable;
277        unsafe {
278            (vtable.dealloc)(self.ptr);
279        }
280    }
281
282    /// Safety: `dst` must be a `*mut Poll<super::Result<T::Output>>` where `T`
283    /// is the future stored by the task.
284    pub(super) unsafe fn try_read_output<O>(self, dst: *mut Poll<super::Result<O>>, waker: &Waker) {
285        let vtable = self.header().vtable;
286        (vtable.try_read_output)(self.ptr, dst as *mut _, waker);
287    }
288
289    pub(super) fn drop_join_handle_slow(self) {
290        let vtable = self.header().vtable;
291        unsafe { (vtable.drop_join_handle_slow)(self.ptr) }
292    }
293
294    pub(super) fn drop_abort_handle(self) {
295        let vtable = self.header().vtable;
296        unsafe { (vtable.drop_abort_handle)(self.ptr) }
297    }
298
299    pub(super) fn shutdown(self) {
300        let vtable = self.header().vtable;
301        unsafe { (vtable.shutdown)(self.ptr) }
302    }
303
304    /// Increment the task's reference count.
305    ///
306    /// Currently, this is used only when creating an `AbortHandle`.
307    pub(super) fn ref_inc(self) {
308        self.header().state.ref_inc();
309    }
310
311    /// Get the queue-next pointer
312    ///
313    /// This is for usage by the injection queue
314    ///
315    /// Safety: make sure only one queue uses this and access is synchronized.
316    pub(crate) unsafe fn get_queue_next(self) -> Option<RawTask> {
317        self.header()
318            .queue_next
319            .with(|ptr| *ptr)
320            .map(|p| RawTask::from_raw(p))
321    }
322
323    /// Sets the queue-next pointer
324    ///
325    /// This is for usage by the injection queue
326    ///
327    /// Safety: make sure only one queue uses this and access is synchronized.
328    pub(crate) unsafe fn set_queue_next(self, val: Option<RawTask>) {
329        self.header().set_next(val.map(|task| task.ptr));
330    }
331}
332
333impl Copy for RawTask {}
334
335unsafe fn poll<T: Future, S: Schedule>(ptr: NonNull<Header>) {
336    let harness = Harness::<T, S>::from_raw(ptr);
337    harness.poll();
338}
339
340unsafe fn schedule<S: Schedule>(ptr: NonNull<Header>) {
341    use crate::runtime::task::{Notified, Task};
342
343    let scheduler = Header::get_scheduler::<S>(ptr);
344    scheduler
345        .as_ref()
346        .schedule(Notified(Task::from_raw(ptr.cast())));
347}
348
349unsafe fn dealloc<T: Future, S: Schedule>(ptr: NonNull<Header>) {
350    let harness = Harness::<T, S>::from_raw(ptr);
351    harness.dealloc();
352}
353
354unsafe fn try_read_output<T: Future, S: Schedule>(
355    ptr: NonNull<Header>,
356    dst: *mut (),
357    waker: &Waker,
358) {
359    let out = &mut *(dst as *mut Poll<super::Result<T::Output>>);
360
361    let harness = Harness::<T, S>::from_raw(ptr);
362    harness.try_read_output(out, waker);
363}
364
365unsafe fn drop_join_handle_slow<T: Future, S: Schedule>(ptr: NonNull<Header>) {
366    let harness = Harness::<T, S>::from_raw(ptr);
367    harness.drop_join_handle_slow();
368}
369
370unsafe fn drop_abort_handle<T: Future, S: Schedule>(ptr: NonNull<Header>) {
371    let harness = Harness::<T, S>::from_raw(ptr);
372    harness.drop_reference();
373}
374
375unsafe fn shutdown<T: Future, S: Schedule>(ptr: NonNull<Header>) {
376    let harness = Harness::<T, S>::from_raw(ptr);
377    harness.shutdown();
378}