diff --git a/patches/sglang/v0.5.10.post1/sglang.patch b/patches/sglang/v0.5.10.post1/sglang.patch index 8e5a59e0..67e34a4a 100644 --- a/patches/sglang/v0.5.10.post1/sglang.patch +++ b/patches/sglang/v0.5.10.post1/sglang.patch @@ -623,6 +623,37 @@ index 496cd9665..45ceb0e20 100644 + req_hidden_states.cpu().clone().tolist() + ) + ++ def _put_spec_training_mooncake( ++ self: Scheduler, ++ key: str, ++ hidden_states: torch.Tensor, ++ input_ids: torch.Tensor, ++ last_hidden_states: Optional[torch.Tensor], ++ ): ++ import os ++ ++ if os.getenv("TORCHSPEC_USP_SHARDED_MOONCAKE") == "1": ++ max_seq_raw = os.environ.get("TORCHSPEC_USP_MAX_SEQ_LENGTH") ++ self.eagle_mooncake_store.put_usp_shards( ++ key=key, ++ hidden_states=hidden_states, ++ input_ids=input_ids, ++ last_hidden_states=last_hidden_states, ++ target=None, ++ sp_size=int(os.environ["TORCHSPEC_USP_SP_SIZE"]), ++ sp_ring_size=int(os.environ.get("TORCHSPEC_USP_RING_SIZE", "1")), ++ ttt_length=int(os.environ.get("TORCHSPEC_USP_TTT_LENGTH", "1")), ++ max_seq_length=int(max_seq_raw) if max_seq_raw else None, ++ ) ++ return ++ ++ self.eagle_mooncake_store.put( ++ key=key, ++ hidden_states=hidden_states, ++ input_ids=input_ids, ++ last_hidden_states=last_hidden_states, ++ ) ++ + def _send_hidden_states_to_mooncake( + self: Scheduler, + req: Req, @@ -652,7 +683,7 @@ index 496cd9665..45ceb0e20 100644 + if hidden_states.is_cuda and copy_done_event is not None: + torch.cuda.current_stream().wait_event(copy_done_event) + -+ self.eagle_mooncake_store.put( ++ self._put_spec_training_mooncake( + key=key, + hidden_states=hidden_states, + input_ids=input_ids, @@ -1006,4 +1037,3 @@ index 000000000..24af14b7a + return len(self.data_ids) == 0 -- 2.43.0 - diff --git a/patches/sglang/v0.5.8.post1/sglang.patch b/patches/sglang/v0.5.8.post1/sglang.patch index 62c7bddb..345acb18 100644 --- a/patches/sglang/v0.5.8.post1/sglang.patch +++ b/patches/sglang/v0.5.8.post1/sglang.patch @@ -630,6 +630,37 @@ index d79f929d3..3eebfacb5 100644 + pass + # req.hidden_states.append(req_hidden_states.cpu().clone().tolist()) + ++ def _put_spec_training_mooncake( ++ self: Scheduler, ++ key: str, ++ hidden_states: torch.Tensor, ++ input_ids: torch.Tensor, ++ last_hidden_states: Optional[torch.Tensor], ++ ): ++ import os ++ ++ if os.getenv("TORCHSPEC_USP_SHARDED_MOONCAKE") == "1": ++ max_seq_raw = os.environ.get("TORCHSPEC_USP_MAX_SEQ_LENGTH") ++ self.eagle_mooncake_store.put_usp_shards( ++ key=key, ++ hidden_states=hidden_states, ++ input_ids=input_ids, ++ last_hidden_states=last_hidden_states, ++ target=None, ++ sp_size=int(os.environ["TORCHSPEC_USP_SP_SIZE"]), ++ sp_ring_size=int(os.environ.get("TORCHSPEC_USP_RING_SIZE", "1")), ++ ttt_length=int(os.environ.get("TORCHSPEC_USP_TTT_LENGTH", "1")), ++ max_seq_length=int(max_seq_raw) if max_seq_raw else None, ++ ) ++ return ++ ++ self.eagle_mooncake_store.put( ++ key=key, ++ hidden_states=hidden_states, ++ input_ids=input_ids, ++ last_hidden_states=last_hidden_states, ++ ) ++ + def _send_hidden_states_to_mooncake( + self: Scheduler, + req: Req, @@ -658,7 +689,7 @@ index d79f929d3..3eebfacb5 100644 + if hidden_states.is_cuda and copy_done_event is not None: + torch.cuda.current_stream().wait_event(copy_done_event) + -+ self.eagle_mooncake_store.put( ++ self._put_spec_training_mooncake( + key=key, + hidden_states=hidden_states, + input_ids=input_ids, diff --git a/patches/sglang/v0.5.8.post1/sglang_decode.patch b/patches/sglang/v0.5.8.post1/sglang_decode.patch index b586250d..0badb4b7 100644 --- a/patches/sglang/v0.5.8.post1/sglang_decode.patch +++ b/patches/sglang/v0.5.8.post1/sglang_decode.patch @@ -834,6 +834,37 @@ index d79f929d3..45c17b76b 100644 + else: + req.hidden_states.append(req_hidden_states.cpu().clone().tolist()) + ++ def _put_spec_training_mooncake( ++ self: Scheduler, ++ key: str, ++ hidden_states: torch.Tensor, ++ input_ids: torch.Tensor, ++ last_hidden_states: Optional[torch.Tensor], ++ ): ++ import os ++ ++ if os.getenv("TORCHSPEC_USP_SHARDED_MOONCAKE") == "1": ++ max_seq_raw = os.environ.get("TORCHSPEC_USP_MAX_SEQ_LENGTH") ++ self.eagle_mooncake_store.put_usp_shards( ++ key=key, ++ hidden_states=hidden_states, ++ input_ids=input_ids, ++ last_hidden_states=last_hidden_states, ++ target=None, ++ sp_size=int(os.environ["TORCHSPEC_USP_SP_SIZE"]), ++ sp_ring_size=int(os.environ.get("TORCHSPEC_USP_RING_SIZE", "1")), ++ ttt_length=int(os.environ.get("TORCHSPEC_USP_TTT_LENGTH", "1")), ++ max_seq_length=int(max_seq_raw) if max_seq_raw else None, ++ ) ++ return ++ ++ self.eagle_mooncake_store.put( ++ key=key, ++ hidden_states=hidden_states, ++ input_ids=input_ids, ++ last_hidden_states=last_hidden_states, ++ ) ++ + def _send_hidden_states_to_mooncake( + self: Scheduler, + req: Req, @@ -863,7 +894,7 @@ index d79f929d3..45c17b76b 100644 + if hidden_states.is_cuda and copy_done_event is not None: + torch.cuda.current_stream().wait_event(copy_done_event) + -+ self.eagle_mooncake_store.put( ++ self._put_spec_training_mooncake( + key=key, + hidden_states=hidden_states, + input_ids=input_ids, @@ -931,7 +962,7 @@ index d79f929d3..45c17b76b 100644 + f"lhs={tuple(all_last_hidden_states.shape) if all_last_hidden_states is not None else None}, " + f"prompt={len(req.origin_input_ids)}, output={len(req.output_ids)}" + ) -+ self.eagle_mooncake_store.put( ++ self._put_spec_training_mooncake( + key=key, + hidden_states=all_hidden_states, + input_ids=input_ids, diff --git a/tests/test_data_fetcher.py b/tests/test_data_fetcher.py index e353aba4..2140a429 100644 --- a/tests/test_data_fetcher.py +++ b/tests/test_data_fetcher.py @@ -6,6 +6,7 @@ import torch +from torchspec.data.utils import pack_loss_mask, serialize_packed_loss_mask from torchspec.training.data_fetcher import ( MooncakeDataFetcher, MooncakeDataset, @@ -142,6 +143,71 @@ def test_stops_on_none_sentinel(self): assert len(samples) == 1 + def test_usp_sharded_keeps_local_zero_loss_shard_when_global_loss_exists(self): + """A local all-zero USP shard must not be skipped independently. + + With SP=2, the first rank can own only prompt tokens while the second + rank owns loss-bearing response tokens. If rank 0 skips locally and + rank 1 trains, subsequent USP collectives are ordered differently. + """ + full_loss_mask = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1], dtype=torch.long) + packed_loss_mask = serialize_packed_loss_mask(pack_loss_mask(full_loss_mask)) + sample = TrainSample( + mooncake_key="sample", + tensor_shapes={ + "input_ids": (1, 8), + "hidden_states": (1, 8, 2), + "target": (1, 8, 2), + }, + tensor_dtypes={ + "input_ids": torch.long, + "hidden_states": torch.bfloat16, + "target": torch.bfloat16, + }, + packed_loss_mask=packed_loss_mask, + metadata={"usp_sharded": True}, + ) + + outputs = [] + for sp_rank in (0, 1): + ray_queue = MockRayQueue() + ray_queue.put(sample) + ray_queue.put(None) + + store = MockMooncakeStore() + store.put_tensors( + f"sample_usp{sp_rank}", + { + "input_ids": torch.arange(sp_rank * 4, sp_rank * 4 + 4).view(1, 4), + "hidden_states": torch.zeros(1, 4, 2, dtype=torch.bfloat16), + "target": torch.zeros(1, 4, 2, dtype=torch.bfloat16), + }, + ) + + dataset = MooncakeDataset( + ray_queue, + store, + torch.device("cpu"), + usp_enabled=True, + ttt_length=0, + ) + dataset._sp_world_size = 2 + dataset._sp_rank = sp_rank + dataset._sp_ring_size = 1 + + tensors, skipped = dataset._usp_get_sharded_item(skip_count=0) + outputs.append((tensors, skipped)) + + rank0_tensors, rank0_skipped = outputs[0] + rank1_tensors, rank1_skipped = outputs[1] + + assert rank0_tensors is not None + assert rank1_tensors is not None + assert rank0_skipped == 0 + assert rank1_skipped == 0 + assert not rank0_tensors["loss_mask"].any() + assert rank1_tensors["loss_mask"].any() + class TestCreateMooncakeDataloader: def test_default_batch_size_is_one(self): diff --git a/tests/test_eagle3_loss.py b/tests/test_eagle3_loss.py index aa148b64..8a62420d 100644 --- a/tests/test_eagle3_loss.py +++ b/tests/test_eagle3_loss.py @@ -102,9 +102,11 @@ def test_matches_reference(self): raw_logits = F.linear(hs.float(), lm_head_weight.float()) target_p = F.softmax(raw_logits + torch.randn_like(raw_logits) * 0.5, dim=-1) - loss, acc = compiled_forward_kl_loss( + loss_sum, correct, count = compiled_forward_kl_loss( hs, target_p, valid_idx, norm_weight, lm_head_weight, norm_eps ) + loss = loss_sum / count + acc = correct / count ref_loss, ref_acc = _reference_forward_kl_loss( hs, target_p, norm_weight, lm_head_weight, norm_eps ) @@ -131,9 +133,11 @@ def test_perfect_prediction_equals_entropy(self): target_p = F.softmax(logits, dim=-1) expected_entropy = -(target_p * target_p.log()).sum(-1).mean() - loss, acc = compiled_forward_kl_loss( + loss_sum, correct, count = compiled_forward_kl_loss( hs, target_p, valid_idx, norm_weight, lm_head_weight, norm_eps ) + loss = loss_sum / count + acc = correct / count torch.testing.assert_close(loss, expected_entropy, atol=1e-3, rtol=1e-3) self.assertAlmostEqual(acc.item(), 1.0, places=2) @@ -146,9 +150,11 @@ def test_loss_non_negative_and_finite(self): target_p = F.softmax(torch.randn(N, V), dim=-1) valid_idx = torch.arange(N) - loss, acc = compiled_forward_kl_loss( + loss_sum, correct, count = compiled_forward_kl_loss( hs, target_p, valid_idx, norm_weight, lm_head_weight, 1e-6 ) + loss = loss_sum / count + acc = correct / count self.assertTrue(torch.isfinite(loss)) self.assertGreaterEqual(loss.item(), 0.0) self.assertGreaterEqual(acc.item(), 0.0) @@ -230,7 +236,7 @@ def _run_both_paths(self, device="cpu"): precomputed = PrecomputedTarget(target_p_padded) with torch.no_grad(): - plosses_pre, _, acces_pre = model( + plosses_pre, _, acces_pre, _ = model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], target=precomputed, @@ -244,7 +250,7 @@ def _run_both_paths(self, device="cpu"): length, ) with torch.no_grad(): - plosses_lazy, _, acces_lazy = model( + plosses_lazy, _, acces_lazy, _ = model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], target=lazy, @@ -376,7 +382,7 @@ def _check_forward_kl(self, valid_idx): tp_flat = F.softmax(torch.randn(self.BT, self.V), dim=-1) norm_eps = 1e-6 - loss, acc = compiled_forward_kl_loss( + loss_sum, correct, count = compiled_forward_kl_loss( hs_flat, tp_flat, valid_idx, @@ -384,11 +390,13 @@ def _check_forward_kl(self, valid_idx): lm_head_weight, norm_eps, ) + loss = loss_sum / count + acc = correct / count hs_valid = hs_flat[valid_idx] tp_valid = tp_flat[valid_idx] all_idx = torch.arange(hs_valid.shape[0]) - loss_ref, acc_ref = compiled_forward_kl_loss( + loss_sum_ref, correct_ref, count_ref = compiled_forward_kl_loss( hs_valid, tp_valid, all_idx, @@ -396,6 +404,8 @@ def _check_forward_kl(self, valid_idx): lm_head_weight, norm_eps, ) + loss_ref = loss_sum_ref / count_ref + acc_ref = correct_ref / count_ref torch.testing.assert_close(loss, loss_ref, atol=1e-5, rtol=1e-5) torch.testing.assert_close(acc, acc_ref, atol=1e-5, rtol=1e-5) @@ -409,7 +419,7 @@ def _check_forward_kl_from_hs(self, valid_idx): target_lm_head_weight = torch.randn(self.V, self.H, dtype=torch.bfloat16) norm_eps = 1e-6 - loss, acc = compiled_forward_kl_loss_from_hs( + loss_sum, correct, count = compiled_forward_kl_loss_from_hs( hs_flat, ths_flat, valid_idx, @@ -418,11 +428,13 @@ def _check_forward_kl_from_hs(self, valid_idx): target_lm_head_weight, norm_eps, ) + loss = loss_sum / count + acc = correct / count hs_valid = hs_flat[valid_idx] ths_valid = ths_flat[valid_idx] all_idx = torch.arange(hs_valid.shape[0]) - loss_ref, acc_ref = compiled_forward_kl_loss_from_hs( + loss_sum_ref, correct_ref, count_ref = compiled_forward_kl_loss_from_hs( hs_valid, ths_valid, all_idx, @@ -431,6 +443,8 @@ def _check_forward_kl_from_hs(self, valid_idx): target_lm_head_weight, norm_eps, ) + loss_ref = loss_sum_ref / count_ref + acc_ref = correct_ref / count_ref torch.testing.assert_close(loss, loss_ref, atol=1e-5, rtol=1e-5) torch.testing.assert_close(acc, acc_ref, atol=1e-5, rtol=1e-5) diff --git a/tests/test_usp_attention.py b/tests/test_usp_attention.py new file mode 100644 index 00000000..24174941 --- /dev/null +++ b/tests/test_usp_attention.py @@ -0,0 +1,256 @@ +import importlib.util +import os +import socket +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from transformers import LlamaConfig + +from torchspec.models.draft.llama3_eagle import LlamaFlexAttention, LlamaUSPFlashAttention +from torchspec.utils.distributed import get_sp_rank, init_usp_groups + + +def _has_usp_runtime() -> bool: + return ( + importlib.util.find_spec("flash_attn") is not None + and importlib.util.find_spec("yunchang") is not None + ) + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + sock.listen(1) + return sock.getsockname()[1] + + +def _build_config() -> LlamaConfig: + return LlamaConfig( + hidden_size=128, + num_attention_heads=8, + num_key_value_heads=2, + max_position_embeddings=4096, + rms_norm_eps=1e-5, + vocab_size=32000, + intermediate_size=688, + hidden_act="silu", + num_hidden_layers=1, + attention_bias=False, + torch_dtype="bfloat16", + ) + + +def _make_full_hidden_steps( + *, + num_steps: int, + batch_size: int, + global_seq_len: int, + hidden_size: int, + dtype: torch.dtype, +) -> list[torch.Tensor]: + generator = torch.Generator(device="cpu").manual_seed(20260423) + steps = [] + for _ in range(num_steps): + tensor = torch.randn( + batch_size, + global_seq_len, + hidden_size, + generator=generator, + dtype=torch.float32, + ) + steps.append(tensor.to(dtype)) + return steps + + +def _broadcast_state_dict(rank: int, module: torch.nn.Module) -> dict[str, torch.Tensor]: + state = None + if rank == 0: + state = {name: tensor.detach().cpu() for name, tensor in module.state_dict().items()} + obj = [state] + dist.broadcast_object_list(obj, src=0) + return obj[0] + + +def _run_usp_vs_flex_worker( + rank: int, + world_size: int, + port: int, + sp_ulysses_size: int, + sp_ring_size: int, +) -> None: + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + torch.cuda.set_device(rank) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + init_usp_groups(sp_ulysses_size=sp_ulysses_size, sp_ring_size=sp_ring_size) + + device = torch.device(f"cuda:{rank}") + dtype = torch.bfloat16 + num_steps = 2 + batch_size = 1 + global_seq_len = 32 + local_seq_len = global_seq_len // world_size + config = _build_config() + hidden_size = config.hidden_size * 2 + flex_position_ids = torch.arange(global_seq_len, device=device, dtype=torch.long).unsqueeze(0) + if sp_ulysses_size == world_size: + usp_position_ids = flex_position_ids + else: + start = get_sp_rank() * local_seq_len + usp_position_ids = torch.arange( + start, + start + local_seq_len, + device=device, + dtype=torch.long, + ).unsqueeze(0) + attention_mask = torch.ones(batch_size, global_seq_len, device=device, dtype=torch.bool) + + flex_attention = LlamaFlexAttention(config).to(device).to(dtype) if rank == 0 else None + usp_attention = LlamaUSPFlashAttention(config).to(device).to(dtype) + + state_dict = _broadcast_state_dict(rank, flex_attention if rank == 0 else usp_attention) + usp_attention.load_state_dict(state_dict) + if rank == 0: + flex_attention.load_state_dict(state_dict) + + full_hidden_steps = _make_full_hidden_steps( + num_steps=num_steps, + batch_size=batch_size, + global_seq_len=global_seq_len, + hidden_size=hidden_size, + dtype=dtype, + ) + + local_hidden_steps = [] + for full_hidden in full_hidden_steps: + start = get_sp_rank() * local_seq_len + end = start + local_seq_len + local_hidden = ( + full_hidden[:, start:end, :].to(device=device).clone().detach().requires_grad_(True) + ) + local_hidden_steps.append(local_hidden) + + if rank == 0: + flex_hidden_steps = [ + full_hidden.to(device=device).clone().detach().requires_grad_(True) + for full_hidden in full_hidden_steps + ] + + usp_cache_keys = None + usp_cache_values = None + flex_cache_keys = None + flex_cache_values = None + usp_loss = torch.zeros((), device=device, dtype=torch.float32) + flex_loss = torch.zeros((), device=device, dtype=torch.float32) if rank == 0 else None + max_output_diff = 0.0 + loss_scale = 1.0 / (num_steps * batch_size * global_seq_len * config.hidden_size) + + for step in range(num_steps): + usp_out, usp_cache_keys, usp_cache_values = usp_attention( + hidden_states=local_hidden_steps[step], + cache_keys=usp_cache_keys, + cache_values=usp_cache_values, + attention_mask=None, + position_ids=usp_position_ids, + use_cache=True, + ) + usp_loss = usp_loss + usp_out.float().square().sum() * loss_scale + + gathered_usp_out = [torch.empty_like(usp_out) for _ in range(world_size)] + dist.all_gather(gathered_usp_out, usp_out.detach()) + + if rank == 0: + flex_out, flex_cache_keys, flex_cache_values = flex_attention( + hidden_states=flex_hidden_steps[step], + cache_keys=flex_cache_keys, + cache_values=flex_cache_values, + attention_mask=attention_mask, + position_ids=flex_position_ids, + use_cache=True, + ) + flex_loss = flex_loss + flex_out.float().square().sum() * loss_scale + usp_out_full = torch.cat(gathered_usp_out, dim=1) + step_output_diff = (usp_out_full.float() - flex_out.float()).abs().max().item() + max_output_diff = max(max_output_diff, step_output_diff) + + usp_loss.backward() + if rank == 0: + flex_loss.backward() + + reduced_usp_loss = usp_loss.detach().clone() + dist.all_reduce(reduced_usp_loss, op=dist.ReduceOp.SUM) + if rank == 0: + torch.testing.assert_close( + reduced_usp_loss, + flex_loss.detach(), + atol=2e-2, + rtol=2e-2, + msg=f"USP loss mismatch (max output diff={max_output_diff:.6f})", + ) + + for proj_name in ("q_proj", "k_proj", "v_proj", "o_proj"): + usp_grad = getattr(usp_attention, proj_name).weight.grad.detach().float().clone() + dist.all_reduce(usp_grad, op=dist.ReduceOp.SUM) + if rank == 0: + flex_grad = getattr(flex_attention, proj_name).weight.grad.detach().float() + torch.testing.assert_close( + usp_grad, + flex_grad, + atol=3e-2, + rtol=3e-2, + msg=( + f"USP gradient mismatch for {proj_name} (max output diff={max_output_diff:.6f})" + ), + ) + + for step in range(num_steps): + local_grad = local_hidden_steps[step].grad.detach() + gathered_hidden_grad = [torch.empty_like(local_grad) for _ in range(world_size)] + dist.all_gather(gathered_hidden_grad, local_grad) + if rank == 0: + full_hidden_grad = torch.cat(gathered_hidden_grad, dim=1).float() + torch.testing.assert_close( + full_hidden_grad, + flex_hidden_steps[step].grad.detach().float(), + atol=3e-2, + rtol=3e-2, + msg=( + f"USP input gradient mismatch at step {step} " + f"(max output diff={max_output_diff:.6f})" + ), + ) + + dist.barrier() + dist.destroy_process_group() + + +class TestUSPAttention(unittest.TestCase): + @unittest.skipUnless(torch.cuda.is_available(), "CUDA not available") + @unittest.skipUnless(torch.cuda.device_count() >= 2, "Requires at least 2 CUDA devices") + @unittest.skipUnless(_has_usp_runtime(), "USP test requires flash_attn and yunchang") + def test_usp_matches_flex_loss_and_gradients(self): + port = _find_free_port() + mp.spawn( + _run_usp_vs_flex_worker, + args=(2, port, 2, 1), + nprocs=2, + join=True, + ) + + @unittest.skipUnless(torch.cuda.is_available(), "CUDA not available") + @unittest.skipUnless(torch.cuda.device_count() >= 2, "Requires at least 2 CUDA devices") + @unittest.skipUnless(_has_usp_runtime(), "USP test requires flash_attn and yunchang") + def test_usp_ring_matches_flex_loss_and_gradients(self): + port = _find_free_port() + mp.spawn( + _run_usp_vs_flex_worker, + args=(2, port, 1, 2), + nprocs=2, + join=True, + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/torchspec/config/train_config.py b/torchspec/config/train_config.py index 4c93c72d..e5cb2494 100644 --- a/torchspec/config/train_config.py +++ b/torchspec/config/train_config.py @@ -106,6 +106,8 @@ class TrainingConfig: # "training_first" (default) or "inference_first". Extensible to custom mappings later. placement_strategy: str = "training_first" compile_model: bool = False # torch.compile the full training model + sp_ring_size: int = 1 + sp_ulysses_size: int = 1 gradient_checkpointing: bool = False learning_rate: float = 1e-4 diff --git a/torchspec/controller/inference_manager.py b/torchspec/controller/inference_manager.py index eaaf0d95..4bc6e6cc 100644 --- a/torchspec/controller/inference_manager.py +++ b/torchspec/controller/inference_manager.py @@ -484,13 +484,16 @@ def _parse_engine_output(self, entry: InferenceInput, output: dict) -> Inference self._metrics.record(output) + metadata = dict(entry.metadata or {}) + metadata.update(output.get("metadata", {}) or {}) + return InferenceOutput( data_id=entry.data_id, mooncake_key=output["mooncake_key"], tensor_shapes=output.get("tensor_shapes", {}), tensor_dtypes=output.get("tensor_dtypes", {}), packed_loss_mask=output.get("packed_loss_mask", entry.packed_loss_mask), - metadata=entry.metadata, + metadata=metadata, ) async def _forward_results(self, results: list[tuple[InferenceInput, Any | Exception]]) -> int: diff --git a/torchspec/controller/training_controller.py b/torchspec/controller/training_controller.py index 771cfffb..1c944619 100644 --- a/torchspec/controller/training_controller.py +++ b/torchspec/controller/training_controller.py @@ -124,6 +124,12 @@ class AsyncTrainingController: def __init__(self, args, dp_size: int): self.args = args self.dp_size = dp_size + self.sp_size = ( + getattr(args, "sp_ulysses_size", 1) * getattr(args, "sp_ring_size", 1) + if getattr(args, "attention_backend", None) == "usp" + else 1 + ) + self.queue_count = dp_size * self.sp_size self.prompt_buffer: deque[InferenceInput] = deque() self._prompt_lock = threading.Lock() @@ -133,7 +139,7 @@ def __init__(self, args, dp_size: int): self._pool_bytes = 0 self._sample_bytes: dict[str, int] = {} - self.train_queues = [Queue() for _ in range(dp_size)] + self.train_queues = [Queue() for _ in range(self.queue_count)] # Eval: separate pool and queues so eval data never mixes with training self.eval_pool: deque[InferenceOutput] = deque() @@ -141,7 +147,7 @@ def __init__(self, args, dp_size: int): self._eval_data_ids: set[str] = set() self._eval_expected_count: int = 0 self._eval_dispatched_samples: int = 0 - self.eval_queues = [Queue() for _ in range(dp_size)] + self.eval_queues = [Queue() for _ in range(self.queue_count)] self.batch_id = 0 self.dispatch_batch_size = args.per_dp_rank_batch_size * dp_size @@ -484,15 +490,20 @@ def _dispatch_to_queues( for result in results: metadata = getattr(result, "metadata", {}) or {} last_turn_loss_only = metadata.get("has_thinking") - queues[dp_rank].put( - TrainSample( - mooncake_key=result.mooncake_key, - tensor_shapes=result.tensor_shapes, - tensor_dtypes=result.tensor_dtypes, - packed_loss_mask=result.packed_loss_mask, - last_turn_loss_only=last_turn_loss_only, - ) + sample = TrainSample( + mooncake_key=result.mooncake_key, + tensor_shapes=result.tensor_shapes, + tensor_dtypes=result.tensor_dtypes, + packed_loss_mask=result.packed_loss_mask, + last_turn_loss_only=last_turn_loss_only, + metadata=metadata, ) + if self.sp_size > 1 and len(queues) == self.queue_count: + start = dp_rank * self.sp_size + for rank in range(start, start + self.sp_size): + queues[rank].put(sample) + else: + queues[dp_rank].put(sample) def push_inference_sample(self, sample: InferenceOutput) -> int: """Add a single inference sample to the training pool. diff --git a/torchspec/data/utils.py b/torchspec/data/utils.py index a8348a02..226123ef 100644 --- a/torchspec/data/utils.py +++ b/torchspec/data/utils.py @@ -47,8 +47,9 @@ def is_local_data_path(path: str, base_dir: str | None = None) -> bool: class DataCollatorWithPadding: - def __init__(self): + def __init__(self, usp_enabled: bool = False): self.sp_degree = 1 + self.usp_enabled = usp_enabled def paddingtensor(self, intensors: torch.Tensor, N: int) -> torch.Tensor: B, n, S = intensors.shape @@ -87,14 +88,16 @@ def _get_loss_mask(self, item: Dict[str, Any]) -> torch.Tensor: def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: max_length = max(item["input_ids"].shape[1] for item in features) max_length = ((max_length + self.sp_degree - 1) // self.sp_degree) * self.sp_degree - # Round up to nearest bucket to reduce unique shapes for torch.compile. - # Without this, every batch gets a different padded length, causing - # FlexAttention recompilation (~1s overhead per new shape). - _BUCKET = 256 - max_length = ((max_length + _BUCKET - 1) // _BUCKET) * _BUCKET - # All real tokens get attention_mask=1; paddingtensor2D zero-pads the rest. - attention_masks = [torch.ones_like(item["input_ids"]).long() for item in features] + if self.usp_enabled: + attention_masks = [item["attention_mask"].long() for item in features] + else: + # Round up to nearest bucket to reduce unique shapes for torch.compile. + # Without this, every batch gets a different padded length, causing + # FlexAttention recompilation (~1s overhead per new shape). + _BUCKET = 256 + max_length = ((max_length + _BUCKET - 1) // _BUCKET) * _BUCKET + attention_masks = [torch.ones_like(item["input_ids"]).long() for item in features] batch_input_ids = torch.cat( [self.paddingtensor2D(item["input_ids"], max_length) for item in features] @@ -113,6 +116,14 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: "target": None, "last_hidden_states": None, } + if self.usp_enabled: + max_position_length = max(item["position_ids"].shape[1] for item in features) + batch["position_ids"] = torch.cat( + [ + self.paddingtensor2D(item["position_ids"], max_position_length) + for item in features + ] + ) if all("hidden_states" in item for item in features): batch["hidden_states"] = torch.cat( [self.paddingtensor(item["hidden_states"], max_length) for item in features] @@ -225,6 +236,15 @@ def resolve_loss_mask( packed = data.get("packed_loss_mask") if packed is not None: mask = unpack_loss_mask(packed) + input_ids = data.get("input_ids") + if input_ids is not None: + if input_ids.dim() == 2: + input_ids = input_ids.squeeze(0) + expected_len = input_ids.shape[-1] + if mask.shape[0] > expected_len: + mask = mask[:expected_len] + elif mask.shape[0] < expected_len: + mask = torch.nn.functional.pad(mask, (0, expected_len - mask.shape[0])) if not mask.any(): return None data["loss_mask"] = mask diff --git a/torchspec/inference/engine/sgl_engine.py b/torchspec/inference/engine/sgl_engine.py index 39abe1c5..7169ae50 100644 --- a/torchspec/inference/engine/sgl_engine.py +++ b/torchspec/inference/engine/sgl_engine.py @@ -63,6 +63,35 @@ ) +_USP_SHARDED_MOONCAKE_ENV_KEYS = ( + "TORCHSPEC_USP_SHARDED_MOONCAKE", + "TORCHSPEC_USP_SP_SIZE", + "TORCHSPEC_USP_RING_SIZE", + "TORCHSPEC_USP_TTT_LENGTH", + "TORCHSPEC_USP_MAX_SEQ_LENGTH", +) + + +def _configure_usp_sharded_mooncake_env(args: Any, max_seq_length: int | None) -> None: + values: dict[str, str] = {} + if getattr(args, "attention_backend", None) == "usp": + sp_ring_size = getattr(args, "sp_ring_size", 1) + values = { + "TORCHSPEC_USP_SHARDED_MOONCAKE": "1", + "TORCHSPEC_USP_SP_SIZE": str(getattr(args, "sp_ulysses_size", 1) * sp_ring_size), + "TORCHSPEC_USP_RING_SIZE": str(sp_ring_size), + "TORCHSPEC_USP_TTT_LENGTH": str(getattr(args, "ttt_length", 1)), + } + if max_seq_length is not None: + values["TORCHSPEC_USP_MAX_SEQ_LENGTH"] = str(max_seq_length) + + for name in _USP_SHARDED_MOONCAKE_ENV_KEYS: + if name in values: + os.environ[name] = values[name] + else: + os.environ.pop(name, None) + + class SglEngine(SglDecodeEngineMixin, InferenceEngine, RayActor): """Ray actor wrapper for sgl.Engine with distributed deployment support. @@ -225,8 +254,10 @@ def init( extra = {k: v for k, v in extra.items() if k not in _PROTECTED_ENGINE_KEYS} engine_kwargs.update(extra) - # Protected keys — always set by TorchSpec, never overridable + # SGLang's patched scheduler reads these process env vars when writing + # Mooncake training tensors. max_seq_length = getattr(self.args, "max_seq_length", None) + _configure_usp_sharded_mooncake_env(self.args, max_seq_length) engine_kwargs.update( { @@ -465,6 +496,8 @@ def generate( "tensor_shapes": tensor_shapes, "tensor_dtypes": self._get_tensor_dtypes(), } + if getattr(self.args, "attention_backend", None) == "usp": + output["metadata"] = {"usp_sharded": True} outputs.append(output) logger.debug( diff --git a/torchspec/inference/engine/sgl_engine_decode.py b/torchspec/inference/engine/sgl_engine_decode.py index fd696fe2..29f770f0 100644 --- a/torchspec/inference/engine/sgl_engine_decode.py +++ b/torchspec/inference/engine/sgl_engine_decode.py @@ -259,6 +259,8 @@ def generate_with_decode( "tensor_dtypes": self._get_tensor_dtypes(), "packed_loss_mask": loss_mask, } + if getattr(self.args, "attention_backend", None) == "usp": + output_dict["metadata"] = {"usp_sharded": True} # Add performance metrics if available (for wandb logging) _METRIC_KEYS = ( diff --git a/torchspec/models/draft/llama3_eagle.py b/torchspec/models/draft/llama3_eagle.py index 6200c9ce..2696e866 100644 --- a/torchspec/models/draft/llama3_eagle.py +++ b/torchspec/models/draft/llama3_eagle.py @@ -19,9 +19,11 @@ # SOFTWARE. import math +import os from typing import Optional, Tuple import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from torch.nn.attention.flex_attention import create_block_mask, flex_attention @@ -34,6 +36,7 @@ compile_friendly_flex_attention, generate_eagle3_mask, ) +from torchspec.utils.distributed import get_sp_ring_group, get_sp_ulysses_group from torchspec.utils.logging import logger, print_with_rank _flash_attn_import_error: ImportError | None = None @@ -1904,6 +1907,340 @@ def backward(ctx, grad_out): return dq.to(q.dtype), dcache_k.to(cache_k.dtype), dcache_v.to(cache_v.dtype), None, None +class _USPFlashCachedMergeFunc(torch.autograd.Function): + @staticmethod + def _merge_dims(q: torch.Tensor, cache_k: torch.Tensor): + bsz, q_len, num_heads, head_dim = q.shape + num_blocks = cache_k.shape[1] + num_kv_heads = cache_k.shape[3] + num_groups = num_heads // num_kv_heads + return bsz, q_len, num_heads, head_dim, num_blocks, num_kv_heads, num_groups + + @staticmethod + def _kernel_lse(lse: torch.Tensor, bsz: int, q_len: int, num_heads: int) -> torch.Tensor: + return lse.reshape(bsz, q_len, num_heads).transpose(1, 2).contiguous() + + @staticmethod + def forward(ctx, q, cache_k, cache_v, softmax_scale: float): + bsz, q_len, num_heads, head_dim, num_blocks, num_kv_heads, num_groups = ( + _USPFlashCachedMergeFunc._merge_dims(q, cache_k) + ) + q_expanded = q.view(bsz, q_len, num_kv_heads, num_groups, head_dim) + + k0 = cache_k[:, 0].contiguous() + v0 = cache_v[:, 0].contiguous() + assert _std_flash_attn_forward is not None and _std_flash_attn_backward is not None, ( + "USP cached merge requires standard flash-attn forward/backward " + f"(standard import error: {_std_flash_attn_import_error!r})" + ) + out0, lse0_kernel = _standard_flash_attn_forward( + q.contiguous(), + k0, + v0, + softmax_scale=softmax_scale, + causal=True, + ) + out0_expanded = out0.view(bsz, q_len, num_kv_heads, num_groups, head_dim).float() + lse0 = lse0_kernel.transpose(1, 2).reshape(bsz, q_len, num_kv_heads, num_groups).float() + lse_terms = [lse0] + attn_terms = [out0_expanded] + for i in range(1, num_blocks): + ki = cache_k[:, i].unsqueeze(-2).float() + vi = cache_v[:, i].unsqueeze(-2).float() + lse_terms.append((q_expanded.float() * ki).sum(-1) * softmax_scale) + attn_terms.append(vi.expand_as(out0_expanded)) + + merged_lse = torch.logsumexp(torch.stack(lse_terms, dim=-1), dim=-1) + out = sum( + term * torch.exp(lse - merged_lse).unsqueeze(-1) + for term, lse in zip(attn_terms, lse_terms) + ) + ctx.save_for_backward(q, cache_k, cache_v, out, merged_lse) + ctx.softmax_scale = softmax_scale + return out.to(q.dtype).reshape_as(q) + + @staticmethod + def backward(ctx, grad_out): + q, cache_k, cache_v, out, merged_lse = ctx.saved_tensors + bsz, q_len, num_heads, head_dim, num_blocks, num_kv_heads, num_groups = ( + _USPFlashCachedMergeFunc._merge_dims(q, cache_k) + ) + scale = ctx.softmax_scale + + grad_out_f = grad_out.float().view(bsz, q_len, num_kv_heads, num_groups, head_dim) + q_f = q.float() + q_expanded = q_f.view(bsz, q_len, num_kv_heads, num_groups, head_dim) + out_f = out.float() + out_expanded = out_f.view(bsz, q_len, num_kv_heads, num_groups, head_dim) + + dq = torch.zeros_like(q_f) + dcache_k = torch.zeros_like(cache_k.float()) + dcache_v = torch.zeros_like(cache_v.float()) + + merged_lse_kernel = _USPFlashCachedMergeFunc._kernel_lse(merged_lse, bsz, q_len, num_heads) + dq0 = torch.empty_like(q) + dk0 = torch.empty_like(cache_k[:, 0]) + dv0 = torch.empty_like(cache_v[:, 0]) + _standard_flash_attn_backward_call( + grad_out.contiguous(), + q.contiguous(), + cache_k[:, 0].contiguous(), + cache_v[:, 0].contiguous(), + out.to(q.dtype).reshape_as(q).contiguous(), + merged_lse_kernel, + dq0, + dk0, + dv0, + softmax_scale=scale, + causal=True, + ) + dq += dq0.float() + dcache_k[:, 0] += dk0.float() + dcache_v[:, 0] += dv0.float() + + for i in range(1, num_blocks): + ki = cache_k[:, i].float().unsqueeze(-2) + vi = cache_v[:, i].float().unsqueeze(-2) + lse_i = (q_expanded * ki).sum(-1) * scale + wi = torch.exp(lse_i - merged_lse) + d_out_i = grad_out_f * wi.unsqueeze(-1) + d_lse_i = wi * (grad_out_f * (vi.expand_as(out_expanded) - out_expanded)).sum(-1) + dq += (d_lse_i.unsqueeze(-1) * scale * ki).reshape_as(q) + dcache_k[:, i] += (d_lse_i.unsqueeze(-1) * scale * q_expanded).sum(dim=3) + dcache_v[:, i] += d_out_i.sum(dim=3) + + return dq.to(q.dtype), dcache_k.to(cache_k.dtype), dcache_v.to(cache_v.dtype), None + + +def _update_ring_out_and_lse( + out: torch.Tensor | None, + lse: torch.Tensor | None, + block_out: torch.Tensor, + block_lse: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + block_out = block_out.float() + block_lse = block_lse.float() + if out is None or lse is None: + return block_out, block_lse + new_lse = torch.logaddexp(lse, block_lse) + out = out * torch.exp(lse - new_lse).unsqueeze(-1) + block_out * torch.exp( + block_lse - new_lse + ).unsqueeze(-1) + return out, new_lse + + +class _USPRingFlashCachedMergeFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, q, cache_k, cache_v, ring_group: dist.ProcessGroup, softmax_scale: float): + from yunchang.ring.ring_flash_attn import ring_flash_attn_forward + + bsz, q_len, num_heads, head_dim, num_blocks, num_kv_heads, num_groups = ( + _USPFlashCachedMergeFunc._merge_dims(q, cache_k) + ) + q_expanded = q.view(bsz, q_len, num_kv_heads, num_groups, head_dim) + + out_ring, lse_ring = ring_flash_attn_forward( + ring_group, + q.contiguous(), + cache_k[:, 0].contiguous(), + cache_v[:, 0].contiguous(), + softmax_scale=softmax_scale, + dropout_p=0.0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + ) + if lse_ring.dim() == 3 and lse_ring.shape[1] == num_heads: + lse_ring = lse_ring.transpose(1, 2) + acc_out = out_ring.view(bsz, q_len, num_kv_heads, num_groups, head_dim) + acc_lse = lse_ring.reshape(bsz, q_len, num_kv_heads, num_groups) + + for i in range(1, num_blocks): + ki = cache_k[:, i].unsqueeze(-2).float() + vi = cache_v[:, i].unsqueeze(-2).float() + lse_i = (q_expanded.float() * ki).sum(-1) * softmax_scale + out_i = vi.expand(bsz, q_len, num_kv_heads, num_groups, head_dim) + acc_out, acc_lse = _update_ring_out_and_lse(acc_out, acc_lse, out_i, lse_i) + + ctx.save_for_backward(q, cache_k, cache_v, acc_out, acc_lse) + ctx.ring_group = ring_group + ctx.softmax_scale = softmax_scale + return acc_out.to(q.dtype).reshape_as(q) + + @staticmethod + def backward(ctx, grad_out): + from yunchang.ring.ring_flash_attn import ring_flash_attn_backward + + q, cache_k, cache_v, out, merged_lse = ctx.saved_tensors + bsz, q_len, num_heads, head_dim, num_blocks, num_kv_heads, num_groups = ( + _USPFlashCachedMergeFunc._merge_dims(q, cache_k) + ) + scale = ctx.softmax_scale + + if grad_out.ndim == 3: + grad_out = grad_out.view(bsz, q_len, num_heads, head_dim) + + grad_out_f = grad_out.float().view(bsz, q_len, num_kv_heads, num_groups, head_dim) + q_f = q.float() + q_expanded = q_f.view(bsz, q_len, num_kv_heads, num_groups, head_dim) + out_expanded = out.float().view(bsz, q_len, num_kv_heads, num_groups, head_dim) + + dcache_k = torch.zeros_like(cache_k.float()) + dcache_v = torch.zeros_like(cache_v.float()) + out_q = out.to(q.dtype).reshape_as(q) + merged_lse_kernel = _USPFlashCachedMergeFunc._kernel_lse(merged_lse, bsz, q_len, num_heads) + + dq, dk0, dv0 = ring_flash_attn_backward( + ctx.ring_group, + grad_out.contiguous(), + q.contiguous(), + cache_k[:, 0].contiguous(), + cache_v[:, 0].contiguous(), + out_q.contiguous(), + merged_lse_kernel, + softmax_scale=scale, + dropout_p=0.0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + ) + dq = dq.float() + dcache_k[:, 0] = dk0.float() + dcache_v[:, 0] = dv0.float() + + for i in range(1, num_blocks): + ki = cache_k[:, i].float().unsqueeze(-2) + vi = cache_v[:, i].float().unsqueeze(-2) + lse_i = (q_expanded * ki).sum(-1) * scale + wi = torch.exp(lse_i - merged_lse) + d_out_i = grad_out_f * wi.unsqueeze(-1) + d_lse_i = wi * (grad_out_f * (vi.expand_as(out_expanded) - out_expanded)).sum(-1) + dq += (d_lse_i.unsqueeze(-1) * scale * ki).reshape_as(q) + dcache_k[:, i] += (d_lse_i.unsqueeze(-1) * scale * q_expanded).sum(dim=3) + dcache_v[:, i] += d_out_i.sum(dim=3) + + return dq.to(q.dtype), dcache_k.to(cache_k.dtype), dcache_v.to(cache_v.dtype), None, None + + +class LlamaUSPFlashAttention(LlamaAttention): + """USP attention with Ulysses all-to-all and local flash attention.""" + + def __init__(self, config): + super().__init__(config) + if not dist.is_initialized(): + raise RuntimeError( + "LlamaUSPFlashAttention requires torch.distributed to be initialized first." + ) + try: + from yunchang.comm import SeqAllToAll4D + except ImportError as exc: + raise RuntimeError( + "USP requires yunchang; please install it to use USP attention." + ) from exc + + self._SeqAllToAll4D = SeqAllToAll4D + self.ring_pg = get_sp_ring_group() + self.ulysses_pg = get_sp_ulysses_group() + if self.ring_pg is None or self.ulysses_pg is None: + raise RuntimeError("USP requires sp ring/ulysses groups to be initialized.") + self.sp_ring_degree = dist.get_world_size(self.ring_pg) + self.sp_ulysses_degree = dist.get_world_size(self.ulysses_pg) + self.scatter_idx = 2 + self.gather_idx = 1 + self.use_sync = False + + def forward( + self, + hidden_states: torch.Tensor, + cache_keys: Optional[torch.Tensor] = None, + cache_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + del attention_mask, use_cache + bsz, q_len, _ = hidden_states.size() + local_q_len = q_len + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + query_states = self._SeqAllToAll4D.apply( + self.ulysses_pg, + query_states, + self.scatter_idx, + self.gather_idx, + self.use_sync, + ) + + key_states = self.k_proj(hidden_states).view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ) + key_states = self._SeqAllToAll4D.apply( + self.ulysses_pg, + key_states, + self.scatter_idx, + self.gather_idx, + self.use_sync, + ) + + value_states = self.v_proj(hidden_states).view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ) + value_states = self._SeqAllToAll4D.apply( + self.ulysses_pg, + value_states, + self.scatter_idx, + self.gather_idx, + self.use_sync, + ) + + q_len = query_states.shape[1] + global_q_len = local_q_len * self.sp_ring_degree * self.sp_ulysses_degree + + lck = 0 if cache_keys is None else cache_keys.shape[1] + cos, sin = self.rotary_emb(query_states, seq_len=global_q_len + lck) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + lck, unsqueeze_dim=2 + ) + + if cache_keys is not None: + cache_keys = torch.cat([cache_keys, key_states.unsqueeze(1)], dim=1) + cache_values = torch.cat([cache_values, value_states.unsqueeze(1)], dim=1) + else: + cache_keys = key_states.unsqueeze(1) + cache_values = value_states.unsqueeze(1) + + softmax_scale = 1.0 / math.sqrt(self.head_dim) + if self.sp_ring_degree > 1: + attn_output = _USPRingFlashCachedMergeFunc.apply( + query_states, + cache_keys, + cache_values, + self.ring_pg, + softmax_scale, + ) + else: + attn_output = _USPFlashCachedMergeFunc.apply( + query_states, + cache_keys, + cache_values, + softmax_scale, + ) + + attn_output = self._SeqAllToAll4D.apply( + self.ulysses_pg, + attn_output, + self.gather_idx, + self.scatter_idx, + self.use_sync, + ) + attn_output = attn_output.reshape(bsz, local_q_len, self.head_dim * self.num_heads) + attn_output = self.o_proj(attn_output) + return attn_output, cache_keys, cache_values + + def warmup_flash_attention_masked( q_len: int, num_heads: int, @@ -2054,6 +2391,8 @@ def __init__(self, config, attention_backend: str = "sdpa"): self.self_attn = LlamaFlashAttentionMasked(config=config) elif attention_backend == "fa": self.self_attn = LlamaFlashAttention(config=config) + elif attention_backend == "usp": + self.self_attn = LlamaUSPFlashAttention(config=config) else: raise ValueError(f"Unknown attention backend {attention_backend}") @@ -2133,8 +2472,14 @@ def project_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor: raise ValueError( f"Target hidden states size mismatch: {hidden_states.size(-1)} != expected: {expected_size}" ) - - return self.fc(hidden_states) + if os.environ.get("TORCHSPEC_EAGLE3_PROJ_FP32", "1") in {"0", "false", "False"}: + return self.fc(hidden_states.to(self.fc.weight.dtype)) + proj = F.linear( + hidden_states.to(torch.float32), + self.fc.weight.to(torch.float32), + None if self.fc.bias is None else self.fc.bias.to(torch.float32), + ) + return proj.to(self.fc.weight.dtype) def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: norm_hidden_states = self.norm(hidden_states) diff --git a/torchspec/models/eagle3.py b/torchspec/models/eagle3.py index e6c4aefa..14d22eb0 100644 --- a/torchspec/models/eagle3.py +++ b/torchspec/models/eagle3.py @@ -22,6 +22,7 @@ from typing import List, Optional, Tuple, Union import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint as torch_checkpoint @@ -29,6 +30,12 @@ from torchspec.models.ops.loss import ( compiled_forward_kl_loss, compiled_forward_kl_loss_from_hs, + compiled_sum_forward_kl_loss_from_hs, +) +from torchspec.utils.distributed import ( + get_draft_sp_group, + get_sp_ring_rank, + get_sp_ulysses_group, ) from torchspec.utils.tensor import padding @@ -63,6 +70,13 @@ def __init__( self.attention_backend = attention_backend self.gradient_checkpointing = gradient_checkpointing self.vocab_pruning = draft_model.vocab_size != draft_model.target_vocab_size + self._usp_sp_group = get_draft_sp_group() if attention_backend == "usp" else None + self._usp_ulysses_group = get_sp_ulysses_group() if attention_backend == "usp" else None + self._usp_ulysses_world_size = ( + dist.get_world_size(self._usp_ulysses_group) + if self._usp_ulysses_group is not None + else 1 + ) def _calculate_loss( self, @@ -74,25 +88,14 @@ def _calculate_loss( norm_weight: torch.Tensor, lm_head_weight: torch.Tensor, norm_eps: float, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute forward-KL loss and accuracy for one TTT step. - - Both paths pass full (B*T, ...) flat views + valid_idx into the - compiled function so torch.compile can fuse index_select with - subsequent ops, avoiding separate (N_valid, V) copies outside. - - - PrecomputedTarget (vocab pruning): compiled_forward_kl_loss - with pre-computed target probs. - - LazyTarget (no pruning): compiled_forward_kl_loss_from_hs - computes target softmax inside the compiled graph. - """ + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: valid_idx = mask.flatten().nonzero().squeeze(-1) if valid_idx.numel() == 0: # FSDP requires every trainable param to participate in gradient # all-reduce/reduce-scatter. total = sum(p.reshape(-1)[0] for p in self.parameters() if p.requires_grad) zero = total * 0.0 - return zero, zero.detach() + return zero, zero.detach(), zero.detach() # Important as it prevents recompilation. torch._dynamo.maybe_mark_dynamic(valid_idx, 0) hs_flat = hidden_states.reshape(-1, hidden_states.shape[-1]) @@ -109,7 +112,6 @@ def _calculate_loss( ) return compiled_forward_kl_loss(*args) else: - # lazy ths_flat = target.hidden_states_padded[:, idx : idx + seq_length, :].reshape( -1, target.lm_head_weight.shape[-1] ) @@ -122,12 +124,17 @@ def _calculate_loss( target.lm_head_weight, norm_eps, ) + use_sum_lazy_loss = self.attention_backend == "usp" if self.gradient_checkpointing and self.training: return torch_checkpoint( - compiled_forward_kl_loss_from_hs, + compiled_sum_forward_kl_loss_from_hs + if use_sum_lazy_loss + else compiled_forward_kl_loss_from_hs, *args, use_reentrant=False, ) + if use_sum_lazy_loss: + return compiled_sum_forward_kl_loss_from_hs(*args) return compiled_forward_kl_loss_from_hs(*args) def forward( @@ -139,19 +146,34 @@ def forward( hidden_states: torch.Tensor, past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, position_ids: Optional[torch.Tensor] = None, - ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: batch_size, seq_length, _ = hidden_states.shape seq_length_with_past = seq_length past_key_values_length = 0 norm_weight, lm_head_weight, norm_eps = self.draft_model.get_lm_head_params() - hidden_states = self.draft_model.project_hidden_states(hidden_states) if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length - if position_ids is None: + if self.attention_backend == "usp": + usp_chunk_size = seq_length - self.length + if usp_chunk_size <= 0: + raise ValueError( + f"USP local seq_length ({seq_length}) must be larger than ttt_length ({self.length})" + ) + if position_ids is None: + device = hidden_states.device + ring_chunk_size = usp_chunk_size * self._usp_ulysses_world_size + position_start = get_sp_ring_rank() * ring_chunk_size + past_key_values_length + position_ids = torch.arange( + position_start, + position_start + ring_chunk_size, + dtype=torch.long, + device=device, + ).unsqueeze(0) + elif position_ids is None: device = hidden_states.device position_ids = torch.arange( past_key_values_length, @@ -172,8 +194,6 @@ def forward( past_key_values_length=past_key_values_length, ) - # position_mask (vocab pruning) is a subset of loss_mask that further - # filters to tokens whose target argmax falls in the draft vocab. if isinstance(target, PrecomputedTarget) and target.position_mask is not None: mask = target.position_mask else: @@ -182,37 +202,55 @@ def forward( plosses = [] vlosses = [] acces = [] + acc_counts = [] cache_keys = None cache_values = None - # Clamp multimodal placeholder IDs (hash-based pad values from SGLang) - # to valid vocab range before embedding lookup. input_ids = input_ids.clamp(min=0, max=self.draft_model.target_vocab_size - 1) for idx in range(self.length): is_last = idx == self.length - 1 - inputs_embeds = self.draft_model.embed_input_ids(input_ids) - inputs_embeds = inputs_embeds.to(hidden_states.dtype) + step_input_ids = input_ids + step_hidden_states = hidden_states + step_attention_mask = attention_mask + step_position_ids = position_ids + step_mask = mask + step_seq_length = seq_length + + if self.attention_backend == "usp": + step_seq_length = usp_chunk_size + step_input_ids = input_ids[:, :step_seq_length] + step_hidden_states = hidden_states[:, :step_seq_length, :] + step_mask = mask[:, :step_seq_length] + if attention_mask is not None: + step_attention_mask = attention_mask[:, :step_seq_length] + if position_ids is not None: + step_position_ids = position_ids[ + :, : step_seq_length * self._usp_ulysses_world_size + ] + + inputs_embeds = self.draft_model.embed_input_ids(step_input_ids) + inputs_embeds = inputs_embeds.to(step_hidden_states.dtype) if self.gradient_checkpointing and self.training: hidden_states_out, cache_keys, cache_values = torch_checkpoint( self.draft_model.backbone, inputs_embeds, - hidden_states, - attention_mask, - position_ids, + step_hidden_states, + step_attention_mask, + step_position_ids, cache_keys, cache_values, - True, # use_cache + True, use_reentrant=False, ) else: hidden_states_out, cache_keys, cache_values = self.draft_model.backbone( input_embeds=inputs_embeds, - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, + hidden_states=step_hidden_states, + attention_mask=step_attention_mask, + position_ids=step_position_ids, cache_keys=cache_keys, cache_values=cache_values, use_cache=True, @@ -220,23 +258,64 @@ def forward( hidden_states = hidden_states_out - loss, acc = self._calculate_loss( + local_sum_loss, local_correct, local_count = self._calculate_loss( hidden_states=hidden_states, target=target, - mask=mask, + mask=step_mask, idx=idx, - seq_length=seq_length, + seq_length=step_seq_length, norm_weight=norm_weight, lm_head_weight=lm_head_weight, norm_eps=norm_eps, ) + if self.attention_backend == "usp": + # A shard can have no local loss tokens while its Ulysses peers do. + # Keep the zero-loss path connected to this layer's activations so + # autograd still executes the same sequence-parallel collectives. + local_sum_loss = local_sum_loss + hidden_states.sum() * 0.0 + + loss = local_sum_loss / local_count.clamp_min(1.0) + metric_loss = loss.detach() + metric_acc = ( + (local_correct / local_count.clamp_min(1.0)).detach() + if float(local_count.detach().float().cpu()) > 0.0 + else local_correct.detach().float() * 0.0 + ) + + if self._usp_sp_group is not None: + reduced_stats = torch.stack( + ( + local_sum_loss.detach().clone().float(), + local_correct.detach().clone().float(), + local_count.detach().clone().float(), + ) + ) + dist.all_reduce(reduced_stats, op=dist.ReduceOp.SUM, group=self._usp_sp_group) + reduced_sum_loss, reduced_correct, reduced_count = reduced_stats.unbind() + denom = reduced_count.clamp_min(1.0) + loss = (local_sum_loss / denom).to(loss.dtype) + if reduced_count.item() > 0: + metric_loss = (reduced_sum_loss / denom).detach() + metric_acc = (reduced_correct / denom).to( + device=loss.device, dtype=torch.float32 + ) + metric_count = reduced_count.to(device=loss.device, dtype=torch.float32) + else: + metric_loss = reduced_sum_loss.detach() * 0.0 + metric_acc = local_correct.detach().float() * 0.0 + metric_count = reduced_count.to(device=loss.device, dtype=torch.float32) + else: + metric_count = local_count.detach().float().to(device=loss.device) + plosses.append(loss) - acces.append(acc) + vlosses.append(metric_loss) + acces.append(metric_acc) + acc_counts.append(metric_count) if not is_last: input_ids = padding(input_ids, left=False) mask = padding(mask, left=False) - return plosses, vlosses, acces + return plosses, vlosses, acces, acc_counts @torch.no_grad() @@ -249,21 +328,25 @@ def compute_target_p_padded( chunk_size: int = 4096, ) -> PrecomputedTarget: target_lm_head_weight = target_lm_head_weight.detach() - pruned_weight = target_lm_head_weight[t2d] # (V_draft, D) + pruned_weight = target_lm_head_weight[t2d] - B, T, _D = target_hidden_states.shape + bsz, seq_len, hidden_size = target_hidden_states.shape loss_mask_bool = loss_mask.bool() valid_flat_idx = loss_mask_bool.reshape(-1).nonzero(as_tuple=True)[0] - valid_hs = target_hidden_states.reshape(-1, _D)[valid_flat_idx] # (N_valid, D) + valid_hs = target_hidden_states.reshape(-1, hidden_size)[valid_flat_idx] - position_mask_flat = torch.zeros(B * T, device=target_hidden_states.device, dtype=torch.float) + position_mask_flat = torch.zeros( + bsz * seq_len, + device=target_hidden_states.device, + dtype=torch.float, + ) for i in range(0, valid_hs.shape[0], chunk_size): chunk_hs = valid_hs[i : i + chunk_size] chunk_argmax = F.linear(chunk_hs, target_lm_head_weight).argmax(-1) in_draft = t2d[chunk_argmax] position_mask_flat[valid_flat_idx[i : i + chunk_size]] = in_draft.float() - position_mask = position_mask_flat.reshape(B, T) + position_mask = position_mask_flat.reshape(bsz, seq_len) target_logits_pruned = F.linear(target_hidden_states, pruned_weight) target_p = F.softmax(target_logits_pruned.float(), dim=-1) @@ -277,11 +360,6 @@ def compute_lazy_target_padded( target_lm_head_weight: torch.Tensor, length: int, ) -> LazyTarget: - """Build a LazyTarget that defers softmax to the forward loop. - - Used for non-pruning cases to avoid materializing the full - (B, T, V_full) target probability tensor. - """ return LazyTarget( hidden_states_padded=F.pad(target_hidden_states, (0, 0, 0, length), value=0.0), lm_head_weight=target_lm_head_weight.detach(), diff --git a/torchspec/models/ops/loss.py b/torchspec/models/ops/loss.py index 476d36ba..11c3fa2e 100644 --- a/torchspec/models/ops/loss.py +++ b/torchspec/models/ops/loss.py @@ -22,6 +22,40 @@ import torch.nn.functional as F +def _forward_kl_from_logits(logits: torch.Tensor, target_p: torch.Tensor) -> torch.Tensor: + logits_f32 = logits.float() + return torch.logsumexp(logits_f32, dim=-1) - (target_p * logits_f32).sum(-1) + + +def _softmax_from_logits(logits: torch.Tensor) -> torch.Tensor: + logits_f32 = logits.float() + return torch.exp(logits_f32 - torch.logsumexp(logits_f32, dim=-1, keepdim=True)) + + +@torch.compile(dynamic=None) +def compiled_sum_forward_kl_loss( + prenorm_hidden_states_flat, + target_p_flat, + valid_idx, + norm_weight, + lm_head_weight, + norm_eps, +): + hs = prenorm_hidden_states_flat.index_select(0, valid_idx) + tp = target_p_flat.index_select(0, valid_idx) + + hs_f32 = hs.float() + variance = hs_f32.pow(2).mean(-1, keepdim=True) + rstd = torch.rsqrt(variance + norm_eps) + norm_hs = (hs_f32 * rstd).to(hs.dtype) * norm_weight + + logits = F.linear(norm_hs, lm_head_weight) + token_loss = _forward_kl_from_logits(logits, tp) + correct = (logits.argmax(-1) == tp.argmax(-1)).float() + count = torch.ones_like(token_loss, dtype=torch.float32).sum() + return token_loss.sum(), correct.sum(), count + + @torch.compile(dynamic=None) def compiled_forward_kl_loss( prenorm_hidden_states_flat, @@ -55,14 +89,38 @@ def compiled_forward_kl_loss( logits = F.linear(norm_hs, lm_head_weight) # (N, V_out) - # Forward KL loss - log_p = F.log_softmax(logits.float(), dim=-1) - loss = -(tp * log_p).sum(-1).mean() + token_loss = _forward_kl_from_logits(logits, tp) + correct = (logits.argmax(-1) == tp.argmax(-1)).float() + count = torch.ones_like(token_loss, dtype=torch.float32).sum() + return token_loss.sum(), correct.sum(), count + + +@torch.compile(dynamic=None) +def compiled_sum_forward_kl_loss_from_hs( + prenorm_hidden_states_flat, + target_hidden_states_flat, + valid_idx, + norm_weight, + lm_head_weight, + target_lm_head_weight, + norm_eps, +): + hs = prenorm_hidden_states_flat.index_select(0, valid_idx) + ths = target_hidden_states_flat.index_select(0, valid_idx) + + target_logits = F.linear(ths, target_lm_head_weight) + tp = _softmax_from_logits(target_logits) - # Accuracy - acc = (logits.argmax(-1) == tp.argmax(-1)).float().mean() + hs_f32 = hs.float() + variance = hs_f32.pow(2).mean(-1, keepdim=True) + rstd = torch.rsqrt(variance + norm_eps) + norm_hs = (hs_f32 * rstd).to(hs.dtype) * norm_weight - return loss, acc + logits = F.linear(norm_hs, lm_head_weight) + token_loss = _forward_kl_from_logits(logits, tp) + correct = (logits.argmax(-1) == target_logits.argmax(-1)).float() + count = torch.ones_like(token_loss, dtype=torch.float32).sum() + return token_loss.sum(), correct.sum(), count @torch.compile(dynamic=None) @@ -88,7 +146,8 @@ def compiled_forward_kl_loss_from_hs( ths = target_hidden_states_flat.index_select(0, valid_idx) # Target probs (detached weights → no grad flows through target) - tp = F.softmax(F.linear(ths, target_lm_head_weight).float(), dim=-1) + target_logits = F.linear(ths, target_lm_head_weight) + tp = _softmax_from_logits(target_logits) # RMSNorm hs_f32 = hs.float() @@ -98,11 +157,7 @@ def compiled_forward_kl_loss_from_hs( logits = F.linear(norm_hs, lm_head_weight) - # Forward KL loss - log_p = F.log_softmax(logits.float(), dim=-1) - loss = -(tp * log_p).sum(-1).mean() - - # Accuracy - acc = (logits.argmax(-1) == tp.argmax(-1)).float().mean() - - return loss, acc + token_loss = _forward_kl_from_logits(logits, tp) + correct = (logits.argmax(-1) == target_logits.argmax(-1)).float() + count = torch.ones_like(token_loss, dtype=torch.float32).sum() + return token_loss.sum(), correct.sum(), count diff --git a/torchspec/ray/train_group.py b/torchspec/ray/train_group.py index 86379eb0..76326ebc 100644 --- a/torchspec/ray/train_group.py +++ b/torchspec/ray/train_group.py @@ -26,6 +26,7 @@ from ray.util.placement_group import PlacementGroup from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from torchspec.utils.distributed import _build_usp_group_ranks from torchspec.utils.env import get_torchspec_env_vars @@ -164,10 +165,7 @@ def set_train_queues(self, queues, mooncake_config, per_dp_rank_batch_size: int mooncake_config: MooncakeConfig object. Each actor initializes its own store. per_dp_rank_batch_size: Number of samples per DP rank per training step. """ - if len(queues) != len(self._actor_handlers): - raise ValueError( - f"Number of queues ({len(queues)}) must match number of actors ({len(self._actor_handlers)})" - ) + queues = self._expand_queues_for_usp(queues) return ray.get( [ actor.set_train_queue.remote( @@ -181,11 +179,7 @@ def set_train_queues(self, queues, mooncake_config, per_dp_rank_batch_size: int def set_eval_queues(self, queues, mooncake_config, per_dp_rank_batch_size: int = 1): """Set eval data queues — mirrors set_train_queues.""" - if len(queues) != len(self._actor_handlers): - raise ValueError( - f"Number of eval queues ({len(queues)}) must match " - f"number of actors ({len(self._actor_handlers)})" - ) + queues = self._expand_queues_for_usp(queues) return ray.get( [ actor.set_eval_queue.remote( @@ -197,6 +191,36 @@ def set_eval_queues(self, queues, mooncake_config, per_dp_rank_batch_size: int = ] ) + def _expand_queues_for_usp(self, queues): + actor_count = len(self._actor_handlers) + if len(queues) == actor_count: + return queues + + attention_backend = getattr(self.args, "attention_backend", None) + if attention_backend != "usp": + raise ValueError( + f"Number of queues ({len(queues)}) must match number of actors ({actor_count})" + ) + + sp_size = getattr(self.args, "sp_ulysses_size", 1) * getattr(self.args, "sp_ring_size", 1) + if sp_size <= 0 or actor_count % sp_size != 0: + raise ValueError(f"Invalid USP topology: actor_count={actor_count}, sp_size={sp_size}") + dp_size = actor_count // sp_size + if len(queues) != dp_size: + raise ValueError( + f"USP expects {dp_size} queues for {actor_count} actors with sp_size={sp_size}, " + f"got {len(queues)}" + ) + draft_sp_groups, _, _ = _build_usp_group_ranks( + world_size=actor_count, + sp_ulysses_size=getattr(self.args, "sp_ulysses_size", 1), + sp_ring_size=getattr(self.args, "sp_ring_size", 1), + ) + rank_to_dp_idx = { + rank: dp_idx for dp_idx, ranks in enumerate(draft_sp_groups) for rank in ranks + } + return [queues[rank_to_dp_idx[rank]] for rank in range(actor_count)] + def cache_eval_samples(self, count: int): """Tell every actor to drain ``count`` individual samples from its eval queue into CPU cache.""" return ray.get([actor.cache_eval_samples.remote(count) for actor in self._actor_handlers]) diff --git a/torchspec/train_entry.py b/torchspec/train_entry.py index 8c0d5cac..a2e8ed99 100644 --- a/torchspec/train_entry.py +++ b/torchspec/train_entry.py @@ -147,6 +147,7 @@ def parse_config(): setattr(flat_args, key, value) _resolve_batch_size(flat_args) + _validate_usp_args(flat_args) return flat_args @@ -168,19 +169,67 @@ def _maybe_create_scratch_draft(args, train_group): def _resolve_batch_size(args): """Derive dp_size, per_dp_rank_batch_size, dispatch_batch_size, and global_batch_size.""" - dp_size = ( - getattr(args, "dp_size", None) or args.training_num_nodes * args.training_num_gpus_per_node - ) - args.dp_size = dp_size - sp_size = getattr(args, "sp_size", None) - if sp_size is not None and sp_size != 1: - raise NotImplementedError(f"Sequence parallel is not yet supported (got sp_size={sp_size})") - sp_size = sp_size or 1 + world_size = args.training_num_nodes * args.training_num_gpus_per_node + if getattr(args, "attention_backend", None) == "usp": + sp_size = getattr(args, "sp_ulysses_size", 1) * getattr(args, "sp_ring_size", 1) + if sp_size <= 0: + raise ValueError(f"USP requires positive sp_size, got {sp_size}") + if world_size % sp_size != 0: + raise ValueError( + f"world_size ({world_size}) must be divisible by USP sp_size ({sp_size})" + ) + dp_size = getattr(args, "dp_size", None) or (world_size // sp_size) + if dp_size * sp_size != world_size: + raise ValueError( + f"dp_size ({dp_size}) * sp_size ({sp_size}) must equal world_size ({world_size})" + ) + args.dp_size = dp_size + args.sp_size = sp_size + args.per_dp_rank_batch_size = 1 + else: + dp_size = getattr(args, "dp_size", None) or world_size + args.dp_size = dp_size + sp_size = getattr(args, "sp_size", None) + if sp_size is not None and sp_size != 1: + raise NotImplementedError( + f"Sequence parallel is not yet supported (got sp_size={sp_size})" + ) + sp_size = sp_size or 1 + args.per_dp_rank_batch_size = args.micro_batch_size * sp_size + accumulation_steps = getattr(args, "draft_accumulation_steps", 1) - args.per_dp_rank_batch_size = args.micro_batch_size * sp_size args.global_batch_size = args.per_dp_rank_batch_size * dp_size * accumulation_steps +def _validate_usp_args(args) -> None: + if getattr(args, "attention_backend", None) != "usp": + return + + sp_size = getattr(args, "sp_size", None) + if sp_size is None: + sp_size = getattr(args, "sp_ulysses_size", 1) * getattr(args, "sp_ring_size", 1) + if sp_size <= 1: + raise NotImplementedError(f"USP requires sp_size > 1, got {sp_size}") + + inference_engine_type = getattr(args, "inference_engine_type", "sgl") + if inference_engine_type != "sgl": + raise ValueError( + f"USP currently only supports inference_engine_type=sgl, got {inference_engine_type}" + ) + + fsdp_strategy = getattr(args, "fsdp_strategy", "REPLICATE").upper() + if fsdp_strategy != "REPLICATE": + raise NotImplementedError( + f"USP currently only supports fsdp_strategy=REPLICATE, got {fsdp_strategy}" + ) + + micro_batch_size = getattr(args, "micro_batch_size", 1) + if micro_batch_size != 1: + raise NotImplementedError( + f"USP currently only supports micro_batch_size=1, got {micro_batch_size}" + ) + + def _get_draft_model_config(args): """Resolve draft model config from args or auto-generate from target model.""" diff --git a/torchspec/training/data_fetcher.py b/torchspec/training/data_fetcher.py index 4fd06cc7..9e72c104 100644 --- a/torchspec/training/data_fetcher.py +++ b/torchspec/training/data_fetcher.py @@ -31,10 +31,17 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple import torch +import torch.distributed as dist +import torch.nn.functional as F from ray.util.queue import Queue as RayQueue from torch.utils.data import DataLoader, IterableDataset -from torchspec.data.utils import resolve_loss_mask +from torchspec.data.utils import deserialize_packed_loss_mask, resolve_loss_mask, unpack_loss_mask +from torchspec.utils.distributed import ( + get_draft_sp_group, + get_sp_ring_group, + get_usp_rank_coords, +) from torchspec.utils.logging import logger @@ -45,6 +52,7 @@ class TrainSample: tensor_dtypes: Optional[Dict[str, torch.dtype]] = None packed_loss_mask: Optional[str] = None last_turn_loss_only: Optional[bool] = None + metadata: Optional[Dict[str, Any]] = None class MooncakeDataset(IterableDataset): @@ -68,6 +76,9 @@ def __init__( skip_after_header: int = 0, batch_size: int = 1, min_loss_tokens: int = 0, + usp_enabled: bool = False, + ttt_length: int = 1, + max_seq_length: Optional[int] = None, ): self.ray_queue = ray_queue self.mooncake_store = mooncake_store @@ -81,6 +92,32 @@ def __init__( self.skip_after_header = skip_after_header self._batch_size = batch_size self._min_loss_tokens = min_loss_tokens + self.usp_enabled = usp_enabled + self.ttt_length = ttt_length + self.max_seq_length = max_seq_length + self._init_sp_context() + + def _init_sp_context(self) -> None: + self._sp_group = None + self._sp_world_size = 1 + self._sp_rank = 0 + self._sp_ring_size = 1 + self._sp_ring_rank = 0 + if not self.usp_enabled: + return + + sp_group = get_draft_sp_group() + if sp_group is None: + return + + self._sp_group = sp_group + self._sp_world_size = dist.get_world_size(sp_group) + self._sp_rank = dist.get_rank(sp_group) + + ring_group = get_sp_ring_group() + if ring_group is not None: + self._sp_ring_size = dist.get_world_size(ring_group) + self._sp_ring_rank = dist.get_rank(ring_group) def _load_from_mooncake(self, sample: TrainSample) -> Dict[str, Any]: """Load tensors from mooncake key into device memory.""" @@ -150,6 +187,33 @@ def _compute_loss_mask(self, data: Dict[str, Any]) -> torch.Tensor | None: skip_after_header=self.skip_after_header, ) + def _should_skip_for_loss_mask( + self, data: Dict[str, Any], mooncake_key: str, skip_count: int + ) -> tuple[bool, int]: + mask = self._compute_loss_mask(data) + if mask is None: + skip_count += 1 + logger.warning( + f"Skipping sample with all-zero loss mask " + f"(mooncake_key={mooncake_key}, total_skipped={skip_count})" + ) + return True, skip_count + + if ( + self._min_loss_tokens > 0 + and isinstance(mask, torch.Tensor) + and mask.sum() < self._min_loss_tokens + ): + skip_count += 1 + logger.warning( + f"Skipping sample with too few loss-masked tokens " + f"({int(mask.sum())} < {self._min_loss_tokens}, " + f"mooncake_key={mooncake_key}, total_skipped={skip_count})" + ) + return True, skip_count + + return False, skip_count + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Iterate over samples synchronously. @@ -159,6 +223,15 @@ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: yield_count = 0 skip_count = 0 while True: + if self.usp_enabled: + data, skipped = self._usp_get_sharded_item(skip_count=skip_count) + skip_count += skipped + if data is None: + break + yield_count += 1 + yield data + continue + logger.debug(f"__iter__: waiting for item from ray_queue (yield_count={yield_count})") try: item = self.ray_queue.get(block=True, timeout=self.timeout) @@ -173,26 +246,10 @@ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: logger.debug(f"__iter__: got item, mooncake_key={item.mooncake_key}") data = self._load_from_mooncake(item) - mask = self._compute_loss_mask(data) - if mask is None: - skip_count += 1 - logger.warning( - f"Skipping sample with all-zero loss mask " - f"(mooncake_key={item.mooncake_key}, total_skipped={skip_count})" - ) - continue - - if ( - self._min_loss_tokens > 0 - and isinstance(mask, torch.Tensor) - and mask.sum() < self._min_loss_tokens - ): - skip_count += 1 - logger.warning( - f"Skipping sample with too few loss-masked tokens " - f"({int(mask.sum())} < {self._min_loss_tokens}, " - f"mooncake_key={item.mooncake_key}, total_skipped={skip_count})" - ) + should_skip, skip_count = self._should_skip_for_loss_mask( + data, item.mooncake_key, skip_count + ) + if should_skip: continue # Note: target is computed in the collator from last_hidden_states for sglang mode @@ -222,6 +279,143 @@ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: logger.debug(f"__iter__: yielding batch {yield_count}, keys={list(data.keys())}") yield data + def _usp_global_len(self, sample: TrainSample) -> int: + global_len = sample.tensor_shapes["input_ids"][-1] + if self.max_seq_length is not None: + global_len = min(global_len, self.max_seq_length) + return global_len + + def _usp_chunk_size(self, global_len: int) -> int: + return (global_len + self._sp_world_size - 1) // self._sp_world_size + + def _usp_loss_mask(self, sample: TrainSample, global_len: int) -> torch.Tensor: + if sample.packed_loss_mask is None: + raise RuntimeError("USP sharded Mooncake reads require packed_loss_mask metadata") + loss_mask = unpack_loss_mask(deserialize_packed_loss_mask(sample.packed_loss_mask)) + loss_mask = loss_mask[:global_len] + if loss_mask.shape[0] < global_len: + loss_mask = F.pad(loss_mask, (0, global_len - loss_mask.shape[0])) + return loss_mask + + def _local_usp_shapes(self, sample: TrainSample) -> dict[str, tuple[int, ...]]: + local_len = self._usp_chunk_size(self._usp_global_len(sample)) + self.ttt_length + shapes: dict[str, tuple[int, ...]] = { + "input_ids": (1, local_len), + "hidden_states": (1, local_len, sample.tensor_shapes["hidden_states"][-1]), + } + if "last_hidden_states" in sample.tensor_shapes: + shapes["last_hidden_states"] = ( + 1, + local_len, + sample.tensor_shapes["last_hidden_states"][-1], + ) + if "target" in sample.tensor_shapes: + shapes["target"] = (1, local_len, sample.tensor_shapes["target"][-1]) + return shapes + + def _local_usp_loss_and_position( + self, + sample: TrainSample, + local_len: int, + ) -> dict[str, torch.Tensor]: + sp_ulysses_size = max(1, self._sp_world_size // self._sp_ring_size) + global_len = self._usp_global_len(sample) + chunk_size = self._usp_chunk_size(global_len) + start = self._sp_rank * chunk_size + end = min(start + local_len, global_len) + valid_len = max(0, end - start) + + loss_mask = self._usp_loss_mask(sample, global_len)[start:end].unsqueeze(0) + if loss_mask.shape[-1] < local_len: + loss_mask = F.pad(loss_mask, (0, local_len - loss_mask.shape[-1])) + + attention_mask = torch.zeros((1, local_len), dtype=torch.long) + attention_mask[:, :valid_len] = 1 + + usp_chunk_size = max(local_len - self.ttt_length, 0) + ring_chunk = usp_chunk_size * sp_ulysses_size + _, ring_rank = get_usp_rank_coords( + sp_rank=self._sp_rank, + sp_ulysses_size=sp_ulysses_size, + sp_ring_size=self._sp_ring_size, + ) + ring_start = ring_rank * ring_chunk + position_ids = torch.arange( + ring_start, + ring_start + ring_chunk, + dtype=torch.long, + ).unsqueeze(0) + + return { + "loss_mask": loss_mask.to(self.device), + "attention_mask": attention_mask.to(self.device), + "position_ids": position_ids.to(self.device), + } + + def _should_skip_usp_sharded_sample(self, sample: TrainSample) -> bool: + """Return the SP-consistent skip decision for a pre-sharded USP sample.""" + full_loss_mask = self._usp_loss_mask(sample, self._usp_global_len(sample)) + min_tokens = max(1, self._min_loss_tokens) + return int(full_loss_mask.sum().item()) < min_tokens + + def _usp_get_sharded_item(self, skip_count: int) -> tuple[Dict[str, torch.Tensor] | None, int]: + skipped = 0 + while True: + try: + item = self.ray_queue.get(block=True, timeout=self.timeout) + except Exception as e: + logger.warning( + f"_usp_get_sharded_item: Exception waiting for data: {e}, " + f"timeout={self.timeout}" + ) + return None, skipped + if item is None: + return None, skipped + + metadata = item.metadata or {} + if not metadata.get("usp_sharded", False): + raise RuntimeError( + "USP sharded data fetcher received a non-sharded Mooncake sample. " + f"mooncake_key={item.mooncake_key}" + ) + + shapes = self._local_usp_shapes(item) + dtypes_raw = item.tensor_dtypes or {} + dtypes = {} + for key, dtype_val in dtypes_raw.items(): + if isinstance(dtype_val, str): + dtypes[key] = getattr(torch, dtype_val.replace("torch.", "")) + else: + dtypes[key] = dtype_val + + should_skip = self._should_skip_usp_sharded_sample(item) + shard_key = f"{item.mooncake_key}_usp{self._sp_rank}" + tensors = self.mooncake_store.get( + key=shard_key, + shapes=shapes, + dtypes=dtypes, + device=self.device, + ).to_tensor_dict() + tensors.update(self._local_usp_loss_and_position(item, shapes["input_ids"][-1])) + + self.mooncake_store.remove_eagle3_tensors( + shard_key, + has_last_hidden_states="last_hidden_states" in shapes, + has_target="target" in shapes, + ) + + if should_skip: + skipped += 1 + total_skipped = skip_count + skipped + logger.warning( + f"Skipping USP sharded sample with global all-zero loss mask " + f"(mooncake_key={item.mooncake_key}, sp_rank={self._sp_rank}, " + f"total_skipped={total_skipped})" + ) + continue + + return tensors, skipped + def create_mooncake_dataloader( ray_queue: RayQueue, @@ -237,6 +431,9 @@ def create_mooncake_dataloader( last_turn_loss_only: bool = False, skip_after_header: int = 0, min_loss_tokens: int = 0, + usp_enabled: bool = False, + ttt_length: int = 1, + max_seq_length: Optional[int] = None, ) -> DataLoader: """Create a DataLoader that fetches from mooncake via queue. @@ -277,6 +474,9 @@ def create_mooncake_dataloader( skip_after_header=skip_after_header, batch_size=batch_size, min_loss_tokens=min_loss_tokens, + usp_enabled=usp_enabled, + ttt_length=ttt_length, + max_seq_length=max_seq_length, ) return DataLoader( @@ -318,6 +518,9 @@ def __init__( last_turn_loss_only: bool = False, skip_after_header: int = 0, min_loss_tokens: int = 0, + usp_enabled: bool = False, + ttt_length: int = 1, + max_seq_length: Optional[int] = None, ): self.batch_size = batch_size self._dataloader = create_mooncake_dataloader( @@ -334,6 +537,9 @@ def __init__( last_turn_loss_only=last_turn_loss_only, skip_after_header=skip_after_header, min_loss_tokens=min_loss_tokens, + usp_enabled=usp_enabled, + ttt_length=ttt_length, + max_seq_length=max_seq_length, ) def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: diff --git a/torchspec/training/eagle3_trainer.py b/torchspec/training/eagle3_trainer.py index 230d4010..ebbac8d4 100644 --- a/torchspec/training/eagle3_trainer.py +++ b/torchspec/training/eagle3_trainer.py @@ -104,7 +104,7 @@ def init_model( ] eagle3_model = apply_fsdp2( eagle3_model, - mesh=self.dp_mesh, + mesh=self.grad_sync_mesh, cpu_offload=self.fsdp_cpu_offload, args=self.args, modules_to_shard=midlayer_modules, @@ -113,7 +113,7 @@ def init_model( eagle3_model = fsdp2_load_full_state_dict( eagle3_model, full_state, - self.dp_mesh, + self.grad_sync_mesh, cpu_offload=True if self.fsdp_cpu_offload else None, ) @@ -288,14 +288,17 @@ def _forward(self, batch: dict) -> Tuple[List[torch.Tensor], List[torch.Tensor]] ) del target_hidden_states - plosses, _, acces = self.model( + plosses, vlosses, acces, acc_counts = self.model( input_ids=input_ids, attention_mask=batch["attention_mask"].cuda(), target=target, loss_mask=loss_mask, hidden_states=batch["hidden_states"].cuda(), + position_ids=batch.get("position_ids").cuda() + if batch.get("position_ids") is not None + else None, ) - return plosses, acces + return plosses, vlosses, acces, acc_counts def _backward(self, plosses: List[torch.Tensor], accumulation_steps: int = 1) -> torch.Tensor: ploss_weight = [0.8**i for i in range(len(plosses))] @@ -310,10 +313,11 @@ def _backward(self, plosses: List[torch.Tensor], accumulation_steps: int = 1) -> def eval_forward(self, batch: dict) -> dict: """Single forward pass without backward — returns per-position metrics.""" with torch.no_grad(): - plosses, acces = self._forward(batch) + _plosses, vlosses, acces, acc_counts = self._forward(batch) return { - "plosses": torch.stack(plosses).detach(), + "vlosses": torch.stack(vlosses).detach(), "acces": torch.stack(acces).detach(), + "acc_counts": torch.stack(acc_counts).detach(), } def eval_from_cache(self) -> dict: @@ -347,12 +351,14 @@ def _aggregate_eval_metrics(self, all_step_metrics: list[dict]) -> dict: if not all_step_metrics: return {} - avg_plosses = torch.stack([m["plosses"] for m in all_step_metrics]).mean(dim=0) + avg_vlosses = torch.stack([m["vlosses"] for m in all_step_metrics]).mean(dim=0) avg_acces = torch.stack([m["acces"] for m in all_step_metrics]).mean(dim=0) - dist.all_reduce(avg_plosses, op=dist.ReduceOp.AVG) + dist.all_reduce(avg_vlosses, op=dist.ReduceOp.AVG) dist.all_reduce(avg_acces, op=dist.ReduceOp.AVG) + avg_acc_scalar = avg_acces.mean().item() + cumulative = 1.0 simulated_acc_len = 0.0 for i in range(avg_acces.shape[0]): @@ -360,22 +366,22 @@ def _aggregate_eval_metrics(self, all_step_metrics: list[dict]) -> dict: simulated_acc_len += cumulative ploss_weights = torch.tensor( - [0.8**i for i in range(avg_plosses.shape[0])], device=avg_plosses.device + [0.8**i for i in range(avg_vlosses.shape[0])], device=avg_vlosses.device ) - weighted_avg_loss = (avg_plosses * ploss_weights).sum().item() / ploss_weights.sum().item() + weighted_avg_loss = (avg_vlosses * ploss_weights).sum().item() / ploss_weights.sum().item() metrics: dict = { "eval/avg_loss": weighted_avg_loss, - "eval/avg_acc": avg_acces.mean().item(), + "eval/avg_acc": avg_acc_scalar, "eval/simulated_acc_len": simulated_acc_len, } - for i in range(avg_plosses.shape[0]): - metrics[f"eval/ploss_{i}"] = avg_plosses[i].item() + for i in range(avg_vlosses.shape[0]): + metrics[f"eval/ploss_{i}"] = avg_vlosses[i].item() metrics[f"eval/acc_{i}"] = avg_acces[i].item() if dist.get_rank() == 0: logger.info( - f"eval: loss={weighted_avg_loss:.4f}, acc={avg_acces.mean().item():.4f}, " + f"eval: loss={weighted_avg_loss:.4f}, acc={avg_acc_scalar:.4f}, " f"sim_acc_len={simulated_acc_len:.2f}" ) @@ -393,12 +399,14 @@ def _train_step( batch_idx: int, num_batches: int, ) -> dict: - plosses, acces = self._forward(batch) + plosses, vlosses, acces, acc_counts = self._forward(batch) total_loss = self._backward(plosses, accumulation_steps=accumulation_steps) return { "plosses": torch.stack(plosses).detach(), + "vlosses": torch.stack(vlosses).detach(), "acces": torch.stack(acces).detach(), + "acc_counts": torch.stack(acc_counts).detach(), "plosses_raw": [p.detach() for p in plosses], "acces_raw": [a.detach() for a in acces], "total_loss": total_loss.detach(), @@ -432,15 +440,14 @@ def _aggregate_metrics( if not all_step_metrics: return {} - plosses = [m["plosses"] for m in all_step_metrics] - acces = [m["acces"] for m in all_step_metrics] - - avg_plosses = torch.stack(plosses).mean(dim=0) - avg_acces = torch.stack(acces).mean(dim=0) + avg_vlosses = torch.stack([m["vlosses"] for m in all_step_metrics]).mean(dim=0) + avg_acces = torch.stack([m["acces"] for m in all_step_metrics]).mean(dim=0) - dist.all_reduce(avg_plosses, op=dist.ReduceOp.AVG) + dist.all_reduce(avg_vlosses, op=dist.ReduceOp.AVG) dist.all_reduce(avg_acces, op=dist.ReduceOp.AVG) + avg_acc_scalar = avg_acces.mean().item() + # Simulated acceptance length: acc_0 + acc_0*acc_1 + acc_0*acc_1*acc_2 + ... # Models the expected number of consecutively accepted draft tokens, # which better reflects actual speculative decoding performance. @@ -452,13 +459,13 @@ def _aggregate_metrics( # Compute weighted loss matching _backward's 0.8^i weighting ploss_weights = torch.tensor( - [0.8**i for i in range(avg_plosses.shape[0])], device=avg_plosses.device + [0.8**i for i in range(avg_vlosses.shape[0])], device=avg_vlosses.device ) - weighted_avg_loss = (avg_plosses * ploss_weights).sum().item() / ploss_weights.sum().item() + weighted_avg_loss = (avg_vlosses * ploss_weights).sum().item() / ploss_weights.sum().item() metrics = { "train/avg_loss": weighted_avg_loss, - "train/avg_acc": avg_acces.mean().item(), + "train/avg_acc": avg_acc_scalar, "train/simulated_acc_len": simulated_acc_len, "train/grad_norm": grad_norm.item() if grad_norm is not None else 0.0, "train/global_step": self.global_step, @@ -466,8 +473,8 @@ def _aggregate_metrics( "train/step": step, } - for i in range(avg_plosses.shape[0]): - metrics[f"train/ploss_{i}"] = avg_plosses[i].item() + for i in range(avg_vlosses.shape[0]): + metrics[f"train/ploss_{i}"] = avg_vlosses[i].item() metrics[f"train/acc_{i}"] = avg_acces[i].item() if dist.get_rank() == 0: diff --git a/torchspec/training/fsdp.py b/torchspec/training/fsdp.py index 9c5d4471..8a8d4be9 100644 --- a/torchspec/training/fsdp.py +++ b/torchspec/training/fsdp.py @@ -28,6 +28,16 @@ from torchspec.utils.logging import logger +def _allreduce_with_divide_factor_hook(state: dict, bucket: dist.GradBucket): + tensor = bucket.buffer() + tensor.div_(state["divide_factor"]) + return ( + dist.all_reduce(tensor, group=state["process_group"], async_op=True) + .get_future() + .then(lambda fut: fut.value()[0]) + ) + + @contextmanager def _init_on_device(device: torch.device, include_buffers: Optional[bool] = False): if include_buffers: @@ -165,6 +175,27 @@ def apply_fsdp2( if strategy == "REPLICATE": logger.info("Using REPLICATE strategy (DDP-like, gradient all-reduce only)") replicate(model, device_mesh=mesh) + if args is not None and getattr(args, "attention_backend", None) == "usp": + sp_size = getattr(args, "sp_ulysses_size", 1) * getattr(args, "sp_ring_size", 1) + if sp_size > 1: + process_group = mesh.get_group() if mesh is not None else dist.group.WORLD + divide_factor = dist.get_world_size(process_group) // sp_size + if divide_factor <= 0: + raise ValueError( + f"Invalid USP grad divide factor: world_size=" + f"{dist.get_world_size(process_group)}, sp_size={sp_size}" + ) + model.register_comm_hook( + { + "process_group": process_group, + "divide_factor": divide_factor, + }, + _allreduce_with_divide_factor_hook, + ) + logger.info( + "Registered USP replicate grad hook " + f"(all_reduce / {divide_factor}, sp_size={sp_size})" + ) return model elif strategy != "FULL_SHARD": raise ValueError(f"Unknown fsdp_strategy: {strategy}. Use 'FULL_SHARD' or 'REPLICATE'") diff --git a/torchspec/training/trainer.py b/torchspec/training/trainer.py index ec944c40..68a71b76 100644 --- a/torchspec/training/trainer.py +++ b/torchspec/training/trainer.py @@ -44,6 +44,7 @@ from torchspec.training.fsdp import init_empty_weights from torchspec.training.optimizer import BF16Optimizer from torchspec.transfer.mooncake.eagle_store import EagleMooncakeStore +from torchspec.utils.distributed import get_usp_device_mesh, get_usp_grad_sync_mesh from torchspec.utils.logging import logger from torchspec.utils.processing import get_assistant_token_ids from torchspec.utils.profiling import TrainProfiler @@ -99,6 +100,26 @@ def __init__(self, args: Namespace): def _setup_device_mesh(self) -> None: world_size = dist.get_world_size() rank = dist.get_rank() + self.cache_rank = rank + + usp_mesh = None + if getattr(self.args, "attention_backend", None) == "usp": + usp_mesh = get_usp_device_mesh() + + if usp_mesh is not None: + self.mesh = usp_mesh + self.dp_size = getattr(self.args, "dp_size", world_size) + self.dp_mesh = usp_mesh["draft_dp"] + self.grad_sync_mesh = get_usp_grad_sync_mesh() + if self.grad_sync_mesh is None: + raise RuntimeError("USP grad sync mesh has not been initialized") + self.dp_group = usp_mesh.get_group("draft_dp") + self.dp_rank = dist.get_rank(self.dp_group) + logger.info( + f"[Rank {rank}] Device mesh (USP): world_size={world_size}, dp_size={self.dp_size}, " + f"dp_rank={self.dp_rank}, grad_sync_size={world_size}" + ) + return self.dp_size = world_size self.dp_rank = rank @@ -106,6 +127,7 @@ def _setup_device_mesh(self) -> None: self.mesh = init_device_mesh("cuda", mesh_shape=(self.dp_size,), mesh_dim_names=("dp",)) self.dp_group = self.mesh.get_group("dp") self.dp_mesh = self.mesh + self.grad_sync_mesh = self.dp_mesh logger.info( f"[Rank {rank}] Device mesh (1D): world_size={world_size}, dp_size={self.dp_size}" @@ -156,10 +178,13 @@ def set_train_queue( ) -> None: self.train_queue = queue self.per_dp_rank_batch_size = per_dp_rank_batch_size + usp_enabled = getattr(self.args, "attention_backend", None) == "usp" + if usp_enabled and per_dp_rank_batch_size != 1: + raise ValueError("USP requires per_dp_rank_batch_size=1") if mooncake_config is not None and self.mooncake_store is None: self.init_mooncake_store(mooncake_config) - collator = DataCollatorWithPadding() + collator = DataCollatorWithPadding(usp_enabled=usp_enabled) prefetch_depth = getattr(self.args, "prefetch_depth", 0) gpu_device = torch.cuda.current_device() @@ -180,6 +205,9 @@ def set_train_queue( last_turn_loss_only=self.last_turn_loss_only, skip_after_header=self.skip_after_header, min_loss_tokens=getattr(self.args, "min_loss_tokens", 0), + usp_enabled=usp_enabled, + ttt_length=getattr(self.args, "ttt_length", 1), + max_seq_length=getattr(self.args, "max_seq_length", None), ) if prefetch_depth > 0: @@ -209,10 +237,11 @@ def set_eval_queue( mooncake_config: Optional[MooncakeConfig] = None, per_dp_rank_batch_size: int = 1, ) -> None: + usp_enabled = getattr(self.args, "attention_backend", None) == "usp" if mooncake_config is not None and self.mooncake_store is None: self.init_mooncake_store(mooncake_config) - collator = DataCollatorWithPadding() + collator = DataCollatorWithPadding(usp_enabled=usp_enabled) self._eval_data_fetcher = MooncakeDataFetcher( queue=queue, @@ -226,6 +255,9 @@ def set_eval_queue( last_turn_loss_only=self.last_turn_loss_only, skip_after_header=self.skip_after_header, min_loss_tokens=getattr(self.args, "min_loss_tokens", 0), + usp_enabled=usp_enabled, + ttt_length=getattr(self.args, "ttt_length", 1), + max_seq_length=getattr(self.args, "max_seq_length", None), ) self._eval_collator = collator self._eval_cache: list[dict] = [] @@ -248,7 +280,7 @@ def save_eval_cache(self, cache_dir: str) -> None: self._wait_for_eval_cache_save() cache_snapshot = list(self._eval_cache) - rank = self.dp_rank + rank = self.cache_rank def _save() -> None: os.makedirs(cache_dir, exist_ok=True) @@ -269,7 +301,7 @@ def _wait_for_eval_cache_save(self) -> None: def load_eval_cache(self, cache_dir: str) -> int: # Safe guard to wait for eval cache save to complete. self._wait_for_eval_cache_save() - path = os.path.join(cache_dir, f"eval_rank_{self.dp_rank}.pt") + path = os.path.join(cache_dir, f"eval_rank_{self.cache_rank}.pt") if not os.path.exists(path): return 0 try: diff --git a/torchspec/training/trainer_actor.py b/torchspec/training/trainer_actor.py index 524f47e4..09fc38d8 100644 --- a/torchspec/training/trainer_actor.py +++ b/torchspec/training/trainer_actor.py @@ -28,7 +28,7 @@ from torchspec.models.draft.dflash import DFlashConfig from torchspec.ray.ray_actor import RayActor from torchspec.training.eagle3_trainer import Eagle3Trainer -from torchspec.utils.distributed import init_gloo_group +from torchspec.utils.distributed import init_gloo_group, init_usp_groups from torchspec.utils.logging import setup_file_logging @@ -60,6 +60,12 @@ def init(self, args: Namespace, role: str, mooncake_config=None, with_ref: bool timeout=timedelta(minutes=getattr(args, "distributed_timeout_minutes", 30)), ) + if getattr(args, "attention_backend", None) == "usp": + init_usp_groups( + sp_ulysses_size=getattr(args, "sp_ulysses_size", 1), + sp_ring_size=getattr(args, "sp_ring_size", 1), + ) + init_gloo_group() args.rank = dist.get_rank() diff --git a/torchspec/transfer/mooncake/eagle_store.py b/torchspec/transfer/mooncake/eagle_store.py index 3366ae3c..718b9360 100644 --- a/torchspec/transfer/mooncake/eagle_store.py +++ b/torchspec/transfer/mooncake/eagle_store.py @@ -26,7 +26,9 @@ from torchspec.transfer.mooncake.helpers import _format_bytes from torchspec.transfer.mooncake.store import MooncakeHiddenStateStore +from torchspec.utils.distributed import get_usp_rank_coords from torchspec.utils.logging import logger +from torchspec.utils.usp import split_usp_batch if TYPE_CHECKING: from torchspec.models.target.eagle3_target_model import Eagle3TargetOutput @@ -66,6 +68,138 @@ class EagleMooncakeStore(MooncakeHiddenStateStore): TENSOR_SUFFIXES = ["_hs", "_tgt", "_ids", "_lhs"] + def _put_raw_tensors(self, keys: List[str], tensors: List[torch.Tensor]) -> None: + if self._gpu_direct_available and self._gpu_send_buffer is not None: + buf = self._gpu_send_buffer + buffer_ptrs, sizes = self._stage_tensors_into_buffer(buf, tensors) + self._do_sync_batch_put(keys, buffer_ptrs, sizes) + elif self._host_buffer_pool is None or self._async_put_manager is None: + raise RuntimeError( + "put() requires either GPU Direct (enable_gpu_direct=True) or " + "async host-buffer puts (async_put_pool_size > 0). " + "Current config has async_put_pool_size=0 and GPU Direct is " + f"{'enabled but gpu_send_buffer failed to initialize' if self._gpu_direct_available else 'disabled'}. " + "Set async_put_pool_size >= 1 or enable GPU Direct." + ) + else: + buf = self._host_buffer_pool.get_buffer() + self._async_put_manager.check_last_error() + self._async_put_manager.wait_for_buffer(buf.ptr) + + compute_event = torch.cuda.Event() + compute_event.record() + + with torch.cuda.stream(self._copy_stream): + self._copy_stream.wait_event(compute_event) + buffer_ptrs, sizes = self._stage_tensors_into_buffer(buf, tensors) + copy_done = torch.cuda.Event() + copy_done.record() + + for t in tensors: + if t.is_cuda: + t.record_stream(self._copy_stream) + + self._async_put_manager.submit( + keys, + buffer_ptrs, + sizes, + buf.ptr, + wait_event=copy_done, + device_index=self._copy_stream.device.index, + ) + + def put_usp_shards( + self, + key: str, + hidden_states: torch.Tensor, + input_ids: torch.Tensor, + last_hidden_states: Optional[torch.Tensor], + target: Optional[torch.Tensor], + *, + sp_size: int, + sp_ring_size: int, + ttt_length: int, + max_seq_length: Optional[int], + ) -> Dict[str, Any]: + self._ensure_initialized() + logger.debug("put_usp_shards: starting for key=%s, sp_size=%s", key, sp_size) + + if hidden_states.dtype != HIDDEN_STATES_STORAGE_DTYPE: + hidden_states = hidden_states.to(HIDDEN_STATES_STORAGE_DTYPE) + if ( + last_hidden_states is not None + and last_hidden_states.dtype != HIDDEN_STATES_STORAGE_DTYPE + ): + last_hidden_states = last_hidden_states.to(HIDDEN_STATES_STORAGE_DTYPE) + if target is not None and target.dtype != HIDDEN_STATES_STORAGE_DTYPE: + target = target.to(HIDDEN_STATES_STORAGE_DTYPE) + + target_hs = last_hidden_states if last_hidden_states is not None else target + if target_hs is None: + raise ValueError("USP sharded Mooncake storage requires last_hidden_states or target") + + if input_ids.dim() == 2: + loss_mask_len = input_ids.shape[-1] + else: + loss_mask_len = input_ids.shape[0] + dummy_loss_mask = torch.ones((1, loss_mask_len), dtype=torch.long, device=input_ids.device) + sp_ulysses_size = max(1, sp_size // sp_ring_size) + + for sp_rank in range(sp_size): + _, ring_rank = get_usp_rank_coords( + sp_rank=sp_rank, + sp_ulysses_size=sp_ulysses_size, + sp_ring_size=sp_ring_size, + ) + ( + shard_input_ids, + _attention_mask, + _loss_mask, + shard_hidden_states, + shard_target_hs, + _position_ids, + ) = split_usp_batch( + input_ids=input_ids, + loss_mask=dummy_loss_mask, + hidden_states=hidden_states, + target_hidden_states=target_hs, + ttt_length=ttt_length, + sp_rank=sp_rank, + sp_size=sp_size, + ring_rank=ring_rank, + sp_ring_size=sp_ring_size, + max_len=max_seq_length, + ) + + shard_key = f"{key}_usp{sp_rank}" + keys = [f"{shard_key}_hs", f"{shard_key}_ids"] + tensors = [shard_hidden_states, shard_input_ids] + if last_hidden_states is not None: + keys.append(f"{shard_key}_lhs") + tensors.append(shard_target_hs) + else: + keys.append(f"{shard_key}_tgt") + tensors.append(shard_target_hs) + self._put_raw_tensors(keys, tensors) + + shapes = { + "hidden_states": tuple(hidden_states.shape), + "input_ids": tuple(input_ids.shape), + } + dtypes = { + "hidden_states": hidden_states.dtype, + "input_ids": input_ids.dtype, + } + if target is not None: + shapes["target"] = tuple(target.shape) + dtypes["target"] = target.dtype + if last_hidden_states is not None: + shapes["last_hidden_states"] = tuple(last_hidden_states.shape) + dtypes["last_hidden_states"] = last_hidden_states.dtype + + logger.debug("put_usp_shards: completed key=%s, shapes=%s", key, shapes) + return {"shapes": shapes, "dtypes": dtypes} + def put( self, key: str, @@ -112,46 +246,7 @@ def put( keys.append(f"{key}_lhs") tensors.append(last_hidden_states) - if self._gpu_direct_available and self._gpu_send_buffer is not None: - buf = self._gpu_send_buffer - buffer_ptrs, sizes = self._stage_tensors_into_buffer(buf, tensors) - self._do_sync_batch_put(keys, buffer_ptrs, sizes) - elif self._host_buffer_pool is None or self._async_put_manager is None: - raise RuntimeError( - "put() requires either GPU Direct (enable_gpu_direct=True) or " - "async host-buffer puts (async_put_pool_size > 0). " - "Current config has async_put_pool_size=0 and GPU Direct is " - f"{'enabled but gpu_send_buffer failed to initialize' if self._gpu_direct_available else 'disabled'}. " - "Set async_put_pool_size >= 1 or enable GPU Direct." - ) - else: - buf = self._host_buffer_pool.get_buffer() - self._async_put_manager.check_last_error() - self._async_put_manager.wait_for_buffer(buf.ptr) - - # Stage DtoH on a dedicated stream so the default (compute) stream - # is free to run the next prefill concurrently. - compute_event = torch.cuda.Event() - compute_event.record() - - with torch.cuda.stream(self._copy_stream): - self._copy_stream.wait_event(compute_event) - buffer_ptrs, sizes = self._stage_tensors_into_buffer(buf, tensors) - copy_done = torch.cuda.Event() - copy_done.record() - - for t in tensors: - if t.is_cuda: - t.record_stream(self._copy_stream) - - self._async_put_manager.submit( - keys, - buffer_ptrs, - sizes, - buf.ptr, - wait_event=copy_done, - device_index=self._copy_stream.device.index, - ) + self._put_raw_tensors(keys, tensors) shapes = { "hidden_states": tuple(hidden_states.shape), diff --git a/torchspec/utils/distributed.py b/torchspec/utils/distributed.py index 12e215d6..60f6d280 100644 --- a/torchspec/utils/distributed.py +++ b/torchspec/utils/distributed.py @@ -24,6 +24,11 @@ _TP_DEVICE_MESH = None _TP_GROUP = None +_USP_DEVICE_MESH = None +_USP_GRAD_SYNC_MESH = None +_DRAFT_SP_GROUP = None +_SP_ULYSSES_GROUP = None +_SP_RING_GROUP = None def init_gloo_group(): @@ -50,3 +55,180 @@ def get_tp_group(): def get_tp_device_mesh(): global _TP_DEVICE_MESH return _TP_DEVICE_MESH + + +def get_usp_device_mesh(): + global _USP_DEVICE_MESH + return _USP_DEVICE_MESH + + +def get_usp_grad_sync_mesh(): + global _USP_GRAD_SYNC_MESH + return _USP_GRAD_SYNC_MESH + + +def _build_usp_group_ranks( + world_size: int, sp_ulysses_size: int, sp_ring_size: int +) -> tuple[list[list[int]], list[list[int]], list[list[int]]]: + sp_size = sp_ulysses_size * sp_ring_size + if sp_size <= 0: + raise ValueError(f"sp_size must be positive, got {sp_size}") + if world_size % sp_size != 0: + raise ValueError( + "world_size must be divisible by sp_ulysses_size * sp_ring_size, " + f"got world_size={world_size}, sp_ulysses_size={sp_ulysses_size}, " + f"sp_ring_size={sp_ring_size}" + ) + + draft_sp_groups: list[list[int]] = [] + ulysses_groups: list[list[int]] = [] + ring_groups: list[list[int]] = [] + num_ulysses_pgs = sp_ring_size + num_ring_pgs = sp_ulysses_size + for base_rank in range(0, world_size, sp_size): + draft_sp_groups.append(list(range(base_rank, base_rank + sp_size))) + for idx in range(num_ulysses_pgs): + ulysses_groups.append( + list( + range( + base_rank + idx * sp_ulysses_size, + base_rank + (idx + 1) * sp_ulysses_size, + ) + ) + ) + for idx in range(num_ring_pgs): + ring_groups.append(list(range(base_rank + idx, base_rank + sp_size, num_ring_pgs))) + return draft_sp_groups, ulysses_groups, ring_groups + + +def init_usp_groups(sp_ulysses_size: int = 1, sp_ring_size: int = 1): + global _USP_DEVICE_MESH + global _USP_GRAD_SYNC_MESH + global _DRAFT_SP_GROUP + global _SP_ULYSSES_GROUP, _SP_RING_GROUP + + sp_size = sp_ulysses_size * sp_ring_size + if sp_size == 1: + _USP_DEVICE_MESH = None + _USP_GRAD_SYNC_MESH = None + _DRAFT_SP_GROUP = None + _SP_ULYSSES_GROUP = None + _SP_RING_GROUP = None + return None, None, None + + world_size = dist.get_world_size() + rank = dist.get_rank() + if world_size % sp_size != 0: + raise ValueError( + "world_size must be divisible by sp_ulysses_size * sp_ring_size, " + f"got world_size={world_size}, sp_ulysses_size={sp_ulysses_size}, " + f"sp_ring_size={sp_ring_size}" + ) + + draft_dp_size = world_size // sp_size + + _DRAFT_SP_GROUP = None + _SP_ULYSSES_GROUP = None + _SP_RING_GROUP = None + + _USP_DEVICE_MESH = dist.device_mesh.init_device_mesh( + "cuda", + (draft_dp_size, sp_size), + mesh_dim_names=("draft_dp", "draft_sp"), + ) + _DRAFT_SP_GROUP = _USP_DEVICE_MESH.get_group("draft_sp") + _USP_GRAD_SYNC_MESH = dist.device_mesh.init_device_mesh( + "cuda", + (world_size,), + mesh_dim_names=("draft_dp_with_sp",), + ) + + import yunchang + from yunchang.globals import PROCESS_GROUP as YUNCHANG_PROCESS_GROUP + + yunchang.set_seq_parallel_pg( + sp_ulysses_degree=sp_ulysses_size, + sp_ring_degree=sp_ring_size, + rank=rank, + world_size=world_size, + use_ulysses_low=True, + ) + _SP_ULYSSES_GROUP = YUNCHANG_PROCESS_GROUP.ULYSSES_PG + _SP_RING_GROUP = YUNCHANG_PROCESS_GROUP.RING_PG + _validate_usp_group_composition() + + return _DRAFT_SP_GROUP, _SP_ULYSSES_GROUP, _SP_RING_GROUP + + +def get_draft_sp_group(): + global _DRAFT_SP_GROUP + return _DRAFT_SP_GROUP + + +def get_sp_ulysses_group(): + global _SP_ULYSSES_GROUP + return _SP_ULYSSES_GROUP + + +def get_sp_ring_group(): + global _SP_RING_GROUP + return _SP_RING_GROUP + + +def get_sp_rank() -> int: + sp_group = get_draft_sp_group() + if sp_group is None: + return 0 + return dist.get_rank(sp_group) + + +def get_sp_ulysses_rank() -> int: + ulysses_group = get_sp_ulysses_group() + if ulysses_group is None: + return 0 + return dist.get_rank(ulysses_group) + + +def get_sp_ring_rank() -> int: + ring_group = get_sp_ring_group() + if ring_group is None: + return 0 + return dist.get_rank(ring_group) + + +def get_usp_rank_coords(sp_rank: int, sp_ulysses_size: int, sp_ring_size: int) -> tuple[int, int]: + sp_size = sp_ulysses_size * sp_ring_size + if sp_rank < 0 or sp_rank >= sp_size: + raise ValueError(f"sp_rank must be in [0, {sp_size}), got {sp_rank}") + ulysses_rank = sp_rank % sp_ulysses_size + ring_rank = sp_rank // sp_ulysses_size + return ulysses_rank, ring_rank + + +def _gather_group_members(group) -> tuple[int, ...]: + group_world_size = dist.get_world_size(group) + members = [None] * group_world_size + dist.all_gather_object(members, dist.get_rank(), group=group) + return tuple(members) + + +def _validate_usp_group_composition() -> None: + sp_group = get_draft_sp_group() + ulysses_group = get_sp_ulysses_group() + ring_group = get_sp_ring_group() + if sp_group is None or ulysses_group is None or ring_group is None: + raise RuntimeError("USP groups must be initialized before validating group composition") + + sp_members = _gather_group_members(sp_group) + local_record = { + "world_rank": dist.get_rank(), + "ring_members": _gather_group_members(ring_group), + "ulysses_members": _gather_group_members(ulysses_group), + } + records = [None] * dist.get_world_size(sp_group) + dist.all_gather_object(records, local_record, group=sp_group) + for record in records: + ring_members = tuple(record["ring_members"]) + ulysses_members = tuple(record["ulysses_members"]) + if any(member not in sp_members for member in ring_members + ulysses_members): + raise RuntimeError("USP ring/ulysses groups include ranks outside the draft SP group") diff --git a/torchspec/utils/usp.py b/torchspec/utils/usp.py new file mode 100644 index 00000000..a8679823 --- /dev/null +++ b/torchspec/utils/usp.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + + +def split_usp_batch( + *, + input_ids: torch.Tensor, + loss_mask: torch.Tensor, + hidden_states: torch.Tensor, + target_hidden_states: torch.Tensor, + ttt_length: int, + sp_rank: int, + sp_size: int, + ring_rank: int, + sp_ring_size: int, + max_len: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if input_ids.dim() == 1: + input_ids = input_ids.unsqueeze(0) + if loss_mask.dim() == 1: + loss_mask = loss_mask.unsqueeze(0) + + batch_size, input_len = input_ids.shape + global_len = min(max_len, input_len) if max_len is not None else input_len + chunk_size = (global_len + sp_size - 1) // sp_size + sp_ulysses_size = max(1, sp_size // sp_ring_size) + start = sp_rank * chunk_size + local_len = chunk_size + ttt_length + end = min(start + local_len, global_len) + + loss_mask = loss_mask[:, :global_len].clone() + + def _slice_and_pad(tensor: torch.Tensor, axis: int, pad_value: int = 0): + if tensor.dim() == 1: + tensor = tensor.unsqueeze(0) + if axis == 0: + tensor = tensor[:global_len, :] + sliced = tensor[start : min(end, tensor.shape[0]), :] + valid_len = sliced.shape[0] + if valid_len < local_len: + sliced = F.pad(sliced, (0, 0, 0, local_len - valid_len), value=pad_value) + else: + tensor = tensor[:, :global_len] + sliced = tensor[:, start : min(end, tensor.shape[1])] + valid_len = sliced.shape[1] + if valid_len < local_len: + pad_len = local_len - valid_len + if tensor.dim() == 2: + sliced = F.pad(sliced, (0, pad_len), value=pad_value) + else: + sliced = F.pad(sliced, (0, 0, 0, pad_len), value=pad_value) + return sliced.contiguous(), valid_len + + input_ids, valid_len = _slice_and_pad(input_ids, axis=1, pad_value=0) + loss_mask, _ = _slice_and_pad(loss_mask, axis=1, pad_value=0) + if hidden_states.dim() == 2: + hidden_states, _ = _slice_and_pad(hidden_states, axis=0, pad_value=0) + hidden_states = hidden_states.unsqueeze(0) + else: + hidden_states, _ = _slice_and_pad(hidden_states, axis=1, pad_value=0) + if target_hidden_states.dim() == 2: + target_hidden_states, _ = _slice_and_pad(target_hidden_states, axis=0, pad_value=0) + target_hidden_states = target_hidden_states.unsqueeze(0) + else: + target_hidden_states, _ = _slice_and_pad(target_hidden_states, axis=1, pad_value=0) + + attention_mask = torch.zeros((batch_size, local_len), dtype=torch.long, device=input_ids.device) + attention_mask[:, :valid_len] = 1 + + usp_chunk_size = max(local_len - ttt_length, 0) + ring_chunk = usp_chunk_size * sp_ulysses_size + ring_start = ring_rank * ring_chunk + position_ids = torch.arange( + ring_start, ring_start + ring_chunk, device=input_ids.device, dtype=torch.long + ).unsqueeze(0) + if batch_size > 1: + position_ids = position_ids.expand(batch_size, -1) + + return ( + input_ids, + attention_mask, + loss_mask, + hidden_states, + target_hidden_states, + position_ids, + )