1use std::{
2 future::Future,
3 hash::{BuildHasher, Hash},
4 hint::unreachable_unchecked,
5 marker::PhantomPinned,
6 mem, pin,
7 task::{self, Poll},
8 time::{Duration, Instant},
9};
10
11use crate::{
12 linked_slab::Token,
13 shard::CacheShard,
14 shim::{
15 rw_lock::{RwLock, RwLockWriteGuard},
16 sync::{
17 atomic::{AtomicBool, Ordering},
18 Arc,
19 },
20 thread, OnceLock,
21 },
22 Equivalent, Lifecycle, Weighter,
23};
24
25pub type SharedPlaceholder<Val> = Arc<Placeholder<Val>>;
26
27impl<Val> crate::shard::SharedPlaceholder for SharedPlaceholder<Val> {
28 fn new(hash: u64, idx: Token) -> Self {
29 Arc::new(Placeholder {
30 hash,
31 idx,
32 value: OnceLock::new(),
33 state: RwLock::new(State {
34 waiters: Default::default(),
35 loading: LoadingState::Loading,
36 }),
37 })
38 }
39
40 #[inline]
41 fn same_as(&self, other: &Self) -> bool {
42 Arc::ptr_eq(self, other)
43 }
44
45 #[inline]
46 fn hash(&self) -> u64 {
47 self.hash
48 }
49
50 #[inline]
51 fn idx(&self) -> Token {
52 self.idx
53 }
54}
55
56#[derive(Debug)]
57pub struct Placeholder<Val> {
58 hash: u64,
59 idx: Token,
60 state: RwLock<State>,
61 value: OnceLock<Val>,
62}
63
64impl<Val> Placeholder<Val> {
65 #[inline]
67 pub(crate) fn value(&self) -> Option<&Val> {
68 self.value.get()
69 }
70}
71
72#[derive(Debug)]
73pub struct State {
74 waiters: Vec<Waiter>,
78 loading: LoadingState,
79}
80
81#[derive(Debug)]
82enum LoadingState {
83 Loading,
85 Inserted,
87}
88
89pub struct PlaceholderGuard<'a, Key, Val, We, B, L> {
90 lifecycle: &'a L,
91 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
92 shared: SharedPlaceholder<Val>,
93 inserted: bool,
94}
95
96#[derive(Debug)]
97enum Waiter {
98 Thread {
99 notified: *const AtomicBool,
100 thread: thread::Thread,
101 },
102 Task {
103 notified: *const AtomicBool,
104 waker: task::Waker,
105 },
106}
107
108unsafe impl Send for Waiter {}
111unsafe impl Sync for Waiter {}
112
113impl Waiter {
114 #[inline]
115 fn notify(self) {
116 match self {
117 Waiter::Thread {
118 thread, notified, ..
119 } => {
120 unsafe { notified.as_ref().unwrap().store(true, Ordering::Release) };
123 thread.unpark();
124 }
125 Waiter::Task { waker: t, notified } => {
126 unsafe { notified.as_ref().unwrap().store(true, Ordering::Release) };
127 t.wake();
128 }
129 }
130 }
131
132 #[inline]
133 fn is_waiter(&self, other: *const AtomicBool) -> bool {
134 matches!(self, Waiter::Task { notified, .. } | Waiter::Thread { notified, .. } if std::ptr::eq(*notified, other))
135 }
136}
137
138#[derive(Debug)]
143pub enum GuardResult<'a, Key, Val, We, B, L> {
144 Value(Val),
146 Guard(PlaceholderGuard<'a, Key, Val, We, B, L>),
148 Timeout,
150}
151
152pub use crate::shard::EntryAction;
154
155pub(crate) enum JoinResult<'a, Key, Val, We, B, L> {
157 Filled(Option<SharedPlaceholder<Val>>),
160 Guard(PlaceholderGuard<'a, Key, Val, We, B, L>),
162 Timeout,
164}
165
166#[derive(Debug)]
169pub enum EntryResult<'a, Key, Val, We, B, L, T> {
170 Retained(T),
173 Removed(Key, Val),
176 Replaced(PlaceholderGuard<'a, Key, Val, We, B, L>, Val),
179 Vacant(PlaceholderGuard<'a, Key, Val, We, B, L>),
181 Timeout,
187}
188
189impl<'a, Key, Val, We, B, L> PlaceholderGuard<'a, Key, Val, We, B, L> {
190 #[inline]
191 pub fn start_loading(
192 lifecycle: &'a L,
193 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
194 shared: SharedPlaceholder<Val>,
195 ) -> Self {
196 debug_assert!(matches!(
197 shared.state.write().loading,
198 LoadingState::Loading
199 ));
200 PlaceholderGuard {
201 lifecycle,
202 shard,
203 shared,
204 inserted: false,
205 }
206 }
207
208 #[inline]
211 fn handle_notification(
212 lifecycle: &'a L,
213 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
214 shared: SharedPlaceholder<Val>,
215 ) -> Result<SharedPlaceholder<Val>, PlaceholderGuard<'a, Key, Val, We, B, L>> {
216 if shared.value().is_some() {
219 Ok(shared)
220 } else {
221 Err(PlaceholderGuard::start_loading(lifecycle, shard, shared))
222 }
223 }
224
225 #[inline]
227 fn join_waiters(
228 _locked_shard: RwLockWriteGuard<'a, CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
230 shared: &SharedPlaceholder<Val>,
231 waiter_new: impl FnOnce() -> Option<Waiter>,
233 ) -> bool {
234 let mut state = shared.state.write();
235 match state.loading {
240 LoadingState::Loading => {
241 if let Some(waiter) = waiter_new() {
242 state.waiters.push(waiter);
243 }
244 false
245 }
246 LoadingState::Inserted => true,
247 }
248 }
249}
250
251impl<
252 'a,
253 Key: Eq + Hash,
254 Val: Clone,
255 We: Weighter<Key, Val>,
256 B: BuildHasher,
257 L: Lifecycle<Key, Val>,
258 > PlaceholderGuard<'a, Key, Val, We, B, L>
259{
260 pub fn join<Q>(
261 lifecycle: &'a L,
262 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
263 hash: u64,
264 key: &Q,
265 timeout: Option<Duration>,
266 ) -> GuardResult<'a, Key, Val, We, B, L>
267 where
268 Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
269 {
270 let mut shard_guard = shard.write();
271 let shared = match shard_guard.get_or_placeholder(hash, key) {
272 Ok((_, v)) => return GuardResult::Value(v.clone()),
273 Err((shared, true)) => {
274 return GuardResult::Guard(Self::start_loading(lifecycle, shard, shared));
275 }
276 Err((shared, false)) => shared,
277 };
278 let mut deadline = timeout.map(Ok);
279 match Self::wait_for_placeholder(lifecycle, shard, shard_guard, shared, deadline.as_mut()) {
280 JoinResult::Filled(shared) => unsafe {
281 GuardResult::Value(shared.unwrap_unchecked().value().unwrap_unchecked().clone())
283 },
284 JoinResult::Guard(g) => GuardResult::Guard(g),
285 JoinResult::Timeout => GuardResult::Timeout,
286 }
287 }
288
289 pub(crate) fn wait_for_placeholder(
298 lifecycle: &'a L,
299 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
300 shard_guard: RwLockWriteGuard<'a, CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
301 shared: SharedPlaceholder<Val>,
302 deadline: Option<&mut Result<Duration, Instant>>,
303 ) -> JoinResult<'a, Key, Val, We, B, L> {
304 let notified = pin::pin!(AtomicBool::new(false));
305 let mut parked_thread = None;
306 let already_filled = Self::join_waiters(shard_guard, &shared, || {
307 if matches!(deadline.as_deref(), Some(Ok(d)) if d.is_zero()) {
311 None
312 } else {
313 let thread = thread::current();
314 parked_thread = Some(thread.id());
315 Some(Waiter::Thread {
316 thread,
317 notified: &*notified as *const AtomicBool,
318 })
319 }
320 });
321 if already_filled {
322 return JoinResult::Filled(Some(shared));
323 }
324
325 let deadline = deadline.and_then(|d| match *d {
328 Ok(dur) => match Instant::now().checked_add(dur) {
329 Some(instant) => {
330 *d = Err(instant);
331 Some(instant)
332 }
333 None => None, },
335 Err(instant) => Some(instant),
336 });
337 loop {
338 if let Some(instant) = deadline {
339 let remaining = instant.saturating_duration_since(Instant::now());
340 if remaining.is_zero() {
341 return Self::join_timeout(lifecycle, shard, shared, parked_thread, ¬ified);
342 }
343 #[cfg(not(fuzzing))]
344 thread::park_timeout(remaining);
345 } else {
346 #[cfg(not(fuzzing))]
347 thread::park();
348 }
349 if notified.load(Ordering::Acquire) {
350 return match Self::handle_notification(lifecycle, shard, shared) {
351 Ok(shared) => JoinResult::Filled(Some(shared)),
352 Err(g) => JoinResult::Guard(g),
353 };
354 }
355 }
356 }
357
358 #[cold]
359 fn join_timeout(
360 lifecycle: &'a L,
361 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, Arc<Placeholder<Val>>>>,
362 shared: Arc<Placeholder<Val>>,
363 parked_thread: Option<thread::ThreadId>,
365 notified: &AtomicBool,
366 ) -> JoinResult<'a, Key, Val, We, B, L> {
367 let mut state = shared.state.write();
368 match state.loading {
369 LoadingState::Loading if notified.load(Ordering::Acquire) => {
370 drop(state); JoinResult::Guard(PlaceholderGuard::start_loading(lifecycle, shard, shared))
372 }
373 LoadingState::Loading => {
374 if parked_thread.is_some() {
375 let waiter_idx = state
377 .waiters
378 .iter()
379 .position(|w| w.is_waiter(notified as _));
380 if let Some(idx) = waiter_idx {
381 state.waiters.swap_remove(idx);
382 } else {
383 unsafe { unreachable_unchecked() };
384 }
385 }
386 JoinResult::Timeout
387 }
388 LoadingState::Inserted => {
389 drop(state);
390 JoinResult::Filled(Some(shared))
391 }
392 }
393 }
394}
395
396impl<
397 Key: Eq + Hash,
398 Val: Clone,
399 We: Weighter<Key, Val>,
400 B: BuildHasher,
401 L: Lifecycle<Key, Val>,
402 > PlaceholderGuard<'_, Key, Val, We, B, L>
403{
404 pub fn insert(self, value: Val) -> Result<(), Val> {
410 let lifecycle = self.lifecycle;
411 let lcs = self.insert_with_lifecycle(value)?;
412 lifecycle.end_request(lcs);
413 Ok(())
414 }
415
416 pub fn insert_with_lifecycle(mut self, value: Val) -> Result<L::RequestState, Val> {
422 unsafe { self.shared.value.set(value.clone()).unwrap_unchecked() };
423 let referenced;
424 {
425 let mut state = self.shared.state.write();
429 state.loading = LoadingState::Inserted;
430 referenced = !state.waiters.is_empty();
431 for w in state.waiters.drain(..) {
432 w.notify();
433 }
434 }
435
436 self.inserted = true;
441
442 let mut lcs = self.lifecycle.begin_request();
443 self.shard
444 .write()
445 .replace_placeholder(&mut lcs, &self.shared, referenced, value)?;
446 Ok(lcs)
447 }
448}
449
450impl<Key, Val, We, B, L> PlaceholderGuard<'_, Key, Val, We, B, L> {
451 #[cold]
452 fn drop_uninserted_slow(&mut self) {
453 {
457 let mut state = self.shared.state.write();
458 debug_assert!(matches!(state.loading, LoadingState::Loading));
459 if let Some(waiter) = state.waiters.pop() {
460 waiter.notify();
461 return;
462 }
463 }
464
465 let mut shard_guard = self.shard.write();
469 let mut state = self.shared.state.write();
470 debug_assert!(matches!(state.loading, LoadingState::Loading));
471 if let Some(waiter) = state.waiters.pop() {
472 drop(shard_guard);
473 waiter.notify();
474 } else {
475 shard_guard.remove_placeholder(&self.shared);
476 }
477 }
478}
479
480impl<Key, Val, We, B, L> Drop for PlaceholderGuard<'_, Key, Val, We, B, L> {
481 #[inline]
482 fn drop(&mut self) {
483 if !self.inserted {
484 self.drop_uninserted_slow();
485 }
486 }
487}
488impl<Key, Val, We, B, L> std::fmt::Debug for PlaceholderGuard<'_, Key, Val, We, B, L> {
489 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
490 f.debug_struct("PlaceholderGuard").finish_non_exhaustive()
491 }
492}
493
494pub(crate) struct JoinFuture<'a, 'b, Q: ?Sized, Key, Val, We, B, L> {
506 lifecycle: &'a L,
507 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
508 hash: u64,
509 key: &'b Q,
510 state: JoinFutureState<Val>,
511 notified: AtomicBool,
512 _pin: PhantomPinned,
513}
514
515enum JoinFutureState<Val> {
516 Created,
517 Pending {
518 shared: SharedPlaceholder<Val>,
519 waker: task::Waker,
520 },
521 Done,
522}
523
524impl<'a, 'b, Q: ?Sized, Key, Val, We, B, L> JoinFuture<'a, 'b, Q, Key, Val, We, B, L> {
525 pub(crate) fn new(
526 lifecycle: &'a L,
527 shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
528 hash: u64,
529 key: &'b Q,
530 ) -> Self {
531 Self {
532 lifecycle,
533 shard,
534 hash,
535 key,
536 state: JoinFutureState::Created,
537 notified: Default::default(),
538 _pin: PhantomPinned,
539 }
540 }
541}
542
543impl<Q: ?Sized, Key, Val, We, B, L> JoinFuture<'_, '_, Q, Key, Val, We, B, L> {
544 #[cold]
545 fn drop_pending_waiter(&mut self) {
546 let JoinFutureState::Pending { shared, .. } =
547 mem::replace(&mut self.state, JoinFutureState::Done)
548 else {
549 unsafe { unreachable_unchecked() }
550 };
551 let mut state = shared.state.write();
552 match state.loading {
553 LoadingState::Loading if self.notified.load(Ordering::Acquire) => {
554 drop(state); let _ = PlaceholderGuard::start_loading(self.lifecycle, self.shard, shared);
558 }
559 LoadingState::Loading => {
560 let waiter_idx = state
562 .waiters
563 .iter()
564 .position(|w| w.is_waiter(&self.notified as _));
565 if let Some(idx) = waiter_idx {
566 state.waiters.swap_remove(idx);
567 } else {
568 unsafe { unreachable_unchecked() }
570 }
571 }
572 LoadingState::Inserted => (), }
574 }
575}
576
577impl<Q: ?Sized, Key, Val, We, B, L> Drop for JoinFuture<'_, '_, Q, Key, Val, We, B, L> {
578 #[inline]
579 fn drop(&mut self) {
580 if matches!(self.state, JoinFutureState::Pending { .. }) {
581 self.drop_pending_waiter();
582 }
583 }
584}
585
586impl<
587 'a,
588 Key: Eq + Hash,
589 Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
590 Val,
591 We: Weighter<Key, Val>,
592 B: BuildHasher,
593 L: Lifecycle<Key, Val>,
594 > Future for JoinFuture<'a, '_, Q, Key, Val, We, B, L>
595{
596 type Output = JoinResult<'a, Key, Val, We, B, L>;
597
598 fn poll(self: pin::Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
599 let this = unsafe { self.get_unchecked_mut() };
603 let lifecycle = this.lifecycle;
604 let shard = this.shard;
605 match &mut this.state {
606 JoinFutureState::Created => {
607 let mut shard_guard = shard.write();
608 match shard_guard.get_or_placeholder(this.hash, this.key) {
609 Ok(_) => {
610 this.state = JoinFutureState::Done;
611 Poll::Ready(JoinResult::Filled(None))
612 }
613 Err((shared, true)) => {
614 this.state = JoinFutureState::Done;
615 drop(shard_guard);
616 Poll::Ready(JoinResult::Guard(PlaceholderGuard::start_loading(
617 lifecycle, shard, shared,
618 )))
619 }
620 Err((shared, false)) => {
621 let mut waker = None;
624 let already_filled =
625 PlaceholderGuard::join_waiters(shard_guard, &shared, || {
626 let waker_ = cx.waker().clone();
627 waker = Some(waker_.clone());
628 Some(Waiter::Task {
629 waker: waker_,
630 notified: &this.notified as *const AtomicBool,
631 })
632 });
633 if already_filled {
634 this.state = JoinFutureState::Done;
635 Poll::Ready(JoinResult::Filled(Some(shared)))
636 } else {
637 this.state = JoinFutureState::Pending {
638 shared,
639 waker: waker.unwrap(),
640 };
641 Poll::Pending
642 }
643 }
644 }
645 }
646 JoinFutureState::Pending { waker, shared } => {
647 if !this.notified.load(Ordering::Acquire) {
648 let new_waker = cx.waker();
649 if waker.will_wake(new_waker) {
650 return Poll::Pending;
651 }
652 let mut state = shared.state.write();
653 if !this.notified.load(Ordering::Acquire) {
656 let w = unsafe {
657 state
658 .waiters
659 .iter_mut()
660 .find(|w| w.is_waiter(&this.notified as _))
661 .unwrap_unchecked()
662 };
663 *waker = new_waker.clone();
664 *w = Waiter::Task {
665 waker: new_waker.clone(),
666 notified: &this.notified as *const AtomicBool,
667 };
668 return Poll::Pending;
669 }
670 }
671 let JoinFutureState::Pending { shared, .. } =
672 mem::replace(&mut this.state, JoinFutureState::Done)
673 else {
674 unsafe { unreachable_unchecked() }
675 };
676 Poll::Ready(
677 match PlaceholderGuard::handle_notification(lifecycle, shard, shared) {
678 Ok(shared) => JoinResult::Filled(Some(shared)),
679 Err(g) => JoinResult::Guard(g),
680 },
681 )
682 }
683 JoinFutureState::Done => panic!("Polled after ready"),
684 }
685 }
686}