Skip to content

Conversation

@fynnsu
Copy link
Collaborator

@fynnsu fynnsu commented Oct 3, 2025

This pr introduces Eagle3 Model training into the speculators repo. The implementation is specific to Eagle3 but designed in a way that enables future generalization to other speculative decoding algorithms.

Components

Example training script (scripts/train_llama3_8b_drafter.py)

Shows how to setup and run training. Currently specific to the meta-llama/Llama-3.1-8B-Instruct model but doesn't require many changes to run with a different model. Just need to update

VERIFIER_MODEL_NAME_OR_PATH = "meta-llama/Llama-3.1-8B-Instruct"
HIDDEN_SIZE = 4096  # Must match the verifier model's hidden size
VERIFIER_VOCAB_SIZE = 128256  # Must match the verifier model's vocab size

and pass in a new dataset and t2d / d2t tensors.

Flex Attention

Files:

  • src/speculators/train/eagle3/attention.py
  • tests/unit/train/test_eagle3_attention.py

The training code uses Flex attention which provides substantial speed ups and memory efficiency over the full dense attention operations.

Functions:

  • create_combined_mask_mod(lengths, total_seq_len): This function creates the mask function used by flex attention.
  • extend_mask_for_draft_tokens(block_mask): Helper function to extend the block mask without needed to check each new squares mask value
  • block_mask_to_dense_attention_mask: Only used for debugging purposes
  • flex_attention_forward: lightweight wrapper around flex attention call

Data processing

Files:

  • src/speculators/train/data.py

Data is currently expected in the format of 1 file per data sample. We load these samples and perform a shift to align input_ids, hidden_states, loss_mask, verifier_last_hidden_state correctly. We also automatically collate these samples into batches. Rather than padding and wasting compute on padded tokens, we instead concatenate the sequences along the sequence dimension, keeping track of the boundaries between sequences and setting the attention mask accordingly.

Batch sampling

Files:

  • src/speculators/train/distributed_batch_sampler.py
  • src/speculators/train/data.py

Due to hardware limitations, we set a maximum sequence length for each batch. We would like each batch of data to be close in size this max length, so that each batch has a similar number of tokens. The way we achieve this is through the MultipackDistributedBatchSamplerV2 taken from prior work I did on instructlab/training. This class produces indices of files that when batched together come close to reaching the max length without exceeding it. It also does this in a distributed aware manner so that there is no overlap in the data each rank sees.

To run the packing algorithm, we need to know the lengths of each sample in the dataset. Unfortunately, this would require opening every file in the dataset which is expensive, so instead we approximate the lengths (_compute_approx_lengths in data.py) using the length of the first sample and the relative file sizes of samples.

Eagle3DraftModel

Files:

  • src/speculators/train/eagle3/core.py

The draft model itself. Sets up and loads verifier components, as well as the draft layers / weights. Contains the model forward() pass which:

  • sets up the block mask for the batch
  • computes the target logits using the attached verifier_lm_head. Note: this is computed here for data storage efficiency reasons, as otherwise we would need to save the full logits: [seq_len, vocab_size] instead of the last layer hidden states: [seq_len, hidden_size] to disk. The verifier vocab_size is often > 100k whereas hidden_size might be around 4-8k.
  • For each ttt step:
    • Embeds tokens
    • concatenates with hidden_states
    • applies decoder layers
    • computes logits
    • computes loss and step accuracy
    • prepares next step tokens
    • Updates block mask

Layer definitions

Files:

  • src/speculators/train/eagle3/model_definitions.py

Currently just contains model definitions for llama3 style draft models. Supports norm_before_residual=True or False. Attempted to keep modifications to the original llama models minimal.

Distributed training via FSDP

Files:

  • src/speculators/train/utils.py
  • src/speculators/train/checkpointer.py
  • src/speculators/train/trainer.py (setup_model fn)

Full support for FSDP training by initializing the training script with torchrun --nnodes --nproc_per_node=N where N is the number of gpus. Tested with N=2,3,4, 8 and all work. FSDP training also enables Automatic Mixed Precision (AMP) for improved performance.

checkpointer.py contains checkpointing logic for FSDP distributed model weights (gather all weights on rank 0 before saving).

