1#![allow(clippy::type_complexity)]
77#![cfg_attr(docsrs, feature(doc_cfg))]
78
79#[cfg(not(fuzzing))]
80mod linked_slab;
81#[cfg(fuzzing)]
82pub mod linked_slab;
83mod options;
84#[cfg(not(feature = "shuttle"))]
85mod rw_lock;
86mod shard;
87mod shim;
88pub mod sync;
90mod sync_placeholder;
91pub mod unsync;
93pub use equivalent::Equivalent;
94
95#[cfg(all(test, feature = "shuttle"))]
96mod shuttle_tests;
97
98pub use options::{Options, OptionsBuilder};
99
100#[cfg(feature = "ahash")]
101pub type DefaultHashBuilder = ahash::RandomState;
102#[cfg(not(feature = "ahash"))]
103pub type DefaultHashBuilder = std::collections::hash_map::RandomState;
104
105pub trait Weighter<Key, Val> {
126 fn weight(&self, key: &Key, val: &Val) -> u64;
141}
142
143#[derive(Debug, Clone, Default)]
145pub struct UnitWeighter;
146
147impl<Key, Val> Weighter<Key, Val> for UnitWeighter {
148 #[inline]
149 fn weight(&self, _key: &Key, _val: &Val) -> u64 {
150 1
151 }
152}
153
154pub trait Lifecycle<Key, Val> {
158 type RequestState;
159
160 #[allow(unused_variables)]
169 #[inline]
170 fn is_pinned(&self, key: &Key, val: &Val) -> bool {
171 false
172 }
173
174 fn begin_request(&self) -> Self::RequestState;
176
177 #[allow(unused_variables)]
185 #[inline]
186 fn before_evict(&self, state: &mut Self::RequestState, key: &Key, val: &mut Val) {}
187
188 fn on_evict(&self, state: &mut Self::RequestState, key: Key, val: Val);
190
191 #[allow(unused_variables)]
198 #[inline]
199 fn end_request(&self, state: Self::RequestState) {}
200}
201
202#[non_exhaustive]
206#[derive(Debug, Copy, Clone)]
207pub struct MemoryUsed {
208 pub entries: usize,
209 pub map: usize,
210}
211
212impl MemoryUsed {
213 pub fn total(&self) -> usize {
214 self.entries + self.map
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use std::{
221 hash::Hash,
222 sync::{atomic::AtomicUsize, Arc},
223 time::Duration,
224 };
225
226 use super::*;
227 #[derive(Clone)]
228 struct StringWeighter;
229
230 impl Weighter<u64, String> for StringWeighter {
231 fn weight(&self, _key: &u64, val: &String) -> u64 {
232 val.len() as u64
233 }
234 }
235
236 #[test]
237 fn test_new() {
238 sync::Cache::<(u64, u64), u64>::new(0);
239 sync::Cache::<(u64, u64), u64>::new(1);
240 sync::Cache::<(u64, u64), u64>::new(2);
241 sync::Cache::<(u64, u64), u64>::new(3);
242 sync::Cache::<(u64, u64), u64>::new(usize::MAX);
243 sync::Cache::<u64, u64>::new(0);
244 sync::Cache::<u64, u64>::new(1);
245 sync::Cache::<u64, u64>::new(2);
246 sync::Cache::<u64, u64>::new(3);
247 sync::Cache::<u64, u64>::new(usize::MAX);
248 }
249
250 #[test]
251 fn test_capacity_one() {
252 let cache = sync::Cache::<u64, u64>::new(1);
254 cache.insert(1, 10);
255 assert_eq!(cache.get(&1), Some(10));
256 cache.insert(2, 20);
258 assert_eq!(cache.get(&2), Some(20));
259 assert_eq!(cache.get(&1), None);
260
261 let mut cache = unsync::Cache::<u64, u64>::new(1);
263 cache.insert(1, 10);
264 assert_eq!(cache.get(&1), Some(&10));
265 cache.insert(2, 20);
266 assert_eq!(cache.get(&2), Some(&20));
267 assert_eq!(cache.get(&1), None);
268
269 let cache = sync::Cache::<u64, u64>::new(0);
271 cache.insert(1, 10);
272 assert_eq!(cache.get(&1), None);
273 }
274
275 #[test]
276 fn test_custom_cost() {
277 let cache = sync::Cache::with_weighter(100, 100_000, StringWeighter);
278 cache.insert(1, "1".to_string());
279 cache.insert(54, "54".to_string());
280 cache.insert(1000, "1000".to_string());
281 assert_eq!(cache.get(&1000).unwrap(), "1000");
282 }
283
284 #[test]
285 fn test_change_get_mut_change_weight() {
286 let mut cache = unsync::Cache::with_weighter(100, 100_000, StringWeighter);
287 cache.insert(1, "1".to_string());
288 assert_eq!(cache.get(&1).unwrap(), "1");
289 assert_eq!(cache.weight(), 1);
290 let _old = {
291 cache
292 .get_mut(&1)
293 .map(|mut v| std::mem::replace(&mut *v, "11".to_string()))
294 };
295 let _old = {
296 cache
297 .get_mut(&1)
298 .map(|mut v| std::mem::replace(&mut *v, "".to_string()))
299 };
300 assert_eq!(cache.get(&1).unwrap(), "");
301 assert_eq!(cache.weight(), 0);
302 cache.validate(false);
303 }
304
305 #[derive(Debug, Hash)]
306 pub struct Pair<A, B>(pub A, pub B);
307
308 impl<A, B, C, D> PartialEq<(A, B)> for Pair<C, D>
309 where
310 C: PartialEq<A>,
311 D: PartialEq<B>,
312 {
313 fn eq(&self, rhs: &(A, B)) -> bool {
314 self.0 == rhs.0 && self.1 == rhs.1
315 }
316 }
317
318 impl<A, B, X> Equivalent<X> for Pair<A, B>
319 where
320 Pair<A, B>: PartialEq<X>,
321 A: Hash + Eq,
322 B: Hash + Eq,
323 {
324 fn equivalent(&self, other: &X) -> bool {
325 *self == *other
326 }
327 }
328
329 #[test]
330 fn test_equivalent() {
331 let mut cache = unsync::Cache::new(5);
332 cache.insert(("square".to_string(), 2022), "blue".to_string());
333 cache.insert(("square".to_string(), 2023), "black".to_string());
334 assert_eq!(cache.get(&Pair("square", 2022)).unwrap(), "blue");
335 }
336
337 #[test]
338 fn test_borrow_keys() {
339 let cache = sync::Cache::<(Vec<u8>, Vec<u8>), u64>::new(0);
340 cache.get(&Pair(&b""[..], &b""[..]));
341 let cache = sync::Cache::<(String, String), u64>::new(0);
342 cache.get(&Pair("", ""));
343 }
344
345 #[test]
346 #[cfg_attr(miri, ignore)]
347 fn test_get_or_insert() {
348 use rand::prelude::*;
349 for _i in 0..2000 {
350 dbg!(_i);
351 let mut entered = AtomicUsize::default();
352 let cache = sync::Cache::<(u64, u64), u64>::new(100);
353 const THREADS: usize = 100;
354 let wg = std::sync::Barrier::new(THREADS);
355 let solve_at = rand::rng().random_range(0..THREADS);
356 std::thread::scope(|s| {
357 for _ in 0..THREADS {
358 s.spawn(|| {
359 wg.wait();
360 let result = cache.get_or_insert_with(&(1, 1), || {
361 let before = entered.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
362 if before == solve_at {
363 Ok(1)
364 } else {
365 Err(())
366 }
367 });
368 assert!(matches!(result, Ok(1) | Err(())));
369 });
370 }
371 });
372 assert_eq!(*entered.get_mut(), solve_at + 1);
373 }
374 }
375
376 #[test]
377 fn test_get_or_insert_unsync() {
378 let mut cache = unsync::Cache::<u64, u64>::new(100);
379 let guard = cache.get_ref_or_guard(&0).unwrap_err();
380 guard.insert(0);
381 assert_eq!(cache.get_ref_or_guard(&0).ok().copied(), Some(0));
382 let guard = cache.get_mut_or_guard(&1).err().unwrap();
383 guard.insert(1);
384 let v = *cache.get_mut_or_guard(&1).ok().unwrap().unwrap();
385 assert_eq!(v, 1);
386 let result = cache.get_or_insert_with::<_, ()>(&0, || panic!());
387 assert_eq!(result, Ok(Some(&0)));
388 let result = cache.get_or_insert_with::<_, ()>(&1, || panic!());
389 assert_eq!(result, Ok(Some(&1)));
390 let result = cache.get_or_insert_with::<_, ()>(&3, || Ok(3));
391 assert_eq!(result, Ok(Some(&3)));
392 let result = cache.get_or_insert_with::<_, ()>(&4, || Err(()));
393 assert_eq!(result, Err(()));
394 }
395
396 #[tokio::test]
397 async fn test_get_or_insert_sync() {
398 use crate::sync::*;
399 let cache = sync::Cache::<u64, u64>::new(100);
400 let GuardResult::Guard(guard) = cache.get_value_or_guard(&0, None) else {
401 panic!();
402 };
403 guard.insert(0).unwrap();
404 let GuardResult::Value(v) = cache.get_value_or_guard(&0, None) else {
405 panic!();
406 };
407 assert_eq!(v, 0);
408 let Err(guard) = cache.get_value_or_guard_async(&1).await else {
409 panic!();
410 };
411 guard.insert(1).unwrap();
412 let Ok(v) = cache.get_value_or_guard_async(&1).await else {
413 panic!();
414 };
415 assert_eq!(v, 1);
416
417 let result = cache.get_or_insert_with::<_, ()>(&0, || panic!());
418 assert_eq!(result, Ok(0));
419 let result = cache.get_or_insert_with::<_, ()>(&3, || Ok(3));
420 assert_eq!(result, Ok(3));
421 let result = cache.get_or_insert_with::<_, ()>(&4, || Err(()));
422 assert_eq!(result, Err(()));
423 let result = cache
424 .get_or_insert_async::<_, ()>(&0, async { panic!() })
425 .await;
426 assert_eq!(result, Ok(0));
427 let result = cache
428 .get_or_insert_async::<_, ()>(&4, async { Err(()) })
429 .await;
430 assert_eq!(result, Err(()));
431 let result = cache
432 .get_or_insert_async::<_, ()>(&4, async { Ok(4) })
433 .await;
434 assert_eq!(result, Ok(4));
435 }
436
437 #[test]
438 fn test_retain_unsync() {
439 let mut cache = unsync::Cache::<u64, u64>::new(100);
440 let ranges = 0..10;
441 for i in ranges.clone() {
442 let guard = cache.get_ref_or_guard(&i).unwrap_err();
443 guard.insert(i);
444 assert_eq!(cache.get_ref_or_guard(&i).ok().copied(), Some(i));
445 }
446 let small = 3;
447 cache.retain(|&key, &val| val > small && key > small);
448 for i in ranges.clone() {
449 let actual = cache.get(&i);
450 if i > small {
451 assert!(actual.is_some());
452 assert_eq!(*actual.unwrap(), i);
453 } else {
454 assert!(actual.is_none());
455 }
456 }
457 let big = 7;
458 cache.retain(|&key, &val| val < big && key < big);
459 for i in ranges {
460 let actual = cache.get(&i);
461 if i > small && i < big {
462 assert!(actual.is_some());
463 assert_eq!(*actual.unwrap(), i);
464 } else {
465 assert!(actual.is_none());
466 }
467 }
468 }
469
470 #[tokio::test]
471 async fn test_retain_sync() {
472 use crate::sync::*;
473 let cache = Cache::<u64, u64>::new(100);
474 let ranges = 0..10;
475 for i in ranges.clone() {
476 let GuardResult::Guard(guard) = cache.get_value_or_guard(&i, None) else {
477 panic!();
478 };
479 guard.insert(i).unwrap();
480 let GuardResult::Value(v) = cache.get_value_or_guard(&i, None) else {
481 panic!();
482 };
483 assert_eq!(v, i);
484 }
485 let small = 4;
486 cache.retain(|&key, &val| val > small && key > small);
487 for i in ranges.clone() {
488 let actual = cache.get(&i);
489 if i > small {
490 assert!(actual.is_some());
491 assert_eq!(actual.unwrap(), i);
492 } else {
493 assert!(actual.is_none());
494 }
495 }
496 let big = 8;
497 cache.retain(|&key, &val| val < big && key < big);
498 for i in ranges {
499 let actual = cache.get(&i);
500 if i > small && i < big {
501 assert!(actual.is_some());
502 assert_eq!(actual.unwrap(), i);
503 } else {
504 assert!(actual.is_none());
505 }
506 }
507 }
508
509 #[test]
510 #[cfg_attr(miri, ignore)]
511 fn test_value_or_guard() {
512 use crate::sync::*;
513 use rand::prelude::*;
514 for _i in 0..2000 {
515 dbg!(_i);
516 let mut entered = AtomicUsize::default();
517 let cache = sync::Cache::<(u64, u64), u64>::new(100);
518 const THREADS: usize = 100;
519 let wg = std::sync::Barrier::new(THREADS);
520 let solve_at = rand::rng().random_range(0..THREADS);
521 std::thread::scope(|s| {
522 for _ in 0..THREADS {
523 s.spawn(|| {
524 wg.wait();
525 loop {
526 match cache.get_value_or_guard(&(1, 1), Some(Duration::from_millis(1)))
527 {
528 GuardResult::Value(v) => assert_eq!(v, 1),
529 GuardResult::Guard(g) => {
530 let before =
531 entered.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
532 if before == solve_at {
533 g.insert(1).unwrap();
534 }
535 }
536 GuardResult::Timeout => continue,
537 }
538 break;
539 }
540 });
541 }
542 });
543 assert_eq!(*entered.get_mut(), solve_at + 1);
544 }
545 }
546
547 #[tokio::test(flavor = "multi_thread")]
548 #[cfg_attr(miri, ignore)]
549 async fn test_get_or_insert_async() {
550 use rand::prelude::*;
551 for _i in 0..5000 {
552 dbg!(_i);
553 let entered = Arc::new(AtomicUsize::default());
554 let cache = Arc::new(sync::Cache::<(u64, u64), u64>::new(100));
555 const TASKS: usize = 100;
556 let wg = Arc::new(tokio::sync::Barrier::new(TASKS));
557 let solve_at = rand::rng().random_range(0..TASKS);
558 let mut tasks = Vec::new();
559 for _ in 0..TASKS {
560 let cache = cache.clone();
561 let wg = wg.clone();
562 let entered = entered.clone();
563 let task = tokio::spawn(async move {
564 wg.wait().await;
565 let result = cache
566 .get_or_insert_async(&(1, 1), async {
567 let before = entered.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
568 if before == solve_at {
569 Ok(1)
570 } else {
571 Err(())
572 }
573 })
574 .await;
575 assert!(matches!(result, Ok(1) | Err(())));
576 });
577 tasks.push(task);
578 }
579 for task in tasks {
580 task.await.unwrap();
581 }
582 assert_eq!(cache.get(&(1, 1)), Some(1));
583 assert_eq!(
584 entered.load(std::sync::atomic::Ordering::Relaxed),
585 solve_at + 1
586 );
587 }
588 }
589
590 #[tokio::test(flavor = "multi_thread")]
591 #[cfg_attr(miri, ignore)]
592 async fn test_value_or_guard_async() {
593 use rand::prelude::*;
594 for _i in 0..5000 {
595 dbg!(_i);
596 let entered = Arc::new(AtomicUsize::default());
597 let cache = Arc::new(sync::Cache::<(u64, u64), u64>::new(100));
598 const TASKS: usize = 100;
599 let wg = Arc::new(tokio::sync::Barrier::new(TASKS));
600 let solve_at = rand::rng().random_range(0..TASKS);
601 let mut tasks = Vec::new();
602 for _ in 0..TASKS {
603 let cache = cache.clone();
604 let wg = wg.clone();
605 let entered = entered.clone();
606 let task = tokio::spawn(async move {
607 wg.wait().await;
608 loop {
609 match tokio::time::timeout(
610 Duration::from_millis(1),
611 cache.get_value_or_guard_async(&(1, 1)),
612 )
613 .await
614 {
615 Ok(Ok(r)) => assert_eq!(r, 1),
616 Ok(Err(g)) => {
617 let before =
618 entered.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
619 if before == solve_at {
620 g.insert(1).unwrap();
621 }
622 }
623 Err(_) => continue,
624 }
625 break;
626 }
627 });
628 tasks.push(task);
629 }
630 for task in tasks {
631 task.await.unwrap();
632 }
633 assert_eq!(cache.get(&(1, 1)), Some(1));
634 assert_eq!(
635 entered.load(std::sync::atomic::Ordering::Relaxed),
636 solve_at + 1
637 );
638 }
639 }
640}