Skip to content

Commit

Permalink
refactor: locking of tarball fetching (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
zkochan authored Sep 19, 2023
1 parent cf9825b commit ef60ea6
Showing 1 changed file with 47 additions and 23 deletions.
70 changes: 47 additions & 23 deletions crates/tarball/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use pipe_trait::Pipe;
use reqwest::Client;
use ssri::{Integrity, IntegrityChecker};
use tar::Archive;
use tokio::sync::RwLock;
use tokio::sync::{Notify, RwLock};
use zune_inflate::{errors::InflateDecodeErrors, DeflateDecoder, DeflateOptions};

#[derive(Error, Debug, Diagnostic)]
Expand Down Expand Up @@ -80,10 +80,10 @@ pub enum TarballError {
}

/// Value of the cache.
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone)]
pub enum CacheValue {
/// The package is being processed.
InProgress,
InProgress(Arc<Notify>),
/// The package is saved.
Available(Arc<HashMap<OsString, PathBuf>>),
}
Expand Down Expand Up @@ -121,30 +121,57 @@ pub async fn download_tarball_to_store(
package_unpacked_size: Option<usize>,
package_url: &str,
) -> Result<Arc<HashMap<OsString, PathBuf>>, TarballError> {
while let Some(cache_lock) = cache.get(package_url) {
tracing::info!(target: "pacquet::download", ?package_url, "Job taken");

match &*cache_lock.read().await {
CacheValue::Available(cas_paths) => {
tracing::info!(target: "pacquet::download", ?package_url, cas_paths_len = cas_paths.len(), "Cache hit");
return Ok(cas_paths.clone());
if let Some(cache_lock) = cache.get(package_url) {
let notify;
{
let cache_value = cache_lock.write().await;
match &*cache_value {
CacheValue::Available(cas_paths) => {
return Ok(cas_paths.clone());
}
CacheValue::InProgress(existing_notify) => {
notify = existing_notify.clone();
}
}
CacheValue::InProgress => {
tracing::info!(target: "pacquet::download", ?package_url, "Wait for cache");
}
notify.notified().await;
if let Some(cached) = cache.get(package_url) {
if let CacheValue::Available(cas_paths) = &*cached.read().await {
return Ok(cas_paths.clone());
}
}
drop(cache_lock);
tokio::task::yield_now().await; // prevent deadlock
continue;
Err(TarballError::Io(std::io::Error::new(
std::io::ErrorKind::Other,
"Failed to get or compute tarball data",
)))
} else {
let notify = Arc::new(Notify::new());
let cache_lock = Arc::new(RwLock::new(CacheValue::InProgress(notify.clone())));
cache.insert(package_url.to_string(), cache_lock.clone());
let cas_paths = download_tarball_to_store_uncached(
package_url,
http_client,
store_dir,
package_integrity,
package_unpacked_size,
)
.await?;
let mut cache_write = cache_lock.write().await;
*cache_write = CacheValue::Available(cas_paths.clone());
notify.notify_waiters();
Ok(cas_paths)
}
}

async fn download_tarball_to_store_uncached(
package_url: &str,
http_client: &Client,
store_dir: &'static Path,
package_integrity: &str,
package_unpacked_size: Option<usize>,
) -> Result<Arc<HashMap<OsString, PathBuf>>, TarballError> {
tracing::info!(target: "pacquet::download", ?package_url, "New cache");

let cache_lock = CacheValue::InProgress.pipe(RwLock::new).pipe(Arc::new);
if cache.insert(package_url.to_string(), cache_lock.clone()).is_some() {
tracing::warn!(target: "pacquet::download", ?package_url, "Race condition detected when writing to cache");
}

let network_error = |error| NetworkError { url: package_url.to_string(), error };
let response = http_client
.get(package_url)
Expand Down Expand Up @@ -204,9 +231,6 @@ pub async fn download_tarball_to_store(

tracing::info!(target: "pacquet::download", ?package_url, "Checksum verified");

let mut cache_write = cache_lock.write().await;
*cache_write = CacheValue::Available(cas_paths.clone());

Ok(cas_paths)
}

Expand Down

0 comments on commit ef60ea6

Please sign in to comment.