Skip to content

Support USP sequence parallel attention for eagle3 training#93

Merged
yubofredwang merged 17 commits into
lightseekorg:mainfrom
uygnef:dev/usp
May 9, 2026
Merged

Support USP sequence parallel attention for eagle3 training#93
yubofredwang merged 17 commits into
lightseekorg:mainfrom
uygnef:dev/usp

Conversation

@uygnef
Copy link
Copy Markdown
Collaborator

@uygnef uygnef commented May 8, 2026

This PR updates USP training to use SGLang-produced, pre-sharded Mooncake tensors instead of loading a full sample on one
training rank and scattering it during prefetch.

The main changes are:

  • Store USP rank-local shards in Mooncake from the SGLang producer side.
  • Fan out USP sample metadata to all ranks in the corresponding SP group.
  • Make each training rank read only its own {mooncake_key}_usp{sp_rank} tensors.
  • Reconstruct local loss_mask, attention_mask, and position_ids in the USP data fetcher.
  • Fix a USP zero-loss local shard case that could skip attention backward collectives.
  • Add coverage for the local-zero-loss/global-nonzero-loss USP shard case.

Motivation

The previous USP path loaded full Mooncake tensors in the training prefetch path and then distributed local shards across SP
ranks. This made prefetch do extra distributed communication and exposed a collective ordering issue when one USP shard had no
local loss tokens while another shard in the same Ulysses group did.

With this change, SGLang writes per-SP-rank tensors directly to Mooncake. Training workers only receive lightweight metadata
through Ray queues and independently load their own shard.

Data Flow

