Skip to content

[Feature] Support co-locate training and inference #81

@yubofredwang

Description

@yubofredwang

Design: Support co-locate training and inference

Motivation

Today torchspec places inference engines and the FSDP trainer on disjoint GPUs and ferries hidden states through a Mooncake distributed KV store (torchspec/transfer/mooncake/eagle_store.py, torchspec/training/data_fetcher.py). For a 2-node / 16-GPU job that forces a training/inference split (e.g. 8/8) and a network-heavy hidden-state path even though the producer and consumer could live next to each other.

We want a colocate mode where both roles share every GPU:

  • 2 nodes x 8 GPUs, e.g. 2 engines x 8-GPU TP + FSDP-16 training, all 16 GPUs used by both sides.
  • Hidden states move on-device via direct NCCL send/recv from the engine TP rank on GPU i to the FSDP trainer rank on GPU i. No Mooncake, no Ray queue of big payloads, no master/metadata RPC.

Target topology

flowchart LR
    subgraph gpu [GPU i]
      e[Engine TP rank i]
      t[Trainer FSDP rank i]
    end
    e -->|NCCL P2P, same device| t
Loading

Two processes per GPU, concurrent under CUDA MPS. Engine ranks 0..TP-1 on GPUs 0..TP-1 (engine 0), etc. Trainer ranks 0..N-1 on GPUs 0..N-1. One engine's full batch is split along the batch dim so each of its TP ranks feeds a B_eng/TP-sized shard to its colocated trainer rank.

Design

1. Process placement (MPS)

  • Single placement group of N = world_size GPU bundles. Both pgs["training"] and pgs["inference"] point at the same bundles. This extends the colocate=True branch already present in torchspec/ray/placement_group.py but makes 1:1 bundle pairing between engine rank i and trainer rank i a hard invariant.
  • Trainer actors: num_gpus = train_frac. Engine actors: num_gpus = infer_frac. Ray schedules two actors per bundle.
  • Driver helper launches nvidia-cuda-mps-control -d per node before any actor starts; sets CUDA_MPS_PIPE_DIRECTORY / CUDA_MPS_LOG_DIRECTORY in the actor env. Document MPS prerequisite in docs/ray.md.

2. Data plane (NCCL P2P, GPU-local)

After the target model's TP forward inside sglang, each TP rank holds the same [B_eng, S, H] hidden-states tensor (standard TP all-reduce at the block boundary). The handoff to the trainer is therefore:

  1. Engine rank i local-chunks its tensor along batch: shard_i = hidden_states[i*B_eng/TP : (i+1)*B_eng/TP]. No cross-rank collective, just indexing.
  2. Engine rank i dist.send(shard_i, dst=trainer_rank_i) on the transfer group.
  3. Trainer rank i dist.recv(buffer, src=engine_rank_i) into a pre-allocated GPU buffer, then proceeds into DFlash/Eagle3 forward.

Because both processes share the same physical device under MPS, NCCL uses device-local memory for the P2P (no PCIe / NVLink hop). Aux-layer hidden states (Eagle3 3-layer target) and last_hidden_states are sent in the same step with additional P2P calls on the same group.

Why P2P and not reduce-scatter: the hidden states are already replicated across engine TP ranks, so there is nothing to reduce; the collective would degenerate to a scatter. Local chunk + P2P is the simpler primitive and avoids any patch to sglang's TP boundary. Reduce-scatter is listed as a future optimisation if we ever want to skip the engine's final TP all-reduce and fuse it with the scatter.

3. Control plane (fully sync, single-buffer)

The union NCCL world has 2*N ranks: N trainer ranks + N engine ranks (one per engine rank across all engines). Subgroups:

  • FSDP DP group: the N trainer ranks only, used by FSDP for its own reduce-scatter / all-gather of params and grads. Unchanged semantics.
  • Transfer pairs: for each pair (engine_rank_i, trainer_rank_i) on the same GPU, we use the global world for dist.send/recv. No dedicated subgroup needed for P2P.
  • CPU gloo group: small-payload broadcast of step metadata (step id, B_eng, S, loss_mask, input_ids).

Per training step:

  1. Trainer broadcasts "ready" to engines via the CPU group.
  2. Engine generates one batch, runs target forward, sends shard to colocated trainer.
  3. Trainer runs fwd/bwd/opt.
  4. Repeat.

