Skip to content

Commit

Permalink
add get_or_insert and Send
Browse files Browse the repository at this point in the history
  • Loading branch information
w273732573 committed Jun 16, 2024
1 parent 31708ec commit 7d6227f
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 17 deletions.
2 changes: 1 addition & 1 deletion examples/bench_lru.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,5 @@ fn do_bench(num: usize) {
}

fn main() {
do_bench(1e4 as usize);
do_bench(1e6 as usize);
}
85 changes: 70 additions & 15 deletions src/cache/arc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,48 @@ impl<K: Hash + Eq, V, S: BuildHasher> ArcCache<K, V, S> {
}
}


pub fn get_or_insert<F>(&mut self, k: K, f: F) -> &V
where
F: FnOnce() -> V, {
&*self.get_or_insert_mut(k, f)
}

pub fn get_or_insert_mut<F>(&mut self, k: K, f: F) -> &mut V
where
F: FnOnce() -> V, {

if let Some((key, val)) = self.main_lru.remove(&k) {
self.main_lfu.insert(key, val);
return self.main_lfu.get_mut_key_value(&k).map(|(_, v)| v).unwrap();
}

if let Some((key, val)) = self.ghost_lfu.remove(&k) {
self.main_lfu.full_increase();
self.main_lru.full_decrease();
self.main_lfu.insert(key, val);
return self.main_lfu.get_mut_key_value(&k).map(|(_, v)| v).unwrap();
}

if let Some((key, val)) = self.ghost_lru.remove(&k) {
self.main_lru.full_increase();
self.main_lfu.full_decrease();
self.main_lru.insert(key, val);
return self.main_lru.get_mut_key_value(&k).map(|(_, v)| v).unwrap();
}

if self.main_lfu.contains_key(&k) {
return self.main_lfu.get_mut_key_value(&k).map(|(_, v)| v).unwrap();
}

if self.main_lru.is_full() {
let (pk, pv) = self.main_lru.pop_last().unwrap();
self.ghost_lru.insert(pk, pv);
}
self.get_or_insert_mut(k, f)
}


/// 移除元素
///
/// ```
Expand Down Expand Up @@ -490,6 +532,17 @@ impl<K: Hash + Eq, V, S: BuildHasher> ArcCache<K, V, S> {
}
}


impl<K: Hash + Eq, V: Default, S: BuildHasher> ArcCache<K, V, S> {
pub fn get_or_insert_default(&mut self, k: K) -> &V {
&*self.get_or_insert_mut(k, || V::default())
}

pub fn get_or_insert_default_mut(&mut self, k: K) -> &mut V {
self.get_or_insert_mut(k, || V::default())
}
}

impl<K: Clone + Hash + Eq, V: Clone, S: Clone + BuildHasher> Clone for ArcCache<K, V, S> {
fn clone(&self) -> Self {
ArcCache {
Expand Down Expand Up @@ -739,6 +792,9 @@ where
}
}

unsafe impl<K: Send, V: Send, S: Send> Send for ArcCache<K, V, S> {}
unsafe impl<K: Sync, V: Sync, S: Sync> Sync for ArcCache<K, V, S> {}

#[cfg(test)]
mod tests {
use std::collections::hash_map::RandomState;
Expand Down Expand Up @@ -1111,19 +1167,18 @@ mod tests {
assert_eq!(a[&3], "three");
}

// #[test]
// fn test_drain() {
// let mut a = ArcCache::new(3);
// a.insert(1, 1);
// a.insert(2, 2);
// a.insert(3, 3);

// assert_eq!(a.len(), 3);
// {
// let mut drain = a.drain();
// assert_eq!(drain.next().unwrap(), (1, 1));
// assert_eq!(drain.next().unwrap(), (2, 2));
// }
// assert_eq!(a.len(), 0);
// }

#[test]
fn test_send() {
use std::thread;

let mut cache = ArcCache::new(4);
cache.insert(1, "a");

let handle = thread::spawn(move || {
assert_eq!(cache.get(&1), Some(&"a"));
});

assert!(handle.join().is_ok());
}
}
59 changes: 59 additions & 0 deletions src/cache/lfu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,37 @@ impl<K: Hash + Eq, V, S: BuildHasher> LfuCache<K, V, S> {
}
}


