tokio/runtime/task/
list.rs

1//! This module has containers for storing the tasks spawned on a scheduler. The
2//! `OwnedTasks` container is thread-safe but can only store tasks that
3//! implement Send. The `LocalOwnedTasks` container is not thread safe, but can
4//! store non-Send tasks.
5//!
6//! The collections can be closed to prevent adding new tasks during shutdown of
7//! the scheduler with the collection.
8
9use crate::future::Future;
10use crate::loom::cell::UnsafeCell;
11use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, SpawnLocation, Task};
12use crate::util::linked_list::{Link, LinkedList};
13use crate::util::sharded_list;
14
15use crate::loom::sync::atomic::{AtomicBool, Ordering};
16use std::marker::PhantomData;
17use std::num::NonZeroU64;
18
19// The id from the module below is used to verify whether a given task is stored
20// in this OwnedTasks, or some other task. The counter starts at one so we can
21// use `None` for tasks not owned by any list.
22//
23// The safety checks in this file can technically be violated if the counter is
24// overflown, but the checks are not supposed to ever fail unless there is a
25// bug in Tokio, so we accept that certain bugs would not be caught if the two
26// mixed up runtimes happen to have the same id.
27
28cfg_has_atomic_u64! {
29    use std::sync::atomic::AtomicU64;
30
31    static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1);
32
33    fn get_next_id() -> NonZeroU64 {
34        loop {
35            let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
36            if let Some(id) = NonZeroU64::new(id) {
37                return id;
38            }
39        }
40    }
41}
42
43cfg_not_has_atomic_u64! {
44    use std::sync::atomic::AtomicU32;
45
46    static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1);
47
48    fn get_next_id() -> NonZeroU64 {
49        loop {
50            let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
51            if let Some(id) = NonZeroU64::new(u64::from(id)) {
52                return id;
53            }
54        }
55    }
56}
57
58pub(crate) struct OwnedTasks<S: 'static> {
59    list: List<S>,
60    pub(crate) id: NonZeroU64,
61    closed: AtomicBool,
62}
63
64type List<S> = sharded_list::ShardedList<Task<S>, <Task<S> as Link>::Target>;
65
66pub(crate) struct LocalOwnedTasks<S: 'static> {
67    inner: UnsafeCell<OwnedTasksInner<S>>,
68    pub(crate) id: NonZeroU64,
69    _not_send_or_sync: PhantomData<*const ()>,
70}
71
72struct OwnedTasksInner<S: 'static> {
73    list: LinkedList<Task<S>, <Task<S> as Link>::Target>,
74    closed: bool,
75}
76
77impl<S: 'static> OwnedTasks<S> {
78    pub(crate) fn new(num_cores: usize) -> Self {
79        let shard_size = Self::gen_shared_list_size(num_cores);
80        Self {
81            list: List::new(shard_size),
82            closed: AtomicBool::new(false),
83            id: get_next_id(),
84        }
85    }
86
87    /// Binds the provided task to this `OwnedTasks` instance. This fails if the
88    /// `OwnedTasks` has been closed.
89    pub(crate) fn bind<T>(
90        &self,
91        task: T,
92        scheduler: S,
93        id: super::Id,
94        spawned_at: SpawnLocation,
95    ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
96    where
97        S: Schedule,
98        T: Future + Send + 'static,
99        T::Output: Send + 'static,
100    {
101        let (task, notified, join) = super::new_task(task, scheduler, id, spawned_at);
102        let notified = unsafe { self.bind_inner(task, notified) };
103        (join, notified)
104    }
105
106    /// Bind a task that isn't safe to transfer across thread boundaries.
107    ///
108    /// # Safety
109    ///
110    /// Only use this in `LocalRuntime` where the task cannot move
111    pub(crate) unsafe fn bind_local<T>(
112        &self,
113        task: T,
114        scheduler: S,
115        id: super::Id,
116        spawned_at: SpawnLocation,
117    ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
118    where
119        S: Schedule,
120        T: Future + 'static,
121        T::Output: 'static,
122    {
123        let (task, notified, join) = super::new_task(task, scheduler, id, spawned_at);
124        let notified = unsafe { self.bind_inner(task, notified) };
125        (join, notified)
126    }
127
128    /// The part of `bind` that's the same for every type of future.
129    unsafe fn bind_inner(&self, task: Task<S>, notified: Notified<S>) -> Option<Notified<S>>
130    where
131        S: Schedule,
132    {
133        unsafe {
134            // safety: We just created the task, so we have exclusive access
135            // to the field.
136            task.header().set_owner_id(self.id);
137        }
138
139        let shard = self.list.lock_shard(&task);
140        // Check the closed flag in the lock for ensuring all that tasks
141        // will shut down after the OwnedTasks has been closed.
142        if self.closed.load(Ordering::Acquire) {
143            drop(shard);
144            task.shutdown();
145            return None;
146        }
147        shard.push(task);
148        Some(notified)
149    }
150
151    /// Asserts that the given task is owned by this `OwnedTasks` and convert it to
152    /// a `LocalNotified`, giving the thread permission to poll this task.
153    #[inline]
154    pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
155        debug_assert_eq!(task.header().get_owner_id(), Some(self.id));
156        // safety: All tasks bound to this OwnedTasks are Send, so it is safe
157        // to poll it on this thread no matter what thread we are on.
158        LocalNotified {
159            task: task.0,
160            _not_send: PhantomData,
161        }
162    }
163
164    /// Shuts down all tasks in the collection. This call also closes the
165    /// collection, preventing new items from being added.
166    ///
167    /// The parameter start determines which shard this method will start at.
168    /// Using different values for each worker thread reduces contention.
169    pub(crate) fn close_and_shutdown_all(&self, start: usize)
170    where
171        S: Schedule,
172    {
173        self.closed.store(true, Ordering::Release);
174        for i in start..self.get_shard_size() + start {
175            loop {
176                let task = self.list.pop_back(i);
177                match task {
178                    Some(task) => {
179                        task.shutdown();
180                    }
181                    None => break,
182                }
183            }
184        }
185    }
186
187    #[inline]
188    pub(crate) fn get_shard_size(&self) -> usize {
189        self.list.shard_size()
190    }
191
192    pub(crate) fn num_alive_tasks(&self) -> usize {
193        self.list.len()
194    }
195
196    cfg_64bit_metrics! {
197        pub(crate) fn spawned_tasks_count(&self) -> u64 {
198            self.list.added()
199        }
200    }
201
202    pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
203        // If the task's owner ID is `None` then it is not part of any list and
204        // doesn't need removing.
205        let task_id = task.header().get_owner_id()?;
206
207        assert_eq!(task_id, self.id);
208
209        // safety: We just checked that the provided task is not in some other
210        // linked list.
211        unsafe { self.list.remove(task.header_ptr()) }
212    }
213
214    pub(crate) fn is_empty(&self) -> bool {
215        self.list.is_empty()
216    }
217
218    /// Generates the size of the sharded list based on the number of worker threads.
219    ///
220    /// The sharded lock design can effectively alleviate
221    /// lock contention performance problems caused by high concurrency.
222    ///
223    /// However, as the number of shards increases, the memory continuity between
224    /// nodes in the intrusive linked list will diminish. Furthermore,
225    /// the construction time of the sharded list will also increase with a higher number of shards.
226    ///
227    /// Due to the above reasons, we set a maximum value for the shared list size,
228    /// denoted as `MAX_SHARED_LIST_SIZE`.
229    fn gen_shared_list_size(num_cores: usize) -> usize {
230        const MAX_SHARED_LIST_SIZE: usize = 1 << 16;
231        usize::min(MAX_SHARED_LIST_SIZE, num_cores.next_power_of_two() * 4)
232    }
233}
234
235cfg_taskdump! {
236    impl<S: 'static> OwnedTasks<S> {
237        /// Locks the tasks, and calls `f` on an iterator over them.
238        pub(crate) fn for_each<F>(&self, f: F)
239        where
240            F: FnMut(&Task<S>),
241        {
242            self.list.for_each(f);
243        }
244    }
245}
246
247impl<S: 'static> LocalOwnedTasks<S> {
248    pub(crate) fn new() -> Self {
249        Self {
250            inner: UnsafeCell::new(OwnedTasksInner {
251                list: LinkedList::new(),
252                closed: false,
253            }),
254            id: get_next_id(),
255            _not_send_or_sync: PhantomData,
256        }
257    }
258
259    pub(crate) fn bind<T>(
260        &self,
261        task: T,
262        scheduler: S,
263        id: super::Id,
264        spawned_at: SpawnLocation,
265    ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
266    where
267        S: Schedule,
268        T: Future + 'static,
269        T::Output: 'static,
270    {
271        let (task, notified, join) = super::new_task(task, scheduler, id, spawned_at);
272
273        unsafe {
274            // safety: We just created the task, so we have exclusive access
275            // to the field.
276            task.header().set_owner_id(self.id);
277        }
278
279        if self.is_closed() {
280            drop(notified);
281            task.shutdown();
282            (join, None)
283        } else {
284            self.with_inner(|inner| {
285                inner.list.push_front(task);
286            });
287            (join, Some(notified))
288        }
289    }
290
291    /// Shuts down all tasks in the collection. This call also closes the
292    /// collection, preventing new items from being added.
293    pub(crate) fn close_and_shutdown_all(&self)
294    where
295        S: Schedule,
296    {
297        self.with_inner(|inner| inner.closed = true);
298
299        while let Some(task) = self.with_inner(|inner| inner.list.pop_back()) {
300            task.shutdown();
301        }
302    }
303
304    pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
305        // If the task's owner ID is `None` then it is not part of any list and
306        // doesn't need removing.
307        let task_id = task.header().get_owner_id()?;
308
309        assert_eq!(task_id, self.id);
310
311        self.with_inner(|inner|
312            // safety: We just checked that the provided task is not in some
313            // other linked list.
314            unsafe { inner.list.remove(task.header_ptr()) })
315    }
316
317    /// Asserts that the given task is owned by this `LocalOwnedTasks` and convert
318    /// it to a `LocalNotified`, giving the thread permission to poll this task.
319    #[inline]
320    pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
321        assert_eq!(task.header().get_owner_id(), Some(self.id));
322
323        // safety: The task was bound to this LocalOwnedTasks, and the
324        // LocalOwnedTasks is not Send or Sync, so we are on the right thread
325        // for polling this task.
326        LocalNotified {
327            task: task.0,
328            _not_send: PhantomData,
329        }
330    }
331
332    #[inline]
333    fn with_inner<F, T>(&self, f: F) -> T
334    where
335        F: FnOnce(&mut OwnedTasksInner<S>) -> T,
336    {
337        // safety: This type is not Sync, so concurrent calls of this method
338        // can't happen.  Furthermore, all uses of this method in this file make
339        // sure that they don't call `with_inner` recursively.
340        self.inner.with_mut(|ptr| unsafe { f(&mut *ptr) })
341    }
342
343    pub(crate) fn is_closed(&self) -> bool {
344        self.with_inner(|inner| inner.closed)
345    }
346
347    pub(crate) fn is_empty(&self) -> bool {
348        self.with_inner(|inner| inner.list.is_empty())
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    // This test may run in parallel with other tests, so we only test that ids
357    // come in increasing order.
358    #[test]
359    fn test_id_not_broken() {
360        let mut last_id = get_next_id();
361
362        for _ in 0..1000 {
363            let next_id = get_next_id();
364            assert!(last_id < next_id);
365            last_id = next_id;
366        }
367    }
368}