-
Notifications
You must be signed in to change notification settings - Fork 11
Eagle3 Training #143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
fynnsu
wants to merge
44
commits into
main
Choose a base branch
from
eagle3_training
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Eagle3 Training #143
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
|
📦 Build Artifacts Available |
33b96a6 to
3d12f28
Compare
e8c2ee2 to
6b5ddca
Compare
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
…ackends Supports logging to tensorboard, wandb, trackio, and console, including multiple backends at once Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
The Checkpointer class was specific to FSDP DTensor style sharded model checkpointing. I instead converted it to an abstract-ish base class and created two subclasses for single device and multi-device checkpointing respectively. Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
- Boolean masking issue - Not loading from checkpoint when distributed is disabled - Added gradient clipping - Matched norm classes to research impl - Disabled gradient on embed_tokens - Fixed input_id overwriting offset Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
…ith loss calculations Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
…n to break Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
… configs Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Uses #144 Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
6b5ddca to
2df7e2c
Compare
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
2df7e2c to
129adb3
Compare
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This pr introduces Eagle3 Model training into the speculators repo. The implementation is specific to Eagle3 but designed in a way that enables future generalization to other speculative decoding algorithms.
Components
Example training script (
scripts/train_llama3_8b_drafter.py)Shows how to setup and run training. Currently specific to the
meta-llama/Llama-3.1-8B-Instructmodel but doesn't require many changes to run with a different model. Just need to updateand pass in a new dataset and
t2d/d2ttensors.Flex Attention
Files:
src/speculators/train/eagle3/attention.pytests/unit/train/test_eagle3_attention.pyThe training code uses Flex attention which provides substantial speed ups and memory efficiency over the full dense attention operations.
Functions:
Data processing
Files:
src/speculators/train/data.pyData is currently expected in the format of 1 file per data sample. We load these samples and perform a shift to align
input_ids, hidden_states, loss_mask, verifier_last_hidden_statecorrectly. We also automatically collate these samples into batches. Rather than padding and wasting compute on padded tokens, we instead concatenate the sequences along the sequence dimension, keeping track of the boundaries between sequences and setting the attention mask accordingly.Batch sampling
Files:
src/speculators/train/distributed_batch_sampler.pysrc/speculators/train/data.pyDue to hardware limitations, we set a maximum sequence length for each batch. We would like each batch of data to be close in size this max length, so that each batch has a similar number of tokens. The way we achieve this is through the
MultipackDistributedBatchSamplerV2taken from prior work I did on instructlab/training. This class produces indices of files that when batched together come close to reaching the max length without exceeding it. It also does this in a distributed aware manner so that there is no overlap in the data each rank sees.To run the packing algorithm, we need to know the lengths of each sample in the dataset. Unfortunately, this would require opening every file in the dataset which is expensive, so instead we approximate the lengths (
_compute_approx_lengthsindata.py) using the length of the first sample and the relative file sizes of samples.Eagle3DraftModelFiles:
src/speculators/train/eagle3/core.pyThe draft model itself. Sets up and loads verifier components, as well as the draft layers / weights. Contains the model
forward()pass which:verifier_lm_head. Note: this is computed here for data storage efficiency reasons, as otherwise we would need to save the full logits:[seq_len, vocab_size]instead of the last layer hidden states:[seq_len, hidden_size]to disk. The verifiervocab_sizeis often > 100k whereashidden_sizemight be around 4-8k.Layer definitions
Files:
src/speculators/train/eagle3/model_definitions.pyCurrently just contains model definitions for llama3 style draft models. Supports
norm_before_residual=True or False. Attempted to keep modifications to the original llama models minimal.Distributed training via FSDP
Files:
src/speculators/train/utils.pysrc/speculators/train/checkpointer.pysrc/speculators/train/trainer.py(setup_modelfn)Full support for FSDP training by initializing the training script with
torchrun --nnodes --nproc_per_node=NwhereNis the number of gpus. Tested withN=2,3,4, 8and all work. FSDP training also enables Automatic Mixed Precision (AMP) for improved performance.checkpointer.pycontains checkpointing logic for FSDP distributed model weights (gather all weights on rank 0 before saving).Note: the way distributed works in general is
Ncopies of the script are started and all run the same code but with some env variables setting which lets each process know its rank. Then explicitdist.barrier()calls or implicit calls within FSDP forward/backwards hooks force each process to wait until they all reach the same point in the code, before continuing. It is important that all ranks reach these operations as it allows them to perform synchronized operations (such as gathering, reducing, etc). However, we can also limit certain code to only one rank (rank 0) so that we only log once, or save to checkpoint once, using simpleif local_rank == 0statements.Logging
Files:
src/speculators/train/logger.pyscripts/train_llama3_8b_drafter.py: (setup logger calls at start ofmain())src/speculators/train/trainer.pyand other files: usage ofmetric_loggerandroot_loggerAnother implementation mostly copied from prior work I did on instructlab/training. This uses python's std library
loggingmodule and extends it to support training metric logging. We can log a nested dict of metrics anywhere in the codebase like so:And when the user runs the training script they can select one (or multiple) of
tensorboard,wandb, andtrackioand the results will be logged to the respective experiment tracker.There is also a
root_loggerwhich can be used for regular update logging and everything logged to either theroot_loggerormetric_loggerwill be pretty-printed to console.TrainerFiles:
src/speculators/train/trainer.pyThe
Trainerclass is initialized with the model, data loaders, and a config and:train_epochandval_epochrespectively)Todos:
loss.backward()+ optimizer stepsCode relocation / merging with existing definitions (Currently just have everything underFUTURE PRspeculators/trainbut this will need to change)Essential todos (as of 10/22/2025):
Implement save best or save last logic (currently saving every epoch)FUTURE PRlm_head,embed_tokensloading (requires added loading util for specific layers #144)Eagle3DraftModel.__init__signature cleanup/better configurationConfig/argparsing forFUTURE PRscripts/train.pytorch==2.9andtorch.compile