pub fn get_or_insert<F>(&mut self, k: K, f: F) -> &V
where
F: FnOnce() -> V, {
&*self.get_or_insert_mut(k, f)
}


pub fn get_or_insert_mut<F>(&mut self, k: K, f: F) -> &mut V
where
F: FnOnce() -> V, {
if let Some(l) = self.map.get(KeyWrapper::from_ref(&k)) {
let node = l.as_ptr();
self.detach(node);
self.attach(node);
unsafe { &mut *(*node).val.as_mut_ptr() }
} else {
let v = f();

let (_, node) = self.replace_or_create_node(k, v);
let node_ptr: *mut LfuEntry<K, V> = node.as_ptr();

self.attach(node_ptr);

let keyref = unsafe { (*node_ptr).key.as_ptr() };
self.map.insert(KeyRef { k: keyref }, node);
unsafe { &mut *(*node_ptr).val.as_mut_ptr() }
}
}


/// 移除元素
///
/// ```
Expand Down Expand Up @@ -813,6 +844,17 @@ impl<K: Hash + Eq, V, S: BuildHasher> LfuCache<K, V, S> {
}
}


impl<K: Hash + Eq, V: Default, S: BuildHasher> LfuCache<K, V, S> {
pub fn get_or_insert_default(&mut self, k: K) -> &V {
&*self.get_or_insert_mut(k, || V::default())
}

pub fn get_or_insert_default_mut(&mut self, k: K) -> &mut V {
self.get_or_insert_mut(k, || V::default())
}
}

impl<K: Clone + Hash + Eq, V: Clone, S: Clone + BuildHasher> Clone for LfuCache<K, V, S> {
fn clone(&self) -> Self {
let mut new_lru = LfuCache::with_hasher(self.cap, self.map.hasher().clone());
Expand Down Expand Up @@ -1204,6 +1246,9 @@ where
}
}

unsafe impl<K: Send, V: Send, S: Send> Send for LfuCache<K, V, S> {}
unsafe impl<K: Sync, V: Sync, S: Sync> Sync for LfuCache<K, V, S> {}

#[cfg(test)]
mod tests {
use std::collections::hash_map::RandomState;
Expand Down Expand Up @@ -1591,4 +1636,18 @@ mod tests {
}
assert_eq!(a.len(), 0);
}

#[test]
fn test_send() {
use std::thread;

let mut cache = LfuCache::new(4);
cache.insert(1, "a");

let handle = thread::spawn(move || {
assert_eq!(cache.get(&1), Some(&"a"));
});

assert!(handle.join().is_ok());
}
}
64 changes: 64 additions & 0 deletions src/cache/lru.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ impl<K, V, S> LruCache<K, V, S> {
self.map.len()
}

pub fn is_full(&self) -> bool {
self.map.len() == self.cap
}

pub fn is_empty(&self) -> bool {
self.map.len() == 0
}
Expand Down Expand Up @@ -537,6 +541,36 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
}
}


pub fn get_or_insert<F>(&mut self, k: K, f: F) -> &V
where
F: FnOnce() -> V, {
&*self.get_or_insert_mut(k, f)
}


pub fn get_or_insert_mut<F>(&mut self, k: K, f: F) -> &mut V
where
F: FnOnce() -> V, {
if let Some(l) = self.map.get(KeyWrapper::from_ref(&k)) {
let node = l.as_ptr();
self.detach(node);
self.attach(node);
unsafe { &mut *(*node).val.as_mut_ptr() }
} else {
let v = f();

let (_, node) = self.replace_or_create_node(k, v);
let node_ptr: *mut LruEntry<K, V> = node.as_ptr();

self.attach(node_ptr);

let keyref = unsafe { (*node_ptr).key.as_ptr() };
self.map.insert(KeyRef { k: keyref }, node);
unsafe { &mut *(*node_ptr).val.as_mut_ptr() }
}
}

/// 移除元素
///
/// ```
Expand Down Expand Up @@ -623,6 +657,17 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
}
}


