Skip to content

Commit

Permalink
fix(launcher): fix issue where launcher does not properly report shar…
Browse files Browse the repository at this point in the history
…d failures (huggingface#522)
  • Loading branch information
OlivierDehaene authored Jun 30, 2023
1 parent ecf6dc3 commit 2b53d71
Showing 1 changed file with 37 additions and 21 deletions.
58 changes: 37 additions & 21 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ use std::io::{BufRead, BufReader, Read};
use std::path::Path;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::TryRecvError;
use std::sync::Arc;
use std::sync::{mpsc, Mutex};
use std::sync::{mpsc, Arc};
use std::thread;
use std::thread::sleep;
use std::time::{Duration, Instant};
Expand Down Expand Up @@ -274,7 +273,7 @@ struct Args {
#[derive(Debug)]
enum ShardStatus {
Ready,
Failed((usize, String)),
Failed((usize, Option<String>)),
}

#[allow(clippy::too_many_arguments)]
Expand All @@ -296,7 +295,7 @@ fn shard_manager(
watermark_delta: Option<f32>,
otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<Mutex<bool>>,
shutdown: Arc<AtomicBool>,
_shutdown_sender: mpsc::Sender<()>,
) {
// Get UDS path
Expand Down Expand Up @@ -433,20 +432,20 @@ fn shard_manager(
}
}
status_sender
.send(ShardStatus::Failed((rank, err.to_string())))
.send(ShardStatus::Failed((rank, Some(err.to_string()))))
.unwrap();
return;
}
};

// Redirect STDOUT to the console
let shard_stdout = p.stdout.take().unwrap();
let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
let mut shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());

thread::spawn(move || {
// Enter shard-manager tracing span
let stdout = BufReader::new(shard_stdout);
let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();
for line in stdout.lines() {
for line in shard_stdout_reader.lines() {
// Parse loguru logs
if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
log.trace();
Expand All @@ -460,8 +459,22 @@ fn shard_manager(
loop {
// Process exited
if let Some(exit_status) = p.poll() {
let mut err = String::new();
p.stderr.take().unwrap().read_to_string(&mut err).unwrap();
// We read stderr in another thread as it seems that `read_to_string` can block
// indefinitely in some cases
let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || {
let mut err = String::new();
shard_stderr_reader.read_to_string(&mut err).unwrap();
err_sender.send(err).unwrap_or(());
});

let err = err_receiver
.recv_timeout(Duration::from_millis(100))
.map_err(|err| {
tracing::error!("Unable to read shard {rank} error from stderr");
err
})
.ok();

if let ExitStatus::Signaled(signal) = exit_status {
tracing::error!("Shard process was signaled to shutdown with signal {signal}");
Expand All @@ -474,7 +487,7 @@ fn shard_manager(
}

// We received a shutdown signal
if *shutdown.lock().unwrap() {
if shutdown.load(Ordering::SeqCst) {
p.kill().unwrap();
let _ = p.wait_timeout(Duration::from_secs(90));
tracing::info!("Shard {rank} terminated");
Expand All @@ -494,14 +507,11 @@ fn shard_manager(
}
}

fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receiver<()>) {
fn shutdown_shards(shutdown: Arc<AtomicBool>, shutdown_receiver: &mpsc::Receiver<()>) {
tracing::info!("Shutting down shards");
// Update shutdown value to true
// This will be picked up by the shard manager
{
let mut shutdown = shutdown.lock().unwrap();
*shutdown = true;
}
shutdown.store(true, Ordering::SeqCst);

// Wait for shards to shutdown
// This will block till all shutdown_sender are dropped
Expand Down Expand Up @@ -743,7 +753,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
fn spawn_shards(
num_shard: usize,
args: &Args,
shutdown: Arc<Mutex<bool>>,
shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>,
shutdown_sender: mpsc::Sender<()>,
status_receiver: &mpsc::Receiver<ShardStatus>,
Expand Down Expand Up @@ -819,7 +829,10 @@ fn spawn_shards(
sleep(Duration::from_millis(100));
}
Ok(ShardStatus::Failed((rank, err))) => {
tracing::error!("Shard {} failed to start:\n{}", rank, err);
tracing::error!("Shard {rank} failed to start");
if let Some(err) = err {
tracing::error!("{err}");
}
shutdown_shards(shutdown, shutdown_receiver);
return Err(LauncherError::ShardCannotStart);
}
Expand All @@ -835,7 +848,7 @@ fn spawn_shards(

fn spawn_webserver(
args: Args,
shutdown: Arc<Mutex<bool>>,
shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>,
) -> Result<Popen, LauncherError> {
// All shard started
Expand Down Expand Up @@ -1002,7 +1015,7 @@ fn main() -> Result<(), LauncherError> {
download_convert_model(&args, running.clone())?;

// Shared shutdown bool
let shutdown = Arc::new(Mutex::new(false));
let shutdown = Arc::new(AtomicBool::new(false));
// Shared shutdown channel
// When shutting down, the main thread will wait for all senders to be dropped
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
Expand Down Expand Up @@ -1034,7 +1047,10 @@ fn main() -> Result<(), LauncherError> {

while running.load(Ordering::SeqCst) {
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
tracing::error!("Shard {rank} failed:\n{err}");
tracing::error!("Shard {rank} failed to start");
if let Some(err) = err {
tracing::error!("{err}");
}
exit_code = Err(LauncherError::ShardFailed);
break;
};
Expand Down

0 comments on commit 2b53d71

Please sign in to comment.