Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions crates/messages/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,11 @@ pub mod action {
},
SendModel {
target: Reference,
timeout: SystemTime,
},
ExecuteBatch,
SendUpdate {
target: Reference,
weight: f32,
timeout: SystemTime,
},
ApplyUpdate {
source: Reference,
Expand Down
78 changes: 35 additions & 43 deletions crates/scheduler/src/scheduling/batch_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use crate::{
// decide when to instruct the parameter server to aggregate.
#[derive(Default)]
struct RoundState {
aggregated_updates: bool,
sent_updates: HashSet<PeerId>,
first_update_at: Option<Instant>,
min_quorum: usize,
Expand Down Expand Up @@ -184,7 +185,7 @@ where
.await
.push_worker_without_model(peer_id);
ExecutorAction::Train(TrainAction::WaitForModel {
timeout: now + Duration::from_secs(1),
timeout: now + Duration::from_millis(500),
})
}
}
Expand All @@ -198,25 +199,21 @@ where
strategy: SelectionStrategy::All,
resource: None,
},
timeout: now + Duration::from_secs(60),
timeout: now + Duration::from_secs(10),
})
} else {
ExecutorAction::Train(TrainAction::WaitForModel {
timeout: now + Duration::from_secs(1),
timeout: now + Duration::from_millis(500),
})
}
}
TrainStatus::ReceivedModel => {
// Lazy transition to other state
ExecutorAction::Train(TrainAction::Idle {
timeout: now + Duration::from_secs(1),
})
ExecutorAction::Train(TrainAction::Idle { timeout: now })
}
TrainStatus::SentModel => {
// Lazy transition to other state
ExecutorAction::Train(TrainAction::Idle {
timeout: now + Duration::from_secs(1),
})
ExecutorAction::Train(TrainAction::Idle { timeout: now })
}
TrainStatus::Idle => {
let mut state = round_state.lock().await;
Expand Down Expand Up @@ -281,13 +278,22 @@ where
(false, count)
};

if !should_update {
if state.aggregated_updates {
ExecutorAction::Train(TrainAction::ApplyUpdate {
source: Reference::Peers {
peers: parameter_servers,
strategy: SelectionStrategy::All,
resource: None,
},
timeout: now + Duration::from_secs(10),
})
} else if !should_update {
ExecutorAction::Train(TrainAction::ExecuteBatch)
} else if parameter_servers.is_empty() {
// NOTE: If we need to send an update but there are no parameter servers,
// we must wait (idle) until one becomes available.
ExecutorAction::Train(TrainAction::Idle {
timeout: now + Duration::from_secs(1),
timeout: now + Duration::from_millis(500),
})
} else {
ExecutorAction::Train(TrainAction::SendUpdate {
Expand All @@ -298,8 +304,6 @@ where
resource: None,
},
weight: peer_contribution as f32 / projected_target as f32,
// TODO: We need a way to properly determine a good sent timeout
timeout: now + Duration::from_secs(30),
})
}
} else if state.push_done {
Expand All @@ -321,7 +325,7 @@ where
}
} else {
ExecutorAction::Train(TrainAction::Idle {
timeout: now + Duration::from_secs(1),
timeout: now + Duration::from_millis(500),
})
}
}
Expand All @@ -346,7 +350,7 @@ where

if round_state.lock().await.training_complete {
ExecutorAction::Train(TrainAction::Idle {
timeout: now + Duration::from_secs(1),
timeout: now + Duration::from_millis(500),
})
} else {
let stats: Vec<u64> =
Expand Down Expand Up @@ -400,7 +404,7 @@ where
// NOTE: If we need to send an update but there are no parameter servers,
// we must wait (idle) until one becomes available.
ExecutorAction::Train(TrainAction::Idle {
timeout: now + Duration::from_secs(1),
timeout: now + Duration::from_millis(500),
})
} else {
ExecutorAction::Train(TrainAction::SendUpdate {
Expand All @@ -411,8 +415,6 @@ where
resource: None,
},
weight: peer_contribution as f32 / projected_target as f32,
// TODO: We need a way to properly determine a good sent timeout
timeout: now + Duration::from_secs(30),
})
}
}
Expand Down Expand Up @@ -443,20 +445,10 @@ where
since_first_ms = elapsed_ms,
"Worker reported SentUpdate; recorded for round"
);
if parameter_servers.is_empty() {
ExecutorAction::Train(TrainAction::Idle {
timeout: now + Duration::from_secs(1),
})
} else {
ExecutorAction::Train(TrainAction::ApplyUpdate {
source: Reference::Peers {
peers: parameter_servers,
strategy: SelectionStrategy::All,
resource: None,
},
timeout: now + Duration::from_secs(30),
})
}