flowchart TD
     A[SGLang generate / prefill] --> B[hidden_states, input_ids, last_hidden_states]
     B --> C[EagleMooncakeStore.put_usp_shards]

     C --> D{for each sp_rank}
     D --> E[split_usp_batch]
     E --> F0[Mooncake key: key_usp0<br/>input_ids shard<br/>hidden_states shard<br/>target/lhs shard]
     E --> F1[Mooncake key: key_usp1<br/>input_ids shard<br/>hidden_states shard<br/>target/lhs shard]
     E --> FN[Mooncake key: key_uspN<br/>input_ids shard<br/>hidden_states shard<br/>target/lhs shard]

     C --> G[InferenceOutput<br/>mooncake_key=key<br/>tensor_shapes=global shapes<br/>packed_loss_mask<br/>metadata:
 usp_sharded=true]

     G --> H[AsyncInferenceManager<br/>merge metadata]
     H --> I[AsyncTrainingController]
     I --> J{DP rank's SP group}
     J --> Q0[Queue for train rank sp0]
     J --> Q1[Queue for train rank sp1]
     J --> QN[Queue for train rank spN]

     Q0 --> R0[Trainer rank sp0]
     Q1 --> R1[Trainer rank sp1]
     QN --> RN[Trainer rank spN]

     R0 --> S0[read Mooncake key_usp0]
     R1 --> S1[read Mooncake key_usp1]
     RN --> SN[read Mooncake key_uspN]

     S0 --> T0[reconstruct local loss_mask<br/>attention_mask<br/>position_ids]
     S1 --> T1[reconstruct local loss_mask<br/>attention_mask<br/>position_ids]
     SN --> TN[reconstruct local loss_mask<br/>attention_mask<br/>position_ids]

     T0 --> U[USP forward/backward]
     T1 --> U
     TN --> U
Loading

Implementation Details

SGLang / Mooncake

  • Added EagleMooncakeStore.put_usp_shards(...).
  • The producer splits input_ids, hidden_states, and last_hidden_states / target by USP rank.
  • Shards are stored under rank-local keys:
    • {mooncake_key}usp0*
    • {mooncake_key}usp1*
    • ...
  • SGLang sets metadata={"usp_sharded": True} on USP outputs.
  • SGLang patch files are updated so patched SGLang calls put_usp_shards(...) when TORCHSPEC_USP_SHARDED_MOONCAKE=1.

Controller

  • USP training now creates one queue per SP rank.
  • The controller fans out each sample metadata object to all ranks in the DP rank's SP group.
  • Inference metadata is merged with original sample metadata so fields like has_thinking and usp_sharded are both preserved.

Data Fetcher

  • USP mode now only supports sharded Mooncake samples.
  • Each rank reads:
    • f"{mooncake_key}_usp{sp_rank}"
  • The old full-sample load + training-side scatter path is removed.
  • The fetcher reconstructs local:
    • loss_mask
    • attention_mask
    • position_ids
  • USP batches require explicit attention_mask; non-USP batches keep the existing all-ones fallback.

Collective Ordering Fix

A USP local shard can have zero local loss tokens while its Ulysses peer has nonzero loss tokens. Previously, the zero-loss
shard could skip the attention backward graph, while its peer still executed attention backward all-to-all collectives. That
caused collective ordering divergence.

This PR keeps the USP zero-loss path connected to hidden_states:

local_sum_loss = local_sum_loss + hidden_states.sum() * 0.0

This does not change the loss value, but preserves the same autograd collective sequence across USP ranks.

USP attention correctness and microbenchmark

Correctness is validated against LlamaFlexAttention:

  • PYTHONPATH=. python -m unittest tests.test_usp_attention.TestUSPAttention
  • Result: Ran 2 tests in 50.9s, OK
Mode Config Forward output max diff Reduced loss diff Projection grad max diff Input grad max diff
Ulysses USP sp_ulysses_size=2, sp_ring_size=1 1.6e-1 6.7e-5 <= 7.0e-5 <= 1.9e-6
Ring USP sp_ulysses_size=1, sp_ring_size=2 1.8e-1 4.5e-5 <= 5.5e-5 <= 1.9e-6

Attention-only microbenchmark setup:

  • seq_len=8192, global_batch_size=2
  • FlexAttention: 2-card DP, each rank processes batch_size=1
  • USP: 2-card SP, each global sequence is sharded across 2 ranks, with gradient accumulation 2
Mode Config Median fwd+bwd time Mean fwd+bwd time Relative median time
FlexAttention DP 2-card DP 35.7-36.9 ms 36.0-37.2 ms 1.00x
Ulysses USP sp_ulysses_size=2, sp_ring_size=1 39.6 ms 40.9 ms 1.11x
Ring USP sp_ulysses_size=1, sp_ring_size=2 39.7 ms 39.5 ms 1.08x

I also ran an end-to-end training comparison between the FlexAttention baseline and USP:

  • llama31_8b_align_flex_match_usp: FlexAttention baseline
  • l31_usp_u2_bf16: USP with sp_ulysses_size=2
loss and accuracy comparison gradient norm comparison

Limitations

  • Currently only SGLang supports producing USP-sharded Mooncake tensors.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 13b468bf78

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

position_ids: Optional[torch.LongTensor] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
del attention_mask, use_cache
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Honor padding masks in USP attention

In USP training the collator pads each rank-local shard to a 256-token bucket and passes an attention_mask with zeros for that padding, but this attention implementation immediately discards the mask. When sp_ulysses_size or sp_ring_size is greater than 1, the sequence all-to-all/ring attention treats those padded zero hidden states as real past keys for tokens on later shards, changing the softmax denominator and corrupting training outputs for any sequence whose local shard length is not already exactly the padded bucket size.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fine for the current path: main attention is causal, so real tokens cannot attend to later padded tokens. The padded positions themselves are filtered by loss_mask, so they do not contribute to loss or gradients.

Comment thread torchspec/models/eagle3.py
uygnef added 12 commits May 8, 2026 17:15
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
@yubofredwang
Copy link
Copy Markdown
Collaborator

Thanks for the PR. Doing a verification locally

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 4f16c2b4df

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread torchspec/training/trainer.py
uygnef added 2 commits May 8, 2026 17:38
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 849d27ecdb

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +160 to +169
if self.attention_backend == "usp":
usp_chunk_size = seq_length - self.length
if usp_chunk_size <= 0:
raise ValueError(
f"USP local seq_length ({seq_length}) must be larger than ttt_length ({self.length})"
)
if position_ids is None:
device = hidden_states.device
ring_chunk_size = usp_chunk_size * self._usp_ulysses_world_size
position_start = get_sp_ring_rank() * ring_chunk_size + past_key_values_length
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Derive USP positions from the unbucketed shard length

When a USP shard length is not already on the collator's 256-token bucket, seq_length here is the padded batch length rather than the true chunk_size + ttt_length. That makes usp_chunk_size and ring_chunk_size too large, so ring rank > 0 starts RoPE positions after the padding gap (for example a 250-token shard padded to 256 starts the next ring at 510 instead of 500), and the attention path also treats the bucket padding as interior sequence tokens. This affects the normal training path because _forward does not pass the position_ids computed by the data fetcher, so the fallback below is used after collation.

Useful? React with 👍 / 👎.

Signed-off-by: Yu Feng <admin@fengyu.org>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 1098e20c1f

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread torchspec/training/data_fetcher.py Outdated
Comment on lines +289 to +291
bucket = 256
padded_global_len = ((global_len + bucket - 1) // bucket) * bucket
return (padded_global_len + self._sp_world_size - 1) // self._sp_world_size
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Use the writer's USP shard length

When USP is enabled, the SGLang writer stores shards via EagleMooncakeStore.put_usp_shards()/split_usp_batch(), whose chunk size is ceil(global_len / sp_size) and does not round to a 256-token bucket. This reader requests ceil(round_up_256(global_len) / sp_size) + ttt_length instead; for example global_len=1000, sp_size=4, ttt_length=1 stores length 251 but reads length 257, so EagleMooncakeStore.get() raises a Mooncake size mismatch before training can start. Please make the reader use the same shard length formula or explicit shard metadata from the writer.

Useful? React with 👍 / 👎.

Signed-off-by: Yu Feng <admin@fengyu.org>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f52e69fd36

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread torchspec/utils/usp.py Outdated
Signed-off-by: Yu Feng <admin@fengyu.org>
@yubofredwang yubofredwang merged commit 7b5473a into lightseekorg:main May 9, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants