1use std::{
2 future::Future,
3 hash::{BuildHasher, Hash},
4 hint::unreachable_unchecked,
5 time::Duration,
6};
7
8use crate::{
9 linked_slab::Token,
10 options::{Options, OptionsBuilder},
11 shard::{CacheShard, InsertStrategy},
12 shim::rw_lock::RwLock,
13 sync_placeholder::SharedPlaceholder,
14 DefaultHashBuilder, Equivalent, Lifecycle, MemoryUsed, UnitWeighter, Weighter,
15};
16
17use crate::shard::EntryOrPlaceholder;
18pub use crate::sync_placeholder::{EntryAction, EntryResult, GuardResult, PlaceholderGuard};
19use crate::sync_placeholder::{JoinFuture, JoinResult};
20
21pub struct Cache<
36 Key,
37 Val,
38 We = UnitWeighter,
39 B = DefaultHashBuilder,
40 L = DefaultLifecycle<Key, Val>,
41> {
42 hash_builder: B,
43 shards: Box<[RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>]>,
44 shards_mask: u64,
45 lifecycle: L,
46}
47
48impl<Key: Eq + Hash, Val: Clone> Cache<Key, Val> {
49 pub fn new(items_capacity: usize) -> Self {
51 Self::with(
52 items_capacity,
53 items_capacity as u64,
54 Default::default(),
55 Default::default(),
56 Default::default(),
57 )
58 }
59}
60
61impl<Key: Eq + Hash, Val: Clone, We: Weighter<Key, Val> + Clone> Cache<Key, Val, We> {
62 pub fn with_weighter(
63 estimated_items_capacity: usize,
64 weight_capacity: u64,
65 weighter: We,
66 ) -> Self {
67 Self::with(
68 estimated_items_capacity,
69 weight_capacity,
70 weighter,
71 Default::default(),
72 Default::default(),
73 )
74 }
75}
76
77impl<
78 Key: Eq + Hash,
79 Val: Clone,
80 We: Weighter<Key, Val> + Clone,
81 B: BuildHasher + Clone,
82 L: Lifecycle<Key, Val> + Clone,
83 > Cache<Key, Val, We, B, L>
84{
85 pub fn with(
89 estimated_items_capacity: usize,
90 weight_capacity: u64,
91 weighter: We,
92 hash_builder: B,
93 lifecycle: L,
94 ) -> Self {
95 Self::with_options(
96 OptionsBuilder::new()
97 .estimated_items_capacity(estimated_items_capacity)
98 .weight_capacity(weight_capacity)
99 .build()
100 .unwrap(),
101 weighter,
102 hash_builder,
103 lifecycle,
104 )
105 }
106
107 pub fn with_options(options: Options, weighter: We, hash_builder: B, lifecycle: L) -> Self {
126 let mut num_shards = options.shards.next_power_of_two() as u64;
127 let estimated_items_capacity = options.estimated_items_capacity as u64;
128 let weight_capacity = options.weight_capacity;
129 let mut shard_items_cap =
130 estimated_items_capacity.saturating_add(num_shards - 1) / num_shards;
131 let mut shard_weight_cap =
132 options.weight_capacity.saturating_add(num_shards - 1) / num_shards;
133 while shard_items_cap < 32 && num_shards > 1 {
135 num_shards /= 2;
136 shard_items_cap = estimated_items_capacity.saturating_add(num_shards - 1) / num_shards;
137 shard_weight_cap = weight_capacity.saturating_add(num_shards - 1) / num_shards;
138 }
139 let shards = (0..num_shards)
140 .map(|_| {
141 RwLock::new(CacheShard::new(
142 options.hot_allocation,
143 options.ghost_allocation,
144 shard_items_cap as usize,
145 shard_weight_cap,
146 weighter.clone(),
147 hash_builder.clone(),
148 lifecycle.clone(),
149 ))
150 })
151 .collect::<Vec<_>>();
152 Self {
153 shards: shards.into_boxed_slice(),
154 hash_builder,
155 shards_mask: num_shards - 1,
156 lifecycle,
157 }
158 }
159
160 #[cfg(fuzzing)]
161 pub fn validate(&self) {
162 for s in &*self.shards {
163 s.read().validate(false)
164 }
165 }
166
167 pub fn is_empty(&self) -> bool {
169 self.shards.iter().all(|s| s.read().len() == 0)
170 }
171
172 pub fn len(&self) -> usize {
174 self.shards.iter().map(|s| s.read().len()).sum()
175 }
176
177 pub fn weight(&self) -> u64 {
179 self.shards.iter().map(|s| s.read().weight()).sum()
180 }
181
182 pub fn capacity(&self) -> u64 {
186 self.shards.iter().map(|s| s.read().capacity()).sum()
187 }
188
189 pub fn shard_capacity(&self) -> u64 {
191 self.shards[0].read().capacity()
192 }
193
194 pub fn num_shards(&self) -> usize {
196 self.shards.len()
197 }
198
199 #[cfg(feature = "stats")]
201 pub fn misses(&self) -> u64 {
202 self.shards.iter().map(|s| s.read().misses()).sum()
203 }
204
205 #[cfg(feature = "stats")]
207 pub fn hits(&self) -> u64 {
208 self.shards.iter().map(|s| s.read().hits()).sum()
209 }
210
211 #[inline]
212 fn shard_for<Q>(
213 &self,
214 key: &Q,
215 ) -> Option<(
216 &RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
217 u64,
218 )>
219 where
220 Q: Hash + Equivalent<Key> + ?Sized,
221 {
222 let hash = self.hash_builder.hash_one(key);
223 let shard_idx = (hash.rotate_right(usize::BITS / 2) & self.shards_mask) as usize;
228 self.shards.get(shard_idx).map(|s| (s, hash))
229 }
230
231 pub fn reserve(&self, additional: usize) {
234 let additional_per_shard =
235 additional.saturating_add(self.shards.len() - 1) / self.shards.len();
236 for s in &*self.shards {
237 s.write().reserve(additional_per_shard);
238 }
239 }
240
241 pub fn contains_key<Q>(&self, key: &Q) -> bool
243 where
244 Q: Hash + Equivalent<Key> + ?Sized,
245 {
246 self.shard_for(key)
247 .is_some_and(|(shard, hash)| shard.read().contains(hash, key))
248 }
249
250 pub fn get<Q>(&self, key: &Q) -> Option<Val>
252 where
253 Q: Hash + Equivalent<Key> + ?Sized,
254 {
255 let (shard, hash) = self.shard_for(key)?;
256 shard.read().get(hash, key).cloned()
257 }
258
259 pub fn peek<Q>(&self, key: &Q) -> Option<Val>
262 where
263 Q: Hash + Equivalent<Key> + ?Sized,
264 {
265 let (shard, hash) = self.shard_for(key)?;
266 shard.read().peek(hash, key).cloned()
267 }
268
269 pub fn remove<Q>(&self, key: &Q) -> Option<(Key, Val)>
272 where
273 Q: Hash + Equivalent<Key> + ?Sized,
274 {
275 let (shard, hash) = self.shard_for(key).unwrap();
276 shard.write().remove(hash, key)
277 }
278
279 pub fn remove_if<Q, F>(&self, key: &Q, f: F) -> Option<(Key, Val)>
284 where
285 Q: Hash + Equivalent<Key> + ?Sized,
286 F: FnOnce(&Val) -> bool,
287 {
288 let (shard, hash) = self.shard_for(key).unwrap();
289 shard.write().remove_if(hash, key, f)
290 }
291
292 pub fn replace(&self, key: Key, value: Val, soft: bool) -> Result<(), (Key, Val)> {
298 let lcs = self.replace_with_lifecycle(key, value, soft)?;
299 self.lifecycle.end_request(lcs);
300 Ok(())
301 }
302
303 pub fn replace_with_lifecycle(
309 &self,
310 key: Key,
311 value: Val,
312 soft: bool,
313 ) -> Result<L::RequestState, (Key, Val)> {
314 let mut lcs = self.lifecycle.begin_request();
315 let (shard, hash) = self.shard_for(&key).unwrap();
316 shard
317 .write()
318 .insert(&mut lcs, hash, key, value, InsertStrategy::Replace { soft })?;
319 Ok(lcs)
320 }
321
322 pub fn retain<F>(&self, f: F)
326 where
327 F: Fn(&Key, &Val) -> bool,
328 {
329 for s in self.shards.iter() {
330 s.write().retain(&f);
331 }
332 }
333
334 pub fn insert(&self, key: Key, value: Val) {
336 let lcs = self.insert_with_lifecycle(key, value);
337 self.lifecycle.end_request(lcs);
338 }
339
340 pub fn insert_with_lifecycle(&self, key: Key, value: Val) -> L::RequestState {
342 let mut lcs = self.lifecycle.begin_request();
343 let (shard, hash) = self.shard_for(&key).unwrap();
344 let result = shard
345 .write()
346 .insert(&mut lcs, hash, key, value, InsertStrategy::Insert);
347 debug_assert!(result.is_ok());
349 lcs
350 }
351
352 pub fn clear(&self) {
354 for s in self.shards.iter() {
355 s.write().clear();
356 }
357 }
358
359 pub fn iter(&self) -> Iter<'_, Key, Val, We, B, L>
365 where
366 Key: Clone,
367 {
368 Iter {
369 shards: &self.shards,
370 current_shard: 0,
371 last: None,
372 }
373 }
374
375 pub fn drain(&self) -> Drain<'_, Key, Val, We, B, L> {
386 Drain {
387 shards: &self.shards,
388 current_shard: 0,
389 last: None,
390 }
391 }
392
393 pub fn set_capacity(&self, new_weight_capacity: u64) {
399 let shard_weight_cap = new_weight_capacity.saturating_add(self.shards.len() as u64 - 1)
400 / self.shards.len() as u64;
401 for shard in &*self.shards {
402 shard.write().set_capacity(shard_weight_cap);
403 }
404 }
405
406 pub fn get_value_or_guard<Q>(
418 &self,
419 key: &Q,
420 timeout: Option<Duration>,
421 ) -> GuardResult<'_, Key, Val, We, B, L>
422 where
423 Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
424 {
425 let (shard, hash) = self.shard_for(key).unwrap();
426 if let Some(v) = shard.read().get(hash, key) {
427 return GuardResult::Value(v.clone());
428 }
429 PlaceholderGuard::join(&self.lifecycle, shard, hash, key, timeout)
430 }
431
432 pub fn get_or_insert_with<Q, E>(
436 &self,
437 key: &Q,
438 with: impl FnOnce() -> Result<Val, E>,
439 ) -> Result<Val, E>
440 where
441 Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
442 {
443 match self.get_value_or_guard(key, None) {
444 GuardResult::Value(v) => Ok(v),
445 GuardResult::Guard(g) => {
446 let v = with()?;
447 let _ = g.insert(v.clone());
448 Ok(v)
449 }
450 GuardResult::Timeout => unsafe { unreachable_unchecked() },
451 }
452 }
453
454 pub async fn get_value_or_guard_async<'a, Q>(
462 &'a self,
463 key: &Q,
464 ) -> Result<Val, PlaceholderGuard<'a, Key, Val, We, B, L>>
465 where
466 Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
467 {
468 let (shard, hash) = self.shard_for(key).unwrap();
469 loop {
470 if let Some(v) = shard.read().get(hash, key) {
471 return Ok(v.clone());
472 }
473 match JoinFuture::new(&self.lifecycle, shard, hash, key).await {
474 JoinResult::Filled(Some(shared)) => {
475 return Ok(unsafe { shared.value().unwrap_unchecked().clone() });
477 }
478 JoinResult::Filled(None) => continue,
479 JoinResult::Guard(g) => return Err(g),
480 JoinResult::Timeout => unsafe { unreachable_unchecked() },
481 }
482 }
483 }
484
485 pub async fn get_or_insert_async<Q, E>(
487 &self,
488 key: &Q,
489 with: impl Future<Output = Result<Val, E>>,
490 ) -> Result<Val, E>
491 where
492 Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
493 {
494 match self.get_value_or_guard_async(key).await {
495 Ok(v) => Ok(v),
496 Err(g) => {
497 let v = with.await?;
498 let _ = g.insert(v.clone());
499 Ok(v)
500 }
501 }
502 }
503
504 pub fn entry<Q, T>(
554 &self,
555 key: &Q,
556 timeout: Option<Duration>,
557 on_occupied: impl FnOnce(&Key, &mut Val) -> EntryAction<T>,
558 ) -> EntryResult<'_, Key, Val, We, B, L, T>
559 where
560 Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
561 {
562 let (shard, hash) = self.shard_for(key).unwrap();
563 let mut on_occupied = Some(on_occupied);
568 let mut callback = |k: &Key, v: &mut Val| on_occupied.take().unwrap()(k, v);
569 let mut deadline = timeout.map(Ok);
570
571 loop {
572 let mut shard_guard = shard.write();
573 match shard_guard.entry_or_placeholder(hash, key, &mut callback) {
574 EntryOrPlaceholder::Kept(t) => return EntryResult::Retained(t),
575 EntryOrPlaceholder::Removed(k, v) => return EntryResult::Removed(k, v),
576 EntryOrPlaceholder::Replaced(shared, old_val) => {
577 drop(shard_guard);
578 return EntryResult::Replaced(
579 PlaceholderGuard::start_loading(&self.lifecycle, shard, shared),
580 old_val,
581 );
582 }
583 EntryOrPlaceholder::NewPlaceholder(shared) => {
584 drop(shard_guard);
585 return EntryResult::Vacant(PlaceholderGuard::start_loading(
586 &self.lifecycle,
587 shard,
588 shared,
589 ));
590 }
591 EntryOrPlaceholder::ExistingPlaceholder(shared) => {
592 match PlaceholderGuard::wait_for_placeholder(
593 &self.lifecycle,
594 shard,
595 shard_guard,
596 shared,
597 deadline.as_mut(),
598 ) {
599 JoinResult::Filled(_) => continue,
600 JoinResult::Guard(g) => return EntryResult::Vacant(g),
601 JoinResult::Timeout => return EntryResult::Timeout,
602 }
603 }
604 }
605 }
606 }
607
608 pub async fn entry_async<'a, Q, T>(
615 &'a self,
616 key: &Q,
617 on_occupied: impl FnOnce(&Key, &mut Val) -> EntryAction<T>,
618 ) -> EntryResult<'a, Key, Val, We, B, L, T>
619 where
620 Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
621 {
622 let (shard, hash) = self.shard_for(key).unwrap();
623 let mut on_occupied = Some(on_occupied);
625 let mut callback = |k: &Key, v: &mut Val| on_occupied.take().unwrap()(k, v);
626
627 loop {
628 let result = {
631 let mut shard_guard = shard.write();
632 match shard_guard.entry_or_placeholder(hash, key, &mut callback) {
633 EntryOrPlaceholder::Kept(t) => Ok(EntryResult::Retained(t)),
634 EntryOrPlaceholder::Removed(k, v) => Ok(EntryResult::Removed(k, v)),
635 EntryOrPlaceholder::Replaced(shared, old_val) => {
636 drop(shard_guard);
637 Ok(EntryResult::Replaced(
638 PlaceholderGuard::start_loading(&self.lifecycle, shard, shared),
639 old_val,
640 ))
641 }
642 EntryOrPlaceholder::NewPlaceholder(shared) => {
643 drop(shard_guard);
644 Ok(EntryResult::Vacant(PlaceholderGuard::start_loading(
645 &self.lifecycle,
646 shard,
647 shared,
648 )))
649 }
650 EntryOrPlaceholder::ExistingPlaceholder(_) => Err(()),
651 }
652 };
653 match result {
654 Ok(entry_result) => return entry_result,
655 Err(()) => match JoinFuture::new(&self.lifecycle, shard, hash, key).await {
656 JoinResult::Filled(_) => continue,
657 JoinResult::Guard(g) => return EntryResult::Vacant(g),
658 JoinResult::Timeout => unsafe { unreachable_unchecked() },
659 },
660 }
661 }
662 }
663
664 pub fn memory_used(&self) -> MemoryUsed {
669 let mut total = MemoryUsed { entries: 0, map: 0 };
670 self.shards.iter().for_each(|shard| {
671 let shard_memory = shard.read().memory_used();
672 total.entries += shard_memory.entries;
673 total.map += shard_memory.map;
674 });
675 total
676 }
677}
678
679impl<Key, Val, We, B, L> std::fmt::Debug for Cache<Key, Val, We, B, L> {
680 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
681 f.debug_struct("Cache").finish_non_exhaustive()
682 }
683}
684
685pub struct Iter<'a, Key, Val, We, B, L> {
689 shards: &'a [RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>],
690 current_shard: usize,
691 last: Option<Token>,
692}
693
694impl<Key, Val, We, B, L> Iterator for Iter<'_, Key, Val, We, B, L>
695where
696 Key: Clone,
697 Val: Clone,
698{
699 type Item = (Key, Val);
700
701 fn next(&mut self) -> Option<Self::Item> {
702 while self.current_shard < self.shards.len() {
703 let shard = &self.shards[self.current_shard];
704 let lock = shard.read();
705 if let Some((new_last, key, val)) = lock.iter_from(self.last).next() {
706 self.last = Some(new_last);
707 return Some((key.clone(), val.clone()));
708 }
709 self.last = None;
710 self.current_shard += 1;
711 }
712 None
713 }
714}
715
716impl<Key, Val, We, B, L> std::fmt::Debug for Iter<'_, Key, Val, We, B, L> {
717 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
718 f.debug_struct("Iter").finish_non_exhaustive()
719 }
720}
721
722pub struct Drain<'a, Key, Val, We, B, L> {
726 shards: &'a [RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>],
727 current_shard: usize,
728 last: Option<Token>,
729}
730
731impl<Key, Val, We, B, L> Iterator for Drain<'_, Key, Val, We, B, L>
732where
733 Key: Hash + Eq,
734 We: Weighter<Key, Val>,
735 B: BuildHasher,
736 L: Lifecycle<Key, Val>,
737{
738 type Item = (Key, Val);
739
740 fn next(&mut self) -> Option<Self::Item> {
741 while self.current_shard < self.shards.len() {
742 let shard = &self.shards[self.current_shard];
743 let mut lock = shard.write();
744 if let Some((new_last, key, value)) = lock.remove_next(self.last) {
745 self.last = Some(new_last);
746 return Some((key, value));
747 }
748 self.last = None;
749 self.current_shard += 1;
750 }
751 None
752 }
753}
754
755impl<Key, Val, We, B, L> std::fmt::Debug for Drain<'_, Key, Val, We, B, L> {
756 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
757 f.debug_struct("Drain").finish_non_exhaustive()
758 }
759}
760
761pub struct DefaultLifecycle<Key, Val>(std::marker::PhantomData<(Key, Val)>);
765
766impl<Key, Val> std::fmt::Debug for DefaultLifecycle<Key, Val> {
767 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
768 f.debug_tuple("DefaultLifecycle").finish()
769 }
770}
771
772impl<Key, Val> Default for DefaultLifecycle<Key, Val> {
773 #[inline]
774 fn default() -> Self {
775 Self(Default::default())
776 }
777}
778impl<Key, Val> Clone for DefaultLifecycle<Key, Val> {
779 #[inline]
780 fn clone(&self) -> Self {
781 Self(Default::default())
782 }
783}
784
785impl<Key, Val> Lifecycle<Key, Val> for DefaultLifecycle<Key, Val> {
786 type RequestState = [Option<(Key, Val)>; 2];
792
793 #[inline]
794 fn begin_request(&self) -> Self::RequestState {
795 [None, None]
796 }
797
798 #[inline]
799 fn on_evict(&self, state: &mut Self::RequestState, key: Key, val: Val) {
800 if std::mem::needs_drop::<(Key, Val)>() {
801 if state[0].is_none() {
802 state[0] = Some((key, val));
803 } else if state[1].is_none() {
804 state[1] = Some((key, val));
805 }
806 }
807 }
808}
809
810#[cfg(test)]
811mod tests {
812 use super::*;
813 use std::{
814 sync::{Arc, Barrier},
815 thread,
816 };
817
818 #[test]
819 #[cfg_attr(miri, ignore)]
820 fn test_multiple_threads() {
821 const N_THREAD_PAIRS: usize = 8;
822 const N_ROUNDS: usize = 1_000;
823 const ITEMS_PER_THREAD: usize = 1_000;
824 let mut threads = Vec::new();
825 let barrier = Arc::new(Barrier::new(N_THREAD_PAIRS * 2));
826 let cache = Arc::new(Cache::new(N_THREAD_PAIRS * ITEMS_PER_THREAD / 10));
827 for t in 0..N_THREAD_PAIRS {
828 let barrier = barrier.clone();
829 let cache = cache.clone();
830 let handle = thread::spawn(move || {
831 let start = ITEMS_PER_THREAD * t;
832 barrier.wait();
833 for _round in 0..N_ROUNDS {
834 for i in start..start + ITEMS_PER_THREAD {
835 cache.insert(i, i);
836 }
837 }
838 });
839 threads.push(handle);
840 }
841 for t in 0..N_THREAD_PAIRS {
842 let barrier = barrier.clone();
843 let cache = cache.clone();
844 let handle = thread::spawn(move || {
845 let start = ITEMS_PER_THREAD * t;
846 barrier.wait();
847 for _round in 0..N_ROUNDS {
848 for i in start..start + ITEMS_PER_THREAD {
849 if let Some(cached) = cache.get(&i) {
850 assert_eq!(cached, i);
851 }
852 }
853 }
854 });
855 threads.push(handle);
856 }
857 for t in threads {
858 t.join().unwrap();
859 }
860 }
861
862 #[test]
863 fn test_iter() {
864 let capacity = if cfg!(miri) { 100 } else { 100000 };
865 let options = OptionsBuilder::new()
866 .estimated_items_capacity(capacity)
867 .weight_capacity(capacity as u64)
868 .shards(2)
869 .build()
870 .unwrap();
871 let cache = Cache::with_options(
872 options,
873 UnitWeighter,
874 DefaultHashBuilder::default(),
875 DefaultLifecycle::default(),
876 );
877 let items = capacity / 2;
878 for i in 0..items {
879 cache.insert(i, i);
880 }
881 assert_eq!(cache.len(), items);
882 let mut iter_collected = cache.iter().collect::<Vec<_>>();
883 assert_eq!(iter_collected.len(), items);
884 iter_collected.sort();
885 for (i, v) in iter_collected.into_iter().enumerate() {
886 assert_eq!((i, i), v);
887 }
888 }
889
890 #[test]
891 fn test_drain() {
892 let capacity = if cfg!(miri) { 100 } else { 100000 };
893 let options = OptionsBuilder::new()
894 .estimated_items_capacity(capacity)
895 .weight_capacity(capacity as u64)
896 .shards(2)
897 .build()
898 .unwrap();
899 let cache = Cache::with_options(
900 options,
901 UnitWeighter,
902 DefaultHashBuilder::default(),
903 DefaultLifecycle::default(),
904 );
905 let items = capacity / 2;
906 for i in 0..items {
907 cache.insert(i, i);
908 }
909 assert_eq!(cache.len(), items);
910 let mut drain_collected = cache.drain().collect::<Vec<_>>();
911 assert_eq!(cache.len(), 0);
912 assert_eq!(drain_collected.len(), items);
913 drain_collected.sort();
914 for (i, v) in drain_collected.into_iter().enumerate() {
915 assert_eq!((i, i), v);
916 }
917 }
918
919 #[test]
920 fn test_set_capacity() {
921 let cache = Cache::new(100);
922 for i in 0..80 {
923 cache.insert(i, i);
924 }
925 let initial_len = cache.len();
926 assert!(initial_len <= 80);
927
928 cache.set_capacity(50);
930 assert!(cache.len() <= 50);
931 assert!(cache.weight() <= 50);
932
933 cache.set_capacity(200);
935 assert_eq!(cache.capacity(), 200);
936
937 for i in 100..180 {
939 cache.insert(i, i);
940 }
941 assert!(cache.len() <= 180);
942 assert!(cache.weight() <= 200);
943 }
944
945 #[test]
946 fn test_remove_if() {
947 let cache = Cache::new(100);
948
949 cache.insert(1, 10);
951 cache.insert(2, 20);
952 cache.insert(3, 30);
953
954 let removed = cache.remove_if(&2, |v| *v == 20);
956 assert_eq!(removed, Some((2, 20)));
957 assert_eq!(cache.get(&2), None);
958
959 let not_removed = cache.remove_if(&3, |v| *v == 999);
961 assert_eq!(not_removed, None);
962 assert_eq!(cache.get(&3), Some(30));
963
964 let not_found = cache.remove_if(&999, |_| true);
966 assert_eq!(not_found, None);
967 }
968
969 #[test]
971 fn test_entry_actions() {
972 let cache = Cache::new(100);
973 cache.insert(1, 10);
974 cache.insert(2, 20);
975
976 let result = cache.entry(&1, None, |_k, v| EntryAction::Retain(*v));
978 assert!(matches!(result, EntryResult::Retained(10)));
979 assert_eq!(cache.get(&1), Some(10));
980
981 let result = cache.entry(&1, None, |_k, v| {
983 *v += 5;
984 EntryAction::Retain(())
985 });
986 assert!(matches!(result, EntryResult::Retained(())));
987 assert_eq!(cache.get(&1), Some(15));
988
989 let result = cache.entry(&1, None, |_k, _v| EntryAction::<()>::Remove);
991 assert!(matches!(result, EntryResult::Removed(1, 15)));
992 assert_eq!(cache.get(&1), None);
993
994 let result = cache.entry(&1, None, |_k, v| EntryAction::Retain(*v));
996 match result {
997 EntryResult::Vacant(g) => {
998 let _ = g.insert(99);
999 assert_eq!(cache.get(&1), Some(99));
1000 }
1001 _ => panic!("expected Vacant for removed key"),
1002 }
1003
1004 let mut old_val = 0;
1006 let result = cache.entry(&2, None, |_k, v| {
1007 old_val = *v;
1008 EntryAction::<()>::ReplaceWithGuard
1009 });
1010 assert_eq!(old_val, 20);
1011 match result {
1012 EntryResult::Replaced(g, old) => {
1013 assert_eq!(old, 20);
1014 let _ = g.insert(old_val + 100);
1015 assert_eq!(cache.get(&2), Some(120));
1016 }
1017 _ => panic!("expected Replaced"),
1018 }
1019
1020 let result = cache.entry(&2, None, |_k, _v| EntryAction::<()>::ReplaceWithGuard);
1022 match result {
1023 EntryResult::Replaced(g, _old) => {
1024 drop(g);
1025 assert_eq!(cache.get(&2), None);
1026 }
1027 _ => panic!("expected Replaced"),
1028 }
1029
1030 let result = cache.entry(&3, None, |_k, v| EntryAction::Retain(*v));
1032 match result {
1033 EntryResult::Vacant(g) => {
1034 let _ = g.insert(30);
1035 assert_eq!(cache.get(&3), Some(30));
1036 }
1037 _ => panic!("expected Vacant"),
1038 }
1039 }
1040
1041 #[test]
1043 fn test_entry_weight_tracking() {
1044 #[derive(Clone)]
1045 struct StringWeighter;
1046 impl crate::Weighter<u64, String> for StringWeighter {
1047 fn weight(&self, _key: &u64, val: &String) -> u64 {
1048 val.len() as u64
1049 }
1050 }
1051
1052 let cache = Cache::with_weighter(100, 100_000, StringWeighter);
1053 cache.insert(1, "hello".to_string());
1054 cache.insert(2, "world".to_string());
1055 assert_eq!(cache.weight(), 10);
1056
1057 let result = cache.entry(&1, None, |_k, _v| EntryAction::Retain(()));
1059 assert!(matches!(result, EntryResult::Retained(())));
1060 assert_eq!(cache.weight(), 10);
1061
1062 let result = cache.entry(&1, None, |_k, v| {
1064 v.push_str(" world");
1065 EntryAction::Retain(())
1066 });
1067 assert!(matches!(result, EntryResult::Retained(())));
1068 assert_eq!(cache.weight(), 16); assert_eq!(cache.get(&1).unwrap(), "hello world");
1070
1071 let result = cache.entry(&1, None, |_k, v| {
1073 v.clear();
1074 EntryAction::Retain(())
1075 });
1076 assert!(matches!(result, EntryResult::Retained(())));
1077 assert_eq!(cache.weight(), 5); assert_eq!(cache.get(&1).unwrap(), "");
1079
1080 let result = cache.entry(&2, None, |_k, _v| EntryAction::<()>::Remove);
1082 assert!(matches!(result, EntryResult::Removed(2, _)));
1083 assert_eq!(cache.weight(), 0);
1084 assert_eq!(cache.len(), 1);
1085
1086 cache.insert(3, "hello".to_string());
1088 assert_eq!(cache.weight(), 5);
1089 let result = cache.entry(&3, None, |_k, _v| EntryAction::<()>::ReplaceWithGuard);
1090 match result {
1091 EntryResult::Replaced(g, _old) => {
1092 assert_eq!(cache.weight(), 0);
1093 let _ = g.insert("hello world!!".to_string());
1094 assert_eq!(cache.weight(), 13);
1095 }
1096 _ => panic!("expected Replaced"),
1097 }
1098 }
1099
1100 #[test]
1102 fn test_entry_eviction() {
1103 let cache = Cache::new(2);
1105 cache.insert(1, 10);
1106 cache.insert(2, 20);
1107 assert_eq!(cache.len(), 2);
1108
1109 let result = cache.entry(&3, None, |_k, v| EntryAction::Retain(*v));
1110 match result {
1111 EntryResult::Vacant(g) => {
1112 let _ = g.insert(30);
1113 assert!(cache.len() <= 2);
1114 assert_eq!(cache.get(&3), Some(30));
1115 }
1116 _ => panic!("expected Vacant"),
1117 }
1118
1119 let cache = Cache::new(0);
1121 let result = cache.entry(&1, None, |_k, v| EntryAction::Retain(*v));
1122 match result {
1123 EntryResult::Vacant(g) => {
1124 let _ = g.insert(10);
1125 assert_eq!(cache.get(&1), None);
1126 }
1127 _ => panic!("expected Vacant"),
1128 }
1129 }
1130
1131 #[test]
1133 #[cfg_attr(miri, ignore)]
1134 fn test_entry_concurrent_placeholder_wait() {
1135 let cache = Arc::new(Cache::new(100));
1136 let barrier = Arc::new(Barrier::new(2));
1137
1138 let cache2 = cache.clone();
1140 let barrier2 = barrier.clone();
1141 let handle = thread::spawn(move || match cache2.get_value_or_guard(&1, None) {
1142 GuardResult::Guard(g) => {
1143 barrier2.wait();
1144 std::thread::sleep(Duration::from_millis(50));
1145 let _ = g.insert(42);
1146 }
1147 _ => panic!("expected guard"),
1148 });
1149
1150 barrier.wait();
1151 let result = cache.entry(&1, None, |_k, v| EntryAction::Retain(*v));
1152 assert!(matches!(result, EntryResult::Retained(42)));
1153 handle.join().unwrap();
1154 }
1155
1156 #[test]
1158 #[cfg_attr(miri, ignore)]
1159 fn test_entry_concurrent_placeholder_guard_abandoned() {
1160 let cache = Arc::new(Cache::new(100));
1161 let barrier = Arc::new(Barrier::new(2));
1162
1163 let cache2 = cache.clone();
1164 let barrier2 = barrier.clone();
1165 let handle = thread::spawn(move || match cache2.get_value_or_guard(&1, None) {
1166 GuardResult::Guard(g) => {
1167 barrier2.wait();
1168 std::thread::sleep(Duration::from_millis(50));
1169 drop(g);
1170 }
1171 _ => panic!("expected guard"),
1172 });
1173
1174 barrier.wait();
1175 let result = cache.entry(&1, None, |_k, v| EntryAction::Retain(*v));
1176 match result {
1177 EntryResult::Vacant(g) => {
1178 let _ = g.insert(99);
1179 assert_eq!(cache.get(&1), Some(99));
1180 }
1181 _ => panic!("expected Vacant after abandoned placeholder"),
1182 }
1183 handle.join().unwrap();
1184 }
1185
1186 #[test]
1188 #[cfg_attr(miri, ignore)]
1189 fn test_entry_timeout() {
1190 let cache = Cache::new(100);
1191
1192 let guard = match cache.get_value_or_guard(&1, None) {
1194 GuardResult::Guard(g) => g,
1195 _ => panic!("expected guard"),
1196 };
1197 let result = cache.entry(&1, Some(Duration::ZERO), |_k, v| EntryAction::Retain(*v));
1198 assert!(matches!(result, EntryResult::Timeout));
1199 let _ = guard.insert(1);
1200
1201 let cache = Arc::new(Cache::new(100));
1203 let barrier = Arc::new(Barrier::new(2));
1204 let cache2 = cache.clone();
1205 let barrier2 = barrier.clone();
1206 let holder = thread::spawn(move || {
1207 let guard = match cache2.get_value_or_guard(&1, None) {
1208 GuardResult::Guard(g) => g,
1209 _ => panic!("expected guard"),
1210 };
1211 barrier2.wait();
1212 std::thread::sleep(Duration::from_millis(200));
1213 let _ = guard.insert(1);
1214 });
1215
1216 barrier.wait();
1217 let result = cache.entry(&1, Some(Duration::from_millis(50)), |_k, v| {
1218 EntryAction::Retain(*v)
1219 });
1220 assert!(matches!(result, EntryResult::Timeout));
1221 holder.join().unwrap();
1222 }
1223
1224 #[test]
1226 #[cfg_attr(miri, ignore)]
1227 fn test_entry_concurrent_multiple_waiters() {
1228 let cache = Arc::new(Cache::new(100));
1229 let barrier = Arc::new(Barrier::new(4)); let cache1 = cache.clone();
1232 let barrier1 = barrier.clone();
1233 let loader = thread::spawn(move || match cache1.get_value_or_guard(&1, None) {
1234 GuardResult::Guard(g) => {
1235 barrier1.wait();
1236 std::thread::sleep(Duration::from_millis(50));
1237 let _ = g.insert(42);
1238 }
1239 _ => panic!("expected guard"),
1240 });
1241
1242 let mut waiters = Vec::new();
1243 for _ in 0..3 {
1244 let cache_c = cache.clone();
1245 let barrier_c = barrier.clone();
1246 waiters.push(thread::spawn(move || {
1247 barrier_c.wait();
1248 let result = cache_c.entry(&1, None, |_k, v| EntryAction::Retain(*v));
1249 match result {
1250 EntryResult::Retained(v) => v,
1251 _ => panic!("expected Value"),
1252 }
1253 }));
1254 }
1255
1256 loader.join().unwrap();
1257 for w in waiters {
1258 assert_eq!(w.join().unwrap(), 42);
1259 }
1260 }
1261
1262 #[test]
1264 #[cfg_attr(miri, ignore)]
1265 fn test_entry_concurrent_action_after_wait() {
1266 let cache = Arc::new(Cache::new(100));
1268 let barrier = Arc::new(Barrier::new(2));
1269
1270 let cache1 = cache.clone();
1271 let barrier1 = barrier.clone();
1272 let loader = thread::spawn(move || match cache1.get_value_or_guard(&1, None) {
1273 GuardResult::Guard(g) => {
1274 barrier1.wait();
1275 std::thread::sleep(Duration::from_millis(50));
1276 let _ = g.insert(42);
1277 }
1278 _ => panic!("expected guard"),
1279 });
1280
1281 barrier.wait();
1282 let result = cache.entry(&1, None, |_k, _v| EntryAction::<()>::ReplaceWithGuard);
1283 match result {
1284 EntryResult::Replaced(g, old) => {
1285 assert_eq!(old, 42);
1286 let _ = g.insert(100);
1287 assert_eq!(cache.get(&1), Some(100));
1288 }
1289 _ => panic!("expected Replaced"),
1290 }
1291 loader.join().unwrap();
1292
1293 let cache = Arc::new(Cache::new(100));
1295 let barrier = Arc::new(Barrier::new(2));
1296
1297 let cache1 = cache.clone();
1298 let barrier1 = barrier.clone();
1299 let loader = thread::spawn(move || match cache1.get_value_or_guard(&1, None) {
1300 GuardResult::Guard(g) => {
1301 barrier1.wait();
1302 std::thread::sleep(Duration::from_millis(50));
1303 let _ = g.insert(42);
1304 }
1305 _ => panic!("expected guard"),
1306 });
1307
1308 barrier.wait();
1309 let result = cache.entry(&1, None, |_k, _v| EntryAction::<()>::Remove);
1310 assert!(matches!(result, EntryResult::Removed(1, 42)));
1311 assert_eq!(cache.get(&1), None);
1312 loader.join().unwrap();
1313 }
1314
1315 #[test]
1317 #[cfg_attr(miri, ignore)]
1318 fn test_entry_concurrent_stress() {
1319 const N_THREADS: usize = 8;
1320 const N_KEYS: usize = 50;
1321 const N_OPS: usize = 500;
1322
1323 let cache = Arc::new(Cache::new(1000));
1324 let barrier = Arc::new(Barrier::new(N_THREADS));
1325
1326 let mut handles = Vec::new();
1327 for t in 0..N_THREADS {
1328 let cache = cache.clone();
1329 let barrier = barrier.clone();
1330 handles.push(thread::spawn(move || {
1331 barrier.wait();
1332 for i in 0..N_OPS {
1333 let key = (t * N_OPS + i) % N_KEYS;
1334 let result = cache.entry(&key, Some(Duration::from_millis(10)), |_k, v| {
1335 EntryAction::Retain(*v)
1336 });
1337 match result {
1338 EntryResult::Retained(_) => {}
1339 EntryResult::Vacant(g) => {
1340 let _ = g.insert(key * 10);
1341 }
1342 EntryResult::Replaced(g, _) => {
1343 let _ = g.insert(key * 10);
1344 }
1345 EntryResult::Timeout => {}
1346 EntryResult::Removed(_, _) => {}
1347 }
1348 }
1349 }));
1350 }
1351
1352 for h in handles {
1353 h.join().unwrap();
1354 }
1355
1356 assert!(cache.len() <= N_KEYS);
1357 for key in 0..N_KEYS {
1358 if let Some(v) = cache.get(&key) {
1359 assert_eq!(v, key * 10);
1360 }
1361 }
1362 }
1363
1364 #[tokio::test]
1368 async fn test_entry_async_actions() {
1369 let cache = Cache::new(100);
1370 cache.insert(1, 10);
1371 cache.insert(2, 20);
1372
1373 let result = cache.entry_async(&1, |_k, v| EntryAction::Retain(*v)).await;
1375 assert!(matches!(result, EntryResult::Retained(10)));
1376 assert_eq!(cache.get(&1), Some(10));
1377
1378 let result = cache
1380 .entry_async(&1, |_k, _v| EntryAction::<()>::Remove)
1381 .await;
1382 assert!(matches!(result, EntryResult::Removed(1, 10)));
1383 assert_eq!(cache.get(&1), None);
1384
1385 let result = cache
1387 .entry_async(&2, |_k, _v| EntryAction::<()>::ReplaceWithGuard)
1388 .await;
1389 match result {
1390 EntryResult::Replaced(g, old) => {
1391 assert_eq!(old, 20);
1392 let _ = g.insert(42);
1393 assert_eq!(cache.get(&2), Some(42));
1394 }
1395 _ => panic!("expected Replaced"),
1396 }
1397
1398 let result = cache.entry_async(&3, |_k, v| EntryAction::Retain(*v)).await;
1400 match result {
1401 EntryResult::Vacant(g) => {
1402 let _ = g.insert(99);
1403 assert_eq!(cache.get(&3), Some(99));
1404 }
1405 _ => panic!("expected Vacant"),
1406 }
1407 }
1408
1409 #[tokio::test(flavor = "multi_thread")]
1411 async fn test_entry_async_concurrent_wait() {
1412 let cache = Arc::new(Cache::new(100));
1413 let barrier = Arc::new(Barrier::new(2));
1414
1415 let cache1 = cache.clone();
1416 let barrier1 = barrier.clone();
1417 let holder = thread::spawn(move || {
1418 let guard = match cache1.get_value_or_guard(&1, None) {
1419 GuardResult::Guard(g) => g,
1420 _ => panic!("expected guard"),
1421 };
1422 barrier1.wait();
1423 std::thread::sleep(Duration::from_millis(50));
1424 let _ = guard.insert(42);
1425 });
1426
1427 barrier.wait();
1428 let result = cache.entry_async(&1, |_k, v| EntryAction::Retain(*v)).await;
1429 assert!(matches!(result, EntryResult::Retained(42)));
1430 holder.join().unwrap();
1431 }
1432
1433 #[tokio::test(flavor = "multi_thread")]
1435 async fn test_entry_async_concurrent_guard_abandoned() {
1436 let cache = Arc::new(Cache::new(100));
1437 let barrier = Arc::new(Barrier::new(2));
1438
1439 let cache1 = cache.clone();
1440 let barrier1 = barrier.clone();
1441 let holder = thread::spawn(move || {
1442 let guard = match cache1.get_value_or_guard(&1, None) {
1443 GuardResult::Guard(g) => g,
1444 _ => panic!("expected guard"),
1445 };
1446 barrier1.wait();
1447 std::thread::sleep(Duration::from_millis(50));
1448 drop(guard);
1449 });
1450
1451 barrier.wait();
1452 let result = cache.entry_async(&1, |_k, v| EntryAction::Retain(*v)).await;
1453 match result {
1454 EntryResult::Vacant(g) => {
1455 let _ = g.insert(99);
1456 }
1457 _ => panic!("expected Vacant after abandoned placeholder"),
1458 }
1459 assert_eq!(cache.get(&1), Some(99));
1460 holder.join().unwrap();
1461 }
1462
1463 #[tokio::test(flavor = "multi_thread")]
1465 #[cfg_attr(miri, ignore)]
1466 async fn test_entry_async_concurrent_stress() {
1467 const N_TASKS: usize = 16;
1468 const N_KEYS: usize = 50;
1469 const N_OPS: usize = 200;
1470
1471 let cache = Arc::new(Cache::new(1000));
1472 let barrier = Arc::new(tokio::sync::Barrier::new(N_TASKS));
1473
1474 let mut handles = Vec::new();
1475 for t in 0..N_TASKS {
1476 let cache = cache.clone();
1477 let barrier = barrier.clone();
1478 handles.push(tokio::spawn(async move {
1479 barrier.wait().await;
1480 for i in 0..N_OPS {
1481 let key = (t * N_OPS + i) % N_KEYS;
1482 let _ = cache
1485 .get_or_insert_async(&key, async { Ok::<_, ()>(key * 10) })
1486 .await;
1487 }
1488 }));
1489 }
1490
1491 for h in handles {
1492 h.await.unwrap();
1493 }
1494
1495 assert!(cache.len() <= N_KEYS);
1496 for key in 0..N_KEYS {
1497 if let Some(v) = cache.get(&key) {
1498 assert_eq!(v, key * 10);
1499 }
1500 }
1501 }
1502}