-
Notifications
You must be signed in to change notification settings - Fork 11
Eagle3 Training #143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Eagle3 Training #143
Conversation
|
📦 Build Artifacts Available |
33b96a6 to
3d12f28
Compare
2df7e2c to
129adb3
Compare
brian-dellabetta
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent work! In future, consider splitting a PR like this up into separate smaller PRs that can be merged over time. This looks like it could be split up into a few -- logging, trainer, dataset class, llama-specific code.
A few comments from an outsider's perspective. Since this is all entirely new, i'm sure there will be some validation after this lands, but consider an e2e test here or in a follow-up
3055edf to
345b314
Compare
kylesayrs
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super beautiful, really great job.
The only part that makes me slightly nervous is that reimplementation of model definitions, and the fixed ModelComponents structure, as this pattern makes supporting new models harder/ more rigid. If there's any way to make this more general/ provide a good programming model, that would be nice. Otherwise, this is good for now.
0d6e939 to
9dd2c86
Compare
Yeah that's fair. This is also my least favorite part :( Unfortunately, we do need to make a few changes to the DecoderLayer but I did my best to minimize these and clearly mark them with comments. I'm happy to look further into methods for removing these modifications but right now it isn't super clear to me how we could do that. As for Supporting other drafter architectures for spec decoding is less critical than supporting other model architectures is because we don't need to match the drafter architecture with the verifier's. e.g. there's nothing stopping you from training a Llama drafter model on a Qwen verifier. |
2638cff to
b075197
Compare
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
- Only load files ending with `pt` - Enforce loading on cpu Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
…Model Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Although the object is a `BlockMask`, rename it to `attention_mask` so that the naming aligns with what is used in the transformer components (DecoderBlock and attention fn) Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
In the Eagle3 Algorithm, the first layer needs to be modified to support a larger (2x) hidden dim, while subsequent layers behave as regular. Previously, we used a special decoder layer class that behaved differently depending on the `layer_idx` it received. Now we instead, use the special class only for the first layer and switch to using the original `LlamaDecoderLayer` for subsequent layers. Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Simplifies the code, while removing the option for an non-zero Gaussian transform mean. Also changes default values for standard deviation and fixes scaling issue for uniform transform Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
a71188e to
fc005e1
Compare
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
This pr introduces Eagle3 Model training into the speculators repo. The implementation is specific to Eagle3 but designed in a way that enables future generalization to other speculative decoding algorithms.
Components
Example training script (
scripts/train_llama3_8b_drafter.pyscripts/train.py)Shows how to setup and run training.
Currently specific to the
meta-llama/Llama-3.1-8B-Instructmodel but doesn't require many changes to run with a different model. Just need to updateVERIFIER_MODEL_NAME_OR_PATH = "meta-llama/Llama-3.1-8B-Instruct" HIDDEN_SIZE = 4096 # Must match the verifier model's hidden size VERIFIER_VOCAB_SIZE = 128256 # Must match the verifier model's vocab sizeUpdate: I've generalize the training script. It now has a required cli arg
--verifier_name_or_pathand supports arbitrary verifier models. Note: this usesLlamaConfig.from_pretrained(args.verifier_name_or_path)under the hood, which does work for non-llama models (e.g. a Qwen model) but prints a warning and may not work for every type of verifier.You will also need to pass in a dataset and
t2d/d2ttensors which correspond to the verifier you are using.Flex Attention
Files:
src/speculators/train/eagle3/attention.pytests/unit/train/test_eagle3_attention.pyThe training code uses Flex attention which provides substantial speed ups and memory efficiency over the full dense attention operations.
Functions:
Data processing
Data is currently expected in the format of 1 file per data sample. We load these samples and perform a shift to align
input_ids, hidden_states, loss_mask, verifier_last_hidden_statecorrectly. We also automatically collate these samples into batches. Rather than padding and wasting compute on padded tokens, we instead concatenate the sequences along the sequence dimension, keeping track of the boundaries between sequences and setting the attention mask accordingly.Batch sampling
Files:
src/speculators/train/distributed_batch_sampler.pysrc/speculators/train/data.pyDue to hardware limitations, we set a maximum sequence length for each batch. We would like each batch of data to be close in size this max length, so that each batch has a similar number of tokens. The way we achieve this is through the
MultipackDistributedBatchSamplerV2taken from prior work I did on instructlab/training. This class produces indices of files that when batched together come close to reaching the max length without exceeding it. It also does this in a distributed aware manner so that there is no overlap in the data each rank sees.To run the packing algorithm, we need to know the lengths of each sample in the dataset. Unfortunately, this would require opening every file in the dataset which is expensive, so instead we approximate the lengths (
_compute_approx_lengthsindata.py) using the length of the first sample and the relative file sizes of samples.Eagle3DraftModelFiles:
src/speculators/train/eagle3/core.pyThe draft model itself. Sets up and loads verifier components, as well as the draft layers / weights. Contains the model
forward()pass which:verifier_lm_head. Note: this is computed here for data storage efficiency reasons, as otherwise we would need to save the full logits:[seq_len, vocab_size]instead of the last layer hidden states:[seq_len, hidden_size]to disk. The verifiervocab_sizeis often > 100k whereashidden_sizemight be around 4-8k.Layer definitions
Files:
src/speculators/train/eagle3/model_definitions.pyCurrently just contains model definitions for llama3 style draft models. Supports
norm_before_residual=True or False. Attempted to keep modifications to the original llama models minimal.Distributed training via FSDP
Files:
src/speculators/train/utils.pysrc/speculators/train/checkpointer.pysrc/speculators/train/trainer.py(setup_modelfn)Full support for FSDP training by initializing the training script with
torchrun --nnodes --nproc_per_node=NwhereNis the number of gpus. Tested withN=2,3,4, 8and all work. FSDP training also enables Automatic Mixed Precision (AMP) for improved performance.checkpointer.pycontains checkpointing logic for FSDP distributed model weights (gather all weights on rank 0 before saving).Note: the way distributed works in general is
Ncopies of the script are started and all run the same code but with some env variables setting which lets each process know its rank. Then explicitdist.barrier()calls or implicit calls within FSDP forward/backwards hooks force each process to wait until they all reach the same point in the code, before continuing. It is important that all ranks reach these operations as it allows them to perform synchronized operations (such as gathering, reducing, etc). However, we can also limit certain code to only one rank (rank 0) so that we only log once, or save to checkpoint once, using simpleif local_rank == 0statements.Logging
Files:
src/speculators/train/logger.pyscripts/train.py: (setup logger calls at start ofmain())src/speculators/train/trainer.pyand other files: usage ofmetric_loggerandroot_loggerAnother implementation mostly copied from prior work I did on instructlab/training. This uses python's std library
loggingmodule and extends it to support training metric logging. We can log a nested dict of metrics anywhere in the codebase like so:And when the user runs the training script they can select one (or multiple) of
tensorboard,wandb, andtrackioand the results will be logged to the respective experiment tracker.There is also a
root_loggerwhich can be used for regular update logging and everything logged to either theroot_loggerormetric_loggerwill be pretty-printed to console.TrainerFiles:
src/speculators/train/trainer.pyThe
Trainerclass is initialized with the model, data loaders, and a config and:train_epochandval_epochrespectively)Todos:
loss.backward()+ optimizer stepsCode relocation / merging with existing definitions (Currently just have everything underFUTURE PRspeculators/trainbut this will need to change)Essential todos (as of 10/22/2025):
Implement save best or save last logic (currently saving every epoch)FUTURE PRlm_head,embed_tokensloading (requires added loading util for specific layers #144)Eagle3DraftModel.__init__signature cleanup/better configurationConfig/argparsing forFUTURE PRscripts/train.pytorch==2.9andtorch.compile