From 8299c7097f9a8f8b9dcf24302e54353e5f4d5f6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BA=AC=E6=9D=AD?= Date: Thu, 23 Oct 2025 19:39:51 +0800 Subject: [PATCH 1/6] support multi-layers eagle --- specforge/core/eagle3.py | 10 +++- specforge/modeling/draft/llama3_eagle.py | 68 ++++++++++++++---------- 2 files changed, 48 insertions(+), 30 deletions(-) diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index d5cb9a4b..32ed1cb1 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -225,7 +225,10 @@ def forward( vlosses = [] acces = [] if self.attention_backend == "sdpa": - cache_hidden = [[], []] + num_hidden_layers = getattr( + self.draft_model.config.to_dict(), "num_hidden_layers", 1 + ) + cache_hidden = [[[], []] for _ in range(num_hidden_layers)] past_key_values = None elif self.attention_backend == "flex_attention": cache_hidden = None @@ -599,7 +602,10 @@ def forward( vlosses = [] acces = [] if self.attention_backend == "sdpa": - cache_hidden = [[], []] + num_hidden_layers = getattr( + self.draft_model.config.to_dict(), "num_hidden_layers", 1 + ) + cache_hidden = [[[], []] for _ in range(num_hidden_layers)] past_key_values = None elif self.attention_backend == "flex_attention": cache_hidden = None diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 54dc3269..86790586 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -338,7 +338,7 @@ def forward(self, x, position_ids): class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = 0): super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -350,6 +350,7 @@ def __init__(self, config): self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings + self.layer_idx = layer_idx self.q_proj = nn.Linear( self.hidden_size * 2, self.num_heads * self.head_dim, bias=False @@ -599,7 +600,7 @@ def forward( key_cache, value_cache = past_key_values.update( key_states, value_states, - layer_idx=0, # TODO: support multiple layers + layer_idx=self.layer_idx, cache_kwargs=cache_kwargs, ) @@ -706,23 +707,22 @@ def forward(self, hidden_states): class LlamaDecoderLayer(nn.Module): - def __init__(self, config, attention_backend: str = "sdpa"): + def __init__(self, config, attention_backend: str = "sdpa", layer_idx: int = 0): super().__init__() self.hidden_size = config.hidden_size + self.layer_idx = layer_idx if attention_backend == "sdpa": - self.self_attn = LlamaAttention(config=config) + self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) elif attention_backend == "flex_attention": print_with_rank("Using flex attention on draft model training!") - self.self_attn = LlamaFlexAttention(config=config) + self.self_attn = LlamaFlexAttention(config=config, layer_idx=layer_idx) else: raise ValueError(f"Unknown attention backend {attention_backend}") self.mlp = LlamaMLP(config) - # self.fc = nn.Linear(config.hidden_size * 2, config.hidden_size) self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - # if self.index!=0: self.post_attention_layernorm = LlamaRMSNorm( config.hidden_size, eps=config.rms_norm_eps @@ -797,7 +797,16 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: self.embed_tokens = nn.Embedding( config.vocab_size, config.hidden_size, config.pad_token_id ) - self.midlayer = LlamaDecoderLayer(config, attention_backend=attention_backend) + # Support configurable number of layers + self.num_hidden_layers = getattr(config, "num_hidden_layers", 1) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer( + config, attention_backend=attention_backend, layer_idx=layer_idx + ) + for layer_idx in range(self.num_hidden_layers) + ] + ) if hasattr(config, "target_hidden_size"): self.fc = torch.nn.Linear( @@ -859,16 +868,17 @@ def forward( # fc hidden_states = self.fc(hidden_states) - hidden_states = self.midlayer( - input_emb=inputs_embeds, - hidden_states=hidden_states, - cache_hidden=cache_hidden, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=None, - output_attentions=False, - use_cache=False, - ) + for layer in self.layers[: self.num_hidden_layers]: + hidden_states = layer( + input_emb=inputs_embeds, + hidden_states=hidden_states, + cache_hidden=cache_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=None, + output_attentions=False, + use_cache=False, + ) # norm hidden_states = self.norm(hidden_states) @@ -897,13 +907,15 @@ def backbone( past_key_values: Optional[Cache] = None, use_cache: bool = True, ) -> torch.Tensor: - return self.midlayer( - input_emb=input_embeds, - hidden_states=hidden_states, - cache_hidden=cache_hidden, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - output_attentions=False, - use_cache=False, - ) + for layer in self.layers[: self.num_hidden_layers]: + hidden_states = layer( + input_emb=input_embeds, + hidden_states=hidden_states, + cache_hidden=cache_hidden[layer] if cache_hidden else None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=False, + use_cache=use_cache, + ) + return hidden_states From 195aaa14eaad16cb23b2a500ec065714d9e3bd54 Mon Sep 17 00:00:00 2001 From: Ximingwang-09 <72070413+Ximingwang-09@users.noreply.github.com> Date: Thu, 23 Oct 2025 19:57:06 +0800 Subject: [PATCH 2/6] Update specforge/modeling/draft/llama3_eagle.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- specforge/modeling/draft/llama3_eagle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 86790586..2f998da6 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -911,7 +911,7 @@ def backbone( hidden_states = layer( input_emb=input_embeds, hidden_states=hidden_states, - cache_hidden=cache_hidden[layer] if cache_hidden else None, + cache_hidden=cache_hidden[layer.layer_idx] if cache_hidden else None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, From cd44752362911ac45e4c11c309fd6a1095453a75 Mon Sep 17 00:00:00 2001 From: Ximingwang-09 <72070413+Ximingwang-09@users.noreply.github.com> Date: Thu, 23 Oct 2025 19:57:28 +0800 Subject: [PATCH 3/6] Update specforge/modeling/draft/llama3_eagle.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- specforge/modeling/draft/llama3_eagle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 2f998da6..9281e7cd 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -872,7 +872,7 @@ def forward( hidden_states = layer( input_emb=inputs_embeds, hidden_states=hidden_states, - cache_hidden=cache_hidden, + cache_hidden=cache_hidden[layer.layer_idx] if cache_hidden else None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=None, From 2e050dcd5547152cf83fbb1a0e08756315b0301b Mon Sep 17 00:00:00 2001 From: Ximingwang-09 <72070413+Ximingwang-09@users.noreply.github.com> Date: Thu, 23 Oct 2025 19:57:59 +0800 Subject: [PATCH 4/6] Update specforge/core/eagle3.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- specforge/core/eagle3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index 32ed1cb1..2ccfd9a3 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -226,7 +226,7 @@ def forward( acces = [] if self.attention_backend == "sdpa": num_hidden_layers = getattr( - self.draft_model.config.to_dict(), "num_hidden_layers", 1 + self.draft_model.config, "num_hidden_layers", 1 ) cache_hidden = [[[], []] for _ in range(num_hidden_layers)] past_key_values = None From 229a7f3d5a589f303266398e39bcd34c3e84dca2 Mon Sep 17 00:00:00 2001 From: Ximingwang-09 <72070413+Ximingwang-09@users.noreply.github.com> Date: Thu, 23 Oct 2025 19:58:14 +0800 Subject: [PATCH 5/6] Update specforge/core/eagle3.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- specforge/core/eagle3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index 2ccfd9a3..d87a2f08 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -603,7 +603,7 @@ def forward( acces = [] if self.attention_backend == "sdpa": num_hidden_layers = getattr( - self.draft_model.config.to_dict(), "num_hidden_layers", 1 + self.draft_model.config, "num_hidden_layers", 1 ) cache_hidden = [[[], []] for _ in range(num_hidden_layers)] past_key_values = None From 21cd38a17a7e27c469e75800d2d365d2a0010224 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BA=AC=E6=9D=AD?= Date: Thu, 23 Oct 2025 20:03:21 +0800 Subject: [PATCH 6/6] lint --- specforge/core/eagle3.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index d87a2f08..f35b9acf 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -225,9 +225,7 @@ def forward( vlosses = [] acces = [] if self.attention_backend == "sdpa": - num_hidden_layers = getattr( - self.draft_model.config, "num_hidden_layers", 1 - ) + num_hidden_layers = getattr(self.draft_model.config, "num_hidden_layers", 1) cache_hidden = [[[], []] for _ in range(num_hidden_layers)] past_key_values = None elif self.attention_backend == "flex_attention": @@ -602,9 +600,7 @@ def forward( vlosses = [] acces = [] if self.attention_backend == "sdpa": - num_hidden_layers = getattr( - self.draft_model.config, "num_hidden_layers", 1 - ) + num_hidden_layers = getattr(self.draft_model.config, "num_hidden_layers", 1) cache_hidden = [[[], []] for _ in range(num_hidden_layers)] past_key_values = None elif self.attention_backend == "flex_attention":