Engine is idle while trainer runs and vice versa. One pre-allocated GPU recv buffer per trainer rank. No async pool, no queue-backed pipeline, no double-buffering.

All of the current async machinery -- AsyncInferenceManager background thread, SamplePool, per-DP Ray Queue of TrainSamples, Mooncake retry loops -- is not needed in colocate. The controller keeps dataset iteration and prompt dispatch (small payloads, stays on Ray) but is otherwise much thinner.

4. Memory isolation (soft caps)

MPS does not isolate VRAM -- both processes allocate from the same physical pool. We combine three layers:

Config-time budget. train_frac + infer_frac + 0.10 <= 1.00 (10% safety for cuBLAS/cuDNN/NCCL workspaces used by both sides). For DFlash draft training on H100, a 0.45 / 0.45 split is a reasonable starting point.

Enforced per-process caps.

  • Trainer: torch.cuda.set_per_process_memory_fraction(train_frac, device) in torchspec/training/trainer_actor.py. PyTorch's caching allocator enforces this as a hard ceiling in that process.
  • sglang: mem_fraction_static=infer_frac in torchspec/inference/engine/sgl_engine.py. sglang computes it off "free" memory at its startup, so trainer starts first and claims its fraction before sglang initialises.
  • Both processes: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation pile-up from concurrent alloc/free under MPS.

Deterministic pre-warm. Trainer runs a one-step dummy fwd/bwd at init to bring its allocator to peak; sglang pre-allocates its KV pool up-front via mem_fraction_static. After init neither side grows VRAM.

Files that will change

Core entry + topology:

Process groups and actors:

Data plane:

  • torchspec/inference/engine/sgl_engine.py -- after target forward, chunk along batch and dist.send to the colocated trainer rank instead of Mooncake put.
  • sglang patch inside _sglang -- expose a hook at the spec-training hidden-state boundary so the torchspec engine wrapper can emit via NCCL P2P.
  • torchspec/training/trainer.py -- replace init_mooncake_store with init_nccl_receiver (pre-allocates the recv buffer(s)).
  • torchspec/training/data_fetcher.py -- new NcclDataFetcher / NcclHiddenStatesDataset that posts dist.recv per step.

Controller trim:

Non-goals

  • vLLM colocate path. We touch only sglang here; torchspec/inference/engine/mooncake_hidden_states_connector.py stays on Mooncake. A sibling nccl_hidden_states_connector.py can land later if needed.
  • Any form of engine-ahead pipelining, double-buffering, or async producer/consumer between engine and trainer. Strictly step-serialised handoff.
  • Mixed colocate + disaggregated in the same run.

Open risks

  • Allocator fragmentation between two concurrent processes on one GPU. Mitigated by expandable_segments:True and pre-warm; verify with nvidia-smi + torch.cuda.memory_stats across 100+ steps.
  • NCCL + FSDP interaction. FSDP's collectives and the transfer P2P share the same world. Run FSDP on its own subgroup (trainer ranks only) to avoid interleaving; put transfer P2P on a dedicated CUDA stream so it doesn't serialise behind FSDP's comms.
  • Straggler engine blocks the paired trainer ranks on dist.recv. At FSDP-16 everyone syncs at backward anyway, so one slow engine is already the step bottleneck. Acceptable as a baseline; add a timeout / step-skip policy if it becomes an issue.
  • MPS scheduling fairness. If one side dominates SMs, throughput drops. Expose CUDA_MPS_ACTIVE_THREAD_PERCENTAGE per role as a tuning knob (off by default).

Validation plan

  • Smoke: 1 node, 8 GPUs, 1 engine x 4-GPU TP + FSDP-4 trainer. Run one training step end-to-end. Compare per-layer gradients against the Mooncake baseline on identical prompts + seeds; require numerical match up to NCCL non-determinism (<1e-6 abs).
  • Stability: 1000 steps, log peak memory per process, assert no growth after step 10.
  • Scale-out: 2 nodes x 8 GPUs, 2 engines x 8-GPU TP + FSDP-16, full DFlash training job, compare convergence curves against Mooncake baseline over 1k steps.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions