1#![allow(clippy::type_complexity)]
77#![cfg_attr(docsrs, feature(doc_cfg))]
78
79#[cfg(not(fuzzing))]
80mod linked_slab;
81#[cfg(fuzzing)]
82pub mod linked_slab;
83mod options;
84#[cfg(not(feature = "shuttle"))]
85mod rw_lock;
86mod shard;
87mod shim;
88pub mod sync;
90mod sync_placeholder;
91pub mod unsync;
93pub use equivalent::Equivalent;
94
95#[cfg(all(test, feature = "shuttle"))]
96mod shuttle_tests;
97
98pub use options::{Options, OptionsBuilder};
99
100#[cfg(feature = "ahash")]
101pub type DefaultHashBuilder = ahash::RandomState;
102#[cfg(not(feature = "ahash"))]
103pub type DefaultHashBuilder = std::collections::hash_map::RandomState;
104
105pub trait Weighter<Key, Val> {
126 fn weight(&self, key: &Key, val: &Val) -> u64;
141}
142
143#[derive(Debug, Clone, Default)]
145pub struct UnitWeighter;
146
147impl<Key, Val> Weighter<Key, Val> for UnitWeighter {
148 #[inline]
149 fn weight(&self, _key: &Key, _val: &Val) -> u64 {
150 1
151 }
152}
153
154pub trait Lifecycle<Key, Val> {
158 type RequestState;
159
160 #[allow(unused_variables)]
169 #[inline]
170 fn is_pinned(&self, key: &Key, val: &Val) -> bool {
171 false
172 }
173
174 fn begin_request(&self) -> Self::RequestState;
176
177 #[allow(unused_variables)]
185 #[inline]
186 fn before_evict(&self, state: &mut Self::RequestState, key: &Key, val: &mut Val) {}
187
188 fn on_evict(&self, state: &mut Self::RequestState, key: Key, val: Val);
190
191 #[allow(unused_variables)]
198 #[inline]
199 fn end_request(&self, state: Self::RequestState) {}
200}
201
202#[non_exhaustive]
206#[derive(Debug, Copy, Clone)]
207pub struct MemoryUsed {
208 pub entries: usize,
209 pub map: usize,
210}
211
212impl MemoryUsed {
213 pub fn total(&self) -> usize {
214 self.entries + self.map
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use std::{
221 hash::Hash,
222 sync::{atomic::AtomicUsize, Arc},
223 time::Duration,
224 };
225
226 use super::*;
227 #[derive(Clone)]
228 struct StringWeighter;
229
230 impl Weighter<u64, String> for StringWeighter {
231 fn weight(&self, _key: &u64, val: &String) -> u64 {
232 val.len() as u64
233 }
234 }
235
236 #[test]
237 fn test_new() {
238 sync::Cache::<(u64, u64), u64>::new(0);
239 sync::Cache::<(u64, u64), u64>::new(1);
240 sync::Cache::<(u64, u64), u64>::new(2);
241 sync::Cache::<(u64, u64), u64>::new(3);
242 sync::Cache::<(u64, u64), u64>::new(usize::MAX);
243 sync::Cache::<u64, u64>::new(0);
244 sync::Cache::<u64, u64>::new(1);
245 sync::Cache::<u64, u64>::new(2);
246 sync::Cache::<u64, u64>::new(3);
247 sync::Cache::<u64, u64>::new(usize::MAX);
248 }
249
250 #[test]
251 fn test_custom_cost() {
252 let cache = sync::Cache::with_weighter(100, 100_000, StringWeighter);
253 cache.insert(1, "1".to_string());
254 cache.insert(54, "54".to_string());
255 cache.insert(1000, "1000".to_string());
256 assert_eq!(cache.get(&1000).unwrap(), "1000");
257 }
258
259 #[test]
260 fn test_change_get_mut_change_weight() {
261 let mut cache = unsync::Cache::with_weighter(100, 100_000, StringWeighter);
262 cache.insert(1, "1".to_string());
263 assert_eq!(cache.get(&1).unwrap(), "1");
264 assert_eq!(cache.weight(), 1);
265 let _old = {
266 cache
267 .get_mut(&1)
268 .map(|mut v| std::mem::replace(&mut *v, "11".to_string()))
269 };
270 let _old = {
271 cache
272 .get_mut(&1)
273 .map(|mut v| std::mem::replace(&mut *v, "".to_string()))
274 };
275 assert_eq!(cache.get(&1).unwrap(), "");
276 assert_eq!(cache.weight(), 0);
277 cache.validate(false);
278 }
279
280 #[derive(Debug, Hash)]
281 pub struct Pair<A, B>(pub A, pub B);
282
283 impl<A, B, C, D> PartialEq<(A, B)> for Pair<C, D>
284 where
285 C: PartialEq<A>,
286 D: PartialEq<B>,
287 {
288 fn eq(&self, rhs: &(A, B)) -> bool {
289 self.0 == rhs.0 && self.1 == rhs.1
290 }
291 }
292
293 impl<A, B, X> Equivalent<X> for Pair<A, B>
294 where
295 Pair<A, B>: PartialEq<X>,
296 A: Hash + Eq,
297 B: Hash + Eq,
298 {
299 fn equivalent(&self, other: &X) -> bool {
300 *self == *other
301 }
302 }
303
304 #[test]
305 fn test_equivalent() {
306 let mut cache = unsync::Cache::new(5);
307 cache.insert(("square".to_string(), 2022), "blue".to_string());
308 cache.insert(("square".to_string(), 2023), "black".to_string());
309 assert_eq!(cache.get(&Pair("square", 2022)).unwrap(), "blue");
310 }
311
312 #[test]
313 fn test_borrow_keys() {
314 let cache = sync::Cache::<(Vec<u8>, Vec<u8>), u64>::new(0);
315 cache.get(&Pair(&b""[..], &b""[..]));
316 let cache = sync::Cache::<(String, String), u64>::new(0);
317 cache.get(&Pair("", ""));
318 }
319
320 #[test]
321 #[cfg_attr(miri, ignore)]
322 fn test_get_or_insert() {
323 use rand::prelude::*;
324 for _i in 0..2000 {
325 dbg!(_i);
326 let mut entered = AtomicUsize::default();
327 let cache = sync::Cache::<(u64, u64), u64>::new(100);
328 const THREADS: usize = 100;
329 let wg = std::sync::Barrier::new(THREADS);
330 let solve_at = rand::rng().random_range(0..THREADS);
331 std::thread::scope(|s| {
332 for _ in 0..THREADS {
333 s.spawn(|| {
334 wg.wait();
335 let result = cache.get_or_insert_with(&(1, 1), || {
336 let before = entered.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
337 if before == solve_at {
338 Ok(1)
339 } else {
340 Err(())
341 }
342 });
343 assert!(matches!(result, Ok(1) | Err(())));
344 });
345 }
346 });
347 assert_eq!(*entered.get_mut(), solve_at + 1);
348 }
349 }
350
351 #[test]
352 fn test_get_or_insert_unsync() {
353 let mut cache = unsync::Cache::<u64, u64>::new(100);
354 let guard = cache.get_ref_or_guard(&0).unwrap_err();
355 guard.insert(0);
356 assert_eq!(cache.get_ref_or_guard(&0).ok().copied(), Some(0));
357 let guard = cache.get_mut_or_guard(&1).err().unwrap();
358 guard.insert(1);
359 let v = *cache.get_mut_or_guard(&1).ok().unwrap().unwrap();
360 assert_eq!(v, 1);
361 let result = cache.get_or_insert_with::<_, ()>(&0, || panic!());
362 assert_eq!(result, Ok(Some(&0)));
363 let result = cache.get_or_insert_with::<_, ()>(&1, || panic!());
364 assert_eq!(result, Ok(Some(&1)));
365 let result = cache.get_or_insert_with::<_, ()>(&3, || Ok(3));
366 assert_eq!(result, Ok(Some(&3)));
367 let result = cache.get_or_insert_with::<_, ()>(&4, || Err(()));
368 assert_eq!(result, Err(()));
369 }
370
371 #[tokio::test]
372 async fn test_get_or_insert_sync() {
373 use crate::sync::*;
374 let cache = sync::Cache::<u64, u64>::new(100);
375 let GuardResult::Guard(guard) = cache.get_value_or_guard(&0, None) else {
376 panic!();
377 };
378 guard.insert(0).unwrap();
379 let GuardResult::Value(v) = cache.get_value_or_guard(&0, None) else {
380 panic!();
381 };
382 assert_eq!(v, 0);
383 let Err(guard) = cache.get_value_or_guard_async(&1).await else {
384 panic!();
385 };
386 guard.insert(1).unwrap();
387 let Ok(v) = cache.get_value_or_guard_async(&1).await else {
388 panic!();
389 };
390 assert_eq!(v, 1);
391
392 let result = cache.get_or_insert_with::<_, ()>(&0, || panic!());
393 assert_eq!(result, Ok(0));
394 let result = cache.get_or_insert_with::<_, ()>(&3, || Ok(3));
395 assert_eq!(result, Ok(3));
396 let result = cache.get_or_insert_with::<_, ()>(&4, || Err(()));
397 assert_eq!(result, Err(()));
398 let result = cache
399 .get_or_insert_async::<_, ()>(&0, async { panic!() })
400 .await;
401 assert_eq!(result, Ok(0));
402 let result = cache
403 .get_or_insert_async::<_, ()>(&4, async { Err(()) })
404 .await;
405 assert_eq!(result, Err(()));
406 let result = cache
407 .get_or_insert_async::<_, ()>(&4, async { Ok(4) })
408 .await;
409 assert_eq!(result, Ok(4));
410 }
411
412 #[test]
413 fn test_retain_unsync() {
414 let mut cache = unsync::Cache::<u64, u64>::new(100);
415 let ranges = 0..10;
416 for i in ranges.clone() {
417 let guard = cache.get_ref_or_guard(&i).unwrap_err();
418 guard.insert(i);
419 assert_eq!(cache.get_ref_or_guard(&i).ok().copied(), Some(i));
420 }
421 let small = 3;
422 cache.retain(|&key, &val| val > small && key > small);
423 for i in ranges.clone() {
424 let actual = cache.get(&i);
425 if i > small {
426 assert!(actual.is_some());
427 assert_eq!(*actual.unwrap(), i);
428 } else {
429 assert!(actual.is_none());
430 }
431 }
432 let big = 7;
433 cache.retain(|&key, &val| val < big && key < big);
434 for i in ranges {
435 let actual = cache.get(&i);
436 if i > small && i < big {
437 assert!(actual.is_some());
438 assert_eq!(*actual.unwrap(), i);
439 } else {
440 assert!(actual.is_none());
441 }
442 }
443 }
444
445 #[tokio::test]
446 async fn test_retain_sync() {
447 use crate::sync::*;
448 let cache = Cache::<u64, u64>::new(100);
449 let ranges = 0..10;
450 for i in ranges.clone() {
451 let GuardResult::Guard(guard) = cache.get_value_or_guard(&i, None) else {
452 panic!();
453 };
454 guard.insert(i).unwrap();
455 let GuardResult::Value(v) = cache.get_value_or_guard(&i, None) else {
456 panic!();
457 };
458 assert_eq!(v, i);
459 }
460 let small = 4;
461 cache.retain(|&key, &val| val > small && key > small);
462 for i in ranges.clone() {
463 let actual = cache.get(&i);
464 if i > small {
465 assert!(actual.is_some());
466 assert_eq!(actual.unwrap(), i);
467 } else {
468 assert!(actual.is_none());
469 }
470 }
471 let big = 8;
472 cache.retain(|&key, &val| val < big && key < big);
473 for i in ranges {
474 let actual = cache.get(&i);
475 if i > small && i < big {
476 assert!(actual.is_some());
477 assert_eq!(actual.unwrap(), i);
478 } else {
479 assert!(actual.is_none());
480 }
481 }
482 }
483
484 #[test]
485 #[cfg_attr(miri, ignore)]
486 fn test_value_or_guard() {
487 use crate::sync::*;
488 use rand::prelude::*;
489 for _i in 0..2000 {
490 dbg!(_i);
491 let mut entered = AtomicUsize::default();
492 let cache = sync::Cache::<(u64, u64), u64>::new(100);
493 const THREADS: usize = 100;
494 let wg = std::sync::Barrier::new(THREADS);
495 let solve_at = rand::rng().random_range(0..THREADS);
496 std::thread::scope(|s| {
497 for _ in 0..THREADS {
498 s.spawn(|| {
499 wg.wait();
500 loop {
501 match cache.get_value_or_guard(&(1, 1), Some(Duration::from_millis(1)))
502 {
503 GuardResult::Value(v) => assert_eq!(v, 1),
504 GuardResult::Guard(g) => {
505 let before =
506 entered.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
507 if before == solve_at {
508 g.insert(1).unwrap();
509 }
510 }
511 GuardResult::Timeout => continue,
512 }
513 break;
514 }
515 });
516 }
517 });
518 assert_eq!(*entered.get_mut(), solve_at + 1);
519 }
520 }
521
522 #[tokio::test(flavor = "multi_thread")]
523 #[cfg_attr(miri, ignore)]
524 async fn test_get_or_insert_async() {
525 use rand::prelude::*;
526 for _i in 0..5000 {
527 dbg!(_i);
528 let entered = Arc::new(AtomicUsize::default());
529 let cache = Arc::new(sync::Cache::<(u64, u64), u64>::new(100));
530 const TASKS: usize = 100;
531 let wg = Arc::new(tokio::sync::Barrier::new(TASKS));
532 let solve_at = rand::rng().random_range(0..TASKS);
533 let mut tasks = Vec::new();
534 for _ in 0..TASKS {
535 let cache = cache.clone();
536 let wg = wg.clone();
537 let entered = entered.clone();
538 let task = tokio::spawn(async move {
539 wg.wait().await;
540 let result = cache
541 .get_or_insert_async(&(1, 1), async {
542 let before = entered.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
543 if before == solve_at {
544 Ok(1)
545 } else {
546 Err(())
547 }
548 })
549 .await;
550 assert!(matches!(result, Ok(1) | Err(())));
551 });
552 tasks.push(task);
553 }
554 for task in tasks {
555 task.await.unwrap();
556 }
557 assert_eq!(cache.get(&(1, 1)), Some(1));
558 assert_eq!(
559 entered.load(std::sync::atomic::Ordering::Relaxed),
560 solve_at + 1
561 );
562 }
563 }
564
565 #[tokio::test(flavor = "multi_thread")]
566 #[cfg_attr(miri, ignore)]
567 async fn test_value_or_guard_async() {
568 use rand::prelude::*;
569 for _i in 0..5000 {
570 dbg!(_i);
571 let entered = Arc::new(AtomicUsize::default());
572 let cache = Arc::new(sync::Cache::<(u64, u64), u64>::new(100));
573 const TASKS: usize = 100;
574 let wg = Arc::new(tokio::sync::Barrier::new(TASKS));
575 let solve_at = rand::rng().random_range(0..TASKS);
576 let mut tasks = Vec::new();
577 for _ in 0..TASKS {
578 let cache = cache.clone();
579 let wg = wg.clone();
580 let entered = entered.clone();
581 let task = tokio::spawn(async move {
582 wg.wait().await;
583 loop {
584 match tokio::time::timeout(
585 Duration::from_millis(1),
586 cache.get_value_or_guard_async(&(1, 1)),
587 )
588 .await
589 {
590 Ok(Ok(r)) => assert_eq!(r, 1),
591 Ok(Err(g)) => {
592 let before =
593 entered.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
594 if before == solve_at {
595 g.insert(1).unwrap();
596 }
597 }
598 Err(_) => continue,
599 }
600 break;
601 }
602 });
603 tasks.push(task);
604 }
605 for task in tasks {
606 task.await.unwrap();
607 }
608 assert_eq!(cache.get(&(1, 1)), Some(1));
609 assert_eq!(
610 entered.load(std::sync::atomic::Ordering::Relaxed),
611 solve_at + 1
612 );
613 }
614 }
615}