1use std::{
2 future::Future,
3 hash::{BuildHasher, Hash},
4 hint::unreachable_unchecked,
5 time::Duration,
6};
7
8use crate::{
9 linked_slab::Token,
10 options::{Options, OptionsBuilder},
11 shard::{CacheShard, InsertStrategy},
12 shim::rw_lock::RwLock,
13 sync_placeholder::SharedPlaceholder,
14 DefaultHashBuilder, Equivalent, Lifecycle, MemoryUsed, UnitWeighter, Weighter,
15};
16
17pub use crate::sync_placeholder::{GuardResult, JoinFuture, PlaceholderGuard};
18
19pub struct Cache<
34 Key,
35 Val,
36 We = UnitWeighter,
37 B = DefaultHashBuilder,
38 L = DefaultLifecycle<Key, Val>,
39> {
40 hash_builder: B,
41 shards: Box<[RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>]>,
42 shards_mask: u64,
43 lifecycle: L,
44}
45
46impl<Key: Eq + Hash, Val: Clone> Cache<Key, Val> {
47 pub fn new(items_capacity: usize) -> Self {
49 Self::with(
50 items_capacity,
51 items_capacity as u64,
52 Default::default(),
53 Default::default(),
54 Default::default(),
55 )
56 }
57}
58
59impl<Key: Eq + Hash, Val: Clone, We: Weighter<Key, Val> + Clone> Cache<Key, Val, We> {
60 pub fn with_weighter(
61 estimated_items_capacity: usize,
62 weight_capacity: u64,
63 weighter: We,
64 ) -> Self {
65 Self::with(
66 estimated_items_capacity,
67 weight_capacity,
68 weighter,
69 Default::default(),
70 Default::default(),
71 )
72 }
73}
74
75impl<
76 Key: Eq + Hash,
77 Val: Clone,
78 We: Weighter<Key, Val> + Clone,
79 B: BuildHasher + Clone,
80 L: Lifecycle<Key, Val> + Clone,
81 > Cache<Key, Val, We, B, L>
82{
83 pub fn with(
87 estimated_items_capacity: usize,
88 weight_capacity: u64,
89 weighter: We,
90 hash_builder: B,
91 lifecycle: L,
92 ) -> Self {
93 Self::with_options(
94 OptionsBuilder::new()
95 .estimated_items_capacity(estimated_items_capacity)
96 .weight_capacity(weight_capacity)
97 .build()
98 .unwrap(),
99 weighter,
100 hash_builder,
101 lifecycle,
102 )
103 }
104
105 pub fn with_options(options: Options, weighter: We, hash_builder: B, lifecycle: L) -> Self {
124 let mut num_shards = options.shards.next_power_of_two() as u64;
125 let estimated_items_capacity = options.estimated_items_capacity as u64;
126 let weight_capacity = options.weight_capacity;
127 let mut shard_items_cap =
128 estimated_items_capacity.saturating_add(num_shards - 1) / num_shards;
129 let mut shard_weight_cap =
130 options.weight_capacity.saturating_add(num_shards - 1) / num_shards;
131 while shard_items_cap < 32 && num_shards > 1 {
133 num_shards /= 2;
134 shard_items_cap = estimated_items_capacity.saturating_add(num_shards - 1) / num_shards;
135 shard_weight_cap = weight_capacity.saturating_add(num_shards - 1) / num_shards;
136 }
137 let shards = (0..num_shards)
138 .map(|_| {
139 RwLock::new(CacheShard::new(
140 options.hot_allocation,
141 options.ghost_allocation,
142 shard_items_cap as usize,
143 shard_weight_cap,
144 weighter.clone(),
145 hash_builder.clone(),
146 lifecycle.clone(),
147 ))
148 })
149 .collect::<Vec<_>>();
150 Self {
151 shards: shards.into_boxed_slice(),
152 hash_builder,
153 shards_mask: num_shards - 1,
154 lifecycle,
155 }
156 }
157
158 #[cfg(fuzzing)]
159 pub fn validate(&self) {
160 for s in &*self.shards {
161 s.read().validate(false)
162 }
163 }
164
165 pub fn is_empty(&self) -> bool {
167 self.shards.iter().all(|s| s.read().len() == 0)
168 }
169
170 pub fn len(&self) -> usize {
172 self.shards.iter().map(|s| s.read().len()).sum()
173 }
174
175 pub fn weight(&self) -> u64 {
177 self.shards.iter().map(|s| s.read().weight()).sum()
178 }
179
180 pub fn capacity(&self) -> u64 {
184 self.shards.iter().map(|s| s.read().capacity()).sum()
185 }
186
187 pub fn shard_capacity(&self) -> u64 {
189 self.shards[0].read().capacity()
190 }
191
192 pub fn num_shards(&self) -> usize {
194 self.shards.len()
195 }
196
197 #[cfg(feature = "stats")]
199 pub fn misses(&self) -> u64 {
200 self.shards.iter().map(|s| s.read().misses()).sum()
201 }
202
203 #[cfg(feature = "stats")]
205 pub fn hits(&self) -> u64 {
206 self.shards.iter().map(|s| s.read().hits()).sum()
207 }
208
209 #[inline]
210 fn shard_for<Q>(
211 &self,
212 key: &Q,
213 ) -> Option<(
214 &RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
215 u64,
216 )>
217 where
218 Q: Hash + Equivalent<Key> + ?Sized,
219 {
220 let hash = self.hash_builder.hash_one(key);
221 let shard_idx = (hash.rotate_right(usize::BITS / 2) & self.shards_mask) as usize;
226 self.shards.get(shard_idx).map(|s| (s, hash))
227 }
228
229 pub fn reserve(&self, additional: usize) {
232 let additional_per_shard =
233 additional.saturating_add(self.shards.len() - 1) / self.shards.len();
234 for s in &*self.shards {
235 s.write().reserve(additional_per_shard);
236 }
237 }
238
239 pub fn contains_key<Q>(&self, key: &Q) -> bool
241 where
242 Q: Hash + Equivalent<Key> + ?Sized,
243 {
244 self.shard_for(key)
245 .is_some_and(|(shard, hash)| shard.read().contains(hash, key))
246 }
247
248 pub fn get<Q>(&self, key: &Q) -> Option<Val>
250 where
251 Q: Hash + Equivalent<Key> + ?Sized,
252 {
253 let (shard, hash) = self.shard_for(key)?;
254 shard.read().get(hash, key).cloned()
255 }
256
257 pub fn peek<Q>(&self, key: &Q) -> Option<Val>
260 where
261 Q: Hash + Equivalent<Key> + ?Sized,
262 {
263 let (shard, hash) = self.shard_for(key)?;
264 shard.read().peek(hash, key).cloned()
265 }
266
267 pub fn remove<Q>(&self, key: &Q) -> Option<(Key, Val)>
270 where
271 Q: Hash + Equivalent<Key> + ?Sized,
272 {
273 let (shard, hash) = self.shard_for(key).unwrap();
274 shard.write().remove(hash, key)
275 }
276
277 pub fn remove_if<Q, F>(&self, key: &Q, f: F) -> Option<(Key, Val)>
282 where
283 Q: Hash + Equivalent<Key> + ?Sized,
284 F: FnOnce(&Val) -> bool,
285 {
286 let (shard, hash) = self.shard_for(key).unwrap();
287 shard.write().remove_if(hash, key, f)
288 }
289
290 pub fn replace(&self, key: Key, value: Val, soft: bool) -> Result<(), (Key, Val)> {
296 let lcs = self.replace_with_lifecycle(key, value, soft)?;
297 self.lifecycle.end_request(lcs);
298 Ok(())
299 }
300
301 pub fn replace_with_lifecycle(
307 &self,
308 key: Key,
309 value: Val,
310 soft: bool,
311 ) -> Result<L::RequestState, (Key, Val)> {
312 let mut lcs = self.lifecycle.begin_request();
313 let (shard, hash) = self.shard_for(&key).unwrap();
314 shard
315 .write()
316 .insert(&mut lcs, hash, key, value, InsertStrategy::Replace { soft })?;
317 Ok(lcs)
318 }
319
320 pub fn retain<F>(&self, f: F)
324 where
325 F: Fn(&Key, &Val) -> bool,
326 {
327 for s in self.shards.iter() {
328 s.write().retain(&f);
329 }
330 }
331
332 pub fn insert(&self, key: Key, value: Val) {
334 let lcs = self.insert_with_lifecycle(key, value);
335 self.lifecycle.end_request(lcs);
336 }
337
338 pub fn insert_with_lifecycle(&self, key: Key, value: Val) -> L::RequestState {
340 let mut lcs = self.lifecycle.begin_request();
341 let (shard, hash) = self.shard_for(&key).unwrap();
342 let result = shard
343 .write()
344 .insert(&mut lcs, hash, key, value, InsertStrategy::Insert);
345 debug_assert!(result.is_ok());
347 lcs
348 }
349
350 pub fn clear(&self) {
352 for s in self.shards.iter() {
353 s.write().clear();
354 }
355 }
356
357 pub fn iter(&self) -> Iter<'_, Key, Val, We, B, L>
363 where
364 Key: Clone,
365 {
366 Iter {
367 shards: &self.shards,
368 current_shard: 0,
369 last: None,
370 }
371 }
372
373 pub fn drain(&self) -> Drain<'_, Key, Val, We, B, L> {
384 Drain {
385 shards: &self.shards,
386 current_shard: 0,
387 last: None,
388 }
389 }
390
391 pub fn set_capacity(&self, new_weight_capacity: u64) {
397 let shard_weight_cap = new_weight_capacity.saturating_add(self.shards.len() as u64 - 1)
398 / self.shards.len() as u64;
399 for shard in &*self.shards {
400 shard.write().set_capacity(shard_weight_cap);
401 }
402 }
403
404 pub fn get_value_or_guard<Q>(
416 &self,
417 key: &Q,
418 timeout: Option<Duration>,
419 ) -> GuardResult<'_, Key, Val, We, B, L>
420 where
421 Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
422 {
423 let (shard, hash) = self.shard_for(key).unwrap();
424 if let Some(v) = shard.read().get(hash, key) {
425 return GuardResult::Value(v.clone());
426 }
427 PlaceholderGuard::join(&self.lifecycle, shard, hash, key, timeout)
428 }
429
430 pub fn get_or_insert_with<Q, E>(
434 &self,
435 key: &Q,
436 with: impl FnOnce() -> Result<Val, E>,
437 ) -> Result<Val, E>
438 where
439 Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
440 {
441 match self.get_value_or_guard(key, None) {
442 GuardResult::Value(v) => Ok(v),
443 GuardResult::Guard(g) => {
444 let v = with()?;
445 let _ = g.insert(v.clone());
446 Ok(v)
447 }
448 GuardResult::Timeout => unsafe { unreachable_unchecked() },
449 }
450 }
451
452 pub async fn get_value_or_guard_async<'a, Q>(
460 &'a self,
461 key: &Q,
462 ) -> Result<Val, PlaceholderGuard<'a, Key, Val, We, B, L>>
463 where
464 Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
465 {
466 let (shard, hash) = self.shard_for(key).unwrap();
467 if let Some(v) = shard.read().get(hash, key) {
468 return Ok(v.clone());
469 }
470 JoinFuture::new(&self.lifecycle, shard, hash, key).await
471 }
472
473 pub async fn get_or_insert_async<Q, E>(
475 &self,
476 key: &Q,
477 with: impl Future<Output = Result<Val, E>>,
478 ) -> Result<Val, E>
479 where
480 Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
481 {
482 match self.get_value_or_guard_async(key).await {
483 Ok(v) => Ok(v),
484 Err(g) => {
485 let v = with.await?;
486 let _ = g.insert(v.clone());
487 Ok(v)
488 }
489 }
490 }
491
492 pub fn memory_used(&self) -> MemoryUsed {
497 let mut total = MemoryUsed { entries: 0, map: 0 };
498 self.shards.iter().for_each(|shard| {
499 let shard_memory = shard.read().memory_used();
500 total.entries += shard_memory.entries;
501 total.map += shard_memory.map;
502 });
503 total
504 }
505}
506
507impl<Key, Val, We, B, L> std::fmt::Debug for Cache<Key, Val, We, B, L> {
508 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
509 f.debug_struct("Cache").finish_non_exhaustive()
510 }
511}
512
513pub struct Iter<'a, Key, Val, We, B, L> {
517 shards: &'a [RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>],
518 current_shard: usize,
519 last: Option<Token>,
520}
521
522impl<Key, Val, We, B, L> Iterator for Iter<'_, Key, Val, We, B, L>
523where
524 Key: Clone,
525 Val: Clone,
526{
527 type Item = (Key, Val);
528
529 fn next(&mut self) -> Option<Self::Item> {
530 while self.current_shard < self.shards.len() {
531 let shard = &self.shards[self.current_shard];
532 let lock = shard.read();
533 if let Some((new_last, key, val)) = lock.iter_from(self.last).next() {
534 self.last = Some(new_last);
535 return Some((key.clone(), val.clone()));
536 }
537 self.last = None;
538 self.current_shard += 1;
539 }
540 None
541 }
542}
543
544impl<Key, Val, We, B, L> std::fmt::Debug for Iter<'_, Key, Val, We, B, L> {
545 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
546 f.debug_struct("Iter").finish_non_exhaustive()
547 }
548}
549
550pub struct Drain<'a, Key, Val, We, B, L> {
554 shards: &'a [RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>],
555 current_shard: usize,
556 last: Option<Token>,
557}
558
559impl<Key, Val, We, B, L> Iterator for Drain<'_, Key, Val, We, B, L>
560where
561 Key: Hash + Eq,
562 We: Weighter<Key, Val>,
563 B: BuildHasher,
564 L: Lifecycle<Key, Val>,
565{
566 type Item = (Key, Val);
567
568 fn next(&mut self) -> Option<Self::Item> {
569 while self.current_shard < self.shards.len() {
570 let shard = &self.shards[self.current_shard];
571 let mut lock = shard.write();
572 if let Some((new_last, key, value)) = lock.remove_next(self.last) {
573 self.last = Some(new_last);
574 return Some((key, value));
575 }
576 self.last = None;
577 self.current_shard += 1;
578 }
579 None
580 }
581}
582
583impl<Key, Val, We, B, L> std::fmt::Debug for Drain<'_, Key, Val, We, B, L> {
584 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
585 f.debug_struct("Drain").finish_non_exhaustive()
586 }
587}
588
589pub struct DefaultLifecycle<Key, Val>(std::marker::PhantomData<(Key, Val)>);
593
594impl<Key, Val> std::fmt::Debug for DefaultLifecycle<Key, Val> {
595 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
596 f.debug_tuple("DefaultLifecycle").finish()
597 }
598}
599
600impl<Key, Val> Default for DefaultLifecycle<Key, Val> {
601 #[inline]
602 fn default() -> Self {
603 Self(Default::default())
604 }
605}
606impl<Key, Val> Clone for DefaultLifecycle<Key, Val> {
607 #[inline]
608 fn clone(&self) -> Self {
609 Self(Default::default())
610 }
611}
612
613impl<Key, Val> Lifecycle<Key, Val> for DefaultLifecycle<Key, Val> {
614 type RequestState = [Option<(Key, Val)>; 2];
620
621 #[inline]
622 fn begin_request(&self) -> Self::RequestState {
623 [None, None]
624 }
625
626 #[inline]
627 fn on_evict(&self, state: &mut Self::RequestState, key: Key, val: Val) {
628 if std::mem::needs_drop::<(Key, Val)>() {
629 if state[0].is_none() {
630 state[0] = Some((key, val));
631 } else if state[1].is_none() {
632 state[1] = Some((key, val));
633 }
634 }
635 }
636}
637
638#[cfg(test)]
639mod tests {
640 use super::*;
641 use std::{
642 sync::{Arc, Barrier},
643 thread,
644 };
645
646 #[test]
647 #[cfg_attr(miri, ignore)]
648 fn test_multiple_threads() {
649 const N_THREAD_PAIRS: usize = 8;
650 const N_ROUNDS: usize = 1_000;
651 const ITEMS_PER_THREAD: usize = 1_000;
652 let mut threads = Vec::new();
653 let barrier = Arc::new(Barrier::new(N_THREAD_PAIRS * 2));
654 let cache = Arc::new(Cache::new(N_THREAD_PAIRS * ITEMS_PER_THREAD / 10));
655 for t in 0..N_THREAD_PAIRS {
656 let barrier = barrier.clone();
657 let cache = cache.clone();
658 let handle = thread::spawn(move || {
659 let start = ITEMS_PER_THREAD * t;
660 barrier.wait();
661 for _round in 0..N_ROUNDS {
662 for i in start..start + ITEMS_PER_THREAD {
663 cache.insert(i, i);
664 }
665 }
666 });
667 threads.push(handle);
668 }
669 for t in 0..N_THREAD_PAIRS {
670 let barrier = barrier.clone();
671 let cache = cache.clone();
672 let handle = thread::spawn(move || {
673 let start = ITEMS_PER_THREAD * t;
674 barrier.wait();
675 for _round in 0..N_ROUNDS {
676 for i in start..start + ITEMS_PER_THREAD {
677 if let Some(cached) = cache.get(&i) {
678 assert_eq!(cached, i);
679 }
680 }
681 }
682 });
683 threads.push(handle);
684 }
685 for t in threads {
686 t.join().unwrap();
687 }
688 }
689
690 #[test]
691 fn test_iter() {
692 let capacity = if cfg!(miri) { 100 } else { 100000 };
693 let options = OptionsBuilder::new()
694 .estimated_items_capacity(capacity)
695 .weight_capacity(capacity as u64)
696 .shards(2)
697 .build()
698 .unwrap();
699 let cache = Cache::with_options(
700 options,
701 UnitWeighter,
702 DefaultHashBuilder::default(),
703 DefaultLifecycle::default(),
704 );
705 let items = capacity / 2;
706 for i in 0..items {
707 cache.insert(i, i);
708 }
709 assert_eq!(cache.len(), items);
710 let mut iter_collected = cache.iter().collect::<Vec<_>>();
711 assert_eq!(iter_collected.len(), items);
712 iter_collected.sort();
713 for (i, v) in iter_collected.into_iter().enumerate() {
714 assert_eq!((i, i), v);
715 }
716 }
717
718 #[test]
719 fn test_drain() {
720 let capacity = if cfg!(miri) { 100 } else { 100000 };
721 let options = OptionsBuilder::new()
722 .estimated_items_capacity(capacity)
723 .weight_capacity(capacity as u64)
724 .shards(2)
725 .build()
726 .unwrap();
727 let cache = Cache::with_options(
728 options,
729 UnitWeighter,
730 DefaultHashBuilder::default(),
731 DefaultLifecycle::default(),
732 );
733 let items = capacity / 2;
734 for i in 0..items {
735 cache.insert(i, i);
736 }
737 assert_eq!(cache.len(), items);
738 let mut drain_collected = cache.drain().collect::<Vec<_>>();
739 assert_eq!(cache.len(), 0);
740 assert_eq!(drain_collected.len(), items);
741 drain_collected.sort();
742 for (i, v) in drain_collected.into_iter().enumerate() {
743 assert_eq!((i, i), v);
744 }
745 }
746
747 #[test]
748 fn test_set_capacity() {
749 let cache = Cache::new(100);
750 for i in 0..80 {
751 cache.insert(i, i);
752 }
753 let initial_len = cache.len();
754 assert!(initial_len <= 80);
755
756 cache.set_capacity(50);
758 assert!(cache.len() <= 50);
759 assert!(cache.weight() <= 50);
760
761 cache.set_capacity(200);
763 assert_eq!(cache.capacity(), 200);
764
765 for i in 100..180 {
767 cache.insert(i, i);
768 }
769 assert!(cache.len() <= 180);
770 assert!(cache.weight() <= 200);
771 }
772
773 #[test]
774 fn test_remove_if() {
775 let cache = Cache::new(100);
776
777 cache.insert(1, 10);
779 cache.insert(2, 20);
780 cache.insert(3, 30);
781
782 let removed = cache.remove_if(&2, |v| *v == 20);
784 assert_eq!(removed, Some((2, 20)));
785 assert_eq!(cache.get(&2), None);
786
787 let not_removed = cache.remove_if(&3, |v| *v == 999);
789 assert_eq!(not_removed, None);
790 assert_eq!(cache.get(&3), Some(30));
791
792 let not_found = cache.remove_if(&999, |_| true);
794 assert_eq!(not_found, None);
795 }
796}