tokio/runtime/scheduler/multi_thread/queue.rs
1//! Run-queue structures to support a work-stealing scheduler
2
3use crate::loom::cell::UnsafeCell;
4use crate::loom::sync::Arc;
5use crate::runtime::scheduler::multi_thread::{Overflow, Stats};
6use crate::runtime::task;
7
8use std::mem::{self, MaybeUninit};
9use std::ptr;
10use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
11
12// Use wider integers when possible to increase ABA resilience.
13//
14// See issue #5041: <https://github.com/tokio-rs/tokio/issues/5041>.
15cfg_has_atomic_u64! {
16 type UnsignedShort = u32;
17 type UnsignedLong = u64;
18 type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU32;
19 type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU64;
20}
21cfg_not_has_atomic_u64! {
22 type UnsignedShort = u16;
23 type UnsignedLong = u32;
24 type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU16;
25 type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU32;
26}
27
28/// Producer handle. May only be used from a single thread.
29pub(crate) struct Local<T: 'static> {
30 inner: Arc<Inner<T>>,
31}
32
33/// Consumer handle. May be used from many threads.
34pub(crate) struct Steal<T: 'static>(Arc<Inner<T>>);
35
36pub(crate) struct Inner<T: 'static> {
37 /// Concurrently updated by many threads.
38 ///
39 /// Contains two `UnsignedShort` values. The `LSB` byte is the "real" head of
40 /// the queue. The `UnsignedShort` in the `MSB` is set by a stealer in process
41 /// of stealing values. It represents the first value being stolen in the
42 /// batch. The `UnsignedShort` indices are intentionally wider than strictly
43 /// required for buffer indexing in order to provide ABA mitigation and make
44 /// it possible to distinguish between full and empty buffers.
45 ///
46 /// When both `UnsignedShort` values are the same, there is no active
47 /// stealer.
48 ///
49 /// Tracking an in-progress stealer prevents a wrapping scenario.
50 head: AtomicUnsignedLong,
51
52 /// Only updated by producer thread but read by many threads.
53 tail: AtomicUnsignedShort,
54
55 /// Elements
56 buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY]>,
57}
58
59unsafe impl<T> Send for Inner<T> {}
60unsafe impl<T> Sync for Inner<T> {}
61
62#[cfg(not(loom))]
63const LOCAL_QUEUE_CAPACITY: usize = 256;
64
65// Shrink the size of the local queue when using loom. This shouldn't impact
66// logic, but allows loom to test more edge cases in a reasonable a mount of
67// time.
68#[cfg(loom)]
69const LOCAL_QUEUE_CAPACITY: usize = 4;
70
71const MASK: usize = LOCAL_QUEUE_CAPACITY - 1;
72
73// Constructing the fixed size array directly is very awkward. The only way to
74// do it is to repeat `UnsafeCell::new(MaybeUninit::uninit())` 256 times, as
75// the contents are not Copy. The trick with defining a const doesn't work for
76// generic types.
77fn make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]> {
78 assert_eq!(buffer.len(), LOCAL_QUEUE_CAPACITY);
79
80 // safety: We check that the length is correct.
81 unsafe { Box::from_raw(Box::into_raw(buffer).cast()) }
82}
83
84/// Create a new local run-queue
85pub(crate) fn local<T: 'static>() -> (Steal<T>, Local<T>) {
86 let mut buffer = Vec::with_capacity(LOCAL_QUEUE_CAPACITY);
87
88 for _ in 0..LOCAL_QUEUE_CAPACITY {
89 buffer.push(UnsafeCell::new(MaybeUninit::uninit()));
90 }
91
92 let inner = Arc::new(Inner {
93 head: AtomicUnsignedLong::new(0),
94 tail: AtomicUnsignedShort::new(0),
95 buffer: make_fixed_size(buffer.into_boxed_slice()),
96 });
97
98 let local = Local {
99 inner: inner.clone(),
100 };
101
102 let remote = Steal(inner);
103
104 (remote, local)
105}
106
107impl<T> Local<T> {
108 /// Returns the number of entries in the queue
109 pub(crate) fn len(&self) -> usize {
110 let (_, head) = unpack(self.inner.head.load(Acquire));
111 // safety: this is the **only** thread that updates this cell.
112 let tail = unsafe { self.inner.tail.unsync_load() };
113 len(head, tail)
114 }
115
116 /// How many tasks can be pushed into the queue
117 pub(crate) fn remaining_slots(&self) -> usize {
118 let (steal, _) = unpack(self.inner.head.load(Acquire));
119 // safety: this is the **only** thread that updates this cell.
120 let tail = unsafe { self.inner.tail.unsync_load() };
121
122 LOCAL_QUEUE_CAPACITY - len(steal, tail)
123 }
124
125 pub(crate) fn max_capacity(&self) -> usize {
126 LOCAL_QUEUE_CAPACITY
127 }
128
129 /// Returns false if there are any entries in the queue
130 ///
131 /// Separate to `is_stealable` so that refactors of `is_stealable` to "protect"
132 /// some tasks from stealing won't affect this
133 pub(crate) fn has_tasks(&self) -> bool {
134 self.len() != 0
135 }
136
137 /// Pushes a batch of tasks to the back of the queue. All tasks must fit in
138 /// the local queue.
139 ///
140 /// # Panics
141 ///
142 /// The method panics if there is not enough capacity to fit in the queue.
143 pub(crate) fn push_back(&mut self, tasks: impl ExactSizeIterator<Item = task::Notified<T>>) {
144 let len = tasks.len();
145 assert!(len <= LOCAL_QUEUE_CAPACITY);
146
147 if len == 0 {
148 // Nothing to do
149 return;
150 }
151
152 let head = self.inner.head.load(Acquire);
153 let (steal, _) = unpack(head);
154
155 // safety: this is the **only** thread that updates this cell.
156 let mut tail = unsafe { self.inner.tail.unsync_load() };
157
158 if tail.wrapping_sub(steal) <= (LOCAL_QUEUE_CAPACITY - len) as UnsignedShort {
159 // Yes, this if condition is structured a bit weird (first block
160 // does nothing, second returns an error). It is this way to match
161 // `push_back_or_overflow`.
162 } else {
163 panic!()
164 }
165
166 for task in tasks {
167 let idx = tail as usize & MASK;
168
169 self.inner.buffer[idx].with_mut(|ptr| {
170 // Write the task to the slot
171 //
172 // Safety: There is only one producer and the above `if`
173 // condition ensures we don't touch a cell if there is a
174 // value, thus no consumer.
175 unsafe {
176 ptr::write((*ptr).as_mut_ptr(), task);
177 }
178 });
179
180 tail = tail.wrapping_add(1);
181 }
182
183 self.inner.tail.store(tail, Release);
184 }
185
186 /// Pushes a task to the back of the local queue, if there is not enough
187 /// capacity in the queue, this triggers the overflow operation.
188 ///
189 /// When the queue overflows, half of the current contents of the queue is
190 /// moved to the given Injection queue. This frees up capacity for more
191 /// tasks to be pushed into the local queue.
192 pub(crate) fn push_back_or_overflow<O: Overflow<T>>(
193 &mut self,
194 mut task: task::Notified<T>,
195 overflow: &O,
196 stats: &mut Stats,
197 ) {
198 let tail = loop {
199 let head = self.inner.head.load(Acquire);
200 let (steal, real) = unpack(head);
201
202 // safety: this is the **only** thread that updates this cell.
203 let tail = unsafe { self.inner.tail.unsync_load() };
204
205 if tail.wrapping_sub(steal) < LOCAL_QUEUE_CAPACITY as UnsignedShort {
206 // There is capacity for the task
207 break tail;
208 } else if steal != real {
209 // Concurrently stealing, this will free up capacity, so only
210 // push the task onto the inject queue
211 overflow.push(task);
212 return;
213 } else {
214 // Push the current task and half of the queue into the
215 // inject queue.
216 match self.push_overflow(task, real, tail, overflow, stats) {
217 Ok(_) => return,
218 // Lost the race, try again
219 Err(v) => {
220 task = v;
221 }
222 }
223 }
224 };
225
226 self.push_back_finish(task, tail);
227 }
228
229 // Second half of `push_back`
230 fn push_back_finish(&self, task: task::Notified<T>, tail: UnsignedShort) {
231 // Map the position to a slot index.
232 let idx = tail as usize & MASK;
233
234 self.inner.buffer[idx].with_mut(|ptr| {
235 // Write the task to the slot
236 //
237 // Safety: There is only one producer and the above `if`
238 // condition ensures we don't touch a cell if there is a
239 // value, thus no consumer.
240 unsafe {
241 ptr::write((*ptr).as_mut_ptr(), task);
242 }
243 });
244
245 // Make the task available. Synchronizes with a load in
246 // `steal_into2`.
247 self.inner.tail.store(tail.wrapping_add(1), Release);
248 }
249
250 /// Moves a batch of tasks into the inject queue.
251 ///
252 /// This will temporarily make some of the tasks unavailable to stealers.
253 /// Once `push_overflow` is done, a notification is sent out, so if other
254 /// workers "missed" some of the tasks during a steal, they will get
255 /// another opportunity.
256 #[inline(never)]
257 fn push_overflow<O: Overflow<T>>(
258 &mut self,
259 task: task::Notified<T>,
260 head: UnsignedShort,
261 tail: UnsignedShort,
262 overflow: &O,
263 stats: &mut Stats,
264 ) -> Result<(), task::Notified<T>> {
265 /// How many elements are we taking from the local queue.
266 ///
267 /// This is one less than the number of tasks pushed to the inject
268 /// queue as we are also inserting the `task` argument.
269 const NUM_TASKS_TAKEN: UnsignedShort = (LOCAL_QUEUE_CAPACITY / 2) as UnsignedShort;
270
271 assert_eq!(
272 tail.wrapping_sub(head) as usize,
273 LOCAL_QUEUE_CAPACITY,
274 "queue is not full; tail = {tail}; head = {head}"
275 );
276
277 // Claim all tasks.
278 //
279 // We are claiming the tasks **before** reading them out of the buffer.
280 // This is safe because only the **current** thread is able to push new
281 // tasks.
282 //
283 // There isn't really any need for memory ordering... Relaxed would
284 // work. This is because all tasks are pushed into the queue from the
285 // current thread (or memory has been acquired if the local queue handle
286 // moved).
287 if self
288 .inner
289 .head
290 .compare_exchange_weak(pack(head, head), pack(tail, tail), Release, Relaxed)
291 .is_err()
292 {
293 // We failed to claim the tasks, losing the race. Return out of
294 // this function and try the full `push` routine again. The queue
295 // may not be full anymore.
296 return Err(task);
297 }
298
299 // Add back the first half of tasks.
300 //
301 // We are doing it this way instead of just taking half of the tasks because we want the
302 // *second* half of the tasks, and if you just incremented `head` by `NUM_TASKS_TAKEN`,
303 // then you would be taking the first half instead of the second half.
304 //
305 // Pushing the second half of the local queue to the injection queue is better because when
306 // we take tasks *out* of the injection queue, we always place them in the first half. This
307 // means that if a task is in the second half, then we know for sure that this task is not
308 // a task we just got from the injection queue. This ensures that when we take a task out
309 // of the injection queue, then it will not be moved back into the injection queue (at
310 // least not until after we have polled it at least once).
311 //
312 // Note that if a concurrent worker tries to steal from us between these two operations and
313 // sees that the worker queue is empty, then that worker may go to sleep, and we do not
314 // notify it about these tasks becoming available for stealing again. Ordinarily this would
315 // be a problem, but it isn't in this case because the worker will be notified about the
316 // tasks we are adding to the injection queue instead, which ensures that the stealer wakes
317 // up again to take the tasks from the injection queue.
318 self.inner
319 .tail
320 .store(tail.wrapping_add(NUM_TASKS_TAKEN), Release);
321
322 /// An iterator that takes elements out of the run queue.
323 struct BatchTaskIter<'a, T: 'static> {
324 buffer: &'a [UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY],
325 head: UnsignedLong,
326 i: UnsignedLong,
327 }
328 impl<'a, T: 'static> Iterator for BatchTaskIter<'a, T> {
329 type Item = task::Notified<T>;
330
331 #[inline]
332 fn next(&mut self) -> Option<task::Notified<T>> {
333 if self.i == UnsignedLong::from(NUM_TASKS_TAKEN) {
334 None
335 } else {
336 let i_idx = self.i.wrapping_add(self.head) as usize & MASK;
337 let slot = &self.buffer[i_idx];
338
339 // safety: Our CAS from before has assumed exclusive ownership
340 // of the task pointers in this range.
341 let task = slot.with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
342
343 self.i += 1;
344 Some(task)
345 }
346 }
347 }
348
349 // safety: The CAS above ensures that no consumer will look at these
350 // values again, and we are the only producer.
351 let batch_iter = BatchTaskIter {
352 buffer: &self.inner.buffer,
353 head: head.wrapping_add(NUM_TASKS_TAKEN) as UnsignedLong,
354 i: 0,
355 };
356 overflow.push_batch(batch_iter.chain(std::iter::once(task)));
357
358 // Add 1 to factor in the task currently being scheduled.
359 stats.incr_overflow_count();
360
361 Ok(())
362 }
363
364 /// Pops a task from the local queue.
365 pub(crate) fn pop(&mut self) -> Option<task::Notified<T>> {
366 let mut head = self.inner.head.load(Acquire);
367
368 let idx = loop {
369 let (steal, real) = unpack(head);
370
371 // safety: this is the **only** thread that updates this cell.
372 let tail = unsafe { self.inner.tail.unsync_load() };
373
374 if real == tail {
375 // queue is empty
376 return None;
377 }
378
379 let next_real = real.wrapping_add(1);
380
381 // If `steal == real` there are no concurrent stealers. Both `steal`
382 // and `real` are updated.
383 let next = if steal == real {
384 pack(next_real, next_real)
385 } else {
386 assert_ne!(steal, next_real);
387 pack(steal, next_real)
388 };
389
390 // Attempt to claim a task.
391 let res = self
392 .inner
393 .head
394 .compare_exchange_weak(head, next, AcqRel, Acquire);
395
396 match res {
397 Ok(_) => break real as usize & MASK,
398 Err(actual) => head = actual,
399 }
400 };
401
402 Some(self.inner.buffer[idx].with(|ptr| unsafe { ptr::read(ptr).assume_init() }))
403 }
404}
405
406impl<T> Steal<T> {
407 /// Returns the number of entries in the queue
408 pub(crate) fn len(&self) -> usize {
409 let (_, head) = unpack(self.0.head.load(Acquire));
410 let tail = self.0.tail.load(Acquire);
411 len(head, tail)
412 }
413
414 /// Return true if the queue is empty,
415 /// false if there are any entries in the queue
416 pub(crate) fn is_empty(&self) -> bool {
417 self.len() == 0
418 }
419
420 /// Steals half the tasks from self and place them into `dst`.
421 pub(crate) fn steal_into(
422 &self,
423 dst: &mut Local<T>,
424 dst_stats: &mut Stats,
425 ) -> Option<task::Notified<T>> {
426 // Safety: the caller is the only thread that mutates `dst.tail` and
427 // holds a mutable reference.
428 let dst_tail = unsafe { dst.inner.tail.unsync_load() };
429
430 // To the caller, `dst` may **look** empty but still have values
431 // contained in the buffer. If another thread is concurrently stealing
432 // from `dst` there may not be enough capacity to steal.
433 let (steal, _) = unpack(dst.inner.head.load(Acquire));
434
435 if dst_tail.wrapping_sub(steal) > LOCAL_QUEUE_CAPACITY as UnsignedShort / 2 {
436 // we *could* try to steal less here, but for simplicity, we're just
437 // going to abort.
438 return None;
439 }
440
441 // Steal the tasks into `dst`'s buffer. This does not yet expose the
442 // tasks in `dst`.
443 let mut n = self.steal_into2(dst, dst_tail);
444
445 if n == 0 {
446 // No tasks were stolen
447 return None;
448 }
449
450 dst_stats.incr_steal_count(n as u16);
451 dst_stats.incr_steal_operations();
452
453 // We are returning a task here
454 n -= 1;
455
456 let ret_pos = dst_tail.wrapping_add(n);
457 let ret_idx = ret_pos as usize & MASK;
458
459 // safety: the value was written as part of `steal_into2` and not
460 // exposed to stealers, so no other thread can access it.
461 let ret = dst.inner.buffer[ret_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
462
463 if n == 0 {
464 // The `dst` queue is empty, but a single task was stolen
465 return Some(ret);
466 }
467
468 // Make the stolen items available to consumers
469 dst.inner.tail.store(dst_tail.wrapping_add(n), Release);
470
471 Some(ret)
472 }
473
474 // Steal tasks from `self`, placing them into `dst`. Returns the number of
475 // tasks that were stolen.
476 fn steal_into2(&self, dst: &mut Local<T>, dst_tail: UnsignedShort) -> UnsignedShort {
477 let mut prev_packed = self.0.head.load(Acquire);
478 let mut next_packed;
479
480 let n = loop {
481 let (src_head_steal, src_head_real) = unpack(prev_packed);
482 let src_tail = self.0.tail.load(Acquire);
483
484 // If these two do not match, another thread is concurrently
485 // stealing from the queue.
486 if src_head_steal != src_head_real {
487 return 0;
488 }
489
490 // Number of available tasks to steal
491 let n = src_tail.wrapping_sub(src_head_real);
492 let n = n - n / 2;
493
494 if n == 0 {
495 // No tasks available to steal
496 return 0;
497 }
498
499 // Update the real head index to acquire the tasks.
500 let steal_to = src_head_real.wrapping_add(n);
501 assert_ne!(src_head_steal, steal_to);
502 next_packed = pack(src_head_steal, steal_to);
503
504 // Claim all those tasks. This is done by incrementing the "real"
505 // head but not the steal. By doing this, no other thread is able to
506 // steal from this queue until the current thread completes.
507 let res = self
508 .0
509 .head
510 .compare_exchange_weak(prev_packed, next_packed, AcqRel, Acquire);
511
512 match res {
513 Ok(_) => break n,
514 Err(actual) => prev_packed = actual,
515 }
516 };
517
518 assert!(
519 n <= LOCAL_QUEUE_CAPACITY as UnsignedShort / 2,
520 "actual = {n}"
521 );
522
523 let (first, _) = unpack(next_packed);
524
525 // Take all the tasks
526 for i in 0..n {
527 // Compute the positions
528 let src_pos = first.wrapping_add(i);
529 let dst_pos = dst_tail.wrapping_add(i);
530
531 // Map to slots
532 let src_idx = src_pos as usize & MASK;
533 let dst_idx = dst_pos as usize & MASK;
534
535 // Read the task
536 //
537 // safety: We acquired the task with the atomic exchange above.
538 let task = self.0.buffer[src_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
539
540 // Write the task to the new slot
541 //
542 // safety: `dst` queue is empty and we are the only producer to
543 // this queue.
544 dst.inner.buffer[dst_idx]
545 .with_mut(|ptr| unsafe { ptr::write((*ptr).as_mut_ptr(), task) });
546 }
547
548 let mut prev_packed = next_packed;
549
550 // Update `src_head_steal` to match `src_head_real` signalling that the
551 // stealing routine is complete.
552 loop {
553 let head = unpack(prev_packed).1;
554 next_packed = pack(head, head);
555
556 let res = self
557 .0
558 .head
559 .compare_exchange_weak(prev_packed, next_packed, AcqRel, Acquire);
560
561 match res {
562 Ok(_) => return n,
563 Err(actual) => prev_packed = actual,
564 }
565 }
566 }
567}
568
569impl<T> Clone for Steal<T> {
570 fn clone(&self) -> Steal<T> {
571 Steal(self.0.clone())
572 }
573}
574
575impl<T> Drop for Local<T> {
576 fn drop(&mut self) {
577 if !std::thread::panicking() {
578 assert!(self.pop().is_none(), "queue not empty");
579 }
580 }
581}
582
583/// Calculate the length of the queue using the head and tail.
584/// The `head` can be the `steal` or `real` head.
585fn len(head: UnsignedShort, tail: UnsignedShort) -> usize {
586 tail.wrapping_sub(head) as usize
587}
588
589/// Split the head value into the real head and the index a stealer is working
590/// on.
591fn unpack(n: UnsignedLong) -> (UnsignedShort, UnsignedShort) {
592 let real = n & UnsignedShort::MAX as UnsignedLong;
593 let steal = n >> (mem::size_of::<UnsignedShort>() * 8);
594
595 (steal as UnsignedShort, real as UnsignedShort)
596}
597
598/// Join the two head values
599fn pack(steal: UnsignedShort, real: UnsignedShort) -> UnsignedLong {
600 (real as UnsignedLong) | ((steal as UnsignedLong) << (mem::size_of::<UnsignedShort>() * 8))
601}
602
603#[test]
604fn test_local_queue_capacity() {
605 assert!(LOCAL_QUEUE_CAPACITY - 1 <= u8::MAX as usize);
606}