Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: limit the max amount of concurrent network requests #116

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion crates/cli/src/commands/add.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use clap::Parser;
use std::collections::VecDeque;
use std::{collections::VecDeque, sync::Arc};

use crate::{
package::{fetch_package_version_directly, find_package_version_from_registry},
Expand All @@ -9,6 +9,7 @@ use futures_util::future;
use pacquet_diagnostics::miette::WrapErr;
use pacquet_package_json::DependencyGroup;
use pacquet_registry::PackageVersion;
use tokio::sync::Semaphore;

#[derive(Parser, Debug)]
pub struct AddCommandArgs {
Expand Down Expand Up @@ -59,13 +60,15 @@ impl PackageManager {
/// 5. Symlink all dependencies to node_modules/.pacquet/pkg@version/node_modules
/// 6. Update package.json
pub async fn add(&mut self, args: &AddCommandArgs) -> Result<(), PackageManagerError> {
let semaphore = Arc::new(Semaphore::new(16));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be configurable somehow

let latest_version = fetch_package_version_directly(
&self.tarball_cache,
&self.config,
&self.http_client,
&args.package,
"latest",
&self.config.modules_dir,
&semaphore,
)
.await?;
let package_node_modules_path =
Expand All @@ -85,6 +88,7 @@ impl PackageManager {
name,
version,
path,
&semaphore,
)
});

Expand All @@ -109,6 +113,7 @@ impl PackageManager {
name,
version,
&node_modules_path,
&semaphore,
)
},
);
Expand Down
37 changes: 23 additions & 14 deletions crates/cli/src/commands/install.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use pacquet_diagnostics::tracing;
use pacquet_package_json::DependencyGroup;
use pacquet_registry::PackageVersion;
use pipe_trait::Pipe;
use tokio::sync::Semaphore;
use std::sync::Arc;

