1use crate::freelist::{FreeList, FreeListHandle, WeakFreeListHandle};
6use std::{mem, num};
7
8#[cfg_attr(feature = "capture", derive(Serialize))]
43#[cfg_attr(feature = "replay", derive(Deserialize))]
44#[derive(MallocSizeOf)]
45struct LRUCacheEntry<T> {
46 partition_index: u8,
48
49 lru_index: ItemIndex,
52
53 value: T,
55}
56
57#[cfg_attr(feature = "capture", derive(Serialize))]
59#[cfg_attr(feature = "replay", derive(Deserialize))]
60#[derive(MallocSizeOf)]
61pub struct LRUCache<T, M> {
62 entries: FreeList<LRUCacheEntry<T>, M>,
64 lru: Vec<LRUTracker<FreeListHandle<M>>>,
66}
67
68impl<T, M> LRUCache<T, M> {
69 pub fn new(lru_partition_count: usize) -> Self {
71 assert!(lru_partition_count <= u8::MAX as usize + 1);
72 LRUCache {
73 entries: FreeList::new(),
74 lru: (0..lru_partition_count).map(|_| LRUTracker::new()).collect(),
75 }
76 }
77
78 pub fn push_new(
82 &mut self,
83 partition_index: u8,
84 value: T,
85 ) -> WeakFreeListHandle<M> {
86 let handle = self.entries.insert(LRUCacheEntry {
92 partition_index: 0,
93 lru_index: ItemIndex(num::NonZeroU32::new(1).unwrap()),
94 value
95 });
96
97 let weak_handle = handle.weak();
99
100 let entry = self.entries.get_mut(&handle);
103 let lru_index = self.lru[partition_index as usize].push_new(handle);
104 entry.partition_index = partition_index;
105 entry.lru_index = lru_index;
106
107 weak_handle
108 }
109
110 pub fn get_opt(
113 &self,
114 handle: &WeakFreeListHandle<M>,
115 ) -> Option<&T> {
116 self.entries
117 .get_opt(handle)
118 .map(|entry| {
119 &entry.value
120 })
121 }
122
123 pub fn get_opt_mut(
126 &mut self,
127 handle: &WeakFreeListHandle<M>,
128 ) -> Option<&mut T> {
129 self.entries
130 .get_opt_mut(handle)
131 .map(|entry| {
132 &mut entry.value
133 })
134 }
135
136 pub fn peek_oldest(&self, partition_index: u8) -> Option<&T> {
139 self.lru[partition_index as usize]
140 .peek_front()
141 .map(|handle| {
142 let entry = self.entries.get(handle);
143 &entry.value
144 })
145 }
146
147 pub fn pop_oldest(
150 &mut self,
151 partition_index: u8,
152 ) -> Option<T> {
153 self.lru[partition_index as usize]
154 .pop_front()
155 .map(|handle| {
156 let entry = self.entries.free(handle);
157 entry.value
158 })
159 }
160
161 #[must_use]
167 pub fn replace_or_insert(
168 &mut self,
169 handle: &mut WeakFreeListHandle<M>,
170 partition_index: u8,
171 data: T,
172 ) -> Option<T> {
173 match self.entries.get_opt_mut(handle) {
174 Some(entry) => {
175 if entry.partition_index != partition_index {
176 let strong_handle = self.lru[entry.partition_index as usize].remove(entry.lru_index);
178 let lru_index = self.lru[partition_index as usize].push_new(strong_handle);
179 entry.partition_index = partition_index;
180 entry.lru_index = lru_index;
181 }
182 Some(mem::replace(&mut entry.value, data))
183 }
184 None => {
185 *handle = self.push_new(partition_index, data);
186 None
187 }
188 }
189 }
190
191 pub fn remove(&mut self, handle: &WeakFreeListHandle<M>) -> Option<T> {
193 if let Some(entry) = self.entries.get_opt_mut(handle) {
194 let strong_handle = self.lru[entry.partition_index as usize].remove(entry.lru_index);
195 return Some(self.entries.free(strong_handle).value);
196 }
197
198 None
199 }
200
201 pub fn touch(
206 &mut self,
207 handle: &WeakFreeListHandle<M>,
208 ) -> Option<&mut T> {
209 let lru = &mut self.lru;
210
211 self.entries
212 .get_opt_mut(handle)
213 .map(|entry| {
214 lru[entry.partition_index as usize].mark_used(entry.lru_index);
215 &mut entry.value
216 })
217 }
218
219 #[cfg(test)]
221 fn validate(&self) {
222 for lru in &self.lru {
223 lru.validate();
224 }
225 }
226}
227
228#[cfg_attr(feature = "capture", derive(Serialize))]
230#[cfg_attr(feature = "replay", derive(Deserialize))]
231#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, MallocSizeOf)]
232struct ItemIndex(num::NonZeroU32);
233
234impl ItemIndex {
235 fn as_usize(&self) -> usize {
236 self.0.get() as usize
237 }
238}
239
240#[cfg_attr(feature = "capture", derive(Serialize))]
245#[cfg_attr(feature = "replay", derive(Deserialize))]
246#[derive(Debug, MallocSizeOf)]
247struct Item<H> {
248 prev: Option<ItemIndex>,
249 next: Option<ItemIndex>,
250 handle: Option<H>,
251}
252
253#[cfg_attr(feature = "capture", derive(Serialize))]
255#[cfg_attr(feature = "replay", derive(Deserialize))]
256#[derive(MallocSizeOf)]
257struct LRUTracker<H> {
258 head: Option<ItemIndex>,
260 tail: Option<ItemIndex>,
262 free_list_head: Option<ItemIndex>,
264 items: Vec<Item<H>>,
266}
267
268impl<H> LRUTracker<H> where H: std::fmt::Debug {
269 fn new() -> Self {
271 let items = vec![
274 Item {
275 prev: None,
276 next: None,
277 handle: None,
278 },
279 ];
280
281 LRUTracker {
282 head: None,
283 tail: None,
284 free_list_head: None,
285 items,
286 }
287 }
288
289 fn link_as_new_tail(
292 &mut self,
293 item_index: ItemIndex,
294 ) {
295 match (self.head, self.tail) {
296 (Some(..), Some(tail)) => {
297 self.items[item_index.as_usize()].prev = Some(tail);
299 self.items[item_index.as_usize()].next = None;
300
301 self.items[tail.as_usize()].next = Some(item_index);
302 self.tail = Some(item_index);
303 }
304 (None, None) => {
305 self.items[item_index.as_usize()].prev = None;
307 self.items[item_index.as_usize()].next = None;
308
309 self.head = Some(item_index);
310 self.tail = Some(item_index);
311 }
312 (Some(..), None) | (None, Some(..)) => {
313 unreachable!();
315 }
316 }
317 }
318
319 fn unlink(
323 &mut self,
324 item_index: ItemIndex,
325 ) {
326 let (next, prev) = {
327 let item = &self.items[item_index.as_usize()];
328 (item.next, item.prev)
329 };
330
331 match next {
332 Some(next) => {
333 self.items[next.as_usize()].prev = prev;
334 }
335 None => {
336 debug_assert_eq!(self.tail, Some(item_index));
337 self.tail = prev;
338 }
339 }
340
341 match prev {
342 Some(prev) => {
343 self.items[prev.as_usize()].next = next;
344 }
345 None => {
346 debug_assert_eq!(self.head, Some(item_index));
347 self.head = next;
348 }
349 }
350 }
351
352 fn push_new(
355 &mut self,
356 handle: H,
357 ) -> ItemIndex {
358 let item_index = match self.free_list_head {
360 Some(index) => {
361 let item = &mut self.items[index.as_usize()];
363
364 assert!(item.handle.is_none());
365 item.handle = Some(handle);
366
367 self.free_list_head = item.next;
368
369 index
370 }
371 None => {
372 let index = ItemIndex(num::NonZeroU32::new(self.items.len() as u32).unwrap());
374
375 self.items.push(Item {
376 prev: None,
377 next: None,
378 handle: Some(handle),
379 });
380
381 index
382 }
383 };
384
385 self.link_as_new_tail(item_index);
387
388 item_index
389 }
390
391 fn peek_front(&self) -> Option<&H> {
393 self.head.map(|head| self.items[head.as_usize()].handle.as_ref().unwrap())
394 }
395
396 fn pop_front(
399 &mut self,
400 ) -> Option<H> {
401 let handle = match (self.head, self.tail) {
402 (Some(head), Some(tail)) => {
403 let item_index = head;
404
405 if head == tail {
407 self.head = None;
408 self.tail = None;
409 } else {
410 let new_head = self.items[head.as_usize()].next.unwrap();
412 self.head = Some(new_head);
413 self.items[new_head.as_usize()].prev = None;
414 }
415
416 self.items[item_index.as_usize()].next = self.free_list_head;
418 self.free_list_head = Some(item_index);
419
420 Some(self.items[item_index.as_usize()].handle.take().unwrap())
422 }
423 (None, None) => {
424 None
426 }
427 (Some(..), None) | (None, Some(..)) => {
428 unreachable!();
430 }
431 };
432
433 handle
434 }
435
436 fn remove(
439 &mut self,
440 index: ItemIndex,
441 ) -> H {
442 self.unlink(index);
444
445 let handle = self.items[index.as_usize()].handle.take().unwrap();
446
447 self.items[index.as_usize()].next = self.free_list_head;
449 self.free_list_head = Some(index);
450
451 handle
452 }
453
454 fn mark_used(
457 &mut self,
458 index: ItemIndex,
459 ) {
460 self.unlink(index);
461 self.link_as_new_tail(index);
462 }
463
464 #[cfg(test)]
466 fn validate(&self) {
467 use std::collections::HashSet;
468
469 assert!((self.head.is_none() && self.tail.is_none()) || (self.head.is_some() && self.tail.is_some()));
471
472 if let Some(head) = self.head {
474 assert!(self.items[head.as_usize()].prev.is_none());
475 }
476
477 if let Some(tail) = self.tail {
479 assert!(self.items[tail.as_usize()].next.is_none());
480 }
481
482 let mut free_items = Vec::new();
484 let mut free_items_set = HashSet::new();
485 let mut valid_items_front = Vec::new();
486 let mut valid_items_front_set = HashSet::new();
487 let mut valid_items_reverse = Vec::new();
488 let mut valid_items_reverse_set = HashSet::new();
489
490 let mut current = self.free_list_head;
491 while let Some(index) = current {
492 let item = &self.items[index.as_usize()];
493 free_items.push(index);
494 assert!(free_items_set.insert(index));
495 current = item.next;
496 }
497
498 current = self.head;
499 while let Some(index) = current {
500 let item = &self.items[index.as_usize()];
501 valid_items_front.push(index);
502 assert!(valid_items_front_set.insert(index));
503 current = item.next;
504 }
505
506 current = self.tail;
507 while let Some(index) = current {
508 let item = &self.items[index.as_usize()];
509 valid_items_reverse.push(index);
510 assert!(!valid_items_reverse_set.contains(&index));
511 valid_items_reverse_set.insert(index);
512 current = item.prev;
513 }
514
515 assert_eq!(valid_items_front.len(), valid_items_front_set.len());
517 assert_eq!(valid_items_reverse.len(), valid_items_reverse_set.len());
518
519 assert_eq!(free_items.len() + valid_items_front.len() + 1, self.items.len());
521
522 assert_eq!(valid_items_front.len(), valid_items_reverse.len());
524
525 assert!(free_items_set.intersection(&valid_items_reverse_set).collect::<HashSet<_>>().is_empty());
527 assert!(free_items_set.intersection(&valid_items_front_set).collect::<HashSet<_>>().is_empty());
528
529 assert_eq!(valid_items_front_set.len(), valid_items_reverse_set.len());
531
532 for (i0, i1) in valid_items_front.iter().zip(valid_items_reverse.iter().rev()) {
534 assert_eq!(i0, i1);
535 }
536 }
537}
538
539#[test]
540fn test_lru_tracker_push_peek() {
541 struct CacheMarker;
546 const NUM_ELEMENTS: usize = 50;
547
548 let mut cache: LRUCache<usize, CacheMarker> = LRUCache::new(1);
549 cache.validate();
550
551 assert_eq!(cache.peek_oldest(0), None);
552
553 for i in 0 .. NUM_ELEMENTS {
554 cache.push_new(0, i);
555 }
556 cache.validate();
557
558 assert_eq!(cache.peek_oldest(0), Some(&0));
559 assert_eq!(cache.peek_oldest(0), Some(&0));
560
561 cache.pop_oldest(0);
562 assert_eq!(cache.peek_oldest(0), Some(&1));
563}
564
565#[test]
566fn test_lru_tracker_push_pop() {
567 struct CacheMarker;
571 const NUM_ELEMENTS: usize = 50;
572
573 let mut cache: LRUCache<usize, CacheMarker> = LRUCache::new(1);
574 cache.validate();
575
576 for i in 0 .. NUM_ELEMENTS {
577 cache.push_new(0, i);
578 }
579 cache.validate();
580
581 for i in 0 .. NUM_ELEMENTS {
582 assert_eq!(cache.pop_oldest(0), Some(i));
583 }
584 cache.validate();
585
586 assert_eq!(cache.pop_oldest(0), None);
587}
588
589#[test]
590fn test_lru_tracker_push_touch_pop() {
591 struct CacheMarker;
595 const NUM_ELEMENTS: usize = 50;
596
597 let mut cache: LRUCache<usize, CacheMarker> = LRUCache::new(1);
598 let mut handles = Vec::new();
599 cache.validate();
600
601 for i in 0 .. NUM_ELEMENTS {
602 handles.push(cache.push_new(0, i));
603 }
604 cache.validate();
605
606 for i in 0 .. NUM_ELEMENTS/2 {
607 cache.touch(&handles[i*2]);
608 }
609 cache.validate();
610
611 for i in 0 .. NUM_ELEMENTS/2 {
612 assert_eq!(cache.pop_oldest(0), Some(i*2+1));
613 }
614 cache.validate();
615 for i in 0 .. NUM_ELEMENTS/2 {
616 assert_eq!(cache.pop_oldest(0), Some(i*2));
617 }
618 cache.validate();
619
620 assert_eq!(cache.pop_oldest(0), None);
621}
622
623#[test]
624fn test_lru_tracker_push_get() {
625 struct CacheMarker;
628 const NUM_ELEMENTS: usize = 50;
629
630 let mut cache: LRUCache<usize, CacheMarker> = LRUCache::new(1);
631 let mut handles = Vec::new();
632 cache.validate();
633
634 for i in 0 .. NUM_ELEMENTS {
635 handles.push(cache.push_new(0, i));
636 }
637 cache.validate();
638
639 for i in 0 .. NUM_ELEMENTS/2 {
640 assert!(cache.get_opt(&handles[i]) == Some(&i));
641 }
642 cache.validate();
643}
644
645#[test]
646fn test_lru_tracker_push_replace_get() {
647 struct CacheMarker;
651 const NUM_ELEMENTS: usize = 50;
652
653 let mut cache: LRUCache<usize, CacheMarker> = LRUCache::new(1);
654 let mut handles = Vec::new();
655 cache.validate();
656
657 for i in 0 .. NUM_ELEMENTS {
658 handles.push(cache.push_new(0, i));
659 }
660 cache.validate();
661
662 for i in 0 .. NUM_ELEMENTS {
663 assert_eq!(cache.replace_or_insert(&mut handles[i], 0, i * 2), Some(i));
664 }
665 cache.validate();
666
667 for i in 0 .. NUM_ELEMENTS/2 {
668 assert!(cache.get_opt(&handles[i]) == Some(&(i * 2)));
669 }
670 cache.validate();
671
672 let mut empty_handle = WeakFreeListHandle::invalid();
673 assert_eq!(cache.replace_or_insert(&mut empty_handle, 0, 100), None);
674 assert_eq!(cache.get_opt(&empty_handle), Some(&100));
675}