impl<K: Hash + Eq, V: Default, S: BuildHasher> LruCache<K, V, S> {
pub fn get_or_insert_default(&mut self, k: K) -> &V {
&*self.get_or_insert_mut(k, || V::default())
}

pub fn get_or_insert_default_mut(&mut self, k: K) -> &mut V {
self.get_or_insert_mut(k, || V::default())
}
}

impl<K: Clone + Hash + Eq, V: Clone, S: Clone + BuildHasher> Clone for LruCache<K, V, S> {
fn clone(&self) -> Self {
let mut new_lru = LruCache::with_hasher(self.cap, self.map.hasher().clone());
Expand Down Expand Up @@ -928,6 +973,10 @@ where
}
}

unsafe impl<K: Send, V: Send, S: Send> Send for LruCache<K, V, S> {}
unsafe impl<K: Sync, V: Sync, S: Sync> Sync for LruCache<K, V, S> {}


#[cfg(test)]
mod tests {
use std::collections::hash_map::RandomState;
Expand Down Expand Up @@ -1306,4 +1355,19 @@ mod tests {
}
assert_eq!(a.len(), 0);
}


#[test]
fn test_send() {
use std::thread;

let mut cache = LruCache::new(4);
cache.insert(1, "a");

let handle = thread::spawn(move || {
assert_eq!(cache.get(&1), Some(&"a"));
});

assert!(handle.join().is_ok());
}
}
61 changes: 61 additions & 0 deletions src/cache/lruk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,35 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruKCache<K, V, S> {
}
}

pub fn get_or_insert<F>(&mut self, k: K, f: F) -> &V
where
F: FnOnce() -> V, {
&*self.get_or_insert_mut(k, f)
}


pub fn get_or_insert_mut<F>(&mut self, k: K, f: F) -> &mut V
where
F: FnOnce() -> V, {
if let Some(l) = self.map.get(KeyWrapper::from_ref(&k)) {
let node = l.as_ptr();
self.detach(node);
self.attach(node);
unsafe { &mut *(*node).val.as_mut_ptr() }
} else {
let v = f();

let (_, node) = self.replace_or_create_node(k, v);
let node_ptr: *mut LruKEntry<K, V> = node.as_ptr();

self.attach(node_ptr);

let keyref = unsafe { (*node_ptr).key.as_ptr() };
self.map.insert(KeyRef { k: keyref }, node);
unsafe { &mut *(*node_ptr).val.as_mut_ptr() }
}
}

/// 移除元素
///
/// ```
Expand Down Expand Up @@ -672,6 +701,17 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruKCache<K, V, S> {
}
}


impl<K: Hash + Eq, V: Default, S: BuildHasher> LruKCache<K, V, S> {
pub fn get_or_insert_default(&mut self, k: K) -> &V {
&*self.get_or_insert_mut(k, || V::default())
}

pub fn get_or_insert_default_mut(&mut self, k: K) -> &mut V {
self.get_or_insert_mut(k, || V::default())
}
}

impl<K: Clone + Hash + Eq, V: Clone, S: Clone + BuildHasher> Clone for LruKCache<K, V, S> {
fn clone(&self) -> Self {

Expand Down Expand Up @@ -998,6 +1038,12 @@ where
}
}



unsafe impl<K: Send, V: Send, S: Send> Send for LruKCache<K, V, S> {}
unsafe impl<K: Sync, V: Sync, S: Sync> Sync for LruKCache<K, V, S> {}


#[cfg(test)]
mod tests {
use std::collections::hash_map::RandomState;
Expand Down Expand Up @@ -1376,4 +1422,19 @@ mod tests {
}
assert_eq!(a.len(), 0);
}


#[test]
fn test_send() {
use std::thread;

let mut cache = LruKCache::new(4);
cache.insert(1, "a");

let handle = thread::spawn(move || {
assert_eq!(cache.get(&1), Some(&"a"));
});

assert!(handle.join().is_ok());
}
}
3 changes: 2 additions & 1 deletion src/cache/slab.rs
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,8 @@ impl<T: Default> Extend<T> for Slab<T> {
}
}


unsafe impl<T: Send + Default> Send for Slab<T> {}
unsafe impl<T: Sync + Default> Sync for Slab<T> {}

#[cfg(test)]
mod tests {
Expand Down

0 comments on commit 7d6227f

Please sign in to comment.