#[derive(Parser, Debug)]
pub struct InstallCommandArgs {
Expand Down Expand Up @@ -46,7 +48,7 @@ impl PackageManager {
///
/// This function is used by [`PackageManager::install`].
#[async_recursion]
async fn install_dependencies(&self, package: &PackageVersion) {
async fn install_dependencies(&self, package: &PackageVersion, semaphore: &Semaphore) {
let node_modules_path =
self.config.virtual_store_dir.join(package.to_store_name()).join("node_modules");

Expand All @@ -55,17 +57,19 @@ impl PackageManager {
package
.dependencies(self.config.auto_install_peers)
.map(|(name, version)| async {
let semaphore_clone = semaphore.clone();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

semaphore is not an Arc here. Cloning it would create a new Semaphore. Is this intentional?

let dependency = find_package_version_from_registry(
&self.tarball_cache,
&self.config,
&self.http_client,
name,
version,
&node_modules_path,
&semaphore_clone,
)
.await
.unwrap();
self.install_dependencies(&dependency).await;
self.install_dependencies(&dependency, &semaphore_clone).await;
})
.pipe(future::join_all)
.await;
Expand All @@ -76,21 +80,26 @@ impl PackageManager {
/// Jobs of the `install` command.
pub async fn install(&self, args: &InstallCommandArgs) -> Result<(), PackageManagerError> {
tracing::info!(target: "pacquet::install", "Start all");
let semaphore = Arc::new(Semaphore::new(16));

self.package_json
.dependencies(args.dependency_groups())
.map(|(name, version)| async move {
let dependency = find_package_version_from_registry(
&self.tarball_cache,
&self.config,
&self.http_client,
name,
version,
&self.config.modules_dir,
)
.await
.unwrap();
self.install_dependencies(&dependency).await;
.map(|(name, version)| {
let semaphore_clone = semaphore.clone();
Copy link
Contributor

@KSXGitHub KSXGitHub Sep 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend Arc::clone(&semaphore) to avoid confusing it with regular expensive .clone().

async move {
let dependency = find_package_version_from_registry(
&self.tarball_cache,
&self.config,
&self.http_client,
name,
version,
&self.config.modules_dir,
&semaphore_clone,
)
.await
.unwrap();
self.install_dependencies(&dependency, &semaphore_clone).await;
}
})
.pipe(future::join_all)
.await;
Expand Down
13 changes: 9 additions & 4 deletions crates/cli/src/package.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use pacquet_npmrc::Npmrc;
use pacquet_registry::{Package, PackageVersion};
use pacquet_tarball::{download_tarball_to_store, Cache};
use std::path::Path;
use tokio::sync::Semaphore;

/// This function execute the following and returns the package
/// - retrieves the package from the registry
Expand All @@ -20,10 +21,11 @@ pub async fn find_package_version_from_registry(
name: &str,
version: &str,
symlink_path: &Path,
semaphore: &Semaphore,
) -> Result<PackageVersion, PackageManagerError> {
let package = Package::fetch_from_registry(name, http_client, &config.registry).await?;
let package = Package::fetch_from_registry(name, http_client, &config.registry, semaphore).await?;
let package_version = package.pinned_version(version).unwrap();
internal_fetch(tarball_cache, package_version, config, symlink_path).await?;
internal_fetch(tarball_cache, package_version, config, symlink_path, semaphore).await?;
Ok(package_version.to_owned())
}

Expand All @@ -34,10 +36,11 @@ pub async fn fetch_package_version_directly(
name: &str,
version: &str,
symlink_path: &Path,
semaphore: &Semaphore,
) -> Result<PackageVersion, PackageManagerError> {
let package_version =
PackageVersion::fetch_from_registry(name, version, http_client, &config.registry).await?;
internal_fetch(tarball_cache, &package_version, config, symlink_path).await?;
PackageVersion::fetch_from_registry(name, version, http_client, &config.registry, semaphore).await?;
internal_fetch(tarball_cache, &package_version, config, symlink_path, semaphore).await?;
Ok(package_version.to_owned())
}

Expand All @@ -46,6 +49,7 @@ async fn internal_fetch(
package_version: &PackageVersion,
config: &Npmrc,
symlink_path: &Path,
semaphore: &Semaphore,
) -> Result<(), PackageManagerError> {
let store_folder_name = package_version.to_store_name();

Expand All @@ -56,6 +60,7 @@ async fn internal_fetch(
package_version.dist.integrity.as_ref().expect("has integrity field"),
package_version.dist.unpacked_size,
package_version.as_tarball_url(),
semaphore,
)
.await?;

Expand Down
3 changes: 3 additions & 0 deletions crates/registry/src/package.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use pipe_trait::Pipe;
use serde::{Deserialize, Serialize};

use crate::{package_version::PackageVersion, NetworkError, RegistryError};
use tokio::sync::Semaphore;

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Package {
Expand All @@ -30,9 +31,11 @@ impl Package {
name: &str,
http_client: &reqwest::Client,
registry: &str,
semaphore: &Semaphore,
) -> Result<Self, RegistryError> {
let url = || format!("{registry}{name}"); // TODO: use reqwest URL directly
let network_error = |error| NetworkError { error, url: url() };
let _permit = semaphore.acquire().await;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you gave this value a name? If you want to postpone dropping it, I recommend explicitly calling drop (i.e. drop(permit)) at the end of the scope. If you only want to .await the semaphore, please remove let _permit =.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't it be dropped immediately if I don't assign it to a value? As far as I understand, the permit is dropped, when it goes out of scope.

http_client
.get(url())
.header("content-type", "application/json")
Expand Down
3 changes: 3 additions & 0 deletions crates/registry/src/package_version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use pipe_trait::Pipe;
use serde::{Deserialize, Serialize};

use crate::{package_distribution::PackageDistribution, NetworkError, RegistryError};
use tokio::sync::Semaphore;

#[derive(Serialize, Deserialize, Debug, Clone, Eq)]
#[serde(rename_all = "camelCase")]
Expand All @@ -28,9 +29,11 @@ impl PackageVersion {
version: &str,
http_client: &reqwest::Client,
registry: &str,
semaphore: &Semaphore,
) -> Result<Self, RegistryError> {
let url = || format!("{registry}{name}/{version}");
let network_error = |error| NetworkError { error, url: url() };
let _permit = semaphore.acquire().await;

http_client
.get(url())
Expand Down
5 changes: 4 additions & 1 deletion crates/tarball/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use pipe_trait::Pipe;
use reqwest::Client;
use ssri::{Integrity, IntegrityChecker};
use tar::Archive;
use tokio::sync::RwLock;
use tokio::sync::{RwLock, Semaphore};
use zune_inflate::{errors::InflateDecodeErrors, DeflateDecoder, DeflateOptions};

#[derive(Error, Debug, Diagnostic)]
Expand Down Expand Up @@ -119,6 +119,7 @@ pub async fn download_tarball_to_store(
package_integrity: &str,
package_unpacked_size: Option<usize>,
package_url: &str,
semaphore: &Semaphore
) -> Result<Arc<HashMap<OsString, PathBuf>>, TarballError> {
while let Some(cache_lock) = cache.get(package_url) {
tracing::info!(target: "pacquet::download", ?package_url, "Job taken");
Expand All @@ -145,6 +146,7 @@ pub async fn download_tarball_to_store(
}

let network_error = |error| NetworkError { url: package_url.to_string(), error };
let permit = semaphore.acquire().await;
let response = Client::new()
.get(package_url)
.send()
Expand All @@ -153,6 +155,7 @@ pub async fn download_tarball_to_store(
.bytes()
.await
.map_err(network_error)?;
drop(permit);

tracing::info!(target: "pacquet::download", ?package_url, "Download completed");

Expand Down
5 changes: 4 additions & 1 deletion tasks/benchmark/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::{fs, path::Path};
use std::{fs, path::Path, sync::Arc};

use clap::Parser;
use criterion::{Criterion, Throughput};
use mockito::ServerGuard;
use pacquet_tarball::download_tarball_to_store;
use project_root::get_project_root;
use tempfile::tempdir;
use tokio::sync::Semaphore;

#[derive(Debug, Parser)]
struct CliArgs {
Expand All @@ -25,6 +26,7 @@ fn bench_tarball(c: &mut Criterion, server: &mut ServerGuard, fixtures_folder: &
group.throughput(Throughput::Bytes(file.len() as u64));
group.bench_function("download_dependency", |b| {
b.to_async(&rt).iter(|| async {
let semaphore = Arc::new(Semaphore::new(16));
let dir = tempdir().unwrap();
let cas_map =
download_tarball_to_store(
Expand All @@ -33,6 +35,7 @@ fn bench_tarball(c: &mut Criterion, server: &mut ServerGuard, fixtures_folder: &
"sha512-dj7vjIn1Ar8sVXj2yAXiMNCJDmS9MQ9XMlIecX2dIzzhjSHCyKo4DdXjXMs7wKW2kj6yvVRSpuQjOZ3YLrh56w==",
Some(16697),
url,
&semaphore,
).await.unwrap();
drop(dir);
cas_map.len()
Expand Down