diff --git a/crates/data/src/bin/hypha-data.rs b/crates/data/src/bin/hypha-data.rs index f42c5ef9..8fceb73c 100644 --- a/crates/data/src/bin/hypha-data.rs +++ b/crates/data/src/bin/hypha-data.rs @@ -36,7 +36,7 @@ use tokio::{ }; use tokio_retry::{ Retry, - strategy::{ExponentialBackoff, jitter}, + strategy::{FixedInterval, jitter}, }; use tracing::level_filters::LevelFilter; use tracing_subscriber::{ @@ -156,7 +156,9 @@ async fn run(config: ConfigWithMetadata) -> Result<()> { // Dial each gateway and, on success, set up a relay circuit listen via it. let gateway_peer_ids = Retry::spawn( - ExponentialBackoff::from_millis(100).map(jitter).take(3), + FixedInterval::from_millis(config.network().rtt_ms().max(100)) + .map(jitter) + .take(6), || { let network = network.clone(); diff --git a/crates/data/src/network.rs b/crates/data/src/network.rs index dde88f85..f77a7a1b 100644 --- a/crates/data/src/network.rs +++ b/crates/data/src/network.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, sync::Arc, time::Duration}; use futures_util::StreamExt; use hypha_config::NetworkConfig; @@ -84,8 +84,10 @@ impl Network { exclude_cidrs: Vec, network_config: &NetworkConfig, ) -> Result<(Self, NetworkDriver), SwarmError> { - let (action_sender, action_receiver) = mpsc::channel(5); + let (action_sender, action_receiver) = mpsc::channel(64); let meter = metrics::global::meter(); + let request_timeout = + (Duration::from_millis(network_config.rtt_ms()) * 10).max(Duration::from_secs(10)); let swarm = SwarmBuilder::with_existing_identity(cert_chain, private_key, ca_certs, crls) .with_tokio() @@ -150,20 +152,22 @@ impl Network { StreamProtocol::new(data_record::IDENTIFIER), request_response::ProtocolSupport::Inbound, )], - request_response::Config::default(), + request_response::Config::default() + .with_request_timeout(request_timeout), ), health_request_response: request_response::Behaviour::::new( [( StreamProtocol::new(health::IDENTIFIER), request_response::ProtocolSupport::Full, )], - request_response::Config::default(), + request_response::Config::default().with_request_timeout(request_timeout), ), } }) .map_err(|_| { SwarmError::BehaviourCreation("Failed to create swarm behavior.".to_string()) })? + .with_swarm_config(|c| c.with_idle_connection_timeout(Duration::from_secs(30))) .build(); Ok(( diff --git a/crates/gateway/src/network.rs b/crates/gateway/src/network.rs index 59b8de69..c1d931fa 100644 --- a/crates/gateway/src/network.rs +++ b/crates/gateway/src/network.rs @@ -3,7 +3,7 @@ //! This module wires together the various networking primitives to run the //! gateway's event loop. -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, sync::Arc, time::Duration}; use futures_util::stream::StreamExt; use hypha_config::NetworkConfig; @@ -163,8 +163,7 @@ impl Network { .map_err(|_| { SwarmError::BehaviourCreation("Failed to create swarm behavior.".to_string()) })? - // TODO: Tune swarm configuration - .with_swarm_config(|config| config) + .with_swarm_config(|c| c.with_idle_connection_timeout(Duration::from_secs(30))) .build(); swarm diff --git a/crates/messages/src/lib.rs b/crates/messages/src/lib.rs index 4957297f..2153f09a 100644 --- a/crates/messages/src/lib.rs +++ b/crates/messages/src/lib.rs @@ -163,13 +163,11 @@ pub mod action { }, SendModel { target: Reference, - timeout: SystemTime, }, ExecuteBatch, SendUpdate { target: Reference, weight: f32, - timeout: SystemTime, }, ApplyUpdate { source: Reference, diff --git a/crates/network/src/request_response.rs b/crates/network/src/request_response.rs index 69cc414e..d27e09d2 100644 --- a/crates/network/src/request_response.rs +++ b/crates/network/src/request_response.rs @@ -672,32 +672,34 @@ where let handler = handlers.iter().find(|h| (h.matcher)(&request)); if let Some(handler) = handler { - match handler - .sender - .send(Ok(InboundRequest { - request_id, - channel, - peer_id: peer, - request, - })) - .await - { - Ok(_) => { - tracing::trace!( - peer = %peer, - request_id = ?request_id, - handler_id = %handler.id, - "Successfully sent request to handler channel" - ); + let sender = handler.sender.clone(); + let handler_id = handler.id; + let inbound_request = InboundRequest { + request_id, + channel, + peer_id: peer, + request, + }; + + tokio::spawn(async move { + match sender.send(Ok(inbound_request)).await { + Ok(_) => { + tracing::trace!( + peer = %peer, + request_id = ?request_id, + handler_id = %handler_id, + "Successfully sent request to handler channel" + ); + } + Err(_) => { + tracing::warn!( + peer = %peer, + handler_id = %handler_id, + "Handler channel closed, request dropped" + ); + } } - Err(_) => { - tracing::warn!( - peer = %peer, - handler_id = %handler.id, - "Handler channel closed, request dropped" - ); - } - } + }); } else { tracing::warn!( peer = %peer, @@ -861,6 +863,9 @@ where /// /// This enables `network.on::

(...)` and `network.request::

(...)` where `P: Protocol`. pub trait RequestResponseInterfaceExt: Clone + Sized + Send + Sync + 'static { + /// Default channel capacity for handler streams. Sized to tolerate bursty arrivals under RTT. + const DEFAULT_HANDLER_BUFFER: usize = 512; + /// Create a handler builder for the protocol `P` using the given pattern. fn on(&self, pattern: Pat) -> HandlerBuilder<'_, TCodec, Self> where @@ -872,7 +877,7 @@ pub trait RequestResponseInterfaceExt: Clone + Sized + Send + Sync + 'static { HandlerBuilder { interface: self, matcher: pattern.into_matcher(), - buffer_size: 32, + buffer_size: Self::DEFAULT_HANDLER_BUFFER, } } diff --git a/crates/scheduler/src/bin/hypha-scheduler.rs b/crates/scheduler/src/bin/hypha-scheduler.rs index a1235acc..b3316d44 100644 --- a/crates/scheduler/src/bin/hypha-scheduler.rs +++ b/crates/scheduler/src/bin/hypha-scheduler.rs @@ -37,7 +37,7 @@ use miette::{IntoDiagnostic, Result}; use serde_json::Value; use tokio_retry::{ Retry, - strategy::{ExponentialBackoff, jitter}, + strategy::{FixedInterval, jitter}, }; use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::CancellationToken; @@ -144,7 +144,9 @@ async fn run(config: ConfigWithMetadata) -> Result<()> { // Dial each gateway and, on success, set up a relay circuit listen via it. let gateway_peer_ids = Retry::spawn( - ExponentialBackoff::from_millis(100).map(jitter).take(3), + FixedInterval::from_millis(config.network().rtt_ms().max(100)) + .map(jitter) + .take(6), || { let network = network.clone(); diff --git a/crates/scheduler/src/network.rs b/crates/scheduler/src/network.rs index 754be4ca..bb9eff69 100644 --- a/crates/scheduler/src/network.rs +++ b/crates/scheduler/src/network.rs @@ -3,7 +3,7 @@ //! The scheduler orchestrates workers via libp2p. This module brings together //! the networking primitives and drives the underlying swarm. -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, sync::Arc, time::Duration}; use futures_util::stream::StreamExt; use hypha_config::NetworkConfig; @@ -110,8 +110,10 @@ impl Network { exclude_cidrs: Vec, network_config: &NetworkConfig, ) -> Result<(Self, NetworkDriver), SwarmError> { - let (action_sender, action_receiver) = mpsc::channel(5); + let (action_sender, action_receiver) = mpsc::channel(64); let meter = metrics::global::meter(); + let request_timeout = + (Duration::from_millis(network_config.rtt_ms()) * 10).max(Duration::from_secs(10)); // Build libp2p Swarm using the derived identity and mTLS config let swarm = SwarmBuilder::with_existing_identity(cert_chain, private_key, ca_certs, crls) @@ -182,21 +184,21 @@ impl Network { StreamProtocol::new(api::IDENTIFIER), request_response::ProtocolSupport::Full, )], - request_response::Config::default(), + request_response::Config::default().with_request_timeout(request_timeout), ), health_request_response: request_response::Behaviour::::new( [( StreamProtocol::new(health::IDENTIFIER), request_response::ProtocolSupport::Outbound, )], - request_response::Config::default(), + request_response::Config::default().with_request_timeout(request_timeout), ), action_request_response: request_response::Behaviour::::new( [( StreamProtocol::new(action::IDENTIFIER), request_response::ProtocolSupport::Full, )], - request_response::Config::default(), + request_response::Config::default().with_request_timeout(request_timeout), ), data_record_request_response: request_response::Behaviour::::new( @@ -204,13 +206,15 @@ impl Network { StreamProtocol::new(data_record::IDENTIFIER), request_response::ProtocolSupport::Outbound, )], - request_response::Config::default(), + request_response::Config::default() + .with_request_timeout(request_timeout), ), } }) .map_err(|_| { SwarmError::BehaviourCreation("Failed to create swarm behavior.".to_string()) })? + .with_swarm_config(|c| c.with_idle_connection_timeout(Duration::from_secs(30))) .build(); Ok(( diff --git a/crates/scheduler/src/scheduling/batch_scheduler.rs b/crates/scheduler/src/scheduling/batch_scheduler.rs index 30e70907..bef0d6f3 100644 --- a/crates/scheduler/src/scheduling/batch_scheduler.rs +++ b/crates/scheduler/src/scheduling/batch_scheduler.rs @@ -38,6 +38,7 @@ use crate::{ // decide when to instruct the parameter server to aggregate. #[derive(Default)] struct RoundState { + aggregated_updates: bool, sent_updates: HashSet, first_update_at: Option, min_quorum: usize, @@ -48,6 +49,8 @@ struct RoundState { training_complete: bool, applied_final_update: HashSet, push_done: bool, + // NOTE: Tracks workers that have applied the update for the current round. + applied_updates: HashSet, } #[derive(Default)] @@ -164,6 +167,11 @@ where tracing::debug!(%peer_id, ?status, %job_id, "Received action request"); let now = SystemTime::now(); + // Rimeouts sized for ~100ms RTT with generous margins. + let short_idle = now + Duration::from_millis(500); + let wait_model = now + Duration::from_secs(1); + let long_io = now + Duration::from_secs(60); + let ps_broadcast_idle = now + Duration::from_secs(5); let since_start = start.elapsed().as_millis() as u64; // NOTE: We rely on Pool::members() being oldest-first ordered by join time. @@ -176,7 +184,7 @@ where let state = round_state.lock().await; if state.round == 0 { ExecutorAction::Train(TrainAction::Idle { - timeout: now + Duration::from_millis(100), + timeout: short_idle, }) } else { training_state @@ -184,7 +192,7 @@ where .await .push_worker_without_model(peer_id); ExecutorAction::Train(TrainAction::WaitForModel { - timeout: now + Duration::from_secs(1), + timeout: wait_model, }) } } @@ -198,25 +206,21 @@ where strategy: SelectionStrategy::All, resource: None, }, - timeout: now + Duration::from_secs(60), + timeout: long_io, }) } else { ExecutorAction::Train(TrainAction::WaitForModel { - timeout: now + Duration::from_secs(1), + timeout: wait_model, }) } } TrainStatus::ReceivedModel => { // Lazy transition to other state - ExecutorAction::Train(TrainAction::Idle { - timeout: now + Duration::from_secs(1), - }) + ExecutorAction::Train(TrainAction::Idle { timeout: now }) } TrainStatus::SentModel => { // Lazy transition to other state - ExecutorAction::Train(TrainAction::Idle { - timeout: now + Duration::from_secs(1), - }) + ExecutorAction::Train(TrainAction::Idle { timeout: now }) } TrainStatus::Idle => { let mut state = round_state.lock().await; @@ -281,13 +285,26 @@ where (false, count) }; - if !should_update { + if state.aggregated_updates && !state.applied_updates.contains(&peer_id) { + ExecutorAction::Train(TrainAction::ApplyUpdate { + source: Reference::Peers { + peers: parameter_servers, + strategy: SelectionStrategy::All, + resource: None, + }, + timeout: now + Duration::from_secs(10), + }) + } else if state.sent_updates.contains(&peer_id) { + ExecutorAction::Train(TrainAction::Idle { + timeout: short_idle, + }) + } else if !should_update { ExecutorAction::Train(TrainAction::ExecuteBatch) } else if parameter_servers.is_empty() { // NOTE: If we need to send an update but there are no parameter servers, // we must wait (idle) until one becomes available. ExecutorAction::Train(TrainAction::Idle { - timeout: now + Duration::from_secs(1), + timeout: short_idle, }) } else { ExecutorAction::Train(TrainAction::SendUpdate { @@ -298,8 +315,6 @@ where resource: None, }, weight: peer_contribution as f32 / projected_target as f32, - // TODO: We need a way to properly determine a good sent timeout - timeout: now + Duration::from_secs(30), }) } } else if state.push_done { @@ -321,7 +336,7 @@ where } } else { ExecutorAction::Train(TrainAction::Idle { - timeout: now + Duration::from_secs(1), + timeout: short_idle, }) } } @@ -344,9 +359,17 @@ where ) }; - if round_state.lock().await.training_complete { + let (training_complete, sent_update) = { + let state = round_state.lock().await; + ( + state.training_complete, + state.sent_updates.contains(&peer_id), + ) + }; + + if training_complete || sent_update { ExecutorAction::Train(TrainAction::Idle { - timeout: now + Duration::from_secs(1), + timeout: short_idle, }) } else { let stats: Vec = @@ -400,7 +423,7 @@ where // NOTE: If we need to send an update but there are no parameter servers, // we must wait (idle) until one becomes available. ExecutorAction::Train(TrainAction::Idle { - timeout: now + Duration::from_secs(1), + timeout: short_idle, }) } else { ExecutorAction::Train(TrainAction::SendUpdate { @@ -411,8 +434,6 @@ where resource: None, }, weight: peer_contribution as f32 / projected_target as f32, - // TODO: We need a way to properly determine a good sent timeout - timeout: now + Duration::from_secs(30), }) } } @@ -443,24 +464,15 @@ where since_first_ms = elapsed_ms, "Worker reported SentUpdate; recorded for round" ); - if parameter_servers.is_empty() { - ExecutorAction::Train(TrainAction::Idle { - timeout: now + Duration::from_secs(1), - }) - } else { - ExecutorAction::Train(TrainAction::ApplyUpdate { - source: Reference::Peers { - peers: parameter_servers, - strategy: SelectionStrategy::All, - resource: None, - }, - timeout: now + Duration::from_secs(30), - }) - } + + ExecutorAction::Train(TrainAction::Idle { + timeout: short_idle, + }) } TrainStatus::AppliedUpdate => { let training_complete = { let mut state = round_state.lock().await; + state.applied_updates.insert(peer_id); if state.training_complete { state.applied_final_update.insert(peer_id); @@ -472,7 +484,7 @@ where if training_complete { ExecutorAction::Train(TrainAction::Idle { - timeout: now + Duration::from_secs(1), + timeout: now + Duration::from_millis(500), }) } else { let mut training = training_state.lock().await; @@ -484,7 +496,6 @@ where strategy: SelectionStrategy::One, resource: None, }, - timeout: now + Duration::from_secs(30), }) } else { ExecutorAction::Train(TrainAction::ExecuteBatch) @@ -508,7 +519,7 @@ where } } ExecutorAction::Train(TrainAction::Idle { - timeout: now + Duration::from_secs(1), + timeout: short_idle, }) } TrainStatus::Error(TrainError::Other { message }) => { @@ -537,7 +548,7 @@ where ); } ExecutorAction::Aggregate(AggregateAction::Idle { - timeout: now + Duration::from_secs(5), + timeout: short_idle, }) } else { let workers: Vec<_> = { @@ -553,7 +564,7 @@ where if workers.is_empty() { ExecutorAction::Aggregate(AggregateAction::Idle { - timeout: now + Duration::from_secs(1), + timeout: short_idle, }) } else { // Start aggregation when either all workers have sent updates, @@ -580,7 +591,7 @@ where }) } else { ExecutorAction::Aggregate(AggregateAction::Idle { - timeout: now + Duration::from_millis(500), + timeout: short_idle, }) } } @@ -590,7 +601,7 @@ where // Only allow the primary PS to proceed to broadcast. if Some(peer_id) != primary_ps { ExecutorAction::Aggregate(AggregateAction::Idle { - timeout: now + Duration::from_secs(5), + timeout: ps_broadcast_idle, }) } else { let workers: Vec<_> = worker_pool @@ -601,53 +612,18 @@ where if workers.is_empty() { ExecutorAction::Aggregate(AggregateAction::Idle { - timeout: now + Duration::from_secs(1), + timeout: short_idle, }) } else { // Log that we are moving to broadcast for this round. let round = { - let state = round_state.lock().await; + let mut state = round_state.lock().await; + state.aggregated_updates = true; + state.applied_updates.clear(); state.round }; tracing::info!(round = %round, "Trigger BroadcastUpdate"); - // Reset round state befor completing a broadcast on the primary PS. - // Because workers will directly continue training after receiving - // updates and speed varies. The PS returns after ALL workers - // received their updates. - if Some(peer_id) == primary_ps { - let next_round = { - let mut state = round_state.lock().await; - - tracing::info!( - round = state.round, - "Broadcast completed; advancing round" - ); - - state.sent_updates.clear(); - state.first_update_at = None; - state.round = state.round.saturating_add(1); - - if state.round >= state.update_rounds { - state.training_complete = true; - tracing::info!( - round = state.round, - target = state.update_rounds, - "Target update rounds reached; entering completion phase" - ); - } - - state.round - }; - - let mut training = training_state.lock().await; - training.reset_round(); - tracing::info!( - round = next_round, - "Next round started; training state reset" - ); - } - ExecutorAction::Aggregate(AggregateAction::BroadcastUpdate { target: Reference::Peers { peers: workers, @@ -665,23 +641,63 @@ where .map_err(BatchSchedulerError::from)?; } - let training_complete = { round_state.lock().await.training_complete }; + let next_round = { + let mut state = round_state.lock().await; + + tracing::info!(round = state.round, "Broadcast completed; advancing round"); + + state.sent_updates.clear(); + state.first_update_at = None; + state.round = state.round.saturating_add(1); + + if state.round >= state.update_rounds { + state.training_complete = true; + tracing::info!( + round = state.round, + target = state.update_rounds, + "Target update rounds reached; entering completion phase" + ); + } + + state.round + }; + + let mut training = training_state.lock().await; + training.reset_round(); + tracing::info!( + round = next_round, + "Next round started; training state reset" + ); + + let training_complete = { + let mut state = round_state.lock().await; + state.aggregated_updates = false; + state.training_complete + }; if training_complete { ExecutorAction::Aggregate(AggregateAction::Terminate) } else { ExecutorAction::Aggregate(AggregateAction::Idle { - timeout: now + Duration::from_secs(1), + timeout: short_idle, }) } } AggregateStatus::Error(AggregateError::Connection { message }) => { tracing::warn!(%peer_id, message = %message, "Aggregator reported connection error"); + { + let mut state = round_state.lock().await; + state.aggregated_updates = false; + } ExecutorAction::Aggregate(AggregateAction::Idle { - timeout: now + Duration::from_secs(1), + timeout: short_idle, }) } AggregateStatus::Error(AggregateError::Other { message }) => { tracing::warn!(%peer_id, message = %message, "Aggregator reported error"); + { + let mut state = round_state.lock().await; + state.aggregated_updates = false; + } ExecutorAction::Aggregate(AggregateAction::Terminate) } @@ -740,6 +756,8 @@ impl BatchScheduler { training_complete: false, applied_final_update: HashSet::new(), push_done: false, + aggregated_updates: false, + applied_updates: HashSet::new(), })); let training_state = Arc::new(Mutex::new(TrainingState::new(samples_between_updates))); network @@ -801,7 +819,10 @@ impl BatchScheduler { #[cfg(test)] mod batch_scheduler_tests { - use std::{collections::HashMap, time::SystemTime}; + use std::{ + collections::{HashMap, HashSet}, + time::SystemTime, + }; use futures_util::StreamExt; use hypha_messages::{ @@ -1276,6 +1297,8 @@ mod batch_scheduler_tests { training_complete: false, applied_final_update: Default::default(), push_done: false, + aggregated_updates: false, + applied_updates: HashSet::default(), })); let training_state = std::sync::Arc::new(tokio::sync::Mutex::new(TrainingState::new(800))); let batch_sizer = std::sync::Arc::new(|resources: &Resources| resources.gpu() as u32); @@ -1351,7 +1374,6 @@ mod batch_scheduler_tests { resource: None, }, weight: 0.3, - timeout: SystemTime::now(), }), 2000, ), @@ -1365,7 +1387,6 @@ mod batch_scheduler_tests { resource: None, }, weight: 0.3, - timeout: SystemTime::now(), }), 2400, ), diff --git a/crates/scheduler/src/worker.rs b/crates/scheduler/src/worker.rs index c6e6df5d..8542b5c4 100644 --- a/crates/scheduler/src/worker.rs +++ b/crates/scheduler/src/worker.rs @@ -11,6 +11,10 @@ use hypha_resources::Resources; use libp2p::PeerId; use thiserror::Error; use tokio::{task::JoinHandle, time::sleep}; +use tokio_retry::{ + Retry, + strategy::{FixedInterval, jitter}, +}; use uuid::Uuid; use crate::network::Network; @@ -85,19 +89,26 @@ impl Worker { async move { loop { tracing::debug!(%lease_id, %peer_id, "Refreshing lease"); - match network - .request::( - peer_id, - api::Request::RenewLease(renew_lease::Request { id: lease_id }), - ) - .await - { + let retry_strategy = FixedInterval::from_millis(200).map(jitter).take(6); + + let result = Retry::spawn(retry_strategy, || { + let network = network.clone(); + async move { + network + .request::( + peer_id, + api::Request::RenewLease(renew_lease::Request { id: lease_id }), + ) + .await + } + }) + .await; + + match result { Ok(api::Response::RenewLease(renew_lease::Response::Renewed { - id: _, timeout, + .. })) => { - // Handle successful response - // TODO: Make the min refresh configurable let duration = timeout @@ -118,15 +129,18 @@ impl Worker { sleep(safe_duration).await; } Ok(api::Response::RenewLease(renew_lease::Response::Failed)) => { - // Handle failed response return Err(WorkerError::LeaseExpired); } Err(error) => { - // Handle error + tracing::warn!( + %lease_id, + %peer_id, + error = %error, + "Lease renewal failed after retries" + ); return Err(WorkerError::NetworkError(error)); } _ => { - // Handle unexpected response return Err(WorkerError::DispatchFailed( "Unexpected response".to_string(), )); diff --git a/crates/worker/src/arbiter.rs b/crates/worker/src/arbiter.rs index b4852109..b416da28 100644 --- a/crates/worker/src/arbiter.rs +++ b/crates/worker/src/arbiter.rs @@ -25,9 +25,9 @@ const WORKER_TOPIC: &str = "hypha/worker"; // This allows proper handling of multiple schedulers by batching advertisements const WINDOW_LIMIT: usize = 100; const WINDOW_WAIT: std::time::Duration = std::time::Duration::from_millis(200); -const OFFER_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(500); +const OFFER_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5); const PRUNE_INTERVAL: std::time::Duration = std::time::Duration::from_millis(250); -const LEASE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); +const LEASE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60); #[derive(Debug, Error)] #[error("lease error")] diff --git a/crates/worker/src/bin/hypha-worker.rs b/crates/worker/src/bin/hypha-worker.rs index a79d3eee..0cc66cf0 100644 --- a/crates/worker/src/bin/hypha-worker.rs +++ b/crates/worker/src/bin/hypha-worker.rs @@ -31,7 +31,7 @@ use miette::{IntoDiagnostic, Result}; use tokio::signal::unix::{SignalKind, signal}; use tokio_retry::{ Retry, - strategy::{ExponentialBackoff, jitter}, + strategy::{FixedInterval, jitter}, }; use tokio_util::sync::CancellationToken; use tracing::level_filters::LevelFilter; @@ -148,7 +148,9 @@ async fn run(config: ConfigWithMetadata) -> Result<()> { // Dial each gateway and, on success, set up a relay circuit listen via it. let gateway_peer_ids = Retry::spawn( - ExponentialBackoff::from_millis(100).map(jitter).take(3), + FixedInterval::from_millis(config.network().rtt_ms().max(100)) + .map(jitter) + .take(6), || { let network = network.clone(); diff --git a/crates/worker/src/connector/mod.rs b/crates/worker/src/connector/mod.rs index 8681165d..7ccb4a58 100644 --- a/crates/worker/src/connector/mod.rs +++ b/crates/worker/src/connector/mod.rs @@ -19,7 +19,7 @@ use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_retry::{ Retry, - strategy::{ExponentialBackoff, jitter}, + strategy::{FixedInterval, jitter}, }; use tokio_util::io::StreamReader; @@ -347,44 +347,47 @@ where } => match strategy { SelectionStrategy::All => { let network = self.network.clone(); - let it = futures_util::stream::iter(peers.clone()).then(move |peer| { - let network = network.clone(); - async move { - let retry_strategy = - ExponentialBackoff::from_millis(100).map(jitter).take(3); - - async fn attempt_push( - network: T, - peer: PeerId, - payload_len: u64, - ) -> Result - where - T: StreamPushSenderInterface, - { - let writer = network - .open_push_stream(peer, payload_len) - .await - .map_err(ConnectorError::OpenStream)?; - Ok(Box::pin(writer)) - } - - let result = Retry::spawn(retry_strategy, move || { - attempt_push(network.clone(), peer, payload_len) - }) - .await; - - match result { - Ok(writer) => Ok(WriteItem { - meta: ItemMeta { - kind: "peer", - name: peer.to_string(), - }, - writer, - }), - Err(e) => Err(e), + let count = peers.len(); + let it = futures_util::stream::iter(peers.clone()) + .map(move |peer| { + let network = network.clone(); + async move { + let retry_strategy = + FixedInterval::from_millis(200).map(jitter).take(6); + + async fn attempt_push( + network: T, + peer: PeerId, + payload_len: u64, + ) -> Result + where + T: StreamPushSenderInterface, + { + let writer = network + .open_push_stream(peer, payload_len) + .await + .map_err(ConnectorError::OpenStream)?; + Ok(Box::pin(writer)) + } + + let result = Retry::spawn(retry_strategy, move || { + attempt_push(network.clone(), peer, payload_len) + }) + .await; + + match result { + Ok(writer) => Ok(WriteItem { + meta: ItemMeta { + kind: "peer", + name: peer.to_string(), + }, + writer, + }), + Err(e) => Err(e), + } } - } - }); + }) + .buffer_unordered(count.max(1)); Ok(Box::pin(it) as WriteItemStream) } SelectionStrategy::One | SelectionStrategy::Random => { diff --git a/crates/worker/src/executor/bridge.rs b/crates/worker/src/executor/bridge.rs index db1197f0..d930519c 100644 --- a/crates/worker/src/executor/bridge.rs +++ b/crates/worker/src/executor/bridge.rs @@ -3,20 +3,17 @@ use std::{ os::unix::fs::{MetadataExt, PermissionsExt}, path::{Path, PathBuf}, sync::Arc, - time::Duration, + time::SystemTime, }; use axum::{ Json, Router, extract::State, http::StatusCode, - response::{ - IntoResponse, Response, - sse::{Event, KeepAlive, Sse}, - }, + response::{IntoResponse, Response}, routing::{get, post}, }; -use futures_util::{StreamExt, stream}; +use futures_util::StreamExt; use hypha_data::hash::get_file_hash; use hypha_messages::{ DataSlice, Fetch, Receive, Reference, Send, @@ -35,10 +32,11 @@ use tokio::{ fs::{self, set_permissions}, io::{self, AsyncWriteExt}, net::UnixListener, + time::sleep, }; use tokio_retry::{ Retry, - strategy::{ExponentialBackoff, FibonacciBackoff, jitter}, + strategy::{FibonacciBackoff, FixedInterval, jitter}, }; use tokio_util::{sync::CancellationToken, task::TaskTracker}; use utoipa::OpenApi; @@ -138,7 +136,6 @@ struct SockState { network: Network, job_id: Uuid, scheduler: PeerId, - task_tracker: TaskTracker, cancel: CancellationToken, } @@ -171,7 +168,6 @@ impl Bridge { network, job_id, scheduler, - task_tracker: task_tracker.clone(), cancel: cancel_token.clone(), }); @@ -408,9 +404,7 @@ async fn send_resource( State(state): State>, Json(req): Json, ) -> Result<(), Error> { - let retry_strategy = ExponentialBackoff::from_millis(100) - .map(jitter) // add jitter to delays - .take(3); // limit to 3 retries + let retry_strategy = FixedInterval::from_millis(50).map(jitter).take(20); Retry::spawn(retry_strategy, || { let state = state.clone(); @@ -518,6 +512,7 @@ fn validate_fetch(resource: &Fetch) -> Result<(), Error> { struct ReceiveSubscribeRequest { resource: Receive, path: Option, + timeout: Option, } #[derive(Debug, Serialize)] @@ -530,119 +525,114 @@ struct UpdatePointer { async fn receive_subscribe( State(state): State>, Json(req): Json, -) -> Result>>, Error> { +) -> Result { let dir_rel = req.path.unwrap_or_else(|| "incoming".to_string()); let dir_abs = safe_join(&state.work_dir, &dir_rel)?; fs::create_dir_all(&dir_abs).await?; - // Channel to push events to the SSE stream - let (tx, rx) = tokio::sync::mpsc::channel::(64); - let connector = state.connector.clone(); + let idle_timeout = req + .timeout + .and_then(|t| t.duration_since(SystemTime::now()).ok()); + if idle_timeout + .as_ref() + .is_some_and(|duration| duration.is_zero()) + { + return Ok(StatusCode::NO_CONTENT.into_response()); + } + + let mut incoming = match state.connector.receive(req.resource.clone()).await { + Ok(s) => s, + Err(err) => { + tracing::error!(error = %err, path = %dir_rel, "receive_subscribe: failed to start stream"); + return Ok(StatusCode::NO_CONTENT.into_response()); + } + }; let work_dir = state.work_dir.clone(); - let resource = req.resource.clone(); let cancel = state.cancel.clone(); - let task_tracker = state.task_tracker.clone(); - let dir_rel_clone = dir_rel.clone(); + let mut idle_timer = idle_timeout.map(|duration| Box::pin(sleep(duration))); + let mut pointer: Option = None; - // Background task: receive loops until the client disconnects or an error occurs - task_tracker.spawn(async move { - let mut incoming = match connector.receive(resource).await { - Ok(s) => s, + while let Some(item_result) = tokio::select! { + _ = cancel.cancelled() => { + tracing::debug!(path = %dir_rel, "receive_subscribe: task cancelled"); + None + } + _ = async { + if let Some(timer) = idle_timer.as_mut() { + timer.as_mut().await; + } else { + std::future::pending::<()>().await; + } + }, + if idle_timer.is_some() => { + tracing::warn!(path = %dir_rel, "receive_subscribe: idle timeout reached"); + None + } + item = incoming.next() => item, + } { + let item = match item_result { + Ok(item) => item, Err(err) => { - tracing::error!(error = %err, path = %dir_rel_clone, "receive_subscribe: failed to start stream"); - return; + tracing::warn!(error = %err, path = %dir_rel, "receive_subscribe: stream error"); + continue; } }; - let mut index = 0usize; - while let Some(item_result) = tokio::select! { - _ = tx.closed() => { - tracing::debug!(path = %dir_rel_clone, "receive_subscribe: client stream dropped"); - None - } - _ = cancel.cancelled() => { - tracing::debug!(path = %dir_rel_clone, "receive_subscribe: task cancelled"); - None - } - item = incoming.next() => item, - } { - let item = match item_result { - Ok(item) => item, - Err(err) => { - tracing::warn!(error = %err, path = %dir_rel_clone, "receive_subscribe: stream error"); - continue; - } - }; - let (file_name, mut reader) = derive_name_and_reader(item, index); - let file_rel = format!("{}/{}", dir_rel_clone, file_name); - let file_abs = match safe_join(&work_dir, &file_rel) { - Ok(p) => p, - Err(err) => { - tracing::error!(error = %err, file = %file_rel, "receive_subscribe: invalid target path"); - continue; - } - }; - if let Some(parent) = file_abs.parent() { - match fs::create_dir_all(parent).await { - Ok(()) => (), - Err(err) => { - tracing::error!(error = %err, directory = %parent.display(), "receive_subscribe: failed to create directory"); - continue; - } - } + // Once data starts flowing, disable idle timeout so long copies are not interrupted. + idle_timer = None; + let (file_name, mut reader) = derive_name_and_reader(item, 0); + let file_rel = format!("{}/{}", dir_rel, file_name); + let file_abs = match safe_join(&work_dir, &file_rel) { + Ok(p) => p, + Err(err) => { + tracing::error!(error = %err, file = %file_rel, "receive_subscribe: invalid target path"); + continue; } - let mut file = match fs::File::create(&file_abs).await { - Ok(f) => f, - Err(err) => { - tracing::error!(error = %err, file = %file_abs.display(), "receive_subscribe: failed to create file"); - continue; - } - }; - let size = match tokio::io::copy(&mut reader, &mut file).await { - Ok(n) => n, + }; + if let Some(parent) = file_abs.parent() { + match fs::create_dir_all(parent).await { + Ok(()) => (), Err(err) => { - tracing::warn!(error = %err, file = %file_abs.display(), "receive_subscribe: failed to copy resource"); + tracing::error!(error = %err, directory = %parent.display(), "receive_subscribe: failed to create directory"); continue; } - }; - if let Err(err) = file.sync_all().await { - tracing::warn!(error = %err, file = %file_abs.display(), "receive_subscribe: failed to sync file"); } - if let Err(err) = set_permissions(&file_abs, Permissions::from_mode(0o600)).await { - tracing::warn!(error = %err, file = %file_abs.display(), "receive_subscribe: failed to set permissions"); + } + let mut file = match fs::File::create(&file_abs).await { + Ok(f) => f, + Err(err) => { + tracing::error!(error = %err, file = %file_abs.display(), "receive_subscribe: failed to create file"); + continue; } - - tracing::info!(size, file = %file_abs.display(), "Received resource"); - - let from_peer = file_name.split('.').next().unwrap_or("").to_string(); - let pointer = UpdatePointer { - path: file_rel, - size, - from_peer, - }; - let ev = match serde_json::to_string(&pointer) { - Ok(data) => Event::default().data(data), - Err(err) => { - tracing::error!(error = %err, "receive_subscribe: failed to serialize pointer"); - Event::default().data(r#"{"error":"serialize"}"#) - } - }; - if tx.send(ev).await.is_err() { - tracing::debug!(path = %dir_rel_clone, "receive_subscribe: client disconnected"); - break; + }; + let size = match tokio::io::copy(&mut reader, &mut file).await { + Ok(n) => n, + Err(err) => { + tracing::warn!(error = %err, file = %file_abs.display(), "receive_subscribe: failed to copy resource"); + continue; } - index += 1; + }; + if let Err(err) = file.sync_all().await { + tracing::warn!(error = %err, file = %file_abs.display(), "receive_subscribe: failed to sync file"); + } + if let Err(err) = set_permissions(&file_abs, Permissions::from_mode(0o600)).await { + tracing::warn!(error = %err, file = %file_abs.display(), "receive_subscribe: failed to set permissions"); } - }); - let stream = stream::unfold(rx, |mut rx| async move { - rx.recv().await.map(|ev| (Ok(ev), rx)) - }); + tracing::info!(size, file = %file_abs.display(), "Received resource"); - Ok(Sse::new(stream).keep_alive( - KeepAlive::new() - .interval(Duration::from_secs(5)) - .text("keepalive"), - )) + let from_peer = file_name.split('.').next().unwrap_or("").to_string(); + pointer = Some(UpdatePointer { + path: file_rel, + size, + from_peer, + }); + break; + } + + Ok(match pointer { + Some(p) => (StatusCode::OK, Json(p)).into_response(), + None => StatusCode::NO_CONTENT.into_response(), + }) } async fn send_action( @@ -653,7 +643,7 @@ async fn send_action( return Err(Error::InvalidStatus("job_id mismatch".to_string())); } - let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3); + let retry_strategy = FixedInterval::from_millis(200).map(jitter).take(6); // TODO we should ensure that a message is not received repeatedly. Otherwise it will distort the training. let result = Retry::spawn(retry_strategy, || { diff --git a/crates/worker/src/executor/parameter_server.rs b/crates/worker/src/executor/parameter_server.rs index ef08cb54..1904a053 100644 --- a/crates/worker/src/executor/parameter_server.rs +++ b/crates/worker/src/executor/parameter_server.rs @@ -28,7 +28,7 @@ use tokio::{ }; use tokio_retry::{ Retry, - strategy::{ExponentialBackoff, jitter}, + strategy::{FixedInterval, jitter}, }; use tokio_util::{future::FutureExt, sync::CancellationToken, task::TaskTracker}; use uuid::Uuid; @@ -92,9 +92,9 @@ impl JobExecutor for ParameterServerExecutor { ) -> Result { tracing::info!(job_spec = ?job, "Executing parameter server job"); - let retry_strategy = ExponentialBackoff::from_millis(100) - .map(jitter) // add jitter to delays - .take(3); // limit to 3 retries + // NOTE: Retry for a second, please note that this needs to align with + // the batch scheduler timings + let retry_strategy = FixedInterval::from_millis(50).map(jitter).take(20); let id = Uuid::new_v4(); let work_dir = self.work_dir_base.join(format!("hypha-{}", id)); @@ -183,6 +183,12 @@ impl JobExecutor for ParameterServerExecutor { let pid = peer.parse().unwrap_or_else(|_| PeerId::random()); let entry = store.entry(pid).or_default(); entry.push(file_path.clone()); + tracing::debug!( + peer_id = %peer, + stored = entry.len(), + total_peers = store.len(), + "Stored incoming update" + ); } updates_notify.notify_one(); }.with_cancellation_token_owned(cancel.clone())); @@ -268,9 +274,9 @@ impl JobExecutor for ParameterServerExecutor { // NOTE: Allowed peers come from scheduler. If empty, accept any. let allowed = receive.get_peers().clone(); - // TODO: These should come from the scheduler and must be configurable. - let max_delay = std::time::Duration::from_millis(500); - let action_deadline = std::time::Duration::from_secs(30); + // NOTE: Timeouts sized for ~100ms RTT. + let gap_timeout = std::time::Duration::from_secs(10); + let action_deadline = std::time::Duration::from_secs(60); match aggregate_updates( updates_store.clone(), @@ -278,7 +284,7 @@ impl JobExecutor for ParameterServerExecutor { work_dir.clone(), &device, &optimizer, - max_delay, + gap_timeout, action_deadline, updates_notify.clone(), cancel.clone(), @@ -295,10 +301,16 @@ impl JobExecutor for ParameterServerExecutor { } Err(e) => { tracing::warn!(error = %e, "Failed to aggregate updates"); + let agg_error = match e { + Error::InvalidExecutorConfig(msg) => { + AggregateError::Connection { message: msg } + } + other => AggregateError::Other { + message: other.to_string(), + }, + }; current_status = ExecutorStatus::Aggregate( - action::AggregateStatus::Error(AggregateError::Other { - message: e.to_string(), - }), + action::AggregateStatus::Error(agg_error), ); } } @@ -405,8 +417,7 @@ async fn aggregate_updates( let mut used: HashSet = HashSet::new(); let deadline = tokio::time::Instant::now() + action_deadline; - // NOTE: Max delay we allow for any peer to send an update - // when, if have not received an update within this time, we end the action. + // NOTE: Max delay we allow between updates before ending the action. let max_delay = tokio::time::sleep(gap_timeout); tokio::pin!(max_delay); @@ -488,14 +499,31 @@ async fn aggregate_updates( tokio::select! { _ = cancel.cancelled() => return Err(Error::InvalidExecutorConfig("aggregation cancelled".to_string())), _ = &mut max_delay => { - tracing::debug!("Aggregate max delay reached"); + tracing::debug!( + waited_ms = gap_timeout.as_millis(), + used = used.len(), + allow = allowed.len(), + "Aggregate max delay reached" + ); break; }, _ = tokio::time::sleep_until(deadline) => { - tracing::warn!("Aggregate deadline reached"); + tracing::warn!( + waited_ms = action_deadline.as_millis(), + used = used.len(), + allow = allowed.len(), + "Aggregate deadline reached" + ); break; }, _ = notify.notified() => { + tracing::debug!( + waited_ms = gap_timeout.as_millis(), + "Aggregation notified of new update; resetting gap timer" + ); + max_delay + .as_mut() + .reset(tokio::time::Instant::now() + gap_timeout); // New updates available; loop to try again continue; } @@ -510,6 +538,11 @@ async fn aggregate_updates( ) .await?; } else { + tracing::warn!( + used_peers = used.len(), + allowed_peers = allowed.len(), + "Aggregation finished without receiving any updates" + ); return Err(Error::InvalidExecutorConfig( "no updates available to aggregate".to_string(), )); @@ -535,6 +568,8 @@ async fn broadcast_update( gradient_file: &Path, cancel: CancellationToken, ) -> Result<(), Error> { + tracing::info!("Broadcasting update to {:?}", send); + let payload_len = fs::metadata(gradient_file).await?.len(); let mut writers = connector.send(send, payload_len).await?; diff --git a/crates/worker/src/network.rs b/crates/worker/src/network.rs index ffaa93a6..49dddb68 100644 --- a/crates/worker/src/network.rs +++ b/crates/worker/src/network.rs @@ -4,7 +4,7 @@ //! It ties together the networking primitives and drives the swarm. This //! documentation follows the [rustdoc guidelines](https://doc.rust-lang.org/rustdoc/how-to-write-documentation.html). -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, sync::Arc, time::Duration}; use futures_util::stream::StreamExt; use hypha_config::NetworkConfig; @@ -104,8 +104,10 @@ impl Network { exclude_cidrs: Vec, network_config: &NetworkConfig, ) -> Result<(Self, NetworkDriver), SwarmError> { - let (action_sender, action_receiver) = mpsc::channel(5); + let (action_sender, action_receiver) = mpsc::channel(64); let meter = metrics::global::meter(); + let request_timeout = + (Duration::from_millis(network_config.rtt_ms()) * 10).max(Duration::from_secs(10)); let swarm = SwarmBuilder::with_existing_identity(cert_chain, private_key, ca_certs, crls) .with_tokio() @@ -174,27 +176,28 @@ impl Network { StreamProtocol::new(api::IDENTIFIER), request_response::ProtocolSupport::Full, )], - request_response::Config::default(), + request_response::Config::default().with_request_timeout(request_timeout), ), health_request_response: request_response::Behaviour::::new( [( StreamProtocol::new(health::IDENTIFIER), request_response::ProtocolSupport::Full, )], - request_response::Config::default(), + request_response::Config::default().with_request_timeout(request_timeout), ), action_request_response: request_response::Behaviour::::new( [( StreamProtocol::new(action::IDENTIFIER), request_response::ProtocolSupport::Full, )], - request_response::Config::default(), + request_response::Config::default().with_request_timeout(request_timeout), ), } }) .map_err(|_| { SwarmError::BehaviourCreation("Failed to create swarm behavior.".to_string()) })? + .with_swarm_config(|c| c.with_idle_connection_timeout(Duration::from_secs(30))) .build(); Ok(( diff --git a/executors/accelerate/src/hypha/accelerate_executor/api.py b/executors/accelerate/src/hypha/accelerate_executor/api.py index f917af59..1187e623 100644 --- a/executors/accelerate/src/hypha/accelerate_executor/api.py +++ b/executors/accelerate/src/hypha/accelerate_executor/api.py @@ -1,6 +1,4 @@ -import json -from collections.abc import Iterator -from contextlib import AbstractContextManager, contextmanager +from contextlib import AbstractContextManager from types import TracebackType from typing import Any, override @@ -33,47 +31,32 @@ def send_resource(self, resource: Any, path: str, remove_file: bool = True, time req = {"resource": resource, "path": path, "timeout_ms": timeout_ms, "remove_file": remove_file} # We must allow the client to wait at least as long as the requested timeout. # If timeout is None, wait forever. - _ = self._client.post("http://hypha/resources/send", json=req, timeout=timeout).raise_for_status() + _ = self._client.post( + "http://hypha/resources/send", json=req, timeout=httpx.Timeout(None, connect=0.1) + ).raise_for_status() def send_action(self, payload: Any) -> Any: - resp = self._client.post("http://hypha/action/update", json=payload, timeout=None).raise_for_status() + resp = self._client.post( + "http://hypha/action/update", json=payload, timeout=httpx.Timeout(None, connect=0.1) + ).raise_for_status() return resp.json() def fetch(self, resource: Any) -> Any: - resp = self._client.post("http://hypha/resources/fetch", json=resource, timeout=None).raise_for_status() + resp = self._client.post( + "http://hypha/resources/fetch", json=resource, timeout=httpx.Timeout(None, connect=0.1) + ).raise_for_status() return resp.json() - @contextmanager - def receive(self, resource: Any, path: str, timeout: float | None = None) -> Iterator["EventSource"]: - req = {"resource": resource, "path": path} - # Use a short connect timeout to fail fast if the local side is unresponsive, - # but respect the provided timeout for the total duration/read. - # If timeout is None, we still enforce a connect timeout. - timeout_config = httpx.Timeout(timeout, connect=5.0) - with self._client.stream( - "POST", + def receive(self, resource: Any, path: str, timeout: Any | None = None) -> Any | None: + req = {"resource": resource, "path": path, "timeout": timeout} + resp = self._client.post( "http://hypha/resources/receive", json=req, - headers={"Accept": "text/event-stream"}, - timeout=timeout_config, - ) as resp: - yield EventSource(resp) - - -class EventSource: - def __init__(self, response: httpx.Response) -> None: - self._response: httpx.Response = response - - @property - def response(self) -> httpx.Response: - return self._response - - def __iter__(self) -> Iterator[Any]: - for line in self._response.iter_lines(): - fieldname, _, value = line.rstrip("\n").partition(":") - - if fieldname == "data": - result = json.loads(value) - - yield result - # Ignore other SSE fields (e.g., event:, id:, retry:) + timeout=httpx.Timeout(None, connect=0.1), + ) + if resp.status_code == httpx.codes.NO_CONTENT: + return None + resp.raise_for_status() + if resp.status_code == httpx.codes.NO_CONTENT: + return None + return resp.json() diff --git a/executors/accelerate/src/hypha/accelerate_executor/training.py b/executors/accelerate/src/hypha/accelerate_executor/training.py index f52bb644..343b3d02 100644 --- a/executors/accelerate/src/hypha/accelerate_executor/training.py +++ b/executors/accelerate/src/hypha/accelerate_executor/training.py @@ -82,7 +82,7 @@ def system_time_to_epoch_ms(timeout: object) -> int | None: def sleep_until_epoch_ms(target_ms: int) -> None: - now_ms = int(time.time() * 1000.0) + now_ms = time.time() * 1000.0 if target_ms > now_ms: time.sleep((target_ms - now_ms) / 1000.0) @@ -211,13 +211,8 @@ def main(socket_path: str, work_dir: str, job_json: str) -> None: # noqa: PLR09 } continue - timeout_ms = system_time_to_epoch_ms(action.get("timeout")) - timeout_sec = (timeout_ms - int(time.time() * 1000.0)) / 1000.0 if timeout_ms else None - if timeout_sec is not None and timeout_sec < 1.0: - timeout_sec = 1.0 - try: - session.send_resource(target, last_gradient, timeout=timeout_sec) + session.send_resource(target, last_gradient) current_status = { "executor": "train", "details": { @@ -248,49 +243,31 @@ def main(socket_path: str, work_dir: str, job_json: str) -> None: # noqa: PLR09 } continue - timeout_ms = system_time_to_epoch_ms(action.get("timeout")) - read_timeout = (timeout_ms - int(time.time() * 1000.0)) / 1000.0 if timeout_ms else None - if read_timeout is not None and read_timeout <= 0: - # Scheduler will tell us what to do next. - current_status = { - "executor": "train", - "details": { - "state": "error", - "type": "connection", - "message": "ApplyUpdate timeout reached before receive", - }, - } - continue - receive_path = f"incoming-{uuid.uuid4()}" try: - with session.receive(source, receive_path, timeout=read_timeout) as receiver: - updates_iter = iter(receiver) - pointers = next(updates_iter) - if pointers: - latest = pointers[-1] if isinstance(pointers, list) else pointers - parameters = ( - latest.get("parameters") if isinstance(latest.get("parameters"), dict) else None - ) - rel_path = parameters.get("path") if parameters else latest.get("path") - if isinstance(rel_path, str): - path = os.path.join(work_dir, rel_path) - model.load_state_dict(merge_models(previous_model_path, path)) - save_model(model, previous_model_path) - - # Once we updated the model, we no longer need the parameter file. - os.remove(path) - except StopIteration: - current_status = { - "executor": "train", - "details": { - "state": "error", - "type": "connection", - "message": "Receiver stream closed; no updates to merge.", - }, - } - continue + pointers = session.receive(source, receive_path, timeout=action.get("timeout")) + if pointers: + latest = pointers[-1] if isinstance(pointers, list) else pointers + parameters = latest.get("parameters") if isinstance(latest.get("parameters"), dict) else None + rel_path = parameters.get("path") if parameters else latest.get("path") + if isinstance(rel_path, str): + path = os.path.join(work_dir, rel_path) + model.load_state_dict(merge_models(previous_model_path, path)) + save_model(model, previous_model_path) + + # Once we updated the model, we no longer need the parameter file. + os.remove(path) + else: + current_status = { + "executor": "train", + "details": { + "state": "error", + "type": "connection", + "message": "No updates received before timeout.", + }, + } + continue except Exception as exc: # noqa: BLE001 current_status = { "executor": "train", @@ -346,13 +323,8 @@ def main(socket_path: str, work_dir: str, job_json: str) -> None: # noqa: PLR09 } continue - timeout_ms = system_time_to_epoch_ms(action.get("timeout")) - timeout_sec = (timeout_ms - int(time.time() * 1000.0)) / 1000.0 if timeout_ms else None - if timeout_sec is not None and timeout_sec < 1.0: - timeout_sec = 1.0 - try: - session.send_resource(target, CURRENT_MODEL_NAME, remove_file=False, timeout=timeout_sec) + session.send_resource(target, CURRENT_MODEL_NAME, remove_file=False) current_status = { "executor": "train", "details": {"state": "sent-model"}, @@ -380,43 +352,28 @@ def main(socket_path: str, work_dir: str, job_json: str) -> None: # noqa: PLR09 } continue - timeout_ms = system_time_to_epoch_ms(action.get("timeout")) - read_timeout = (timeout_ms - int(time.time() * 1000.0)) / 1000.0 if timeout_ms else None - if read_timeout is not None and read_timeout <= 0: - # Scheduler will tell us what to do next. - current_status = { - "executor": "train", - "details": { - "state": "error", - "type": "connection", - "message": "ReceiveModel timeout reached before receive", - }, - } - continue try: receive_path = f"incoming-{uuid.uuid4()}" - with session.receive(source, receive_path, timeout=read_timeout) as receiver: - updates_iter = iter(receiver) - pointers = next(updates_iter) - if pointers: - incomming = pointers[-1] if isinstance(pointers, list) else pointers - rel_path = incomming.get("path") - if isinstance(rel_path, str): - path = os.path.join(work_dir, rel_path) - model.load_state_dict(load_file(path)) - os.remove(previous_model_path) - shutil.copy(path, previous_model_path) - os.remove(path) - except StopIteration: - current_status = { - "executor": "train", - "details": { - "state": "error", - "type": "connection", - "message": "Receiver stream closed; no updates to merge.", - }, - } - continue + pointers = session.receive(source, receive_path, timeout=action.get("timeout")) + if pointers: + incomming = pointers[-1] if isinstance(pointers, list) else pointers + rel_path = incomming.get("path") + if isinstance(rel_path, str): + path = os.path.join(work_dir, rel_path) + model.load_state_dict(load_file(path)) + os.remove(previous_model_path) + shutil.copy(path, previous_model_path) + os.remove(path) + else: + current_status = { + "executor": "train", + "details": { + "state": "error", + "type": "connection", + "message": "No model received before timeout.", + }, + } + continue except Exception as exc: # noqa: BLE001 current_status = { "executor": "train", diff --git a/scripts/network-sim.sh b/scripts/network-sim.sh new file mode 100755 index 00000000..abd3c87d --- /dev/null +++ b/scripts/network-sim.sh @@ -0,0 +1,192 @@ +#!/bin/bash +# Network simulation script for Hypha testing using dnctl (dummynet) +# +# This script safely configures packet filtering rules to simulate network +# conditions (latency, packet loss, bandwidth limits) on localhost connections +# between Hypha components. +# +# Usage: +# sudo ./network-sim.sh start [delay_ms] [packet_loss_%] [bandwidth_kbit] +# sudo ./network-sim.sh status +# sudo ./network-sim.sh stop +# +# Examples: +# sudo ./network-sim.sh start 100 5 1000 # 100ms delay, 5% loss, 1Mbps +# sudo ./network-sim.sh start 50 0 10000 # 50ms delay, no loss, 10Mbps +# sudo ./network-sim.sh start 200 # 200ms delay only +# sudo ./network-sim.sh stop # Remove all rules + +set -euo pipefail + +# PF anchor name for isolated rule management +# IMPORTANT: must match default dummynet-anchor "com.apple/*" on macOS +ANCHOR="com.apple/hypha-test" +PIPE_NUM=1 + +show_usage() { + cat < 0.0500 + local loss_ratio + loss_ratio=$(bc <<< "scale=4; $loss_pct / 100") + pipe_config="$pipe_config plr ${loss_ratio}" + fi + + # Add bandwidth if requested + if [[ $bw_kbit -gt 0 ]]; then + pipe_config="$pipe_config bw ${bw_kbit}Kbit/s" + fi + + # NOTE: Split stats by flow (src/dst/proto/ports) + pipe_config="$pipe_config mask all" + + echo "Configuring dummynet pipe $PIPE_NUM..." + dnctl pipe "$PIPE_NUM" config $pipe_config + + echo "Configuring packet filter rules (anchor: $ANCHOR) for localhost traffic..." + + # Apply to ALL traffic on localhost (lo0), IPv4 and IPv6 + # 'quick' ensures these rules are applied immediately when matched. + local pf_rules="" + # IPv4 localhost + pf_rules+="dummynet in quick on lo0 inet all pipe $PIPE_NUM"$'\n' + pf_rules+="dummynet out quick on lo0 inet all pipe $PIPE_NUM"$'\n' + # IPv6 localhost + pf_rules+="dummynet in quick on lo0 inet6 all pipe $PIPE_NUM"$'\n' + pf_rules+="dummynet out quick on lo0 inet6 all pipe $PIPE_NUM"$'\n' + + # Apply rules to PF anchor (isolated from other rules); -q silences the -f warning + echo "$pf_rules" | pfctl -q -a "$ANCHOR" -f - + + # Enable PF if not already enabled + if ! pfctl -s info | grep -q "Status: Enabled"; then + echo "Enabling packet filter..." + # -E is reference-counted enable on macOS + pfctl -E >/dev/null 2>&1 || pfctl -e >/dev/null 2>&1 || true + fi + + echo "Network simulation started successfully!" + echo "" + echo "To adjust settings, run: $0 stop && sudo $0 start [new_params]" + echo "To stop simulation, run: sudo $0 stop" +} + +show_status() { + echo "=== Network Simulation Status ===" + echo "" + + # Check if PF is enabled + echo "Packet Filter Status:" + if pfctl -s info | grep -q "Status: Enabled"; then + echo " ✓ Enabled" + else + echo " ✗ Disabled (simulation not active)" + echo "" + return + fi + echo "" + + # Show our dummynet rules in the anchor + echo "Hypha dummynet rules (anchor: $ANCHOR):" + if pfctl -a "$ANCHOR" -s dummynet 2>/dev/null | grep -q .; then + pfctl -a "$ANCHOR" -s dummynet | sed 's/^/ /' + else + echo " (no dummynet rules configured)" + fi + echo "" + + # Show dummynet pipe configuration and counters + echo "Dummynet Pipe $PIPE_NUM:" + if dnctl pipe "$PIPE_NUM" show 2>/dev/null | grep -q .; then + dnctl pipe "$PIPE_NUM" show | sed 's/^/ /' + else + echo " (pipe not configured)" + fi +} + +stop_simulation() { + echo "Stopping network simulation..." + + # Flush rules from our anchor (only this anchor, not system rules) + pfctl -q -a "$ANCHOR" -F all 2>/dev/null || true + + # Delete dummynet pipe + dnctl pipe delete "$PIPE_NUM" 2>/dev/null || true + + echo "Network simulation stopped. Normal network conditions restored." + echo "" + echo "Note: PF remains enabled but localhost is no longer affected by this script." +} + +# Main command dispatch +case "${1:-}" in + start) + check_root + shift + start_simulation "$@" + ;; + status) + check_root + show_status + ;; + stop) + check_root + stop_simulation + ;; + -h|--help|help) + show_usage + ;; + *) + show_usage + exit 1 + ;; +esac