Support USP sequence parallel attention for eagle3 training#93
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 13b468bf78
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| position_ids: Optional[torch.LongTensor] = None, | ||
| use_cache: bool = False, | ||
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: | ||
| del attention_mask, use_cache |
There was a problem hiding this comment.
Honor padding masks in USP attention
In USP training the collator pads each rank-local shard to a 256-token bucket and passes an attention_mask with zeros for that padding, but this attention implementation immediately discards the mask. When sp_ulysses_size or sp_ring_size is greater than 1, the sequence all-to-all/ring attention treats those padded zero hidden states as real past keys for tokens on later shards, changing the softmax denominator and corrupting training outputs for any sequence whose local shard length is not already exactly the padded bucket size.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
This should be fine for the current path: main attention is causal, so real tokens cannot attend to later padded tokens. The padded positions themselves are filtered by loss_mask, so they do not contribute to loss or gradients.
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
|
Thanks for the PR. Doing a verification locally |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 4f16c2b4df
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
Signed-off-by: Yu Feng <admin@fengyu.org>
Signed-off-by: Yu Feng <admin@fengyu.org>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 849d27ecdb
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if self.attention_backend == "usp": | ||
| usp_chunk_size = seq_length - self.length | ||
| if usp_chunk_size <= 0: | ||
| raise ValueError( | ||
| f"USP local seq_length ({seq_length}) must be larger than ttt_length ({self.length})" | ||
| ) | ||
| if position_ids is None: | ||
| device = hidden_states.device | ||
| ring_chunk_size = usp_chunk_size * self._usp_ulysses_world_size | ||
| position_start = get_sp_ring_rank() * ring_chunk_size + past_key_values_length |
There was a problem hiding this comment.
Derive USP positions from the unbucketed shard length
When a USP shard length is not already on the collator's 256-token bucket, seq_length here is the padded batch length rather than the true chunk_size + ttt_length. That makes usp_chunk_size and ring_chunk_size too large, so ring rank > 0 starts RoPE positions after the padding gap (for example a 250-token shard padded to 256 starts the next ring at 510 instead of 500), and the attention path also treats the bucket padding as interior sequence tokens. This affects the normal training path because _forward does not pass the position_ids computed by the data fetcher, so the fallback below is used after collation.
Useful? React with 👍 / 👎.
Signed-off-by: Yu Feng <admin@fengyu.org>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 1098e20c1f
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| bucket = 256 | ||
| padded_global_len = ((global_len + bucket - 1) // bucket) * bucket | ||
| return (padded_global_len + self._sp_world_size - 1) // self._sp_world_size |
There was a problem hiding this comment.
Use the writer's USP shard length
When USP is enabled, the SGLang writer stores shards via EagleMooncakeStore.put_usp_shards()/split_usp_batch(), whose chunk size is ceil(global_len / sp_size) and does not round to a 256-token bucket. This reader requests ceil(round_up_256(global_len) / sp_size) + ttt_length instead; for example global_len=1000, sp_size=4, ttt_length=1 stores length 251 but reads length 257, so EagleMooncakeStore.get() raises a Mooncake size mismatch before training can start. Please make the reader use the same shard length formula or explicit shard metadata from the writer.
Useful? React with 👍 / 👎.
Signed-off-by: Yu Feng <admin@fengyu.org>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f52e69fd36
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
Signed-off-by: Yu Feng <admin@fengyu.org>
This PR updates USP training to use SGLang-produced, pre-sharded Mooncake tensors instead of loading a full sample on one
training rank and scattering it during prefetch.
The main changes are:
{mooncake_key}_usp{sp_rank}tensors.loss_mask,attention_mask, andposition_idsin the USP data fetcher.Motivation
The previous USP path loaded full Mooncake tensors in the training prefetch path and then distributed local shards across SP
ranks. This made prefetch do extra distributed communication and exposed a collective ordering issue when one USP shard had no
local loss tokens while another shard in the same Ulysses group did.
With this change, SGLang writes per-SP-rank tensors directly to Mooncake. Training workers only receive lightweight metadata
through Ray queues and independently load their own shard.
Data Flow
flowchart TD A[SGLang generate / prefill] --> B[hidden_states, input_ids, last_hidden_states] B --> C[EagleMooncakeStore.put_usp_shards] C --> D{for each sp_rank} D --> E[split_usp_batch] E --> F0[Mooncake key: key_usp0<br/>input_ids shard<br/>hidden_states shard<br/>target/lhs shard] E --> F1[Mooncake key: key_usp1<br/>input_ids shard<br/>hidden_states shard<br/>target/lhs shard] E --> FN[Mooncake key: key_uspN<br/>input_ids shard<br/>hidden_states shard<br/>target/lhs shard] C --> G[InferenceOutput<br/>mooncake_key=key<br/>tensor_shapes=global shapes<br/>packed_loss_mask<br/>metadata: usp_sharded=true] G --> H[AsyncInferenceManager<br/>merge metadata] H --> I[AsyncTrainingController] I --> J{DP rank's SP group} J --> Q0[Queue for train rank sp0] J --> Q1[Queue for train rank sp1] J --> QN[Queue for train rank spN] Q0 --> R0[Trainer rank sp0] Q1 --> R1[Trainer rank sp1] QN --> RN[Trainer rank spN] R0 --> S0[read Mooncake key_usp0] R1 --> S1[read Mooncake key_usp1] RN --> SN[read Mooncake key_uspN] S0 --> T0[reconstruct local loss_mask<br/>attention_mask<br/>position_ids] S1 --> T1[reconstruct local loss_mask<br/>attention_mask<br/>position_ids] SN --> TN[reconstruct local loss_mask<br/>attention_mask<br/>position_ids] T0 --> U[USP forward/backward] T1 --> U TN --> UImplementation Details
SGLang / Mooncake
Controller
Data Fetcher
Collective Ordering Fix
A USP local shard can have zero local loss tokens while its Ulysses peer has nonzero loss tokens. Previously, the zero-loss
shard could skip the attention backward graph, while its peer still executed attention backward all-to-all collectives. That
caused collective ordering divergence.
This PR keeps the USP zero-loss path connected to hidden_states:
local_sum_loss = local_sum_loss + hidden_states.sum() * 0.0
This does not change the loss value, but preserves the same autograd collective sequence across USP ranks.
USP attention correctness and microbenchmark
Correctness is validated against
LlamaFlexAttention:PYTHONPATH=. python -m unittest tests.test_usp_attention.TestUSPAttentionRan 2 tests in 50.9s,OKsp_ulysses_size=2, sp_ring_size=11.6e-16.7e-5<= 7.0e-5<= 1.9e-6sp_ulysses_size=1, sp_ring_size=21.8e-14.5e-5<= 5.5e-5<= 1.9e-6Attention-only microbenchmark setup:
seq_len=8192,global_batch_size=2batch_size=1235.7-36.9 ms36.0-37.2 ms1.00xsp_ulysses_size=2, sp_ring_size=139.6 ms40.9 ms1.11xsp_ulysses_size=1, sp_ring_size=239.7 ms39.5 ms1.08xI also ran an end-to-end training comparison between the FlexAttention baseline and USP:
llama31_8b_align_flex_match_usp: FlexAttention baselinel31_usp_u2_bf16: USP withsp_ulysses_size=2Limitations