diff --git a/src/lib.rs b/src/lib.rs index 2d39b74..a12ca47 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -158,15 +158,17 @@ impl Borrow<[T]> for KeyRef> { struct LruEntry { key: mem::MaybeUninit, val: mem::MaybeUninit, + cost: usize, prev: *mut LruEntry, next: *mut LruEntry, } impl LruEntry { - fn new(key: K, val: V) -> Self { + fn new(key: K, val: V, cost: usize) -> Self { LruEntry { key: mem::MaybeUninit::new(key), val: mem::MaybeUninit::new(val), + cost, prev: ptr::null_mut(), next: ptr::null_mut(), } @@ -176,6 +178,7 @@ impl LruEntry { LruEntry { key: mem::MaybeUninit::uninit(), val: mem::MaybeUninit::uninit(), + cost: 0, prev: ptr::null_mut(), next: ptr::null_mut(), } @@ -190,7 +193,8 @@ pub type DefaultHasher = std::collections::hash_map::RandomState; /// An LRU Cache pub struct LruCache { map: HashMap, Box>, S>, - cap: NonZeroUsize, + cost_cap: NonZeroUsize, + cost: usize, // head and tail are sigil nodes to facilitate inserting entries head: *mut LruEntry, @@ -272,7 +276,8 @@ impl LruCache { // declare it as such since we only mutate it inside the unsafe block. let cache = LruCache { map, - cap, + cost_cap: cap, + cost: 0, head: Box::into_raw(Box::new(LruEntry::new_sigil())), tail: Box::into_raw(Box::new(LruEntry::new_sigil())), }; @@ -285,7 +290,7 @@ impl LruCache { cache } - /// Puts a key-value pair into cache. If the key already exists in the cache, then it updates + /// Puts a key-value pair into cache, with a cost of 1. If the key already exists in the cache, then it updates /// the key's value and returns the old value. Otherwise, `None` is returned. /// /// # Example @@ -303,10 +308,31 @@ impl LruCache { /// assert_eq!(cache.get(&2), Some(&"beta")); /// ``` pub fn put(&mut self, k: K, v: V) -> Option { - self.capturing_put(k, v, false).map(|(_, v)| v) + self.put_with_cost(k, v, 1) + } + + /// Puts a key-value pair into cache, with the given cost. If the key already exists in the cache, then it updates + /// the key's value and returns the old value. Otherwise, `None` is returned. + /// + /// # Example + /// + /// ``` + /// use lru::LruCache; + /// use std::num::NonZeroUsize; + /// let mut cache = LruCache::new(NonZeroUsize::new(12).unwrap()); + /// + /// assert_eq!(None, cache.put_with_cost(1, "aa", 2)); + /// assert_eq!(None, cache.put_with_cost(2, "bbbb", 4)); + /// assert_eq!(Some("bbbb"), cache.put_with_cost(2, "cccccc", 6)); + /// + /// assert_eq!(cache.get(&1), Some(&"aa")); + /// assert_eq!(cache.get(&2), Some(&"cccccc")); + /// ``` + pub fn put_with_cost(&mut self, k: K, v: V, cost: usize) -> Option { + self.capturing_put(k, v, false, cost).map(|(_, v)| v) } - /// Pushes a key-value pair into the cache. If an entry with key `k` already exists in + /// Pushes a key-value pair into the cache, with a cost of 1. If an entry with key `k` already exists in /// the cache or another cache entry is removed (due to the lru's capacity), /// then it returns the old entry's key-value pair. Otherwise, returns `None`. /// @@ -331,17 +357,54 @@ impl LruCache { /// assert_eq!(cache.get(&3), Some(&"alpha")); /// ``` pub fn push(&mut self, k: K, v: V) -> Option<(K, V)> { - self.capturing_put(k, v, true) + self.push_with_cost(k, v, 1) + } + + /// Pushes a key-value pair into the cache, with a given cost. If an entry with key `k` already exists in + /// the cache or another cache entry is removed (due to the lru's capacity), + /// then it returns the old entry's key-value pair. Otherwise, returns `None`. + /// + /// # Example + /// + /// ``` + /// use lru::LruCache; + /// use std::num::NonZeroUsize; + /// let mut cache = LruCache::new(NonZeroUsize::new(2).unwrap()); + /// + /// assert_eq!(None, cache.push_with_cost(1, "a", 1)); + /// assert_eq!(None, cache.push_with_cost(2, "b", 1)); + /// + /// // This push call returns (2, "b") because that was previously 2's entry in the cache. + /// assert_eq!(Some((2, "b")), cache.push_with_cost(2, "beta", 1)); + /// + /// // This push call returns (1, "a") because the cache is at capacity and 1's entry was the lru entry. + /// assert_eq!(Some((1, "a")), cache.push_with_cost(3, "alpha", 1)); + /// + /// assert_eq!(cache.get(&1), None); + /// assert_eq!(cache.get(&2), Some(&"beta")); + /// assert_eq!(cache.get(&3), Some(&"alpha")); + /// ``` + pub fn push_with_cost(&mut self, k: K, v: V, cost: usize) -> Option<(K, V)> { + self.capturing_put(k, v, true, cost) } // Used internally by `put` and `push` to add a new entry to the lru. // Takes ownership of and returns entries replaced due to the cache's capacity // when `capture` is true. - fn capturing_put(&mut self, k: K, mut v: V, capture: bool) -> Option<(K, V)> { + fn capturing_put(&mut self, k: K, mut v: V, capture: bool, cost: usize) -> Option<(K, V)> { + if cost > self.cost_cap.get() { + // Special case: if the new value can't ever fit inside the cache, then + // we invalidate the cache entry and don't evict anything else. + return self.pop_entry(&k); + } + let node_ref = self.map.get_mut(&KeyRef { k: &k }); match node_ref { Some(node_ref) => { + let old_cost = node_ref.cost; + node_ref.cost = cost; + let node_ptr: *mut LruEntry = &mut **node_ref; // if the key is already in the cache just update its value and move it to the @@ -349,10 +412,16 @@ impl LruCache { unsafe { mem::swap(&mut v, &mut (*(*node_ptr).val.as_mut_ptr()) as &mut V) } self.detach(node_ptr); self.attach(node_ptr); + + self.cost -= old_cost; + self.cost += cost; + + self.shrink_within_cost(); + Some((k, v)) } None => { - let (replaced, mut node) = self.replace_or_create_node(k, v); + let (replaced, mut node) = self.replace_or_create_node(k, v, cost); let node_ptr: *mut LruEntry = &mut *node; self.attach(node_ptr); @@ -360,27 +429,50 @@ impl LruCache { let keyref = unsafe { (*node_ptr).key.as_ptr() }; self.map.insert(KeyRef { k: keyref }, node); + self.shrink_within_cost(); + replaced.filter(|_| capture) } } } + fn shrink_within_cost(&mut self) { + let mut did_shrink = false; + while self.cost() > self.cost_cap.get() { + self.pop_lru(); + did_shrink = true; + } + if did_shrink { + self.map.shrink_to_fit(); + } + } + // Used internally to swap out a node if the cache is full or to create a new node if space // is available. Shared between `put`, `push`, `get_or_insert`, and `get_or_insert_mut`. #[allow(clippy::type_complexity)] - fn replace_or_create_node(&mut self, k: K, v: V) -> (Option<(K, V)>, Box>) { - if self.len() == self.cap().get() { + fn replace_or_create_node( + &mut self, + k: K, + v: V, + cost: usize, + ) -> (Option<(K, V)>, Box>) { + if self.cost + cost > self.cost_cap.get() && !self.is_empty() { // if the cache is full, remove the last entry so we can use it for the new key let old_key = KeyRef { k: unsafe { &(*(*(*self.tail).prev).key.as_ptr()) }, }; let mut old_node = self.map.remove(&old_key).unwrap(); + let old_cost = old_node.cost; // read out the node's old key and value and then replace it let replaced = unsafe { (old_node.key.assume_init(), old_node.val.assume_init()) }; old_node.key = mem::MaybeUninit::new(k); old_node.val = mem::MaybeUninit::new(v); + old_node.cost = cost; + + self.cost -= old_cost; + self.cost += cost; let node_ptr: *mut LruEntry = &mut *old_node; self.detach(node_ptr); @@ -388,7 +480,8 @@ impl LruCache { (Some(replaced), old_node) } else { // if the cache is not full allocate a new LruEntry - (None, Box::new(LruEntry::new(k, v))) + self.cost += cost; + (None, Box::new(LruEntry::new(k, v, cost))) } } @@ -499,13 +592,16 @@ impl LruCache { unsafe { &(*(*node_ptr).val.as_ptr()) as &V } } else { let v = f(); - let (_, mut node) = self.replace_or_create_node(k, v); + let (_, mut node) = self.replace_or_create_node(k, v, 1); let node_ptr: *mut LruEntry = &mut *node; self.attach(node_ptr); let keyref = unsafe { (*node_ptr).key.as_ptr() }; self.map.insert(KeyRef { k: keyref }, node); + + self.shrink_within_cost(); + unsafe { &(*(*node_ptr).val.as_ptr()) as &V } } } @@ -545,13 +641,16 @@ impl LruCache { unsafe { &mut (*(*node_ptr).val.as_mut_ptr()) as &mut V } } else { let v = f(); - let (_, mut node) = self.replace_or_create_node(k, v); + let (_, mut node) = self.replace_or_create_node(k, v, 1); let node_ptr: *mut LruEntry = &mut *node; self.attach(node_ptr); let keyref = unsafe { (*node_ptr).key.as_ptr() }; self.map.insert(KeyRef { k: keyref }, node); + + self.shrink_within_cost(); + unsafe { &mut (*(*node_ptr).val.as_mut_ptr()) as &mut V } } } @@ -693,6 +792,7 @@ impl LruCache { match self.map.remove(k) { None => None, Some(mut old_node) => { + self.cost -= old_node.cost; unsafe { ptr::drop_in_place(old_node.key.as_mut_ptr()); } @@ -730,6 +830,7 @@ impl LruCache { match self.map.remove(k) { None => None, Some(mut old_node) => { + self.cost -= old_node.cost; let node_ptr: *mut LruEntry = &mut *old_node; self.detach(node_ptr); unsafe { Some((old_node.key.assume_init(), old_node.val.assume_init())) } @@ -759,6 +860,7 @@ impl LruCache { /// ``` pub fn pop_lru(&mut self) -> Option<(K, V)> { let node = self.remove_last()?; + self.cost -= node.cost; // N.B.: Can't destructure directly because of https://github.com/rust-lang/rust/issues/28536 let node = *node; let LruEntry { key, val, .. } = node; @@ -858,6 +960,29 @@ impl LruCache { self.map.len() } + /// Returns the total cost of all key-value pairs that are currently in the the cache. + /// + /// # Example + /// + /// ``` + /// use lru::LruCache; + /// use std::num::NonZeroUsize; + /// let mut cache = LruCache::new(NonZeroUsize::new(4).unwrap()); + /// assert_eq!(cache.cost(), 0); + /// + /// cache.put_with_cost(1, "a", 2); + /// assert_eq!(cache.cost(), 2); + /// + /// cache.put_with_cost(2, "b", 1); + /// assert_eq!(cache.cost(), 3); + /// + /// cache.put_with_cost(3, "c", 3); + /// assert_eq!(cache.cost(), 4); + /// ``` + pub fn cost(&self) -> usize { + self.cost + } + /// Returns a bool indicating whether the cache is empty or not. /// /// # Example @@ -875,7 +1000,7 @@ impl LruCache { self.map.len() == 0 } - /// Returns the maximum number of key-value pairs the cache can hold. + /// Returns the maximum allowed total cost of all key-value pairs in the cache. /// /// # Example /// @@ -883,10 +1008,10 @@ impl LruCache { /// use lru::LruCache; /// use std::num::NonZeroUsize; /// let mut cache: LruCache = LruCache::new(NonZeroUsize::new(2).unwrap()); - /// assert_eq!(cache.cap().get(), 2); + /// assert_eq!(cache.cost_cap().get(), 2); /// ``` - pub fn cap(&self) -> NonZeroUsize { - self.cap + pub fn cost_cap(&self) -> NonZeroUsize { + self.cost_cap } /// Resizes the cache. If the new capacity is smaller than the size of the current @@ -913,16 +1038,12 @@ impl LruCache { /// ``` pub fn resize(&mut self, cap: NonZeroUsize) { // return early if capacity doesn't change - if cap == self.cap { + if cap == self.cost_cap { return; } - while self.map.len() > cap.get() { - self.pop_lru(); - } - self.map.shrink_to_fit(); - - self.cap = cap; + self.cost_cap = cap; + self.shrink_within_cost() } /// Clears the contents of the cache. @@ -1097,7 +1218,7 @@ impl fmt::Debug for LruCache { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("LruCache") .field("len", &self.len()) - .field("cap", &self.cap()) + .field("cost_cap", &self.cost_cap()) .finish() } } @@ -1361,7 +1482,7 @@ mod tests { assert_eq!(cache.put("apple", "red"), None); assert_eq!(cache.put("banana", "yellow"), None); - assert_eq!(cache.cap().get(), 2); + assert_eq!(cache.cost_cap().get(), 2); assert_eq!(cache.len(), 2); assert!(!cache.is_empty()); assert_opt_eq(cache.get(&"apple"), "red"); @@ -1376,7 +1497,7 @@ mod tests { assert_eq!(cache.put("apple", "red"), None); assert_eq!(cache.put("banana", "yellow"), None); - assert_eq!(cache.cap().get(), 2); + assert_eq!(cache.cost_cap().get(), 2); assert_eq!(cache.len(), 2); assert!(!cache.is_empty()); assert_eq!(cache.get_or_insert("apple", || "orange"), &"red"); @@ -1393,7 +1514,7 @@ mod tests { assert_eq!(cache.put("apple", "red"), None); assert_eq!(cache.put("banana", "yellow"), None); - assert_eq!(cache.cap().get(), 2); + assert_eq!(cache.cost_cap().get(), 2); assert_eq!(cache.len(), 2); let v = cache.get_or_insert_mut("apple", || "orange"); @@ -1413,7 +1534,7 @@ mod tests { cache.put("apple", "red"); cache.put("banana", "yellow"); - assert_eq!(cache.cap().get(), 2); + assert_eq!(cache.cost_cap().get(), 2); assert_eq!(cache.len(), 2); assert_opt_eq_mut(cache.get_mut(&"apple"), "red"); assert_opt_eq_mut(cache.get_mut(&"banana"), "yellow"); @@ -1431,7 +1552,7 @@ mod tests { *v = 4; } - assert_eq!(cache.cap().get(), 2); + assert_eq!(cache.cost_cap().get(), 2); assert_eq!(cache.len(), 2); assert_opt_eq_mut(cache.get_mut(&"apple"), 4); assert_opt_eq_mut(cache.get_mut(&"banana"), 3); @@ -1470,6 +1591,57 @@ mod tests { assert_opt_eq(cache.get(&"tomato"), "red"); } + #[test] + fn test_put_with_cost_removes_oldest() { + let mut cache = LruCache::new(NonZeroUsize::new(11).unwrap()); + + assert_eq!(cache.put_with_cost("apple", "red", 3), None); + assert_eq!(cache.put_with_cost("banana", "yellow", 6), None); + assert_eq!(cache.put_with_cost("pear", "green", 5), None); + + assert!(cache.get(&"apple").is_none()); + assert_opt_eq(cache.get(&"banana"), "yellow"); + assert_opt_eq(cache.get(&"pear"), "green"); + + // Even though we inserted "apple" into the cache earlier it has since been removed from + // the cache so there is no current value for `put` to return. + assert_eq!(cache.put_with_cost("apple", "green", 5), None); + assert_eq!(cache.put_with_cost("tomato", "red", 3), None); + + assert!(cache.get(&"pear").is_none()); + assert!(cache.get(&"banana").is_none()); + assert_opt_eq(cache.get(&"apple"), "green"); + assert_opt_eq(cache.get(&"tomato"), "red"); + } + + #[test] + fn test_put_with_cost_way_too_big() { + let mut cache = LruCache::new(NonZeroUsize::new(11).unwrap()); + + assert_eq!(cache.put_with_cost("apple", "red", 3), None); + assert_eq!(cache.put_with_cost("banana", "yellow", 6), None); + assert_eq!(cache.put_with_cost("pear", "green", 5000), None); + + assert_opt_eq(cache.get(&"apple"), "red"); + assert_opt_eq(cache.get(&"banana"), "yellow"); + assert!(cache.get(&"pear").is_none()); + } + + #[test] + fn test_put_with_cost_evict_multiple2() { + let mut cache = LruCache::new(NonZeroUsize::new(16).unwrap()); + + assert_eq!(cache.put_with_cost("apple", "red", 3), None); + assert_eq!(cache.put_with_cost("banana", "yellow", 6), None); + assert_eq!(cache.put_with_cost("plum", "purple", 6), None); + assert_eq!(cache.put_with_cost("pear", "green", 10), None); + + assert!(cache.get(&"apple").is_none()); + assert!(cache.get(&"banana").is_none()); + assert_opt_eq(cache.get(&"plum"), "purple"); + assert_opt_eq(cache.get(&"pear"), "green"); + } + #[test] fn test_peek() { let mut cache = LruCache::new(NonZeroUsize::new(2).unwrap()); @@ -2006,15 +2178,9 @@ mod tests { fn test_no_memory_leaks_with_pop() { static DROP_COUNT: AtomicUsize = AtomicUsize::new(0); - #[derive(Hash, Eq)] + #[derive(Hash, PartialEq, Eq)] struct KeyDropCounter(usize); - impl PartialEq for KeyDropCounter { - fn eq(&self, other: &Self) -> bool { - self.0.eq(&other.0) - } - } - impl Drop for KeyDropCounter { fn drop(&mut self) { DROP_COUNT.fetch_add(1, Ordering::SeqCst);