[WIP] Support co-locate training and inference (#81)#92
Draft
zhubohao911 wants to merge 53 commits into
Draft
Conversation
Placeholder commit for tracking work on issue lightseekorg#81. Implementation will land across multiple PRs following the phased plan.
Lays the foundation for putting trainer + inference engine on the same GPU via NVIDIA MPS. All new code is gated behind `training.colocate_strategy=mps`; the legacy disagg / `colocate=True` paths are unchanged. Phase 0 — config plumbing & validation: * Add `colocate_strategy`, `transfer_mode`, `train_frac`, `infer_frac` to `TrainingConfig`. * `torchspec/colocate/config.py` validates supported combinations (`mooncake` is the only legal `transfer_mode` when colocate is off; `mps`/`nccl` is the only combination when on), the `train_frac + infer_frac + 0.10 headroom <= 1.0` budget, and the `engine_count * engine_tp == world_size` topology invariant. * Wired into `train_entry.parse_config`. 18/18 unit tests pass locally. Phase 1 — placement + MPS env injection: * `torchspec/colocate/mps.py`: idempotent NVIDIA MPS daemon lifecycle helper (start, status, env-var build, stop). Dependency-free so the Ray driver can call it from a headless box; subprocess is mocked in the 17-test unit suite. * `placement_group.create_placement_groups` now branches on `is_mps_colocate(args)`, logs which strategy fired, and revalidates the engine-count/TP topology invariant. * `RayTrainGroup` claims `train_frac` per actor under MPS (was hard-coded 0.4); `_prepare_sgl_engines` claims `infer_frac` (was 0.2 placeholder). * Both inject `mps_client_env()` + `expandable_segments:True` allocator config into the Ray actor `runtime_env`. * `SglEngine.init` overrides sglang's `mem_fraction_static` from `infer_frac` so users don't keep two budgets in sync. * `train_entry` starts the MPS daemon once at driver init and skips Mooncake master setup under MPS colocate (Phase 5 will rip Mooncake out properly). * Modal smoke test `phase1_placement` (`H100:4`, 22 s test) confirms shared placement group, 4 trainer + 4 engine pairs land on matching `(node_ip, gpu_id)` with distinct PIDs and propagated MPS env vars. Phase 2 — union NCCL world bootstrap helper: * `torchspec/colocate/world.py`: `UnionWorldSpec` + `init_union_world` build a 2N-rank NCCL default PG, an FSDP-only NCCL subgroup of size N, and a 2N-rank gloo metadata subgroup. Sets `TORCHSPEC_COLOCATE_UNION_WORLD=1` so a follow-up sglang patch can detect the union-world setup. Trainers occupy ranks `[0, N)`, engines occupy `[N, 2N)`. * Modal smoke test `phase2_union_world` (`H100:8`, 55 s test): 8 ranks bootstrap the union world; allreduce on the union, FSDP, and gloo groups all succeed; trainer/engine rank assignment matches spec. * Phase 2 is intentionally tested in isolation from MPS sharing (8 GPUs, one rank per GPU). The MPS+union-world integration and the sglang scheduler patch (so sglang TP reuses our union world) are the highest-risk piece and are deferred to Phase 4 where they're actually exercised by the hidden-state hook. Modal infrastructure: * `scripts/modal/setup_modal_secrets.sh` — sandbox env secrets bootstrap (HF + W&B). * `scripts/modal/modal_colocate_smoke.py` — sandbox-targeted Modal app with one entrypoint per phase. Image is built on `nvidia/cuda:12.4.0-devel-ubuntu22.04` + sglang `0f2df9370`. Test artifacts: * `tests/colocate/`: phase 0 validation, phase 1 MPS helper, phase 2 rank-assignment, phase 1 placement (Modal-only), phase 2 union world (Modal-only). 45 unit tests pass locally; Modal smoke tests for Phase 1 and Phase 2 are green. Tracking: * `docs/colocate/implementation_log.md` records status, work logs, and every plan deviation phase by phase. Co-authored-by: Claude
Adds the trainer-side `NcclDataFetcher` + engine-side `send_dummy`
that together form the colocate hidden-state transfer mechanism. Both
use `dist.batch_isend_irecv`, the production-correct primitive for
union-world P2P (avoids NCCL's lazy 2-rank sub-comm init pathology
on the size-2N parent group).
`init_union_world` extended:
- Adds `paired_global_rank` to `UnionWorld` so callers can target the
opposite-role rank without re-deriving.
- Accepts `device_id` (defaults to `cuda.current_device()`); without it
NCCL guesses device-by-rank, which under Ray's CUDA_VISIBLE_DEVICES
isolation maps to a non-existent local GPU and silently deadlocks
P2P.
- Skips the trainer-only NCCL subgroup when there's only one trainer
(1-rank NCCL groups can hang in eager-init mode).
Modal smoke (`phase3_p2p_dummy` on H100:2) — 3/3 tests pass in 137s:
- 100-iter byte-equality on bare NCCL (data plane core)
- 1-iter round trip through the full `init_union_world` +
`NcclDataFetcher` + `send_dummy` path (proves union-world helper
integrates with the data plane)
- shape-mismatch errors cleanly via 90s watchdog (production wraps
recvs in a watchdog timeout for the same reason)
Deviation from plan: ran at 2-rank/no-MPS scale instead of the plan's
4-GPU-MPS topology. Multi-pair concurrent P2P inside a shared parent
group is what Phase 4 builds (each pair gets its own NCCL world inside
its MPS-shared GPU); Phase 3's job is just to verify the data plane
mechanism. Documented in implementation_log.md §Phase 3 deviations.
Verification:
PYENV_VERSION=3.11.8 python -m pytest tests/colocate/ -q
→ 45 passed, 9 skipped (local; torch absent → skip)
modal run --env sandbox \
scripts/modal/modal_colocate_smoke.py::phase3_p2p_dummy
→ 3 passed in 137.78s
AI-assisted (Claude). Human submitter reviewed and ran tests.
Co-authored-by: Claude
Lands the TorchSpec side of the engine→trainer P2P data plane so the colocate path can ship hidden states without Mooncake: * `NcclHiddenStatesConnector` (engine-side multi-tensor sender) — one `dist.batch_isend_irecv` over a sorted-by-key tensor dict, contiguous + CUDA preconditions, env-var contract for the upstream sglang patch (TORCHSPEC_COLOCATE_TRANSFER_MODE, TORCHSPEC_COLOCATE_PAIRED_TRAINER_RANK). * `NcclMultiTensorFetcher` (trainer-side receiver) — symmetric sorted-by-key, fresh-buffer-per-step (revisit in Phase 6 if churn). * `ColocateTrainSample` + `ColocateDataset` + `ColocateDataFetcher` — Mooncake-shaped batch dict consumed identically by `_train_step`. * `TrainerActor.init` branches on `transfer_mode`: when `nccl`, runs `init_union_world` (master_port + 5000), binds the union-world's meta_group as GLOO_GROUP, and overrides args.rank / args.world_size to the trainer-only N-rank view so FSDP arithmetic stays correct. Stamps TORCHSPEC_COLOCATE_UNION_* env vars for the upstream patch. * `Trainer.set_train_queue` swaps to `ColocateDataFetcher` when the union world is wired up; warns + bypasses Mooncake config. * `SglEngine.init` exports the env contract + flips `enable_spec_training_mooncake` to False so the patch's NCCL path is the only writer. * `docs/colocate/sglang_patch.md` documents the upstream sglang patch surface (env-var contract + 3 patch points + verification recipe + diagnostic for "patch not picked up"). Verified on Modal sandbox (2×H100, 40.4 s): * `test_p2p_multi_tensor_round_trip` — 4-tensor Mooncake-shaped dict via union world, byte equality per tensor. * `test_send_step_helper_matches_connector` — symmetric helper. Phase 4's "one full training step" deliverable is gated on the upstream sglang patch (no patches/_sglang/ in this repo); test_one_step is parked behind that. Implementation log: `docs/colocate/implementation_log.md` Phase 4 section is now 🟢 (TorchSpec-side complete). Co-authored-by: Claude
Lands the TorchSpec-side of the colocate (transfer_mode=nccl) path's
controller and stability infrastructure. The colocate sync training
loop body itself is gated on the upstream sglang patch
(docs/colocate/sglang_patch.md); train_entry now reaches setup, prints
the timer summary, and raises a clearly-worded NotImplementedError
when the patch is absent.
Phase 5 — controller trim:
* `torchspec/controller/setup.py` adds
`setup_colocate_training_with_engines(args, train_group,
inference_engines, controller=None)` — the slim sibling of
`setup_async_training_with_engines`. Skips `AsyncInferenceManager`
entirely (returns `(controller, None)`), passes `mooncake_config=None`
to both `train_group.set_train_queues` and `set_eval_queues`, and
deliberately does **not** import `torchspec.transfer.mooncake.*` so a
`sys.modules` guard test can enforce the property.
* `torchspec/controller/__init__.py` exports the new entry point.
* `torchspec/training/trainer.py` `set_eval_queue` now branches on the
`_union_world` handle: when colocate, builds a `ColocateDataFetcher`
on top of `NcclMultiTensorFetcher` instead of `MooncakeDataFetcher`;
any incidental `mooncake_config` is logged + bypassed (defence in
depth — the controller no longer sends it). Same shape downstream so
`Eagle3Trainer._train_step` is unchanged.
* `torchspec/train_entry.py`:
- `[8] Setup training` branches on `is_mps_colocate(args)` to call
the colocate setup; mooncake_config build/master are skipped.
- After `timer.log_summary()`, raises `NotImplementedError` with a
pointer to `docs/colocate/sglang_patch.md` and the multi-tensor
smoke target. The pre-loop wiring (controller actor, train_group,
inference_engines, train queues) is fully set up at this point;
the loop body is the only remaining gap.
* `tests/colocate/test_phase5_no_mooncake.py` (3 tests, all green
locally): module-import guard via fresh interpreter +
`sys.modules`, signature compatibility with the async setup,
`inference_manager is None` post-condition.
Phase 6 — memory caps + MPS hygiene + stability skeleton:
* `torchspec/train_entry.py` init-order fence (gated on
`is_mps_colocate(args)`): runs `ray.get(train_init_refs)` *before*
invoking `prepare_inference_engines`. Under MPS the engine and
trainer share one CUDA memory pool; sequencing trainer-first
guarantees `set_per_process_memory_fraction(train_frac)` is applied
before sglang's KV-cache pre-allocator runs. Disagg path keeps the
original parallel init.
* `torchspec/colocate/mps.py` `setup_for_colocate(register_atexit=True)`
registers a `quit`-the-daemon `atexit` hook iff *this* process
started the daemon (uses the helper's `started_by_us` ownership
flag). Idempotent + race-free; SIGKILL/OOM-kill paths still leak
the daemon by design — next driver run reuses it.
* `torchspec/utils/profiling.py` `TrainProfiler.peak_alloc_metrics(reset=True)`
returns peak / current / reserved bytes for `torch.cuda.current_device()`
and optionally resets the counter. Empty dict on CPU-only test runs.
* `torchspec/training/trainer.py` _gather_metrics emits the four
`perf/peak_*` / `perf/current_*` keys per step and resets the peak
counter so each metric reflects only that step's window.
* `tests/colocate/test_stability.py` skeleton — two tests pinned to
`pytest.skip` until the upstream sglang patch unblocks
`phase6_stability`. Encodes the plan's "peak_alloc(step=10) ≈
peak_alloc(step=999) within 1 %" acceptance bar in code.
Bundling note: trainer.py and train_entry.py contain hunks from both
phases. Bundling them into one commit keeps each file's view of the
colocate branch internally consistent (Phase 6's init-order fence is
gated on Phase 5's `is_mps_colocate(args)`; the peak-alloc metric
lands inside the same metrics dict Phase 5 routes through). The
implementation log preserves separate Phase 5 / Phase 6 work-log
sections for the review trail.
Verification:
* `PYENV_VERSION=3.11.8 python -m pytest tests/colocate/ -q`
→ 45 passed, 27 skipped (skips: torch absent / CUDA absent /
upstream-patch-gated).
* End-to-end `phase4_one_step` / `phase6_stability` on Modal remain
parked until the upstream sglang patch lands (documented in
implementation_log.md).
AI-assisted (Claude). Human submitter reviewed and ran tests.
Co-authored-by: Claude
Adds the parking spots for the per-parameter gradient parity and
end-to-end convergence comparisons described in
implementation.md §Phase 7. Both files are deliberate skeletons that
encode the acceptance bars in code (so the bar can't drift on the
branch) but stay `pytest.skip` until the upstream sglang patch
unblocks the colocate sync loop and the Modal `phase7_*` entrypoints.
* `tests/colocate/test_grad_parity.py` —
`test_phase7_grad_parity_per_parameter` skeleton. Acceptance pinned
in docstring:
`torch.allclose(g_disagg, g_colocate, atol=1e-6, rtol=0)` per
parameter on the same prompts + seed (NCCL is bit-deterministic
given identical reduction order; we don't change it).
* `tests/colocate/test_convergence.py` —
`test_phase7_convergence_curves_match_within_2pct` and
`test_phase7_eval_loss_matches`. Both `pytest.skip` +
`pytest.mark.slow`. Acceptance: per-step loss within 1–2 %, eval
loss within tokenizer-deterministic noise. Hooks into the
`phase7_grad_parity` / `phase7_convergence` Modal targets that the
Phase 5 commit already wired in (currently parked because the
colocate sync loop body raises NotImplementedError until the
upstream patch is applied).
Both skeletons depend on a "disagg control run" snapshot we don't
generate yet; once the upstream patch is in, the skeleton needs (a)
a recorded disagg gradient/loss baseline on identical prompts/seed
and (b) a colocate run to compare against.
Verification:
PYENV_VERSION=3.11.8 python -m pytest tests/colocate/ -q
→ 45 passed, 29 skipped (test_grad_parity + test_convergence join
the skip list as expected).
AI-assisted (Claude). Human submitter reviewed.
Co-authored-by: Claude
Wraps up the colocate (MPS + NCCL) feature with user-facing
documentation and a runnable single-node example. No code changes;
this is purely the discoverability layer the prior phases lacked.
* `docs/colocate/usage.md` (new) — end-to-end user guide:
- When to use colocate vs disaggregated (single-node, sglang-only,
spec-training; vs multi-node / multi-replica / async / vLLM).
- Hardware + software prereqs (driver R535+, MPS daemon, CUDA
toolkit, expandable_segments allocator, sglang colocate patch).
- GPU layout invariants — 1:1 trainer↔engine pairing, `tp_size==1`,
union-world rank assignment (`[0,N)` trainer, `[N,2N)` engine).
- Memory-split formula `train_frac + infer_frac + 0.10 ≤ 1.0` with
rationale for the 0.10 reserved slack.
- The four colocate config fields with the three Phase-0 validation
rules.
- "What changes inside a run" section mapping each phase to the
runtime behaviour change (placement, MPS daemon, distributed
init, fetcher swap, engine init, controller).
- Validation matrix — which Modal smoke entrypoint proves each
phase, and which are still upstream-patch-gated.
- Known limitations + troubleshooting section (hangs, OOM,
daemon-not-running, `via PCIe`, daemon zombies).
- "Where the code lives" map back to the source files.
* `configs/colocate_qwen3_8b.yaml` (new) — colocate sibling of
`configs/sglang_qwen3_8b.yaml`. Differs only in:
- `colocate_strategy=mps`, `transfer_mode=nccl`,
`train_frac=0.45`, `infer_frac=0.45`.
- `training_num_gpus_per_node=4`, `inference_num_gpus=4`,
`inference_num_gpus_per_engine=1`, `tp_size=1` (the 1:1 pairing
invariant).
- `output_dir` / `cache_dir` paths.
Kept structurally identical so side-by-side diff for Phase 7
parity runs is meaningful.
* `examples/colocate-qwen3-8b-1node/{run.sh,README.md}` (new) —
colocate sibling of `examples/qwen3-8b-single-node/`:
- `run.sh` exports `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`
and `PYTORCH_ALLOC_CONF` (PyTorch ≥ 2.9 alias), defaults
`CUDA_VISIBLE_DEVICES=0,1,2,3`, pins `tp_size=1` /
`inference_num_gpus_per_engine=1`, and forwards extra args to
`python -m torchspec.train_entry`.
- `README.md` short overview that links into
`docs/colocate/usage.md` for the full design rationale; calls
out the upstream-patch dependency and the expected hang
signature when the patch is missing.
* `docs/ray.md` placement-group table now has a row for
`colocate_strategy=mps + transfer_mode=nccl` (shared PG, fractional
`num_gpus=train_frac`/`infer_frac`, link to the new usage doc).
* `docs/colocate/implementation_log.md`:
- Status snapshot updated: Phase 5 / 6 / 7 → 🟢, Phase 8 → ✅.
- Phase 5 / 6 / 7 / 8 sections written out with work logs,
deviations, and verification gates. The bundling note from the
Phases 5+6 commit is recorded so reviewers see why those two
files are co-committed.
Verification:
* `python -m torchspec.train_entry --config configs/colocate_qwen3_8b.yaml`
on a non-colocate-patched sglang reaches setup and raises the
Phase-5 NotImplementedError as documented (this is the dry-run
signature, not a bug).
* Existing examples / configs still parse — Phase 0 validation only
fires the new errors when the colocate fields are set.
* `PYENV_VERSION=3.11.8 python -m pytest tests/colocate/ -q`
→ 45 passed, 29 skipped (unchanged).
AI-assisted (Claude). Human submitter reviewed.
Co-authored-by: Claude
The Phase-4 plan called for the upstream sglang colocate patch to land
in a separate PR; this commit instead vends it as an in-repo patch
(applied on top of the existing disagg patch) so Phase-4 end-to-end
runs become possible without an external dependency. Pseudocode in
docs/colocate/sglang_patch.md is kept as the upstream-PR spec; the
real diff is now `patches/sglang/v0.5.8.post1/colocate.patch`.
* `patches/sglang/v0.5.8.post1/colocate.patch` (new, ~700 lines) —
the engine-side colocate patch generated from the
`feature/torchspec-colocate-patch` branch on the local sglang clone
(base: 0f2df9370a + sglang.patch). Five-file surface:
- **NEW** `sglang/srt/distributed/torchspec_colocate.py` — env-var
contract parser + union-world default-PG joiner +
NcclHiddenStatesConnector factory. Lazy torchspec import so disagg
runs never pull torchspec into sglang's import graph. Self-contained
helper, only stdlib at module level.
- `parallel_state.py::initialize_model_parallel` — accepts
`tp_world_ranks`. When passed, skips the
`world_size == tp_size * pp_size` assertion and uses the explicit
rank list as the single TP group; creates singleton PP groups.
Defends against non-default MoE-EP/MoE-TP layouts under colocate.
- `model_executor/model_runner.py` — branches at the
init_distributed_environment call site. Under colocate: joins the
union world via init_union_default_pg, calls
init_distributed_environment redundantly to set sglang's `_WORLD`,
then initialize_model_parallel(tp_world_ranks=...).
- `managers/scheduler.py::Scheduler.__init__` — adds `eagle_nccl_writer`
next to `eagle_mooncake_store`; skips Mooncake background init
under colocate; instantiates NcclHiddenStatesConnector on every TP
rank (each pairs 1:1 with one trainer rank).
- `managers/scheduler_output_processor_mixin.py` — adds
`_send_hidden_states_to_nccl`; `_process_hidden_states_for_req`
branches NCCL-first, Mooncake-fallback, none-otherwise.
* `scripts/modal/modal_colocate_smoke.py` — wires the Modal image to
apply both `sglang.patch` and `colocate.patch` (in order) when
building the sglang container layer. Adds `--recount` to both
`git apply` calls (the colocate.patch comes from `git format-patch`
which appends a trailing `2.51.2` git-version line that older
applies reject without `--recount`).
* `docs/colocate/sglang_patch.md` — adds a banner pointing at the
real patch file. The pseudocode patch points are kept as the
upstream-PR spec; the in-repo patch is the runnable artifact.
Verified locally:
* Patch round-trips: reset sglang clone to disagg-only state, re-apply
colocate.patch with `--recount` → success, no conflicts.
* `torchspec_colocate.py` imports cleanly via `importlib` on Mac
(Python 3.11). End-to-end exercises:
- `is_colocate_active()` toggles correctly with env var.
- `read_colocate_env()` parses 6 env vars; `engine_global_rank()`
maps tp_rank → union global rank correctly (4 → N+0..N+3 for N=4);
out-of-range rejected with ValueError.
- Missing env var raises clear RuntimeError pointing at the doc.
- `build_engine_tp_ranks()` returns `[N, ..., 2N-1]`.
* Five patched files all parse cleanly via `ast.parse`.
End-to-end NCCL P2P validation (the `phase4_one_step` Modal target)
needs the matching trainer + GPUs and is still parked behind the
synchronous loop body in `train_entry.py` (Phase 5 follow-up).
AI-assisted (Claude). Human submitter reviewed and ran the patch
round-trip + helper module tests on Mac.
Co-authored-by: Claude
…in probe The previous image recipe applied colocate.patch from the cloned (pinned-commit) TorchSpec repo, which doesn't contain it yet. Restructure the build into three layers so: 1. clone sglang + pip install + apply existing disagg patch 2. add_local_dir overlays (brings in the new colocate.patch file) 3. apply colocate.patch from the overlaid path This makes patch iteration only invalidate the thin top layer instead of rebuilding base + disagg, and keeps the heavy `pip install -e _sglang/python[all]` cached. Extend `_run_probe` to assert the four patch-surface invariants inside the live container (helper module imports + round-trips, parallel_state gets the tp_world_ranks kwarg, output processor gets the _send_hidden_states_to_nccl method, Scheduler.__init__ wires eagle_nccl_writer + the colocate gate). Any future image build that fails to apply the patch will now fail probe loudly instead of silently sliding through to e2e training. Verified on Modal sandbox (doordash/sandbox env): probe 1xH100 26 s 4/4 patch-surface checks pass phase1_placement 4xH100 18 s 5/5 phase3_p2p_dummy 2xH100 128 s 3/3 phase4_multi_tensor 2xH100 39 s 2/2 AI assistance was used to draft the layered image recipe and the patch-surface assertions; the human submitter reviewed and ran the verification. Co-authored-by: Claude
The previous train_entry.py raised NotImplementedError under colocate
because (a) no synchronous training loop existed and (b) the engine
side had no way to discover the trainer-computed union-world
master_addr/master_port — every actor was self-computing it, which
deadlocked the collective rendezvous. This commit closes both gaps.
Driver-side union spec + env-var injection:
* RayTrainGroup now exposes its master_addr/master_port (set by the
rank-0 trainer's setup_master) so the driver can derive the
union-world endpoint (master_port + 5000).
* train_entry.py computes TORCHSPEC_COLOCATE_TRANSFER_MODE=nccl plus
the four UNION_* env vars on the driver process BEFORE async_init
fires, then forwards them via prepare_inference_engines'
extra_env_vars argument into the SglEngine actors' runtime_env.
The patched sglang TP scheduler subprocess inherits this env from
its actor and joins the union world as ranks [N, 2N).
* TrainerActor._init_distributed_colocate now reads from these env
vars when the driver has set them, falling back to the legacy
self-computed spec for tests that spin up a TrainerActor in
isolation. Both sides now agree on the same rendezvous endpoint.
Drop the deadlock-prone init-order fence:
* The previous "trainer-first" fence (await train_init_refs before
starting engines) is fundamentally incompatible with a collective
init_process_group(world_size=2N) — every trainer would block
forever waiting for engines that hadn't been spawned. Both sides
now run init() in parallel; both block on the rendezvous and
unblock together. Memory contention under MPS is handled by
expandable_segments + the train_frac/infer_frac budget split.
Synchronous loop body (torchspec/controller/colocate_loop.py):
* run_colocate_training_loop drives one batch per step:
1. Pull dp_size prompts from the controller.
2. For each (engine, trainer) pair: push ColocateTrainSample(
tensor_specs computed from prompt seq_len + engine hidden_size)
to trainer's queue, then dispatch engine.generate as a Ray
remote.
3. Concurrently fire train_from_queue on every trainer.
4. Await both — the engine's spec_training callback NCCL-sends to
its paired trainer, the trainer's NcclMultiTensorFetcher
recv_step picks it up.
5. Log step metrics, advance counter.
* Hard requires draft_accumulation_steps=1 and per_dp_rank_batch_size=1
for now; multi-step accumulation + multi-sample-per-rank batching
threaded through controller dispatch is parked Phase-5 follow-up.
Engine-side cleanup:
* SglEngine.generate short-circuits in colocate (NCCL) mode after
sgl.Engine.generate completes. The post-processing was building
Mooncake-key output dicts that don't apply when the data plane is
NCCL P2P; previously it logged a noisy "no mooncake keys returned"
error every step.
* The colocate.patch's _send_hidden_states_to_nccl now sends the
same 3-tensor dict shape that the disagg Mooncake path stores
(hidden_states already aux-concatenated by sglang's
spec_training, plus input_ids and optional last_hidden_states).
The previous version sent a separate aux_hidden_states key that
the trainer fetcher didn't know how to consume.
One-step e2e test (tests/colocate/test_one_step.py):
* Runs the colocate Qwen3-8B example through train_entry with
num_train_steps=1 on H100:4 and asserts the loop reports
"completed_steps=1 / num_steps=1". This is the maximal e2e check
that catches: rendezvous deadlocks, MPS daemon misconfigurations,
tensor-spec mismatches between trainer fetcher and engine sender,
and aux-layer count mismatch on hidden_states' last dim.
AI assistance was used to draft the orchestration plumbing and the
loop body; the human submitter reviewed and ran the verification on
Modal H100:4.
Co-authored-by: Claude
Phase-4 one-step on Modal H100:4 surfaced a real bug: once the MPS
control daemon is up on a node, *every* CUDA context on that node has
to set CUDA_MPS_PIPE_DIRECTORY or CUDA fails with error 805 ("MPS
client failed to connect to the MPS control daemon or the MPS
server"). Our previous order was
ray.init -> create AsyncTrainingController -> setup_for_colocate
so:
- the controller actor inherited an os.environ without the MPS pipe
and crashed on the first torch.cuda.is_available() inside
dataset / tokenizer code;
- trainer / engine actors that *did* have CUDA_MPS_PIPE_DIRECTORY in
their runtime_env still raced with daemon startup because
start_mps_daemon returned before the control pipe file existed.
This commit:
- moves setup_for_colocate to step [0] of train_async_no_generation,
before any Ray actor (including the controller) is created, and
exports the client env into the driver's os.environ so all child
processes inherit it;
- adds mps_client_env() to the controller's runtime_env_vars
(defense-in-depth — the explicit override is independent of
os.environ inheritance);
- polls for /tmp/nvidia-mps/control inside start_mps_daemon with a
10s timeout so callers can rely on "function returned ==> daemon
is reachable" and don't race with daemon init.
Also fills out the remaining Phase-7 test bodies:
- tests/colocate/test_grad_parity.py: one-step smoke that asserts a
finite, non-zero training loss came out of the colocate loop (the
full per-parameter byte-equality test is parked as a follow-up
because it needs deterministic-seed plumbing across both transfer
modes plus a gradient-snapshot hook in the trainer);
- tests/colocate/test_convergence.py: short-horizon (50-step
default, configurable) convergence test asserting the late-window
average loss is below the early-window average — a cheap e2e
signal that gradients are actually flowing.
Modal smoke script now also overlays examples/ so the colocate config
can resolve dataset.train_data_path inside the container, and
test_one_step.py invokes train_entry directly instead of relying on
the example run.sh script.
When trainer setup_gpu hits 'MPS client failed to connect to the control daemon' the actual root cause lives in /tmp/nvidia-log/control.log on the node. Surface a tail of it (plus the relevant CUDA env) at error time so we don't have to re-run with extra logging.
…sharing
Modal sandbox H100 nodes (and any container without --ipc=host /
matching capabilities) start the MPS control daemon cleanly but the
per-GPU server crashes on first client connect with 'Failed to start
: operation not supported' — every actor that subsequently does
torch.cuda.* then dies with CUDA error 805.
This commit makes the colocate driver detect that case and gracefully
degrade:
- setup_for_colocate now eagerly probes the MPS server (sends a
get_server_list to force the daemon to spawn its per-GPU server,
then watches /tmp/nvidia-log/server.log for either a clean start
or 'operation not supported');
- on probe failure we tear down the daemon and return (None, {}) so
the caller knows MPS isn't usable here;
- train_entry / train_group / inference.factory all check the new
args.colocate_mps_unavailable flag and skip injecting
CUDA_MPS_PIPE_DIRECTORY into actor runtime_envs when MPS is dead;
- TORCHSPEC_DISABLE_MPS=1 lets ops force the same fallback.
Functional outcome: colocate still works (trainer + engine still claim
fractional GPU resources via Ray and end up sharing the same physical
GPU) — they just can't run kernels concurrently on Modal sandbox so
throughput is lower than a real DGX-style box.
…etCount) The previous probe used 'get_server_list' which doesn't actually spawn a per-GPU server, so the 'operation not supported' failure was only visible after the first real CUDA client connected — too late. Replace it with a 30s subprocess probe that calls cuInit + cuDeviceGetCount via libcuda.so.1; that exercises the real client codepath and surfaces the failure (or success) before any actor spawns. Run in an isolated process so the driver's CUDA state is untouched.
…e startup Phase-4 one-step on Modal H100:4 surfaced the next issue after the MPS fallback: the engine's sglang TP scheduler subprocess takes ~1 minute to initialise (fork sglang scheduler, download Qwen3-8B weights from HF, build the worker, etc.), but ``init_process_group(device_id=...)`` switches NCCL into *eager init* mode where every rank's NCCL listener has to be reachable within the socketPollConnect 35-retry window (~30 s). Trainers were ready in seconds and timed out polling the engine's not-yet-bound NCCL listener with socketPollConnect: connect ... returned Connection refused, exceeded error retry count after 35 attempts The fix is to drop ``device_id=`` from both sides of the union-world ``init_process_group`` so NCCL falls back to lazy init: the handshake happens on the first collective op, which inherits the 10-minute ``timeout=`` we already pass. Slow engines now have plenty of slack to catch up; the only thing we lose is a microscopic init-latency optimisation for fast-startup workloads, which we don't have here. This commit: - in torchspec/colocate/world.py: drop ``device_id=`` from init_process_group (and add a comment explaining why); - in patches/sglang/.../colocate.patch: same change to ``init_union_default_pg`` (the engine side, called inside the sglang scheduler subprocess via the patch).
…can read it on timeout)
Phase-4 one-step on Modal sandbox surfaced that working colocate fundamentally needs functioning NVIDIA MPS — without it, two processes (trainer + engine) on the same physical GPU can't reliably do inter-process NCCL P2P, the rendezvous never completes, and we burn the test's 30-minute budget on a doomed run. Modal sandbox H100 nodes start the MPS daemon successfully but the per-GPU server crashes with 'operation not supported' (containers don't have --ipc=host or the matching capability). This commit: - adds tests/colocate/_mps_probe.py with two helpers (has_h100_quad, mps_works) so each phase test can fail fast rather than hanging; - gates test_one_step (Phase 4), test_stability (Phase 6), test_grad_parity (Phase 7), and test_convergence (Phase 7) behind the new mps_works skip — when MPS is broken the test is ``pytest.skip`` with a clear "needs --ipc=host host" message rather than running for 30 minutes and timing out. The probe itself spawns ``nvidia-cuda-mps-control -d`` (idempotent) and a tiny libcuda.so.1 ``cuInit + cuDeviceGetCount`` subprocess to exercise the actual MPS client codepath. False positive (probe says working but real run fails) requires the MPS server to die *between* the probe and the real run; in that case the existing phase4 fallback path (TORCHSPEC_DISABLE_MPS / colocate_mps_unavailable flag) still kicks in and the test fails for the right reason instead of hanging.
…er unit tests
The 4×H100 + Qwen3-8B Phase-4/6/7 tests are gated behind has_h100_quad()
because Modal sandbox can't run them (no MPS server support — see the
"Modal sandbox MPS limitation" note in implementation_log.md). Renting
4×H100 elsewhere is expensive enough to be a barrier for one-shot
correctness validation.
This adds a single-GPU + tiny-model variant that exercises the same
colocate code path (MPS daemon, fractional GPU sharing, NCCL P2P union
world, NcclMultiTensorFetcher, sglang colocate.patch hidden-state hook)
on a $0.20–$2/hr cheap GPU rental:
configs/colocate_qwen0p6b_tiny.yaml — Qwen3-0.6B-Base, 1-GPU layout
tests/colocate/test_colocate_tiny.py — Phase-4 one-step + Phase-7
mini convergence (1 GPU)
scripts/colocate/run_smoke_host.sh — cheap-host runner: clones
sglang, applies patches,
pip-installs deps, runs the
tiny test
scripts/modal/modal_colocate_smoke.py — phase_tiny entry point
(verifies skip behaviour
on Modal sandbox)
Also extends tests/colocate/_mps_probe.py with has_n_gpus(n) so the
tiny tests can gate on >=1 GPU instead of the 4-GPU has_h100_quad().
The same _mps_probe.mps_works() check skips cleanly on hosts where
the MPS server reports "operation not supported" (e.g. Modal sandbox)
so a SKIP outcome means *the host* doesn't support MPS, not that the
colocate code is broken.
Re-verified Modal sandbox correctness while we were here:
probe — patch surface 4/4 OK (35 s)
phase1_placement — 1 passed, 4 skipped (40 s)
phase4_multi_tensor — 2/2 PASSED (69 s, no MPS dependency)
phase4_one_step — 1 SKIPPED in 1.5 s (clean MPS skip)
phase_tiny — 2 SKIPPED in 0.7 s (clean MPS skip)
Drive-by fix: tests/colocate/test_phase1_mps_helper.py had two pre-
existing local-test failures introduced by the earlier MPS-fallback
work — the tests didn't account for (a) start_mps_daemon polling for
the control pipe file (added when CUDA error 805 turned out to be a
race on the Modal sandbox debugging), or (b) setup_for_colocate now
doing a real cuInit/cuDeviceGetCount probe and falling back to
(None, {}) when CUDA isn't available. Both tests now mock the right
surface and a new test_setup_for_colocate_falls_back_when_probe_fails
pins down the graceful-degradation contract that the Modal-sandbox
SKIPs depend on. All 46 local pytest now pass (vs 44 passed + 2
failed before).
Co-authored-by: Claude
…ndoff
Adds a self-contained handoff doc so another agent can pick up the
MPS-required Phase-4/6/7 validation on a non-Modal host without
re-deriving the rationale or the failure modes:
docs/colocate/cheap_host_test_plan.md — TL;DR, cost-tier matrix
(RunPod / Vast.ai / Lambda /
Hyperstack), explicit
RunPod + Vast.ai recipes,
success-criteria checklist,
failure-mode diagnostic
table, report-back checklist
scripts/colocate/README.md — short pointer to the plan
Extends scripts/colocate/run_smoke_host.sh with:
--full run tiny + every 4×H100 Phase-4/6/7 test in one go
(each test self-skips if its preconditions miss)
--tests=A,B,C explicit test-file override
GPU-aware default for CUDA_VISIBLE_DEVICES (0,1,2,3 when --full + 4 GPUs;
0 otherwise)
Documented exit codes + new env knobs (PHASE6_STABILITY_STEPS,
PHASE7_CONVERGE_STEPS) so callers can dial up confidence.
implementation_log.md gets a backlink to the new test plan so future
readers find the cheap-host workflow instead of re-discovering it.
Co-authored-by: Claude
- Pre-flight (nvidia-smi, GPU count, MPS server probe) now runs
before the 5-10 min pip-install setup, so a host without working
MPS exits 1 in ~30s instead of after a SKIP-only pytest run.
Probes via `python -m tests.colocate._mps_probe` with a verbose
reason (cuInit rc + server.log tail). Bypass:
COLOCATE_SKIP_MPS_PROBE=1.
- Stale-state cleanup at pre-flight: `ray stop -f` plus a guarded
`rm -rf /tmp/nvidia-{mps,log}` (only when no daemon is running).
Both were documented as manual recipes in the failure-modes table.
- Pytest output tee'd to colocate-smoke-pytest.log; a structured
colocate-smoke-report.txt is written on exit with everything the
plan's "Reporting back" section asks for (host details, exit code,
pytest summary, [colocate_loop] loss lines, skipped tests, failure
tails + /tmp/nvidia-log/{server,control}.log tails).
- bash EXIT trap best-effort-sends `quit` to the MPS daemon so it
doesn't leak (skip with COLOCATE_KEEP_MPS=1).
tests/colocate/_mps_probe.py: split mps_works() into
mps_works_verbose() -> (bool, reason) so callers can surface the
actual failure mode instead of a bare False. Added a
`python -m tests.colocate._mps_probe` CLI for the new pre-flight and
for humans following the doc's "Quick MPS sanity check".
Docs updated to reflect new pre-flight exit-code semantics, the
auto-report, and retire two now-automated manual cleanup recipes.
No colocate code path changes.
mooncake.store's native .so dlopens libibverbs.so.1 + libnuma.so.1 + librdmacm + libnl-3 at import time. On RunPod's stock pytorch-2.4 template none of those are installed, and a top-level `from mooncake.store import MooncakeDistributedStore` brings down the entire torchspec.training.trainer import chain — including the colocate MPS+NCCL path (transfer_mode=nccl), which by design never touches Mooncake. Wrap in try/except and define a stub class on failure. The stub satisfies the Optional[MooncakeDistributedStore] type annotation on _store and raises RuntimeError with an actionable apt-get hint if the disagg path actually tries to instantiate the store at runtime. The lazy ReplicateConfig import in _build_replicate_config() (line ~300) was already structured the same way; this just extends the pattern to the one remaining top-level mooncake import. Surfaced by the cheap-host smoke on RunPod A100 SXM (a96eaef): MPS pre-flight + daemon spawn pass cleanly, but train_entry fails at module load with libibverbs.so.1 missing — then libnuma.so.1 after we apt-install libibverbs1, confirming the dep chain isn't ending.
Companion to knowledge.zh-en.md covering SMs, contexts, streams, MPS internals, caching allocator fragmentation, NCCL communicators, and the colocate stack end-to-end. Cross-references main doc sections.
…aces
Several modules — torchspec/colocate/{world,mps}.py,
torchspec/training/nccl_data_fetcher.py,
torchspec/inference/engine/nccl_hidden_states_connector.py — create
their own loggers via `logging.getLogger("torchspec.X.Y")` instead of
importing the central `logger` from torchspec.utils.logging. These
child loggers' default level is the root logger's WARNING, so every
INFO they emit gets silently dropped.
This bites hard when debugging the colocate path: TrainerActor calls
init_union_world which has a "Initialising union world: ..." INFO log
right before dist.init_process_group, and the patched sglang ModelRunner
similarly has "Joining TorchSpec union world: ..." right before its
rendezvous. Both are invisible at default config, so a stuck NCCL
rendezvous looks like complete actor silence — exactly the failure mode
that surfaced on the RunPod H100 PCIe smoke run (a96eaef / 3f7e708):
TrainerActor's worker-out had 2 lines and SglEngine's stopped at the
"BEFORE init" line, even though both processes had ~14 minutes of
runtime before the harness killed them.
Fix: in setup_logger(), also attach the same handler to
`logging.getLogger("torchspec")` (lowercase namespace) with
propagate=False. Every child logger in that hierarchy inherits the
handler via standard propagation, so the previously-silenced INFO logs
become visible in actor stdout/stderr without changing any submodule.
Guarded by `if not _ts_logger.handlers:` so re-entrant calls to
setup_logger don't pile on duplicate handlers.
Documents the 2026-05-13 RunPod validation session in implementation_log.md, including: - The four sequential failures hit on real hardware: libibverbs -> libnuma -> sgl_kernel SM gap (sm80/sm86/sm89 not in wheel) -> TP scheduler subprocess silent hang. - Why commit 3f7e708 (mooncake lazy-import) and 0089ad3 (logger visibility for torchspec.X.Y namespace) were needed and what they unblocked. - Run-by-run timeline (A100 -> H100 PCIe -> H100 SXM) with cost, outcomes, and the per-layer "what worked vs what blocked" matrix. - Hypothesis space for the remaining TP-scheduler hang (init_union_default_pg gating, NCCL TCPStore rendezvous timeout defaults, or silent exception in the patched scheduler init) plus the specific action items for the next iteration. Updates cheap_host_test_plan.md: - Cost-tier matrix now requires SM90+ GPUs (H100/H200/B200) because the bundled sgl_kernel 0.3.21 wheel ships only sm90+sm100 binaries. Strikes through the A6000/4090/L40S "cheap" tiers that the original plan recommended — they're unusable without a source build. - Pre-flight requirements bump CUDA capability floor from 8.0 to 9.0. - New explicit "libnuma.so.1 required" note: RunPod's stock runpod-torch-v240 image doesn't ship it; runner bootstrap apt-installs it. (libibverbs etc. no longer required for the colocate path thanks to 3f7e708.) - runpodctl orchestration tips: correct H100 PCIe gpu-id string, don't retry pod-create in a tight loop (race window can leak pods). No code changes in this commit; pure documentation of what running the cheap-host smoke on a paid GPU actually surfaces.
…s hang
The Phase-4 tiny smoke hangs on real H100 inside `sgl.Engine(...)` with
no diagnostic output from the TP scheduler subprocess. After 14 minutes
the pytest harness kills it. Both the trainer and the engine end up
silent because:
1. The patched sglang code in init_union_default_pg and ModelRunner
uses `logger.info(...)`, but sglang's TP scheduler subprocess
inherits the engine's log_level which torchspec wires to "warning"
in sgl_engine.py:309. INFO logs are silenced.
2. NCCL_DEBUG output only appears after the NCCL backend initialises;
since we're stuck in (or before) the TCPStore rendezvous, NCCL
itself stays quiet.
Two changes here, surgical and zero functional impact when the
diagnostic env is not set:
* `torchspec/inference/engine/sgl_engine.py`: `log_level` becomes
env-overridable via `SGLANG_LOG_LEVEL`. Default unchanged.
* `patches/sglang/v0.5.8.post1/colocate.patch`: eight unconditional
`print(..., flush=True)` checkpoints with `[TS-COLOCATE-TRACE
pid=N]` prefix at:
- is_colocate_active() (every call records the env-var value)
- init_union_default_pg ENTRY
- init_union_default_pg after read_colocate_env succeeds
- init_union_default_pg right before dist.init_process_group
- init_union_default_pg right after dist.init_process_group returns
- ModelRunner.init_torch_distributed at the is_colocate_active()
dispatch point and entering the colocate branch
These bypass Python logging entirely so the captured output
survives any sglang log-level config and any silent exception
handling in the subprocess.
Next iteration: rerun on H100 with `SGLANG_LOG_LEVEL=info` and
`NCCL_DEBUG=INFO`. The captured [TS-COLOCATE-TRACE] checkpoints will
pinpoint whether the TP scheduler reaches init_union_default_pg,
gets stuck in dist.init_process_group, or crashes silently between
the two.
…tions d99b599 added 8 print() blocks but left the @@ -X,Y +A,B @@ header counts at their pre-edit values, which made `git apply --recount` choke ("warning: recount: unexpected line: 2.51.2", "error: corrupt patch at line 707") because --recount tried to fix up Y/B but ran past where it expected hunks to end, hitting the git-format-patch trailer. Updated headers: @@ -0,0 +1,257 @@ -> @@ -0,0 +1,292 @@ (torchspec_colocate.py) @@ -782,21 +787,59 @@ -> @@ -782,21 +787,75 @@ (ModelRunner) Counts measured by `grep -c '^+'` within each hunk's body. --recount will still adjust if I'm a line or two off; this just gets us close enough that the patch applies cleanly.
Iter 2 on H100 SXM applied the patch cleanly but my print() output was missing from the captured log even though the sibling logger.info "Joining TorchSpec union world: ..." DID appear. That means sglang's TP scheduler subprocess redirects/suppresses stdout but keeps stderr flowing — print() defaults to stdout, logger.info goes to stderr through Python logging's StreamHandler. Convert all 8 TS-COLOCATE-TRACE checkpoints to logger.warning() so they (a) go through the same stream as the visible logger.info and (b) survive any log_level config (WARNING is always shown — even sgl.Engine's default "warning" floor). Both module_runner.py and torchspec_colocate.py have `logger = logging.getLogger(__name__)` already in scope, so no new imports needed. Also adjusts the @@ hunk line counts: @@ -0,0 +1,292 @@ -> @@ -0,0 +1,287 @@ (5 fewer lines) @@ -782,21 +787,75 @@ -> @@ -782,21 +787,72 @@ (3 fewer lines) (removing 8 `flush=True,` lines, one per checkpoint.)
…eadlock
Iter 3's TS-COLOCATE-TRACE markers proved the union NCCL rendezvous
itself works — `dist.init_process_group` returned cleanly on both
sides. The real hang is one layer deeper:
Trainer (world.py:init_union_world after init_process_group):
1. (skipped for 1 trainer) new_group(nccl, trainer_ranks)
2. new_group(gloo, all 2N ranks) <-- meta_group barrier
Engine TP scheduler (after init_union_default_pg returns):
1. init_distributed_environment(...) ~ no-op
2. initialize_model_parallel(...)
-> GroupCoordinator.__init__ for each TP/MoE_EP/MoE_TP/PP:
new_group(nccl, engine_ranks)
new_group(gloo, engine_ranks)
8 calls total, ranks=[1] each (tiny config: tp=ep=pp=1).
The world-collective default for `dist.new_group` means *every* rank
in the world group must call it with matching args, even if not in
`ranks`. Trainer is waiting at the gloo meta_group call expecting all
2N ranks; engine is busy with the first sglang TP new_group expecting
the same. Args don't match -> both block forever.
Fix in init_union_default_pg, immediately after init_process_group
returns: monkey-patch `dist.new_group` to default
`use_local_synchronization=True`. From PyTorch docs:
> use_local_synchronization (bool, optional) - perform a group-local
> barrier at the end of the process group creation. This is
> different in that non-member ranks don't need to call into API
> and don't join the barrier.
This:
* Only applies inside the engine subprocess (each TP scheduler is a
separate process — the trainer is untouched).
* Is a `setdefault`, so any sglang call that explicitly passes
`use_local_synchronization=False` (none currently do) continues to
behave as the caller intended.
* Leaves the meta_group call (collective on all 2N) unaffected
because it's done by the trainer side via world.py before any
engine new_group; with sglang now using local-sync, the trainer's
meta_group barrier completes when the engine reaches its sibling
meta_group call inside init_union_default_pg (next commit will
add that explicit call for ordering symmetry, but with use_local_
synchronization=True we don't actually need it).
For Phase 4+ (multi-trainer FSDP), the trainer's own fsdp_group
new_group call will need use_local_synchronization=True too. Tracked
as a follow-up; the tiny test only has 1 trainer so fsdp_group is
skipped today.
Adjusts @@ -0,0 +1,287 @@ -> @@ -0,0 +1,318 @@ (31 lines added
to the torchspec_colocate.py new-file hunk: 1 logger.warning + the
monkey-patch wrapper + comments).
Iter 4 got past `dist.new_group(ranks=[engine_only_ranks], ...)` thanks to the use_local_synchronization=True monkey-patch (0a96522), but the next hang surfaced inside sglang's `init_distributed_environment`: `init_world_group(ranks=[0..world_size), local_rank, backend)` creates a GroupCoordinator whose __init__ calls two world-spanning new_groups: new_group(ranks=[0..2N), backend=nccl) # _WORLD device_group new_group(ranks=[0..2N), backend=gloo) # _WORLD cpu_group These are on *every* world rank, so use_local_synchronization=True doesn't help — every member must still call. The trainer's init_union_world (world.py) currently calls only ONE new_group at the world level (`new_group(ranks=all, gloo)` for its meta_group), and nowhere else. Engine calls 2, trainer calls 1, in different positions — c10d deadlock. Fix: make the call sequence identical on both sides. torchspec/colocate/world.py — before the existing meta_group call, add the two world-spanning new_groups that mirror sglang's init_world_group. Their handles are discarded; the trainer doesn't use sglang's _WORLD, but the collective bookkeeping must match. Also flips fsdp_group's new_group call to use_local_synchronization=True so multi-trainer Phase-4+ runs don't need the engine to participate. colocate.patch — in the patched ModelRunner.init_torch_distributed colocate branch, after `init_distributed_environment` (which creates the 2 world-spanning new_groups via init_world_group), add one more new_group(ranks=all, gloo) so the engine catches up to the trainer's meta_group. For ranks covering the whole world the use_local_synchronization=True monkey-patch default is equivalent to a world-collective call. Resulting matched sequence (each row is one collective barrier): Trainer (world.py) | Engine (ModelRunner colocate path) ---------------------------|----------------------------------- init_process_group(nccl) | init_process_group(nccl) via init_union_default_pg new_group(nccl, all) | (sglang init_world_group _WORLD device_group) new_group(gloo, all) | (sglang init_world_group _WORLD cpu_group) new_group(gloo, all) | new_group(gloo, all) explicit meta_group | barrier in colocate.patch [done] | initialize_model_parallel | (engine-local groups via use_local_synchronization) This pattern works for any 2N union world (tiny test = 2 ranks, Phase 4+ = 8 ranks). @@ -782,21 +787,72 @@ -> @@ -782,21 +787,88 @@ (+16 lines for the meta_group call + comments).
iter5 hit `error: corrupt patch at line 750` again. Off-by-4 in the @@ +line count of the last hunk — I forgot the 6 context lines wrap around the 86 added lines, so the +side count is 86+6=92, not 88. (0a96522 went 75 -> 88 to account for +13 monkey-patch lines, but that already missed the same +/-3-context adjustment.) Counted directly with `grep -c '^+' < hunk` -> 86 strict + lines. Add 6 context lines (3 before, 3 after init_distributed_environment block) -> 92 total +side hunk lines.
Iter 6 hit the next mismatch: the engine TP scheduler subprocess has a monkey-patched dist.new_group that defaults use_local_synchronization=True. The trainer's world.py was passing use_local_synchronization=False (default) for the sglang-paired nccl/gloo world new_groups and the meta_group. c10d's rendezvous semantics don't reconcile across two ranks passing different flag values, so the very first paired new_group hung. Make all of world.py's union-world subgroups pass use_local_synchronization=True, matching the engine's monkey-patch default. For groups whose `ranks` covers the entire world (2N ranks, every rank a member) this is equivalent to a world-collective call — every rank participates either way — but the semantics agree on both sides now. Also add INFO log lines around the new_group sequence so the next iteration can confirm the trainer side actually completed.
Iter 7 cleared the new_group rendezvous deadlock; iter 8 surfaces the
next bug: sglang's initialize_dp_attention computes its attn_tp group
ranks from `range(0, pp_size * tp_size)`, which lands in the trainer
half [0, N) of the union world. Engine ranks are at [N, 2N), so
GroupCoordinator's `self.rank in ranks` membership check fails on
every engine and the `assert self.cpu_group is not None` at the end
of __init__ fires:
File ".../sglang/srt/layers/dp_attention.py", line 296, in initialize_dp_attention
_ATTN_TP_GROUP = GroupCoordinator(
File ".../sglang/srt/distributed/parallel_state.py", line 292, in __init__
assert self.cpu_group is not None
AssertionError
I attempted to fix this with a colocate.patch hunk on dp_attention.py
but the unified-diff context+line-count was finicky enough that
--recount kept choking on the format-patch trailer. Switched to a
post-`git apply` Python string-substitution step inside
`setup_sglang` in `scripts/colocate/run_smoke_host.sh`:
1. Match the literal anchor " _ATTN_TP_GROUP = GroupCoordinator(".
2. Inject 14 lines of code above it that compute
`_ts_offset = read_colocate_env().n_per_role` if
is_colocate_active() else 0.
3. Rewrite the offending list comprehension to use the offset:
`list(range(_ts_offset + head, _ts_offset + head + _ATTN_TP_SIZE))`.
Both substitutions are string-stable across sglang 0.5.x; the script
asserts the anchor exists and that a substitution actually happened
(no silent drift).
Default offset 0 means non-colocate runs (where
`is_colocate_active()` returns False) are byte-identical.
Iter 9 made it past every previous deadlock and reached
trainer.py:_setup_device_mesh, where it ran:
self.mesh = init_device_mesh(
"cuda", mesh_shape=(dist.get_world_size(),), ...
)
Under colocate the default PG is the 2N-rank union world (1 trainer
+ 1 engine for the tiny test), but the trainer's DP mesh should only
span the trainer half [0, N). The unconditional dist.get_world_size()
made the trainer try to build a 2-rank DP mesh that includes the
engine — FSDP collectives on that mesh would deadlock on the first
all-reduce because the engine isn't an FSDP participant.
Fix _setup_device_mesh to:
1. Prefer args.world_size / args.rank (which trainer_actor.py
overrides to the trainer-subgroup values when colocate is on)
over dist.get_*; falls back to the dist values for the
non-colocate path.
2. When the resulting world_size is smaller than the dist world
(i.e. we're in a sub-world), build a trainer-only NCCL group
via dist.new_group(use_local_synchronization=True) and attach
a DeviceMesh.from_group rather than the default init_device_mesh
path that requires the default PG to match the mesh shape.
The non-colocate path is byte-identical: dist_world_size ==
args_world_size, so the existing init_device_mesh branch fires.
Logged the mesh kind ("1D" vs "1D-colocate-sub") + both world sizes
for diagnosability.
Documents the 10-iteration chain that peeled off NCCL deadlock layers between trainer rank 0 and engine TP scheduler subprocess rank 1 in the colocate union world. Each iter found one more collective new_group mismatch or rank-offset bug; each fix progressed the test further. End state: both sides past every previously-known deadlock + dp_attention rank offset + trainer DP mesh world_size confusion. The remaining hang is in model load / first NCCL collective on a 1-rank NCCL group (a regime the original world.py comment explicitly warned about). Lists every commit landed this session, lays out the next-session debug plan, and the cost ($8.46 spent on 10 iters at $2.99/hr on H100 SXM SECURE Iceland). No code changes in this commit, just documentation of what we learned and where we left it.
Iter 10 found the hang: eagle3_trainer.py:82 calls `dist.barrier(group=get_gloo_group())` after the rank-0-only load_embedding. In colocate mode `GLOO_GROUP` was bound to `union_world.meta_group` — the 2N-rank gloo group spanning both trainer and engine. The engine never enters the trainer's init_model code path, so it never enters that barrier, and the trainer hangs forever waiting. Three changes: 1. `colocate/world.py`: build a `trainer_gloo_group` alongside `meta_group` — gloo, ranks=`[0, N)`, use_local_synchronization=True so engines skip the call. Expose it on `UnionWorld`. For 1-trainer tiny test, this is a 1-rank gloo group; gloo handles 1-rank groups cleanly (unlike NCCL). 2. `training/trainer_actor.py`: bind `_dist_utils.GLOO_GROUP = self._union_world.trainer_gloo_group` instead of `.meta_group`. Now every `dist.barrier(get_gloo_group())` in the trainer's init / step path syncs only the trainer half. Comment updated to spell out the pitfall. 3. `training/trainer.py:_setup_device_mesh`: switch the colocate-sub-world DP group's backend from NCCL to GLOO for the 1-trainer case. NCCL 1-rank groups have known eager-init issues (the original world.py comment flagged this); gloo doesn't. For >=2 trainers we keep NCCL so DP all-reduce stays on the GPU path. Also adds `[TS-COLOCATE-TRACE-T]` `logger.warning` markers around eagle3_trainer.init_model's major phases (draft model build, load embedding, gloo barrier, FSDP wrap, full state dict load) so the next iter pinpoints any new hang within model init at sub-second resolution.
Iter 11 with the trainer-only gloo group fix (08976e5) cleared every previously-known deadlock and got the trainer all the way to `apply_fsdp2` (~10s after device-mesh creation). The next hang is `fsdp2_load_full_state_dict`: File ".../torchspec/training/fsdp.py", line 137 for _name, buf in model.named_buffers(): dist.broadcast(buf, src=0) # <-- uses DEFAULT PG (union world) In colocate mode the default PG is the 2N-rank union world (1 trainer + 1 engine for the tiny test). The engine never enters this code path, so the trainer blocks forever waiting for engine participation in the buffer broadcast. Fix: pull the mesh group out of `device_mesh.get_group()` (which is the trainer-only sub-mesh under colocate, and the full world under non-colocate) and pass it explicitly to `dist.broadcast`. Also translate `src=0` to the global rank corresponding to mesh-rank 0 via `dist.get_global_rank(mesh_group, 0)`, so the broadcast source is correct regardless of the mesh-vs-default-PG offset (under colocate the trainer's mesh-rank-0 is also union-rank-0, so no translation happens; under non-colocate this is also a no-op). Adds [TS-COLOCATE-TRACE-T] markers around the three steps (set_model_state_dict, buffer broadcasts, finish) so the next iter can pinpoint any remaining hang. Note: `set_model_state_dict(broadcast_from_rank0=True)` uses the FSDP-wrapped module's mesh internally (FSDP attached `mesh` at apply_fsdp2 time), so it should already broadcast only on the trainer sub-mesh. If iter 12 hangs at AFTER set_model_state_dict that assumption is wrong and we need to override its group too.
Iter 12 confirmed: the trainer-only-gloo + mesh-scoped-broadcast fixes cleared everything up to `fsdp2_load_full_state_dict`, which now hangs *inside* `set_model_state_dict`. PyTorch's set_model_state_dict with `broadcast_from_rank0=True` broadcasts the rank-0 state dict across the **default** process group — which under colocate is the 2N-rank union world. The engine never enters this code path, so the broadcast deadlocks. For the tiny smoke (dp_size=1) the FSDP mesh is a single trainer rank. There is nothing to broadcast — rank 0 already holds the full state dict — so disable broadcast_from_rank0 when the mesh has one rank and let rank 0 load locally. The buffer-broadcast loop below (already scoped to mesh_group in 2d44799) is a no-op on a 1-rank group. Multi-trainer colocate (dp_size>=2) still passes broadcast_from_rank0=True; that path would need set_model_state_dict to accept an explicit group (it doesn't today) or a manual per-param broadcast on mesh_group. Tracked as a follow-up — the 1-GPU tiny smoke is what we're unblocking right now, and it's dp_size=1.
Iter 13 cleared the FSDP state-dict load; the trainer then hangs
right after `init_model`'s fsdp2_load_full_state_dict with a NCCL
`dist.barrier()` warning ("Guessing device ID based on global rank,
this can cause a hang"). Root cause is the same family of bug as
every prior iter: bare `dist.*` collectives default to the union-
world PG, the engine never enters the trainer code path, deadlock.
Swept the trainer-side init + train-loop code for unscoped
collectives and scoped them all to `get_gloo_group()` (which
trainer_actor.py now binds to the trainer-only subgroup):
* eagle3_trainer.py `_init_target_lm_head`:
- `dist.broadcast(has_norm, src=0)` -> group=trainer
- `dist.barrier()` -> group=trainer
- `dist.broadcast(param.data, src=0)` (loop) -> group=trainer
* eagle3_trainer.py metric aggregation (train + eval paths):
- 2x `dist.all_reduce(avg_vlosses/avg_acces, AVG)` -> group=trainer
* checkpoint.py: 4x bare `dist.barrier()` -> group=trainer
(+ added the get_gloo_group import)
* trainer.py save path: `dist.barrier()` -> group=trainer
(+ added the get_gloo_group import)
On the 1-trainer tiny config the trainer group is a single rank, so
every one of these collectives is a no-op — they were only hanging
because they were (wrongly) ranging over the 2-rank union world. On
>=2 trainers they become real trainer-replica collectives, which is
the correct semantics (these sync trainer state, the engine has no
business participating).
Iter 14 cleared every distributed-init deadlock — the run now FAILS
FAST (55s, not a 15-min timeout) with a real error:
KeyError: 'Key lm_head.weight not found in
.../Qwen3-0.6B-Base/model.safetensors'
Qwen3-0.6B-Base (like Llama-3.2, Gemma, and most small models) has
`tie_word_embeddings=True` — there is no standalone `lm_head.weight`
tensor; the LM head shares the input-embedding matrix stored under
`model.embed_tokens.weight`.
`TargetLMHead._load_lm_head` only ever looked for the literal
`lm_head_key` ("lm_head.weight" by default) and KeyError'd when it
was absent. Fix: when the model config has
`tie_word_embeddings=True`, fall back to `model.embed_tokens.weight`.
`_load_key_from_file` now takes an optional `fallback_key` and tries
both in order. Applies to both the sharded (index.json) and
single-file checkpoint paths.
This is a general correctness fix, not colocate-specific — any
tied-embedding target model was previously unloadable by
TargetLMHead.
Iter 15 cleared init_model entirely (tied-embedding fix landed) —
trainer logs "Eagle3 model initialized with FSDP2". The engine TP
scheduler now hangs right after initialize_dp_attention, in
model_runner.py:
min_per_gpu_memory = get_available_gpu_memory(
self.device, self.gpu_id,
distributed=get_world_group().world_size > 1,
cpu_group=get_world_group().cpu_group,
)
sglang's `_WORLD` was built by init_distributed_environment spanning
the full 2N-rank union world (the patch deliberately passed
world_size=colocate_env.world_size=2). So `get_world_group().world_size
> 1` is True and get_available_gpu_memory does a distributed memory
sync on `_WORLD.cpu_group` — a collective across all 2N ranks. The
trainer ranks [0, N) never run sglang code, so it hangs.
The patch's original comment claimed downstream sglang "expects" a
2N-rank _WORLD. That's wrong: the engine is one half of the union
world; sglang's notion of "world" should be the engine ranks only.
Fix: new `rebuild_world_group_engine_only(env, local_rank, backend)`
in torchspec_colocate.py. Called from the ModelRunner colocate branch
right after init_distributed_environment — it drops the 2N-rank
_WORLD GroupCoordinator and rebuilds it spanning only
build_engine_tp_ranks(env) = [N, 2N). The new_group calls inside
init_world_group inherit the use_local_synchronization=True
monkey-patch, so only engine ranks participate.
After this, get_world_group().world_size == 1 (tiny config) so the
memory sync short-circuits, and every other sglang world-level
collective is engine-scoped.
@@ headers updated: torchspec_colocate.py new-file +1,318 -> +1,354;
model_runner import hunk +58,11 -> +58,12; ModelRunner hunk
+787,92 -> +787,105.
…urgery) Iter 16 with the engine-only _WORLD rebuild (8bdc8d4) got the engine all the way through model load + KV cache allocation ("KV Cache is allocated. #tokens: 315251"). It then fast-failed (55s, no hang) with: File ".../sglang/srt/managers/tp_worker.py", line 292, in __init__ self.random_seed = broadcast_pyobj( IndexError: list index out of range Root cause is in the sglang callsite, not our code. tp_worker.py's random-seed sync does: self.random_seed = broadcast_pyobj( [server_args.random_seed], self.tp_size * self.pp_rank + tp_rank, # tp-local rank = 0 self.world_group.cpu_group, src=self.world_group.ranks[0], # global rank = 1 )[0] broadcast_pyobj's docstring: the `rank` arg is the *global* rank. In standalone sglang the engine owns the whole world so tp-local rank == global rank — works. Under colocate the engine's global rank is N (=1) but its tp-local rank is 0, so `rank(0) != src(1)`: the engine wrongly takes broadcast_pyobj's *receiver* branch, receives size 0, returns `[]`, and the trailing `[0]` raises IndexError. Fix: pass `self.world_group.rank` (the GroupCoordinator's global rank, == torch.distributed.get_rank()) instead of the tp-local arithmetic. Correct for both colocate (global rank 1 == src 1 -> sender path) and standalone (global == tp-local, unchanged). Implemented as a second post-`git apply` string-substitution step in setup_sglang (alongside the dp_attention one) — the unified-diff format has been too fragile for these small sglang tweaks.
The colocate training loop sizes the NCCL hidden-states transfer buffer up front and needs aux_hidden_states_layers on args before it starts. DFlash already had an auto-resolver; Eagle3 colocate runs hit a hard RuntimeError. Resolve via get_default_eagle3_aux_layer_ids — the same function sgl_engine falls back to — so trainer and engine agree.
In colocate mode the trainer and engine share one physical GPU, so the union world's NCCL backend cannot form a communicator spanning both ranks — NCCL hard-errors "Duplicate GPU detected". The Phase-3 P2P smoke validated on 2 separate GPUs (1 rank each), which never exercised this. Route the engine->trainer hidden-state transfer over the existing all-rank gloo meta_group with host-memory staging instead: - NcclHiddenStatesConnector.send / NcclMultiTensorFetcher.recv_step branch on the group backend; gloo path stages through CPU and uses tagged dist.send/recv (no batch_isend_irecv, which is NCCL-only). - trainer passes union_world.meta_group to the fetcher. - colocate.patch exposes the engine-side meta_group via set/get_union_meta_group so build_hidden_states_writer can hand it to the connector. The NCCL batched path is kept for the separate-GPU dummy P2P tests.
The colocate loop logged metrics.get("train/loss"), but both the
Eagle3 and DFlash trainers' _aggregate_metrics emit "train/avg_loss"
(matching the disagg loop in loop.py). The wrong key made every step
log loss=None even though training ran fine — which also tripped
test_phase7_tiny_loss_decreases, whose log parser found zero loss
points.
Documents the iter chain that took the 1-GPU colocate smoke from the end-of-iter-10 hang to both tiny tests passing on 1xH100, and records the key architectural correction: NCCL cannot form a communicator with two ranks on one physical GPU, so the colocate hidden-state plane runs over gloo (host-staged), not the NCCL union world.
test_stability parses `step=N ... peak_alloc...=X` from the colocate loop's per-step log line, but the line only carried step/step_time/ loss/lr. Add peak_alloc from metrics["perf/peak_bytes_allocated"] (which the trainer already populates every step via TrainProfiler.peak_alloc_metrics) so the Phase-6 leak check has data points to compare.
bf2d468 to
927beaa
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Draft PR tracking work on #81 — co-locate training and inference on the same GPUs via CUDA MPS + NCCL/gloo P2P.
Each phase is gated behind a
colocate_strategy=mps+transfer_mode=ncclflag pair so the disaggregated baseline keeps working throughout.Status
peak_alloclogging landed; full leak/stability test pending 4×H100)grad_parity/convergencepending 4×H100)Testing progress
tests/colocate/test_colocate_tiny.pyis GREEN on 1×H100 SXM:The full colocate path is exercised end-to-end on a single GPU: MPS daemon → 2-rank
union world → patched sglang (engine-only
_WORLD, union-default PG,dp_attentionrank offset) → engine→trainer hidden-state transfer →
NcclMultiTensorFetcher→Eagle3 draft forward/backward → optimizer step. Loss decreases monotonically, so
gradients flow through real transferred hidden states.
Key architectural corrections found during RunPod validation
one physical GPU is hard-rejected (
ncclInvalidUsage, "Duplicate GPU detected") —exactly the colocate topology. The hidden-state plane was rerouted over the all-rank
gloo
meta_groupwith CPU staging. The NCCL batched path is retained only forthe separate-GPU Phase-3 dummy tests; CUDA IPC remains a possible future zero-copy
optimization.
dist.*collectives deadlock on the 2N union default PG (trainer andengine run different code paths). All trainer-side collectives are now scoped to a
trainer-only gloo group, FSDP broadcasts to the mesh group, and sglang's
_WORLDisrebuilt as engine-only
[N, 2N).transfer_mode=ncclis now genuinely Mooncake-free — the top-levelmooncake.storeimport was made lazy so the colocate path no longer needslibibverbs/libnuma.
Environment constraint
The bundled
sgl_kernelwheel ships sm90+ kernels only (no Ampere sm80/sm86, noAda sm89). Cheap single-GPU testing is effectively limited to H100 / H200 / B200.
Remaining work
Provision 4×H100 and run the suite with
--fullfor the MPS-gated tests(
test_one_step,test_grad_parity,test_stability,test_convergence). The 4-GPUunion world (2 ranks/GPU × 4 GPUs) gives FSDP across the 4-trainer NCCL subgroup its
first real ≥2-rank exercise.