You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Today torchspec places inference engines and the FSDP trainer on disjoint GPUs and ferries hidden states through a Mooncake distributed KV store (torchspec/transfer/mooncake/eagle_store.py, torchspec/training/data_fetcher.py). For a 2-node / 16-GPU job that forces a training/inference split (e.g. 8/8) and a network-heavy hidden-state path even though the producer and consumer could live next to each other.
We want a colocate mode where both roles share every GPU:
2 nodes x 8 GPUs, e.g. 2 engines x 8-GPU TP + FSDP-16 training, all 16 GPUs used by both sides.
Hidden states move on-device via direct NCCL send/recv from the engine TP rank on GPU i to the FSDP trainer rank on GPU i. No Mooncake, no Ray queue of big payloads, no master/metadata RPC.
Target topology
flowchart LR
subgraph gpu [GPU i]
e[Engine TP rank i]
t[Trainer FSDP rank i]
end
e -->|NCCL P2P, same device| t
Loading
Two processes per GPU, concurrent under CUDA MPS. Engine ranks 0..TP-1 on GPUs 0..TP-1 (engine 0), etc. Trainer ranks 0..N-1 on GPUs 0..N-1. One engine's full batch is split along the batch dim so each of its TP ranks feeds a B_eng/TP-sized shard to its colocated trainer rank.
Design
1. Process placement (MPS)
Single placement group of N = world_size GPU bundles. Both pgs["training"] and pgs["inference"] point at the same bundles. This extends the colocate=True branch already present in torchspec/ray/placement_group.py but makes 1:1 bundle pairing between engine rank i and trainer rank i a hard invariant.
Trainer actors: num_gpus = train_frac. Engine actors: num_gpus = infer_frac. Ray schedules two actors per bundle.
Driver helper launches nvidia-cuda-mps-control -d per node before any actor starts; sets CUDA_MPS_PIPE_DIRECTORY / CUDA_MPS_LOG_DIRECTORY in the actor env. Document MPS prerequisite in docs/ray.md.
2. Data plane (NCCL P2P, GPU-local)
After the target model's TP forward inside sglang, each TP rank holds the same[B_eng, S, H] hidden-states tensor (standard TP all-reduce at the block boundary). The handoff to the trainer is therefore:
Engine rank i local-chunks its tensor along batch: shard_i = hidden_states[i*B_eng/TP : (i+1)*B_eng/TP]. No cross-rank collective, just indexing.
Engine rank idist.send(shard_i, dst=trainer_rank_i) on the transfer group.
Trainer rank idist.recv(buffer, src=engine_rank_i) into a pre-allocated GPU buffer, then proceeds into DFlash/Eagle3 forward.
Because both processes share the same physical device under MPS, NCCL uses device-local memory for the P2P (no PCIe / NVLink hop). Aux-layer hidden states (Eagle3 3-layer target) and last_hidden_states are sent in the same step with additional P2P calls on the same group.
Why P2P and not reduce-scatter: the hidden states are already replicated across engine TP ranks, so there is nothing to reduce; the collective would degenerate to a scatter. Local chunk + P2P is the simpler primitive and avoids any patch to sglang's TP boundary. Reduce-scatter is listed as a future optimisation if we ever want to skip the engine's final TP all-reduce and fuse it with the scatter.
3. Control plane (fully sync, single-buffer)
The union NCCL world has 2*N ranks: N trainer ranks + N engine ranks (one per engine rank across all engines). Subgroups:
FSDP DP group: the N trainer ranks only, used by FSDP for its own reduce-scatter / all-gather of params and grads. Unchanged semantics.
Transfer pairs: for each pair (engine_rank_i, trainer_rank_i) on the same GPU, we use the global world for dist.send/recv. No dedicated subgroup needed for P2P.
CPU gloo group: small-payload broadcast of step metadata (step id, B_eng, S, loss_mask, input_ids).
Per training step:
Trainer broadcasts "ready" to engines via the CPU group.
Engine generates one batch, runs target forward, sends shard to colocated trainer.
Trainer runs fwd/bwd/opt.
Repeat.
Engine is idle while trainer runs and vice versa. One pre-allocated GPU recv buffer per trainer rank. No async pool, no queue-backed pipeline, no double-buffering.
All of the current async machinery -- AsyncInferenceManager background thread, SamplePool, per-DP Ray Queue of TrainSamples, Mooncake retry loops -- is not needed in colocate. The controller keeps dataset iteration and prompt dispatch (small payloads, stays on Ray) but is otherwise much thinner.
4. Memory isolation (soft caps)
MPS does not isolate VRAM -- both processes allocate from the same physical pool. We combine three layers:
Config-time budget.train_frac + infer_frac + 0.10 <= 1.00 (10% safety for cuBLAS/cuDNN/NCCL workspaces used by both sides). For DFlash draft training on H100, a 0.45 / 0.45 split is a reasonable starting point.
Enforced per-process caps.
Trainer: torch.cuda.set_per_process_memory_fraction(train_frac, device) in torchspec/training/trainer_actor.py. PyTorch's caching allocator enforces this as a hard ceiling in that process.
sglang: mem_fraction_static=infer_frac in torchspec/inference/engine/sgl_engine.py. sglang computes it off "free" memory at its startup, so trainer starts first and claims its fraction before sglang initialises.
Both processes: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation pile-up from concurrent alloc/free under MPS.
Deterministic pre-warm. Trainer runs a one-step dummy fwd/bwd at init to bring its allocator to peak; sglang pre-allocates its KV pool up-front via mem_fraction_static. After init neither side grows VRAM.
Files that will change
Core entry + topology:
torchspec/train_entry.py -- add colocate_strategy=mps + transfer_mode=nccl branch; skip launch_mooncake_master, skip build_mooncake_config when in this branch.
torchspec/controller/setup.py -- conditional Mooncake wiring; new setup_colocate_training_with_engines that wires P2P pairs.
Any form of engine-ahead pipelining, double-buffering, or async producer/consumer between engine and trainer. Strictly step-serialised handoff.
Mixed colocate + disaggregated in the same run.
Open risks
Allocator fragmentation between two concurrent processes on one GPU. Mitigated by expandable_segments:True and pre-warm; verify with nvidia-smi + torch.cuda.memory_stats across 100+ steps.
NCCL + FSDP interaction. FSDP's collectives and the transfer P2P share the same world. Run FSDP on its own subgroup (trainer ranks only) to avoid interleaving; put transfer P2P on a dedicated CUDA stream so it doesn't serialise behind FSDP's comms.
Straggler engine blocks the paired trainer ranks on dist.recv. At FSDP-16 everyone syncs at backward anyway, so one slow engine is already the step bottleneck. Acceptable as a baseline; add a timeout / step-skip policy if it becomes an issue.
MPS scheduling fairness. If one side dominates SMs, throughput drops. Expose CUDA_MPS_ACTIVE_THREAD_PERCENTAGE per role as a tuning knob (off by default).
Validation plan
Smoke: 1 node, 8 GPUs, 1 engine x 4-GPU TP + FSDP-4 trainer. Run one training step end-to-end. Compare per-layer gradients against the Mooncake baseline on identical prompts + seeds; require numerical match up to NCCL non-determinism (<1e-6 abs).
Stability: 1000 steps, log peak memory per process, assert no growth after step 10.
Scale-out: 2 nodes x 8 GPUs, 2 engines x 8-GPU TP + FSDP-16, full DFlash training job, compare convergence curves against Mooncake baseline over 1k steps.
Design: Support co-locate training and inference
Motivation
Today
torchspecplaces inference engines and the FSDP trainer on disjoint GPUs and ferries hidden states through a Mooncake distributed KV store (torchspec/transfer/mooncake/eagle_store.py, torchspec/training/data_fetcher.py). For a 2-node / 16-GPU job that forces a training/inference split (e.g. 8/8) and a network-heavy hidden-state path even though the producer and consumer could live next to each other.We want a colocate mode where both roles share every GPU:
ito the FSDP trainer rank on GPUi. No Mooncake, no Ray queue of big payloads, no master/metadata RPC.Target topology
flowchart LR subgraph gpu [GPU i] e[Engine TP rank i] t[Trainer FSDP rank i] end e -->|NCCL P2P, same device| tTwo processes per GPU, concurrent under CUDA MPS. Engine ranks
0..TP-1on GPUs0..TP-1(engine 0), etc. Trainer ranks0..N-1on GPUs0..N-1. One engine's full batch is split along the batch dim so each of itsTPranks feeds aB_eng/TP-sized shard to its colocated trainer rank.Design
1. Process placement (MPS)
N = world_sizeGPU bundles. Bothpgs["training"]andpgs["inference"]point at the same bundles. This extends thecolocate=Truebranch already present in torchspec/ray/placement_group.py but makes 1:1 bundle pairing between engine rankiand trainer rankia hard invariant.num_gpus = train_frac. Engine actors:num_gpus = infer_frac. Ray schedules two actors per bundle.nvidia-cuda-mps-control -dper node before any actor starts; setsCUDA_MPS_PIPE_DIRECTORY/CUDA_MPS_LOG_DIRECTORYin the actor env. Document MPS prerequisite in docs/ray.md.2. Data plane (NCCL P2P, GPU-local)
After the target model's TP forward inside sglang, each TP rank holds the same
[B_eng, S, H]hidden-states tensor (standard TP all-reduce at the block boundary). The handoff to the trainer is therefore:ilocal-chunks its tensor along batch:shard_i = hidden_states[i*B_eng/TP : (i+1)*B_eng/TP]. No cross-rank collective, just indexing.idist.send(shard_i, dst=trainer_rank_i)on the transfer group.idist.recv(buffer, src=engine_rank_i)into a pre-allocated GPU buffer, then proceeds into DFlash/Eagle3 forward.Because both processes share the same physical device under MPS, NCCL uses device-local memory for the P2P (no PCIe / NVLink hop). Aux-layer hidden states (Eagle3 3-layer target) and
last_hidden_statesare sent in the same step with additional P2P calls on the same group.Why P2P and not reduce-scatter: the hidden states are already replicated across engine TP ranks, so there is nothing to reduce; the collective would degenerate to a scatter. Local chunk + P2P is the simpler primitive and avoids any patch to sglang's TP boundary. Reduce-scatter is listed as a future optimisation if we ever want to skip the engine's final TP all-reduce and fuse it with the scatter.
3. Control plane (fully sync, single-buffer)
The union NCCL world has
2*Nranks:Ntrainer ranks +Nengine ranks (one per engine rank across all engines). Subgroups:Ntrainer ranks only, used by FSDP for its own reduce-scatter / all-gather of params and grads. Unchanged semantics.(engine_rank_i, trainer_rank_i)on the same GPU, we use the global world fordist.send/recv. No dedicated subgroup needed for P2P.B_eng,S,loss_mask,input_ids).Per training step:
Engine is idle while trainer runs and vice versa. One pre-allocated GPU recv buffer per trainer rank. No async pool, no queue-backed pipeline, no double-buffering.
All of the current async machinery --
AsyncInferenceManagerbackground thread,SamplePool, per-DP RayQueueofTrainSamples, Mooncake retry loops -- is not needed in colocate. The controller keeps dataset iteration and prompt dispatch (small payloads, stays on Ray) but is otherwise much thinner.4. Memory isolation (soft caps)
MPS does not isolate VRAM -- both processes allocate from the same physical pool. We combine three layers:
Config-time budget.
train_frac + infer_frac + 0.10 <= 1.00(10% safety for cuBLAS/cuDNN/NCCL workspaces used by both sides). For DFlash draft training on H100, a 0.45 / 0.45 split is a reasonable starting point.Enforced per-process caps.
torch.cuda.set_per_process_memory_fraction(train_frac, device)in torchspec/training/trainer_actor.py. PyTorch's caching allocator enforces this as a hard ceiling in that process.mem_fraction_static=infer_fracin torchspec/inference/engine/sgl_engine.py. sglang computes it off "free" memory at its startup, so trainer starts first and claims its fraction before sglang initialises.PYTORCH_CUDA_ALLOC_CONF=expandable_segments:Trueto avoid fragmentation pile-up from concurrent alloc/free under MPS.Deterministic pre-warm. Trainer runs a one-step dummy fwd/bwd at init to bring its allocator to peak; sglang pre-allocates its KV pool up-front via
mem_fraction_static. After init neither side grows VRAM.Files that will change
Core entry + topology:
colocate_strategy=mps+transfer_mode=ncclbranch; skiplaunch_mooncake_master, skipbuild_mooncake_configwhen in this branch.setup_colocate_training_with_enginesthat wires P2P pairs.N.Process groups and actors:
2*N-rank union world, create FSDP subgroup over trainer ranks, set memory fraction + MPS env.Data plane:
dist.sendto the colocated trainer rank instead of Mooncakeput._sglang-- expose a hook at the spec-training hidden-state boundary so the torchspec engine wrapper can emit via NCCL P2P.init_mooncake_storewithinit_nccl_receiver(pre-allocates the recv buffer(s)).NcclDataFetcher/NcclHiddenStatesDatasetthat postsdist.recvper step.Controller trim:
mooncake_key+ shape plumbing fromTrainSample; carry onlystep_id,seq_len,loss_mask,input_ids.Non-goals
nccl_hidden_states_connector.pycan land later if needed.Open risks
expandable_segments:Trueand pre-warm; verify withnvidia-smi+torch.cuda.memory_statsacross 100+ steps.dist.recv. At FSDP-16 everyone syncs at backward anyway, so one slow engine is already the step bottleneck. Acceptable as a baseline; add a timeout / step-skip policy if it becomes an issue.CUDA_MPS_ACTIVE_THREAD_PERCENTAGEper role as a tuning knob (off by default).Validation plan