Skip to content

[WIP] Support co-locate training and inference (#81)#92

Draft
zhubohao911 wants to merge 53 commits into
lightseekorg:mainfrom
zhubohao911:feature/colocate-training-inference
Draft

[WIP] Support co-locate training and inference (#81)#92
zhubohao911 wants to merge 53 commits into
lightseekorg:mainfrom
zhubohao911:feature/colocate-training-inference

Conversation

@zhubohao911
Copy link
Copy Markdown
Collaborator

@zhubohao911 zhubohao911 commented May 7, 2026

Draft PR tracking work on #81 — co-locate training and inference on the same GPUs via CUDA MPS + NCCL/gloo P2P.

Each phase is gated behind a colocate_strategy=mps + transfer_mode=nccl flag pair so the disaggregated baseline keeps working throughout.

Status

  • Phase 0 — config flags & validation
  • Phase 1 — placement: 1:1 bundle pairing + MPS env
  • Phase 2 — union NCCL world bootstrap
  • Phase 3 — P2P data plane (smoke test)
  • Phase 4 — sglang hidden-state hook
  • Phase 5 — controller / sync training loop
  • Phase 6 — memory caps & stability (per-step peak_alloc logging landed; full leak/stability test pending 4×H100)
  • Phase 7 — numeric parity & convergence (tiny loss-decrease test green; grad_parity / convergence pending 4×H100)
  • Phase 8 — docs & example config

Testing progress

tests/colocate/test_colocate_tiny.py is GREEN on 1×H100 SXM:

test_phase4_tiny_one_step       PASSED   (1 step end-to-end)
test_phase7_tiny_loss_decreases PASSED   (loss 12.02 → 9.74 over 20 steps)
2 passed in 175.33s

The full colocate path is exercised end-to-end on a single GPU: MPS daemon → 2-rank
union world → patched sglang (engine-only _WORLD, union-default PG, dp_attention
rank offset) → engine→trainer hidden-state transfer → NcclMultiTensorFetcher
Eagle3 draft forward/backward → optimizer step. Loss decreases monotonically, so
gradients flow through real transferred hidden states.

Key architectural corrections found during RunPod validation

  • NCCL cannot do same-GPU P2P. A union-world NCCL communicator with two ranks on
    one physical GPU is hard-rejected (ncclInvalidUsage, "Duplicate GPU detected") —
    exactly the colocate topology. The hidden-state plane was rerouted over the all-rank
    gloo meta_group with CPU staging. The NCCL batched path is retained only for
    the separate-GPU Phase-3 dummy tests; CUDA IPC remains a possible future zero-copy
    optimization.
  • Unscoped dist.* collectives deadlock on the 2N union default PG (trainer and
    engine run different code paths). All trainer-side collectives are now scoped to a
    trainer-only gloo group, FSDP broadcasts to the mesh group, and sglang's _WORLD is
    rebuilt as engine-only [N, 2N).
  • transfer_mode=nccl is now genuinely Mooncake-free — the top-level
    mooncake.store import was made lazy so the colocate path no longer needs
    libibverbs/libnuma.

Environment constraint

The bundled sgl_kernel wheel ships sm90+ kernels only (no Ampere sm80/sm86, no
Ada sm89). Cheap single-GPU testing is effectively limited to H100 / H200 / B200.

Remaining work

Provision 4×H100 and run the suite with --full for the MPS-gated tests
(test_one_step, test_grad_parity, test_stability, test_convergence). The 4-GPU
union world (2 ranks/GPU × 4 GPUs) gives FSDP across the 4-trainer NCCL subgroup its
first real ≥2-rank exercise.

Placeholder commit for tracking work on issue lightseekorg#81. Implementation will land
across multiple PRs following the phased plan.
Lays the foundation for putting trainer + inference engine on the same
GPU via NVIDIA MPS. All new code is gated behind
`training.colocate_strategy=mps`; the legacy disagg / `colocate=True`
paths are unchanged.

Phase 0 — config plumbing & validation:
* Add `colocate_strategy`, `transfer_mode`, `train_frac`, `infer_frac`
  to `TrainingConfig`.
* `torchspec/colocate/config.py` validates supported combinations
  (`mooncake` is the only legal `transfer_mode` when colocate is off;
  `mps`/`nccl` is the only combination when on), the `train_frac +
  infer_frac + 0.10 headroom <= 1.0` budget, and the
  `engine_count * engine_tp == world_size` topology invariant.
* Wired into `train_entry.parse_config`. 18/18 unit tests pass locally.

Phase 1 — placement + MPS env injection:
* `torchspec/colocate/mps.py`: idempotent NVIDIA MPS daemon lifecycle
  helper (start, status, env-var build, stop). Dependency-free so the
  Ray driver can call it from a headless box; subprocess is mocked in
  the 17-test unit suite.
* `placement_group.create_placement_groups` now branches on
  `is_mps_colocate(args)`, logs which strategy fired, and revalidates
  the engine-count/TP topology invariant.
* `RayTrainGroup` claims `train_frac` per actor under MPS (was hard-coded
  0.4); `_prepare_sgl_engines` claims `infer_frac` (was 0.2 placeholder).
* Both inject `mps_client_env()` + `expandable_segments:True` allocator
  config into the Ray actor `runtime_env`.
* `SglEngine.init` overrides sglang's `mem_fraction_static` from
  `infer_frac` so users don't keep two budgets in sync.
* `train_entry` starts the MPS daemon once at driver init and skips
  Mooncake master setup under MPS colocate (Phase 5 will rip Mooncake
  out properly).
* Modal smoke test `phase1_placement` (`H100:4`, 22 s test) confirms
  shared placement group, 4 trainer + 4 engine pairs land on matching
  `(node_ip, gpu_id)` with distinct PIDs and propagated MPS env vars.

Phase 2 — union NCCL world bootstrap helper:
* `torchspec/colocate/world.py`: `UnionWorldSpec` + `init_union_world`
  build a 2N-rank NCCL default PG, an FSDP-only NCCL subgroup of size N,
  and a 2N-rank gloo metadata subgroup. Sets
  `TORCHSPEC_COLOCATE_UNION_WORLD=1` so a follow-up sglang patch can
  detect the union-world setup. Trainers occupy ranks `[0, N)`, engines
  occupy `[N, 2N)`.
* Modal smoke test `phase2_union_world` (`H100:8`, 55 s test): 8 ranks
  bootstrap the union world; allreduce on the union, FSDP, and gloo
  groups all succeed; trainer/engine rank assignment matches spec.
* Phase 2 is intentionally tested in isolation from MPS sharing (8 GPUs,
  one rank per GPU). The MPS+union-world integration and the sglang
  scheduler patch (so sglang TP reuses our union world) are the
  highest-risk piece and are deferred to Phase 4 where they're actually
  exercised by the hidden-state hook.

Modal infrastructure:
* `scripts/modal/setup_modal_secrets.sh` — sandbox env secrets bootstrap
  (HF + W&B).
* `scripts/modal/modal_colocate_smoke.py` — sandbox-targeted Modal app
  with one entrypoint per phase. Image is built on
  `nvidia/cuda:12.4.0-devel-ubuntu22.04` + sglang `0f2df9370`.

Test artifacts:
* `tests/colocate/`: phase 0 validation, phase 1 MPS helper, phase 2
  rank-assignment, phase 1 placement (Modal-only), phase 2 union world
  (Modal-only). 45 unit tests pass locally; Modal smoke tests for
  Phase 1 and Phase 2 are green.

Tracking:
* `docs/colocate/implementation_log.md` records status, work logs, and
  every plan deviation phase by phase.

Co-authored-by: Claude
Adds the trainer-side `NcclDataFetcher` + engine-side `send_dummy`
that together form the colocate hidden-state transfer mechanism. Both
use `dist.batch_isend_irecv`, the production-correct primitive for
union-world P2P (avoids NCCL's lazy 2-rank sub-comm init pathology
on the size-2N parent group).

`init_union_world` extended:
- Adds `paired_global_rank` to `UnionWorld` so callers can target the
  opposite-role rank without re-deriving.
- Accepts `device_id` (defaults to `cuda.current_device()`); without it
  NCCL guesses device-by-rank, which under Ray's CUDA_VISIBLE_DEVICES
  isolation maps to a non-existent local GPU and silently deadlocks
  P2P.
- Skips the trainer-only NCCL subgroup when there's only one trainer
  (1-rank NCCL groups can hang in eager-init mode).

Modal smoke (`phase3_p2p_dummy` on H100:2) — 3/3 tests pass in 137s:
- 100-iter byte-equality on bare NCCL (data plane core)
- 1-iter round trip through the full `init_union_world` +
  `NcclDataFetcher` + `send_dummy` path (proves union-world helper
  integrates with the data plane)
- shape-mismatch errors cleanly via 90s watchdog (production wraps
  recvs in a watchdog timeout for the same reason)

Deviation from plan: ran at 2-rank/no-MPS scale instead of the plan's
4-GPU-MPS topology. Multi-pair concurrent P2P inside a shared parent
group is what Phase 4 builds (each pair gets its own NCCL world inside
its MPS-shared GPU); Phase 3's job is just to verify the data plane
mechanism. Documented in implementation_log.md §Phase 3 deviations.

Verification:
  PYENV_VERSION=3.11.8 python -m pytest tests/colocate/ -q
    → 45 passed, 9 skipped (local; torch absent → skip)
  modal run --env sandbox \
    scripts/modal/modal_colocate_smoke.py::phase3_p2p_dummy
    → 3 passed in 137.78s

AI-assisted (Claude). Human submitter reviewed and ran tests.

Co-authored-by: Claude
Lands the TorchSpec side of the engine→trainer P2P data plane so the
colocate path can ship hidden states without Mooncake:

* `NcclHiddenStatesConnector` (engine-side multi-tensor sender) —
  one `dist.batch_isend_irecv` over a sorted-by-key tensor dict,
  contiguous + CUDA preconditions, env-var contract for the upstream
  sglang patch (TORCHSPEC_COLOCATE_TRANSFER_MODE,
  TORCHSPEC_COLOCATE_PAIRED_TRAINER_RANK).
* `NcclMultiTensorFetcher` (trainer-side receiver) — symmetric
  sorted-by-key, fresh-buffer-per-step (revisit in Phase 6 if churn).
* `ColocateTrainSample` + `ColocateDataset` + `ColocateDataFetcher`
  — Mooncake-shaped batch dict consumed identically by `_train_step`.
* `TrainerActor.init` branches on `transfer_mode`: when `nccl`, runs
  `init_union_world` (master_port + 5000), binds the union-world's
  meta_group as GLOO_GROUP, and overrides args.rank / args.world_size
  to the trainer-only N-rank view so FSDP arithmetic stays correct.
  Stamps TORCHSPEC_COLOCATE_UNION_* env vars for the upstream patch.
* `Trainer.set_train_queue` swaps to `ColocateDataFetcher` when the
  union world is wired up; warns + bypasses Mooncake config.
* `SglEngine.init` exports the env contract + flips
  `enable_spec_training_mooncake` to False so the patch's NCCL path
  is the only writer.
* `docs/colocate/sglang_patch.md` documents the upstream sglang
  patch surface (env-var contract + 3 patch points + verification
  recipe + diagnostic for "patch not picked up").

Verified on Modal sandbox (2×H100, 40.4 s):
* `test_p2p_multi_tensor_round_trip` — 4-tensor Mooncake-shaped dict
  via union world, byte equality per tensor.
* `test_send_step_helper_matches_connector` — symmetric helper.

Phase 4's "one full training step" deliverable is gated on the
upstream sglang patch (no patches/_sglang/ in this repo); test_one_step
is parked behind that.

Implementation log: `docs/colocate/implementation_log.md` Phase 4
section is now 🟢 (TorchSpec-side complete).

Co-authored-by: Claude
Lands the TorchSpec-side of the colocate (transfer_mode=nccl) path's
controller and stability infrastructure. The colocate sync training
loop body itself is gated on the upstream sglang patch
(docs/colocate/sglang_patch.md); train_entry now reaches setup, prints
the timer summary, and raises a clearly-worded NotImplementedError
when the patch is absent.

Phase 5 — controller trim:
* `torchspec/controller/setup.py` adds
  `setup_colocate_training_with_engines(args, train_group,
  inference_engines, controller=None)` — the slim sibling of
  `setup_async_training_with_engines`. Skips `AsyncInferenceManager`
  entirely (returns `(controller, None)`), passes `mooncake_config=None`
  to both `train_group.set_train_queues` and `set_eval_queues`, and
  deliberately does **not** import `torchspec.transfer.mooncake.*` so a
  `sys.modules` guard test can enforce the property.
* `torchspec/controller/__init__.py` exports the new entry point.
* `torchspec/training/trainer.py` `set_eval_queue` now branches on the
  `_union_world` handle: when colocate, builds a `ColocateDataFetcher`
  on top of `NcclMultiTensorFetcher` instead of `MooncakeDataFetcher`;
  any incidental `mooncake_config` is logged + bypassed (defence in
  depth — the controller no longer sends it). Same shape downstream so
  `Eagle3Trainer._train_step` is unchanged.
* `torchspec/train_entry.py`:
  - `[8] Setup training` branches on `is_mps_colocate(args)` to call
    the colocate setup; mooncake_config build/master are skipped.
  - After `timer.log_summary()`, raises `NotImplementedError` with a
    pointer to `docs/colocate/sglang_patch.md` and the multi-tensor
    smoke target. The pre-loop wiring (controller actor, train_group,
    inference_engines, train queues) is fully set up at this point;
    the loop body is the only remaining gap.
* `tests/colocate/test_phase5_no_mooncake.py` (3 tests, all green
  locally): module-import guard via fresh interpreter +
  `sys.modules`, signature compatibility with the async setup,
  `inference_manager is None` post-condition.

Phase 6 — memory caps + MPS hygiene + stability skeleton:
* `torchspec/train_entry.py` init-order fence (gated on
  `is_mps_colocate(args)`): runs `ray.get(train_init_refs)` *before*
  invoking `prepare_inference_engines`. Under MPS the engine and
  trainer share one CUDA memory pool; sequencing trainer-first
  guarantees `set_per_process_memory_fraction(train_frac)` is applied
  before sglang's KV-cache pre-allocator runs. Disagg path keeps the
  original parallel init.
* `torchspec/colocate/mps.py` `setup_for_colocate(register_atexit=True)`
  registers a `quit`-the-daemon `atexit` hook iff *this* process
  started the daemon (uses the helper's `started_by_us` ownership
  flag). Idempotent + race-free; SIGKILL/OOM-kill paths still leak
  the daemon by design — next driver run reuses it.
* `torchspec/utils/profiling.py` `TrainProfiler.peak_alloc_metrics(reset=True)`
  returns peak / current / reserved bytes for `torch.cuda.current_device()`
  and optionally resets the counter. Empty dict on CPU-only test runs.
* `torchspec/training/trainer.py` _gather_metrics emits the four
  `perf/peak_*` / `perf/current_*` keys per step and resets the peak
  counter so each metric reflects only that step's window.
* `tests/colocate/test_stability.py` skeleton — two tests pinned to
  `pytest.skip` until the upstream sglang patch unblocks
  `phase6_stability`. Encodes the plan's "peak_alloc(step=10) ≈
  peak_alloc(step=999) within 1 %" acceptance bar in code.

Bundling note: trainer.py and train_entry.py contain hunks from both
phases. Bundling them into one commit keeps each file's view of the
colocate branch internally consistent (Phase 6's init-order fence is
gated on Phase 5's `is_mps_colocate(args)`; the peak-alloc metric
lands inside the same metrics dict Phase 5 routes through). The
implementation log preserves separate Phase 5 / Phase 6 work-log
sections for the review trail.

Verification:
* `PYENV_VERSION=3.11.8 python -m pytest tests/colocate/ -q`
  → 45 passed, 27 skipped (skips: torch absent / CUDA absent /
  upstream-patch-gated).
* End-to-end `phase4_one_step` / `phase6_stability` on Modal remain
  parked until the upstream sglang patch lands (documented in
  implementation_log.md).

AI-assisted (Claude). Human submitter reviewed and ran tests.

Co-authored-by: Claude
Adds the parking spots for the per-parameter gradient parity and
end-to-end convergence comparisons described in
implementation.md §Phase 7. Both files are deliberate skeletons that
encode the acceptance bars in code (so the bar can't drift on the
branch) but stay `pytest.skip` until the upstream sglang patch
unblocks the colocate sync loop and the Modal `phase7_*` entrypoints.

* `tests/colocate/test_grad_parity.py` —
  `test_phase7_grad_parity_per_parameter` skeleton. Acceptance pinned
  in docstring:
  `torch.allclose(g_disagg, g_colocate, atol=1e-6, rtol=0)` per
  parameter on the same prompts + seed (NCCL is bit-deterministic
  given identical reduction order; we don't change it).

* `tests/colocate/test_convergence.py` —
  `test_phase7_convergence_curves_match_within_2pct` and
  `test_phase7_eval_loss_matches`. Both `pytest.skip` +
  `pytest.mark.slow`. Acceptance: per-step loss within 1–2 %, eval
  loss within tokenizer-deterministic noise. Hooks into the
  `phase7_grad_parity` / `phase7_convergence` Modal targets that the
  Phase 5 commit already wired in (currently parked because the
  colocate sync loop body raises NotImplementedError until the
  upstream patch is applied).

Both skeletons depend on a "disagg control run" snapshot we don't
generate yet; once the upstream patch is in, the skeleton needs (a)
a recorded disagg gradient/loss baseline on identical prompts/seed
and (b) a colocate run to compare against.

Verification:
  PYENV_VERSION=3.11.8 python -m pytest tests/colocate/ -q
  → 45 passed, 29 skipped (test_grad_parity + test_convergence join
    the skip list as expected).

AI-assisted (Claude). Human submitter reviewed.

Co-authored-by: Claude
Wraps up the colocate (MPS + NCCL) feature with user-facing
documentation and a runnable single-node example. No code changes;
this is purely the discoverability layer the prior phases lacked.

* `docs/colocate/usage.md` (new) — end-to-end user guide:
  - When to use colocate vs disaggregated (single-node, sglang-only,
    spec-training; vs multi-node / multi-replica / async / vLLM).
  - Hardware + software prereqs (driver R535+, MPS daemon, CUDA
    toolkit, expandable_segments allocator, sglang colocate patch).
  - GPU layout invariants — 1:1 trainer↔engine pairing, `tp_size==1`,
    union-world rank assignment (`[0,N)` trainer, `[N,2N)` engine).
  - Memory-split formula `train_frac + infer_frac + 0.10 ≤ 1.0` with
    rationale for the 0.10 reserved slack.
  - The four colocate config fields with the three Phase-0 validation
    rules.
  - "What changes inside a run" section mapping each phase to the
    runtime behaviour change (placement, MPS daemon, distributed
    init, fetcher swap, engine init, controller).
  - Validation matrix — which Modal smoke entrypoint proves each
    phase, and which are still upstream-patch-gated.
  - Known limitations + troubleshooting section (hangs, OOM,
    daemon-not-running, `via PCIe`, daemon zombies).
  - "Where the code lives" map back to the source files.

* `configs/colocate_qwen3_8b.yaml` (new) — colocate sibling of
  `configs/sglang_qwen3_8b.yaml`. Differs only in:
  - `colocate_strategy=mps`, `transfer_mode=nccl`,
    `train_frac=0.45`, `infer_frac=0.45`.
  - `training_num_gpus_per_node=4`, `inference_num_gpus=4`,
    `inference_num_gpus_per_engine=1`, `tp_size=1` (the 1:1 pairing
    invariant).
  - `output_dir` / `cache_dir` paths.
  Kept structurally identical so side-by-side diff for Phase 7
  parity runs is meaningful.

* `examples/colocate-qwen3-8b-1node/{run.sh,README.md}` (new) —
  colocate sibling of `examples/qwen3-8b-single-node/`:
  - `run.sh` exports `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`
    and `PYTORCH_ALLOC_CONF` (PyTorch ≥ 2.9 alias), defaults
    `CUDA_VISIBLE_DEVICES=0,1,2,3`, pins `tp_size=1` /
    `inference_num_gpus_per_engine=1`, and forwards extra args to
    `python -m torchspec.train_entry`.
  - `README.md` short overview that links into
    `docs/colocate/usage.md` for the full design rationale; calls
    out the upstream-patch dependency and the expected hang
    signature when the patch is missing.

* `docs/ray.md` placement-group table now has a row for
  `colocate_strategy=mps + transfer_mode=nccl` (shared PG, fractional
  `num_gpus=train_frac`/`infer_frac`, link to the new usage doc).

* `docs/colocate/implementation_log.md`:
  - Status snapshot updated: Phase 5 / 6 / 7 → 🟢, Phase 8 → ✅.
  - Phase 5 / 6 / 7 / 8 sections written out with work logs,
    deviations, and verification gates. The bundling note from the
    Phases 5+6 commit is recorded so reviewers see why those two
    files are co-committed.

Verification:
* `python -m torchspec.train_entry --config configs/colocate_qwen3_8b.yaml`
  on a non-colocate-patched sglang reaches setup and raises the
  Phase-5 NotImplementedError as documented (this is the dry-run
  signature, not a bug).
* Existing examples / configs still parse — Phase 0 validation only
  fires the new errors when the colocate fields are set.
* `PYENV_VERSION=3.11.8 python -m pytest tests/colocate/ -q`
  → 45 passed, 29 skipped (unchanged).

AI-assisted (Claude). Human submitter reviewed.

Co-authored-by: Claude
The Phase-4 plan called for the upstream sglang colocate patch to land
in a separate PR; this commit instead vends it as an in-repo patch
(applied on top of the existing disagg patch) so Phase-4 end-to-end
runs become possible without an external dependency. Pseudocode in
docs/colocate/sglang_patch.md is kept as the upstream-PR spec; the
real diff is now `patches/sglang/v0.5.8.post1/colocate.patch`.

* `patches/sglang/v0.5.8.post1/colocate.patch` (new, ~700 lines) —
  the engine-side colocate patch generated from the
  `feature/torchspec-colocate-patch` branch on the local sglang clone
  (base: 0f2df9370a + sglang.patch). Five-file surface:
  - **NEW** `sglang/srt/distributed/torchspec_colocate.py` — env-var
    contract parser + union-world default-PG joiner +
    NcclHiddenStatesConnector factory. Lazy torchspec import so disagg
    runs never pull torchspec into sglang's import graph. Self-contained
    helper, only stdlib at module level.
  - `parallel_state.py::initialize_model_parallel` — accepts
    `tp_world_ranks`. When passed, skips the
    `world_size == tp_size * pp_size` assertion and uses the explicit
    rank list as the single TP group; creates singleton PP groups.
    Defends against non-default MoE-EP/MoE-TP layouts under colocate.
  - `model_executor/model_runner.py` — branches at the
    init_distributed_environment call site. Under colocate: joins the
    union world via init_union_default_pg, calls
    init_distributed_environment redundantly to set sglang's `_WORLD`,
    then initialize_model_parallel(tp_world_ranks=...).
  - `managers/scheduler.py::Scheduler.__init__` — adds `eagle_nccl_writer`
    next to `eagle_mooncake_store`; skips Mooncake background init
    under colocate; instantiates NcclHiddenStatesConnector on every TP
    rank (each pairs 1:1 with one trainer rank).
  - `managers/scheduler_output_processor_mixin.py` — adds
    `_send_hidden_states_to_nccl`; `_process_hidden_states_for_req`
    branches NCCL-first, Mooncake-fallback, none-otherwise.

* `scripts/modal/modal_colocate_smoke.py` — wires the Modal image to
  apply both `sglang.patch` and `colocate.patch` (in order) when
  building the sglang container layer. Adds `--recount` to both
  `git apply` calls (the colocate.patch comes from `git format-patch`
  which appends a trailing `2.51.2` git-version line that older
  applies reject without `--recount`).

* `docs/colocate/sglang_patch.md` — adds a banner pointing at the
  real patch file. The pseudocode patch points are kept as the
  upstream-PR spec; the in-repo patch is the runnable artifact.

Verified locally:
* Patch round-trips: reset sglang clone to disagg-only state, re-apply
  colocate.patch with `--recount` → success, no conflicts.
* `torchspec_colocate.py` imports cleanly via `importlib` on Mac
  (Python 3.11). End-to-end exercises:
  - `is_colocate_active()` toggles correctly with env var.
  - `read_colocate_env()` parses 6 env vars; `engine_global_rank()`
    maps tp_rank → union global rank correctly (4 → N+0..N+3 for N=4);
    out-of-range rejected with ValueError.
  - Missing env var raises clear RuntimeError pointing at the doc.
  - `build_engine_tp_ranks()` returns `[N, ..., 2N-1]`.
* Five patched files all parse cleanly via `ast.parse`.

End-to-end NCCL P2P validation (the `phase4_one_step` Modal target)
needs the matching trainer + GPUs and is still parked behind the
synchronous loop body in `train_entry.py` (Phase 5 follow-up).

AI-assisted (Claude). Human submitter reviewed and ran the patch
round-trip + helper module tests on Mac.

Co-authored-by: Claude
…in probe

The previous image recipe applied colocate.patch from the cloned
(pinned-commit) TorchSpec repo, which doesn't contain it yet. Restructure
the build into three layers so:

  1. clone sglang + pip install + apply existing disagg patch
  2. add_local_dir overlays (brings in the new colocate.patch file)
  3. apply colocate.patch from the overlaid path

This makes patch iteration only invalidate the thin top layer instead
of rebuilding base + disagg, and keeps the heavy `pip install -e
_sglang/python[all]` cached.

Extend `_run_probe` to assert the four patch-surface invariants inside
the live container (helper module imports + round-trips, parallel_state
gets the tp_world_ranks kwarg, output processor gets the
_send_hidden_states_to_nccl method, Scheduler.__init__ wires
eagle_nccl_writer + the colocate gate). Any future image build that
fails to apply the patch will now fail probe loudly instead of silently
sliding through to e2e training.

Verified on Modal sandbox (doordash/sandbox env):
  probe                    1xH100  26 s  4/4 patch-surface checks pass
  phase1_placement         4xH100  18 s  5/5
  phase3_p2p_dummy         2xH100 128 s  3/3
  phase4_multi_tensor      2xH100  39 s  2/2

AI assistance was used to draft the layered image recipe and the
patch-surface assertions; the human submitter reviewed and ran the
verification.

Co-authored-by: Claude
The previous train_entry.py raised NotImplementedError under colocate
because (a) no synchronous training loop existed and (b) the engine
side had no way to discover the trainer-computed union-world
master_addr/master_port — every actor was self-computing it, which
deadlocked the collective rendezvous. This commit closes both gaps.

Driver-side union spec + env-var injection:
* RayTrainGroup now exposes its master_addr/master_port (set by the
  rank-0 trainer's setup_master) so the driver can derive the
  union-world endpoint (master_port + 5000).
* train_entry.py computes TORCHSPEC_COLOCATE_TRANSFER_MODE=nccl plus
  the four UNION_* env vars on the driver process BEFORE async_init
  fires, then forwards them via prepare_inference_engines'
  extra_env_vars argument into the SglEngine actors' runtime_env.
  The patched sglang TP scheduler subprocess inherits this env from
  its actor and joins the union world as ranks [N, 2N).
* TrainerActor._init_distributed_colocate now reads from these env
  vars when the driver has set them, falling back to the legacy
  self-computed spec for tests that spin up a TrainerActor in
  isolation. Both sides now agree on the same rendezvous endpoint.

Drop the deadlock-prone init-order fence:
* The previous "trainer-first" fence (await train_init_refs before
  starting engines) is fundamentally incompatible with a collective
  init_process_group(world_size=2N) — every trainer would block
  forever waiting for engines that hadn't been spawned. Both sides
  now run init() in parallel; both block on the rendezvous and
  unblock together. Memory contention under MPS is handled by
  expandable_segments + the train_frac/infer_frac budget split.

Synchronous loop body (torchspec/controller/colocate_loop.py):
* run_colocate_training_loop drives one batch per step:
  1. Pull dp_size prompts from the controller.
  2. For each (engine, trainer) pair: push ColocateTrainSample(
     tensor_specs computed from prompt seq_len + engine hidden_size)
     to trainer's queue, then dispatch engine.generate as a Ray
     remote.
  3. Concurrently fire train_from_queue on every trainer.
  4. Await both — the engine's spec_training callback NCCL-sends to
     its paired trainer, the trainer's NcclMultiTensorFetcher
     recv_step picks it up.
  5. Log step metrics, advance counter.
* Hard requires draft_accumulation_steps=1 and per_dp_rank_batch_size=1
  for now; multi-step accumulation + multi-sample-per-rank batching
  threaded through controller dispatch is parked Phase-5 follow-up.

Engine-side cleanup:
* SglEngine.generate short-circuits in colocate (NCCL) mode after
  sgl.Engine.generate completes. The post-processing was building
  Mooncake-key output dicts that don't apply when the data plane is
  NCCL P2P; previously it logged a noisy "no mooncake keys returned"
  error every step.
* The colocate.patch's _send_hidden_states_to_nccl now sends the
  same 3-tensor dict shape that the disagg Mooncake path stores
  (hidden_states already aux-concatenated by sglang's
  spec_training, plus input_ids and optional last_hidden_states).
  The previous version sent a separate aux_hidden_states key that
  the trainer fetcher didn't know how to consume.

One-step e2e test (tests/colocate/test_one_step.py):
* Runs the colocate Qwen3-8B example through train_entry with
  num_train_steps=1 on H100:4 and asserts the loop reports
  "completed_steps=1 / num_steps=1". This is the maximal e2e check
  that catches: rendezvous deadlocks, MPS daemon misconfigurations,
  tensor-spec mismatches between trainer fetcher and engine sender,
  and aux-layer count mismatch on hidden_states' last dim.

AI assistance was used to draft the orchestration plumbing and the
loop body; the human submitter reviewed and ran the verification on
Modal H100:4.

Co-authored-by: Claude
Phase-4 one-step on Modal H100:4 surfaced a real bug: once the MPS
control daemon is up on a node, *every* CUDA context on that node has
to set CUDA_MPS_PIPE_DIRECTORY or CUDA fails with error 805 ("MPS
client failed to connect to the MPS control daemon or the MPS
server"). Our previous order was

  ray.init -> create AsyncTrainingController -> setup_for_colocate

so:
- the controller actor inherited an os.environ without the MPS pipe
  and crashed on the first torch.cuda.is_available() inside
  dataset / tokenizer code;
- trainer / engine actors that *did* have CUDA_MPS_PIPE_DIRECTORY in
  their runtime_env still raced with daemon startup because
  start_mps_daemon returned before the control pipe file existed.

This commit:
- moves setup_for_colocate to step [0] of train_async_no_generation,
  before any Ray actor (including the controller) is created, and
  exports the client env into the driver's os.environ so all child
  processes inherit it;
- adds mps_client_env() to the controller's runtime_env_vars
  (defense-in-depth — the explicit override is independent of
  os.environ inheritance);
- polls for /tmp/nvidia-mps/control inside start_mps_daemon with a
  10s timeout so callers can rely on "function returned ==> daemon
  is reachable" and don't race with daemon init.

Also fills out the remaining Phase-7 test bodies:
- tests/colocate/test_grad_parity.py: one-step smoke that asserts a
  finite, non-zero training loss came out of the colocate loop (the
  full per-parameter byte-equality test is parked as a follow-up
  because it needs deterministic-seed plumbing across both transfer
  modes plus a gradient-snapshot hook in the trainer);
- tests/colocate/test_convergence.py: short-horizon (50-step
  default, configurable) convergence test asserting the late-window
  average loss is below the early-window average — a cheap e2e
  signal that gradients are actually flowing.

Modal smoke script now also overlays examples/ so the colocate config
can resolve dataset.train_data_path inside the container, and
test_one_step.py invokes train_entry directly instead of relying on
the example run.sh script.
When trainer setup_gpu hits 'MPS client failed to connect to the
control daemon' the actual root cause lives in
/tmp/nvidia-log/control.log on the node. Surface a tail of it (plus
the relevant CUDA env) at error time so we don't have to re-run with
extra logging.
…sharing

Modal sandbox H100 nodes (and any container without --ipc=host /
matching capabilities) start the MPS control daemon cleanly but the
per-GPU server crashes on first client connect with 'Failed to start
: operation not supported' — every actor that subsequently does
torch.cuda.* then dies with CUDA error 805.

This commit makes the colocate driver detect that case and gracefully
degrade:

- setup_for_colocate now eagerly probes the MPS server (sends a
  get_server_list to force the daemon to spawn its per-GPU server,
  then watches /tmp/nvidia-log/server.log for either a clean start
  or 'operation not supported');
- on probe failure we tear down the daemon and return (None, {}) so
  the caller knows MPS isn't usable here;
- train_entry / train_group / inference.factory all check the new
  args.colocate_mps_unavailable flag and skip injecting
  CUDA_MPS_PIPE_DIRECTORY into actor runtime_envs when MPS is dead;
- TORCHSPEC_DISABLE_MPS=1 lets ops force the same fallback.

Functional outcome: colocate still works (trainer + engine still claim
fractional GPU resources via Ray and end up sharing the same physical
GPU) — they just can't run kernels concurrently on Modal sandbox so
throughput is lower than a real DGX-style box.
…etCount)

The previous probe used 'get_server_list' which doesn't actually
spawn a per-GPU server, so the 'operation not supported' failure was
only visible after the first real CUDA client connected — too late.
Replace it with a 30s subprocess probe that calls cuInit +
cuDeviceGetCount via libcuda.so.1; that exercises the real client
codepath and surfaces the failure (or success) before any actor
spawns. Run in an isolated process so the driver's CUDA state is
untouched.
…e startup

Phase-4 one-step on Modal H100:4 surfaced the next issue after the
MPS fallback: the engine's sglang TP scheduler subprocess takes
~1 minute to initialise (fork sglang scheduler, download Qwen3-8B
weights from HF, build the worker, etc.), but
``init_process_group(device_id=...)`` switches NCCL into
*eager init* mode where every rank's NCCL listener has to be
reachable within the socketPollConnect 35-retry window (~30 s).
Trainers were ready in seconds and timed out polling the engine's
not-yet-bound NCCL listener with

  socketPollConnect: connect ... returned Connection refused,
  exceeded error retry count after 35 attempts

The fix is to drop ``device_id=`` from both sides of the union-world
``init_process_group`` so NCCL falls back to lazy init: the
handshake happens on the first collective op, which inherits the
10-minute ``timeout=`` we already pass. Slow engines now have plenty
of slack to catch up; the only thing we lose is a microscopic
init-latency optimisation for fast-startup workloads, which we
don't have here.

This commit:

- in torchspec/colocate/world.py: drop ``device_id=`` from
  init_process_group (and add a comment explaining why);
- in patches/sglang/.../colocate.patch: same change to
  ``init_union_default_pg`` (the engine side, called inside the
  sglang scheduler subprocess via the patch).
Phase-4 one-step on Modal sandbox surfaced that working colocate
fundamentally needs functioning NVIDIA MPS — without it, two
processes (trainer + engine) on the same physical GPU can't reliably
do inter-process NCCL P2P, the rendezvous never completes, and we
burn the test's 30-minute budget on a doomed run. Modal sandbox H100
nodes start the MPS daemon successfully but the per-GPU server
crashes with 'operation not supported' (containers don't have
--ipc=host or the matching capability).

This commit:

- adds tests/colocate/_mps_probe.py with two helpers
  (has_h100_quad, mps_works) so each phase test can fail fast
  rather than hanging;
- gates test_one_step (Phase 4), test_stability (Phase 6),
  test_grad_parity (Phase 7), and test_convergence (Phase 7) behind
  the new mps_works skip — when MPS is broken the test is
  ``pytest.skip`` with a clear "needs --ipc=host host" message
  rather than running for 30 minutes and timing out.

The probe itself spawns ``nvidia-cuda-mps-control -d`` (idempotent)
and a tiny libcuda.so.1 ``cuInit + cuDeviceGetCount`` subprocess to
exercise the actual MPS client codepath. False positive (probe says
working but real run fails) requires the MPS server to die *between*
the probe and the real run; in that case the existing
phase4 fallback path (TORCHSPEC_DISABLE_MPS / colocate_mps_unavailable
flag) still kicks in and the test fails for the right reason
instead of hanging.
…er unit tests

The 4×H100 + Qwen3-8B Phase-4/6/7 tests are gated behind has_h100_quad()
because Modal sandbox can't run them (no MPS server support — see the
"Modal sandbox MPS limitation" note in implementation_log.md). Renting
4×H100 elsewhere is expensive enough to be a barrier for one-shot
correctness validation.

This adds a single-GPU + tiny-model variant that exercises the same
colocate code path (MPS daemon, fractional GPU sharing, NCCL P2P union
world, NcclMultiTensorFetcher, sglang colocate.patch hidden-state hook)
on a $0.20–$2/hr cheap GPU rental:

  configs/colocate_qwen0p6b_tiny.yaml       — Qwen3-0.6B-Base, 1-GPU layout
  tests/colocate/test_colocate_tiny.py       — Phase-4 one-step + Phase-7
                                                mini convergence (1 GPU)
  scripts/colocate/run_smoke_host.sh         — cheap-host runner: clones
                                                sglang, applies patches,
                                                pip-installs deps, runs the
                                                tiny test
  scripts/modal/modal_colocate_smoke.py      — phase_tiny entry point
                                                (verifies skip behaviour
                                                on Modal sandbox)

Also extends tests/colocate/_mps_probe.py with has_n_gpus(n) so the
tiny tests can gate on >=1 GPU instead of the 4-GPU has_h100_quad().
The same _mps_probe.mps_works() check skips cleanly on hosts where
the MPS server reports "operation not supported" (e.g. Modal sandbox)
so a SKIP outcome means *the host* doesn't support MPS, not that the
colocate code is broken.

Re-verified Modal sandbox correctness while we were here:
  probe                — patch surface 4/4 OK (35 s)
  phase1_placement     — 1 passed, 4 skipped (40 s)
  phase4_multi_tensor  — 2/2 PASSED (69 s, no MPS dependency)
  phase4_one_step      — 1 SKIPPED in 1.5 s (clean MPS skip)
  phase_tiny           — 2 SKIPPED in 0.7 s (clean MPS skip)

Drive-by fix: tests/colocate/test_phase1_mps_helper.py had two pre-
existing local-test failures introduced by the earlier MPS-fallback
work — the tests didn't account for (a) start_mps_daemon polling for
the control pipe file (added when CUDA error 805 turned out to be a
race on the Modal sandbox debugging), or (b) setup_for_colocate now
doing a real cuInit/cuDeviceGetCount probe and falling back to
(None, {}) when CUDA isn't available. Both tests now mock the right
surface and a new test_setup_for_colocate_falls_back_when_probe_fails
pins down the graceful-degradation contract that the Modal-sandbox
SKIPs depend on. All 46 local pytest now pass (vs 44 passed + 2
failed before).

Co-authored-by: Claude
…ndoff

Adds a self-contained handoff doc so another agent can pick up the
MPS-required Phase-4/6/7 validation on a non-Modal host without
re-deriving the rationale or the failure modes:

  docs/colocate/cheap_host_test_plan.md  — TL;DR, cost-tier matrix
                                            (RunPod / Vast.ai / Lambda /
                                            Hyperstack), explicit
                                            RunPod + Vast.ai recipes,
                                            success-criteria checklist,
                                            failure-mode diagnostic
                                            table, report-back checklist
  scripts/colocate/README.md             — short pointer to the plan

Extends scripts/colocate/run_smoke_host.sh with:
  --full           run tiny + every 4×H100 Phase-4/6/7 test in one go
                   (each test self-skips if its preconditions miss)
  --tests=A,B,C    explicit test-file override
  GPU-aware default for CUDA_VISIBLE_DEVICES (0,1,2,3 when --full + 4 GPUs;
  0 otherwise)
  Documented exit codes + new env knobs (PHASE6_STABILITY_STEPS,
  PHASE7_CONVERGE_STEPS) so callers can dial up confidence.

implementation_log.md gets a backlink to the new test plan so future
readers find the cheap-host workflow instead of re-discovering it.

Co-authored-by: Claude
- Pre-flight (nvidia-smi, GPU count, MPS server probe) now runs
  before the 5-10 min pip-install setup, so a host without working
  MPS exits 1 in ~30s instead of after a SKIP-only pytest run.
  Probes via `python -m tests.colocate._mps_probe` with a verbose
  reason (cuInit rc + server.log tail). Bypass:
  COLOCATE_SKIP_MPS_PROBE=1.
- Stale-state cleanup at pre-flight: `ray stop -f` plus a guarded
  `rm -rf /tmp/nvidia-{mps,log}` (only when no daemon is running).
  Both were documented as manual recipes in the failure-modes table.
- Pytest output tee'd to colocate-smoke-pytest.log; a structured
  colocate-smoke-report.txt is written on exit with everything the
  plan's "Reporting back" section asks for (host details, exit code,
  pytest summary, [colocate_loop] loss lines, skipped tests, failure
  tails + /tmp/nvidia-log/{server,control}.log tails).
- bash EXIT trap best-effort-sends `quit` to the MPS daemon so it
  doesn't leak (skip with COLOCATE_KEEP_MPS=1).

tests/colocate/_mps_probe.py: split mps_works() into
mps_works_verbose() -> (bool, reason) so callers can surface the
actual failure mode instead of a bare False. Added a
`python -m tests.colocate._mps_probe` CLI for the new pre-flight and
for humans following the doc's "Quick MPS sanity check".

Docs updated to reflect new pre-flight exit-code semantics, the
auto-report, and retire two now-automated manual cleanup recipes.

No colocate code path changes.
mooncake.store's native .so dlopens libibverbs.so.1 + libnuma.so.1 +
librdmacm + libnl-3 at import time. On RunPod's stock pytorch-2.4
template none of those are installed, and a top-level
`from mooncake.store import MooncakeDistributedStore` brings down the
entire torchspec.training.trainer import chain — including the
colocate MPS+NCCL path (transfer_mode=nccl), which by design never
touches Mooncake.

Wrap in try/except and define a stub class on failure. The stub
satisfies the Optional[MooncakeDistributedStore] type annotation on
_store and raises RuntimeError with an actionable apt-get hint if
the disagg path actually tries to instantiate the store at runtime.
The lazy ReplicateConfig import in _build_replicate_config() (line
~300) was already structured the same way; this just extends the
pattern to the one remaining top-level mooncake import.

Surfaced by the cheap-host smoke on RunPod A100 SXM (a96eaef): MPS
pre-flight + daemon spawn pass cleanly, but train_entry fails at
module load with libibverbs.so.1 missing — then libnuma.so.1 after
we apt-install libibverbs1, confirming the dep chain isn't ending.
Companion to knowledge.zh-en.md covering SMs, contexts, streams,
MPS internals, caching allocator fragmentation, NCCL communicators,
and the colocate stack end-to-end. Cross-references main doc sections.
…aces

Several modules — torchspec/colocate/{world,mps}.py,
torchspec/training/nccl_data_fetcher.py,
torchspec/inference/engine/nccl_hidden_states_connector.py — create
their own loggers via `logging.getLogger("torchspec.X.Y")` instead of
importing the central `logger` from torchspec.utils.logging. These
child loggers' default level is the root logger's WARNING, so every
INFO they emit gets silently dropped.

This bites hard when debugging the colocate path: TrainerActor calls
init_union_world which has a "Initialising union world: ..." INFO log
right before dist.init_process_group, and the patched sglang ModelRunner
similarly has "Joining TorchSpec union world: ..." right before its
rendezvous. Both are invisible at default config, so a stuck NCCL
rendezvous looks like complete actor silence — exactly the failure mode
that surfaced on the RunPod H100 PCIe smoke run (a96eaef / 3f7e708):
TrainerActor's worker-out had 2 lines and SglEngine's stopped at the
"BEFORE init" line, even though both processes had ~14 minutes of
runtime before the harness killed them.

Fix: in setup_logger(), also attach the same handler to
`logging.getLogger("torchspec")` (lowercase namespace) with
propagate=False. Every child logger in that hierarchy inherits the
handler via standard propagation, so the previously-silenced INFO logs
become visible in actor stdout/stderr without changing any submodule.

Guarded by `if not _ts_logger.handlers:` so re-entrant calls to
setup_logger don't pile on duplicate handlers.
Documents the 2026-05-13 RunPod validation session in
implementation_log.md, including:

- The four sequential failures hit on real hardware: libibverbs ->
  libnuma -> sgl_kernel SM gap (sm80/sm86/sm89 not in wheel) ->
  TP scheduler subprocess silent hang.
- Why commit 3f7e708 (mooncake lazy-import) and 0089ad3 (logger
  visibility for torchspec.X.Y namespace) were needed and what they
  unblocked.
- Run-by-run timeline (A100 -> H100 PCIe -> H100 SXM) with cost,
  outcomes, and the per-layer "what worked vs what blocked" matrix.
- Hypothesis space for the remaining TP-scheduler hang
  (init_union_default_pg gating, NCCL TCPStore rendezvous timeout
  defaults, or silent exception in the patched scheduler init) plus
  the specific action items for the next iteration.

Updates cheap_host_test_plan.md:

- Cost-tier matrix now requires SM90+ GPUs (H100/H200/B200) because
  the bundled sgl_kernel 0.3.21 wheel ships only sm90+sm100 binaries.
  Strikes through the A6000/4090/L40S "cheap" tiers that the original
  plan recommended — they're unusable without a source build.
- Pre-flight requirements bump CUDA capability floor from 8.0 to 9.0.
- New explicit "libnuma.so.1 required" note: RunPod's stock
  runpod-torch-v240 image doesn't ship it; runner bootstrap
  apt-installs it. (libibverbs etc. no longer required for the
  colocate path thanks to 3f7e708.)
- runpodctl orchestration tips: correct H100 PCIe gpu-id string,
  don't retry pod-create in a tight loop (race window can leak pods).

No code changes in this commit; pure documentation of what
running the cheap-host smoke on a paid GPU actually surfaces.
…s hang

The Phase-4 tiny smoke hangs on real H100 inside `sgl.Engine(...)` with
no diagnostic output from the TP scheduler subprocess. After 14 minutes
the pytest harness kills it. Both the trainer and the engine end up
silent because:

1. The patched sglang code in init_union_default_pg and ModelRunner
   uses `logger.info(...)`, but sglang's TP scheduler subprocess
   inherits the engine's log_level which torchspec wires to "warning"
   in sgl_engine.py:309. INFO logs are silenced.
2. NCCL_DEBUG output only appears after the NCCL backend initialises;
   since we're stuck in (or before) the TCPStore rendezvous, NCCL
   itself stays quiet.

Two changes here, surgical and zero functional impact when the
diagnostic env is not set:

* `torchspec/inference/engine/sgl_engine.py`: `log_level` becomes
  env-overridable via `SGLANG_LOG_LEVEL`. Default unchanged.
* `patches/sglang/v0.5.8.post1/colocate.patch`: eight unconditional
  `print(..., flush=True)` checkpoints with `[TS-COLOCATE-TRACE
  pid=N]` prefix at:
  - is_colocate_active() (every call records the env-var value)
  - init_union_default_pg ENTRY
  - init_union_default_pg after read_colocate_env succeeds
  - init_union_default_pg right before dist.init_process_group
  - init_union_default_pg right after dist.init_process_group returns
  - ModelRunner.init_torch_distributed at the is_colocate_active()
    dispatch point and entering the colocate branch
  These bypass Python logging entirely so the captured output
  survives any sglang log-level config and any silent exception
  handling in the subprocess.

Next iteration: rerun on H100 with `SGLANG_LOG_LEVEL=info` and
`NCCL_DEBUG=INFO`. The captured [TS-COLOCATE-TRACE] checkpoints will
pinpoint whether the TP scheduler reaches init_union_default_pg,
gets stuck in dist.init_process_group, or crashes silently between
the two.
…tions

d99b599 added 8 print() blocks but left the @@ -X,Y +A,B @@ header
counts at their pre-edit values, which made `git apply --recount`
choke ("warning: recount: unexpected line: 2.51.2", "error: corrupt
patch at line 707") because --recount tried to fix up Y/B but ran
past where it expected hunks to end, hitting the git-format-patch
trailer.

Updated headers:
  @@ -0,0 +1,257 @@  ->  @@ -0,0 +1,292 @@   (torchspec_colocate.py)
  @@ -782,21 +787,59 @@  ->  @@ -782,21 +787,75 @@  (ModelRunner)

Counts measured by `grep -c '^+'` within each hunk's body. --recount
will still adjust if I'm a line or two off; this just gets us close
enough that the patch applies cleanly.
Iter 2 on H100 SXM applied the patch cleanly but my print() output
was missing from the captured log even though the sibling logger.info
"Joining TorchSpec union world: ..." DID appear. That means sglang's
TP scheduler subprocess redirects/suppresses stdout but keeps stderr
flowing — print() defaults to stdout, logger.info goes to stderr
through Python logging's StreamHandler.

Convert all 8 TS-COLOCATE-TRACE checkpoints to logger.warning() so
they (a) go through the same stream as the visible logger.info and
(b) survive any log_level config (WARNING is always shown — even
sgl.Engine's default "warning" floor). Both module_runner.py and
torchspec_colocate.py have `logger = logging.getLogger(__name__)`
already in scope, so no new imports needed.

Also adjusts the @@ hunk line counts:
  @@ -0,0 +1,292 @@      ->  @@ -0,0 +1,287 @@   (5 fewer lines)
  @@ -782,21 +787,75 @@   ->  @@ -782,21 +787,72 @@ (3 fewer lines)

(removing 8 `flush=True,` lines, one per checkpoint.)
…eadlock

Iter 3's TS-COLOCATE-TRACE markers proved the union NCCL rendezvous
itself works — `dist.init_process_group` returned cleanly on both
sides. The real hang is one layer deeper:

  Trainer (world.py:init_union_world after init_process_group):
    1. (skipped for 1 trainer)  new_group(nccl, trainer_ranks)
    2. new_group(gloo, all 2N ranks)   <-- meta_group barrier

  Engine TP scheduler (after init_union_default_pg returns):
    1. init_distributed_environment(...)  ~ no-op
    2. initialize_model_parallel(...)
       -> GroupCoordinator.__init__ for each TP/MoE_EP/MoE_TP/PP:
            new_group(nccl, engine_ranks)
            new_group(gloo, engine_ranks)
       8 calls total, ranks=[1] each (tiny config: tp=ep=pp=1).

The world-collective default for `dist.new_group` means *every* rank
in the world group must call it with matching args, even if not in
`ranks`. Trainer is waiting at the gloo meta_group call expecting all
2N ranks; engine is busy with the first sglang TP new_group expecting
the same. Args don't match -> both block forever.

Fix in init_union_default_pg, immediately after init_process_group
returns: monkey-patch `dist.new_group` to default
`use_local_synchronization=True`. From PyTorch docs:

  > use_local_synchronization (bool, optional) - perform a group-local
  > barrier at the end of the process group creation. This is
  > different in that non-member ranks don't need to call into API
  > and don't join the barrier.

This:
* Only applies inside the engine subprocess (each TP scheduler is a
  separate process — the trainer is untouched).
* Is a `setdefault`, so any sglang call that explicitly passes
  `use_local_synchronization=False` (none currently do) continues to
  behave as the caller intended.
* Leaves the meta_group call (collective on all 2N) unaffected
  because it's done by the trainer side via world.py before any
  engine new_group; with sglang now using local-sync, the trainer's
  meta_group barrier completes when the engine reaches its sibling
  meta_group call inside init_union_default_pg (next commit will
  add that explicit call for ordering symmetry, but with use_local_
  synchronization=True we don't actually need it).

For Phase 4+ (multi-trainer FSDP), the trainer's own fsdp_group
new_group call will need use_local_synchronization=True too. Tracked
as a follow-up; the tiny test only has 1 trainer so fsdp_group is
skipped today.

Adjusts @@ -0,0 +1,287 @@ -> @@ -0,0 +1,318 @@ (31 lines added
to the torchspec_colocate.py new-file hunk: 1 logger.warning + the
monkey-patch wrapper + comments).
Iter 4 got past `dist.new_group(ranks=[engine_only_ranks], ...)` thanks
to the use_local_synchronization=True monkey-patch (0a96522), but the
next hang surfaced inside sglang's `init_distributed_environment`:
`init_world_group(ranks=[0..world_size), local_rank, backend)` creates
a GroupCoordinator whose __init__ calls two world-spanning new_groups:

    new_group(ranks=[0..2N), backend=nccl)   # _WORLD device_group
    new_group(ranks=[0..2N), backend=gloo)   # _WORLD cpu_group

These are on *every* world rank, so use_local_synchronization=True
doesn't help — every member must still call. The trainer's
init_union_world (world.py) currently calls only ONE new_group at the
world level (`new_group(ranks=all, gloo)` for its meta_group), and
nowhere else. Engine calls 2, trainer calls 1, in different
positions — c10d deadlock.

Fix: make the call sequence identical on both sides.

torchspec/colocate/world.py — before the existing meta_group call,
add the two world-spanning new_groups that mirror sglang's
init_world_group. Their handles are discarded; the trainer doesn't
use sglang's _WORLD, but the collective bookkeeping must match.
Also flips fsdp_group's new_group call to
use_local_synchronization=True so multi-trainer Phase-4+ runs don't
need the engine to participate.

colocate.patch — in the patched ModelRunner.init_torch_distributed
colocate branch, after `init_distributed_environment` (which creates
the 2 world-spanning new_groups via init_world_group), add one more
new_group(ranks=all, gloo) so the engine catches up to the trainer's
meta_group. For ranks covering the whole world the
use_local_synchronization=True monkey-patch default is equivalent
to a world-collective call.

Resulting matched sequence (each row is one collective barrier):

    Trainer (world.py)         | Engine (ModelRunner colocate path)
    ---------------------------|-----------------------------------
    init_process_group(nccl)   | init_process_group(nccl)  via init_union_default_pg
    new_group(nccl, all)       | (sglang init_world_group _WORLD device_group)
    new_group(gloo, all)       | (sglang init_world_group _WORLD cpu_group)
    new_group(gloo, all)       | new_group(gloo, all)      explicit meta_group
                               |                           barrier in colocate.patch
    [done]                     | initialize_model_parallel
                               | (engine-local groups via use_local_synchronization)

This pattern works for any 2N union world (tiny test = 2 ranks, Phase
4+ = 8 ranks).

@@ -782,21 +787,72 @@ -> @@ -782,21 +787,88 @@ (+16 lines for
the meta_group call + comments).
iter5 hit `error: corrupt patch at line 750` again. Off-by-4 in the
@@ +line count of the last hunk — I forgot the 6 context lines wrap
around the 86 added lines, so the +side count is 86+6=92, not 88.
(0a96522 went 75 -> 88 to account for +13 monkey-patch lines, but
that already missed the same +/-3-context adjustment.)

Counted directly with `grep -c '^+' < hunk` -> 86 strict + lines.
Add 6 context lines (3 before, 3 after init_distributed_environment
block) -> 92 total +side hunk lines.
Iter 6 hit the next mismatch: the engine TP scheduler subprocess has
a monkey-patched dist.new_group that defaults
use_local_synchronization=True. The trainer's world.py was passing
use_local_synchronization=False (default) for the sglang-paired
nccl/gloo world new_groups and the meta_group. c10d's rendezvous
semantics don't reconcile across two ranks passing different flag
values, so the very first paired new_group hung.

Make all of world.py's union-world subgroups pass
use_local_synchronization=True, matching the engine's monkey-patch
default. For groups whose `ranks` covers the entire world (2N ranks,
every rank a member) this is equivalent to a world-collective call —
every rank participates either way — but the semantics agree on
both sides now.

Also add INFO log lines around the new_group sequence so the next
iteration can confirm the trainer side actually completed.
Iter 7 cleared the new_group rendezvous deadlock; iter 8 surfaces the
next bug: sglang's initialize_dp_attention computes its attn_tp group
ranks from `range(0, pp_size * tp_size)`, which lands in the trainer
half [0, N) of the union world. Engine ranks are at [N, 2N), so
GroupCoordinator's `self.rank in ranks` membership check fails on
every engine and the `assert self.cpu_group is not None` at the end
of __init__ fires:

  File ".../sglang/srt/layers/dp_attention.py", line 296, in initialize_dp_attention
    _ATTN_TP_GROUP = GroupCoordinator(
  File ".../sglang/srt/distributed/parallel_state.py", line 292, in __init__
    assert self.cpu_group is not None
AssertionError

I attempted to fix this with a colocate.patch hunk on dp_attention.py
but the unified-diff context+line-count was finicky enough that
--recount kept choking on the format-patch trailer. Switched to a
post-`git apply` Python string-substitution step inside
`setup_sglang` in `scripts/colocate/run_smoke_host.sh`:

  1. Match the literal anchor "    _ATTN_TP_GROUP = GroupCoordinator(".
  2. Inject 14 lines of code above it that compute
     `_ts_offset = read_colocate_env().n_per_role` if
     is_colocate_active() else 0.
  3. Rewrite the offending list comprehension to use the offset:
     `list(range(_ts_offset + head, _ts_offset + head + _ATTN_TP_SIZE))`.

Both substitutions are string-stable across sglang 0.5.x; the script
asserts the anchor exists and that a substitution actually happened
(no silent drift).

Default offset 0 means non-colocate runs (where
`is_colocate_active()` returns False) are byte-identical.
Iter 9 made it past every previous deadlock and reached
trainer.py:_setup_device_mesh, where it ran:

    self.mesh = init_device_mesh(
        "cuda", mesh_shape=(dist.get_world_size(),), ...
    )

Under colocate the default PG is the 2N-rank union world (1 trainer
+ 1 engine for the tiny test), but the trainer's DP mesh should only
span the trainer half [0, N). The unconditional dist.get_world_size()
made the trainer try to build a 2-rank DP mesh that includes the
engine — FSDP collectives on that mesh would deadlock on the first
all-reduce because the engine isn't an FSDP participant.

Fix _setup_device_mesh to:
  1. Prefer args.world_size / args.rank (which trainer_actor.py
     overrides to the trainer-subgroup values when colocate is on)
     over dist.get_*; falls back to the dist values for the
     non-colocate path.
  2. When the resulting world_size is smaller than the dist world
     (i.e. we're in a sub-world), build a trainer-only NCCL group
     via dist.new_group(use_local_synchronization=True) and attach
     a DeviceMesh.from_group rather than the default init_device_mesh
     path that requires the default PG to match the mesh shape.

The non-colocate path is byte-identical: dist_world_size ==
args_world_size, so the existing init_device_mesh branch fires.

Logged the mesh kind ("1D" vs "1D-colocate-sub") + both world sizes
for diagnosability.
Documents the 10-iteration chain that peeled off NCCL deadlock layers
between trainer rank 0 and engine TP scheduler subprocess rank 1 in
the colocate union world. Each iter found one more collective
new_group mismatch or rank-offset bug; each fix progressed the test
further. End state: both sides past every previously-known
deadlock + dp_attention rank offset + trainer DP mesh world_size
confusion. The remaining hang is in model load / first NCCL
collective on a 1-rank NCCL group (a regime the original world.py
comment explicitly warned about).

Lists every commit landed this session, lays out the next-session
debug plan, and the cost ($8.46 spent on 10 iters at $2.99/hr on
H100 SXM SECURE Iceland).

No code changes in this commit, just documentation of what we
learned and where we left it.
Iter 10 found the hang: eagle3_trainer.py:82 calls
`dist.barrier(group=get_gloo_group())` after the rank-0-only
load_embedding. In colocate mode `GLOO_GROUP` was bound to
`union_world.meta_group` — the 2N-rank gloo group spanning both
trainer and engine. The engine never enters the trainer's
init_model code path, so it never enters that barrier, and the
trainer hangs forever waiting.

Three changes:

1. `colocate/world.py`: build a `trainer_gloo_group` alongside
   `meta_group` — gloo, ranks=`[0, N)`, use_local_synchronization=True
   so engines skip the call. Expose it on `UnionWorld`. For
   1-trainer tiny test, this is a 1-rank gloo group; gloo handles
   1-rank groups cleanly (unlike NCCL).

2. `training/trainer_actor.py`: bind
   `_dist_utils.GLOO_GROUP = self._union_world.trainer_gloo_group`
   instead of `.meta_group`. Now every
   `dist.barrier(get_gloo_group())` in the trainer's init / step path
   syncs only the trainer half. Comment updated to spell out the
   pitfall.

3. `training/trainer.py:_setup_device_mesh`: switch the
   colocate-sub-world DP group's backend from NCCL to GLOO for the
   1-trainer case. NCCL 1-rank groups have known eager-init issues
   (the original world.py comment flagged this); gloo doesn't.
   For >=2 trainers we keep NCCL so DP all-reduce stays on the GPU
   path.

Also adds `[TS-COLOCATE-TRACE-T]` `logger.warning` markers around
eagle3_trainer.init_model's major phases (draft model build, load
embedding, gloo barrier, FSDP wrap, full state dict load) so the
next iter pinpoints any new hang within model init at sub-second
resolution.
Iter 11 with the trainer-only gloo group fix (08976e5) cleared every
previously-known deadlock and got the trainer all the way to
`apply_fsdp2` (~10s after device-mesh creation). The next hang is
`fsdp2_load_full_state_dict`:

  File ".../torchspec/training/fsdp.py", line 137
      for _name, buf in model.named_buffers():
          dist.broadcast(buf, src=0)        # <-- uses DEFAULT PG (union world)

In colocate mode the default PG is the 2N-rank union world (1
trainer + 1 engine for the tiny test). The engine never enters this
code path, so the trainer blocks forever waiting for engine
participation in the buffer broadcast.

Fix: pull the mesh group out of `device_mesh.get_group()` (which is
the trainer-only sub-mesh under colocate, and the full world under
non-colocate) and pass it explicitly to `dist.broadcast`. Also
translate `src=0` to the global rank corresponding to mesh-rank 0
via `dist.get_global_rank(mesh_group, 0)`, so the broadcast source
is correct regardless of the mesh-vs-default-PG offset (under
colocate the trainer's mesh-rank-0 is also union-rank-0, so no
translation happens; under non-colocate this is also a no-op).

Adds [TS-COLOCATE-TRACE-T] markers around the three steps
(set_model_state_dict, buffer broadcasts, finish) so the next iter
can pinpoint any remaining hang.

Note: `set_model_state_dict(broadcast_from_rank0=True)` uses the
FSDP-wrapped module's mesh internally (FSDP attached `mesh` at
apply_fsdp2 time), so it should already broadcast only on the
trainer sub-mesh. If iter 12 hangs at AFTER set_model_state_dict
that assumption is wrong and we need to override its group too.
Iter 12 confirmed: the trainer-only-gloo + mesh-scoped-broadcast
fixes cleared everything up to `fsdp2_load_full_state_dict`, which
now hangs *inside* `set_model_state_dict`. PyTorch's
set_model_state_dict with `broadcast_from_rank0=True` broadcasts the
rank-0 state dict across the **default** process group — which under
colocate is the 2N-rank union world. The engine never enters this
code path, so the broadcast deadlocks.

For the tiny smoke (dp_size=1) the FSDP mesh is a single trainer
rank. There is nothing to broadcast — rank 0 already holds the full
state dict — so disable broadcast_from_rank0 when the mesh has one
rank and let rank 0 load locally. The buffer-broadcast loop below
(already scoped to mesh_group in 2d44799) is a no-op on a 1-rank
group.

Multi-trainer colocate (dp_size>=2) still passes
broadcast_from_rank0=True; that path would need set_model_state_dict
to accept an explicit group (it doesn't today) or a manual
per-param broadcast on mesh_group. Tracked as a follow-up — the
1-GPU tiny smoke is what we're unblocking right now, and it's
dp_size=1.
Iter 13 cleared the FSDP state-dict load; the trainer then hangs
right after `init_model`'s fsdp2_load_full_state_dict with a NCCL
`dist.barrier()` warning ("Guessing device ID based on global rank,
this can cause a hang"). Root cause is the same family of bug as
every prior iter: bare `dist.*` collectives default to the union-
world PG, the engine never enters the trainer code path, deadlock.

Swept the trainer-side init + train-loop code for unscoped
collectives and scoped them all to `get_gloo_group()` (which
trainer_actor.py now binds to the trainer-only subgroup):

* eagle3_trainer.py `_init_target_lm_head`:
  - `dist.broadcast(has_norm, src=0)`         -> group=trainer
  - `dist.barrier()`                          -> group=trainer
  - `dist.broadcast(param.data, src=0)` (loop) -> group=trainer
* eagle3_trainer.py metric aggregation (train + eval paths):
  - 2x `dist.all_reduce(avg_vlosses/avg_acces, AVG)` -> group=trainer
* checkpoint.py: 4x bare `dist.barrier()` -> group=trainer
  (+ added the get_gloo_group import)
* trainer.py save path: `dist.barrier()` -> group=trainer
  (+ added the get_gloo_group import)

On the 1-trainer tiny config the trainer group is a single rank, so
every one of these collectives is a no-op — they were only hanging
because they were (wrongly) ranging over the 2-rank union world. On
>=2 trainers they become real trainer-replica collectives, which is
the correct semantics (these sync trainer state, the engine has no
business participating).
Iter 14 cleared every distributed-init deadlock — the run now FAILS
FAST (55s, not a 15-min timeout) with a real error:

  KeyError: 'Key lm_head.weight not found in
  .../Qwen3-0.6B-Base/model.safetensors'

Qwen3-0.6B-Base (like Llama-3.2, Gemma, and most small models) has
`tie_word_embeddings=True` — there is no standalone `lm_head.weight`
tensor; the LM head shares the input-embedding matrix stored under
`model.embed_tokens.weight`.

`TargetLMHead._load_lm_head` only ever looked for the literal
`lm_head_key` ("lm_head.weight" by default) and KeyError'd when it
was absent. Fix: when the model config has
`tie_word_embeddings=True`, fall back to `model.embed_tokens.weight`.
`_load_key_from_file` now takes an optional `fallback_key` and tries
both in order. Applies to both the sharded (index.json) and
single-file checkpoint paths.

This is a general correctness fix, not colocate-specific — any
tied-embedding target model was previously unloadable by
TargetLMHead.
Iter 15 cleared init_model entirely (tied-embedding fix landed) —
trainer logs "Eagle3 model initialized with FSDP2". The engine TP
scheduler now hangs right after initialize_dp_attention, in
model_runner.py:

  min_per_gpu_memory = get_available_gpu_memory(
      self.device, self.gpu_id,
      distributed=get_world_group().world_size > 1,
      cpu_group=get_world_group().cpu_group,
  )

sglang's `_WORLD` was built by init_distributed_environment spanning
the full 2N-rank union world (the patch deliberately passed
world_size=colocate_env.world_size=2). So `get_world_group().world_size
> 1` is True and get_available_gpu_memory does a distributed memory
sync on `_WORLD.cpu_group` — a collective across all 2N ranks. The
trainer ranks [0, N) never run sglang code, so it hangs.

The patch's original comment claimed downstream sglang "expects" a
2N-rank _WORLD. That's wrong: the engine is one half of the union
world; sglang's notion of "world" should be the engine ranks only.

Fix: new `rebuild_world_group_engine_only(env, local_rank, backend)`
in torchspec_colocate.py. Called from the ModelRunner colocate branch
right after init_distributed_environment — it drops the 2N-rank
_WORLD GroupCoordinator and rebuilds it spanning only
build_engine_tp_ranks(env) = [N, 2N). The new_group calls inside
init_world_group inherit the use_local_synchronization=True
monkey-patch, so only engine ranks participate.

After this, get_world_group().world_size == 1 (tiny config) so the
memory sync short-circuits, and every other sglang world-level
collective is engine-scoped.

@@ headers updated: torchspec_colocate.py new-file +1,318 -> +1,354;
model_runner import hunk +58,11 -> +58,12; ModelRunner hunk
+787,92 -> +787,105.
…urgery)

Iter 16 with the engine-only _WORLD rebuild (8bdc8d4) got the engine
all the way through model load + KV cache allocation
("KV Cache is allocated. #tokens: 315251"). It then fast-failed (55s,
no hang) with:

  File ".../sglang/srt/managers/tp_worker.py", line 292, in __init__
    self.random_seed = broadcast_pyobj(
  IndexError: list index out of range

Root cause is in the sglang callsite, not our code. tp_worker.py's
random-seed sync does:

  self.random_seed = broadcast_pyobj(
      [server_args.random_seed],
      self.tp_size * self.pp_rank + tp_rank,   # tp-local rank = 0
      self.world_group.cpu_group,
      src=self.world_group.ranks[0],            # global rank = 1
  )[0]

broadcast_pyobj's docstring: the `rank` arg is the *global* rank.
In standalone sglang the engine owns the whole world so tp-local
rank == global rank — works. Under colocate the engine's global
rank is N (=1) but its tp-local rank is 0, so `rank(0) != src(1)`:
the engine wrongly takes broadcast_pyobj's *receiver* branch,
receives size 0, returns `[]`, and the trailing `[0]` raises
IndexError.

Fix: pass `self.world_group.rank` (the GroupCoordinator's global
rank, == torch.distributed.get_rank()) instead of the tp-local
arithmetic. Correct for both colocate (global rank 1 == src 1 ->
sender path) and standalone (global == tp-local, unchanged).

Implemented as a second post-`git apply` string-substitution step
in setup_sglang (alongside the dp_attention one) — the unified-diff
format has been too fragile for these small sglang tweaks.
The colocate training loop sizes the NCCL hidden-states transfer buffer
up front and needs aux_hidden_states_layers on args before it starts.
DFlash already had an auto-resolver; Eagle3 colocate runs hit a hard
RuntimeError. Resolve via get_default_eagle3_aux_layer_ids — the same
function sgl_engine falls back to — so trainer and engine agree.
In colocate mode the trainer and engine share one physical GPU, so the
union world's NCCL backend cannot form a communicator spanning both
ranks — NCCL hard-errors "Duplicate GPU detected". The Phase-3 P2P
smoke validated on 2 separate GPUs (1 rank each), which never exercised
this.

Route the engine->trainer hidden-state transfer over the existing
all-rank gloo meta_group with host-memory staging instead:
- NcclHiddenStatesConnector.send / NcclMultiTensorFetcher.recv_step
  branch on the group backend; gloo path stages through CPU and uses
  tagged dist.send/recv (no batch_isend_irecv, which is NCCL-only).
- trainer passes union_world.meta_group to the fetcher.
- colocate.patch exposes the engine-side meta_group via
  set/get_union_meta_group so build_hidden_states_writer can hand it
  to the connector.

The NCCL batched path is kept for the separate-GPU dummy P2P tests.
The colocate loop logged metrics.get("train/loss"), but both the
Eagle3 and DFlash trainers' _aggregate_metrics emit "train/avg_loss"
(matching the disagg loop in loop.py). The wrong key made every step
log loss=None even though training ran fine — which also tripped
test_phase7_tiny_loss_decreases, whose log parser found zero loss
points.
Documents the iter chain that took the 1-GPU colocate smoke from the
end-of-iter-10 hang to both tiny tests passing on 1xH100, and records
the key architectural correction: NCCL cannot form a communicator with
two ranks on one physical GPU, so the colocate hidden-state plane runs
over gloo (host-staged), not the NCCL union world.
test_stability parses `step=N ... peak_alloc...=X` from the colocate
loop's per-step log line, but the line only carried step/step_time/
loss/lr. Add peak_alloc from metrics["perf/peak_bytes_allocated"]
(which the trainer already populates every step via
TrainProfiler.peak_alloc_metrics) so the Phase-6 leak check has data
points to compare.
@zhubohao911 zhubohao911 force-pushed the feature/colocate-training-inference branch from bf2d468 to 927beaa Compare May 14, 2026 20:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant