diff --git a/crates/messages/src/lib.rs b/crates/messages/src/lib.rs index 4957297f..2153f09a 100644 --- a/crates/messages/src/lib.rs +++ b/crates/messages/src/lib.rs @@ -163,13 +163,11 @@ pub mod action { }, SendModel { target: Reference, - timeout: SystemTime, }, ExecuteBatch, SendUpdate { target: Reference, weight: f32, - timeout: SystemTime, }, ApplyUpdate { source: Reference, diff --git a/crates/scheduler/src/scheduling/batch_scheduler.rs b/crates/scheduler/src/scheduling/batch_scheduler.rs index 30e70907..273bd5db 100644 --- a/crates/scheduler/src/scheduling/batch_scheduler.rs +++ b/crates/scheduler/src/scheduling/batch_scheduler.rs @@ -38,6 +38,7 @@ use crate::{ // decide when to instruct the parameter server to aggregate. #[derive(Default)] struct RoundState { + aggregated_updates: bool, sent_updates: HashSet, first_update_at: Option, min_quorum: usize, @@ -184,7 +185,7 @@ where .await .push_worker_without_model(peer_id); ExecutorAction::Train(TrainAction::WaitForModel { - timeout: now + Duration::from_secs(1), + timeout: now + Duration::from_millis(500), }) } } @@ -198,25 +199,21 @@ where strategy: SelectionStrategy::All, resource: None, }, - timeout: now + Duration::from_secs(60), + timeout: now + Duration::from_secs(10), }) } else { ExecutorAction::Train(TrainAction::WaitForModel { - timeout: now + Duration::from_secs(1), + timeout: now + Duration::from_millis(500), }) } } TrainStatus::ReceivedModel => { // Lazy transition to other state - ExecutorAction::Train(TrainAction::Idle { - timeout: now + Duration::from_secs(1), - }) + ExecutorAction::Train(TrainAction::Idle { timeout: now }) } TrainStatus::SentModel => { // Lazy transition to other state - ExecutorAction::Train(TrainAction::Idle { - timeout: now + Duration::from_secs(1), - }) + ExecutorAction::Train(TrainAction::Idle { timeout: now }) } TrainStatus::Idle => { let mut state = round_state.lock().await; @@ -281,13 +278,22 @@ where (false, count) }; - if !should_update { + if state.aggregated_updates { + ExecutorAction::Train(TrainAction::ApplyUpdate { + source: Reference::Peers { + peers: parameter_servers, + strategy: SelectionStrategy::All, + resource: None, + }, + timeout: now + Duration::from_secs(10), + }) + } else if !should_update { ExecutorAction::Train(TrainAction::ExecuteBatch) } else if parameter_servers.is_empty() { // NOTE: If we need to send an update but there are no parameter servers, // we must wait (idle) until one becomes available. ExecutorAction::Train(TrainAction::Idle { - timeout: now + Duration::from_secs(1), + timeout: now + Duration::from_millis(500), }) } else { ExecutorAction::Train(TrainAction::SendUpdate { @@ -298,8 +304,6 @@ where resource: None, }, weight: peer_contribution as f32 / projected_target as f32, - // TODO: We need a way to properly determine a good sent timeout - timeout: now + Duration::from_secs(30), }) } } else if state.push_done { @@ -321,7 +325,7 @@ where } } else { ExecutorAction::Train(TrainAction::Idle { - timeout: now + Duration::from_secs(1), + timeout: now + Duration::from_millis(500), }) } } @@ -346,7 +350,7 @@ where if round_state.lock().await.training_complete { ExecutorAction::Train(TrainAction::Idle { - timeout: now + Duration::from_secs(1), + timeout: now + Duration::from_millis(500), }) } else { let stats: Vec = @@ -400,7 +404,7 @@ where // NOTE: If we need to send an update but there are no parameter servers, // we must wait (idle) until one becomes available. ExecutorAction::Train(TrainAction::Idle { - timeout: now + Duration::from_secs(1), + timeout: now + Duration::from_millis(500), }) } else { ExecutorAction::Train(TrainAction::SendUpdate { @@ -411,8 +415,6 @@ where resource: None, }, weight: peer_contribution as f32 / projected_target as f32, - // TODO: We need a way to properly determine a good sent timeout - timeout: now + Duration::from_secs(30), }) } } @@ -443,20 +445,10 @@ where since_first_ms = elapsed_ms, "Worker reported SentUpdate; recorded for round" ); - if parameter_servers.is_empty() { - ExecutorAction::Train(TrainAction::Idle { - timeout: now + Duration::from_secs(1), - }) - } else { - ExecutorAction::Train(TrainAction::ApplyUpdate { - source: Reference::Peers { - peers: parameter_servers, - strategy: SelectionStrategy::All, - resource: None, - }, - timeout: now + Duration::from_secs(30), - }) - } + + ExecutorAction::Train(TrainAction::Idle { + timeout: now + Duration::from_millis(500), + }) } TrainStatus::AppliedUpdate => { let training_complete = { @@ -472,7 +464,7 @@ where if training_complete { ExecutorAction::Train(TrainAction::Idle { - timeout: now + Duration::from_secs(1), + timeout: now + Duration::from_millis(500), }) } else { let mut training = training_state.lock().await; @@ -484,7 +476,6 @@ where strategy: SelectionStrategy::One, resource: None, }, - timeout: now + Duration::from_secs(30), }) } else { ExecutorAction::Train(TrainAction::ExecuteBatch) @@ -508,7 +499,7 @@ where } } ExecutorAction::Train(TrainAction::Idle { - timeout: now + Duration::from_secs(1), + timeout: now + Duration::from_millis(500), }) } TrainStatus::Error(TrainError::Other { message }) => { @@ -537,7 +528,7 @@ where ); } ExecutorAction::Aggregate(AggregateAction::Idle { - timeout: now + Duration::from_secs(5), + timeout: now + Duration::from_millis(500), }) } else { let workers: Vec<_> = { @@ -553,7 +544,7 @@ where if workers.is_empty() { ExecutorAction::Aggregate(AggregateAction::Idle { - timeout: now + Duration::from_secs(1), + timeout: now + Duration::from_millis(500), }) } else { // Start aggregation when either all workers have sent updates, @@ -601,12 +592,13 @@ where if workers.is_empty() { ExecutorAction::Aggregate(AggregateAction::Idle { - timeout: now + Duration::from_secs(1), + timeout: now + Duration::from_millis(500), }) } else { // Log that we are moving to broadcast for this round. let round = { - let state = round_state.lock().await; + let mut state = round_state.lock().await; + state.aggregated_updates = true; state.round }; tracing::info!(round = %round, "Trigger BroadcastUpdate"); @@ -670,14 +662,14 @@ where ExecutorAction::Aggregate(AggregateAction::Terminate) } else { ExecutorAction::Aggregate(AggregateAction::Idle { - timeout: now + Duration::from_secs(1), + timeout: now + Duration::from_millis(500), }) } } AggregateStatus::Error(AggregateError::Connection { message }) => { tracing::warn!(%peer_id, message = %message, "Aggregator reported connection error"); ExecutorAction::Aggregate(AggregateAction::Idle { - timeout: now + Duration::from_secs(1), + timeout: now + Duration::from_millis(500), }) } AggregateStatus::Error(AggregateError::Other { message }) => { @@ -740,6 +732,7 @@ impl BatchScheduler { training_complete: false, applied_final_update: HashSet::new(), push_done: false, + aggregated_updates: false, })); let training_state = Arc::new(Mutex::new(TrainingState::new(samples_between_updates))); network @@ -1276,6 +1269,7 @@ mod batch_scheduler_tests { training_complete: false, applied_final_update: Default::default(), push_done: false, + aggregated_updates: false, })); let training_state = std::sync::Arc::new(tokio::sync::Mutex::new(TrainingState::new(800))); let batch_sizer = std::sync::Arc::new(|resources: &Resources| resources.gpu() as u32); @@ -1351,7 +1345,6 @@ mod batch_scheduler_tests { resource: None, }, weight: 0.3, - timeout: SystemTime::now(), }), 2000, ), @@ -1365,7 +1358,6 @@ mod batch_scheduler_tests { resource: None, }, weight: 0.3, - timeout: SystemTime::now(), }), 2400, ), diff --git a/crates/worker/src/executor/bridge.rs b/crates/worker/src/executor/bridge.rs index db1197f0..0cbfb8df 100644 --- a/crates/worker/src/executor/bridge.rs +++ b/crates/worker/src/executor/bridge.rs @@ -3,20 +3,17 @@ use std::{ os::unix::fs::{MetadataExt, PermissionsExt}, path::{Path, PathBuf}, sync::Arc, - time::Duration, + time::SystemTime, }; use axum::{ Json, Router, extract::State, http::StatusCode, - response::{ - IntoResponse, Response, - sse::{Event, KeepAlive, Sse}, - }, + response::{IntoResponse, Response}, routing::{get, post}, }; -use futures_util::{StreamExt, stream}; +use futures_util::StreamExt; use hypha_data::hash::get_file_hash; use hypha_messages::{ DataSlice, Fetch, Receive, Reference, Send, @@ -35,10 +32,11 @@ use tokio::{ fs::{self, set_permissions}, io::{self, AsyncWriteExt}, net::UnixListener, + time::sleep, }; use tokio_retry::{ Retry, - strategy::{ExponentialBackoff, FibonacciBackoff, jitter}, + strategy::{ExponentialBackoff, FibonacciBackoff, FixedInterval, jitter}, }; use tokio_util::{sync::CancellationToken, task::TaskTracker}; use utoipa::OpenApi; @@ -138,7 +136,6 @@ struct SockState { network: Network, job_id: Uuid, scheduler: PeerId, - task_tracker: TaskTracker, cancel: CancellationToken, } @@ -171,7 +168,6 @@ impl Bridge { network, job_id, scheduler, - task_tracker: task_tracker.clone(), cancel: cancel_token.clone(), }); @@ -408,9 +404,7 @@ async fn send_resource( State(state): State>, Json(req): Json, ) -> Result<(), Error> { - let retry_strategy = ExponentialBackoff::from_millis(100) - .map(jitter) // add jitter to delays - .take(3); // limit to 3 retries + let retry_strategy = FixedInterval::from_millis(50).map(jitter).take(20); Retry::spawn(retry_strategy, || { let state = state.clone(); @@ -518,6 +512,7 @@ fn validate_fetch(resource: &Fetch) -> Result<(), Error> { struct ReceiveSubscribeRequest { resource: Receive, path: Option, + timeout: Option, } #[derive(Debug, Serialize)] @@ -530,119 +525,114 @@ struct UpdatePointer { async fn receive_subscribe( State(state): State>, Json(req): Json, -) -> Result>>, Error> { +) -> Result { let dir_rel = req.path.unwrap_or_else(|| "incoming".to_string()); let dir_abs = safe_join(&state.work_dir, &dir_rel)?; fs::create_dir_all(&dir_abs).await?; - // Channel to push events to the SSE stream - let (tx, rx) = tokio::sync::mpsc::channel::(64); - let connector = state.connector.clone(); + let idle_timeout = req + .timeout + .and_then(|t| t.duration_since(SystemTime::now()).ok()); + if idle_timeout + .as_ref() + .is_some_and(|duration| duration.is_zero()) + { + return Ok(StatusCode::NO_CONTENT.into_response()); + } + + let mut incoming = match state.connector.receive(req.resource.clone()).await { + Ok(s) => s, + Err(err) => { + tracing::error!(error = %err, path = %dir_rel, "receive_subscribe: failed to start stream"); + return Ok(StatusCode::NO_CONTENT.into_response()); + } + }; let work_dir = state.work_dir.clone(); - let resource = req.resource.clone(); let cancel = state.cancel.clone(); - let task_tracker = state.task_tracker.clone(); - let dir_rel_clone = dir_rel.clone(); + let mut idle_timer = idle_timeout.map(|duration| Box::pin(sleep(duration))); + let mut pointer: Option = None; - // Background task: receive loops until the client disconnects or an error occurs - task_tracker.spawn(async move { - let mut incoming = match connector.receive(resource).await { - Ok(s) => s, + while let Some(item_result) = tokio::select! { + _ = cancel.cancelled() => { + tracing::debug!(path = %dir_rel, "receive_subscribe: task cancelled"); + None + } + _ = async { + if let Some(timer) = idle_timer.as_mut() { + timer.as_mut().await; + } else { + std::future::pending::<()>().await; + } + }, + if idle_timer.is_some() => { + tracing::warn!(path = %dir_rel, "receive_subscribe: idle timeout reached"); + None + } + item = incoming.next() => item, + } { + let item = match item_result { + Ok(item) => item, Err(err) => { - tracing::error!(error = %err, path = %dir_rel_clone, "receive_subscribe: failed to start stream"); - return; + tracing::warn!(error = %err, path = %dir_rel, "receive_subscribe: stream error"); + continue; } }; - let mut index = 0usize; - while let Some(item_result) = tokio::select! { - _ = tx.closed() => { - tracing::debug!(path = %dir_rel_clone, "receive_subscribe: client stream dropped"); - None - } - _ = cancel.cancelled() => { - tracing::debug!(path = %dir_rel_clone, "receive_subscribe: task cancelled"); - None - } - item = incoming.next() => item, - } { - let item = match item_result { - Ok(item) => item, - Err(err) => { - tracing::warn!(error = %err, path = %dir_rel_clone, "receive_subscribe: stream error"); - continue; - } - }; - let (file_name, mut reader) = derive_name_and_reader(item, index); - let file_rel = format!("{}/{}", dir_rel_clone, file_name); - let file_abs = match safe_join(&work_dir, &file_rel) { - Ok(p) => p, - Err(err) => { - tracing::error!(error = %err, file = %file_rel, "receive_subscribe: invalid target path"); - continue; - } - }; - if let Some(parent) = file_abs.parent() { - match fs::create_dir_all(parent).await { - Ok(()) => (), - Err(err) => { - tracing::error!(error = %err, directory = %parent.display(), "receive_subscribe: failed to create directory"); - continue; - } - } + // Once data starts flowing, disable idle timeout so long copies are not interrupted. + idle_timer = None; + let (file_name, mut reader) = derive_name_and_reader(item, 0); + let file_rel = format!("{}/{}", dir_rel, file_name); + let file_abs = match safe_join(&work_dir, &file_rel) { + Ok(p) => p, + Err(err) => { + tracing::error!(error = %err, file = %file_rel, "receive_subscribe: invalid target path"); + continue; } - let mut file = match fs::File::create(&file_abs).await { - Ok(f) => f, - Err(err) => { - tracing::error!(error = %err, file = %file_abs.display(), "receive_subscribe: failed to create file"); - continue; - } - }; - let size = match tokio::io::copy(&mut reader, &mut file).await { - Ok(n) => n, + }; + if let Some(parent) = file_abs.parent() { + match fs::create_dir_all(parent).await { + Ok(()) => (), Err(err) => { - tracing::warn!(error = %err, file = %file_abs.display(), "receive_subscribe: failed to copy resource"); + tracing::error!(error = %err, directory = %parent.display(), "receive_subscribe: failed to create directory"); continue; } - }; - if let Err(err) = file.sync_all().await { - tracing::warn!(error = %err, file = %file_abs.display(), "receive_subscribe: failed to sync file"); } - if let Err(err) = set_permissions(&file_abs, Permissions::from_mode(0o600)).await { - tracing::warn!(error = %err, file = %file_abs.display(), "receive_subscribe: failed to set permissions"); + } + let mut file = match fs::File::create(&file_abs).await { + Ok(f) => f, + Err(err) => { + tracing::error!(error = %err, file = %file_abs.display(), "receive_subscribe: failed to create file"); + continue; } - - tracing::info!(size, file = %file_abs.display(), "Received resource"); - - let from_peer = file_name.split('.').next().unwrap_or("").to_string(); - let pointer = UpdatePointer { - path: file_rel, - size, - from_peer, - }; - let ev = match serde_json::to_string(&pointer) { - Ok(data) => Event::default().data(data), - Err(err) => { - tracing::error!(error = %err, "receive_subscribe: failed to serialize pointer"); - Event::default().data(r#"{"error":"serialize"}"#) - } - }; - if tx.send(ev).await.is_err() { - tracing::debug!(path = %dir_rel_clone, "receive_subscribe: client disconnected"); - break; + }; + let size = match tokio::io::copy(&mut reader, &mut file).await { + Ok(n) => n, + Err(err) => { + tracing::warn!(error = %err, file = %file_abs.display(), "receive_subscribe: failed to copy resource"); + continue; } - index += 1; + }; + if let Err(err) = file.sync_all().await { + tracing::warn!(error = %err, file = %file_abs.display(), "receive_subscribe: failed to sync file"); + } + if let Err(err) = set_permissions(&file_abs, Permissions::from_mode(0o600)).await { + tracing::warn!(error = %err, file = %file_abs.display(), "receive_subscribe: failed to set permissions"); } - }); - let stream = stream::unfold(rx, |mut rx| async move { - rx.recv().await.map(|ev| (Ok(ev), rx)) - }); + tracing::info!(size, file = %file_abs.display(), "Received resource"); - Ok(Sse::new(stream).keep_alive( - KeepAlive::new() - .interval(Duration::from_secs(5)) - .text("keepalive"), - )) + let from_peer = file_name.split('.').next().unwrap_or("").to_string(); + pointer = Some(UpdatePointer { + path: file_rel, + size, + from_peer, + }); + break; + } + + Ok(match pointer { + Some(p) => (StatusCode::OK, Json(p)).into_response(), + None => StatusCode::NO_CONTENT.into_response(), + }) } async fn send_action( diff --git a/crates/worker/src/executor/parameter_server.rs b/crates/worker/src/executor/parameter_server.rs index ef08cb54..a60a7199 100644 --- a/crates/worker/src/executor/parameter_server.rs +++ b/crates/worker/src/executor/parameter_server.rs @@ -28,7 +28,7 @@ use tokio::{ }; use tokio_retry::{ Retry, - strategy::{ExponentialBackoff, jitter}, + strategy::{FixedInterval, jitter}, }; use tokio_util::{future::FutureExt, sync::CancellationToken, task::TaskTracker}; use uuid::Uuid; @@ -92,9 +92,9 @@ impl JobExecutor for ParameterServerExecutor { ) -> Result { tracing::info!(job_spec = ?job, "Executing parameter server job"); - let retry_strategy = ExponentialBackoff::from_millis(100) - .map(jitter) // add jitter to delays - .take(3); // limit to 3 retries + // NOTE: Retry for a second, please note that this needs to align with + // the batch scheduler timings + let retry_strategy = FixedInterval::from_millis(50).map(jitter).take(20); let id = Uuid::new_v4(); let work_dir = self.work_dir_base.join(format!("hypha-{}", id)); @@ -535,6 +535,8 @@ async fn broadcast_update( gradient_file: &Path, cancel: CancellationToken, ) -> Result<(), Error> { + tracing::info!("Broadcasting update to {:?}", send); + let payload_len = fs::metadata(gradient_file).await?.len(); let mut writers = connector.send(send, payload_len).await?; diff --git a/executors/accelerate/src/hypha/accelerate_executor/api.py b/executors/accelerate/src/hypha/accelerate_executor/api.py index f917af59..4956b61b 100644 --- a/executors/accelerate/src/hypha/accelerate_executor/api.py +++ b/executors/accelerate/src/hypha/accelerate_executor/api.py @@ -1,6 +1,4 @@ -import json -from collections.abc import Iterator -from contextlib import AbstractContextManager, contextmanager +from contextlib import AbstractContextManager from types import TracebackType from typing import Any, override @@ -43,37 +41,19 @@ def fetch(self, resource: Any) -> Any: resp = self._client.post("http://hypha/resources/fetch", json=resource, timeout=None).raise_for_status() return resp.json() - @contextmanager - def receive(self, resource: Any, path: str, timeout: float | None = None) -> Iterator["EventSource"]: - req = {"resource": resource, "path": path} + def receive(self, resource: Any, path: str, timeout: Any | None = None) -> Any | None: + req = {"resource": resource, "path": path, "timeout": timeout} # Use a short connect timeout to fail fast if the local side is unresponsive, - # but respect the provided timeout for the total duration/read. - # If timeout is None, we still enforce a connect timeout. - timeout_config = httpx.Timeout(timeout, connect=5.0) - with self._client.stream( - "POST", + # but hand off read timing entirely to the bridge. + timeout_config = httpx.Timeout(None, connect=5.0) + resp = self._client.post( "http://hypha/resources/receive", json=req, - headers={"Accept": "text/event-stream"}, timeout=timeout_config, - ) as resp: - yield EventSource(resp) - - -class EventSource: - def __init__(self, response: httpx.Response) -> None: - self._response: httpx.Response = response - - @property - def response(self) -> httpx.Response: - return self._response - - def __iter__(self) -> Iterator[Any]: - for line in self._response.iter_lines(): - fieldname, _, value = line.rstrip("\n").partition(":") - - if fieldname == "data": - result = json.loads(value) - - yield result - # Ignore other SSE fields (e.g., event:, id:, retry:) + ) + if resp.status_code == httpx.codes.NO_CONTENT: + return None + resp.raise_for_status() + if resp.status_code == httpx.codes.NO_CONTENT: + return None + return resp.json() diff --git a/executors/accelerate/src/hypha/accelerate_executor/training.py b/executors/accelerate/src/hypha/accelerate_executor/training.py index f52bb644..343b3d02 100644 --- a/executors/accelerate/src/hypha/accelerate_executor/training.py +++ b/executors/accelerate/src/hypha/accelerate_executor/training.py @@ -82,7 +82,7 @@ def system_time_to_epoch_ms(timeout: object) -> int | None: def sleep_until_epoch_ms(target_ms: int) -> None: - now_ms = int(time.time() * 1000.0) + now_ms = time.time() * 1000.0 if target_ms > now_ms: time.sleep((target_ms - now_ms) / 1000.0) @@ -211,13 +211,8 @@ def main(socket_path: str, work_dir: str, job_json: str) -> None: # noqa: PLR09 } continue - timeout_ms = system_time_to_epoch_ms(action.get("timeout")) - timeout_sec = (timeout_ms - int(time.time() * 1000.0)) / 1000.0 if timeout_ms else None - if timeout_sec is not None and timeout_sec < 1.0: - timeout_sec = 1.0 - try: - session.send_resource(target, last_gradient, timeout=timeout_sec) + session.send_resource(target, last_gradient) current_status = { "executor": "train", "details": { @@ -248,49 +243,31 @@ def main(socket_path: str, work_dir: str, job_json: str) -> None: # noqa: PLR09 } continue - timeout_ms = system_time_to_epoch_ms(action.get("timeout")) - read_timeout = (timeout_ms - int(time.time() * 1000.0)) / 1000.0 if timeout_ms else None - if read_timeout is not None and read_timeout <= 0: - # Scheduler will tell us what to do next. - current_status = { - "executor": "train", - "details": { - "state": "error", - "type": "connection", - "message": "ApplyUpdate timeout reached before receive", - }, - } - continue - receive_path = f"incoming-{uuid.uuid4()}" try: - with session.receive(source, receive_path, timeout=read_timeout) as receiver: - updates_iter = iter(receiver) - pointers = next(updates_iter) - if pointers: - latest = pointers[-1] if isinstance(pointers, list) else pointers - parameters = ( - latest.get("parameters") if isinstance(latest.get("parameters"), dict) else None - ) - rel_path = parameters.get("path") if parameters else latest.get("path") - if isinstance(rel_path, str): - path = os.path.join(work_dir, rel_path) - model.load_state_dict(merge_models(previous_model_path, path)) - save_model(model, previous_model_path) - - # Once we updated the model, we no longer need the parameter file. - os.remove(path) - except StopIteration: - current_status = { - "executor": "train", - "details": { - "state": "error", - "type": "connection", - "message": "Receiver stream closed; no updates to merge.", - }, - } - continue + pointers = session.receive(source, receive_path, timeout=action.get("timeout")) + if pointers: + latest = pointers[-1] if isinstance(pointers, list) else pointers + parameters = latest.get("parameters") if isinstance(latest.get("parameters"), dict) else None + rel_path = parameters.get("path") if parameters else latest.get("path") + if isinstance(rel_path, str): + path = os.path.join(work_dir, rel_path) + model.load_state_dict(merge_models(previous_model_path, path)) + save_model(model, previous_model_path) + + # Once we updated the model, we no longer need the parameter file. + os.remove(path) + else: + current_status = { + "executor": "train", + "details": { + "state": "error", + "type": "connection", + "message": "No updates received before timeout.", + }, + } + continue except Exception as exc: # noqa: BLE001 current_status = { "executor": "train", @@ -346,13 +323,8 @@ def main(socket_path: str, work_dir: str, job_json: str) -> None: # noqa: PLR09 } continue - timeout_ms = system_time_to_epoch_ms(action.get("timeout")) - timeout_sec = (timeout_ms - int(time.time() * 1000.0)) / 1000.0 if timeout_ms else None - if timeout_sec is not None and timeout_sec < 1.0: - timeout_sec = 1.0 - try: - session.send_resource(target, CURRENT_MODEL_NAME, remove_file=False, timeout=timeout_sec) + session.send_resource(target, CURRENT_MODEL_NAME, remove_file=False) current_status = { "executor": "train", "details": {"state": "sent-model"}, @@ -380,43 +352,28 @@ def main(socket_path: str, work_dir: str, job_json: str) -> None: # noqa: PLR09 } continue - timeout_ms = system_time_to_epoch_ms(action.get("timeout")) - read_timeout = (timeout_ms - int(time.time() * 1000.0)) / 1000.0 if timeout_ms else None - if read_timeout is not None and read_timeout <= 0: - # Scheduler will tell us what to do next. - current_status = { - "executor": "train", - "details": { - "state": "error", - "type": "connection", - "message": "ReceiveModel timeout reached before receive", - }, - } - continue try: receive_path = f"incoming-{uuid.uuid4()}" - with session.receive(source, receive_path, timeout=read_timeout) as receiver: - updates_iter = iter(receiver) - pointers = next(updates_iter) - if pointers: - incomming = pointers[-1] if isinstance(pointers, list) else pointers - rel_path = incomming.get("path") - if isinstance(rel_path, str): - path = os.path.join(work_dir, rel_path) - model.load_state_dict(load_file(path)) - os.remove(previous_model_path) - shutil.copy(path, previous_model_path) - os.remove(path) - except StopIteration: - current_status = { - "executor": "train", - "details": { - "state": "error", - "type": "connection", - "message": "Receiver stream closed; no updates to merge.", - }, - } - continue + pointers = session.receive(source, receive_path, timeout=action.get("timeout")) + if pointers: + incomming = pointers[-1] if isinstance(pointers, list) else pointers + rel_path = incomming.get("path") + if isinstance(rel_path, str): + path = os.path.join(work_dir, rel_path) + model.load_state_dict(load_file(path)) + os.remove(previous_model_path) + shutil.copy(path, previous_model_path) + os.remove(path) + else: + current_status = { + "executor": "train", + "details": { + "state": "error", + "type": "connection", + "message": "No model received before timeout.", + }, + } + continue except Exception as exc: # noqa: BLE001 current_status = { "executor": "train",