Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions specforge/core/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,10 @@ def forward(
vlosses = []
acces = []
if self.attention_backend == "sdpa":
cache_hidden = [[], []]
num_hidden_layers = getattr(
self.draft_model.config.to_dict(), "num_hidden_layers", 1
)
cache_hidden = [[[], []] for _ in range(num_hidden_layers)]
past_key_values = None
elif self.attention_backend == "flex_attention":
cache_hidden = None
Expand Down Expand Up @@ -599,7 +602,10 @@ def forward(
vlosses = []
acces = []
if self.attention_backend == "sdpa":
cache_hidden = [[], []]
num_hidden_layers = getattr(
self.draft_model.config.to_dict(), "num_hidden_layers", 1
)
cache_hidden = [[[], []] for _ in range(num_hidden_layers)]
past_key_values = None
elif self.attention_backend == "flex_attention":
cache_hidden = None
Expand Down
68 changes: 40 additions & 28 deletions specforge/modeling/draft/llama3_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def forward(self, x, position_ids):
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config):
def __init__(self, config, layer_idx: Optional[int] = 0):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
Expand All @@ -350,6 +350,7 @@ def __init__(self, config):
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.layer_idx = layer_idx

self.q_proj = nn.Linear(
self.hidden_size * 2, self.num_heads * self.head_dim, bias=False
Expand Down Expand Up @@ -599,7 +600,7 @@ def forward(
key_cache, value_cache = past_key_values.update(
key_states,
value_states,
layer_idx=0, # TODO: support multiple layers
layer_idx=self.layer_idx,
cache_kwargs=cache_kwargs,
)

Expand Down Expand Up @@ -706,23 +707,22 @@ def forward(self, hidden_states):


class LlamaDecoderLayer(nn.Module):
def __init__(self, config, attention_backend: str = "sdpa"):
def __init__(self, config, attention_backend: str = "sdpa", layer_idx: int = 0):
super().__init__()
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx

if attention_backend == "sdpa":
self.self_attn = LlamaAttention(config=config)
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
elif attention_backend == "flex_attention":
print_with_rank("Using flex attention on draft model training!")
self.self_attn = LlamaFlexAttention(config=config)
self.self_attn = LlamaFlexAttention(config=config, layer_idx=layer_idx)
else:
raise ValueError(f"Unknown attention backend {attention_backend}")

self.mlp = LlamaMLP(config)
# self.fc = nn.Linear(config.hidden_size * 2, config.hidden_size)
self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# if self.index!=0:

self.post_attention_layernorm = LlamaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
Expand Down Expand Up @@ -797,7 +797,16 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None:
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, config.pad_token_id
)
self.midlayer = LlamaDecoderLayer(config, attention_backend=attention_backend)
# Support configurable number of layers
self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
self.layers = nn.ModuleList(
[
LlamaDecoderLayer(
config, attention_backend=attention_backend, layer_idx=layer_idx
)
for layer_idx in range(self.num_hidden_layers)
]
)

if hasattr(config, "target_hidden_size"):
self.fc = torch.nn.Linear(
Expand Down Expand Up @@ -859,16 +868,17 @@ def forward(

# fc
hidden_states = self.fc(hidden_states)
hidden_states = self.midlayer(
input_emb=inputs_embeds,
hidden_states=hidden_states,
cache_hidden=cache_hidden,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=None,
output_attentions=False,
use_cache=False,
)
for layer in self.layers[: self.num_hidden_layers]:
hidden_states = layer(
input_emb=inputs_embeds,
hidden_states=hidden_states,
cache_hidden=cache_hidden,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=None,
output_attentions=False,
use_cache=False,
)

# norm
hidden_states = self.norm(hidden_states)
Expand Down Expand Up @@ -897,13 +907,15 @@ def backbone(
past_key_values: Optional[Cache] = None,
use_cache: bool = True,
) -> torch.Tensor:
return self.midlayer(
input_emb=input_embeds,
hidden_states=hidden_states,
cache_hidden=cache_hidden,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=False,
use_cache=False,
)
for layer in self.layers[: self.num_hidden_layers]:
hidden_states = layer(
input_emb=input_embeds,
hidden_states=hidden_states,
cache_hidden=cache_hidden[layer] if cache_hidden else None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=False,
use_cache=use_cache,
)
return hidden_states
Loading