diff --git a/src/prime_rl/entrypoints/rl.py b/src/prime_rl/entrypoints/rl.py index 9e33d217e2..88cb6800da 100644 --- a/src/prime_rl/entrypoints/rl.py +++ b/src/prime_rl/entrypoints/rl.py @@ -296,6 +296,7 @@ def sigterm_handler(signum, frame): # Start training process trainer_cmd = [ "torchrun", + "--role=trainer", f"--rdzv-endpoint=localhost:{get_free_port()}", f"--rdzv-id={uuid.uuid4().hex}", # Pipe all logs to file, and only master rank logs to stdout @@ -342,7 +343,10 @@ def sigterm_handler(signum, frame): # Monitor all processes for failures logger.success("Startup complete. Showing trainer logs...") - tail_process = Popen(["tail", "-F", log_dir / "trainer.log"]) + tail_process = Popen( + f"tail -F '{log_dir / 'trainer.log'}' | sed -u 's/^\\[[a-zA-Z]*[0-9]*\\]://'", + shell=True, + ) processes.append(tail_process) # Check for errors from monitor threads diff --git a/src/prime_rl/entrypoints/sft.py b/src/prime_rl/entrypoints/sft.py index df79803458..2bc735bc7d 100644 --- a/src/prime_rl/entrypoints/sft.py +++ b/src/prime_rl/entrypoints/sft.py @@ -115,6 +115,7 @@ def sft_local(config: SFTConfig): trainer_cmd = [ "torchrun", + "--role=trainer", f"--rdzv-endpoint=localhost:{get_free_port()}", f"--rdzv-id={uuid.uuid4().hex}", f"--log-dir={config.output_dir / 'logs' / 'trainer' / 'torchrun'}", @@ -159,7 +160,10 @@ def sft_local(config: SFTConfig): monitor_threads.append(monitor_thread) logger.success("Startup complete. Showing trainer logs...") - tail_process = Popen(["tail", "-F", str(log_dir / "trainer.log")]) + tail_process = Popen( + f"tail -F '{log_dir / 'trainer.log'}' | sed -u 's/^\\[[a-zA-Z]*[0-9]*\\]://'", + shell=True, + ) processes.append(tail_process) stop_event.wait() diff --git a/src/prime_rl/templates/multi_node_rl.sbatch.j2 b/src/prime_rl/templates/multi_node_rl.sbatch.j2 index d6957733c1..828841af62 100644 --- a/src/prime_rl/templates/multi_node_rl.sbatch.j2 +++ b/src/prime_rl/templates/multi_node_rl.sbatch.j2 @@ -342,6 +342,7 @@ else echo $HOSTNAMES_STR | tee $OUTPUT_DIR/logs/trainer/node_${TRAIN_NODE_RANK}.log {% if wandb_shared %}WANDB_SHARED_LABEL=trainer {% endif %}uv run torchrun \ + --role=trainer \ --nnodes=$NUM_TRAIN_NODES \ --nproc-per-node=$GPUS_PER_NODE \ --node-rank=$TRAIN_NODE_RANK \ @@ -353,6 +354,6 @@ else --local-ranks-filter={{ ranks_filter }} \ -m prime_rl.trainer.rl.train \ @ $CONFIG_DIR/trainer.toml \ - 2>&1 | tee -a $OUTPUT_DIR/logs/trainer/node_${TRAIN_NODE_RANK}.log + 2>&1 | sed -u 's/^\[[a-zA-Z]*[0-9]*\]://' | tee -a $OUTPUT_DIR/logs/trainer/node_${TRAIN_NODE_RANK}.log {% if num_infer_nodes > 0 %} fi{% endif %} ' diff --git a/src/prime_rl/templates/multi_node_sft.sbatch.j2 b/src/prime_rl/templates/multi_node_sft.sbatch.j2 index 3d904b319d..079129d3f2 100644 --- a/src/prime_rl/templates/multi_node_sft.sbatch.j2 +++ b/src/prime_rl/templates/multi_node_sft.sbatch.j2 @@ -82,6 +82,7 @@ srun bash -c ' echo $HOSTNAMES | tee $OUTPUT_DIR/logs/trainer/node_${TRAIN_NODE_RANK}.log uv run torchrun \ + --role=trainer \ --nnodes=$NUM_NODES \ --nproc-per-node=$GPUS_PER_NODE \ --node-rank=$TRAIN_NODE_RANK \ @@ -93,5 +94,5 @@ srun bash -c ' --local-ranks-filter={{ ranks_filter }} \ -m prime_rl.trainer.sft.train \ @ $CONFIG_PATH \ - 2>&1 | tee -a $OUTPUT_DIR/logs/trainer/node_${TRAIN_NODE_RANK}.log + 2>&1 | sed -u 's/^\[[a-zA-Z]*[0-9]*\]://' | tee -a $OUTPUT_DIR/logs/trainer/node_${TRAIN_NODE_RANK}.log '