diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index d5cb9a4b..f35b9acf 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -225,7 +225,8 @@ def forward( vlosses = [] acces = [] if self.attention_backend == "sdpa": - cache_hidden = [[], []] + num_hidden_layers = getattr(self.draft_model.config, "num_hidden_layers", 1) + cache_hidden = [[[], []] for _ in range(num_hidden_layers)] past_key_values = None elif self.attention_backend == "flex_attention": cache_hidden = None @@ -599,7 +600,8 @@ def forward( vlosses = [] acces = [] if self.attention_backend == "sdpa": - cache_hidden = [[], []] + num_hidden_layers = getattr(self.draft_model.config, "num_hidden_layers", 1) + cache_hidden = [[[], []] for _ in range(num_hidden_layers)] past_key_values = None elif self.attention_backend == "flex_attention": cache_hidden = None diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 54dc3269..9281e7cd 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -338,7 +338,7 @@ def forward(self, x, position_ids): class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = 0): super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -350,6 +350,7 @@ def __init__(self, config): self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings + self.layer_idx = layer_idx self.q_proj = nn.Linear( self.hidden_size * 2, self.num_heads * self.head_dim, bias=False @@ -599,7 +600,7 @@ def forward( key_cache, value_cache = past_key_values.update( key_states, value_states, - layer_idx=0, # TODO: support multiple layers + layer_idx=self.layer_idx, cache_kwargs=cache_kwargs, ) @@ -706,23 +707,22 @@ def forward(self, hidden_states): class LlamaDecoderLayer(nn.Module): - def __init__(self, config, attention_backend: str = "sdpa"): + def __init__(self, config, attention_backend: str = "sdpa", layer_idx: int = 0): super().__init__() self.hidden_size = config.hidden_size + self.layer_idx = layer_idx if attention_backend == "sdpa": - self.self_attn = LlamaAttention(config=config) + self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) elif attention_backend == "flex_attention": print_with_rank("Using flex attention on draft model training!") - self.self_attn = LlamaFlexAttention(config=config) + self.self_attn = LlamaFlexAttention(config=config, layer_idx=layer_idx) else: raise ValueError(f"Unknown attention backend {attention_backend}") self.mlp = LlamaMLP(config) - # self.fc = nn.Linear(config.hidden_size * 2, config.hidden_size) self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - # if self.index!=0: self.post_attention_layernorm = LlamaRMSNorm( config.hidden_size, eps=config.rms_norm_eps @@ -797,7 +797,16 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: self.embed_tokens = nn.Embedding( config.vocab_size, config.hidden_size, config.pad_token_id ) - self.midlayer = LlamaDecoderLayer(config, attention_backend=attention_backend) + # Support configurable number of layers + self.num_hidden_layers = getattr(config, "num_hidden_layers", 1) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer( + config, attention_backend=attention_backend, layer_idx=layer_idx + ) + for layer_idx in range(self.num_hidden_layers) + ] + ) if hasattr(config, "target_hidden_size"): self.fc = torch.nn.Linear( @@ -859,16 +868,17 @@ def forward( # fc hidden_states = self.fc(hidden_states) - hidden_states = self.midlayer( - input_emb=inputs_embeds, - hidden_states=hidden_states, - cache_hidden=cache_hidden, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=None, - output_attentions=False, - use_cache=False, - ) + for layer in self.layers[: self.num_hidden_layers]: + hidden_states = layer( + input_emb=inputs_embeds, + hidden_states=hidden_states, + cache_hidden=cache_hidden[layer.layer_idx] if cache_hidden else None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=None, + output_attentions=False, + use_cache=False, + ) # norm hidden_states = self.norm(hidden_states) @@ -897,13 +907,15 @@ def backbone( past_key_values: Optional[Cache] = None, use_cache: bool = True, ) -> torch.Tensor: - return self.midlayer( - input_emb=input_embeds, - hidden_states=hidden_states, - cache_hidden=cache_hidden, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - output_attentions=False, - use_cache=False, - ) + for layer in self.layers[: self.num_hidden_layers]: + hidden_states = layer( + input_emb=input_embeds, + hidden_states=hidden_states, + cache_hidden=cache_hidden[layer.layer_idx] if cache_hidden else None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=False, + use_cache=use_cache, + ) + return hidden_states