diff --git a/src/lib.rs b/src/lib.rs index 3d46e0a..b86c878 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -445,6 +445,13 @@ pub trait IOCached { /// Should return `Self::Error` if the operation fails fn cache_remove(&self, k: &K) -> Result, Self::Error>; + /// Remove all cached values + /// + /// # Errors + /// + /// Should return `Self::Error` if the operation fails + fn cache_clear(&self) -> Result<(), Self::Error>; + /// Set the flag to control whether cache hits refresh the ttl of cached values, returns the old flag value fn cache_set_refresh(&mut self, refresh: bool) -> bool; @@ -479,6 +486,9 @@ pub trait IOCachedAsync { /// Remove a cached value async fn cache_remove(&self, k: &K) -> Result, Self::Error>; + /// Remove all cached values + async fn cache_clear(&self) -> Result<(), Self::Error>; + /// Set the flag to control whether cache hits refresh the ttl of cached values, returns the old flag value fn cache_set_refresh(&mut self, refresh: bool) -> bool; diff --git a/src/stores/disk.rs b/src/stores/disk.rs index c65e313..3aa4b4e 100644 --- a/src/stores/disk.rs +++ b/src/stores/disk.rs @@ -353,6 +353,13 @@ where self.ttl } + fn cache_clear(&self) -> Result<(), Self::Error> { + for (key, _value) in self.connection.iter().flatten() { + self.connection.remove(key)?; + } + Ok(()) + } + fn cache_set_lifespan(&mut self, ttl: Duration) -> Option { let old = self.ttl; self.ttl = Some(ttl); diff --git a/src/stores/redis.rs b/src/stores/redis.rs index ce57658..bbfb8c2 100644 --- a/src/stores/redis.rs +++ b/src/stores/redis.rs @@ -1,4 +1,5 @@ use crate::IOCached; +use redis::Commands; use serde::de::DeserializeOwned; use serde::Serialize; use std::marker::PhantomData; @@ -217,7 +218,7 @@ where RedisCacheBuilder::new(prefix, ttl) } - fn generate_key(&self, key: &K) -> String { + fn generate_key(&self, key: impl Display) -> String { format!("{}{}{}", self.namespace, self.prefix, key) } @@ -345,6 +346,19 @@ where Some(self.ttl) } + fn cache_clear(&self) -> Result<(), RedisCacheError> { + // `scan_match` takes `&mut self`, so we need two connection objects to scan and + // delete...? + let mut scan = self.pool.get()?; + let mut delete = self.pool.get()?; + + for key in scan.scan_match::<_, String>(self.generate_key("*"))? { + let () = delete.del(key)?; + } + + Ok(()) + } + fn cache_set_lifespan(&mut self, ttl: Duration) -> Option { let old = self.ttl; self.ttl = ttl; @@ -365,6 +379,8 @@ where mod async_redis { use std::time::Duration; + use redis::AsyncCommands; + use super::{ CachedRedisValue, DeserializeOwned, Display, PhantomData, RedisCacheBuildError, RedisCacheError, Serialize, DEFAULT_NAMESPACE, ENV_KEY, @@ -530,7 +546,7 @@ mod async_redis { AsyncRedisCacheBuilder::new(prefix, ttl) } - fn generate_key(&self, key: &K) -> String { + fn generate_key(&self, key: impl Display) -> String { format!("{}{}{}", self.namespace, self.prefix, key) } @@ -628,6 +644,21 @@ mod async_redis { } } + async fn cache_clear(&self) -> Result<(), Self::Error> { + // `scan_match` takes `&mut self`, so we need two connection objects to scan and + // delete...? + let mut scan = self.connection.clone(); + let mut delete = self.connection.clone(); + + let mut scanner = scan.scan_match::<_, String>(self.generate_key("*")).await?; + + while let Some(key) = scanner.next_item().await { + let () = delete.del(key).await?; + } + + Ok(()) + } + /// Set the flag to control whether cache hits refresh the ttl of cached values, returns the old flag value fn cache_set_refresh(&mut self, refresh: bool) -> bool { let old = self.refresh;