quick_cache/
sync_placeholder.rs

1use std::{
2    future::Future,
3    hash::{BuildHasher, Hash},
4    hint::unreachable_unchecked,
5    mem, pin,
6    task::{self, Poll},
7    time::{Duration, Instant},
8};
9
10use crate::{
11    linked_slab::Token,
12    shard::CacheShard,
13    shim::{
14        rw_lock::{RwLock, RwLockWriteGuard},
15        sync::{
16            atomic::{AtomicBool, Ordering},
17            Arc,
18        },
19        thread, OnceLock,
20    },
21    Equivalent, Lifecycle, Weighter,
22};
23
24pub type SharedPlaceholder<Val> = Arc<Placeholder<Val>>;
25
26impl<Val> crate::shard::SharedPlaceholder for SharedPlaceholder<Val> {
27    fn new(hash: u64, idx: Token) -> Self {
28        Arc::new(Placeholder {
29            hash,
30            idx,
31            value: OnceLock::new(),
32            state: RwLock::new(State {
33                waiters: Default::default(),
34                loading: LoadingState::Loading,
35            }),
36        })
37    }
38
39    #[inline]
40    fn same_as(&self, other: &Self) -> bool {
41        Arc::ptr_eq(self, other)
42    }
43
44    #[inline]
45    fn hash(&self) -> u64 {
46        self.hash
47    }
48
49    #[inline]
50    fn idx(&self) -> Token {
51        self.idx
52    }
53}
54
55#[derive(Debug)]
56pub struct Placeholder<Val> {
57    hash: u64,
58    idx: Token,
59    state: RwLock<State>,
60    value: OnceLock<Val>,
61}
62
63#[derive(Debug)]
64pub struct State {
65    /// The waiters list
66    /// Adding to the list requires holding the outer shard lock to avoid races between
67    /// removing the orphan placeholder from the cache and adding a new waiter to it.
68    waiters: Vec<Waiter>,
69    loading: LoadingState,
70}
71
72#[derive(Debug)]
73enum LoadingState {
74    /// A guard was/will be created and the value might get filled
75    Loading,
76    /// A value was filled, no more waiters can be added
77    Inserted,
78}
79
80pub struct PlaceholderGuard<'a, Key, Val, We, B, L> {
81    lifecycle: &'a L,
82    shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
83    shared: SharedPlaceholder<Val>,
84    inserted: bool,
85}
86
87#[derive(Debug)]
88enum Waiter {
89    Thread {
90        notified: *const AtomicBool,
91        thread: thread::Thread,
92    },
93    Task {
94        notified: *const AtomicBool,
95        waker: task::Waker,
96    },
97}
98
99// SAFETY: The AtomicBool is on the waiting thread's stack or pinned future
100// and the thread/task will remove itself from waiters before returning
101unsafe impl Send for Waiter {}
102unsafe impl Sync for Waiter {}
103
104impl Waiter {
105    #[inline]
106    fn notify(self) {
107        match self {
108            Waiter::Thread {
109                thread, notified, ..
110            } => {
111                // SAFETY: The AtomicBool is on the waiting thread's stack or pinned future
112                // and the thread/task will remove itself from waiters before returning
113                unsafe { notified.as_ref().unwrap().store(true, Ordering::Release) };
114                thread.unpark();
115            }
116            Waiter::Task { waker: t, notified } => {
117                unsafe { notified.as_ref().unwrap().store(true, Ordering::Release) };
118                t.wake();
119            }
120        }
121    }
122
123    #[inline]
124    fn is_waiter(&self, other: *const AtomicBool) -> bool {
125        matches!(self, Waiter::Task { notified, .. } | Waiter::Thread { notified, .. } if std::ptr::eq(*notified, other))
126    }
127}
128
129#[derive(Debug)]
130pub enum GuardResult<'a, Key, Val, We, B, L> {
131    Value(Val),
132    Guard(PlaceholderGuard<'a, Key, Val, We, B, L>),
133    Timeout,
134}
135
136impl<'a, Key, Val, We, B, L> PlaceholderGuard<'a, Key, Val, We, B, L> {
137    pub fn start_loading(
138        lifecycle: &'a L,
139        shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
140        shared: SharedPlaceholder<Val>,
141    ) -> Self {
142        debug_assert!(matches!(
143            shared.state.write().loading,
144            LoadingState::Loading
145        ));
146        PlaceholderGuard {
147            lifecycle,
148            shard,
149            shared,
150            inserted: false,
151        }
152    }
153
154    // Check the state of the placeholder, returning the value if it was loaded
155    // or a guard if the caller got the guard.
156    #[inline]
157    fn handle_notification(
158        lifecycle: &'a L,
159        shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
160        shared: SharedPlaceholder<Val>,
161    ) -> Result<Val, PlaceholderGuard<'a, Key, Val, We, B, L>>
162    where
163        Val: Clone,
164    {
165        // Check if the value was loaded, and if it wasn't it means we got the
166        // guard and need to start loading the value.
167        if let Some(v) = shared.value.get() {
168            Ok(v.clone())
169        } else {
170            Err(PlaceholderGuard::start_loading(lifecycle, shard, shared))
171        }
172    }
173
174    // Join the waiters list or return the value if it was already loaded
175    #[inline]
176    fn join_waiters(
177        // we require the shard lock to be held to add a new waiter
178        _locked_shard: RwLockWriteGuard<'a, CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
179        shared: &SharedPlaceholder<Val>,
180        // a function that returns a waiter if it should be added
181        waiter_new: impl FnOnce() -> Option<Waiter>,
182    ) -> Option<Val>
183    where
184        Val: Clone,
185    {
186        let mut state = shared.state.write();
187        // _locked_shard could be released here, it would be sufficient to synchronize with the holder
188        // of the guard trying to remove the placeholder from the cache. But if this placeholder is hot,
189        // anyone waiting on the shard will immediately hit the state lock. Since the cache is sharded
190        // we consider the latter more likely. So we keep the shard lock until we are done with the state.
191        match state.loading {
192            LoadingState::Loading => {
193                if let Some(waiter) = waiter_new() {
194                    state.waiters.push(waiter);
195                }
196                None
197            }
198            LoadingState::Inserted => unsafe {
199                // SAFETY: The value is guaranteed to be set at this point
200                drop(state); // Allow cloning outside the lock
201                Some(shared.value.get().unwrap_unchecked().clone())
202            },
203        }
204    }
205}
206
207impl<
208        'a,
209        Key: Eq + Hash,
210        Val: Clone,
211        We: Weighter<Key, Val>,
212        B: BuildHasher,
213        L: Lifecycle<Key, Val>,
214    > PlaceholderGuard<'a, Key, Val, We, B, L>
215{
216    pub fn join<Q>(
217        lifecycle: &'a L,
218        shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
219        hash: u64,
220        key: &Q,
221        mut timeout: Option<Duration>,
222    ) -> GuardResult<'a, Key, Val, We, B, L>
223    where
224        Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
225    {
226        let mut shard_guard = shard.write();
227        let shared = match shard_guard.upsert_placeholder(hash, key) {
228            Ok((_, v)) => return GuardResult::Value(v.clone()),
229            Err((shared, true)) => {
230                return GuardResult::Guard(Self::start_loading(lifecycle, shard, shared));
231            }
232            Err((shared, false)) => shared,
233        };
234
235        // Create notified flag on stack - this will live for the entire duration of join
236        let notified = pin::pin!(AtomicBool::new(false));
237        // Set if the thread was added to the waiters list
238        let mut parked_thread = None;
239        let maybe_val = Self::join_waiters(shard_guard, &shared, || {
240            if timeout.is_some_and(|t| t.is_zero()) {
241                None
242            } else {
243                let thread = thread::current();
244                let id = thread.id();
245                parked_thread = Some(id);
246                Some(Waiter::Thread {
247                    thread,
248                    notified: &*notified as *const AtomicBool,
249                })
250            }
251        });
252        if let Some(v) = maybe_val {
253            return GuardResult::Value(v);
254        }
255
256        // Track the start time of the timeout, set lazily
257        let mut timeout_start = None;
258        loop {
259            if let Some(remaining) = timeout {
260                if remaining.is_zero() {
261                    return Self::join_timeout(lifecycle, shard, shared, parked_thread, &notified);
262                }
263                let start = *timeout_start.get_or_insert_with(Instant::now);
264                #[cfg(not(fuzzing))]
265                thread::park_timeout(remaining);
266                timeout = Some(remaining.saturating_sub(start.elapsed()));
267            } else {
268                thread::park();
269            }
270            if notified.load(Ordering::Acquire) {
271                return match Self::handle_notification(lifecycle, shard, shared) {
272                    Ok(v) => GuardResult::Value(v),
273                    Err(g) => GuardResult::Guard(g),
274                };
275            }
276        }
277    }
278
279    #[cold]
280    fn join_timeout(
281        lifecycle: &'a L,
282        shard: &'a RwLock<CacheShard<Key, Val, We, B, L, Arc<Placeholder<Val>>>>,
283        shared: Arc<Placeholder<Val>>,
284        // when timeout is zero, the thread may have not been added to the waiters list
285        parked_thread: Option<thread::ThreadId>,
286        notified: &AtomicBool,
287    ) -> GuardResult<'a, Key, Val, We, B, L> {
288        let mut state = shared.state.write();
289        match state.loading {
290            LoadingState::Loading if notified.load(Ordering::Acquire) => {
291                drop(state); // Drop state guard to avoid a deadlock with start_loading
292                GuardResult::Guard(PlaceholderGuard::start_loading(lifecycle, shard, shared))
293            }
294            LoadingState::Loading => {
295                if parked_thread.is_some() {
296                    // Remove ourselves from the waiters list
297                    let waiter_idx = state
298                        .waiters
299                        .iter()
300                        .position(|w| w.is_waiter(notified as _));
301                    if let Some(idx) = waiter_idx {
302                        state.waiters.swap_remove(idx);
303                    } else {
304                        unsafe { unreachable_unchecked() };
305                    }
306                }
307                GuardResult::Timeout
308            }
309            LoadingState::Inserted => unsafe {
310                // SAFETY: The value is guaranteed to be set at this point
311                GuardResult::Value(shared.value.get().unwrap_unchecked().clone())
312            },
313        }
314    }
315}
316
317impl<
318        Key: Eq + Hash,
319        Val: Clone,
320        We: Weighter<Key, Val>,
321        B: BuildHasher,
322        L: Lifecycle<Key, Val>,
323    > PlaceholderGuard<'_, Key, Val, We, B, L>
324{
325    /// Inserts the value into the placeholder
326    ///
327    /// Returns Err if the placeholder isn't in the cache anymore.
328    /// A placeholder can be removed as a result of a `remove` call
329    /// or a non-placeholder `insert` with the same key.
330    pub fn insert(self, value: Val) -> Result<(), Val> {
331        let lifecycle = self.lifecycle;
332        let lcs = self.insert_with_lifecycle(value)?;
333        lifecycle.end_request(lcs);
334        Ok(())
335    }
336
337    /// Inserts the value into the placeholder
338    ///
339    /// Returns Err if the placeholder isn't in the cache anymore.
340    /// A placeholder can be removed as a result of a `remove` call
341    /// or a non-placeholder `insert` with the same key.
342    pub fn insert_with_lifecycle(mut self, value: Val) -> Result<L::RequestState, Val> {
343        unsafe { self.shared.value.set(value.clone()).unwrap_unchecked() };
344        let referenced;
345        {
346            // Whoever is already waiting will get notified and hit the fast-path
347            // as they will see the value set. Anyone that races trying to add themselves
348            // to the waiters list will wait on the state lock.
349            let mut state = self.shared.state.write();
350            state.loading = LoadingState::Inserted;
351            referenced = !state.waiters.is_empty();
352            for w in state.waiters.drain(..) {
353                w.notify();
354            }
355        }
356
357        // Set flag to disable drop_uninserted_slow, it has no work to do:
358        //   - waiters have already been drained
359        //   - no waiters can be added because we set LoadingState::Inserted
360        //   - the placeholder will be removed here, if it still exists
361        self.inserted = true;
362
363        let mut lcs = self.lifecycle.begin_request();
364        self.shard
365            .write()
366            .replace_placeholder(&mut lcs, &self.shared, referenced, value)?;
367        Ok(lcs)
368    }
369}
370
371impl<Key, Val, We, B, L> PlaceholderGuard<'_, Key, Val, We, B, L> {
372    #[cold]
373    fn drop_uninserted_slow(&mut self) {
374        // Fast path: check if there are other waiters without the shard lock
375        // This may or may not be common, but the assumption is that the shard lock is hot
376        // and should be avoided if possible.
377        {
378            let mut state = self.shared.state.write();
379            debug_assert!(matches!(state.loading, LoadingState::Loading));
380            if let Some(waiter) = state.waiters.pop() {
381                waiter.notify();
382                return;
383            }
384        }
385
386        // Slow path: acquire shard lock and re-check
387        // By acquiring the shard lock we synchronize with any other threads that might be
388        // trying to add themselves to the waiters list.
389        let mut shard_guard = self.shard.write();
390        let mut state = self.shared.state.write();
391        debug_assert!(matches!(state.loading, LoadingState::Loading));
392        if let Some(waiter) = state.waiters.pop() {
393            drop(shard_guard);
394            waiter.notify();
395        } else {
396            shard_guard.remove_placeholder(&self.shared);
397        }
398    }
399}
400
401impl<Key, Val, We, B, L> Drop for PlaceholderGuard<'_, Key, Val, We, B, L> {
402    #[inline]
403    fn drop(&mut self) {
404        if !self.inserted {
405            self.drop_uninserted_slow();
406        }
407    }
408}
409impl<Key, Val, We, B, L> std::fmt::Debug for PlaceholderGuard<'_, Key, Val, We, B, L> {
410    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
411        f.debug_struct("PlaceholderGuard").finish_non_exhaustive()
412    }
413}
414
415/// Future that results in an Ok(Value) or Err(Guard)
416pub struct JoinFuture<'a, 'b, Q: ?Sized, Key, Val, We, B, L> {
417    lifecycle: &'a L,
418    shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
419    state: JoinFutureState<'b, Q, Val>,
420    notified: AtomicBool,
421}
422
423enum JoinFutureState<'b, Q: ?Sized, Val> {
424    Created {
425        hash: u64,
426        key: &'b Q,
427    },
428    Pending {
429        shared: SharedPlaceholder<Val>,
430        waker: task::Waker,
431    },
432    Done,
433}
434
435impl<'a, 'b, Q: ?Sized, Key, Val, We, B, L> JoinFuture<'a, 'b, Q, Key, Val, We, B, L> {
436    pub fn new(
437        lifecycle: &'a L,
438        shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
439        hash: u64,
440        key: &'b Q,
441    ) -> JoinFuture<'a, 'b, Q, Key, Val, We, B, L> {
442        Self {
443            lifecycle,
444            shard,
445            state: JoinFutureState::Created { hash, key },
446            notified: Default::default(),
447        }
448    }
449
450    #[cold]
451    fn drop_pending_waiter(&mut self) {
452        let JoinFutureState::Pending { shared, .. } =
453            mem::replace(&mut self.state, JoinFutureState::Done)
454        else {
455            unsafe { unreachable_unchecked() }
456        };
457        let mut state = shared.state.write();
458        match state.loading {
459            LoadingState::Loading if self.notified.load(Ordering::Acquire) => {
460                // The write guard was abandoned elsewhere, this future was notified but didn't get polled.
461                // So we get and drop the guard here to handle the side effects.
462                drop(state); // Drop state guard to avoid a deadlock with start_loading
463                let _ = PlaceholderGuard::start_loading(self.lifecycle, self.shard, shared);
464            }
465            LoadingState::Loading => {
466                // Remove ourselves from the waiters list
467                let waiter_idx = state
468                    .waiters
469                    .iter()
470                    .position(|w| w.is_waiter(&self.notified as _));
471                if let Some(idx) = waiter_idx {
472                    state.waiters.swap_remove(idx);
473                } else {
474                    // We didn't find ourselves in the waiters list!?
475                    unsafe { unreachable_unchecked() }
476                }
477            }
478            LoadingState::Inserted => (), // We were notified but didn't get polled - nothing to do
479        }
480    }
481}
482
483impl<Q: ?Sized, Key, Val, We, B, L> Drop for JoinFuture<'_, '_, Q, Key, Val, We, B, L> {
484    #[inline]
485    fn drop(&mut self) {
486        if matches!(self.state, JoinFutureState::Pending { .. }) {
487            self.drop_pending_waiter();
488        }
489    }
490}
491
492impl<
493        'a,
494        Key: Eq + Hash,
495        Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
496        Val: Clone,
497        We: Weighter<Key, Val>,
498        B: BuildHasher,
499        L: Lifecycle<Key, Val>,
500    > Future for JoinFuture<'a, '_, Q, Key, Val, We, B, L>
501{
502    type Output = Result<Val, PlaceholderGuard<'a, Key, Val, We, B, L>>;
503
504    fn poll(mut self: pin::Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
505        let this = &mut *self;
506        let lifecycle = this.lifecycle;
507        let shard = this.shard;
508        match &mut this.state {
509            JoinFutureState::Created { hash, key } => {
510                debug_assert!(!this.notified.load(Ordering::Acquire));
511                let mut shard_guard = shard.write();
512                match shard_guard.upsert_placeholder(*hash, *key) {
513                    Ok((_, v)) => {
514                        this.state = JoinFutureState::Done;
515                        Poll::Ready(Ok(v.clone()))
516                    }
517                    Err((shared, true)) => {
518                        let guard = PlaceholderGuard::start_loading(lifecycle, shard, shared);
519                        this.state = JoinFutureState::Done;
520                        Poll::Ready(Err(guard))
521                    }
522                    Err((shared, false)) => {
523                        let mut waker = None;
524                        let maybe_val =
525                            PlaceholderGuard::join_waiters(shard_guard, &shared, || {
526                                let waker_ = cx.waker().clone();
527                                waker = Some(waker_.clone());
528                                Some(Waiter::Task {
529                                    waker: waker_,
530                                    notified: &this.notified as *const AtomicBool,
531                                })
532                            });
533                        if let Some(v) = maybe_val {
534                            debug_assert!(waker.is_none());
535                            debug_assert!(!this.notified.load(Ordering::Acquire));
536                            this.state = JoinFutureState::Done;
537                            Poll::Ready(Ok(v))
538                        } else {
539                            let waker = waker.unwrap();
540                            this.state = JoinFutureState::Pending { shared, waker };
541                            Poll::Pending
542                        }
543                    }
544                }
545            }
546            JoinFutureState::Pending { .. } if this.notified.load(Ordering::Acquire) => {
547                let JoinFutureState::Pending { shared, .. } =
548                    mem::replace(&mut this.state, JoinFutureState::Done)
549                else {
550                    unsafe { unreachable_unchecked() }
551                };
552                Poll::Ready(PlaceholderGuard::handle_notification(
553                    lifecycle, shard, shared,
554                ))
555            }
556            JoinFutureState::Pending { waker, shared } => {
557                // Update waker in case it changed
558                let new_waker = cx.waker();
559                if !waker.will_wake(new_waker) {
560                    let mut state = shared.state.write();
561                    if let Some(w) = state
562                        .waiters
563                        .iter_mut()
564                        .find(|w| w.is_waiter(&this.notified as _))
565                    {
566                        *waker = new_waker.clone();
567                        *w = Waiter::Task {
568                            waker: new_waker.clone(),
569                            notified: &this.notified as *const AtomicBool,
570                        };
571                    } else {
572                        unsafe { unreachable_unchecked() };
573                    }
574                }
575                Poll::Pending
576            }
577            JoinFutureState::Done => panic!("Polled after ready"),
578        }
579    }
580}