Note: the way distributed works in general is N copies of the script are started and all run the same code but with some env variables setting which lets each process know its rank. Then explicit dist.barrier() calls or implicit calls within FSDP forward/backwards hooks force each process to wait until they all reach the same point in the code, before continuing. It is important that all ranks reach these operations as it allows them to perform synchronized operations (such as gathering, reducing, etc). However, we can also limit certain code to only one rank (rank 0) so that we only log once, or save to checkpoint once, using simple if local_rank == 0 statements.

Logging

Files:

  • src/speculators/train/logger.py
  • scripts/train_llama3_8b_drafter.py: (setup logger calls at start of main())
  • src/speculators/train/trainer.py and other files: usage of metric_logger and root_logger

Another implementation mostly copied from prior work I did on instructlab/training. This uses python's std library logging module and extends it to support training metric logging. We can log a nested dict of metrics anywhere in the codebase like so:

# Setup once
import logging
metric_logger = logging.getLogger("speculators.metrics")

# Log call
metric_logger.info(
    {"train": {"loss": loss.item(), **acc_values}, "epoch": epoch},
    extra={"step": self.global_step},
)

And when the user runs the training script they can select one (or multiple) of tensorboard, wandb, and trackio and the results will be logged to the respective experiment tracker.

There is also a root_logger which can be used for regular update logging and everything logged to either the root_logger or metric_logger will be pretty-printed to console.

Trainer

Files:

  • src/speculators/train/trainer.py

The Trainer class is initialized with the model, data loaders, and a config and:

  • Sets up model / optimizer (loads weights and configures distributed if needed)
  • Contains the training and validation loops (train_epoch and val_epoch respectively)
  • And the overall training loop which alternatives between training, validation, and saving checkpoints

Todos:

  • Eagle3Draft Model definition with TTT steps and loss calculations
  • Patched Decoder layer definitions
  • Simple data loading from sample files
  • FlexAttention masking and implementation
  • Loss Masking
  • Training loop
    • Train data loader
    • loss.backward() + optimizer steps
    • Distributed loss reduction
    • Val data loader
    • Metric collection/reporting
    • Model checkpointing
  • Data batching
    • Collate fn
    • Batch sampler (dynamic batch size through sample packing)
    • Distributed (rank) aware sampling
  • Distributed support
  • Code relocation / merging with existing definitions (Currently just have everything under speculators/train but this will need to change) FUTURE PR
  • Verify correctness of key components (attention masking, data token alignment, etc).
  • General testing

Essential todos (as of 10/22/2025):

  • Save checkpoints to safetensors format w/ required config info
  • Implement save best or save last logic (currently saving every epoch) FUTURE PR
  • Better Verifier lm_head, embed_tokens loading (requires added loading util for specific layers #144)
  • Eagle3DraftModel.__init__ signature cleanup/better configuration
  • Config/argparsing for scripts/train.py FUTURE PR
  • Ensure flex attention impl works with torch==2.9 and torch.compile
  • Fix lint / quality / type errors and pass CI

@github-actions
Copy link

github-actions bot commented Oct 3, 2025

📦 Build Artifacts Available
The build artifacts (`.whl` and `.tar.gz`) have been successfully generated and are available for download: https://github.com/vllm-project/speculators/actions/runs/18784925458/artifacts/4364832910.
They will be retained for up to 30 days.
Commit: 7f97d58

@fynnsu fynnsu force-pushed the eagle3_training branch 4 times, most recently from 33b96a6 to 3d12f28 Compare October 8, 2025 21:30
@fynnsu fynnsu force-pushed the eagle3_training branch 8 times, most recently from e8c2ee2 to 6b5ddca Compare October 23, 2025 21:13
fynnsu added 17 commits October 23, 2025 21:13
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
…ackends

Supports logging to tensorboard, wandb, trackio, and console, including multiple backends at once



Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
The Checkpointer class was specific to FSDP DTensor style sharded model checkpointing.
I instead converted it to an abstract-ish base class and created two subclasses for single 
device and multi-device checkpointing respectively.

Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
fynnsu added 21 commits October 23, 2025 21:13
- Boolean masking issue
- Not loading from checkpoint when distributed is disabled
- Added gradient clipping
- Matched norm classes to research impl
- Disabled gradient on embed_tokens
- Fixed input_id overwriting offset


Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
…ith loss calculations

Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
@fynnsu fynnsu changed the title [WIP] Eagle3 Training Implementation Eagle3 Training Oct 24, 2025
@fynnsu fynnsu marked this pull request as ready for review October 24, 2025 17:05
@fynnsu fynnsu requested a review from eldarkurtic October 24, 2025 18:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant