1use crate::future::Future;
10use crate::loom::cell::UnsafeCell;
11use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, SpawnLocation, Task};
12use crate::util::linked_list::{Link, LinkedList};
13use crate::util::sharded_list;
14
15use crate::loom::sync::atomic::{AtomicBool, Ordering};
16use std::marker::PhantomData;
17use std::num::NonZeroU64;
18
19cfg_has_atomic_u64! {
29 use std::sync::atomic::AtomicU64;
30
31 static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1);
32
33 fn get_next_id() -> NonZeroU64 {
34 loop {
35 let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
36 if let Some(id) = NonZeroU64::new(id) {
37 return id;
38 }
39 }
40 }
41}
42
43cfg_not_has_atomic_u64! {
44 use std::sync::atomic::AtomicU32;
45
46 static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1);
47
48 fn get_next_id() -> NonZeroU64 {
49 loop {
50 let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
51 if let Some(id) = NonZeroU64::new(u64::from(id)) {
52 return id;
53 }
54 }
55 }
56}
57
58pub(crate) struct OwnedTasks<S: 'static> {
59 list: List<S>,
60 pub(crate) id: NonZeroU64,
61 closed: AtomicBool,
62}
63
64type List<S> = sharded_list::ShardedList<Task<S>, <Task<S> as Link>::Target>;
65
66pub(crate) struct LocalOwnedTasks<S: 'static> {
67 inner: UnsafeCell<OwnedTasksInner<S>>,
68 pub(crate) id: NonZeroU64,
69 _not_send_or_sync: PhantomData<*const ()>,
70}
71
72struct OwnedTasksInner<S: 'static> {
73 list: LinkedList<Task<S>, <Task<S> as Link>::Target>,
74 closed: bool,
75}
76
77impl<S: 'static> OwnedTasks<S> {
78 pub(crate) fn new(num_cores: usize) -> Self {
79 let shard_size = Self::gen_shared_list_size(num_cores);
80 Self {
81 list: List::new(shard_size),
82 closed: AtomicBool::new(false),
83 id: get_next_id(),
84 }
85 }
86
87 pub(crate) fn bind<T>(
90 &self,
91 task: T,
92 scheduler: S,
93 id: super::Id,
94 spawned_at: SpawnLocation,
95 ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
96 where
97 S: Schedule,
98 T: Future + Send + 'static,
99 T::Output: Send + 'static,
100 {
101 let (task, notified, join) = super::new_task(task, scheduler, id, spawned_at);
102 let notified = unsafe { self.bind_inner(task, notified) };
103 (join, notified)
104 }
105
106 pub(crate) unsafe fn bind_local<T>(
112 &self,
113 task: T,
114 scheduler: S,
115 id: super::Id,
116 spawned_at: SpawnLocation,
117 ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
118 where
119 S: Schedule,
120 T: Future + 'static,
121 T::Output: 'static,
122 {
123 let (task, notified, join) = super::new_task(task, scheduler, id, spawned_at);
124 let notified = unsafe { self.bind_inner(task, notified) };
125 (join, notified)
126 }
127
128 unsafe fn bind_inner(&self, task: Task<S>, notified: Notified<S>) -> Option<Notified<S>>
130 where
131 S: Schedule,
132 {
133 unsafe {
134 task.header().set_owner_id(self.id);
137 }
138
139 let shard = self.list.lock_shard(&task);
140 if self.closed.load(Ordering::Acquire) {
143 drop(shard);
144 task.shutdown();
145 return None;
146 }
147 shard.push(task);
148 Some(notified)
149 }
150
151 #[inline]
154 pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
155 debug_assert_eq!(task.header().get_owner_id(), Some(self.id));
156 LocalNotified {
159 task: task.0,
160 _not_send: PhantomData,
161 }
162 }
163
164 pub(crate) fn close_and_shutdown_all(&self, start: usize)
170 where
171 S: Schedule,
172 {
173 self.closed.store(true, Ordering::Release);
174 for i in start..self.get_shard_size() + start {
175 loop {
176 let task = self.list.pop_back(i);
177 match task {
178 Some(task) => {
179 task.shutdown();
180 }
181 None => break,
182 }
183 }
184 }
185 }
186
187 #[inline]
188 pub(crate) fn get_shard_size(&self) -> usize {
189 self.list.shard_size()
190 }
191
192 pub(crate) fn num_alive_tasks(&self) -> usize {
193 self.list.len()
194 }
195
196 cfg_64bit_metrics! {
197 pub(crate) fn spawned_tasks_count(&self) -> u64 {
198 self.list.added()
199 }
200 }
201
202 pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
203 let task_id = task.header().get_owner_id()?;
206
207 assert_eq!(task_id, self.id);
208
209 unsafe { self.list.remove(task.header_ptr()) }
212 }
213
214 pub(crate) fn is_empty(&self) -> bool {
215 self.list.is_empty()
216 }
217
218 fn gen_shared_list_size(num_cores: usize) -> usize {
230 const MAX_SHARED_LIST_SIZE: usize = 1 << 16;
231 usize::min(MAX_SHARED_LIST_SIZE, num_cores.next_power_of_two() * 4)
232 }
233}
234
235cfg_taskdump! {
236 impl<S: 'static> OwnedTasks<S> {
237 pub(crate) fn for_each<F>(&self, f: F)
239 where
240 F: FnMut(&Task<S>),
241 {
242 self.list.for_each(f);
243 }
244 }
245}
246
247impl<S: 'static> LocalOwnedTasks<S> {
248 pub(crate) fn new() -> Self {
249 Self {
250 inner: UnsafeCell::new(OwnedTasksInner {
251 list: LinkedList::new(),
252 closed: false,
253 }),
254 id: get_next_id(),
255 _not_send_or_sync: PhantomData,
256 }
257 }
258
259 pub(crate) fn bind<T>(
260 &self,
261 task: T,
262 scheduler: S,
263 id: super::Id,
264 spawned_at: SpawnLocation,
265 ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
266 where
267 S: Schedule,
268 T: Future + 'static,
269 T::Output: 'static,
270 {
271 let (task, notified, join) = super::new_task(task, scheduler, id, spawned_at);
272
273 unsafe {
274 task.header().set_owner_id(self.id);
277 }
278
279 if self.is_closed() {
280 drop(notified);
281 task.shutdown();
282 (join, None)
283 } else {
284 self.with_inner(|inner| {
285 inner.list.push_front(task);
286 });
287 (join, Some(notified))
288 }
289 }
290
291 pub(crate) fn close_and_shutdown_all(&self)
294 where
295 S: Schedule,
296 {
297 self.with_inner(|inner| inner.closed = true);
298
299 while let Some(task) = self.with_inner(|inner| inner.list.pop_back()) {
300 task.shutdown();
301 }
302 }
303
304 pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
305 let task_id = task.header().get_owner_id()?;
308
309 assert_eq!(task_id, self.id);
310
311 self.with_inner(|inner|
312 unsafe { inner.list.remove(task.header_ptr()) })
315 }
316
317 #[inline]
320 pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
321 assert_eq!(task.header().get_owner_id(), Some(self.id));
322
323 LocalNotified {
327 task: task.0,
328 _not_send: PhantomData,
329 }
330 }
331
332 #[inline]
333 fn with_inner<F, T>(&self, f: F) -> T
334 where
335 F: FnOnce(&mut OwnedTasksInner<S>) -> T,
336 {
337 self.inner.with_mut(|ptr| unsafe { f(&mut *ptr) })
341 }
342
343 pub(crate) fn is_closed(&self) -> bool {
344 self.with_inner(|inner| inner.closed)
345 }
346
347 pub(crate) fn is_empty(&self) -> bool {
348 self.with_inner(|inner| inner.list.is_empty())
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
359 fn test_id_not_broken() {
360 let mut last_id = get_next_id();
361
362 for _ in 0..1000 {
363 let next_id = get_next_id();
364 assert!(last_id < next_id);
365 last_id = next_id;
366 }
367 }
368}