Skip to main content

tokio/runtime/scheduler/multi_thread/
queue.rs

1//! Run-queue structures to support a work-stealing scheduler
2
3use crate::loom::cell::UnsafeCell;
4use crate::loom::sync::Arc;
5use crate::runtime::scheduler::multi_thread::{Overflow, Stats};
6use crate::runtime::task;
7
8use std::mem::{self, MaybeUninit};
9use std::ptr;
10use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
11
12// Use wider integers when possible to increase ABA resilience.
13//
14// See issue #5041: <https://github.com/tokio-rs/tokio/issues/5041>.
15cfg_has_atomic_u64! {
16    type UnsignedShort = u32;
17    type UnsignedLong = u64;
18    type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU32;
19    type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU64;
20}
21cfg_not_has_atomic_u64! {
22    type UnsignedShort = u16;
23    type UnsignedLong = u32;
24    type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU16;
25    type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU32;
26}
27
28/// Producer handle. May only be used from a single thread.
29pub(crate) struct Local<T: 'static> {
30    inner: Arc<Inner<T>>,
31}
32
33/// Consumer handle. May be used from many threads.
34pub(crate) struct Steal<T: 'static>(Arc<Inner<T>>);
35
36pub(crate) struct Inner<T: 'static> {
37    /// Concurrently updated by many threads.
38    ///
39    /// Contains two `UnsignedShort` values. The `LSB` byte is the "real" head of
40    /// the queue. The `UnsignedShort` in the `MSB` is set by a stealer in process
41    /// of stealing values. It represents the first value being stolen in the
42    /// batch. The `UnsignedShort` indices are intentionally wider than strictly
43    /// required for buffer indexing in order to provide ABA mitigation and make
44    /// it possible to distinguish between full and empty buffers.
45    ///
46    /// When both `UnsignedShort` values are the same, there is no active
47    /// stealer.
48    ///
49    /// Tracking an in-progress stealer prevents a wrapping scenario.
50    head: AtomicUnsignedLong,
51
52    /// Only updated by producer thread but read by many threads.
53    tail: AtomicUnsignedShort,
54
55    /// Elements
56    buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY]>,
57}
58
59unsafe impl<T> Send for Inner<T> {}
60unsafe impl<T> Sync for Inner<T> {}
61
62#[cfg(not(loom))]
63const LOCAL_QUEUE_CAPACITY: usize = 256;
64
65// Shrink the size of the local queue when using loom. This shouldn't impact
66// logic, but allows loom to test more edge cases in a reasonable a mount of
67// time.
68#[cfg(loom)]
69const LOCAL_QUEUE_CAPACITY: usize = 4;
70
71const MASK: usize = LOCAL_QUEUE_CAPACITY - 1;
72
73// Constructing the fixed size array directly is very awkward. The only way to
74// do it is to repeat `UnsafeCell::new(MaybeUninit::uninit())` 256 times, as
75// the contents are not Copy. The trick with defining a const doesn't work for
76// generic types.
77fn make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]> {
78    assert_eq!(buffer.len(), LOCAL_QUEUE_CAPACITY);
79
80    // safety: We check that the length is correct.
81    unsafe { Box::from_raw(Box::into_raw(buffer).cast()) }
82}
83
84/// Create a new local run-queue
85pub(crate) fn local<T: 'static>() -> (Steal<T>, Local<T>) {
86    let mut buffer = Vec::with_capacity(LOCAL_QUEUE_CAPACITY);
87
88    for _ in 0..LOCAL_QUEUE_CAPACITY {
89        buffer.push(UnsafeCell::new(MaybeUninit::uninit()));
90    }
91
92    let inner = Arc::new(Inner {
93        head: AtomicUnsignedLong::new(0),
94        tail: AtomicUnsignedShort::new(0),
95        buffer: make_fixed_size(buffer.into_boxed_slice()),
96    });
97
98    let local = Local {
99        inner: inner.clone(),
100    };
101
102    let remote = Steal(inner);
103
104    (remote, local)
105}
106
107impl<T> Local<T> {
108    /// Returns the number of entries in the queue
109    pub(crate) fn len(&self) -> usize {
110        let (_, head) = unpack(self.inner.head.load(Acquire));
111        // safety: this is the **only** thread that updates this cell.
112        let tail = unsafe { self.inner.tail.unsync_load() };
113        len(head, tail)
114    }
115
116    /// How many tasks can be pushed into the queue
117    pub(crate) fn remaining_slots(&self) -> usize {
118        let (steal, _) = unpack(self.inner.head.load(Acquire));
119        // safety: this is the **only** thread that updates this cell.
120        let tail = unsafe { self.inner.tail.unsync_load() };
121
122        LOCAL_QUEUE_CAPACITY - len(steal, tail)
123    }
124
125    pub(crate) fn max_capacity(&self) -> usize {
126        LOCAL_QUEUE_CAPACITY
127    }
128
129    /// Returns false if there are any entries in the queue
130    ///
131    /// Separate to `is_stealable` so that refactors of `is_stealable` to "protect"
132    /// some tasks from stealing won't affect this
133    pub(crate) fn has_tasks(&self) -> bool {
134        self.len() != 0
135    }
136
137    /// Pushes a batch of tasks to the back of the queue. All tasks must fit in
138    /// the local queue.
139    ///
140    /// # Panics
141    ///
142    /// The method panics if there is not enough capacity to fit in the queue.
143    pub(crate) fn push_back(&mut self, tasks: impl ExactSizeIterator<Item = task::Notified<T>>) {
144        let len = tasks.len();
145        assert!(len <= LOCAL_QUEUE_CAPACITY);
146
147        if len == 0 {
148            // Nothing to do
149            return;
150        }
151
152        let head = self.inner.head.load(Acquire);
153        let (steal, _) = unpack(head);
154
155        // safety: this is the **only** thread that updates this cell.
156        let mut tail = unsafe { self.inner.tail.unsync_load() };
157
158        if tail.wrapping_sub(steal) <= (LOCAL_QUEUE_CAPACITY - len) as UnsignedShort {
159            // Yes, this if condition is structured a bit weird (first block
160            // does nothing, second returns an error). It is this way to match
161            // `push_back_or_overflow`.
162        } else {
163            panic!()
164        }
165
166        for task in tasks {
167            let idx = tail as usize & MASK;
168
169            self.inner.buffer[idx].with_mut(|ptr| {
170                // Write the task to the slot
171                //
172                // Safety: There is only one producer and the above `if`
173                // condition ensures we don't touch a cell if there is a
174                // value, thus no consumer.
175                unsafe {
176                    ptr::write((*ptr).as_mut_ptr(), task);
177                }
178            });
179
180            tail = tail.wrapping_add(1);
181        }
182
183        self.inner.tail.store(tail, Release);
184    }
185
186    /// Pushes a task to the back of the local queue, if there is not enough
187    /// capacity in the queue, this triggers the overflow operation.
188    ///
189    /// When the queue overflows, half of the current contents of the queue is
190    /// moved to the given Injection queue. This frees up capacity for more
191    /// tasks to be pushed into the local queue.
192    pub(crate) fn push_back_or_overflow<O: Overflow<T>>(
193        &mut self,
194        mut task: task::Notified<T>,
195        overflow: &O,
196        stats: &mut Stats,
197    ) {
198        let tail = loop {
199            let head = self.inner.head.load(Acquire);
200            let (steal, real) = unpack(head);
201
202            // safety: this is the **only** thread that updates this cell.
203            let tail = unsafe { self.inner.tail.unsync_load() };
204
205            if tail.wrapping_sub(steal) < LOCAL_QUEUE_CAPACITY as UnsignedShort {
206                // There is capacity for the task
207                break tail;
208            } else if steal != real {
209                // Concurrently stealing, this will free up capacity, so only
210                // push the task onto the inject queue
211                overflow.push(task);
212                return;
213            } else {
214                // Push the current task and half of the queue into the
215                // inject queue.
216                match self.push_overflow(task, real, tail, overflow, stats) {
217                    Ok(_) => return,
218                    // Lost the race, try again
219                    Err(v) => {
220                        task = v;
221                    }
222                }
223            }
224        };
225
226        self.push_back_finish(task, tail);
227    }
228
229    // Second half of `push_back`
230    fn push_back_finish(&self, task: task::Notified<T>, tail: UnsignedShort) {
231        // Map the position to a slot index.
232        let idx = tail as usize & MASK;
233
234        self.inner.buffer[idx].with_mut(|ptr| {
235            // Write the task to the slot
236            //
237            // Safety: There is only one producer and the above `if`
238            // condition ensures we don't touch a cell if there is a
239            // value, thus no consumer.
240            unsafe {
241                ptr::write((*ptr).as_mut_ptr(), task);
242            }
243        });
244
245        // Make the task available. Synchronizes with a load in
246        // `steal_into2`.
247        self.inner.tail.store(tail.wrapping_add(1), Release);
248    }
249
250    /// Moves a batch of tasks into the inject queue.
251    ///
252    /// This will temporarily make some of the tasks unavailable to stealers.
253    /// Once `push_overflow` is done, a notification is sent out, so if other
254    /// workers "missed" some of the tasks during a steal, they will get
255    /// another opportunity.
256    #[inline(never)]
257    fn push_overflow<O: Overflow<T>>(
258        &mut self,
259        task: task::Notified<T>,
260        head: UnsignedShort,
261        tail: UnsignedShort,
262        overflow: &O,
263        stats: &mut Stats,
264    ) -> Result<(), task::Notified<T>> {
265        /// How many elements are we taking from the local queue.
266        ///
267        /// This is one less than the number of tasks pushed to the inject
268        /// queue as we are also inserting the `task` argument.
269        const NUM_TASKS_TAKEN: UnsignedShort = (LOCAL_QUEUE_CAPACITY / 2) as UnsignedShort;
270
271        assert_eq!(
272            tail.wrapping_sub(head) as usize,
273            LOCAL_QUEUE_CAPACITY,
274            "queue is not full; tail = {tail}; head = {head}"
275        );
276
277        // Claim all tasks.
278        //
279        // We are claiming the tasks **before** reading them out of the buffer.
280        // This is safe because only the **current** thread is able to push new
281        // tasks.
282        //
283        // There isn't really any need for memory ordering... Relaxed would
284        // work. This is because all tasks are pushed into the queue from the
285        // current thread (or memory has been acquired if the local queue handle
286        // moved).
287        if self
288            .inner
289            .head
290            .compare_exchange_weak(pack(head, head), pack(tail, tail), Release, Relaxed)
291            .is_err()
292        {
293            // We failed to claim the tasks, losing the race. Return out of
294            // this function and try the full `push` routine again. The queue
295            // may not be full anymore.
296            return Err(task);
297        }
298
299        // Add back the first half of tasks.
300        //
301        // We are doing it this way instead of just taking half of the tasks because we want the
302        // *second* half of the tasks, and if you just incremented `head` by `NUM_TASKS_TAKEN`,
303        // then you would be taking the first half instead of the second half.
304        //
305        // Pushing the second half of the local queue to the injection queue is better because when
306        // we take tasks *out* of the injection queue, we always place them in the first half. This
307        // means that if a task is in the second half, then we know for sure that this task is not
308        // a task we just got from the injection queue. This ensures that when we take a task out
309        // of the injection queue, then it will not be moved back into the injection queue (at
310        // least not until after we have polled it at least once).
311        //
312        // Note that if a concurrent worker tries to steal from us between these two operations and
313        // sees that the worker queue is empty, then that worker may go to sleep, and we do not
314        // notify it about these tasks becoming available for stealing again. Ordinarily this would
315        // be a problem, but it isn't in this case because the worker will be notified about the
316        // tasks we are adding to the injection queue instead, which ensures that the stealer wakes
317        // up again to take the tasks from the injection queue.
318        self.inner
319            .tail
320            .store(tail.wrapping_add(NUM_TASKS_TAKEN), Release);
321
322        /// An iterator that takes elements out of the run queue.
323        struct BatchTaskIter<'a, T: 'static> {
324            buffer: &'a [UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY],
325            head: UnsignedLong,
326            i: UnsignedLong,
327        }
328        impl<'a, T: 'static> Iterator for BatchTaskIter<'a, T> {
329            type Item = task::Notified<T>;
330
331            #[inline]
332            fn next(&mut self) -> Option<task::Notified<T>> {
333                if self.i == UnsignedLong::from(NUM_TASKS_TAKEN) {
334                    None
335                } else {
336                    let i_idx = self.i.wrapping_add(self.head) as usize & MASK;
337                    let slot = &self.buffer[i_idx];
338
339                    // safety: Our CAS from before has assumed exclusive ownership
340                    // of the task pointers in this range.
341                    let task = slot.with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
342
343                    self.i += 1;
344                    Some(task)
345                }
346            }
347        }
348
349        // safety: The CAS above ensures that no consumer will look at these
350        // values again, and we are the only producer.
351        let batch_iter = BatchTaskIter {
352            buffer: &self.inner.buffer,
353            head: head.wrapping_add(NUM_TASKS_TAKEN) as UnsignedLong,
354            i: 0,
355        };
356        overflow.push_batch(batch_iter.chain(std::iter::once(task)));
357
358        // Add 1 to factor in the task currently being scheduled.
359        stats.incr_overflow_count();
360
361        Ok(())
362    }
363
364    /// Pops a task from the local queue.
365    pub(crate) fn pop(&mut self) -> Option<task::Notified<T>> {
366        let mut head = self.inner.head.load(Acquire);
367
368        let idx = loop {
369            let (steal, real) = unpack(head);
370
371            // safety: this is the **only** thread that updates this cell.
372            let tail = unsafe { self.inner.tail.unsync_load() };
373
374            if real == tail {
375                // queue is empty
376                return None;
377            }
378
379            let next_real = real.wrapping_add(1);
380
381            // If `steal == real` there are no concurrent stealers. Both `steal`
382            // and `real` are updated.
383            let next = if steal == real {
384                pack(next_real, next_real)
385            } else {
386                assert_ne!(steal, next_real);
387                pack(steal, next_real)
388            };
389
390            // Attempt to claim a task.
391            let res = self
392                .inner
393                .head
394                .compare_exchange_weak(head, next, AcqRel, Acquire);
395
396            match res {
397                Ok(_) => break real as usize & MASK,
398                Err(actual) => head = actual,
399            }
400        };
401
402        Some(self.inner.buffer[idx].with(|ptr| unsafe { ptr::read(ptr).assume_init() }))
403    }
404}
405
406impl<T> Steal<T> {
407    /// Returns the number of entries in the queue
408    pub(crate) fn len(&self) -> usize {
409        let (_, head) = unpack(self.0.head.load(Acquire));
410        let tail = self.0.tail.load(Acquire);
411        len(head, tail)
412    }
413
414    /// Return true if the queue is empty,
415    /// false if there are any entries in the queue
416    pub(crate) fn is_empty(&self) -> bool {
417        self.len() == 0
418    }
419
420    /// Steals half the tasks from self and place them into `dst`.
421    pub(crate) fn steal_into(
422        &self,
423        dst: &mut Local<T>,
424        dst_stats: &mut Stats,
425    ) -> Option<task::Notified<T>> {
426        // Safety: the caller is the only thread that mutates `dst.tail` and
427        // holds a mutable reference.
428        let dst_tail = unsafe { dst.inner.tail.unsync_load() };
429
430        // To the caller, `dst` may **look** empty but still have values
431        // contained in the buffer. If another thread is concurrently stealing
432        // from `dst` there may not be enough capacity to steal.
433        let (steal, _) = unpack(dst.inner.head.load(Acquire));
434
435        if dst_tail.wrapping_sub(steal) > LOCAL_QUEUE_CAPACITY as UnsignedShort / 2 {
436            // we *could* try to steal less here, but for simplicity, we're just
437            // going to abort.
438            return None;
439        }
440
441        // Steal the tasks into `dst`'s buffer. This does not yet expose the
442        // tasks in `dst`.
443        let mut n = self.steal_into2(dst, dst_tail);
444
445        if n == 0 {
446            // No tasks were stolen
447            return None;
448        }
449
450        dst_stats.incr_steal_count(n as u16);
451        dst_stats.incr_steal_operations();
452
453        // We are returning a task here
454        n -= 1;
455
456        let ret_pos = dst_tail.wrapping_add(n);
457        let ret_idx = ret_pos as usize & MASK;
458
459        // safety: the value was written as part of `steal_into2` and not
460        // exposed to stealers, so no other thread can access it.
461        let ret = dst.inner.buffer[ret_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
462
463        if n == 0 {
464            // The `dst` queue is empty, but a single task was stolen
465            return Some(ret);
466        }
467
468        // Make the stolen items available to consumers
469        dst.inner.tail.store(dst_tail.wrapping_add(n), Release);
470
471        Some(ret)
472    }
473
474    // Steal tasks from `self`, placing them into `dst`. Returns the number of
475    // tasks that were stolen.
476    fn steal_into2(&self, dst: &mut Local<T>, dst_tail: UnsignedShort) -> UnsignedShort {
477        let mut prev_packed = self.0.head.load(Acquire);
478        let mut next_packed;
479
480        let n = loop {
481            let (src_head_steal, src_head_real) = unpack(prev_packed);
482            let src_tail = self.0.tail.load(Acquire);
483
484            // If these two do not match, another thread is concurrently
485            // stealing from the queue.
486            if src_head_steal != src_head_real {
487                return 0;
488            }
489
490            // Number of available tasks to steal
491            let n = src_tail.wrapping_sub(src_head_real);
492            let n = n - n / 2;
493
494            if n == 0 {
495                // No tasks available to steal
496                return 0;
497            }
498
499            // Update the real head index to acquire the tasks.
500            let steal_to = src_head_real.wrapping_add(n);
501            assert_ne!(src_head_steal, steal_to);
502            next_packed = pack(src_head_steal, steal_to);
503
504            // Claim all those tasks. This is done by incrementing the "real"
505            // head but not the steal. By doing this, no other thread is able to
506            // steal from this queue until the current thread completes.
507            let res = self
508                .0
509                .head
510                .compare_exchange_weak(prev_packed, next_packed, AcqRel, Acquire);
511
512            match res {
513                Ok(_) => break n,
514                Err(actual) => prev_packed = actual,
515            }
516        };
517
518        assert!(
519            n <= LOCAL_QUEUE_CAPACITY as UnsignedShort / 2,
520            "actual = {n}"
521        );
522
523        let (first, _) = unpack(next_packed);
524
525        // Take all the tasks
526        for i in 0..n {
527            // Compute the positions
528            let src_pos = first.wrapping_add(i);
529            let dst_pos = dst_tail.wrapping_add(i);
530
531            // Map to slots
532            let src_idx = src_pos as usize & MASK;
533            let dst_idx = dst_pos as usize & MASK;
534
535            // Read the task
536            //
537            // safety: We acquired the task with the atomic exchange above.
538            let task = self.0.buffer[src_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
539
540            // Write the task to the new slot
541            //
542            // safety: `dst` queue is empty and we are the only producer to
543            // this queue.
544            dst.inner.buffer[dst_idx]
545                .with_mut(|ptr| unsafe { ptr::write((*ptr).as_mut_ptr(), task) });
546        }
547
548        let mut prev_packed = next_packed;
549
550        // Update `src_head_steal` to match `src_head_real` signalling that the
551        // stealing routine is complete.
552        loop {
553            let head = unpack(prev_packed).1;
554            next_packed = pack(head, head);
555
556            let res = self
557                .0
558                .head
559                .compare_exchange_weak(prev_packed, next_packed, AcqRel, Acquire);
560
561            match res {
562                Ok(_) => return n,
563                Err(actual) => prev_packed = actual,
564            }
565        }
566    }
567}
568
569impl<T> Clone for Steal<T> {
570    fn clone(&self) -> Steal<T> {
571        Steal(self.0.clone())
572    }
573}
574
575impl<T> Drop for Local<T> {
576    fn drop(&mut self) {
577        if !std::thread::panicking() {
578            assert!(self.pop().is_none(), "queue not empty");
579        }
580    }
581}
582
583/// Calculate the length of the queue using the head and tail.
584/// The `head` can be the `steal` or `real` head.
585fn len(head: UnsignedShort, tail: UnsignedShort) -> usize {
586    tail.wrapping_sub(head) as usize
587}
588
589/// Split the head value into the real head and the index a stealer is working
590/// on.
591fn unpack(n: UnsignedLong) -> (UnsignedShort, UnsignedShort) {
592    let real = n & UnsignedShort::MAX as UnsignedLong;
593    let steal = n >> (mem::size_of::<UnsignedShort>() * 8);
594
595    (steal as UnsignedShort, real as UnsignedShort)
596}
597
598/// Join the two head values
599fn pack(steal: UnsignedShort, real: UnsignedShort) -> UnsignedLong {
600    (real as UnsignedLong) | ((steal as UnsignedLong) << (mem::size_of::<UnsignedShort>() * 8))
601}
602
603#[test]
604fn test_local_queue_capacity() {
605    assert!(LOCAL_QUEUE_CAPACITY - 1 <= u8::MAX as usize);
606}