Skip to content

Commit a77725e

Browse files
samsjaclaude
andauthored
refactor(slurm): move NCCL and runtime configs from template to config (#1818)
* refactor(slurm): move NCCL and runtime configs from template to config - Add port/timeout fields to SharedWeightBroadcastConfig - Propagate port/timeout in auto_setup_weight_broadcast - Add auto_setup_slurm_nccl validator that sets: - orchestrator.num_train_workers from SLURM topology - trainer.weight_broadcast.host to 0.0.0.0 - trainer.weight_broadcast.inference_world_size from SLURM topology - Add validate_inference_config to catch missing inference config early - Remove hardcoded env vars and CLI args from SLURM template Only --weight_broadcast.host $MASTER_ADDR and --client.base-url remain as CLI args (runtime values from SLURM). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * feat(slurm): add rl_slurm entrypoint and update docs - Add rl_slurm script entrypoint to pyproject.toml - Update docs to use `uv run rl_slurm` instead of `uv run python -m prime_rl.slurm.rl` 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> --------- Co-authored-by: sami jaghouar <sami@primeintellect.ai> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 3102ca4 commit a77725e

File tree

5 files changed

+38
-17
lines changed

5 files changed

+38
-17
lines changed

docs/slurm.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
For SLURM clusters, use the `rl_slurm` entrypoint. It resolves the full config (trainer, orchestrator, inference), dumps sub-configs as TOML files, renders a SLURM batch script from a Jinja2 template, and submits it with `sbatch`.
44

55
```bash
6-
uv run python -m prime_rl.slurm.rl @ examples/slurm/hendrycks_math.toml
6+
uv run rl_slurm @ examples/slurm/hendrycks_math.toml
77
```
88

99
This will:
@@ -15,7 +15,7 @@ This will:
1515
To only generate the script without submitting, use `--dry-run`:
1616

1717
```bash
18-
uv run python -m prime_rl.slurm.rl @ examples/slurm/hendrycks_math.toml --dry-run
18+
uv run rl_slurm @ examples/slurm/hendrycks_math.toml --dry-run
1919
```
2020

2121
## Configuration
@@ -123,7 +123,7 @@ dp = 2
123123
The default template handles a standard multi-node setup with NCCL weight broadcast, InfiniBand detection, and `srun`-based process dispatch. For more advanced use cases (custom partitions, account settings, module loads, different networking setups, etc.), provide your own Jinja2 template:
124124

125125
```bash
126-
uv run python -m prime_rl.slurm.rl \
126+
uv run rl_slurm \
127127
@ my_config.toml \
128128
--slurm-template path/to/my_template.sh.j2
129129
```

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ dependencies = [
3838

3939
[project.scripts]
4040
rl = "prime_rl.rl:main"
41+
rl_slurm = "prime_rl.slurm.rl:main"
4142
trainer = "prime_rl.trainer.rl.train:main"
4243
orchestrator = "prime_rl.orchestrator.orchestrator:main"
4344
inference = "prime_rl.inference.server:main"

src/prime_rl/rl_config.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ class SharedWeightBroadcastConfig(BaseSettings):
9292
"filesystem"
9393
)
9494

95+
port: Annotated[int, Field(description="The port to use for NCCL weight broadcast.")] = 29501
96+
timeout: Annotated[int, Field(description="The timeout in seconds for NCCL weight broadcast.")] = 1200
97+
9598

9699
class BaseRLConfig(BaseSettings):
97100
"""Configures an RL training run."""
@@ -306,10 +309,15 @@ def auto_setup_weight_broadcast(self):
306309
if self.weight_broadcast.type == "nccl":
307310
inference_world_size = self.inference.parallel.dp * self.inference.parallel.tp if self.inference else 1
308311
self.trainer.weight_broadcast = TrainerNCCLWeightBroadcastConfig(
309-
type=self.weight_broadcast.type, inference_world_size=inference_world_size
312+
type=self.weight_broadcast.type,
313+
inference_world_size=inference_world_size,
314+
port=self.weight_broadcast.port,
315+
timeout=self.weight_broadcast.timeout,
310316
)
311317
self.orchestrator.weight_broadcast = OrchestratorNCCLWeightBroadcastConfig(
312-
type=self.weight_broadcast.type
318+
type=self.weight_broadcast.type,
319+
port=self.weight_broadcast.port,
320+
timeout=self.weight_broadcast.timeout,
313321
)
314322
elif self.weight_broadcast.type == "filesystem":
315323
self.trainer.weight_broadcast = TrainerFileSystemWeightBroadcastConfig()

src/prime_rl/slurm/rl.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,30 @@ def auto_setup_dp_replicate(self):
4949
self.trainer.dp_replicate = self.num_train_nodes // self.nodes_per_fsdp_group
5050
return self
5151

52+
@model_validator(mode="after")
53+
def auto_setup_slurm_nccl(self):
54+
"""Set SLURM-specific values for NCCL weight broadcast and num_train_workers."""
55+
# Set num_train_workers based on SLURM topology
56+
self.orchestrator.num_train_workers = self.num_train_nodes * self.gpus_per_node
57+
58+
# Set NCCL-specific values if using NCCL weight broadcast
59+
if self.weight_broadcast is not None and self.weight_broadcast.type == "nccl":
60+
# Trainer listens on all interfaces
61+
self.trainer.weight_broadcast.host = "0.0.0.0"
62+
# Compute inference world size from SLURM topology
63+
self.trainer.weight_broadcast.inference_world_size = self.gpus_per_node * self.num_infer_nodes
64+
return self
65+
66+
@model_validator(mode="after")
67+
def validate_inference_config(self):
68+
"""Validate that inference config is provided when num_infer_nodes > 0."""
69+
if self.num_infer_nodes > 0 and self.inference is None:
70+
raise ValueError(
71+
f"inference config is required when num_infer_nodes > 0 (got {self.num_infer_nodes}). "
72+
"The SLURM template will launch inference servers on these nodes."
73+
)
74+
return self
75+
5276

5377
def write_subconfigs(config: RLSLURMConfig, output_dir: Path) -> None:
5478
"""Write resolved subconfigs to disk as TOML files."""

src/prime_rl/slurm/rl_slurm.sh.j2

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@ export OMP_NUM_THREADS=1
3838
export HOSTNAMES=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
3939
export INFER_HOSTS=${HOSTNAMES[@]:0:$NUM_INFER_NODES}
4040
export TRAIN_HOSTS=${HOSTNAMES[@]:$NUM_INFER_NODES:$SLURM_JOB_NUM_NODES}
41-
export BROADCAST_PORT=${BROADCAST_PORT:-29501}
42-
export BROADCAST_TIMEOUT=${BROADCAST_TIMEOUT:-12000}
43-
export NCCL_COMM_TIMEOUT=${NCCL_BROADCAST_TIMEOUT:-12000}
4441

4542
INFER_URLS=""
4643
for host in ${INFER_HOSTS[@]}; do
@@ -105,11 +102,7 @@ else
105102
uv run orchestrator \
106103
@ $CONFIG_DIR/orchestrator.toml \
107104
--weight_broadcast.host $MASTER_ADDR \
108-
--weight_broadcast.port $BROADCAST_PORT \
109-
--weight_broadcast.timeout $BROADCAST_TIMEOUT \
110105
--client.base-url $INFER_URLS \
111-
--client.timeout 3600 \
112-
--num-train-workers $((NUM_TRAIN_NODES * GPUS_PER_NODE)) \
113106
2>&1 | tee $OUTPUT_DIR/slurm/latest_orchestrator.log $OUTPUT_DIR/slurm/job_${SLURM_JOB_ID}_orchestrator.log & disown
114107
fi
115108
@@ -132,11 +125,6 @@ else
132125
--local-ranks-filter=$(seq -s, 0 $((GPUS_PER_NODE - 1))) \
133126
-m prime_rl.trainer.rl.train \
134127
@ $CONFIG_DIR/trainer.toml \
135-
--weight_broadcast.host 0.0.0.0 \
136-
--weight_broadcast.port $BROADCAST_PORT \
137-
--weight_broadcast.inference_world_size $((GPUS_PER_NODE * NUM_INFER_NODES)) \
138-
--weight_broadcast.timeout $BROADCAST_TIMEOUT \
139-
--dist_timeout_seconds $NCCL_COMM_TIMEOUT \
140128
2>&1 | tee -a $OUTPUT_DIR/slurm/latest_train_node_rank_${TRAIN_NODE_RANK}.log $OUTPUT_DIR/slurm/job_${SLURM_JOB_ID}_train_node_rank_${TRAIN_NODE_RANK}.log
141129
fi
142130
'

0 commit comments

Comments
 (0)