Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions patches/sglang/v0.5.10.post1/sglang.patch
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,37 @@ index 496cd9665..45ceb0e20 100644
+ req_hidden_states.cpu().clone().tolist()
+ )
+
+ def _put_spec_training_mooncake(
+ self: Scheduler,
+ key: str,
+ hidden_states: torch.Tensor,
+ input_ids: torch.Tensor,
+ last_hidden_states: Optional[torch.Tensor],
+ ):
+ import os
+
+ if os.getenv("TORCHSPEC_USP_SHARDED_MOONCAKE") == "1":
+ max_seq_raw = os.environ.get("TORCHSPEC_USP_MAX_SEQ_LENGTH")
+ self.eagle_mooncake_store.put_usp_shards(
+ key=key,
+ hidden_states=hidden_states,
+ input_ids=input_ids,
+ last_hidden_states=last_hidden_states,
+ target=None,
+ sp_size=int(os.environ["TORCHSPEC_USP_SP_SIZE"]),
+ sp_ring_size=int(os.environ.get("TORCHSPEC_USP_RING_SIZE", "1")),
+ ttt_length=int(os.environ.get("TORCHSPEC_USP_TTT_LENGTH", "1")),
+ max_seq_length=int(max_seq_raw) if max_seq_raw else None,
+ )
+ return
+
+ self.eagle_mooncake_store.put(
+ key=key,
+ hidden_states=hidden_states,
+ input_ids=input_ids,
+ last_hidden_states=last_hidden_states,
+ )
+
+ def _send_hidden_states_to_mooncake(
+ self: Scheduler,
+ req: Req,
Expand Down Expand Up @@ -652,7 +683,7 @@ index 496cd9665..45ceb0e20 100644
+ if hidden_states.is_cuda and copy_done_event is not None:
+ torch.cuda.current_stream().wait_event(copy_done_event)
+
+ self.eagle_mooncake_store.put(
+ self._put_spec_training_mooncake(
+ key=key,
+ hidden_states=hidden_states,
+ input_ids=input_ids,
Expand Down Expand Up @@ -1006,4 +1037,3 @@ index 000000000..24af14b7a
+ return len(self.data_ids) == 0
--
2.43.0

33 changes: 32 additions & 1 deletion patches/sglang/v0.5.8.post1/sglang.patch
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,37 @@ index d79f929d3..3eebfacb5 100644
+ pass
+ # req.hidden_states.append(req_hidden_states.cpu().clone().tolist())
+
+ def _put_spec_training_mooncake(
+ self: Scheduler,
+ key: str,
+ hidden_states: torch.Tensor,
+ input_ids: torch.Tensor,
+ last_hidden_states: Optional[torch.Tensor],
+ ):
+ import os
+
+ if os.getenv("TORCHSPEC_USP_SHARDED_MOONCAKE") == "1":
+ max_seq_raw = os.environ.get("TORCHSPEC_USP_MAX_SEQ_LENGTH")
+ self.eagle_mooncake_store.put_usp_shards(
+ key=key,
+ hidden_states=hidden_states,
+ input_ids=input_ids,
+ last_hidden_states=last_hidden_states,
+ target=None,
+ sp_size=int(os.environ["TORCHSPEC_USP_SP_SIZE"]),
+ sp_ring_size=int(os.environ.get("TORCHSPEC_USP_RING_SIZE", "1")),
+ ttt_length=int(os.environ.get("TORCHSPEC_USP_TTT_LENGTH", "1")),
+ max_seq_length=int(max_seq_raw) if max_seq_raw else None,
+ )
+ return
+
+ self.eagle_mooncake_store.put(
+ key=key,
+ hidden_states=hidden_states,
+ input_ids=input_ids,
+ last_hidden_states=last_hidden_states,
+ )
+
+ def _send_hidden_states_to_mooncake(
+ self: Scheduler,
+ req: Req,
Expand Down Expand Up @@ -658,7 +689,7 @@ index d79f929d3..3eebfacb5 100644
+ if hidden_states.is_cuda and copy_done_event is not None:
+ torch.cuda.current_stream().wait_event(copy_done_event)
+
+ self.eagle_mooncake_store.put(
+ self._put_spec_training_mooncake(
+ key=key,
+ hidden_states=hidden_states,
+ input_ids=input_ids,
Expand Down
35 changes: 33 additions & 2 deletions patches/sglang/v0.5.8.post1/sglang_decode.patch
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,37 @@ index d79f929d3..45c17b76b 100644
+ else:
+ req.hidden_states.append(req_hidden_states.cpu().clone().tolist())
+
+ def _put_spec_training_mooncake(
+ self: Scheduler,
+ key: str,
+ hidden_states: torch.Tensor,
+ input_ids: torch.Tensor,
+ last_hidden_states: Optional[torch.Tensor],
+ ):
+ import os
+
+ if os.getenv("TORCHSPEC_USP_SHARDED_MOONCAKE") == "1":
+ max_seq_raw = os.environ.get("TORCHSPEC_USP_MAX_SEQ_LENGTH")
+ self.eagle_mooncake_store.put_usp_shards(
+ key=key,
+ hidden_states=hidden_states,
+ input_ids=input_ids,
+ last_hidden_states=last_hidden_states,
+ target=None,
+ sp_size=int(os.environ["TORCHSPEC_USP_SP_SIZE"]),
+ sp_ring_size=int(os.environ.get("TORCHSPEC_USP_RING_SIZE", "1")),
+ ttt_length=int(os.environ.get("TORCHSPEC_USP_TTT_LENGTH", "1")),
+ max_seq_length=int(max_seq_raw) if max_seq_raw else None,
+ )
+ return
+
+ self.eagle_mooncake_store.put(
+ key=key,
+ hidden_states=hidden_states,
+ input_ids=input_ids,
+ last_hidden_states=last_hidden_states,
+ )
+
+ def _send_hidden_states_to_mooncake(
+ self: Scheduler,
+ req: Req,
Expand Down Expand Up @@ -863,7 +894,7 @@ index d79f929d3..45c17b76b 100644
+ if hidden_states.is_cuda and copy_done_event is not None:
+ torch.cuda.current_stream().wait_event(copy_done_event)
+
+ self.eagle_mooncake_store.put(
+ self._put_spec_training_mooncake(
+ key=key,
+ hidden_states=hidden_states,
+ input_ids=input_ids,
Expand Down Expand Up @@ -931,7 +962,7 @@ index d79f929d3..45c17b76b 100644
+ f"lhs={tuple(all_last_hidden_states.shape) if all_last_hidden_states is not None else None}, "
+ f"prompt={len(req.origin_input_ids)}, output={len(req.output_ids)}"
+ )
+ self.eagle_mooncake_store.put(
+ self._put_spec_training_mooncake(
+ key=key,
+ hidden_states=all_hidden_states,
+ input_ids=input_ids,
Expand Down
66 changes: 66 additions & 0 deletions tests/test_data_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch

from torchspec.data.utils import pack_loss_mask, serialize_packed_loss_mask
from torchspec.training.data_fetcher import (
MooncakeDataFetcher,
MooncakeDataset,
Expand Down Expand Up @@ -142,6 +143,71 @@ def test_stops_on_none_sentinel(self):

assert len(samples) == 1

def test_usp_sharded_keeps_local_zero_loss_shard_when_global_loss_exists(self):
"""A local all-zero USP shard must not be skipped independently.

With SP=2, the first rank can own only prompt tokens while the second
rank owns loss-bearing response tokens. If rank 0 skips locally and
rank 1 trains, subsequent USP collectives are ordered differently.
"""
full_loss_mask = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1], dtype=torch.long)
packed_loss_mask = serialize_packed_loss_mask(pack_loss_mask(full_loss_mask))
sample = TrainSample(
mooncake_key="sample",
tensor_shapes={
"input_ids": (1, 8),
"hidden_states": (1, 8, 2),
"target": (1, 8, 2),
},
tensor_dtypes={
"input_ids": torch.long,
"hidden_states": torch.bfloat16,
"target": torch.bfloat16,
},
packed_loss_mask=packed_loss_mask,
metadata={"usp_sharded": True},
)

outputs = []
for sp_rank in (0, 1):
ray_queue = MockRayQueue()
ray_queue.put(sample)
ray_queue.put(None)

store = MockMooncakeStore()
store.put_tensors(
f"sample_usp{sp_rank}",
{
"input_ids": torch.arange(sp_rank * 4, sp_rank * 4 + 4).view(1, 4),
"hidden_states": torch.zeros(1, 4, 2, dtype=torch.bfloat16),
"target": torch.zeros(1, 4, 2, dtype=torch.bfloat16),
},
)

dataset = MooncakeDataset(
ray_queue,
store,
torch.device("cpu"),
usp_enabled=True,
ttt_length=0,
)
dataset._sp_world_size = 2
dataset._sp_rank = sp_rank
dataset._sp_ring_size = 1

tensors, skipped = dataset._usp_get_sharded_item(skip_count=0)
outputs.append((tensors, skipped))

rank0_tensors, rank0_skipped = outputs[0]
rank1_tensors, rank1_skipped = outputs[1]

assert rank0_tensors is not None
assert rank1_tensors is not None
assert rank0_skipped == 0
assert rank1_skipped == 0
assert not rank0_tensors["loss_mask"].any()
assert rank1_tensors["loss_mask"].any()


class TestCreateMooncakeDataloader:
def test_default_batch_size_is_one(self):
Expand Down
32 changes: 23 additions & 9 deletions tests/test_eagle3_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,11 @@ def test_matches_reference(self):
raw_logits = F.linear(hs.float(), lm_head_weight.float())
target_p = F.softmax(raw_logits + torch.randn_like(raw_logits) * 0.5, dim=-1)

loss, acc = compiled_forward_kl_loss(
loss_sum, correct, count = compiled_forward_kl_loss(
hs, target_p, valid_idx, norm_weight, lm_head_weight, norm_eps
)
loss = loss_sum / count
acc = correct / count
ref_loss, ref_acc = _reference_forward_kl_loss(
hs, target_p, norm_weight, lm_head_weight, norm_eps
)
Expand All @@ -131,9 +133,11 @@ def test_perfect_prediction_equals_entropy(self):
target_p = F.softmax(logits, dim=-1)
expected_entropy = -(target_p * target_p.log()).sum(-1).mean()

loss, acc = compiled_forward_kl_loss(
loss_sum, correct, count = compiled_forward_kl_loss(
hs, target_p, valid_idx, norm_weight, lm_head_weight, norm_eps
)
loss = loss_sum / count
acc = correct / count
torch.testing.assert_close(loss, expected_entropy, atol=1e-3, rtol=1e-3)
self.assertAlmostEqual(acc.item(), 1.0, places=2)

Expand All @@ -146,9 +150,11 @@ def test_loss_non_negative_and_finite(self):
target_p = F.softmax(torch.randn(N, V), dim=-1)
valid_idx = torch.arange(N)

loss, acc = compiled_forward_kl_loss(
loss_sum, correct, count = compiled_forward_kl_loss(
hs, target_p, valid_idx, norm_weight, lm_head_weight, 1e-6
)
loss = loss_sum / count
acc = correct / count
self.assertTrue(torch.isfinite(loss))
self.assertGreaterEqual(loss.item(), 0.0)
self.assertGreaterEqual(acc.item(), 0.0)
Expand Down Expand Up @@ -230,7 +236,7 @@ def _run_both_paths(self, device="cpu"):

precomputed = PrecomputedTarget(target_p_padded)
with torch.no_grad():
plosses_pre, _, acces_pre = model(
plosses_pre, _, acces_pre, _ = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
target=precomputed,
Expand All @@ -244,7 +250,7 @@ def _run_both_paths(self, device="cpu"):
length,
)
with torch.no_grad():
plosses_lazy, _, acces_lazy = model(
plosses_lazy, _, acces_lazy, _ = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
target=lazy,
Expand Down Expand Up @@ -376,26 +382,30 @@ def _check_forward_kl(self, valid_idx):
tp_flat = F.softmax(torch.randn(self.BT, self.V), dim=-1)
norm_eps = 1e-6

loss, acc = compiled_forward_kl_loss(
loss_sum, correct, count = compiled_forward_kl_loss(
hs_flat,
tp_flat,
valid_idx,
norm_weight,
lm_head_weight,
norm_eps,
)
loss = loss_sum / count
acc = correct / count

hs_valid = hs_flat[valid_idx]
tp_valid = tp_flat[valid_idx]
all_idx = torch.arange(hs_valid.shape[0])
loss_ref, acc_ref = compiled_forward_kl_loss(
loss_sum_ref, correct_ref, count_ref = compiled_forward_kl_loss(
hs_valid,
tp_valid,
all_idx,
norm_weight,
lm_head_weight,
norm_eps,
)
loss_ref = loss_sum_ref / count_ref
acc_ref = correct_ref / count_ref

torch.testing.assert_close(loss, loss_ref, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(acc, acc_ref, atol=1e-5, rtol=1e-5)
Expand All @@ -409,7 +419,7 @@ def _check_forward_kl_from_hs(self, valid_idx):
target_lm_head_weight = torch.randn(self.V, self.H, dtype=torch.bfloat16)
norm_eps = 1e-6

loss, acc = compiled_forward_kl_loss_from_hs(
loss_sum, correct, count = compiled_forward_kl_loss_from_hs(
hs_flat,
ths_flat,
valid_idx,
Expand All @@ -418,11 +428,13 @@ def _check_forward_kl_from_hs(self, valid_idx):
target_lm_head_weight,
norm_eps,
)
loss = loss_sum / count
acc = correct / count

hs_valid = hs_flat[valid_idx]
ths_valid = ths_flat[valid_idx]
all_idx = torch.arange(hs_valid.shape[0])
loss_ref, acc_ref = compiled_forward_kl_loss_from_hs(
loss_sum_ref, correct_ref, count_ref = compiled_forward_kl_loss_from_hs(
hs_valid,
ths_valid,
all_idx,
Expand All @@ -431,6 +443,8 @@ def _check_forward_kl_from_hs(self, valid_idx):
target_lm_head_weight,
norm_eps,
)
loss_ref = loss_sum_ref / count_ref
acc_ref = correct_ref / count_ref

torch.testing.assert_close(loss, loss_ref, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(acc, acc_ref, atol=1e-5, rtol=1e-5)
Expand Down
Loading
Loading