ExecutorAction::Train(TrainAction::Idle {
timeout: now + Duration::from_millis(500),
})
}
TrainStatus::AppliedUpdate => {
let training_complete = {
Expand All @@ -472,7 +464,7 @@ where

if training_complete {
ExecutorAction::Train(TrainAction::Idle {
timeout: now + Duration::from_secs(1),
timeout: now + Duration::from_millis(500),
})
} else {
let mut training = training_state.lock().await;
Expand All @@ -484,7 +476,6 @@ where
strategy: SelectionStrategy::One,
resource: None,
},
timeout: now + Duration::from_secs(30),
})
} else {
ExecutorAction::Train(TrainAction::ExecuteBatch)
Expand All @@ -508,7 +499,7 @@ where
}
}
ExecutorAction::Train(TrainAction::Idle {
timeout: now + Duration::from_secs(1),
timeout: now + Duration::from_millis(500),
})
}
TrainStatus::Error(TrainError::Other { message }) => {
Expand Down Expand Up @@ -537,7 +528,7 @@ where
);
}
ExecutorAction::Aggregate(AggregateAction::Idle {
timeout: now + Duration::from_secs(5),
timeout: now + Duration::from_millis(500),
})
} else {
let workers: Vec<_> = {
Expand All @@ -553,7 +544,7 @@ where

if workers.is_empty() {
ExecutorAction::Aggregate(AggregateAction::Idle {
timeout: now + Duration::from_secs(1),
timeout: now + Duration::from_millis(500),
})
} else {
// Start aggregation when either all workers have sent updates,
Expand Down Expand Up @@ -601,12 +592,13 @@ where

if workers.is_empty() {
ExecutorAction::Aggregate(AggregateAction::Idle {
timeout: now + Duration::from_secs(1),
timeout: now + Duration::from_millis(500),
})
} else {
// Log that we are moving to broadcast for this round.
let round = {
let state = round_state.lock().await;
let mut state = round_state.lock().await;
state.aggregated_updates = true;
state.round
};
tracing::info!(round = %round, "Trigger BroadcastUpdate");
Expand Down Expand Up @@ -670,14 +662,14 @@ where
ExecutorAction::Aggregate(AggregateAction::Terminate)
} else {
ExecutorAction::Aggregate(AggregateAction::Idle {
timeout: now + Duration::from_secs(1),
timeout: now + Duration::from_millis(500),
})
}
}
AggregateStatus::Error(AggregateError::Connection { message }) => {
tracing::warn!(%peer_id, message = %message, "Aggregator reported connection error");
ExecutorAction::Aggregate(AggregateAction::Idle {
timeout: now + Duration::from_secs(1),
timeout: now + Duration::from_millis(500),
})
}
AggregateStatus::Error(AggregateError::Other { message }) => {
Expand Down Expand Up @@ -740,6 +732,7 @@ impl BatchScheduler {
training_complete: false,
applied_final_update: HashSet::new(),
push_done: false,
aggregated_updates: false,
}));
let training_state = Arc::new(Mutex::new(TrainingState::new(samples_between_updates)));
network
Expand Down Expand Up @@ -1276,6 +1269,7 @@ mod batch_scheduler_tests {
training_complete: false,
applied_final_update: Default::default(),
push_done: false,
aggregated_updates: false,
}));
let training_state = std::sync::Arc::new(tokio::sync::Mutex::new(TrainingState::new(800)));
let batch_sizer = std::sync::Arc::new(|resources: &Resources| resources.gpu() as u32);
Expand Down Expand Up @@ -1351,7 +1345,6 @@ mod batch_scheduler_tests {
resource: None,
},
weight: 0.3,
timeout: SystemTime::now(),
}),
2000,
),
Expand All @@ -1365,7 +1358,6 @@ mod batch_scheduler_tests {
resource: None,
},
weight: 0.3,
timeout: SystemTime::now(),
}),
2400,
),
Expand Down
Loading