Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
0e0d02e
Add Eagle3DraftModel implementation
fynnsu Oct 1, 2025
ae94a42
Initial training loop script + batched data loading and collation
fynnsu Oct 3, 2025
20cd35f
Add loss masking
fynnsu Oct 3, 2025
08c8848
Disable training on `verifier_lm_head`
fynnsu Oct 3, 2025
76aa8a1
Add Distributed Data Parallel (DDP)
fynnsu Oct 3, 2025
c4103d1
Add checkpoint saving and factor out rank0 logging
fynnsu Oct 3, 2025
aaa471a
Add distributed batch sampler
fynnsu Oct 3, 2025
b157e69
Refactor training loop into `Trainer` class
fynnsu Oct 3, 2025
2e4cf4d
Rename `training_loop.py` to `trainer.py`
fynnsu Oct 3, 2025
4088ec7
Move train script to scripts/ dir
fynnsu Oct 3, 2025
976fd9d
Add validation dataset
fynnsu Oct 3, 2025
d30a0c6
Small fixes
fynnsu Oct 3, 2025
595d9a2
Switch to FSDP2 and add Checkpointer class
fynnsu Oct 3, 2025
3476e9f
Add general logging implementation with support for multiple metric b…
fynnsu Oct 3, 2025
7ff6186
Add flex attention mask tests
fynnsu Oct 3, 2025
3d12f28
Move Checkpointer to own file and restore single-device support
fynnsu Oct 3, 2025
8e22439
Correct data shifting
fynnsu Oct 3, 2025
d9a304c
Add verifier embedding loading
fynnsu Oct 3, 2025
cab5a6a
Fix small issues
fynnsu Oct 15, 2025
fb22a24
Bug fixes
fynnsu Oct 15, 2025
eb039fb
Move accuracy calculation logic into eagle3/core.py and consolidate w…
fynnsu Oct 15, 2025
bc051b8
Format
fynnsu Oct 15, 2025
a47eb08
Update validation dataset handling
fynnsu Oct 17, 2025
5fa9136
Update data noise transforms
fynnsu Oct 17, 2025
ac36b32
Fix issue with extend that was causing `torch.compile`d flex_attentio…
fynnsu Oct 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import torch
import numpy as np
from transformers import LlamaConfig

from speculators.train.eagle3.core import Eagle3DraftModel, Eagle3VerifierLMHead
from speculators.train.data import (
Eagle3SampleFileDataset,
create_collate_fn,
split_files,
AddUniformNoise,
)
from speculators.train.distributed_batch_sampler import (
MultipackDistributedBatchSamplerV2,
)
from torch.utils.data import DataLoader

from speculators.train.utils import maybe_setup_distributed, maybe_destroy_distributed
from speculators.train.trainer import Trainer
from speculators.train.logger import setup_metric_logger, setup_root_logger


local_rank, world_size, rank, is_distributed = maybe_setup_distributed()


DEVICE = torch.device(local_rank)
EPOCHS = 102
draft_vocab_size = 32000
total_seq_len = 8192
datapath = "./data"
verifier_model_name_or_path = "meta-llama/Llama-3.1-8B-Instruct"


# TEMP MODEL SETUP
llama_config = LlamaConfig.from_pretrained(verifier_model_name_or_path)
hidden_size = llama_config.hidden_size
verifier_vocab_size = llama_config.vocab_size
llama_config = LlamaConfig(hidden_size=hidden_size, vocab_size=verifier_vocab_size)
llama_config._attn_implementation = "simple_flex_attention"

# d2t_vocab = torch.zeros(draft_vocab_size, dtype=torch.long).to(DEVICE)
# t2d_vocab = (
# torch.cat(
# [
# torch.ones(draft_vocab_size),
# torch.zeros(llama_config.vocab_size - draft_vocab_size),
# ]
# )
# .to(torch.bool)
# .to(DEVICE)
# )
d2t_vocab = torch.from_numpy(np.load("d2t.npy")).to(DEVICE)
t2d_vocab = torch.from_numpy(np.load("t2d.npy")).to(DEVICE)

setup_metric_logger(loggers="trackio", run_name=None, output_dir="./logs")
setup_root_logger()
# END TEMP MODEL SETUP

draft_model = Eagle3DraftModel(
verifier_model_name_or_path=verifier_model_name_or_path,
hidden_size=hidden_size,
t2d_vocab=t2d_vocab,
d2t_vocab=d2t_vocab,
decoder_layer_config=llama_config,
verifier_vocab_size=verifier_vocab_size,
verifier_pad_token_id=None,
num_layers=1,
ttt_steps=3,
)

verifier_lm_head = Eagle3VerifierLMHead(
hidden_size=hidden_size, draft_vocab_size=draft_vocab_size
)
verifier_lm_head.load_verifier_lm_head(verifier_model_name_or_path, t2d_vocab)

### TMP
draft_model.lm_head.weight.data = verifier_lm_head.lm_head.weight.data.to(
t2d_vocab.device
)
###
noise_transform = AddUniformNoise(
std=0.2, tensors=("hidden_states", "verifier_last_hidden_states")
)

