Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions torchspec/models/draft/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ class Eagle3DraftModel(PreTrainedModel, ABC):
the abstract methods to support training with TTT.
"""

def __init__(self, config):
super().__init__(config)

self.num_aux_hidden_states = getattr(config, "num_aux_hidden_states", None)
if self.num_aux_hidden_states is None:
eagle_config = getattr(config, "eagle_config", None) or {}
layer_ids = eagle_config.get("eagle_aux_hidden_state_layer_ids")
self.num_aux_hidden_states = len(layer_ids) if layer_ids else 3

@abstractmethod
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
Expand Down
21 changes: 20 additions & 1 deletion torchspec/models/draft/deepseek_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,9 +519,22 @@ def __init__(self, config: DeepseekV3Config, attention_backend: str = "sdpa") ->
self.midlayer = DeepSeekDecoderLayer(config, attention_backend=attention_backend)

target_hidden_size = getattr(config, "target_hidden_size", config.hidden_size)
self.fc = nn.Linear(target_hidden_size * 3, config.hidden_size, bias=False)
self.fc = nn.Linear(
target_hidden_size * self.num_aux_hidden_states, config.hidden_size, bias=False
)
use_fc_norm = getattr(config, "fc_norm", None)
if use_fc_norm:
self.fc_norm = nn.ModuleList(
[
LlamaRMSNorm(target_hidden_size, eps=config.rms_norm_eps)
for _ in range(self.num_aux_hidden_states)
]
)
else:
self.fc_norm = None

self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm_output = getattr(config, "norm_output", False)
self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False)

if self.vocab_size != self.target_vocab_size:
Expand All @@ -547,6 +560,12 @@ def project_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise ValueError(
f"Target hidden states size mismatch: {hidden_states.size(-1)} != expected: {expected_size}"
)
if self.fc_norm is not None:
chunks = hidden_states.chunk(self.num_aux_hidden_states, dim=-1)
hidden_states = torch.cat(
[norm(chunk) for norm, chunk in zip(self.fc_norm, chunks)],
dim=-1,
)
return self.fc(hidden_states)

def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
Expand Down
26 changes: 23 additions & 3 deletions torchspec/models/draft/llama3_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2432,12 +2432,26 @@ def __init__(self, config, attention_backend="sdpa") -> None:
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
self.midlayer = LlamaDecoderLayer(config, attention_backend=attention_backend)

if hasattr(config, "target_hidden_size"):
self.fc = torch.nn.Linear(config.target_hidden_size * 3, config.hidden_size, bias=False)
target_hidden_size = getattr(config, "target_hidden_size", config.hidden_size)

self.fc = torch.nn.Linear(
target_hidden_size * self.num_aux_hidden_states,
config.hidden_size,
bias=False,
)
use_fc_norm = getattr(config, "fc_norm", None)
if use_fc_norm:
self.fc_norm = nn.ModuleList(
[
LlamaRMSNorm(target_hidden_size, eps=config.rms_norm_eps)
for _ in range(self.num_aux_hidden_states)
]
)
else:
self.fc = torch.nn.Linear(config.hidden_size * 3, config.hidden_size, bias=False)
self.fc_norm = None

self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm_output = getattr(config, "norm_output", False)
self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False)

if self.vocab_size != self.target_vocab_size:
Expand All @@ -2454,6 +2468,12 @@ def project_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise ValueError(
f"Target hidden states size mismatch: {hidden_states.size(-1)} != expected: {expected_size}"
)
if self.fc_norm is not None:
chunks = hidden_states.chunk(self.num_aux_hidden_states, dim=-1)
hidden_states = torch.cat(
[norm(chunk) for norm, chunk in zip(self.fc_norm, chunks)],
dim=-1,
)
if os.environ.get("TORCHSPEC_EAGLE3_PROJ_FP32", "1") in {"0", "false", "False"}:
return self.fc(hidden_states.to(self.fc.weight.dtype))
proj = F.linear(
Expand Down
5 changes: 5 additions & 0 deletions torchspec/models/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,11 @@ def forward(
lm_head_weight=lm_head_weight,
norm_eps=norm_eps,
)

# Model takes its own normed hidden states as input for the next step, so apply norm here if needed.
if self.draft_model.norm_output:
hidden_states = self.draft_model.norm(hidden_states)

if self.attention_backend == "usp":
# A shard can have no local loss tokens while its Ulysses peers do.
# Keep the zero-loss path connected to this layer's activations so
Expand Down
Loading