Skip to content
Merged
Show file tree
Hide file tree
Changes from 72 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
2c1c155
Add Eagle3DraftModel implementation
fynnsu Oct 1, 2025
1c5caf3
Initial training loop script + batched data loading and collation
fynnsu Oct 3, 2025
7e6c1b4
Add loss masking
fynnsu Oct 3, 2025
e0d92a7
Disable training on `verifier_lm_head`
fynnsu Oct 3, 2025
5679c44
Add Distributed Data Parallel (DDP)
fynnsu Oct 3, 2025
5423e0c
Add checkpoint saving and factor out rank0 logging
fynnsu Oct 3, 2025
c8919d7
Add distributed batch sampler
fynnsu Oct 3, 2025
4abb160
Refactor training loop into `Trainer` class
fynnsu Oct 3, 2025
775d27c
Rename `training_loop.py` to `trainer.py`
fynnsu Oct 3, 2025
be29741
Move train script to scripts/ dir
fynnsu Oct 3, 2025
7e8b13b
Add validation dataset
fynnsu Oct 3, 2025
2a5310b
Small fixes
fynnsu Oct 3, 2025
258e59d
Switch to FSDP2 and add Checkpointer class
fynnsu Oct 3, 2025
36bdfd0
Add general logging implementation with support for multiple metric b…
fynnsu Oct 3, 2025
1c8649a
Add flex attention mask tests
fynnsu Oct 3, 2025
a659cbc
Move Checkpointer to own file and restore single-device support
fynnsu Oct 3, 2025
f35e1ad
Correct data shifting
fynnsu Oct 3, 2025
718c65a
Add verifier embedding loading
fynnsu Oct 3, 2025
1942806
Fix small issues
fynnsu Oct 15, 2025
ebc56a5
Bug fixes
fynnsu Oct 15, 2025
a9921cf
Move accuracy calculation logic into eagle3/core.py and consolidate w…
fynnsu Oct 15, 2025
342ab2d
Format
fynnsu Oct 15, 2025
00d0da6
Update validation dataset handling
fynnsu Oct 17, 2025
ded4106
Update data noise transforms
fynnsu Oct 17, 2025
2c0fd37
Fix issue with extend that was causing `torch.compile`d flex_attentio…
fynnsu Oct 17, 2025
2cd3ead
Update model definitions / attribute names and structure to match vllm
fynnsu Oct 15, 2025
d43a5c0
Add support for saving model checkpoints in vllm friendly format with…
fynnsu Oct 15, 2025
ce4c977
Format
fynnsu Oct 15, 2025
590377a
Sort imports
fynnsu Oct 15, 2025
1bbe9cd
Enable gqa when applicable
fynnsu Oct 15, 2025
7afca7b
Initialize Eagle3DraftModel from Eagle3SpeculatorConfig
fynnsu Oct 15, 2025
3775944
Merge verifier_lm_head back into Eagle3DraftModel
fynnsu Oct 15, 2025
74923b2
Fix most lint errors
fynnsu Oct 22, 2025
b996c9e
Enable rank0 log filter override and remove distributed print statements
fynnsu Oct 23, 2025
6060bc9
Fix lint issues
fynnsu Oct 23, 2025
6464b98
Add `rich` and `tqdm` to requirements
fynnsu Oct 23, 2025
b5420a3
Fix type errors
fynnsu Oct 23, 2025
8e0aca3
Switch to efficient Verifier weight loading
fynnsu Oct 23, 2025
c7af20c
Fix barrier() warning
fynnsu Oct 23, 2025
e676cd3
Ensure dtype of loaded verifier weights matches model weight dtype
fynnsu Oct 15, 2025
80a3364
Allow non-strict weight loading
fynnsu Oct 15, 2025
9299f64
Format
fynnsu Oct 23, 2025
3cbbf4a
Cleanup `scripts/train.py`
fynnsu Oct 15, 2025
5cde85f
Rename `scripts/train.py` to `scripts/train_llama3_8b_drafter.py`
fynnsu Oct 24, 2025
78ad320
Add ttt step loss decay
fynnsu Oct 24, 2025
8a82bdb
Improve trainer status logging
fynnsu Oct 15, 2025
d41a672
Add support for new and old data formats
fynnsu Oct 15, 2025
a07f7ef
Improve metric calculation and logging
fynnsu Oct 15, 2025
0556ede
Fix typo
fynnsu Oct 28, 2025
4707244
Move data transforms to separate file
fynnsu Oct 15, 2025
539f426
Add docstrings to distributed setup utils
fynnsu Oct 15, 2025
294b664
`torch.compile` whole Eagle3DraftModel instead of only the attention fn
fynnsu Oct 15, 2025
cf9a478
Make batch `.to(device)` non-blocking in train/val loop
fynnsu Oct 15, 2025
6b67d1a
Improvements to data loading
fynnsu Oct 15, 2025
735da6b
Generalize train script to support other verifier types
fynnsu Oct 30, 2025
7d97df1
Add LR scheduler and switch to AdamW
fynnsu Nov 4, 2025
14d6b25
Save models in different dtype (default bfloat16)
fynnsu Nov 4, 2025
83fde60
Remove attributes with mismatched names from Eagle3DraftModel
fynnsu Nov 4, 2025
7d0da23
Remove ttt_steps and ttt_step_loss_decay as attributes of Eagle3Draft…
fynnsu Nov 4, 2025
ecbbbeb
Favor `attention_mask` name over `block_mask`
fynnsu Nov 4, 2025
8ace0df
Move initialization of variables into `if return_loss` scope
fynnsu Nov 4, 2025
b0174b0
Update model definitions init comments
fynnsu Nov 4, 2025
bdef7ca
Separate definitions of first decoder block layer from subsequent
fynnsu Nov 4, 2025
ed46ace
Update noise transforms
fynnsu Nov 4, 2025
0700ae8
Add Eagle3DraftModel checks where code isn't generalized yet
fynnsu Nov 4, 2025
bbc1526
Add verification checks for loaded d2t and t2d tensors
fynnsu Nov 4, 2025
0a3b8dd
Add a comment explaining `prev_correct` tensor
fynnsu Nov 4, 2025
d2d9d25
Add docstrings / type hints
fynnsu Nov 4, 2025
5c901ef
Refactor Eagle3 draft model attribute setup
fynnsu Nov 4, 2025
877e8fb
Extract metric and loss computation into separate fn
fynnsu Nov 4, 2025
2b4cdce
Add `scripts/TRAINING.md`
fynnsu Nov 4, 2025
fc005e1
Suggestions
fynnsu Nov 4, 2025
45bf922
Fix rebase error
fynnsu Nov 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ dependencies = [
"pydantic>=2.0.0",
"pydantic-settings>=2.0.0",
"pyyaml>=6.0.0",
"rich",
"safetensors",
"torch",
"tqdm",
"transformers",
"typer-slim>=0.12.0",
]
Expand Down Expand Up @@ -102,6 +104,7 @@ dev = [
"types-PyYAML~=6.0.1",
"types-requests~=2.32.0",
"types-toml",
"types-tqdm",

# link checking
"mkdocs-linkcheck~=1.0.6",
Expand Down Expand Up @@ -211,9 +214,6 @@ select = [
"UP", # pyupgrade: automatically upgrades syntax for newer versions of Python
"W", # Warning: provides warnings about potential issues in the code
"YTT", # flake8-2020: identifies code that will break with future Python releases

# Code Documentation
"FIX", # flake8-fixme: detects FIXMEs and other temporary comments that should be resolved
]

[tool.ruff.lint.extend-per-file-ignores]
Expand All @@ -236,6 +236,9 @@ select = [
"src/speculators/convert/**/*.py" = [
"BLE001", # allow catching Exception for conversion errors
]
"scripts/**/*.py" = [
"INP001",
]

"src/speculators/data_generation/**/*.py" = [
"S106", # false positives for chat template tokens
Expand Down
56 changes: 56 additions & 0 deletions scripts/TRAINING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Eagle3 Training

`scripts/train.py` provides the main entry point for training Eagle3 models.

## Running the training script

To run in a multi-node distributed training setup with FSDP, the scripts should be launched with `torchrun`:
```bash
torchrun --nnodes=1 --nproc_per_node=<num_gpus> scripts/train.py
```

For single GPU training (useful for debugging), the script can be run directly:
```bash
python scripts/train.py
```

## Arguments
The scripts has one required argument: `--verifier-name-or-path`, which is the name or path of the verifier model to use.

The scripts has the following optional arguments:
- `--data-path`: The path to the data directory. Defaults to `./data`. The script will collect all `.pt` files in this directory or its subdirectories and use them as training data.
- `--save-path`: The path to save the checkpoints. Defaults to `./checkpoints`. The script will create subdirectories for each epoch to save the model weights and optimizer states. e.g. `./checkpoints/0/`
- `--epochs`: The number of epochs to train for. Defaults to 20.
- `--lr`: The learning rate to use. Defaults to 1e-4.
- `--no-resume-from-checkpoint`: If set, the script will not resume from the last checkpoint if it exists, and will instead start from scratch and overwrite existing checkpoints.
- `--logger`: The logger to use. Defaults to empty string, which means no logging. Supported loggers are `trackio`, `wandb`, and `tensorboard`.
- `--total-seq-len`: The total sequence length to use. Defaults to 8192.
- `--data-format-version`: The version of the data format to use. Defaults to 1. The structure of the data to train on. `1` is the default and is the structure produced by Speculators generation scripts. `0` exists for backwards compatibility with the old data format.
- `--log-dir`: The path to save the logs. Defaults to `./logs`.
- `--run-name`: The name of the run. Defaults to None.
- `--num-layers`: The number of layers to use. Defaults to 1.
- `--d2t-path`: The path to the d2t tensor. Defaults to `d2t.npy`.
- `--t2d-path`: The path to the t2d tensor. Defaults to `t2d.npy`.
- `--ttt-steps`: The number of TTT steps to use. Defaults to 3.
- `--ttt-step-loss-decay`: The loss decay factor to use for the TTT steps. Defaults to 1.0.

## Example run command
```bash
torchrun --nnodes=1 --nproc_per_node=8 scripts/train.py \
--verifier-name-or-path "meta-llama/Llama-3.1-8B" \
--data-path "./data/llama-3.1-8b_sharegpt/gen/" \
--save-path "./checkpoints/llama-3.1-8b.eagle3" \
--epochs 10 \
--lr 1e-4 \
--no-resume-from-checkpoint \
--logger "tensorboard" \
--total-seq-len 8192 \
--data-format-version 1 \
--log-dir "./logs/llama-3.1-8b.eagle3" \
--run-name "llama-3.1-8b.eagle3" \
--num-layers 1 \
--d2t-path "./data/llama-3.1-8b_sharegpt/d2t.npy" \
--t2d-path "./data/llama-3.1-8b_sharegpt/t2d.npy" \
--ttt-steps 3 \
--ttt-step-loss-decay 1.0
```
213 changes: 213 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import argparse

import numpy as np
import torch
from torch.utils.data import DataLoader
from transformers import LlamaConfig

from speculators.config import SpeculatorsConfig, VerifierConfig
from speculators.models.eagle3 import Eagle3SpeculatorConfig
from speculators.proposals.greedy import GreedyTokenProposalConfig
from speculators.train.data import (
Eagle3SampleFileDataset,
create_collate_fn,
split_files,
standardize_data_v0,
standardize_data_v1,
)
from speculators.train.distributed_batch_sampler import (
MultipackDistributedBatchSamplerV2,
)
from speculators.train.eagle3.core import Eagle3DraftModel
from speculators.train.logger import setup_metric_logger, setup_root_logger
from speculators.train.noise_transforms import AddUniformNoise
from speculators.train.trainer import Trainer, TrainerConfig
from speculators.train.utils import maybe_destroy_distributed, maybe_setup_distributed

# DRAFTER MODEL HYPARAMETERS
NORM_BEFORE_RESIDUAL = True

# Dataloader
NUM_WORKERS = 12
PREFETCH_FACTOR = 4
NOISE_STD = 0.05


def setup_dataloader(
file_list: list[str],
world_size: int,
local_rank: int,
add_noise: bool = True,
data_format_version: int = 1,
) -> DataLoader:
"""Setup dataloader for training.
Args:
file_list: List of file paths to load data from.
world_size: Number of processes in the distributed training.
local_rank: Rank of the current process.
add_noise: Whether to add noise to the data.
data_format_version: Version of the data format. Default is 1.
Returns:
DataLoader: Dataloader for training.
"""
if add_noise:
noise_transform = AddUniformNoise(
std=NOISE_STD, tensors=("hidden_states", "verifier_last_hidden_states")
)
else:
noise_transform = None

standardize_fn = (
standardize_data_v1 if data_format_version == 1 else standardize_data_v0
)

dataset = Eagle3SampleFileDataset(
file_list=file_list,
max_len=args.total_seq_len,
transform=noise_transform,
standardize_fn=standardize_fn,
)
batch_sampler = MultipackDistributedBatchSamplerV2(
batch_max_length=args.total_seq_len,
lengths=dataset.approx_lengths,
num_replicas=world_size,
rank=local_rank,
)
return DataLoader(
dataset,
batch_sampler=batch_sampler,
num_workers=NUM_WORKERS,
prefetch_factor=PREFETCH_FACTOR,
pin_memory=True,
collate_fn=create_collate_fn(args.total_seq_len),
persistent_workers=True,
)


def main(args: argparse.Namespace):
# Setup logging
setup_root_logger()
setup_metric_logger(
loggers=args.logger, run_name=args.run_name, output_dir=args.log_dir
)

# Setup distributed training
local_rank, world_size, rank, is_distributed = maybe_setup_distributed()
device = torch.device(local_rank)

# Setup speculator config
llama_config = LlamaConfig.from_pretrained(args.verifier_name_or_path)
llama_config.num_hidden_layers = args.num_layers
llama_config.model_type = "llama" # reset to llama (handles non-llama verifiers)
llama_config._attn_implementation = "simple_flex_attention" # noqa: SLF001

# Load t2d and d2t tensors
d2t = torch.from_numpy(np.load(args.d2t_path)).to(device)
t2d = torch.from_numpy(np.load(args.t2d_path)).to(device)
draft_vocab_size = d2t.shape[0]

speculator_config = Eagle3SpeculatorConfig(
transformer_layer_config=llama_config,
draft_vocab_size=draft_vocab_size,
norm_before_residual=NORM_BEFORE_RESIDUAL,
speculators_config=SpeculatorsConfig(
algorithm="eagle3",
proposal_methods=[
GreedyTokenProposalConfig(
proposal_type="greedy",
speculative_tokens=args.ttt_steps,
)
],
default_proposal_method="greedy",
verifier=VerifierConfig(
name_or_path=args.verifier_name_or_path,
architectures=["LlamaForCausalLM"],
),
),
)

# Setup draft model
draft_model = Eagle3DraftModel(config=speculator_config, t2d=t2d, d2t=d2t)

# Setup dataloaders
train_files, val_files = split_files(args.data_path, ratio=0.9)
train_loader = setup_dataloader(
train_files,
world_size,
local_rank,
add_noise=True,
data_format_version=args.data_format_version,
)
val_loader = setup_dataloader(
val_files,
world_size,
local_rank,
add_noise=False,
data_format_version=args.data_format_version,
)

# Setup trainer
trainer_config = TrainerConfig(
num_epochs=args.epochs,
save_path=args.save_path,
lr=args.lr,
resume_from_checkpoint=not args.no_resume_from_checkpoint,
is_distributed=is_distributed,
local_rank=local_rank,
train_call_kwargs={
"use_off_policy_tokens": False,
"ttt_steps": args.ttt_steps,
"ttt_step_loss_decay": args.ttt_step_loss_decay,
},
val_call_kwargs={
"use_off_policy_tokens": False,
"ttt_steps": args.ttt_steps,
"ttt_step_loss_decay": args.ttt_step_loss_decay,
},
)
trainer = Trainer(draft_model, trainer_config, train_loader, val_loader)

# Run training
trainer.run_training()

# Cleanup
maybe_destroy_distributed()


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--verifier-name-or-path", type=str, required=True)
parser.add_argument("--data-path", type=str, default="./data")
parser.add_argument("--save-path", type=str, default="./checkpoints")
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--no-resume-from-checkpoint", action="store_true")
parser.add_argument(
"--logger",
type=str,
default="",
help="One of 'trackio', 'wandb', 'tensorboard' or comma separated list of them",
)
parser.add_argument("--total-seq-len", type=int, default=8192)
parser.add_argument("--data-format-version", type=int, default=1)
parser.add_argument("--log-dir", type=str, default="./logs")
parser.add_argument("--run-name", type=str, default=None)
parser.add_argument("--num-layers", type=int, default=1)
parser.add_argument("--d2t-path", type=str, default="d2t.npy")
parser.add_argument("--t2d-path", type=str, default="t2d.npy")
parser.add_argument("--ttt-steps", type=int, default=3)
parser.add_argument("--ttt-step-loss-decay", type=float, default=1.0)
return parser.parse_args()


if __name__ == "__main__":
args = parse_args()
main(args)


# RUN WITH:
# torchrun --nnodes=1 --nproc_per_node=<num_gpus> scripts/train.py
# for FSDP training
# OR
# python scripts/train.py
# for single GPU training
Empty file.
Loading
Loading