train_files, val_files = split_files(datapath, ratio=0.9)
train_dataset = Eagle3SampleFileDataset(
file_list=train_files, max_len=total_seq_len, transform=noise_transform
)
train_batch_sampler = MultipackDistributedBatchSamplerV2(
batch_max_length=total_seq_len,
lengths=train_dataset.approx_lengths(),
num_replicas=world_size,
rank=local_rank,
)
train_loader = DataLoader(
train_dataset,
batch_sampler=train_batch_sampler,
num_workers=32,
prefetch_factor=8,
pin_memory=True,
collate_fn=create_collate_fn(total_seq_len),
persistent_workers=True,
)

val_dataset = Eagle3SampleFileDataset(file_list=val_files, max_len=total_seq_len)
val_batch_sampler = MultipackDistributedBatchSamplerV2(
batch_max_length=total_seq_len,
lengths=val_dataset.approx_lengths(),
num_replicas=world_size,
rank=local_rank,
)
val_loader = DataLoader(
val_dataset,
batch_sampler=val_batch_sampler,
num_workers=32,
prefetch_factor=8,
pin_memory=True,
collate_fn=create_collate_fn(total_seq_len),
persistent_workers=True,
)


# todo: make config better
config = {
"num_epochs": EPOCHS,
"save_path": "./checkpoints",
"lr": 1e-4,
"total_seq_len": total_seq_len,
"datapath": "./data",
"resume_from_checkpoint": True,
}


trainer = Trainer(
draft_model,
verifier_lm_head,
config,
train_loader,
val_loader,
is_distributed,
local_rank,
world_size,
)
trainer.run_training()

maybe_destroy_distributed()


# RUN WITH:
# CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nnodes=1 --nproc_per_node=4 scripts/train.py
149 changes: 149 additions & 0 deletions src/speculators/train/checkpointer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from abc import abstractmethod
import os
from pathlib import Path
import torch
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
StateDictOptions,
)

import torch.distributed as dist


class BaseCheckpointer:
"""Helper class to save and load checkpoints.

Checkpoint file structure:
../path/
0/ # epoch number
model_state_dict.pt
optimizer_state_dict.pt
1/
model_state_dict.pt
optimizer_state_dict.pt
...
"""

def __init__(self, path: Path | str, try_load_last_checkpoint: bool = True):
self.path = Path(path)
if try_load_last_checkpoint:
self.previous_epoch: int = self._get_previous_epoch()
else:
self.previous_epoch: int = -1

@abstractmethod
def load_model_state_dict(self, model: torch.nn.Module):
raise NotImplementedError

@abstractmethod
def load_optimizer_state_dict(
self, model: torch.nn.Module, optimizer: torch.optim.Optimizer
):
raise NotImplementedError

@abstractmethod
def save_checkpoint(
self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, epoch: int
):
raise NotImplementedError

def _get_previous_epoch(self) -> int:
if not self.path.exists():
return -1
last_checkpoint_num = -1
for d in self.path.iterdir():
if d.is_dir():
try:
last_checkpoint_num = max(last_checkpoint_num, int(d.name))
except ValueError:
continue
return last_checkpoint_num

def model_path(self, epoch: int):
model_fname = "model_state_dict.pt"
return self.path / str(epoch) / model_fname

def optimizer_path(self, epoch: int):
optimizer_fname = "optimizer_state_dict.pt"
return self.path / str(epoch) / optimizer_fname


class SingleGPUCheckpointer(BaseCheckpointer):
def load_model_state_dict(self, model: torch.nn.Module):
full_state_dict = torch.load(
self.model_path(self.previous_epoch),
weights_only=True,
map_location="cuda:0", # todo: make this configurable
)
model.load_state_dict(full_state_dict)

def load_optimizer_state_dict(
self, model: torch.nn.Module, optimizer: torch.optim.Optimizer
):
full_state_dict = torch.load(
self.optimizer_path(self.previous_epoch),
weights_only=True,
map_location="cuda:0", # todo: make this configurable
)
optimizer.load_state_dict(full_state_dict)

def save_checkpoint(
self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, epoch: int
):
os.makedirs(self.path / str(epoch), exist_ok=True)
torch.save(model.state_dict(), self.model_path(epoch))
torch.save(optimizer.state_dict(), self.optimizer_path(epoch))


class DistributedCheckpointer(BaseCheckpointer):
def load_model_state_dict(self, model: torch.nn.Module):
full_state_dict = torch.load(
self.model_path(self.previous_epoch),
mmap=True,
weights_only=True,
map_location="cpu",
)
set_model_state_dict(
model,
full_state_dict,
options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True),
)
dist.barrier()

def load_optimizer_state_dict(self, model, optimizer: torch.optim.Optimizer):
full_state_dict = torch.load(
self.optimizer_path(self.previous_epoch),
mmap=True,
weights_only=True,
map_location="cpu",
)
set_optimizer_state_dict(
model,
optimizer,
full_state_dict,
options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True),
)
dist.barrier()

def save_checkpoint(
self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, epoch: int
):
model_state_dict = get_model_state_dict(
model, options=StateDictOptions(full_state_dict=True, cpu_offload=True)
)
optimizer_state_dict = get_optimizer_state_dict(
model,
optimizer,
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
)

if dist.get_rank() == 0:
# Only rank 0 saves the checkpoint
os.makedirs(self.path / str(epoch), exist_ok=True)
torch.save(model_state_dict, self.model_path(epoch))
torch.save(optimizer_state_dict, self.optimizer_path(epoch))

dist.barrier()
Loading
Loading