diff --git a/src/downloader.rs b/src/downloader.rs index 87f0462b3..22453d528 100644 --- a/src/downloader.rs +++ b/src/downloader.rs @@ -141,7 +141,7 @@ pub enum GetOutput { } /// Concurrency limits for the [`Downloader`]. -#[derive(Debug)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct ConcurrencyLimits { /// Maximum number of requests the service performs concurrently. pub max_concurrent_requests: usize, @@ -193,7 +193,7 @@ impl ConcurrencyLimits { } /// Configuration for retry behavior of the [`Downloader`]. -#[derive(Debug)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct RetryConfig { /// Maximum number of retry attempts for a node that failed to dial or failed with IO errors. pub max_retries_per_node: u32, @@ -325,13 +325,29 @@ impl Future for DownloadHandle { } } +/// All numerical config options for the downloader. +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +pub struct Config { + /// Concurrency limits for the downloader. + pub concurrency: ConcurrencyLimits, + /// Retry configuration for the downloader. + pub retry: RetryConfig, +} + /// Handle for the download services. -#[derive(Clone, Debug)] +#[derive(Debug, Clone)] pub struct Downloader { + inner: Arc, +} + +#[derive(Debug)] +struct Inner { /// Next id to use for a download intent. - next_id: Arc, + next_id: AtomicU64, /// Channel to communicate with the service. msg_tx: mpsc::Sender, + /// Configuration for the downloader. + config: Arc, } impl Downloader { @@ -340,44 +356,46 @@ impl Downloader { where S: Store, { - Self::with_config(store, endpoint, rt, Default::default(), Default::default()) + Self::with_config(store, endpoint, rt, Default::default()) } /// Create a new Downloader with custom [`ConcurrencyLimits`] and [`RetryConfig`]. - pub fn with_config( - store: S, - endpoint: Endpoint, - rt: LocalPoolHandle, - concurrency_limits: ConcurrencyLimits, - retry_config: RetryConfig, - ) -> Self + pub fn with_config(store: S, endpoint: Endpoint, rt: LocalPoolHandle, config: Config) -> Self where S: Store, { + let config = Arc::new(config); let me = endpoint.node_id().fmt_short(); let (msg_tx, msg_rx) = mpsc::channel(SERVICE_CHANNEL_CAPACITY); let dialer = Dialer::new(endpoint); - + let config2 = config.clone(); let create_future = move || { let getter = get::IoGetter { store: store.clone(), }; - - let service = Service::new(getter, dialer, concurrency_limits, retry_config, msg_rx); + let service = Service::new(getter, dialer, config2, msg_rx); service.run().instrument(error_span!("downloader", %me)) }; rt.spawn_detached(create_future); Self { - next_id: Arc::new(AtomicU64::new(0)), - msg_tx, + inner: Arc::new(Inner { + next_id: AtomicU64::new(0), + msg_tx, + config, + }), } } + /// Get the current configuration. + pub fn config(&self) -> &Config { + &self.inner.config + } + /// Queue a download. pub async fn queue(&self, request: DownloadRequest) -> DownloadHandle { let kind = request.kind; - let intent_id = IntentId(self.next_id.fetch_add(1, Ordering::SeqCst)); + let intent_id = IntentId(self.inner.next_id.fetch_add(1, Ordering::SeqCst)); let (sender, receiver) = oneshot::channel(); let handle = DownloadHandle { id: intent_id, @@ -391,7 +409,7 @@ impl Downloader { }; // if this fails polling the handle will fail as well since the sender side of the oneshot // will be dropped - if let Err(send_err) = self.msg_tx.send(msg).await { + if let Err(send_err) = self.inner.msg_tx.send(msg).await { let msg = send_err.0; debug!(?msg, "download not sent"); } @@ -407,7 +425,7 @@ impl Downloader { receiver: _, } = handle; let msg = Message::CancelIntent { id, kind }; - if let Err(send_err) = self.msg_tx.send(msg).await { + if let Err(send_err) = self.inner.msg_tx.send(msg).await { let msg = send_err.0; debug!(?msg, "cancel not sent"); } @@ -419,7 +437,7 @@ impl Downloader { /// downloads. Use [`Self::queue`] to queue a download. pub async fn nodes_have(&mut self, hash: Hash, nodes: Vec) { let msg = Message::NodesHave { hash, nodes }; - if let Err(send_err) = self.msg_tx.send(msg).await { + if let Err(send_err) = self.inner.msg_tx.send(msg).await { let msg = send_err.0; debug!(?msg, "nodes have not been sent") } @@ -567,19 +585,13 @@ struct Service { progress_tracker: ProgressTracker, } impl, D: DialerT> Service { - fn new( - getter: G, - dialer: D, - concurrency_limits: ConcurrencyLimits, - retry_config: RetryConfig, - msg_rx: mpsc::Receiver, - ) -> Self { + fn new(getter: G, dialer: D, config: Arc, msg_rx: mpsc::Receiver) -> Self { Service { getter, dialer, msg_rx, - concurrency_limits, - retry_config, + concurrency_limits: config.concurrency, + retry_config: config.retry, connected_nodes: Default::default(), retry_node_state: Default::default(), providers: Default::default(), diff --git a/src/downloader/test.rs b/src/downloader/test.rs index 0b5ea1f79..3b452f35a 100644 --- a/src/downloader/test.rs +++ b/src/downloader/test.rs @@ -47,16 +47,24 @@ impl Downloader { let (msg_tx, msg_rx) = mpsc::channel(super::SERVICE_CHANNEL_CAPACITY); let lp = LocalPool::default(); + let config = Arc::new(Config { + concurrency: concurrency_limits, + retry: retry_config, + }); + let config2 = config.clone(); lp.spawn_detached(move || async move { // we want to see the logs of the service - let service = Service::new(getter, dialer, concurrency_limits, retry_config, msg_rx); + let service = Service::new(getter, dialer, config2, msg_rx); service.run().await }); ( Downloader { - next_id: Arc::new(AtomicU64::new(0)), - msg_tx, + inner: Arc::new(Inner { + next_id: AtomicU64::new(0), + msg_tx, + config, + }), }, lp, ) diff --git a/src/net_protocol.rs b/src/net_protocol.rs index c8b0d83b8..a0f6e6574 100644 --- a/src/net_protocol.rs +++ b/src/net_protocol.rs @@ -18,7 +18,7 @@ use serde::{Deserialize, Serialize}; use tracing::debug; use crate::{ - downloader::{ConcurrencyLimits, Downloader, RetryConfig}, + downloader::{self, ConcurrencyLimits, Downloader, RetryConfig}, provider::EventSender, store::GcConfig, util::{ @@ -147,9 +147,8 @@ impl BlobBatches { pub struct Builder { store: S, events: Option, + downloader_config: Option, rt: Option, - concurrency_limits: Option, - retry_config: Option, } impl Builder { @@ -165,15 +164,23 @@ impl Builder { self } + /// Set custom downloader config + pub fn downloader_config(mut self, downloader_config: downloader::Config) -> Self { + self.downloader_config = Some(downloader_config); + self + } + /// Set custom [`ConcurrencyLimits`] to use. pub fn concurrency_limits(mut self, concurrency_limits: ConcurrencyLimits) -> Self { - self.concurrency_limits = Some(concurrency_limits); + let downloader_config = self.downloader_config.get_or_insert_with(Default::default); + downloader_config.concurrency = concurrency_limits; self } /// Set a custom [`RetryConfig`] to use. pub fn retry_config(mut self, retry_config: RetryConfig) -> Self { - self.retry_config = Some(retry_config); + let downloader_config = self.downloader_config.get_or_insert_with(Default::default); + downloader_config.retry = retry_config; self } @@ -184,12 +191,12 @@ impl Builder { .rt .map(Rt::Handle) .unwrap_or_else(|| Rt::Owned(LocalPool::default())); + let downloader_config = self.downloader_config.unwrap_or_default(); let downloader = Downloader::with_config( self.store.clone(), endpoint.clone(), rt.clone(), - self.concurrency_limits.unwrap_or_default(), - self.retry_config.unwrap_or_default(), + downloader_config, ); Blobs::new( self.store, @@ -207,9 +214,8 @@ impl Blobs { Builder { store, events: None, + downloader_config: None, rt: None, - concurrency_limits: None, - retry_config: None, } } } diff --git a/tests/rpc.rs b/tests/rpc.rs index 7dc12e7b2..ab96c8f65 100644 --- a/tests/rpc.rs +++ b/tests/rpc.rs @@ -1,7 +1,7 @@ #![cfg(feature = "test")] use std::{net::SocketAddr, path::PathBuf, vec}; -use iroh_blobs::net_protocol::Blobs; +use iroh_blobs::{downloader, net_protocol::Blobs}; use quic_rpc::client::QuinnConnector; use tempfile::TempDir; use testresult::TestResult; @@ -85,3 +85,28 @@ async fn quinn_rpc_large() -> TestResult<()> { assert_eq!(data, &data2[..]); Ok(()) } + +#[tokio::test] +async fn downloader_config() -> TestResult<()> { + let _ = tracing_subscriber::fmt::try_init(); + let endpoint = iroh::Endpoint::builder().bind().await?; + let store = iroh_blobs::store::mem::Store::default(); + let expected = downloader::Config { + concurrency: downloader::ConcurrencyLimits { + max_concurrent_requests: usize::MAX, + max_concurrent_requests_per_node: usize::MAX, + max_open_connections: usize::MAX, + max_concurrent_dials_per_hash: usize::MAX, + }, + retry: downloader::RetryConfig { + max_retries_per_node: u32::MAX, + initial_retry_delay: std::time::Duration::from_secs(1), + }, + }; + let blobs = Blobs::builder(store) + .downloader_config(expected) + .build(&endpoint); + let actual = blobs.downloader().config(); + assert_eq!(&expected, actual); + Ok(()) +}