-
Notifications
You must be signed in to change notification settings - Fork 11
Eagle3 Training #143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Eagle3 Training #143
Changes from 72 commits
Commits
Show all changes
73 commits
Select commit
Hold shift + click to select a range
2c1c155
Add Eagle3DraftModel implementation
fynnsu 1c5caf3
Initial training loop script + batched data loading and collation
fynnsu 7e6c1b4
Add loss masking
fynnsu e0d92a7
Disable training on `verifier_lm_head`
fynnsu 5679c44
Add Distributed Data Parallel (DDP)
fynnsu 5423e0c
Add checkpoint saving and factor out rank0 logging
fynnsu c8919d7
Add distributed batch sampler
fynnsu 4abb160
Refactor training loop into `Trainer` class
fynnsu 775d27c
Rename `training_loop.py` to `trainer.py`
fynnsu be29741
Move train script to scripts/ dir
fynnsu 7e8b13b
Add validation dataset
fynnsu 2a5310b
Small fixes
fynnsu 258e59d
Switch to FSDP2 and add Checkpointer class
fynnsu 36bdfd0
Add general logging implementation with support for multiple metric b…
fynnsu 1c8649a
Add flex attention mask tests
fynnsu a659cbc
Move Checkpointer to own file and restore single-device support
fynnsu f35e1ad
Correct data shifting
fynnsu 718c65a
Add verifier embedding loading
fynnsu 1942806
Fix small issues
fynnsu ebc56a5
Bug fixes
fynnsu a9921cf
Move accuracy calculation logic into eagle3/core.py and consolidate w…
fynnsu 342ab2d
Format
fynnsu 00d0da6
Update validation dataset handling
fynnsu ded4106
Update data noise transforms
fynnsu 2c0fd37
Fix issue with extend that was causing `torch.compile`d flex_attentio…
fynnsu 2cd3ead
Update model definitions / attribute names and structure to match vllm
fynnsu d43a5c0
Add support for saving model checkpoints in vllm friendly format with…
fynnsu ce4c977
Format
fynnsu 590377a
Sort imports
fynnsu 1bbe9cd
Enable gqa when applicable
fynnsu 7afca7b
Initialize Eagle3DraftModel from Eagle3SpeculatorConfig
fynnsu 3775944
Merge verifier_lm_head back into Eagle3DraftModel
fynnsu 74923b2
Fix most lint errors
fynnsu b996c9e
Enable rank0 log filter override and remove distributed print statements
fynnsu 6060bc9
Fix lint issues
fynnsu 6464b98
Add `rich` and `tqdm` to requirements
fynnsu b5420a3
Fix type errors
fynnsu 8e0aca3
Switch to efficient Verifier weight loading
fynnsu c7af20c
Fix barrier() warning
fynnsu e676cd3
Ensure dtype of loaded verifier weights matches model weight dtype
fynnsu 80a3364
Allow non-strict weight loading
fynnsu 9299f64
Format
fynnsu 3cbbf4a
Cleanup `scripts/train.py`
fynnsu 5cde85f
Rename `scripts/train.py` to `scripts/train_llama3_8b_drafter.py`
fynnsu 78ad320
Add ttt step loss decay
fynnsu 8a82bdb
Improve trainer status logging
fynnsu d41a672
Add support for new and old data formats
fynnsu a07f7ef
Improve metric calculation and logging
fynnsu 0556ede
Fix typo
fynnsu 4707244
Move data transforms to separate file
fynnsu 539f426
Add docstrings to distributed setup utils
fynnsu 294b664
`torch.compile` whole Eagle3DraftModel instead of only the attention fn
fynnsu cf9a478
Make batch `.to(device)` non-blocking in train/val loop
fynnsu 6b67d1a
Improvements to data loading
fynnsu 735da6b
Generalize train script to support other verifier types
fynnsu 7d97df1
Add LR scheduler and switch to AdamW
fynnsu 14d6b25
Save models in different dtype (default bfloat16)
fynnsu 83fde60
Remove attributes with mismatched names from Eagle3DraftModel
fynnsu 7d0da23
Remove ttt_steps and ttt_step_loss_decay as attributes of Eagle3Draft…
fynnsu ecbbbeb
Favor `attention_mask` name over `block_mask`
fynnsu 8ace0df
Move initialization of variables into `if return_loss` scope
fynnsu b0174b0
Update model definitions init comments
fynnsu bdef7ca
Separate definitions of first decoder block layer from subsequent
fynnsu ed46ace
Update noise transforms
fynnsu 0700ae8
Add Eagle3DraftModel checks where code isn't generalized yet
fynnsu bbc1526
Add verification checks for loaded d2t and t2d tensors
fynnsu 0a3b8dd
Add a comment explaining `prev_correct` tensor
fynnsu d2d9d25
Add docstrings / type hints
fynnsu 5c901ef
Refactor Eagle3 draft model attribute setup
fynnsu 877e8fb
Extract metric and loss computation into separate fn
fynnsu 2b4cdce
Add `scripts/TRAINING.md`
fynnsu fc005e1
Suggestions
fynnsu 45bf922
Fix rebase error
fynnsu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| # Eagle3 Training | ||
|
|
||
| `scripts/train.py` provides the main entry point for training Eagle3 models. | ||
|
|
||
| ## Running the training script | ||
|
|
||
| To run in a multi-node distributed training setup with FSDP, the scripts should be launched with `torchrun`: | ||
| ```bash | ||
| torchrun --nnodes=1 --nproc_per_node=<num_gpus> scripts/train.py | ||
| ``` | ||
|
|
||
| For single GPU training (useful for debugging), the script can be run directly: | ||
| ```bash | ||
| python scripts/train.py | ||
| ``` | ||
|
|
||
| ## Arguments | ||
| The scripts has one required argument: `--verifier-name-or-path`, which is the name or path of the verifier model to use. | ||
|
|
||
| The scripts has the following optional arguments: | ||
| - `--data-path`: The path to the data directory. Defaults to `./data`. The script will collect all `.pt` files in this directory or its subdirectories and use them as training data. | ||
| - `--save-path`: The path to save the checkpoints. Defaults to `./checkpoints`. The script will create subdirectories for each epoch to save the model weights and optimizer states. e.g. `./checkpoints/0/` | ||
| - `--epochs`: The number of epochs to train for. Defaults to 20. | ||
| - `--lr`: The learning rate to use. Defaults to 1e-4. | ||
| - `--no-resume-from-checkpoint`: If set, the script will not resume from the last checkpoint if it exists, and will instead start from scratch and overwrite existing checkpoints. | ||
| - `--logger`: The logger to use. Defaults to empty string, which means no logging. Supported loggers are `trackio`, `wandb`, and `tensorboard`. | ||
| - `--total-seq-len`: The total sequence length to use. Defaults to 8192. | ||
| - `--data-format-version`: The version of the data format to use. Defaults to 1. The structure of the data to train on. `1` is the default and is the structure produced by Speculators generation scripts. `0` exists for backwards compatibility with the old data format. | ||
| - `--log-dir`: The path to save the logs. Defaults to `./logs`. | ||
| - `--run-name`: The name of the run. Defaults to None. | ||
| - `--num-layers`: The number of layers to use. Defaults to 1. | ||
| - `--d2t-path`: The path to the d2t tensor. Defaults to `d2t.npy`. | ||
| - `--t2d-path`: The path to the t2d tensor. Defaults to `t2d.npy`. | ||
| - `--ttt-steps`: The number of TTT steps to use. Defaults to 3. | ||
| - `--ttt-step-loss-decay`: The loss decay factor to use for the TTT steps. Defaults to 1.0. | ||
|
|
||
| ## Example run command | ||
| ```bash | ||
| torchrun --nnodes=1 --nproc_per_node=8 scripts/train.py \ | ||
| --verifier-name-or-path "meta-llama/Llama-3.1-8B" \ | ||
| --data-path "./data/llama-3.1-8b_sharegpt/gen/" \ | ||
| --save-path "./checkpoints/llama-3.1-8b.eagle3" \ | ||
| --epochs 10 \ | ||
| --lr 1e-4 \ | ||
| --no-resume-from-checkpoint \ | ||
| --logger "tensorboard" \ | ||
| --total-seq-len 8192 \ | ||
| --data-format-version 1 \ | ||
| --log-dir "./logs/llama-3.1-8b.eagle3" \ | ||
| --run-name "llama-3.1-8b.eagle3" \ | ||
| --num-layers 1 \ | ||
| --d2t-path "./data/llama-3.1-8b_sharegpt/d2t.npy" \ | ||
| --t2d-path "./data/llama-3.1-8b_sharegpt/t2d.npy" \ | ||
| --ttt-steps 3 \ | ||
| --ttt-step-loss-decay 1.0 | ||
| ``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,213 @@ | ||
| import argparse | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from torch.utils.data import DataLoader | ||
| from transformers import LlamaConfig | ||
|
|
||
| from speculators.config import SpeculatorsConfig, VerifierConfig | ||
| from speculators.models.eagle3 import Eagle3SpeculatorConfig | ||
| from speculators.proposals.greedy import GreedyTokenProposalConfig | ||
| from speculators.train.data import ( | ||
| Eagle3SampleFileDataset, | ||
| create_collate_fn, | ||
| split_files, | ||
| standardize_data_v0, | ||
| standardize_data_v1, | ||
| ) | ||
| from speculators.train.distributed_batch_sampler import ( | ||
| MultipackDistributedBatchSamplerV2, | ||
| ) | ||
| from speculators.train.eagle3.core import Eagle3DraftModel | ||
| from speculators.train.logger import setup_metric_logger, setup_root_logger | ||
| from speculators.train.noise_transforms import AddUniformNoise | ||
| from speculators.train.trainer import Trainer, TrainerConfig | ||
| from speculators.train.utils import maybe_destroy_distributed, maybe_setup_distributed | ||
|
|
||
| # DRAFTER MODEL HYPARAMETERS | ||
| NORM_BEFORE_RESIDUAL = True | ||
|
|
||
| # Dataloader | ||
| NUM_WORKERS = 12 | ||
| PREFETCH_FACTOR = 4 | ||
| NOISE_STD = 0.05 | ||
|
|
||
|
|
||
| def setup_dataloader( | ||
| file_list: list[str], | ||
| world_size: int, | ||
| local_rank: int, | ||
| add_noise: bool = True, | ||
| data_format_version: int = 1, | ||
fynnsu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) -> DataLoader: | ||
| """Setup dataloader for training. | ||
| Args: | ||
| file_list: List of file paths to load data from. | ||
| world_size: Number of processes in the distributed training. | ||
| local_rank: Rank of the current process. | ||
| add_noise: Whether to add noise to the data. | ||
| data_format_version: Version of the data format. Default is 1. | ||
| Returns: | ||
| DataLoader: Dataloader for training. | ||
| """ | ||
| if add_noise: | ||
| noise_transform = AddUniformNoise( | ||
| std=NOISE_STD, tensors=("hidden_states", "verifier_last_hidden_states") | ||
| ) | ||
| else: | ||
| noise_transform = None | ||
|
|
||
| standardize_fn = ( | ||
| standardize_data_v1 if data_format_version == 1 else standardize_data_v0 | ||
| ) | ||
|
|
||
| dataset = Eagle3SampleFileDataset( | ||
| file_list=file_list, | ||
| max_len=args.total_seq_len, | ||
| transform=noise_transform, | ||
| standardize_fn=standardize_fn, | ||
| ) | ||
| batch_sampler = MultipackDistributedBatchSamplerV2( | ||
| batch_max_length=args.total_seq_len, | ||
| lengths=dataset.approx_lengths, | ||
| num_replicas=world_size, | ||
| rank=local_rank, | ||
| ) | ||
| return DataLoader( | ||
| dataset, | ||
| batch_sampler=batch_sampler, | ||
| num_workers=NUM_WORKERS, | ||
| prefetch_factor=PREFETCH_FACTOR, | ||
| pin_memory=True, | ||
| collate_fn=create_collate_fn(args.total_seq_len), | ||
| persistent_workers=True, | ||
| ) | ||
|
|
||
|
|
||
| def main(args: argparse.Namespace): | ||
| # Setup logging | ||
| setup_root_logger() | ||
| setup_metric_logger( | ||
| loggers=args.logger, run_name=args.run_name, output_dir=args.log_dir | ||
| ) | ||
|
|
||
| # Setup distributed training | ||
| local_rank, world_size, rank, is_distributed = maybe_setup_distributed() | ||
| device = torch.device(local_rank) | ||
|
|
||
| # Setup speculator config | ||
| llama_config = LlamaConfig.from_pretrained(args.verifier_name_or_path) | ||
| llama_config.num_hidden_layers = args.num_layers | ||
| llama_config.model_type = "llama" # reset to llama (handles non-llama verifiers) | ||
| llama_config._attn_implementation = "simple_flex_attention" # noqa: SLF001 | ||
|
|
||
| # Load t2d and d2t tensors | ||
| d2t = torch.from_numpy(np.load(args.d2t_path)).to(device) | ||
| t2d = torch.from_numpy(np.load(args.t2d_path)).to(device) | ||
| draft_vocab_size = d2t.shape[0] | ||
|
|
||
| speculator_config = Eagle3SpeculatorConfig( | ||
| transformer_layer_config=llama_config, | ||
| draft_vocab_size=draft_vocab_size, | ||
HDCharles marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| norm_before_residual=NORM_BEFORE_RESIDUAL, | ||
| speculators_config=SpeculatorsConfig( | ||
| algorithm="eagle3", | ||
| proposal_methods=[ | ||
| GreedyTokenProposalConfig( | ||
| proposal_type="greedy", | ||
| speculative_tokens=args.ttt_steps, | ||
| ) | ||
| ], | ||
| default_proposal_method="greedy", | ||
| verifier=VerifierConfig( | ||
| name_or_path=args.verifier_name_or_path, | ||
| architectures=["LlamaForCausalLM"], | ||
| ), | ||
| ), | ||
| ) | ||
|
|
||
| # Setup draft model | ||
| draft_model = Eagle3DraftModel(config=speculator_config, t2d=t2d, d2t=d2t) | ||
|
|
||
| # Setup dataloaders | ||
| train_files, val_files = split_files(args.data_path, ratio=0.9) | ||
| train_loader = setup_dataloader( | ||
| train_files, | ||
| world_size, | ||
| local_rank, | ||
| add_noise=True, | ||
| data_format_version=args.data_format_version, | ||
| ) | ||
| val_loader = setup_dataloader( | ||
| val_files, | ||
| world_size, | ||
| local_rank, | ||
| add_noise=False, | ||
| data_format_version=args.data_format_version, | ||
| ) | ||
|
|
||
| # Setup trainer | ||
| trainer_config = TrainerConfig( | ||
| num_epochs=args.epochs, | ||
| save_path=args.save_path, | ||
| lr=args.lr, | ||
| resume_from_checkpoint=not args.no_resume_from_checkpoint, | ||
| is_distributed=is_distributed, | ||
| local_rank=local_rank, | ||
| train_call_kwargs={ | ||
| "use_off_policy_tokens": False, | ||
| "ttt_steps": args.ttt_steps, | ||
| "ttt_step_loss_decay": args.ttt_step_loss_decay, | ||
| }, | ||
| val_call_kwargs={ | ||
| "use_off_policy_tokens": False, | ||
| "ttt_steps": args.ttt_steps, | ||
| "ttt_step_loss_decay": args.ttt_step_loss_decay, | ||
| }, | ||
| ) | ||
| trainer = Trainer(draft_model, trainer_config, train_loader, val_loader) | ||
|
|
||
| # Run training | ||
| trainer.run_training() | ||
|
|
||
| # Cleanup | ||
| maybe_destroy_distributed() | ||
|
|
||
|
|
||
| def parse_args(): | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--verifier-name-or-path", type=str, required=True) | ||
| parser.add_argument("--data-path", type=str, default="./data") | ||
| parser.add_argument("--save-path", type=str, default="./checkpoints") | ||
| parser.add_argument("--epochs", type=int, default=20) | ||
| parser.add_argument("--lr", type=float, default=1e-4) | ||
| parser.add_argument("--no-resume-from-checkpoint", action="store_true") | ||
| parser.add_argument( | ||
| "--logger", | ||
| type=str, | ||
| default="", | ||
| help="One of 'trackio', 'wandb', 'tensorboard' or comma separated list of them", | ||
| ) | ||
| parser.add_argument("--total-seq-len", type=int, default=8192) | ||
| parser.add_argument("--data-format-version", type=int, default=1) | ||
| parser.add_argument("--log-dir", type=str, default="./logs") | ||
| parser.add_argument("--run-name", type=str, default=None) | ||
| parser.add_argument("--num-layers", type=int, default=1) | ||
| parser.add_argument("--d2t-path", type=str, default="d2t.npy") | ||
| parser.add_argument("--t2d-path", type=str, default="t2d.npy") | ||
| parser.add_argument("--ttt-steps", type=int, default=3) | ||
| parser.add_argument("--ttt-step-loss-decay", type=float, default=1.0) | ||
| return parser.parse_args() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args = parse_args() | ||
| main(args) | ||
|
|
||
|
|
||
| # RUN WITH: | ||
| # torchrun --nnodes=1 --nproc_per_node=<num_gpus> scripts/train.py | ||
| # for FSDP training | ||
| # OR | ||
| # python scripts/train.py | ||
| # for single GPU training | ||
Empty file.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.