1use std::{
2 future::Future,
3 hash::{BuildHasher, Hash},
4 hint::unreachable_unchecked,
5 marker::PhantomPinned,
6 mem, pin,
7 task::{self, Poll},
8 time::{Duration, Instant},
9};
10
11use crate::{
12 linked_slab::Token,
13 shard::CacheShard,
14 shim::{
15 rw_lock::{RwLock, RwLockWriteGuard},
16 sync::{
17 atomic::{AtomicBool, Ordering},
18 Arc,
19 },
20 thread, OnceLock,
21 },
22 Equivalent, Lifecycle, Weighter,
23};
24
25pub type SharedPlaceholder<Val> = Arc<Placeholder<Val>>;
26
27impl<Val> crate::shard::SharedPlaceholder for SharedPlaceholder<Val> {
28 fn new(hash: u64, idx: Token) -> Self {
29 Arc::new(Placeholder {
30 hash,
31 idx,
32 value: OnceLock::new(),
33 state: RwLock::new(State {
34 waiters: Default::default(),
35 loading: LoadingState::Loading,
36 }),
37 })
38 }
39
40 #[inline]
41 fn same_as(&self, other: &Self) -> bool {
42 Arc::ptr_eq(self, other)
43 }
44
45 #[inline]
46 fn hash(&self) -> u64 {
47 self.hash
48 }
49
50 #[inline]
51 fn idx(&self) -> Token {
52 self.idx
53 }
54}
55
56#[derive(Debug)]
57pub struct Placeholder<Val> {
58 hash: u64,
59 idx: Token,
60 state: RwLock<State>,
61 value: OnceLock<Val>,
62}
63
64impl<Val> Placeholder<Val> {
65 #[inline]
67 pub(crate) fn value(&self) -> Option<&Val> {
68 self.value.get()
69 }
70}
71
72#[derive(Debug)]
73pub struct State {
74 waiters: Vec<Waiter>,
78 loading: LoadingState,
79}
80
81#[derive(Debug)]
82enum LoadingState {
83 Loading,
85 Inserted,
87}
88
89pub struct PlaceholderGuard<'a, Key, Val, We, B, L> {
90 lifecycle: &'a L,
91 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
92 shared: SharedPlaceholder<Val>,
93 inserted: bool,
94}
95
96#[cfg(test)]
97impl<'a, Key, Val, We, B, L> PlaceholderGuard<'a, Key, Val, We, B, L> {
98 pub fn shared(&self) -> &SharedPlaceholder<Val> {
99 &self.shared
100 }
101}
102
103#[derive(Debug)]
104enum Waiter {
105 Thread {
106 notified: *const AtomicBool,
107 thread: thread::Thread,
108 },
109 Task {
110 notified: *const AtomicBool,
111 waker: task::Waker,
112 },
113}
114
115unsafe impl Send for Waiter {}
118unsafe impl Sync for Waiter {}
119
120impl Waiter {
121 #[inline]
122 fn notify(self) {
123 match self {
124 Waiter::Thread {
125 thread, notified, ..
126 } => {
127 unsafe { notified.as_ref().unwrap().store(true, Ordering::Release) };
130 thread.unpark();
131 }
132 Waiter::Task { waker: t, notified } => {
133 unsafe { notified.as_ref().unwrap().store(true, Ordering::Release) };
134 t.wake();
135 }
136 }
137 }
138
139 #[inline]
140 fn is_waiter(&self, other: *const AtomicBool) -> bool {
141 matches!(self, Waiter::Task { notified, .. } | Waiter::Thread { notified, .. } if std::ptr::eq(*notified, other))
142 }
143}
144
145#[derive(Debug)]
150pub enum GuardResult<'a, Key, Val, We, B, L> {
151 Value(Val),
153 Guard(PlaceholderGuard<'a, Key, Val, We, B, L>),
155 Timeout,
157}
158
159pub use crate::shard::EntryAction;
161
162pub(crate) enum JoinResult<'a, Key, Val, We, B, L> {
164 Filled(Option<SharedPlaceholder<Val>>),
167 Guard(PlaceholderGuard<'a, Key, Val, We, B, L>),
169 Timeout,
171}
172
173#[derive(Debug)]
176pub enum EntryResult<'a, Key, Val, We, B, L, T> {
177 Retained(T),
180 Removed(Key, Val),
183 Replaced(PlaceholderGuard<'a, Key, Val, We, B, L>, Val),
186 Vacant(PlaceholderGuard<'a, Key, Val, We, B, L>),
188 Timeout,
194}
195
196impl<'a, Key, Val, We, B, L> PlaceholderGuard<'a, Key, Val, We, B, L> {
197 #[inline]
198 pub fn start_loading(
199 lifecycle: &'a L,
200 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
201 shared: SharedPlaceholder<Val>,
202 ) -> Self {
203 debug_assert!(matches!(
204 shared.state.write().loading,
205 LoadingState::Loading
206 ));
207 PlaceholderGuard {
208 lifecycle,
209 shard,
210 shared,
211 inserted: false,
212 }
213 }
214
215 #[inline]
218 fn handle_notification(
219 lifecycle: &'a L,
220 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
221 shared: SharedPlaceholder<Val>,
222 ) -> Result<SharedPlaceholder<Val>, PlaceholderGuard<'a, Key, Val, We, B, L>> {
223 if shared.value().is_some() {
226 Ok(shared)
227 } else {
228 Err(PlaceholderGuard::start_loading(lifecycle, shard, shared))
229 }
230 }
231
232 #[inline]
234 fn join_waiters(
235 _locked_shard: RwLockWriteGuard<'a, CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
237 shared: &SharedPlaceholder<Val>,
238 waiter_new: impl FnOnce() -> Option<Waiter>,
240 ) -> bool {
241 let mut state = shared.state.write();
242 match state.loading {
247 LoadingState::Loading => {
248 if let Some(waiter) = waiter_new() {
249 state.waiters.push(waiter);
250 }
251 false
252 }
253 LoadingState::Inserted => true,
254 }
255 }
256}
257
258impl<
259 'a,
260 Key: Eq + Hash,
261 Val: Clone,
262 We: Weighter<Key, Val>,
263 B: BuildHasher,
264 L: Lifecycle<Key, Val>,
265 > PlaceholderGuard<'a, Key, Val, We, B, L>
266{
267 pub fn join<Q>(
268 lifecycle: &'a L,
269 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
270 hash: u64,
271 key: &Q,
272 timeout: Option<Duration>,
273 ) -> GuardResult<'a, Key, Val, We, B, L>
274 where
275 Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
276 {
277 let mut shard_guard = shard.write();
278 let shared = match shard_guard.get_or_placeholder(hash, key) {
279 Ok((_, v)) => return GuardResult::Value(v.clone()),
280 Err((shared, true)) => {
281 return GuardResult::Guard(Self::start_loading(lifecycle, shard, shared));
282 }
283 Err((shared, false)) => shared,
284 };
285 let mut deadline = timeout.map(Ok);
286 match Self::wait_for_placeholder(lifecycle, shard, shard_guard, shared, deadline.as_mut()) {
287 JoinResult::Filled(shared) => unsafe {
288 GuardResult::Value(shared.unwrap_unchecked().value().unwrap_unchecked().clone())
290 },
291 JoinResult::Guard(g) => GuardResult::Guard(g),
292 JoinResult::Timeout => GuardResult::Timeout,
293 }
294 }
295
296 pub(crate) fn wait_for_placeholder(
305 lifecycle: &'a L,
306 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
307 shard_guard: RwLockWriteGuard<'a, CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
308 shared: SharedPlaceholder<Val>,
309 deadline: Option<&mut Result<Duration, Instant>>,
310 ) -> JoinResult<'a, Key, Val, We, B, L> {
311 let notified = pin::pin!(AtomicBool::new(false));
312 let mut parked_thread = None;
313 let already_filled = Self::join_waiters(shard_guard, &shared, || {
314 if matches!(deadline.as_deref(), Some(Ok(d)) if d.is_zero()) {
318 None
319 } else {
320 let thread = thread::current();
321 parked_thread = Some(thread.id());
322 Some(Waiter::Thread {
323 thread,
324 notified: &*notified as *const AtomicBool,
325 })
326 }
327 });
328 if already_filled {
329 return JoinResult::Filled(Some(shared));
330 }
331
332 let deadline = deadline.and_then(|d| match *d {
335 Ok(dur) => match Instant::now().checked_add(dur) {
336 Some(instant) => {
337 *d = Err(instant);
338 Some(instant)
339 }
340 None => None, },
342 Err(instant) => Some(instant),
343 });
344 loop {
345 if let Some(instant) = deadline {
346 let remaining = instant.saturating_duration_since(Instant::now());
347 if remaining.is_zero() {
348 return Self::join_timeout(lifecycle, shard, shared, parked_thread, ¬ified);
349 }
350 #[cfg(not(fuzzing))]
351 thread::park_timeout(remaining);
352 } else {
353 #[cfg(not(fuzzing))]
354 thread::park();
355 }
356 if notified.load(Ordering::Acquire) {
357 return match Self::handle_notification(lifecycle, shard, shared) {
358 Ok(shared) => JoinResult::Filled(Some(shared)),
359 Err(g) => JoinResult::Guard(g),
360 };
361 }
362 }
363 }
364
365 #[cold]
366 fn join_timeout(
367 lifecycle: &'a L,
368 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, Arc<Placeholder<Val>>>>,
369 shared: Arc<Placeholder<Val>>,
370 parked_thread: Option<thread::ThreadId>,
372 notified: &AtomicBool,
373 ) -> JoinResult<'a, Key, Val, We, B, L> {
374 let mut state = shared.state.write();
375 match state.loading {
376 LoadingState::Loading if notified.load(Ordering::Acquire) => {
377 drop(state); JoinResult::Guard(PlaceholderGuard::start_loading(lifecycle, shard, shared))
379 }
380 LoadingState::Loading => {
381 if parked_thread.is_some() {
382 let waiter_idx = state
384 .waiters
385 .iter()
386 .position(|w| w.is_waiter(notified as _));
387 if let Some(idx) = waiter_idx {
388 state.waiters.swap_remove(idx);
389 } else {
390 unsafe { unreachable_unchecked() };
391 }
392 }
393 JoinResult::Timeout
394 }
395 LoadingState::Inserted => {
396 drop(state);
397 JoinResult::Filled(Some(shared))
398 }
399 }
400 }
401}
402
403impl<
404 Key: Eq + Hash,
405 Val: Clone,
406 We: Weighter<Key, Val>,
407 B: BuildHasher,
408 L: Lifecycle<Key, Val>,
409 > PlaceholderGuard<'_, Key, Val, We, B, L>
410{
411 pub fn insert(self, value: Val) -> Result<(), Val> {
417 let lifecycle = self.lifecycle;
418 let lcs = self.insert_with_lifecycle(value)?;
419 lifecycle.end_request(lcs);
420 Ok(())
421 }
422
423 pub fn insert_with_lifecycle(mut self, value: Val) -> Result<L::RequestState, Val> {
429 unsafe { self.shared.value.set(value.clone()).unwrap_unchecked() };
430 let referenced;
431 {
432 let mut state = self.shared.state.write();
436 state.loading = LoadingState::Inserted;
437 referenced = !state.waiters.is_empty();
438 for w in state.waiters.drain(..) {
439 w.notify();
440 }
441 }
442
443 self.inserted = true;
448
449 let mut lcs = self.lifecycle.begin_request();
450 self.shard
451 .write()
452 .replace_placeholder(&mut lcs, &self.shared, referenced, value)?;
453 Ok(lcs)
454 }
455}
456
457impl<Key, Val, We, B, L> PlaceholderGuard<'_, Key, Val, We, B, L> {
458 #[cold]
459 fn drop_uninserted_slow(&mut self) {
460 {
464 let mut state = self.shared.state.write();
465 debug_assert!(matches!(state.loading, LoadingState::Loading));
466 if let Some(waiter) = state.waiters.pop() {
467 waiter.notify();
468 return;
469 }
470 }
471
472 let mut shard_guard = self.shard.write();
476 let mut state = self.shared.state.write();
477 debug_assert!(matches!(state.loading, LoadingState::Loading));
478 if let Some(waiter) = state.waiters.pop() {
479 drop(shard_guard);
480 waiter.notify();
481 } else {
482 shard_guard.remove_placeholder(&self.shared);
483 }
484 }
485}
486
487impl<Key, Val, We, B, L> Drop for PlaceholderGuard<'_, Key, Val, We, B, L> {
488 #[inline]
489 fn drop(&mut self) {
490 if !self.inserted {
491 self.drop_uninserted_slow();
492 }
493 }
494}
495impl<Key, Val, We, B, L> std::fmt::Debug for PlaceholderGuard<'_, Key, Val, We, B, L> {
496 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
497 f.debug_struct("PlaceholderGuard").finish_non_exhaustive()
498 }
499}
500
501pub(crate) struct JoinFuture<'a, 'b, Q: ?Sized, Key, Val, We, B, L> {
513 lifecycle: &'a L,
514 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
515 hash: u64,
516 key: &'b Q,
517 state: JoinFutureState<Val>,
518 notified: AtomicBool,
519 _pin: PhantomPinned,
520}
521
522enum JoinFutureState<Val> {
523 Created,
524 Pending {
525 shared: SharedPlaceholder<Val>,
526 waker: task::Waker,
527 },
528 Done,
529}
530
531impl<'a, 'b, Q: ?Sized, Key, Val, We, B, L> JoinFuture<'a, 'b, Q, Key, Val, We, B, L> {
532 pub(crate) fn new(
533 lifecycle: &'a L,
534 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
535 hash: u64,
536 key: &'b Q,
537 ) -> Self {
538 Self {
539 lifecycle,
540 shard,
541 hash,
542 key,
543 state: JoinFutureState::Created,
544 notified: Default::default(),
545 _pin: PhantomPinned,
546 }
547 }
548}
549
550impl<Q: ?Sized, Key, Val, We, B, L> JoinFuture<'_, '_, Q, Key, Val, We, B, L> {
551 #[cold]
552 fn drop_pending_waiter(&mut self) {
553 let JoinFutureState::Pending { shared, .. } =
554 mem::replace(&mut self.state, JoinFutureState::Done)
555 else {
556 unsafe { unreachable_unchecked() }
557 };
558 let mut state = shared.state.write();
559 match state.loading {
560 LoadingState::Loading if self.notified.load(Ordering::Acquire) => {
561 drop(state); let _ = PlaceholderGuard::start_loading(self.lifecycle, self.shard, shared);
565 }
566 LoadingState::Loading => {
567 let waiter_idx = state
569 .waiters
570 .iter()
571 .position(|w| w.is_waiter(&self.notified as _));
572 if let Some(idx) = waiter_idx {
573 state.waiters.swap_remove(idx);
574 } else {
575 unsafe { unreachable_unchecked() }
577 }
578 }
579 LoadingState::Inserted => (), }
581 }
582}
583
584impl<Q: ?Sized, Key, Val, We, B, L> Drop for JoinFuture<'_, '_, Q, Key, Val, We, B, L> {
585 #[inline]
586 fn drop(&mut self) {
587 if matches!(self.state, JoinFutureState::Pending { .. }) {
588 self.drop_pending_waiter();
589 }
590 }
591}
592
593impl<
594 'a,
595 Key: Eq + Hash,
596 Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
597 Val,
598 We: Weighter<Key, Val>,
599 B: BuildHasher,
600 L: Lifecycle<Key, Val>,
601 > Future for JoinFuture<'a, '_, Q, Key, Val, We, B, L>
602{
603 type Output = JoinResult<'a, Key, Val, We, B, L>;
604
605 fn poll(self: pin::Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
606 let this = unsafe { self.get_unchecked_mut() };
610 let lifecycle = this.lifecycle;
611 let shard = this.shard;
612 match &mut this.state {
613 JoinFutureState::Created => {
614 let mut shard_guard = shard.write();
615 match shard_guard.get_or_placeholder(this.hash, this.key) {
616 Ok(_) => {
617 this.state = JoinFutureState::Done;
618 Poll::Ready(JoinResult::Filled(None))
619 }
620 Err((shared, true)) => {
621 this.state = JoinFutureState::Done;
622 drop(shard_guard);
623 Poll::Ready(JoinResult::Guard(PlaceholderGuard::start_loading(
624 lifecycle, shard, shared,
625 )))
626 }
627 Err((shared, false)) => {
628 let mut waker = None;
631 let already_filled =
632 PlaceholderGuard::join_waiters(shard_guard, &shared, || {
633 let waker_ = cx.waker().clone();
634 waker = Some(waker_.clone());
635 Some(Waiter::Task {
636 waker: waker_,
637 notified: &this.notified as *const AtomicBool,
638 })
639 });
640 if already_filled {
641 this.state = JoinFutureState::Done;
642 Poll::Ready(JoinResult::Filled(Some(shared)))
643 } else {
644 this.state = JoinFutureState::Pending {
645 shared,
646 waker: waker.unwrap(),
647 };
648 Poll::Pending
649 }
650 }
651 }
652 }
653 JoinFutureState::Pending { waker, shared } => {
654 if !this.notified.load(Ordering::Acquire) {
655 let new_waker = cx.waker();
656 if waker.will_wake(new_waker) {
657 return Poll::Pending;
658 }
659 let mut state = shared.state.write();
660 if !this.notified.load(Ordering::Acquire) {
663 let w = unsafe {
664 state
665 .waiters
666 .iter_mut()
667 .find(|w| w.is_waiter(&this.notified as _))
668 .unwrap_unchecked()
669 };
670 *waker = new_waker.clone();
671 *w = Waiter::Task {
672 waker: new_waker.clone(),
673 notified: &this.notified as *const AtomicBool,
674 };
675 return Poll::Pending;
676 }
677 }
678 let JoinFutureState::Pending { shared, .. } =
679 mem::replace(&mut this.state, JoinFutureState::Done)
680 else {
681 unsafe { unreachable_unchecked() }
682 };
683 Poll::Ready(
684 match PlaceholderGuard::handle_notification(lifecycle, shard, shared) {
685 Ok(shared) => JoinResult::Filled(Some(shared)),
686 Err(g) => JoinResult::Guard(g),
687 },
688 )
689 }
690 JoinFutureState::Done => panic!("Polled after ready"),
691 }
692 }
693}