Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: limit the max amount of concurrent network requests #116

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion crates/cli/src/commands/add.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use clap::Parser;
use std::collections::VecDeque;
use std::{collections::VecDeque, sync::Arc};

use crate::{
package::{fetch_package_version_directly, find_package_version_from_registry},
Expand All @@ -9,6 +9,7 @@ use futures_util::future;
use pacquet_diagnostics::miette::WrapErr;
use pacquet_package_json::DependencyGroup;
use pacquet_registry::PackageVersion;
use tokio::sync::Semaphore;

#[derive(Parser, Debug)]
pub struct AddCommandArgs {
Expand Down Expand Up @@ -59,13 +60,15 @@ impl PackageManager {
/// 5. Symlink all dependencies to node_modules/.pacquet/pkg@version/node_modules
/// 6. Update package.json
pub async fn add(&mut self, args: &AddCommandArgs) -> Result<(), PackageManagerError> {
let semaphore = Arc::new(Semaphore::new(16));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be configurable somehow

let latest_version = fetch_package_version_directly(
&self.tarball_cache,
self.config,
&self.http_client,
&args.package,
"latest",
&self.config.modules_dir,
&semaphore,
)
.await?;
let package_node_modules_path =
Expand All @@ -85,6 +88,7 @@ impl PackageManager {
name,
version,
path,
&semaphore,
)
});

Expand All @@ -109,6 +113,7 @@ impl PackageManager {
name,
version,
&node_modules_path,
&semaphore,
)
},
);
Expand Down
36 changes: 22 additions & 14 deletions crates/cli/src/commands/install.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use pacquet_diagnostics::tracing;
use pacquet_package_json::DependencyGroup;
use pacquet_registry::PackageVersion;
use pipe_trait::Pipe;
use std::sync::Arc;
use tokio::sync::Semaphore;

