From 9a8afefe12a21572539d1ab6fbbc49ed5b89777e Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Thu, 4 Dec 2025 12:22:05 -0800 Subject: [PATCH 01/19] implement eagle+medusa in HF Signed-off-by: Ye Yu --- examples/speculative_decoding/launch_train.sh | 2 - .../speculative/plugins/megatron_eagle.py | 198 +----------------- .../torch/speculative/plugins/transformers.py | 86 +++++--- 3 files changed, 64 insertions(+), 222 deletions(-) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 1a6e5707d..b4ee1c641 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -120,8 +120,6 @@ LR=${LR:-"1e-4"} TRAIN_BS=${TRAIN_BS:-4} MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1} MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1} -REDRAFTER_TOKENS=${REDRAFTER_TOKENS:-1} -REDRAFTER_NUM_LAYERS=${REDRAFTER_NUM_LAYERS:-1} FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"} NUM_GPU=${NUM_GPU:-1} TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048} diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 651fca587..f36b1f0a5 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -17,7 +17,6 @@ import copy import warnings -from collections import deque import megatron.core import torch @@ -55,12 +54,7 @@ from ..eagle.conversion import EagleDMRegistry from ..eagle.eagle_model import EagleModel -from ..utils import ( - AcceptanceRateValidation, - Tree, - TreeNode, - get_default_attention_mask_and_position_ids, -) +from ..utils import AcceptanceRateValidation, get_default_attention_mask_and_position_ids from .megatron_medusa import MedusaHead try: @@ -945,17 +939,16 @@ def _eagle_forward( else: eagle_logits, _ = self.output_layer(eagle_hidden_states, weight=output_weight) + draft_logits_list = [eagle_logits] if self.eagle_config.parallel_draft_step > 1: # Get additional draft logits from parallel draft heads - draft_logits_list = [eagle_logits] for draft_head in self.eagle_module.parallel_draft_heads: draft_logits, _ = draft_head(eagle_hidden_states) draft_logits_list.append(draft_logits) - eagle_logits = torch.cat(draft_logits_list, dim=0) return ( eagle_hidden_states, - eagle_logits, + draft_logits_list, eagle_hidden_states_pre_final_layernorm, ) @@ -1112,7 +1105,7 @@ def forward( return logits_sbh.transpose(0, 1).contiguous() for i in range(self.eagle_config.parallel_draft_step): - eagle_logit = eagle_logits[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] + eagle_logit = eagle_logits[i] if i > 0: loss_ = self._compute_eagle_loss( logits_sbh[i:], labels[:, i:], eagle_logit[:-i] @@ -1129,9 +1122,7 @@ def forward( gathered_base_logits = gather_from_tensor_model_parallel_region(logits_sbh) base_top1 = gathered_base_logits.transpose(0, 1).argmax(dim=-1) for i in range(self.eagle_config.parallel_draft_step): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] - ) + gathered_logits = gather_from_tensor_model_parallel_region(eagle_logits[i]) gathered_logits = gathered_logits[ttt_step : -(1 + i)] eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: @@ -1151,177 +1142,6 @@ def forward( return loss - def tree_decode(self, input_ids: torch.Tensor, tree: Tree): - """Tree-based decoding for EAGLE model using a mask-based approach. - - This function implements a tree-based decoding strategy where each path of the tree - represents potential token sequences. The function uses attention masks to control - token dependencies and generate multiple candidate sequences in parallel. - - Args: - input_ids (torch.Tensor): Input token IDs of shape [batch_size, seq_len] - treepaths (list[list[int]]): List of treepaths to decode - - Returns: - tuple: (base_token, base_draft_node, draft_tokens) - - base_token: The next token predicted by the base model - - base_draft_node: A TreeNode containing the base token prediction with a - hierarchical structure of child nodes, where each child node - represents a draft token generated by EAGLE - - draft_tokens: all the draft tokens generated by EAGLE - """ - # Initial setup and base model forward pass - padded_input_ids, seq_len = right_padding(input_ids) - attention_mask, position_ids = get_default_attention_mask_and_position_ids(padded_input_ids) - - # Get base model hidden states - hidden_states, _ = self._base_model_forward( - padded_input_ids, - position_ids, - attention_mask, - ) - - if not self.post_process: - return hidden_states - - # Generate base token prediction - output_weight = ( - self.shared_embedding_or_output_weight() - if self.share_embeddings_and_output_weights - else None - ) - logits_sbh, _ = self.output_layer(hidden_states, weight=output_weight) - logits_sbh = logits_sbh[:seq_len, :, :] - - base_token = ( - gather_from_tensor_model_parallel_region(logits_sbh)[-1:, :, :] - .argmax(dim=-1) - .transpose(0, 1) - ) - - # Early return if no steps needed - if not tree.root.children: - self._aux_hidden_states.clear() - return base_token, None, None - - # Prepare for tree decoding - eagle_ids = torch.cat((input_ids[:, 1:], base_token), dim=-1) - # EAGLE-3 - # Only the first iteration input_hidden_states are from aux_hidden_state layers - hidden_states = self._get_eagle_input_hidden_states(hidden_states) - - if self.config.sequence_parallel: - hidden_states = gather_from_sequence_parallel_region(hidden_states) - hidden_states = hidden_states[:seq_len, :, :] - - # relative id from [seq_len-1, seq_len] contains draft token position - # [seq_len, seq_len + num_child_level_1] contains the number of children for level 1 and so on - relative_ids = torch.tensor( - [seq_len - 1, *list(tree.num_children.values())], - device=input_ids.device, - ).cumsum(dim=0) - - draft_position_ids = torch.arange(relative_ids[-1], device=input_ids.device) - cur_pos = seq_len - 1 - for idx in range(len(relative_ids) - 1): - draft_position_ids[relative_ids[idx] : relative_ids[idx + 1]] = cur_pos - cur_pos += 1 - - draft_attention_mask = torch.full( - (1, 1, relative_ids[-1], relative_ids[-1]), True, device=input_ids.device - ).triu_(1) - draft_attention_mask[:, :, :, seq_len:] = True - draft_attention_mask[:, :, seq_len - 1 :, seq_len - 1 :] = tree.attention_mask - - draft_rotary_pos_emb = self.eagle_module.rotary_pos_emb(seq_len + tree.max_depth) - draft_rotary_pos_emb = torch.cat( - [draft_rotary_pos_emb[index : index + 1] for index in draft_position_ids], dim=0 - ) - - base_draft_node = TreeNode(base_token) - queue = deque([(base_draft_node, tree.root)]) - draft_tokens = [] - # Tree decoding loop - for step in range(tree.max_depth): - # Prepare inputs for EAGLE forward pass - padded_eagle_ids, seq_len, padded_hidden_states = right_padding( - eagle_ids, hidden_states - ) - - if self.config.sequence_parallel: - padded_hidden_states = scatter_to_sequence_parallel_region(padded_hidden_states) - - eagle_attention_mask, eagle_position_ids = get_default_attention_mask_and_position_ids( - padded_eagle_ids - ) - length = eagle_ids.shape[-1] - eagle_attention_mask[:, :, :length, :length] = draft_attention_mask[ - :, :, :length, :length - ] - eagle_attention_mask[:, :, length:, length:] = True - eagle_position_ids[:length] = draft_position_ids[:length] - padded_rotary_pos_emb = self.eagle_module.rotary_pos_emb(padded_eagle_ids.shape[-1]) - padded_rotary_pos_emb[:length] = draft_rotary_pos_emb[:length] - - eagle_inputs = { - "input_ids": padded_eagle_ids, - "embedding": self.embedding( - input_ids=padded_eagle_ids, - position_ids=eagle_position_ids, - ), - "hidden_states": padded_hidden_states, - "attention_mask": eagle_attention_mask, - "rotary_pos_emb": padded_rotary_pos_emb, - } - - # Forward pass through EAGLE - _, eagle_logits, eagle_next_hidden_states_input = self._eagle_forward( - eagle_inputs, - output_weight, - ) - # Process EAGLE outputs - eagle_logits = eagle_logits[:seq_len, :, :] - if self.config.sequence_parallel: - eagle_next_hidden_states_input = gather_from_sequence_parallel_region( - eagle_next_hidden_states_input - ) - eagle_next_hidden_states_input = eagle_next_hidden_states_input[:seq_len, :, :] - # Generate and store top-k tokens for each tree node - for rel_idx in range(relative_ids[step], relative_ids[step + 1]): - draft_node, tree_node = queue.popleft() - n_topk = max(tree_node.children.keys()) + 1 if tree_node.children else 0 - # Get top-k tokens for current position - new_ids = ( - gather_from_tensor_model_parallel_region(eagle_logits)[ - rel_idx : rel_idx + 1, :, : - ] - .topk(n_topk, dim=-1)[1] - .squeeze(0) - ) - - for child_idx, child_node in tree_node.children.items(): - eagle_ids = torch.cat( - (eagle_ids, new_ids[:, child_idx : child_idx + 1]), dim=-1 - ) - # value of the node is token id - new_draft_node = TreeNode(new_ids[:, child_idx]) - draft_tokens.append(new_ids[:, child_idx]) - draft_node.children[child_idx] = new_draft_node - queue.append((new_draft_node, child_node)) - - # Update hidden states for each branch - hidden_states = torch.cat( - ( - hidden_states, - eagle_next_hidden_states_input[rel_idx : rel_idx + 1].repeat( - len(tree_node.children), 1, 1 - ), - ), - dim=0, - ) - draft_tokens = torch.cat(draft_tokens, dim=-1) - return base_token, base_draft_node, draft_tokens - def pseudo_speculative_generate( self, input_ids: torch.Tensor, @@ -1409,14 +1229,10 @@ def pseudo_speculative_generate( if self.eagle_config.parallel_draft_step > 1: parallel_logits = [ - eagle_logits[ - padded_eagle_ids.shape[-1] * i + seq_len - 1 : padded_eagle_ids.shape[-1] - * i - + seq_len - ] + eagle_logits[i][seq_len - 1 : seq_len] for i in range(1, self.eagle_config.parallel_draft_step) ] - eagle_logits = eagle_logits[:seq_len, :, :] + eagle_logits = eagle_logits[0][:seq_len, :, :] if self.config.sequence_parallel: eagle_next_hidden_states_input = gather_from_sequence_parallel_region( eagle_next_hidden_states_input diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 1aed13e87..ff9f63cfd 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -251,6 +251,15 @@ def __init__(self, config, decoder_layer_cls, bias=False): # Disable input norm in first layer. We normed embeds and h individually before. self.layers[0].input_layernorm = nn.Identity() + if self.config.parallel_draft_step > 1: + self.parallel_draft_heads = torch.nn.ModuleList( + nn.Sequential( + *([ResBlock(config.hidden_size)] * self.config.parallel_draft_heads_num_layers), + nn.Linear(config.hidden_size, config.draft_vocab_size, bias=False), + ) + for _ in range(self.config.parallel_draft_step - 1) + ) + def _eagle3_attention_forward_pre_hook(self, module, args, kwargs): """Concat input_embeds and hidden_states for EAGLE-3's first attention layer.""" if "hidden_states" not in kwargs: @@ -669,8 +678,6 @@ def _base_model_forward( labels, **kwargs, ): - # TODO: This function still use eagle_module. Ideally we should remove it, - # so we can del model.eagle_module on the base model ranks to save memory. with torch.no_grad() if freeze_base_model else contextlib.nullcontext(): outputs = super().forward( input_ids=input_ids, @@ -692,11 +699,6 @@ def _base_model_forward( labels = labels.view(-1) base_model_loss = loss_fct(loss_logits, labels) - # Map the base model logits to the draft vocab - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size and self.training: - assert hasattr(self.eagle_module, "d2t"), "d2t buffer not initialized" - base_model_logits = self._map_logits_to_draft_vocab(base_model_logits) - return base_model_hidden_states, base_model_logits, base_model_loss, past_key_values def _map_logits_to_draft_vocab(self, full_logits): @@ -731,7 +733,14 @@ def _eagle_forward( ) eagle_logits = eagle_lm_head(eagle_postnorm_h) - return eagle_postnorm_h, eagle_prenorm_h, eagle_logits, eagle_cache + draft_logits_list = [eagle_logits] + if self.eagle_config.parallel_draft_step > 1: + # Get additional draft logits from parallel draft heads + for draft_head in self.eagle_module.parallel_draft_heads: + draft_logits = draft_head(eagle_postnorm_h) + draft_logits_list.append(draft_logits) + + return eagle_postnorm_h, eagle_prenorm_h, draft_logits_list, eagle_cache def forward( self, @@ -778,8 +787,6 @@ def forward( base_model_logits = base_outputs["base_model_logits"] else: base_model_logits = self.lm_head(base_model_hidden_states) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - base_model_logits = self._map_logits_to_draft_vocab(base_model_logits) base_model_loss = None past_key_values = DynamicCache() # Dummy cache @@ -803,7 +810,7 @@ def forward( # ====Run eagle forward==== eagle_loss = None - train_accs = [] + train_accs = [[] * self.eagle_config.parallel_draft_step] # In EAGLE-3, we have an additional FC layer to concentrate hidden states from multiple base model layers b, seq_length, h = base_model_hidden_states.shape if self.eagle_config.use_aux_hidden_state: @@ -855,24 +862,31 @@ def forward( ), dim=1, ) - classification_loss, acc = self._eagle_loss( - # base model predict +1 tok, while eagle predict +2 - # so we shift base model outputs compared to eagle outputs - base_model_logits[:, 1:], - eagle_logits[:, :-1], - # additionally, we mask the first n tok of eagle outputs at nth TTT step - torch.cat( - ( - torch.zeros(b, ttt_step, dtype=loss_mask.dtype, device=loss_mask.device), - loss_mask[:, 1 + ttt_step :], + for i in range(self.eagle_config.parallel_draft_step): + eagle_logit = eagle_logits[i] + classification_loss, acc = self._eagle_loss( + # base model predict +1 tok, while eagle predict +2 + # so we shift base model outputs compared to eagle outputs + base_model_logits[:, 1 + i :], + eagle_logit[:, : -(1 + i)], + # additionally, we mask the first n tok of eagle outputs at nth TTT step + torch.cat( + ( + torch.zeros( + b, ttt_step, dtype=loss_mask.dtype, device=loss_mask.device + ), + loss_mask[:, 1 + ttt_step :] + if i == 0 + else loss_mask[:, 1 + ttt_step : -i], + ), + dim=1, ), - dim=1, - ), - ) - eagle_loss = ( - classification_loss if eagle_loss is None else eagle_loss + classification_loss - ) - train_accs.append(acc) + ) + classification_loss *= self.eagle_loss_decay_factor ** (ttt_step + i) + eagle_loss = ( + classification_loss if eagle_loss is None else eagle_loss + classification_loss + ) + train_accs[i].append(acc) if not self.training: break # Finally, we merge base model loss and eagle loss, raise error if both are None @@ -903,6 +917,9 @@ def _eagle_loss( loss_mask, ): """Function for EAGLE loss computing.""" + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + assert hasattr(self.eagle_module, "d2t"), "d2t buffer not initialized" + base_model_logits = self._map_logits_to_draft_vocab(base_model_logits) loss_mask = loss_mask[:, :, None] classification_loss = nn.Softmax(dim=2)(base_model_logits) * nn.LogSoftmax(dim=2)( eagle_logits @@ -987,7 +1004,12 @@ def pseudo_speculative_generate( position_embeddings, ) - draft_token = eagle_logits[:, -1:, :].argmax(dim=-1) + if self.eagle_config.parallel_draft_step > 1: + parallel_logits = [ + eagle_logits[i][:, -1:, :] + for i in range(1, self.eagle_config.parallel_draft_step) + ] + draft_token = eagle_logits[0][:, -1:, :].argmax(dim=-1) if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: draft_token += self.eagle_module.d2t[draft_token] draft_tokens.append(draft_token) @@ -998,6 +1020,12 @@ def pseudo_speculative_generate( ) draft_tokens = torch.cat(draft_tokens, dim=-1).to(base_token.device) + if self.eagle_config.parallel_draft_step > 1: + parallel_logits = torch.cat(parallel_logits, dim=1) + parallel_tokens = parallel_logits.argmax(dim=-1) + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + parallel_tokens += self.eagle_module.d2t[parallel_tokens] + draft_tokens = torch.cat((draft_tokens, parallel_tokens), dim=-1).to(base_token.device) return base_token, draft_tokens From 732e9e79863999835e33a29500348c92d9bda14b Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 8 Dec 2025 10:49:45 -0800 Subject: [PATCH 02/19] implement eagle+medusa and update EagleTrainingPlot accordingly Signed-off-by: Ye Yu --- examples/speculative_decoding/eagle_utils.py | 18 +++++++++++++----- .../speculative/plugins/megatron_medusa.py | 3 ++- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 4c6da77a1..ad88cfbcc 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -516,20 +516,28 @@ def on_log(self, args, state, control, **kwargs): if not hasattr(state, "training_accs") or len(state.training_accs) == 0: return control average_acc = np.mean(state.training_accs, axis=0) + parallel_draft_step = average_acc.shape[0] if self.estimate_ar: # Calculate mean training AR since last log - # NOTE: This is only a estimate of the real AR. + # NOTE: This is only an estimate of the real AR. est_ar = 1 acc_cumprod = 1 - for step_acc in average_acc: - est_ar += acc_cumprod * step_acc + for step_acc in average_acc[0]: acc_cumprod *= step_acc + est_ar += acc_cumprod + # Parallel draft tokens only used after all eagle tokens + for draft_acc in average_acc[1:]: + acc_cumprod *= draft_acc[-1] + est_ar += acc_cumprod print_rank_0(f"Step {state.global_step} Estimated Training AR: {est_ar:.4f}") # log to wandb if wandb and is_master(): - for i, step_acc in enumerate(average_acc): - wandb.log({f"step_{i}_train_acc": step_acc}, step=state.global_step) + for i, draft_acc in enumerate(average_acc): + for j, step_acc in enumerate(draft_acc): + wandb.log( + {f"parallel_{i}_step_{j}_train_acc": step_acc}, step=state.global_step + ) if self.estimate_ar: wandb.log({"estimated_training_ar": est_ar}, step=state.global_step) diff --git a/modelopt/torch/speculative/plugins/megatron_medusa.py b/modelopt/torch/speculative/plugins/megatron_medusa.py index 21f8b51ab..ccc6c7a69 100644 --- a/modelopt/torch/speculative/plugins/megatron_medusa.py +++ b/modelopt/torch/speculative/plugins/megatron_medusa.py @@ -36,7 +36,7 @@ class MedusaLayer(MegatronModule): Medusa layer consists of a column parallel linear following a silu. """ - def __init__(self, config): + def __init__(self, config, bias=True): """Constructor. Args: @@ -53,6 +53,7 @@ def __init__(self, config): self.linear = torch.nn.Linear( config.hidden_size, config.hidden_size, + bias=bias, dtype=config.params_dtype, device=device, ) From 0f8d8bed4216b380e0c47fb6639747978ef99aa6 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 8 Dec 2025 10:54:16 -0800 Subject: [PATCH 03/19] minor Signed-off-by: Ye Yu --- examples/speculative_decoding/eagle_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index ad88cfbcc..f575c4bf7 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -516,7 +516,6 @@ def on_log(self, args, state, control, **kwargs): if not hasattr(state, "training_accs") or len(state.training_accs) == 0: return control average_acc = np.mean(state.training_accs, axis=0) - parallel_draft_step = average_acc.shape[0] if self.estimate_ar: # Calculate mean training AR since last log # NOTE: This is only an estimate of the real AR. From d3c40a0de98303dd023690287bf94625e919c5a8 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Tue, 9 Dec 2025 10:02:14 -0800 Subject: [PATCH 04/19] debug Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index ff9f63cfd..373727b59 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -810,7 +810,7 @@ def forward( # ====Run eagle forward==== eagle_loss = None - train_accs = [[] * self.eagle_config.parallel_draft_step] + train_accs = [[] for _ in range(self.eagle_config.parallel_draft_step)] # In EAGLE-3, we have an additional FC layer to concentrate hidden states from multiple base model layers b, seq_length, h = base_model_hidden_states.shape if self.eagle_config.use_aux_hidden_state: From aec7124dea6afb2d536f25c198859bc7fc2510ae Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Tue, 9 Dec 2025 11:58:59 -0800 Subject: [PATCH 05/19] move hidden_size and vocab from main to model modify Signed-off-by: Ye Yu --- examples/speculative_decoding/launch_train.sh | 6 ++++++ examples/speculative_decoding/main.py | 21 ++++--------------- .../torch/speculative/eagle/default_config.py | 4 ---- .../speculative/plugins/megatron_eagle.py | 11 +++++++--- .../torch/speculative/plugins/transformers.py | 11 ++++++++++ 5 files changed, 29 insertions(+), 24 deletions(-) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index b4ee1c641..4ce015d03 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -90,6 +90,10 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi VLM_IMG_DIR="${1#*=}" ;; + --estimate_ar*) + if [[ "$1" != *=* ]]; then shift; fi + ESTIMATE_AR="${1#*=}" + ;; --ar_validate_steps*) if [[ "$1" != *=* ]]; then shift; fi AR_VALIDATE_STEPS="${1#*=}" @@ -128,6 +132,7 @@ DISABLE_TQDM=${DISABLE_TQDM:-False} VLM_PROCESSOR=${VLM_PROCESSOR:-} VLM_IMG_DIR=${VLM_IMG_DIR:-} AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000} +ESTIMATE_AR=${ESTIMATE_AR:-False} if [[ "$MODE" == "medusa" ]]; then SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS" @@ -190,6 +195,7 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \ --tf32 True \ --data_path $DATA \ --disable_tqdm $DISABLE_TQDM \ + --estimate_ar $ESTIMATE_AR \ --ar_validate_steps $AR_VALIDATE_STEPS \ $VLM_ARGS \ $OFFLINE_TRAINING_ARGS \ diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 35c9b8ede..98105b765 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -92,6 +92,9 @@ class TrainingArguments(transformers.TrainingArguments): dataloader_drop_last: bool = field(default=True) bf16: bool = field(default=True) mode: Literal["eagle1", "eagle3", "medusa"] = "eagle3" + estimate_ar: bool = field( + default=False, metadata={"help": "Whether to estimate AR during training for logging."} + ) ar_validate_steps: int = field(default=1000, metadata={"help": "Steps between AR validation."}) disable_tqdm: bool = field(default=False, metadata={"help": "Disable tqdm progress bar."}) remove_unused_columns: bool = field( @@ -193,22 +196,6 @@ def train(): custom_config = json.load(f) config["eagle_architecture_config"].update(custom_config) - # Hidden size and vocab size must match base model - llm_config = ( - model.config.llm_config if hasattr(model.config, "llm_config") else model.config - ) - config["eagle_architecture_config"].update( - { - "hidden_size": llm_config.hidden_size, - "vocab_size": llm_config.vocab_size, - # we also overwrite max_pos_embedding for deployment compatibility - "max_position_embeddings": llm_config.max_position_embeddings, - "draft_vocab_size": custom_config["draft_vocab_size"] - if eagle_args.eagle_config and "draft_vocab_size" in custom_config - else llm_config.vocab_size, - } - ) - mtsp.convert(model, [("eagle", config)]) # read draft vocab cache @@ -238,7 +225,7 @@ def train(): model=model, processing_class=tokenizer, args=training_args, - callbacks=[EagleTrainingPlot(training_args.ar_validate_steps)], + callbacks=[EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)], **data_module, ) diff --git a/modelopt/torch/speculative/eagle/default_config.py b/modelopt/torch/speculative/eagle/default_config.py index 415b8373f..4a00e4364 100644 --- a/modelopt/torch/speculative/eagle/default_config.py +++ b/modelopt/torch/speculative/eagle/default_config.py @@ -18,9 +18,6 @@ default_eagle_config = { "hidden_act": "silu", "torch_dtype": "bfloat16", - "vocab_size": 128256, - "draft_vocab_size": 128256, - "max_position_embeddings": 8192, "position_embedding_type": "rope", "rope_scaling": { "factor": 8.0, @@ -31,7 +28,6 @@ }, "rope_theta": 500000.0, "num_hidden_layers": 1, - "hidden_size": 4096, "intermediate_size": 14336, "num_attention_heads": 32, "num_key_value_heads": 8, diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index f36b1f0a5..734703766 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -88,7 +88,6 @@ def dict_to_config( params_dtype=getattr(torch, architecture_config["torch_dtype"]), pipeline_dtype=getattr(torch, architecture_config["torch_dtype"]), num_layers=architecture_config.get("num_hidden_layers"), - hidden_size=architecture_config.get("hidden_size"), ffn_hidden_size=architecture_config.get("intermediate_size"), num_attention_heads=architecture_config.get("num_attention_heads"), kv_channels=architecture_config.get( @@ -106,8 +105,6 @@ def dict_to_config( config.transformer_layer_spec = None config.seq_length = 8192 config.gradient_accumulation_fusion = False - config.vocab_size = architecture_config.get("vocab_size") - config.max_sequence_length = architecture_config.get("max_position_embeddings") config.position_embedding_type = architecture_config.get("position_embedding_type") config.rotary_percent = 1.0 config.rotary_base = architecture_config.get("rope_theta") @@ -676,6 +673,14 @@ def modify( self.config.bf16, self.config.sequence_parallel, ) + self.eagle_config.hidden_size = self.config.hidden_size + self.eagle_config.vocab_size = self.vocab_size + self.eagle_config.max_sequence_length = self.max_sequence_length + self.eagle_config.draft_vocab_size = ( + self.vocab_size + if self.eagle_config.draft_vocab_size is None + else self.eagle_config.draft_vocab_size + ) if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: assert eagle_self_logit_distillation, ( diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 373727b59..16bbabf1e 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -465,6 +465,17 @@ def modify( eagle_architecture_config=eagle_architecture_config, ) self.eagle_config = PretrainedConfig.from_dict(eagle_architecture_config) + # Hidden size and vocab size must match base model + llm_config = self.config.llm_config if hasattr(self.config, "llm_config") else self.config + self.eagle_config.hidden_size = llm_config.hidden_size + self.eagle_config.vocab_size = llm_config.vocab_size + self.eagle_config.max_position_embeddings = llm_config.max_position_embeddings + self.eagle_config.draft_vocab_size = ( + self.eagle_config.vocab_size + if self.eagle_config.draft_vocab_size is None + else self.eagle_config.draft_vocab_size + ) + if self.eagle_config._attn_implementation is None: self.eagle_config._attn_implementation = "sdpa" decoder_cls = ( From c01acc1729e8307ced365f46716063c588fae7fa Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Tue, 9 Dec 2025 12:08:16 -0800 Subject: [PATCH 06/19] debug Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/transformers.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 16bbabf1e..1f3884e0a 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -470,10 +470,8 @@ def modify( self.eagle_config.hidden_size = llm_config.hidden_size self.eagle_config.vocab_size = llm_config.vocab_size self.eagle_config.max_position_embeddings = llm_config.max_position_embeddings - self.eagle_config.draft_vocab_size = ( - self.eagle_config.vocab_size - if self.eagle_config.draft_vocab_size is None - else self.eagle_config.draft_vocab_size + self.eagle_config.draft_vocab_size = getattr( + self.eagle_config, "draft_vocab_size", self.eagle_config.vocab_size ) if self.eagle_config._attn_implementation is None: From a9494f0eee661fd0165b3de80f21bb7b559466f1 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 10 Dec 2025 09:25:57 -0800 Subject: [PATCH 07/19] fix a bug in default_config Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 10 +++++----- modelopt/torch/speculative/plugins/transformers.py | 4 ++++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 734703766..a568eb3ad 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -90,11 +90,7 @@ def dict_to_config( num_layers=architecture_config.get("num_hidden_layers"), ffn_hidden_size=architecture_config.get("intermediate_size"), num_attention_heads=architecture_config.get("num_attention_heads"), - kv_channels=architecture_config.get( - "head_dim", - architecture_config.get("hidden_size") - // architecture_config.get("num_attention_heads"), - ), + kv_channels=architecture_config.get("head_dim"), num_query_groups=architecture_config.get("num_key_value_heads"), init_method_std=architecture_config.get("initializer_range"), layernorm_epsilon=architecture_config.get("rms_norm_eps"), @@ -681,6 +677,10 @@ def modify( if self.eagle_config.draft_vocab_size is None else self.eagle_config.draft_vocab_size ) + if self.eagle_config.kv_channels is None: + self.eagle_config.kv_channels = ( + self.eagle_config.hidden_size // self.eagle_config.num_attention_heads + ) if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: assert eagle_self_logit_distillation, ( diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 1f3884e0a..13e2a50ce 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -473,6 +473,10 @@ def modify( self.eagle_config.draft_vocab_size = getattr( self.eagle_config, "draft_vocab_size", self.eagle_config.vocab_size ) + if getattr(self.eagle_config, "head_dim", None) is None: + self.eagle_config.head_dim = ( + self.eagle_config.hidden_size // self.eagle_config.num_attention_heads + ) if self.eagle_config._attn_implementation is None: self.eagle_config._attn_implementation = "sdpa" From 35d976d99dfa3a22641a5ea0ec2882e184aace27 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 10 Dec 2025 12:35:24 -0800 Subject: [PATCH 08/19] remove tree decoding test code Signed-off-by: Ye Yu --- .../test_speculative_megatron_modules.py | 181 +----------------- 1 file changed, 1 insertion(+), 180 deletions(-) diff --git a/tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py b/tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py index c1539d101..5a149b77f 100644 --- a/tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py +++ b/tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import deque -from copy import deepcopy from functools import partial import pytest @@ -24,13 +22,10 @@ from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.megatron.models import get_mcore_gpt_model -from megatron.core.tensor_parallel.mappings import gather_from_tensor_model_parallel_region import modelopt.torch.speculative as mtsp -from modelopt.torch.speculative.eagle.default_config import default_eagle_config -from modelopt.torch.speculative.plugins.megatron_eagle import _DynamicEagleGPTModel, right_padding +from modelopt.torch.speculative.plugins.megatron_eagle import _DynamicEagleGPTModel from modelopt.torch.speculative.plugins.megatron_medusa import _DynamicMedusaGPTModel -from modelopt.torch.speculative.utils import Tree, get_default_attention_mask_and_position_ids ALGO_TO_CONFIG = { "eagle1": mtsp.config.EAGLE1_DEFAULT_CFG, @@ -173,177 +168,3 @@ def test_speculative_gpt_model( ), backend="nccl", ) - - -def generate_next_tokens(model, eagle_ids, hidden_states, topk=1): - padded_eagle_ids, seq_len, padded_hidden_states = right_padding(eagle_ids, hidden_states) - eagle_attention_mask, eagle_position_ids = get_default_attention_mask_and_position_ids( - padded_eagle_ids - ) - - eagle_inputs = {} - eagle_inputs["input_ids"] = padded_eagle_ids - eagle_inputs["embedding"] = model.embedding( - input_ids=padded_eagle_ids, - position_ids=eagle_position_ids, - ) - eagle_inputs["hidden_states"] = padded_hidden_states - eagle_inputs["attention_mask"] = eagle_attention_mask - - eagle_inputs["rotary_pos_emb"] = None - - _, eagle_logits, eagle_next_hidden_states_input = model._eagle_forward(eagle_inputs, None) - - eagle_logits = eagle_logits[seq_len - 1 : seq_len, :, :] - eagle_next_hidden_states_input = eagle_next_hidden_states_input[seq_len - 1 : seq_len, :, :] - - draft_token = ( - gather_from_tensor_model_parallel_region(eagle_logits).topk(topk, dim=-1)[1].transpose(0, 1) - ) - return draft_token, eagle_next_hidden_states_input - - -def _test_tree_decode(tree_paths, greedy_steps, rank, size): - activation_func = "squared_relu" - normalization = "RMSNorm" - - num_attention_heads = 8 - num_query_groups = size - max_sequence_length = 32 - vocab_size = 64 - batch_size = 1 - - model = get_mcore_gpt_model( - tensor_model_parallel_size=size, - pipeline_model_parallel_size=1, - initialize_megatron=True, - num_attention_heads=num_attention_heads, - num_query_groups=num_query_groups, - max_sequence_length=max_sequence_length, - vocab_size=vocab_size, - activation_func=activation_func, - normalization=normalization, - ).cuda() - - config = {"eagle_architecture_config": deepcopy(default_eagle_config)} - config["eagle_architecture_config"]["hidden_size"] = model.config.hidden_size - config["eagle_architecture_config"]["vocab_size"] = model.vocab_size - config["eagle_architecture_config"]["draft_vocab_size"] = model.vocab_size - - model = mtsp.convert(model, [("eagle", config)]) - - # Bfloat16 - model = model.to(torch.bfloat16) - - # Prepare inputs for forward. - prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() - attention_mask = torch.tril(torch.ones((1, 1, max_sequence_length, max_sequence_length))).cuda() - position_ids = torch.arange(max_sequence_length, dtype=torch.long).unsqueeze(0).cuda() - attention_mask = attention_mask < 0.5 - - model.eval() - tree = Tree(tree_paths) - - input_id, draft_tokens, pred_tokens = model.tree_decode(prompt_tokens, tree=tree) - - # check for empty tree paths - if not tree_paths: - assert draft_tokens is None, "draft_tokens should be None for empty tree paths" - return - - # check when tree decode is same as greedy decode - if greedy_steps: - spec_input_id, spec_draft_tokens = model.pseudo_speculative_generate( - prompt_tokens, steps=greedy_steps - ) - assert (pred_tokens == spec_draft_tokens[0]).all(), ( - f"pred_tokens should be equal to spec_draft_tokens, {pred_tokens} != {spec_draft_tokens[0]}" - ) - assert input_id == spec_input_id[0], ( - f"spec_input_id should be equal to input_id, {input_id} != {spec_input_id[0]}" - ) - return - - orig_hidden_states, _ = model._base_model_forward( - prompt_tokens, - position_ids, - attention_mask, - ) - - # Get Eagle-specific input hidden states - eagle_hidden_states = model._get_eagle_input_hidden_states(orig_hidden_states) - # Extract tokens for Eagle processing (excluding first token) - eagle_tokens = prompt_tokens[:, 1:] - - # Initialize lists to store draft tokens and hidden states - draft_tokens_list = [input_id] - eagle_hidden_states_list = [eagle_hidden_states] - # Track indices for token and hidden state mapping - index_list = [[[0, 0]]] - - # Initialize queue for breadth-first tree traversal - queue = deque([(draft_tokens, 0)]) - - # Process tree nodes in breadth-first order - while queue: - tree_token, index = queue.popleft() - if not tree_token.children: - continue - # Collect tokens and hidden states for current node - tokens = [] - hidden_states = [] - for token_idx, state_idx in index_list[index]: - tokens.append(draft_tokens_list[token_idx]) - hidden_states.append(eagle_hidden_states_list[state_idx]) - # Concatenate tokens and hidden states for processing - tokens = torch.cat([eagle_tokens, torch.cat(tokens, dim=-1)], dim=-1) - hidden_states = torch.cat(hidden_states, dim=0) - # Generate next token and get updated hidden states - draft_token, eagle_next_hidden_states_input = generate_next_tokens( - model, tokens, hidden_states, topk=len(tree_token.children) - ) - # Verify generated tokens match expected tree structure - for child_idx, tree_node in enumerate(tree_token.children.values()): - assert tree_node.value[0] == draft_token[0, 0, child_idx], ( - f"token mismatch at {tree_node.value[0]} != {draft_token[0, 0, child_idx]}" - ) - # Update tracking variables - cur_len = len(draft_tokens_list) - eagle_hidden_states_list.append(eagle_next_hidden_states_input) - # Process children and add them to the queue - for child_idx, child_tree_token in enumerate(tree_token.children.values()): - queue.append([child_tree_token, len(index_list)]) - draft_tokens_list.append(draft_token[:, :, child_idx]) - index_list.append( - [*index_list[index][:], [cur_len + child_idx, len(eagle_hidden_states_list) - 1]] - ) - - -@pytest.mark.parametrize( - ("greedy_steps", "tree_paths"), - [ - (None, []), - (3, [[0], [0, 0], [0, 0, 0]]), - ( - None, - [ - [0], - [1], - [0, 0], - [0, 1], - [1, 1], - [0, 0, 0], - [0, 0, 1], - [1, 0, 0], - [1, 0], - [0, 0, 1, 0], - ], - ), - ], -) -def test_tree_decode_model(greedy_steps, tree_paths): - spawn_multiprocess_job( - size=torch.cuda.device_count(), - job=partial(_test_tree_decode, tree_paths, greedy_steps), - backend="nccl", - ) From 94702ade8a976b42dfbd2dc42188af296b3adaf9 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Thu, 11 Dec 2025 13:31:22 -0800 Subject: [PATCH 09/19] clean up config; fix export bug Signed-off-by: Ye Yu --- .../speculative_decoding/eagle_config.json | 8 ----- .../torch/export/unified_export_megatron.py | 32 +++++-------------- .../torch/speculative/eagle/default_config.py | 2 +- .../speculative/plugins/megatron_eagle.py | 9 ++---- .../torch/speculative/plugins/transformers.py | 16 ++++------ 5 files changed, 18 insertions(+), 49 deletions(-) diff --git a/examples/speculative_decoding/eagle_config.json b/examples/speculative_decoding/eagle_config.json index eebf64ce5..0b6ed9ef4 100644 --- a/examples/speculative_decoding/eagle_config.json +++ b/examples/speculative_decoding/eagle_config.json @@ -1,11 +1,3 @@ { - "rope_scaling": { - "factor": 32.0, - "low_freq_factor": 1.0, - "high_freq_factor": 4.0, - "original_max_position_embeddings": 8192, - "rope_type": "llama3" - }, - "initializer_range": 0.02, "_attn_implementation": "sdpa" } diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 9b82398d7..7ec6fad25 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -212,26 +212,14 @@ def __init__( ) eagle_config = { - "use_input_layernorm_in_first_layer": mode_cfg["config"][ - "eagle_architecture_config" - ]["use_input_layernorm_in_first_layer"], - "use_last_layernorm": mode_cfg["config"]["eagle_architecture_config"][ - "use_last_layernorm" - ], - "use_mtp_layernorm": mode_cfg["config"]["eagle_architecture_config"][ - "use_mtp_layernorm" - ], - "use_aux_hidden_state": mode_cfg["config"]["eagle_architecture_config"][ - "use_aux_hidden_state" - ], + "use_input_layernorm_in_first_layer": model.eagle_config.use_input_layernorm_in_first_layer, + "use_last_layernorm": model.eagle_config.use_last_layernorm, + "use_mtp_layernorm": model.eagle_config.use_mtp_layernorm, + "use_aux_hidden_state": model.eagle_config.use_aux_hidden_state, "eagle_aux_hidden_state_layer_ids": model.eagle_config.eagle_aux_hidden_state_layer_ids, "next_layer_regular": True, - "parallel_draft_step": mode_cfg["config"]["eagle_architecture_config"][ - "parallel_draft_step" - ], - "parallel_draft_heads_num_layers": mode_cfg["config"][ - "eagle_architecture_config" - ]["parallel_draft_heads_num_layers"], + "parallel_draft_step": model.eagle_config.parallel_draft_step, + "parallel_draft_heads_num_layers": model.eagle_config.parallel_draft_heads_num_layers, } eagle_config_update = { @@ -243,9 +231,7 @@ def __init__( "max_position_embeddings": self._hf_text_config.max_position_embeddings, "num_attention_heads": model.eagle_module.config.num_attention_heads, "num_key_value_heads": model.eagle_module.config.num_query_groups, - "num_hidden_layers": mode_cfg["config"]["eagle_architecture_config"][ - "num_hidden_layers" - ], + "num_hidden_layers": model.eagle_config.num_hidden_layers, "vocab_size": self._hf_text_config.vocab_size, # Unset any special token ids given that the tokenizer can change here. "bos_token_id": None, @@ -254,9 +240,7 @@ def __init__( "sep_token_id": None, # The following attributes are EAGLE specific "eagle_config": eagle_config, - "draft_vocab_size": mode_cfg["config"]["eagle_architecture_config"][ - "draft_vocab_size" - ], + "draft_vocab_size": model.eagle_config.draft_vocab_size, } self._hf_extra_config.update(eagle_config_update) diff --git a/modelopt/torch/speculative/eagle/default_config.py b/modelopt/torch/speculative/eagle/default_config.py index 4a00e4364..09d535b61 100644 --- a/modelopt/torch/speculative/eagle/default_config.py +++ b/modelopt/torch/speculative/eagle/default_config.py @@ -43,6 +43,6 @@ "use_mtp_layernorm": False, "parallel_draft_step": 1, "parallel_draft_heads_num_layers": 1, - "has_lm_head": False, + "has_lm_head": True, "head_dim": 128, } diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index a568eb3ad..da5d99791 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -677,10 +677,6 @@ def modify( if self.eagle_config.draft_vocab_size is None else self.eagle_config.draft_vocab_size ) - if self.eagle_config.kv_channels is None: - self.eagle_config.kv_channels = ( - self.eagle_config.hidden_size // self.eagle_config.num_attention_heads - ) if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: assert eagle_self_logit_distillation, ( @@ -1203,7 +1199,7 @@ def pseudo_speculative_generate( hidden_states = hidden_states[:seq_len, :, :] draft_tokens = [] - for _ in range(steps): + for step in range(steps): padded_eagle_ids, seq_len, padded_hidden_states = right_padding( eagle_ids, hidden_states ) @@ -1232,7 +1228,8 @@ def pseudo_speculative_generate( output_weight, ) - if self.eagle_config.parallel_draft_step > 1: + # parallel_logits are only used after the last step + if step == steps - 1 and self.eagle_config.parallel_draft_step > 1: parallel_logits = [ eagle_logits[i][seq_len - 1 : seq_len] for i in range(1, self.eagle_config.parallel_draft_step) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 13e2a50ce..d5f00ab4c 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -466,17 +466,12 @@ def modify( ) self.eagle_config = PretrainedConfig.from_dict(eagle_architecture_config) # Hidden size and vocab size must match base model - llm_config = self.config.llm_config if hasattr(self.config, "llm_config") else self.config - self.eagle_config.hidden_size = llm_config.hidden_size - self.eagle_config.vocab_size = llm_config.vocab_size - self.eagle_config.max_position_embeddings = llm_config.max_position_embeddings + self.eagle_config.hidden_size = self._base_llm_config.hidden_size + self.eagle_config.vocab_size = self._base_llm_config.vocab_size + self.eagle_config.max_position_embeddings = self._base_llm_config.max_position_embeddings self.eagle_config.draft_vocab_size = getattr( self.eagle_config, "draft_vocab_size", self.eagle_config.vocab_size ) - if getattr(self.eagle_config, "head_dim", None) is None: - self.eagle_config.head_dim = ( - self.eagle_config.hidden_size // self.eagle_config.num_attention_heads - ) if self.eagle_config._attn_implementation is None: self.eagle_config._attn_implementation = "sdpa" @@ -992,7 +987,7 @@ def pseudo_speculative_generate( eagle_input_hidden_states = base_model_hidden_states draft_tokens = [] - for _ in range(steps): + for step in range(steps): # Get eagle inputs for the first eagle forward pass _, eagle_attention_mask, eagle_position_ids = self._get_eagle_module_inputs( input_ids, @@ -1017,7 +1012,8 @@ def pseudo_speculative_generate( position_embeddings, ) - if self.eagle_config.parallel_draft_step > 1: + # parallel logits are only used after the last step + if step == steps - 1 and self.eagle_config.parallel_draft_step > 1: parallel_logits = [ eagle_logits[i][:, -1:, :] for i in range(1, self.eagle_config.parallel_draft_step) From 50896bb6bf0cface2e0dc0762e7396ad552cdc0e Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 12 Dec 2025 10:38:20 -0800 Subject: [PATCH 10/19] take care of HF parallel draft export; rm unused mask tokens Signed-off-by: Ye Yu --- modelopt/torch/export/plugins/hf_spec_export.py | 15 +++++++++++++++ modelopt/torch/export/plugins/mcore_llama.py | 12 +++++++++--- modelopt/torch/export/unified_export_megatron.py | 2 +- modelopt/torch/speculative/eagle/eagle_model.py | 6 ------ .../torch/speculative/plugins/transformers.py | 7 ++++++- 5 files changed, 31 insertions(+), 11 deletions(-) diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index 9f89cb269..5b4ca6839 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -78,6 +78,21 @@ def export_spec_ckpt_state_dict(model: nn.Module): if "eagle_lm_head.weight" not in eagle_state: export_state_dict["lm_head.weight"] = model.state_dict()["lm_head.weight"] + # Add parallel draft weights + if model.eagle_config.parallel_draft_step > 1: + for i in range(model.eagle_config.parallel_draft_step - 1): + for j in range(model.eagle_config.parallel_draft_heads_num_layers): + export_state_dict[f"parallel_draft_heads.{i}.medusa_layers.{j}.linear.weight"] = ( + eagle_state[f"parallel_draft_heads.{i}.{j}.linear.weight"] + ) + if f"parallel_draft_heads.{i}.{j}.linear.bias" in eagle_state: + export_state_dict[f"parallel_draft_heads.{i}.medusa_layers.{j}.linear.bias"] = ( + eagle_state[f"parallel_draft_heads.{i}.{j}.linear.bias"] + ) + export_state_dict[f"parallel_draft_heads.{i}.lm_head.weight"] = eagle_state[ + f"parallel_draft_heads.{i}.{model.eagle_config.parallel_draft_heads_num_layers}.weight" + ] + return export_state_dict diff --git a/modelopt/torch/export/plugins/mcore_llama.py b/modelopt/torch/export/plugins/mcore_llama.py index b11008dde..f85d8b335 100644 --- a/modelopt/torch/export/plugins/mcore_llama.py +++ b/modelopt/torch/export/plugins/mcore_llama.py @@ -98,7 +98,9 @@ "final_layernorm": NameRemapping("norm."), "d2t": NameRemapping("d2t"), "output_layer": NameRemapping("lm_head."), - "parallel_draft_heads.medusa_layers": NameRemapping("parallel_draft_heads.{}.{}.linear."), + "parallel_draft_heads.medusa_layers": NameRemapping( + "parallel_draft_heads.{}.medusa_layers.{}.linear." + ), "parallel_draft_heads.lm_head": NameRemapping("parallel_draft_heads.{}.lm_head."), } @@ -115,7 +117,9 @@ "final_layernorm": NameRemapping("norm."), "d2t": NameRemapping("d2t"), "output_layer": NameRemapping("lm_head."), - "parallel_draft_heads.medusa_layers": NameRemapping("parallel_draft_heads.{}.{}.linear."), + "parallel_draft_heads.medusa_layers": NameRemapping( + "parallel_draft_heads.{}.medusa_layers.{}.linear." + ), "parallel_draft_heads.lm_head": NameRemapping("parallel_draft_heads.{}.lm_head."), } @@ -133,7 +137,9 @@ "final_layernorm": NameRemapping("norm."), "d2t": NameRemapping("d2t"), "output_layer": NameRemapping("lm_head."), - "parallel_draft_heads.medusa_layers": NameRemapping("parallel_draft_heads.{}.{}.linear."), + "parallel_draft_heads.medusa_layers": NameRemapping( + "parallel_draft_heads.{}.medusa_layers.{}.linear." + ), "parallel_draft_heads.lm_head": NameRemapping("parallel_draft_heads.{}.lm_head."), } diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 7ec6fad25..0a9b58829 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -231,7 +231,7 @@ def __init__( "max_position_embeddings": self._hf_text_config.max_position_embeddings, "num_attention_heads": model.eagle_module.config.num_attention_heads, "num_key_value_heads": model.eagle_module.config.num_query_groups, - "num_hidden_layers": model.eagle_config.num_hidden_layers, + "num_hidden_layers": model.eagle_config.num_layers, "vocab_size": self._hf_text_config.vocab_size, # Unset any special token ids given that the tokenizer can change here. "bos_token_id": None, diff --git a/modelopt/torch/speculative/eagle/eagle_model.py b/modelopt/torch/speculative/eagle/eagle_model.py index 69051f6e8..8e9203c9c 100644 --- a/modelopt/torch/speculative/eagle/eagle_model.py +++ b/modelopt/torch/speculative/eagle/eagle_model.py @@ -15,8 +15,6 @@ """Eagle model to support eagle decoding.""" -import torch - from modelopt.torch.opt.dynamic import DynamicModule @@ -45,7 +43,3 @@ def modify( self.eagle_report_acc = eagle_report_acc self.eagle_reuse_base_decoder = eagle_reuse_base_decoder self.eagle_loss_decay_factor = eagle_loss_decay_factor - - if eagle_architecture_config.get("parallel_draft_step", 1) > 1: - for i in range(eagle_architecture_config.get("parallel_draft_step") - 1): - self.register_buffer(f"mask_token_{i}", torch.tensor(-1)) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index d5f00ab4c..44462043a 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -254,7 +254,12 @@ def __init__(self, config, decoder_layer_cls, bias=False): if self.config.parallel_draft_step > 1: self.parallel_draft_heads = torch.nn.ModuleList( nn.Sequential( - *([ResBlock(config.hidden_size)] * self.config.parallel_draft_heads_num_layers), + *( + [ + ResBlock(config.hidden_size) + for _ in range(self.config.parallel_draft_heads_num_layers) + ] + ), nn.Linear(config.hidden_size, config.draft_vocab_size, bias=False), ) for _ in range(self.config.parallel_draft_step - 1) From 34f9325ecddb14dd5498b50b13435750a6a639c6 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 12 Dec 2025 11:29:07 -0800 Subject: [PATCH 11/19] update export logic to support multilayer eagle Signed-off-by: Ye Yu --- .../torch/export/plugins/hf_spec_export.py | 83 ++++++++++++++----- 1 file changed, 60 insertions(+), 23 deletions(-) diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index 5b4ca6839..38031e415 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -18,26 +18,62 @@ import torch import torch.nn as nn -EAGLE_MODELOPT_TO_OFFICIAL = { - "required": { - "layers.0.self_attn.q_proj.weight": "midlayer.self_attn.q_proj.weight", - "layers.0.self_attn.k_proj.weight": "midlayer.self_attn.k_proj.weight", - "layers.0.self_attn.v_proj.weight": "midlayer.self_attn.v_proj.weight", - "layers.0.self_attn.o_proj.weight": "midlayer.self_attn.o_proj.weight", - "layers.0.mlp.gate_proj.weight": "midlayer.mlp.gate_proj.weight", - "layers.0.mlp.up_proj.weight": "midlayer.mlp.up_proj.weight", - "layers.0.mlp.down_proj.weight": "midlayer.mlp.down_proj.weight", - "hidden_norm.weight": "midlayer.hidden_norm.weight", - "input_embeds_norm.weight": "midlayer.input_layernorm.weight", - "layers.0.post_attention_layernorm.weight": "midlayer.post_attention_layernorm.weight", - "norm.weight": "norm.weight", - "fc.weight": "fc.weight", - }, - "optional": { - "d2t": "d2t", - "eagle_lm_head.weight": "lm_head.weight", - }, -} + +def eagle_state_dict_key_convert(num_hidden_layers: int = 1) -> list[str]: + """Convert our eagle model state dict key to official format key(s).""" + assert num_hidden_layers >= 1, "num_hidden_layers should be at least 1." + eagle_modelopt_to_official = { + "required": { + "norm.weight": "norm.weight", + "fc.weight": "fc.weight", + }, + "optional": { + "d2t": "d2t", + "eagle_lm_head.weight": "lm_head.weight", + }, + } + if num_hidden_layers == 1: + eagle_modelopt_to_official["required"].update( + { + "hidden_norm.weight": "midlayer.hidden_norm.weight", + "input_embeds_norm.weight": "midlayer.input_layernorm.weight", + } + ) + else: + eagle_modelopt_to_official["required"].update( + { + "hidden_norm.weight": "layers.0.hidden_norm.weight", + "input_embeds_norm.weight": "layers.0.input_layernorm.weight", + } + ) + for i in range(num_hidden_layers): + if num_hidden_layers == 1: + index = "" + else: + index = f".{i}" + eagle_modelopt_to_official["required"].update( + { + "layers.{i}.self_attn.q_proj.weight": "midlayer" + + index + + ".self_attn.q_proj.weight", + "layers.{i}.self_attn.k_proj.weight": "midlayer" + + index + + ".self_attn.k_proj.weight", + "layers.{i}.self_attn.v_proj.weight": "midlayer" + + index + + ".self_attn.v_proj.weight", + "layers.{i}.self_attn.o_proj.weight": "midlayer" + + index + + ".self_attn.o_proj.weight", + "layers.{i}.mlp.gate_proj.weight": "midlayer" + index + ".mlp.gate_proj.weight", + "layers.{i}.mlp.up_proj.weight": "midlayer" + index + ".mlp.up_proj.weight", + "layers.{i}.mlp.down_proj.weight": "midlayer" + index + ".mlp.down_proj.weight", + "layers.{i}.post_attention_layernorm.weight": "midlayer" + + index + + ".post_attention_layernorm.weight", + } + ) + return eagle_modelopt_to_official def _check_state_dict_keys_match(draft_model: nn.Module, required_items: dict): @@ -61,15 +97,16 @@ def export_spec_ckpt_state_dict(model: nn.Module): # check the model has only speculative decoding assert spec_opt_only(model), "Not purely eagle model." + eagle_modelopt_to_official = eagle_state_dict_key_convert(model.eagle_config.num_hidden_layers) # Check if the state dict keys match - _check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"]) + _check_state_dict_keys_match(model.eagle_module, eagle_modelopt_to_official["required"]) # Convert key names and save the state dict eagle_state = model.eagle_module.state_dict() export_state_dict = {} for ours_key, export_key in { - **EAGLE_MODELOPT_TO_OFFICIAL["required"], - **EAGLE_MODELOPT_TO_OFFICIAL["optional"], + **eagle_modelopt_to_official["required"], + **eagle_modelopt_to_official["optional"], }.items(): if ours_key in eagle_state: export_state_dict[export_key] = eagle_state[ours_key] From b83f66b836360c50121f5ed6535263de89f66638 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 12 Dec 2025 11:31:45 -0800 Subject: [PATCH 12/19] debug Signed-off-by: Ye Yu --- modelopt/torch/export/plugins/hf_spec_export.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index 38031e415..e534b936e 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -53,22 +53,22 @@ def eagle_state_dict_key_convert(num_hidden_layers: int = 1) -> list[str]: index = f".{i}" eagle_modelopt_to_official["required"].update( { - "layers.{i}.self_attn.q_proj.weight": "midlayer" + f"layers.{i}.self_attn.q_proj.weight": "midlayer" + index + ".self_attn.q_proj.weight", - "layers.{i}.self_attn.k_proj.weight": "midlayer" + f"layers.{i}.self_attn.k_proj.weight": "midlayer" + index + ".self_attn.k_proj.weight", - "layers.{i}.self_attn.v_proj.weight": "midlayer" + f"layers.{i}.self_attn.v_proj.weight": "midlayer" + index + ".self_attn.v_proj.weight", - "layers.{i}.self_attn.o_proj.weight": "midlayer" + f"layers.{i}.self_attn.o_proj.weight": "midlayer" + index + ".self_attn.o_proj.weight", - "layers.{i}.mlp.gate_proj.weight": "midlayer" + index + ".mlp.gate_proj.weight", - "layers.{i}.mlp.up_proj.weight": "midlayer" + index + ".mlp.up_proj.weight", - "layers.{i}.mlp.down_proj.weight": "midlayer" + index + ".mlp.down_proj.weight", - "layers.{i}.post_attention_layernorm.weight": "midlayer" + f"layers.{i}.mlp.gate_proj.weight": "midlayer" + index + ".mlp.gate_proj.weight", + f"layers.{i}.mlp.up_proj.weight": "midlayer" + index + ".mlp.up_proj.weight", + f"layers.{i}.mlp.down_proj.weight": "midlayer" + index + ".mlp.down_proj.weight", + f"layers.{i}.post_attention_layernorm.weight": "midlayer" + index + ".post_attention_layernorm.weight", } From 10f46d71e7ab9e4dba7182845d15db049a7a3bf1 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 12 Dec 2025 11:52:23 -0800 Subject: [PATCH 13/19] formatting Signed-off-by: Ye Yu --- modelopt/torch/export/plugins/hf_spec_export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index e534b936e..6563cc661 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -19,7 +19,7 @@ import torch.nn as nn -def eagle_state_dict_key_convert(num_hidden_layers: int = 1) -> list[str]: +def eagle_state_dict_key_convert(num_hidden_layers: int = 1) -> dict[str, dict[str, str]]: """Convert our eagle model state dict key to official format key(s).""" assert num_hidden_layers >= 1, "num_hidden_layers should be at least 1." eagle_modelopt_to_official = { From b936d7e2448dc8525c13522c0978d77eb8a4342b Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 12 Dec 2025 12:32:27 -0800 Subject: [PATCH 14/19] fix eagle export test Signed-off-by: Ye Yu --- tests/examples/speculative_decoding/test_eagle.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index 38dae9a15..4637523c5 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -19,7 +19,7 @@ import safetensors.torch from _test_utils.examples.run_command import run_example_command -from modelopt.torch.export.plugins.hf_spec_export import EAGLE_MODELOPT_TO_OFFICIAL +from modelopt.torch.export.plugins.hf_spec_export import eagle_state_dict_key_convert @pytest.fixture(scope="module") @@ -89,7 +89,8 @@ def test_export_hf_checkpoint(eagle_output_dir): ) # Check the exported checkpoints have required keys state_dict = safetensors.torch.load_file(eagle_output_dir / "eagle-tinyllama-export" / "model.safetensors") - for required_key in EAGLE_MODELOPT_TO_OFFICIAL["required"].values(): + eagle_modelopt_to_official = eagle_state_dict_key_convert(num_layers=1) + for required_key in eagle_modelopt_to_official["required"].values(): assert required_key in state_dict, f"Missing key '{required_key}' in state_dict" From 07aa8612d77a141195425238c220ad5102d13af6 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 12 Dec 2025 13:26:38 -0800 Subject: [PATCH 15/19] fix typo Signed-off-by: Ye Yu --- modelopt/torch/export/plugins/hf_spec_export.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index 6563cc661..1da6e5f5e 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -42,8 +42,8 @@ def eagle_state_dict_key_convert(num_hidden_layers: int = 1) -> dict[str, dict[s else: eagle_modelopt_to_official["required"].update( { - "hidden_norm.weight": "layers.0.hidden_norm.weight", - "input_embeds_norm.weight": "layers.0.input_layernorm.weight", + "hidden_norm.weight": "midlayer.0.hidden_norm.weight", + "input_embeds_norm.weight": "midlayer.0.input_layernorm.weight", } ) for i in range(num_hidden_layers): From b4b9b50e1c178224e747f381f7d1e11864936037 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 12 Dec 2025 13:59:17 -0800 Subject: [PATCH 16/19] typo Signed-off-by: Ye Yu --- tests/examples/speculative_decoding/test_eagle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index 4637523c5..a5223da2c 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -89,7 +89,7 @@ def test_export_hf_checkpoint(eagle_output_dir): ) # Check the exported checkpoints have required keys state_dict = safetensors.torch.load_file(eagle_output_dir / "eagle-tinyllama-export" / "model.safetensors") - eagle_modelopt_to_official = eagle_state_dict_key_convert(num_layers=1) + eagle_modelopt_to_official = eagle_state_dict_key_convert(num_hidden_layers=1) for required_key in eagle_modelopt_to_official["required"].values(): assert required_key in state_dict, f"Missing key '{required_key}' in state_dict" From dd2080559e261cf754f287a993fb08bcee4c157f Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 12 Dec 2025 16:38:52 -0800 Subject: [PATCH 17/19] switch back to has_lm_head=False by default as it hurts AR in testing Signed-off-by: Ye Yu --- modelopt/torch/speculative/eagle/default_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/speculative/eagle/default_config.py b/modelopt/torch/speculative/eagle/default_config.py index 09d535b61..4a00e4364 100644 --- a/modelopt/torch/speculative/eagle/default_config.py +++ b/modelopt/torch/speculative/eagle/default_config.py @@ -43,6 +43,6 @@ "use_mtp_layernorm": False, "parallel_draft_step": 1, "parallel_draft_heads_num_layers": 1, - "has_lm_head": True, + "has_lm_head": False, "head_dim": 128, } From 8c2c6f5b810c0ac7f115553890fdeaaba1d8c675 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 12 Dec 2025 16:51:19 -0800 Subject: [PATCH 18/19] add parallel configs in export Signed-off-by: Ye Yu --- modelopt/torch/export/plugins/hf_spec_export.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index 1da6e5f5e..385f42471 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -172,6 +172,9 @@ def export_spec_ckpt_config(model: nn.Module): "use_input_layernorm_in_first_layer": None, "use_last_layernorm": None, "use_mtp_layernorm": None, + "next_layer_regular": True, + "parallel_draft_step": None, + "parallel_draft_heads_num_layers": None, }, } From d7139dd18b97cb6435e637809fb57c4b5657a131 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 12 Dec 2025 16:55:13 -0800 Subject: [PATCH 19/19] skip next_layer_regular in eagle_config as it is hardcoded to true Signed-off-by: Ye Yu --- modelopt/torch/export/plugins/hf_spec_export.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index 385f42471..b39d74f34 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -191,7 +191,8 @@ def _get_config_from_eagle_config_or_base_config(key: str, model: nn.Module): if isinstance(value, dict): # for eagle config, we find it in model.eagle_config for sub_key in value: - value[sub_key] = _get_config_from_eagle_config_or_base_config(sub_key, model) + if value[sub_key] is None: + value[sub_key] = _get_config_from_eagle_config_or_base_config(sub_key, model) elif value is None: # First, we try to load fron eagle config. new_value = _get_config_from_eagle_config_or_base_config(key, model)