1use std::{
2 future::Future,
3 hash::{BuildHasher, Hash},
4 hint::unreachable_unchecked,
5 mem, pin,
6 task::{self, Poll},
7 time::{Duration, Instant},
8};
9
10use crate::{
11 linked_slab::Token,
12 shard::CacheShard,
13 shim::{
14 rw_lock::{RwLock, RwLockWriteGuard},
15 sync::{
16 atomic::{AtomicBool, Ordering},
17 Arc,
18 },
19 thread, OnceLock,
20 },
21 Equivalent, Lifecycle, Weighter,
22};
23
24pub type SharedPlaceholder<Val> = Arc<Placeholder<Val>>;
25
26impl<Val> crate::shard::SharedPlaceholder for SharedPlaceholder<Val> {
27 fn new(hash: u64, idx: Token) -> Self {
28 Arc::new(Placeholder {
29 hash,
30 idx,
31 value: OnceLock::new(),
32 state: RwLock::new(State {
33 waiters: Default::default(),
34 loading: LoadingState::Loading,
35 }),
36 })
37 }
38
39 #[inline]
40 fn same_as(&self, other: &Self) -> bool {
41 Arc::ptr_eq(self, other)
42 }
43
44 #[inline]
45 fn hash(&self) -> u64 {
46 self.hash
47 }
48
49 #[inline]
50 fn idx(&self) -> Token {
51 self.idx
52 }
53}
54
55#[derive(Debug)]
56pub struct Placeholder<Val> {
57 hash: u64,
58 idx: Token,
59 state: RwLock<State>,
60 value: OnceLock<Val>,
61}
62
63#[derive(Debug)]
64pub struct State {
65 waiters: Vec<Waiter>,
69 loading: LoadingState,
70}
71
72#[derive(Debug)]
73enum LoadingState {
74 Loading,
76 Inserted,
78}
79
80pub struct PlaceholderGuard<'a, Key, Val, We, B, L> {
81 lifecycle: &'a L,
82 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
83 shared: SharedPlaceholder<Val>,
84 inserted: bool,
85}
86
87#[derive(Debug)]
88enum Waiter {
89 Thread {
90 notified: *const AtomicBool,
91 thread: thread::Thread,
92 },
93 Task {
94 notified: *const AtomicBool,
95 waker: task::Waker,
96 },
97}
98
99unsafe impl Send for Waiter {}
102unsafe impl Sync for Waiter {}
103
104impl Waiter {
105 #[inline]
106 fn notify(self) {
107 match self {
108 Waiter::Thread {
109 thread, notified, ..
110 } => {
111 unsafe { notified.as_ref().unwrap().store(true, Ordering::Release) };
114 thread.unpark();
115 }
116 Waiter::Task { waker: t, notified } => {
117 unsafe { notified.as_ref().unwrap().store(true, Ordering::Release) };
118 t.wake();
119 }
120 }
121 }
122
123 #[inline]
124 fn is_waiter(&self, other: *const AtomicBool) -> bool {
125 matches!(self, Waiter::Task { notified, .. } | Waiter::Thread { notified, .. } if std::ptr::eq(*notified, other))
126 }
127}
128
129#[derive(Debug)]
130pub enum GuardResult<'a, Key, Val, We, B, L> {
131 Value(Val),
132 Guard(PlaceholderGuard<'a, Key, Val, We, B, L>),
133 Timeout,
134}
135
136impl<'a, Key, Val, We, B, L> PlaceholderGuard<'a, Key, Val, We, B, L> {
137 pub fn start_loading(
138 lifecycle: &'a L,
139 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
140 shared: SharedPlaceholder<Val>,
141 ) -> Self {
142 debug_assert!(matches!(
143 shared.state.write().loading,
144 LoadingState::Loading
145 ));
146 PlaceholderGuard {
147 lifecycle,
148 shard,
149 shared,
150 inserted: false,
151 }
152 }
153
154 #[inline]
157 fn handle_notification(
158 lifecycle: &'a L,
159 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
160 shared: SharedPlaceholder<Val>,
161 ) -> Result<Val, PlaceholderGuard<'a, Key, Val, We, B, L>>
162 where
163 Val: Clone,
164 {
165 if let Some(v) = shared.value.get() {
168 Ok(v.clone())
169 } else {
170 Err(PlaceholderGuard::start_loading(lifecycle, shard, shared))
171 }
172 }
173
174 #[inline]
176 fn join_waiters(
177 _locked_shard: RwLockWriteGuard<'a, CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
179 shared: &SharedPlaceholder<Val>,
180 waiter_new: impl FnOnce() -> Option<Waiter>,
182 ) -> Option<Val>
183 where
184 Val: Clone,
185 {
186 let mut state = shared.state.write();
187 match state.loading {
192 LoadingState::Loading => {
193 if let Some(waiter) = waiter_new() {
194 state.waiters.push(waiter);
195 }
196 None
197 }
198 LoadingState::Inserted => unsafe {
199 drop(state); Some(shared.value.get().unwrap_unchecked().clone())
202 },
203 }
204 }
205}
206
207impl<
208 'a,
209 Key: Eq + Hash,
210 Val: Clone,
211 We: Weighter<Key, Val>,
212 B: BuildHasher,
213 L: Lifecycle<Key, Val>,
214 > PlaceholderGuard<'a, Key, Val, We, B, L>
215{
216 pub fn join<Q>(
217 lifecycle: &'a L,
218 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
219 hash: u64,
220 key: &Q,
221 mut timeout: Option<Duration>,
222 ) -> GuardResult<'a, Key, Val, We, B, L>
223 where
224 Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
225 {
226 let mut shard_guard = shard.write();
227 let shared = match shard_guard.upsert_placeholder(hash, key) {
228 Ok((_, v)) => return GuardResult::Value(v.clone()),
229 Err((shared, true)) => {
230 return GuardResult::Guard(Self::start_loading(lifecycle, shard, shared));
231 }
232 Err((shared, false)) => shared,
233 };
234
235 let notified = pin::pin!(AtomicBool::new(false));
237 let mut parked_thread = None;
239 let maybe_val = Self::join_waiters(shard_guard, &shared, || {
240 if timeout.is_some_and(|t| t.is_zero()) {
241 None
242 } else {
243 let thread = thread::current();
244 let id = thread.id();
245 parked_thread = Some(id);
246 Some(Waiter::Thread {
247 thread,
248 notified: &*notified as *const AtomicBool,
249 })
250 }
251 });
252 if let Some(v) = maybe_val {
253 return GuardResult::Value(v);
254 }
255
256 let mut timeout_start = None;
258 loop {
259 if let Some(remaining) = timeout {
260 if remaining.is_zero() {
261 return Self::join_timeout(lifecycle, shard, shared, parked_thread, ¬ified);
262 }
263 let start = *timeout_start.get_or_insert_with(Instant::now);
264 #[cfg(not(fuzzing))]
265 thread::park_timeout(remaining);
266 timeout = Some(remaining.saturating_sub(start.elapsed()));
267 } else {
268 thread::park();
269 }
270 if notified.load(Ordering::Acquire) {
271 return match Self::handle_notification(lifecycle, shard, shared) {
272 Ok(v) => GuardResult::Value(v),
273 Err(g) => GuardResult::Guard(g),
274 };
275 }
276 }
277 }
278
279 #[cold]
280 fn join_timeout(
281 lifecycle: &'a L,
282 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, Arc<Placeholder<Val>>>>,
283 shared: Arc<Placeholder<Val>>,
284 parked_thread: Option<thread::ThreadId>,
286 notified: &AtomicBool,
287 ) -> GuardResult<'a, Key, Val, We, B, L> {
288 let mut state = shared.state.write();
289 match state.loading {
290 LoadingState::Loading if notified.load(Ordering::Acquire) => {
291 drop(state); GuardResult::Guard(PlaceholderGuard::start_loading(lifecycle, shard, shared))
293 }
294 LoadingState::Loading => {
295 if parked_thread.is_some() {
296 let waiter_idx = state
298 .waiters
299 .iter()
300 .position(|w| w.is_waiter(notified as _));
301 if let Some(idx) = waiter_idx {
302 state.waiters.swap_remove(idx);
303 } else {
304 unsafe { unreachable_unchecked() };
305 }
306 }
307 GuardResult::Timeout
308 }
309 LoadingState::Inserted => unsafe {
310 GuardResult::Value(shared.value.get().unwrap_unchecked().clone())
312 },
313 }
314 }
315}
316
317impl<
318 Key: Eq + Hash,
319 Val: Clone,
320 We: Weighter<Key, Val>,
321 B: BuildHasher,
322 L: Lifecycle<Key, Val>,
323 > PlaceholderGuard<'_, Key, Val, We, B, L>
324{
325 pub fn insert(self, value: Val) -> Result<(), Val> {
331 let lifecycle = self.lifecycle;
332 let lcs = self.insert_with_lifecycle(value)?;
333 lifecycle.end_request(lcs);
334 Ok(())
335 }
336
337 pub fn insert_with_lifecycle(mut self, value: Val) -> Result<L::RequestState, Val> {
343 unsafe { self.shared.value.set(value.clone()).unwrap_unchecked() };
344 let referenced;
345 {
346 let mut state = self.shared.state.write();
350 state.loading = LoadingState::Inserted;
351 referenced = !state.waiters.is_empty();
352 for w in state.waiters.drain(..) {
353 w.notify();
354 }
355 }
356
357 self.inserted = true;
362
363 let mut lcs = self.lifecycle.begin_request();
364 self.shard
365 .write()
366 .replace_placeholder(&mut lcs, &self.shared, referenced, value)?;
367 Ok(lcs)
368 }
369}
370
371impl<Key, Val, We, B, L> PlaceholderGuard<'_, Key, Val, We, B, L> {
372 #[cold]
373 fn drop_uninserted_slow(&mut self) {
374 {
378 let mut state = self.shared.state.write();
379 debug_assert!(matches!(state.loading, LoadingState::Loading));
380 if let Some(waiter) = state.waiters.pop() {
381 waiter.notify();
382 return;
383 }
384 }
385
386 let mut shard_guard = self.shard.write();
390 let mut state = self.shared.state.write();
391 debug_assert!(matches!(state.loading, LoadingState::Loading));
392 if let Some(waiter) = state.waiters.pop() {
393 drop(shard_guard);
394 waiter.notify();
395 } else {
396 shard_guard.remove_placeholder(&self.shared);
397 }
398 }
399}
400
401impl<Key, Val, We, B, L> Drop for PlaceholderGuard<'_, Key, Val, We, B, L> {
402 #[inline]
403 fn drop(&mut self) {
404 if !self.inserted {
405 self.drop_uninserted_slow();
406 }
407 }
408}
409impl<Key, Val, We, B, L> std::fmt::Debug for PlaceholderGuard<'_, Key, Val, We, B, L> {
410 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
411 f.debug_struct("PlaceholderGuard").finish_non_exhaustive()
412 }
413}
414
415pub struct JoinFuture<'a, 'b, Q: ?Sized, Key, Val, We, B, L> {
417 lifecycle: &'a L,
418 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
419 state: JoinFutureState<'b, Q, Val>,
420 notified: AtomicBool,
421}
422
423enum JoinFutureState<'b, Q: ?Sized, Val> {
424 Created {
425 hash: u64,
426 key: &'b Q,
427 },
428 Pending {
429 shared: SharedPlaceholder<Val>,
430 waker: task::Waker,
431 },
432 Done,
433}
434
435impl<'a, 'b, Q: ?Sized, Key, Val, We, B, L> JoinFuture<'a, 'b, Q, Key, Val, We, B, L> {
436 pub fn new(
437 lifecycle: &'a L,
438 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
439 hash: u64,
440 key: &'b Q,
441 ) -> JoinFuture<'a, 'b, Q, Key, Val, We, B, L> {
442 Self {
443 lifecycle,
444 shard,
445 state: JoinFutureState::Created { hash, key },
446 notified: Default::default(),
447 }
448 }
449
450 #[cold]
451 fn drop_pending_waiter(&mut self) {
452 let JoinFutureState::Pending { shared, .. } =
453 mem::replace(&mut self.state, JoinFutureState::Done)
454 else {
455 unsafe { unreachable_unchecked() }
456 };
457 let mut state = shared.state.write();
458 match state.loading {
459 LoadingState::Loading if self.notified.load(Ordering::Acquire) => {
460 drop(state); let _ = PlaceholderGuard::start_loading(self.lifecycle, self.shard, shared);
464 }
465 LoadingState::Loading => {
466 let waiter_idx = state
468 .waiters
469 .iter()
470 .position(|w| w.is_waiter(&self.notified as _));
471 if let Some(idx) = waiter_idx {
472 state.waiters.swap_remove(idx);
473 } else {
474 unsafe { unreachable_unchecked() }
476 }
477 }
478 LoadingState::Inserted => (), }
480 }
481}
482
483impl<Q: ?Sized, Key, Val, We, B, L> Drop for JoinFuture<'_, '_, Q, Key, Val, We, B, L> {
484 #[inline]
485 fn drop(&mut self) {
486 if matches!(self.state, JoinFutureState::Pending { .. }) {
487 self.drop_pending_waiter();
488 }
489 }
490}
491
492impl<
493 'a,
494 Key: Eq + Hash,
495 Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
496 Val: Clone,
497 We: Weighter<Key, Val>,
498 B: BuildHasher,
499 L: Lifecycle<Key, Val>,
500 > Future for JoinFuture<'a, '_, Q, Key, Val, We, B, L>
501{
502 type Output = Result<Val, PlaceholderGuard<'a, Key, Val, We, B, L>>;
503
504 fn poll(mut self: pin::Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
505 let this = &mut *self;
506 let lifecycle = this.lifecycle;
507 let shard = this.shard;
508 match &mut this.state {
509 JoinFutureState::Created { hash, key } => {
510 debug_assert!(!this.notified.load(Ordering::Acquire));
511 let mut shard_guard = shard.write();
512 match shard_guard.upsert_placeholder(*hash, *key) {
513 Ok((_, v)) => {
514 this.state = JoinFutureState::Done;
515 Poll::Ready(Ok(v.clone()))
516 }
517 Err((shared, true)) => {
518 let guard = PlaceholderGuard::start_loading(lifecycle, shard, shared);
519 this.state = JoinFutureState::Done;
520 Poll::Ready(Err(guard))
521 }
522 Err((shared, false)) => {
523 let mut waker = None;
524 let maybe_val =
525 PlaceholderGuard::join_waiters(shard_guard, &shared, || {
526 let waker_ = cx.waker().clone();
527 waker = Some(waker_.clone());
528 Some(Waiter::Task {
529 waker: waker_,
530 notified: &this.notified as *const AtomicBool,
531 })
532 });
533 if let Some(v) = maybe_val {
534 debug_assert!(waker.is_none());
535 debug_assert!(!this.notified.load(Ordering::Acquire));
536 this.state = JoinFutureState::Done;
537 Poll::Ready(Ok(v))
538 } else {
539 let waker = waker.unwrap();
540 this.state = JoinFutureState::Pending { shared, waker };
541 Poll::Pending
542 }
543 }
544 }
545 }
546 JoinFutureState::Pending { .. } if this.notified.load(Ordering::Acquire) => {
547 let JoinFutureState::Pending { shared, .. } =
548 mem::replace(&mut this.state, JoinFutureState::Done)
549 else {
550 unsafe { unreachable_unchecked() }
551 };
552 Poll::Ready(PlaceholderGuard::handle_notification(
553 lifecycle, shard, shared,
554 ))
555 }
556 JoinFutureState::Pending { waker, shared } => {
557 let new_waker = cx.waker();
559 if !waker.will_wake(new_waker) {
560 let mut state = shared.state.write();
561 if let Some(w) = state
562 .waiters
563 .iter_mut()
564 .find(|w| w.is_waiter(&this.notified as _))
565 {
566 *waker = new_waker.clone();
567 *w = Waiter::Task {
568 waker: new_waker.clone(),
569 notified: &this.notified as *const AtomicBool,
570 };
571 } else {
572 unsafe { unreachable_unchecked() };
573 }
574 }
575 Poll::Pending
576 }
577 JoinFutureState::Done => panic!("Polled after ready"),
578 }
579 }
580}