#[derive(Parser, Debug)]
pub struct InstallCommandArgs {
Expand Down Expand Up @@ -46,7 +48,7 @@ impl PackageManager {
///
/// This function is used by [`PackageManager::install`].
#[async_recursion]
async fn install_dependencies(&self, package: &PackageVersion) {
async fn install_dependencies(&self, package: &PackageVersion, semaphore: &Semaphore) {
let node_modules_path =
self.config.virtual_store_dir.join(package.to_store_name()).join("node_modules");

Expand All @@ -62,10 +64,11 @@ impl PackageManager {
name,
version,
&node_modules_path,
&semaphore,
)
.await
.unwrap();
self.install_dependencies(&dependency).await;
self.install_dependencies(&dependency, &semaphore).await;
})
.pipe(future::join_all)
.await;
Expand All @@ -76,21 +79,26 @@ impl PackageManager {
/// Jobs of the `install` command.
pub async fn install(&self, args: &InstallCommandArgs) -> Result<(), PackageManagerError> {
tracing::info!(target: "pacquet::install", "Start all");
let semaphore = Arc::new(Semaphore::new(16));

self.package_json
.dependencies(args.dependency_groups())
.map(|(name, version)| async move {
let dependency = find_package_version_from_registry(
&self.tarball_cache,
self.config,
&self.http_client,
name,
version,
&self.config.modules_dir,
)
.await
.unwrap();
self.install_dependencies(&dependency).await;
.map(|(name, version)| {
let semaphore_clone = Arc::clone(&semaphore);
async move {
let dependency = find_package_version_from_registry(
&self.tarball_cache,
self.config,
&self.http_client,
name,
version,
&self.config.modules_dir,
&semaphore_clone,
)
.await
.unwrap();
self.install_dependencies(&dependency, &semaphore_clone).await;
}
})
.pipe(future::join_all)
.await;
Expand Down
24 changes: 19 additions & 5 deletions crates/cli/src/package.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use pacquet_registry::{Package, PackageVersion};
use pacquet_tarball::{download_tarball_to_store, Cache};
use reqwest::Client;
use std::path::Path;
use tokio::sync::Semaphore;

/// This function execute the following and returns the package
/// - retrieves the package from the registry
Expand All @@ -21,10 +22,13 @@ pub async fn find_package_version_from_registry(
name: &str,
version: &str,
symlink_path: &Path,
semaphore: &Semaphore,
) -> Result<PackageVersion, PackageManagerError> {
let package = Package::fetch_from_registry(name, http_client, &config.registry).await?;
let package =
Package::fetch_from_registry(name, http_client, &config.registry, semaphore).await?;
let package_version = package.pinned_version(version).unwrap();
internal_fetch(tarball_cache, http_client, package_version, config, symlink_path).await?;
internal_fetch(tarball_cache, http_client, package_version, config, symlink_path, semaphore)
.await?;
Ok(package_version.to_owned())
}

Expand All @@ -35,10 +39,18 @@ pub async fn fetch_package_version_directly(
name: &str,
version: &str,
symlink_path: &Path,
semaphore: &Semaphore,
) -> Result<PackageVersion, PackageManagerError> {
let package_version =
PackageVersion::fetch_from_registry(name, version, http_client, &config.registry).await?;
internal_fetch(tarball_cache, http_client, &package_version, config, symlink_path).await?;
let package_version = PackageVersion::fetch_from_registry(
name,
version,
http_client,
&config.registry,
semaphore,
)
.await?;
internal_fetch(tarball_cache, http_client, &package_version, config, symlink_path, semaphore)
.await?;
Ok(package_version.to_owned())
}

Expand All @@ -48,6 +60,7 @@ async fn internal_fetch(
package_version: &PackageVersion,
config: &'static Npmrc,
symlink_path: &Path,
semaphore: &Semaphore,
) -> Result<(), PackageManagerError> {
let store_folder_name = package_version.to_store_name();

Expand All @@ -59,6 +72,7 @@ async fn internal_fetch(
package_version.dist.integrity.as_ref().expect("has integrity field"),
package_version.dist.unpacked_size,
package_version.as_tarball_url(),
semaphore,
)
.await?;

Expand Down
3 changes: 3 additions & 0 deletions crates/registry/src/package.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use pipe_trait::Pipe;
use serde::{Deserialize, Serialize};

use crate::{package_version::PackageVersion, NetworkError, RegistryError};
use tokio::sync::Semaphore;

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Package {
Expand All @@ -30,9 +31,11 @@ impl Package {
name: &str,
http_client: &reqwest::Client,
registry: &str,
semaphore: &Semaphore,
) -> Result<Self, RegistryError> {
let url = || format!("{registry}{name}"); // TODO: use reqwest URL directly
let network_error = |error| NetworkError { error, url: url() };
let _permit = semaphore.acquire().await;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you gave this value a name? If you want to postpone dropping it, I recommend explicitly calling drop (i.e. drop(permit)) at the end of the scope. If you only want to .await the semaphore, please remove let _permit =.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't it be dropped immediately if I don't assign it to a value? As far as I understand, the permit is dropped, when it goes out of scope.

http_client
.get(url())
.header("content-type", "application/json")
Expand Down
3 changes: 3 additions & 0 deletions crates/registry/src/package_version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use pipe_trait::Pipe;
use serde::{Deserialize, Serialize};

use crate::{package_distribution::PackageDistribution, NetworkError, RegistryError};
use tokio::sync::Semaphore;

#[derive(Serialize, Deserialize, Debug, Clone, Eq)]
#[serde(rename_all = "camelCase")]
Expand All @@ -28,9 +29,11 @@ impl PackageVersion {
version: &str,
http_client: &reqwest::Client,
registry: &str,
semaphore: &Semaphore,
) -> Result<Self, RegistryError> {
let url = || format!("{registry}{name}/{version}");
let network_error = |error| NetworkError { error, url: url() };
let _permit = semaphore.acquire().await;

http_client
.get(url())
Expand Down
7 changes: 6 additions & 1 deletion crates/tarball/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use pipe_trait::Pipe;
use reqwest::Client;
use ssri::{Integrity, IntegrityChecker};
use tar::Archive;
use tokio::sync::{Notify, RwLock};
use tokio::sync::{Notify, RwLock, Semaphore};
use zune_inflate::{errors::InflateDecodeErrors, DeflateDecoder, DeflateOptions};

#[derive(Error, Debug, Diagnostic)]
Expand Down Expand Up @@ -120,6 +120,7 @@ pub async fn download_tarball_to_store(
package_integrity: &str,
package_unpacked_size: Option<usize>,
package_url: &str,
semaphore: &Semaphore,
) -> Result<Arc<HashMap<OsString, PathBuf>>, TarballError> {
if let Some(cache_lock) = cache.get(package_url) {
let notify;
Expand Down Expand Up @@ -154,6 +155,7 @@ pub async fn download_tarball_to_store(
store_dir,
package_integrity,
package_unpacked_size,
semaphore,
)
.await?;
let mut cache_write = cache_lock.write().await;
Expand All @@ -169,10 +171,12 @@ async fn download_tarball_to_store_uncached(
store_dir: &'static Path,
package_integrity: &str,
package_unpacked_size: Option<usize>,
semaphore: &Semaphore,
) -> Result<Arc<HashMap<OsString, PathBuf>>, TarballError> {
tracing::info!(target: "pacquet::download", ?package_url, "New cache");

let network_error = |error| NetworkError { url: package_url.to_string(), error };
let permit = semaphore.acquire().await;
let response = http_client
.get(package_url)
.send()
Expand All @@ -181,6 +185,7 @@ async fn download_tarball_to_store_uncached(
.bytes()
.await
.map_err(network_error)?;
drop(permit);

tracing::info!(target: "pacquet::download", ?package_url, "Download completed");

Expand Down
5 changes: 4 additions & 1 deletion tasks/benchmark/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{fs, path::Path};
use std::{fs, path::Path, sync::Arc};

use clap::Parser;
use criterion::{Criterion, Throughput};
Expand All @@ -8,6 +8,7 @@ use pipe_trait::Pipe;
use project_root::get_project_root;
use reqwest::Client;
use tempfile::tempdir;
use tokio::sync::Semaphore;

#[derive(Debug, Parser)]
struct CliArgs {
Expand All @@ -30,6 +31,7 @@ fn bench_tarball(c: &mut Criterion, server: &mut ServerGuard, fixtures_folder: &
// NOTE: the tempdir is being leaked, meaning the cleanup would be postponed until the end of the benchmark
let dir = tempdir().unwrap().pipe(Box::new).pipe(Box::leak);
let http_client = Client::new();
let semaphore = Arc::new(Semaphore::new(16));

let cas_map =
download_tarball_to_store(
Expand All @@ -39,6 +41,7 @@ fn bench_tarball(c: &mut Criterion, server: &mut ServerGuard, fixtures_folder: &
"sha512-dj7vjIn1Ar8sVXj2yAXiMNCJDmS9MQ9XMlIecX2dIzzhjSHCyKo4DdXjXMs7wKW2kj6yvVRSpuQjOZ3YLrh56w==",
Some(16697),
url,
&semaphore,
).await.unwrap();
cas_map.len()
});
Expand Down