diff --git a/scripts/train_eagle3_offline.py b/scripts/train_eagle3_offline.py index abf1b2e2..cca4eb3b 100644 --- a/scripts/train_eagle3_offline.py +++ b/scripts/train_eagle3_offline.py @@ -541,20 +541,11 @@ def main(): for k, v in model_state_dict.items() if "draft_model." in k and "embed" not in k.lower() } + draft_model.save_pretrained( + os.path.join(args.output_dir, f"epoch_{epoch}"), + state_dict=draft_model_state_dict, + ) - if dist.get_rank() == 0: - torch.save( - state_to_save, - os.path.join(epoch_output_dir, "training_state.pt"), - ) - print_on_rank0( - f"Saved full training state to {epoch_output_dir}/training_state.pt" - ) - draft_model.save_pretrained( - epoch_output_dir, - state_dict=draft_model_state_dict, - ) - print_on_rank0(f"Saved model configuration to {epoch_output_dir}") dist.barrier() # Close the tracker at the end of training diff --git a/scripts/train_eagle3_online.py b/scripts/train_eagle3_online.py index 8cab6132..9ea98387 100644 --- a/scripts/train_eagle3_online.py +++ b/scripts/train_eagle3_online.py @@ -291,22 +291,17 @@ def main(): # load model with resume if draft_model_last_checkpoint: - draft_model = ( - AutoEagle3DraftModel.from_pretrained( - draft_model_last_checkpoint, attention_backend=args.attention_backend, - torch_dtype=torch.bfloat16 - ) - .cuda() - - ) + draft_model = AutoEagle3DraftModel.from_pretrained( + draft_model_last_checkpoint, + attention_backend=args.attention_backend, + torch_dtype=torch.bfloat16, + ).cuda() else: - draft_model = ( - AutoEagle3DraftModel.from_config( - draft_model_config, attention_backend=args.attention_backend, - torch_dtype=torch.bfloat16 - ) - .cuda() - ) + draft_model = AutoEagle3DraftModel.from_config( + draft_model_config, + attention_backend=args.attention_backend, + torch_dtype=torch.bfloat16, + ).cuda() draft_model.load_embedding(args.target_model_path, embedding_key=args.embedding_key) draft_model.freeze_embedding() print_with_rank("Initialized draft model") @@ -652,6 +647,12 @@ def main(): if "draft_model." in k and "embed" not in k.lower() } + # The new save_pretrained method handles all TP logic internally. + # It ensures only global rank 0 writes to disk. + draft_model.save_pretrained( + epoch_output_dir, + state_dict=draft_model_state_dict, + ) if dist.get_rank() == 0: torch.save( state_to_save, @@ -660,10 +661,6 @@ def main(): print_on_rank0( f"Saved full training state to {epoch_output_dir}/training_state.pt" ) - draft_model.save_pretrained( - epoch_output_dir, - state_dict=draft_model_state_dict, - ) print_on_rank0(f"Saved model configuration to {epoch_output_dir}") dist.barrier() diff --git a/specforge/layers/linear.py b/specforge/layers/linear.py index 51a80ecc..1579cf33 100644 --- a/specforge/layers/linear.py +++ b/specforge/layers/linear.py @@ -2,6 +2,7 @@ import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F +from torch.autograd import Function from specforge.distributed import get_tp_group @@ -125,3 +126,17 @@ def load_state_dict(self, state_dict, strict=True): def __repr__(self): return f"ColumnParallelLinear(in_features={self.in_features}, out_features={self.out_features_per_shard}, tp_size={self.tp_size}, tp_rank={self.tp_rank})" + + +class _AllReduce(Function): + @staticmethod + def forward(ctx, input, op, group): + # ctx is a context object that can be used to stash information for backward computation + output = input.clone() + dist.all_reduce(output, op=op, group=group) + return output + + @staticmethod + def backward(ctx, grad_output): + # # The gradient of all_reduce is an identity function, so we can directly return the gradient + return grad_output, None, None diff --git a/specforge/modeling/draft/base.py b/specforge/modeling/draft/base.py index 468988fc..9e05f0be 100644 --- a/specforge/modeling/draft/base.py +++ b/specforge/modeling/draft/base.py @@ -27,12 +27,15 @@ from typing import Optional, Tuple import torch +import torch.distributed as dist import torch.nn as nn from huggingface_hub import snapshot_download from safetensors import safe_open from transformers.cache_utils import Cache from transformers.modeling_utils import PreTrainedModel +from specforge.distributed import get_tp_group +from specforge.layers.linear import ColumnParallelLinear, RowParallelLinear from specforge.modeling._mask_utils import _expand_mask, _make_causal_mask @@ -191,3 +194,71 @@ def load_vocab_mapping(self, file_path: str) -> None: vocab_mapping = torch.load(file_path) self.t2d.copy_(vocab_mapping["t2d"]) self.d2t.copy_(vocab_mapping["d2t"]) + + def save_pretrained(self, save_directory, state_dict=None, **kwargs): + """ + Overrides save_pretrained to handle TP weight aggregation robustly. + This method gathers sharded weights from all TP ranks and saves a single, + complete checkpoint from the main process. + """ + if not dist.is_initialized(): + # Standard non-distributed save + super().save_pretrained(save_directory, state_dict=state_dict, **kwargs) + return + + # Use the provided state_dict or get it from the model + if state_dict is None: + state_dict = self.state_dict() + + # Get distributed process groups and ranks + global_rank = dist.get_rank() + tp_group = get_tp_group() + tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1 + tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0 + + # If not using TP, only rank 0 saves and others do nothing. + if tp_size <= 1: + if global_rank == 0: + super().save_pretrained(save_directory, state_dict=state_dict, **kwargs) + dist.barrier() + return + + # --- Aggregation Logic for TP > 1 --- + # Step 1: Each TP rank's leader (tp_rank == 0) will reconstruct the full state dict. + reconstructed_state_dict = None + if tp_rank == 0: + reconstructed_state_dict = {} + + # All ranks in a TP group participate in gathering shards for each parameter. + modules = dict(self.named_modules()) + for name, param in state_dict.items(): + # Gather shards from all TP ranks into a list + tensor_list = [torch.empty_like(param) for _ in range(tp_size)] + dist.all_gather(tensor_list, param.contiguous(), group=tp_group) + + # Let the tp_rank 0 process handle the concatenation + if tp_rank == 0: + module_name = ".".join(name.split(".")[:-1]) + module = modules.get(module_name) + + if isinstance(module, ColumnParallelLinear) and name.endswith( + ".weight" + ): + # Concat along dimension 0 for ColumnParallel + reconstructed_state_dict[name] = torch.cat(tensor_list, dim=0) + elif isinstance(module, RowParallelLinear) and name.endswith(".weight"): + # Concat along dimension 1 for RowParallel + reconstructed_state_dict[name] = torch.cat(tensor_list, dim=1) + else: + # Non-parallel layers (biases, norms, etc.) are identical across ranks + reconstructed_state_dict[name] = tensor_list[0] + + # Step 2: Only the global rank 0 process saves the final model. + if global_rank == 0: + print(f"Rank {global_rank} saving aggregated model checkpoint...") + super().save_pretrained( + save_directory, state_dict=reconstructed_state_dict, **kwargs + ) + + # Step 3: Barrier to ensure all processes wait until saving is complete. + dist.barrier() diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 22e36a94..3aca1bda 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -3,6 +3,7 @@ from typing import List, Optional, Tuple import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from torch.nn.attention.flex_attention import create_block_mask, flex_attention @@ -11,6 +12,8 @@ from transformers.cache_utils import Cache from transformers.models.llama.configuration_llama import LlamaConfig +from specforge.distributed import get_tp_group +from specforge.layers.linear import ColumnParallelLinear, RowParallelLinear, _AllReduce from specforge.modeling.draft.flex_attention import ( compile_friendly_create_block_mask, compile_friendly_flex_attention, @@ -343,27 +346,42 @@ class LlamaAttention(nn.Module): def __init__(self, config): super().__init__() self.config = config + self.tp_group = get_tp_group() + self._tp_size = ( + dist.get_world_size(self.tp_group) if self.tp_group is not None else 1 + ) + self._tp_rank = dist.get_rank(self.tp_group) if self.tp_group is not None else 0 self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads if hasattr(config, "head_dim"): self.head_dim = config.head_dim else: self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads + + # adjust head number based on tp size + self.num_heads = config.num_attention_heads // self._tp_size + self.num_key_value_heads = config.num_key_value_heads // self._tp_size + assert ( + config.num_attention_heads % self._tp_size == 0 + ), "num_attention_heads must be divisible by tp_size" + assert ( + config.num_key_value_heads % self._tp_size == 0 + ), "num_key_value_heads must be divisible by tp_size" + self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings - self.q_proj = nn.Linear( - self.hidden_size * 2, self.num_heads * self.head_dim, bias=False + self.q_proj = ColumnParallelLinear( + self.hidden_size * 2, config.num_attention_heads * self.head_dim, bias=False ) - self.k_proj = nn.Linear( - self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False + self.k_proj = ColumnParallelLinear( + self.hidden_size * 2, config.num_key_value_heads * self.head_dim, bias=False ) - self.v_proj = nn.Linear( - self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False + self.v_proj = ColumnParallelLinear( + self.hidden_size * 2, config.num_key_value_heads * self.head_dim, bias=False ) - self.o_proj = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=False + self.o_proj = RowParallelLinear( + config.num_attention_heads * self.head_dim, self.hidden_size, bias=False ) self._init_rope() @@ -512,7 +530,6 @@ def forward( (attn_weights, attn_weightsi[..., None]), dim=-1 ) - # upcast attention to fp32 attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(query_states.dtype) @@ -530,7 +547,10 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, self.head_dim * self.num_heads) attn_output = self.o_proj(attn_output) - + if self._tp_size > 1: + attn_output = _AllReduce.apply( + attn_output, dist.ReduceOp.SUM, self.tp_group + ) return attn_output @@ -648,44 +668,35 @@ class LlamaMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config + + self.tp_group = get_tp_group() + self._tp_size = ( + dist.get_world_size(self.tp_group) if self.tp_group is not None else 1 + ) + self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.gate_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, self.hidden_size, bias=False + ) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - if self.config.pretraining_tp > 1: - slice = self.intermediate_size // self.config.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat( - [ - F.linear(x, gate_proj_slices[i]) - for i in range(self.config.pretraining_tp) - ], - dim=-1, - ) - up_proj = torch.cat( - [ - F.linear(x, up_proj_slices[i]) - for i in range(self.config.pretraining_tp) - ], - dim=-1, - ) - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [ - F.linear(intermediate_states[i], down_proj_slices[i]) - for i in range(self.config.pretraining_tp) - ] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # Remove the pretraining_tp > 1 branch in favor of a unified parallel layer implementation. + gate_output = self.gate_proj(x) + up_output = self.up_proj(x) + + down_proj = self.down_proj(self.act_fn(gate_output) * up_output) + if self._tp_size > 1: + down_proj = _AllReduce.apply(down_proj, dist.ReduceOp.SUM, self.tp_group) return down_proj diff --git a/specforge/tracker.py b/specforge/tracker.py index b5bb1352..cc5beb28 100644 --- a/specforge/tracker.py +++ b/specforge/tracker.py @@ -196,7 +196,7 @@ def log(self, log_dict: Dict[str, Any], step: Optional[int] = None): swanlab.log(log_dict, step=step) def close(self): - if self.rank == 0 and self.is_initialized and swanlab.is_running(): + if self.rank == 0 and self.is_initialized: swanlab.finish() self.is_initialized = False diff --git a/tests/test_draft_modeling_tp.py b/tests/test_draft_modeling_tp.py new file mode 100644 index 00000000..d1c0d437 --- /dev/null +++ b/tests/test_draft_modeling_tp.py @@ -0,0 +1,320 @@ +# Filename: test_tp_correctness.py (Final version with tests for both MLP and Attention) + +import os +import tempfile +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from transformers import LlamaConfig +from transformers.activations import ACT2FN + +from specforge.distributed import destroy_distributed, init_distributed +from specforge.modeling.draft.llama3_eagle import LlamaAttention, LlamaMLP + + +# === Temporary, Non-Parallel Model Definitions (for this test file only) === +class VanillaLlamaMLP(nn.Module): + """Temporary non-parallel model to generate the MLP baseline answer.""" + + def __init__(self, config): + super().__init__() + self.gate_proj = nn.Linear( + config.hidden_size, config.intermediate_size, bias=False + ) + self.up_proj = nn.Linear( + config.hidden_size, config.intermediate_size, bias=False + ) + self.down_proj = nn.Linear( + config.intermediate_size, config.hidden_size, bias=False + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# To make VanillaLlamaAttention work standalone, we need to copy some helper functions and classes. +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + cos = cos.squeeze(1).squeeze(0) + sin = sin.squeeze(1).squeeze(0) + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class VanillaLlamaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float() / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._set_cos_sin_cache(seq_len=max_position_embeddings) + + def _set_cos_sin_cache(self, seq_len): + t = torch.arange(seq_len, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :], persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :], persistent=False + ) + + def forward(self, x, seq_len=None): + return ( + self.cos_cached[:, :, :seq_len, ...], + self.sin_cached[:, :, :seq_len, ...], + ) + + +class VanillaLlamaAttention(nn.Module): + """Temporary non-parallel model to generate the Attention baseline answer.""" + + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + + # The input for Eagle Attention is hidden_size * 2 + self.q_proj = nn.Linear( + self.hidden_size * 2, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + self.rotary_emb = VanillaLlamaRotaryEmbedding(self.head_dim) + + def forward(self, hidden_states, position_ids): + bsz, q_len, _ = hidden_states.size() + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_key_value_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_key_value_heads, self.head_dim) + .transpose(1, 2) + ) + + cos, sin = self.rotary_emb(query_states, seq_len=q_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, is_causal=True + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + return self.o_proj(attn_output) + + +# === Core Parallel Test Functions === + + +def run_mlp_tp_test(rank, world_size, temp_dir_name): + """ + This function executes the parallel computation for MLP and compares it with the 'golden standard'. + Args: + rank (int): The rank of the current process, passed automatically by mp.spawn. + world_size (int): The total number of processes, passed automatically by mp.spawn. + temp_dir_name (str): The path to the temporary directory for saving and loading weights. + """ + os.environ["RANK"], os.environ["WORLD_SIZE"] = str(rank), str(world_size) + os.environ["MASTER_ADDR"], os.environ["MASTER_PORT"] = "localhost", "29503" + init_distributed(tp_size=world_size) + torch.cuda.set_device(rank) + config = LlamaConfig( + hidden_size=128, + num_attention_heads=8, + num_key_value_heads=4, + intermediate_size=512, + ) + + # Load the actual LlamaMLP with parallel layers from your project + mlp_tp2 = LlamaMLP(config).cuda(rank) + full_state_dict = torch.load(os.path.join(temp_dir_name, "mlp_weights.pth")) + + sharded_state_dict = {} + for name, param in full_state_dict.items(): + if "gate_proj.weight" in name or "up_proj.weight" in name: + sharded_param = param.chunk(world_size, dim=0)[rank] + elif "down_proj.weight" in name: + sharded_param = param.chunk(world_size, dim=1)[rank] + else: + sharded_param = param + sharded_state_dict[name] = sharded_param + mlp_tp2.load_state_dict(sharded_state_dict) + mlp_tp2.eval() + + input_tensor = torch.load(os.path.join(temp_dir_name, "mlp_input.pth")).cuda(rank) + output_tp2 = mlp_tp2(input_tensor) + + if rank == 0: + output_tp1 = torch.load(os.path.join(temp_dir_name, "mlp_output.pth")) + assert torch.allclose( + output_tp1, output_tp2.cpu(), rtol=1e-4, atol=1e-5 + ), "Output mismatch for LlamaMLP between TP=1 and TP=2!" + print("✅ LlamaMLP TP correctness test passed!") + destroy_distributed() + + +def run_attention_tp_test(rank, world_size, temp_dir_name): + """This function executes the parallel computation for Attention and compares it with the 'golden standard'.""" + os.environ["RANK"], os.environ["WORLD_SIZE"] = str(rank), str(world_size) + os.environ["MASTER_ADDR"], os.environ["MASTER_PORT"] = "localhost", "29504" + init_distributed(tp_size=world_size) + torch.cuda.set_device(rank) + config = LlamaConfig( + hidden_size=128, + num_attention_heads=8, + num_key_value_heads=4, + intermediate_size=512, + ) + + # Load the actual LlamaAttention with parallel layers from your project + attn_tp2 = LlamaAttention(config).cuda(rank) + full_state_dict = torch.load(os.path.join(temp_dir_name, "attn_weights.pth")) + + sharded_state_dict = {} + for name, param in full_state_dict.items(): + if "rotary_emb" in name: + sharded_param = param + elif any( + s in name for s in ["q_proj.weight", "k_proj.weight", "v_proj.weight"] + ): + sharded_param = param.chunk(world_size, dim=0)[rank] + elif "o_proj.weight" in name: + sharded_param = param.chunk(world_size, dim=1)[rank] + else: + sharded_param = param + sharded_state_dict[name] = sharded_param + attn_tp2.load_state_dict(sharded_state_dict, strict=False) + attn_tp2.eval() + + input_tensor = torch.load(os.path.join(temp_dir_name, "attn_input.pth")).cuda(rank) + pos_ids = torch.load(os.path.join(temp_dir_name, "attn_pos_ids.pth")).cuda(rank) + output_tp2 = attn_tp2(input_tensor, position_ids=pos_ids) + + if rank == 0: + output_tp1 = torch.load(os.path.join(temp_dir_name, "attn_output.pth")) + assert torch.allclose( + output_tp1, output_tp2.cpu(), rtol=1e-4, atol=1e-5 + ), "Output mismatch for LlamaAttention between TP=1 and TP=2!" + print("✅ LlamaAttention TP correctness test passed!") + destroy_distributed() + + +# === unittest Launcher === +class TestTPCorrectness(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + self.temp_dir.cleanup() + + def test_mlp_correctness(self): + world_size = 2 + temp_dir_path = self.temp_dir.name + print("\n--- Running MLP TP Correctness Test ---") + + # Phase 1: Generate the "golden standard" for MLP + torch.manual_seed(42) + config = LlamaConfig( + hidden_size=128, + num_attention_heads=8, + num_key_value_heads=4, + intermediate_size=512, + ) + mlp_tp1 = VanillaLlamaMLP(config) + mlp_tp1.eval() + input_tensor = torch.randn(2, 10, config.hidden_size) + output_tp1 = mlp_tp1(input_tensor) + + torch.save(mlp_tp1.state_dict(), os.path.join(temp_dir_path, "mlp_weights.pth")) + torch.save(input_tensor, os.path.join(temp_dir_path, "mlp_input.pth")) + torch.save(output_tp1, os.path.join(temp_dir_path, "mlp_output.pth")) + + # Phase 2 & 3: Spawn parallel processes + mp.spawn(run_mlp_tp_test, nprocs=world_size, args=(world_size, temp_dir_path)) + + def test_attention_correctness(self): + world_size = 2 + temp_dir_path = self.temp_dir.name + print("\n--- Running Attention TP Correctness Test ---") + + # Phase 1: Generate the "golden standard" for Attention + torch.manual_seed(42) + config = LlamaConfig( + hidden_size=128, + num_attention_heads=8, + num_key_value_heads=4, + intermediate_size=512, + ) + attn_tp1 = VanillaLlamaAttention(config) + attn_tp1.eval() + + input_tensor = torch.randn(2, 10, config.hidden_size * 2) + pos_ids = torch.arange(10, dtype=torch.long).unsqueeze(0).expand(2, -1) + output_tp1 = attn_tp1(input_tensor, position_ids=pos_ids) + + torch.save( + attn_tp1.state_dict(), os.path.join(temp_dir_path, "attn_weights.pth") + ) + torch.save(input_tensor, os.path.join(temp_dir_path, "attn_input.pth")) + torch.save(pos_ids, os.path.join(temp_dir_path, "attn_pos_ids.pth")) + torch.save(output_tp1, os.path.join(temp_dir_path, "attn_output.pth")) + + # Phase 2 & 3: Spawn parallel processes + mp.spawn( + run_attention_tp_test, nprocs=world_size, args=(world_size, temp_dir_path) + ) + + +if __name__ == "__main__": + unittest.main()