diff --git a/scripts/train.py b/scripts/train.py new file mode 100644 index 0000000..9888861 --- /dev/null +++ b/scripts/train.py @@ -0,0 +1,145 @@ +import numpy as np +import torch +from torch.utils.data import DataLoader +from transformers import LlamaConfig + +from speculators.train.data import ( + AddUniformNoise, + Eagle3SampleFileDataset, + create_collate_fn, + split_files, +) +from speculators.train.distributed_batch_sampler import ( + MultipackDistributedBatchSamplerV2, +) +from speculators.train.eagle3.core import Eagle3DraftModel, Eagle3VerifierLMHead +from speculators.train.logger import setup_metric_logger, setup_root_logger +from speculators.train.trainer import Trainer +from speculators.train.utils import maybe_destroy_distributed, maybe_setup_distributed + +local_rank, world_size, rank, is_distributed = maybe_setup_distributed() + + +DEVICE = torch.device(local_rank) +EPOCHS = 102 +draft_vocab_size = 32000 +total_seq_len = 8192 +datapath = "./data" +verifier_model_name_or_path = "meta-llama/Llama-3.1-8B-Instruct" + + +# TEMP MODEL SETUP +llama_config = LlamaConfig.from_pretrained(verifier_model_name_or_path) +hidden_size = llama_config.hidden_size +verifier_vocab_size = llama_config.vocab_size +llama_config = LlamaConfig(hidden_size=hidden_size, vocab_size=verifier_vocab_size) +llama_config._attn_implementation = "simple_flex_attention" + +# d2t = torch.zeros(draft_vocab_size, dtype=torch.long).to(DEVICE) +# t2d = ( +# torch.cat( +# [ +# torch.ones(draft_vocab_size), +# torch.zeros(llama_config.vocab_size - draft_vocab_size), +# ] +# ) +# .to(torch.bool) +# .to(DEVICE) +# ) +d2t = torch.from_numpy(np.load("d2t.npy")).to(DEVICE) +t2d = torch.from_numpy(np.load("t2d.npy")).to(DEVICE) + +setup_metric_logger(loggers="trackio", run_name=None, output_dir="./logs") +setup_root_logger() +# END TEMP MODEL SETUP + +draft_model = Eagle3DraftModel( + verifier_model_name_or_path=verifier_model_name_or_path, + hidden_size=hidden_size, + t2d=t2d, + d2t=d2t, + decoder_layer_config=llama_config, + verifier_vocab_size=verifier_vocab_size, + verifier_pad_token_id=None, + num_layers=1, + ttt_steps=3, +) + +verifier_lm_head = Eagle3VerifierLMHead( + hidden_size=hidden_size, draft_vocab_size=draft_vocab_size +) +verifier_lm_head.load_verifier_lm_head(verifier_model_name_or_path, t2d) + +### TMP +draft_model.lm_head.weight.data = verifier_lm_head.lm_head.weight.data.to(t2d.device) +### +noise_transform = AddUniformNoise( + std=0.2, tensors=("hidden_states", "verifier_last_hidden_states") +) + +train_files, val_files = split_files(datapath, ratio=0.9) +train_dataset = Eagle3SampleFileDataset( + file_list=train_files, max_len=total_seq_len, transform=noise_transform +) +train_batch_sampler = MultipackDistributedBatchSamplerV2( + batch_max_length=total_seq_len, + lengths=train_dataset.approx_lengths(), + num_replicas=world_size, + rank=local_rank, +) +train_loader = DataLoader( + train_dataset, + batch_sampler=train_batch_sampler, + num_workers=32, + prefetch_factor=8, + pin_memory=True, + collate_fn=create_collate_fn(total_seq_len), + persistent_workers=True, +) + +val_dataset = Eagle3SampleFileDataset(file_list=val_files, max_len=total_seq_len) +val_batch_sampler = MultipackDistributedBatchSamplerV2( + batch_max_length=total_seq_len, + lengths=val_dataset.approx_lengths(), + num_replicas=world_size, + rank=local_rank, +) +val_loader = DataLoader( + val_dataset, + batch_sampler=val_batch_sampler, + num_workers=32, + prefetch_factor=8, + pin_memory=True, + collate_fn=create_collate_fn(total_seq_len), + persistent_workers=True, +) + + +# todo: make config better +config = { + "num_epochs": EPOCHS, + "save_path": "./checkpoints", + "lr": 1e-4, + "total_seq_len": total_seq_len, + "datapath": "./data", + "resume_from_checkpoint": True, +} + + +trainer = Trainer( + draft_model, + verifier_lm_head, + config, + train_loader, + val_loader, + is_distributed, + local_rank, + world_size, +) +trainer.run_training() + +maybe_destroy_distributed() + + +# RUN WITH: +# CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nnodes=1 --nproc_per_node=4 scripts/train.py diff --git a/src/speculators/train/checkpointer.py b/src/speculators/train/checkpointer.py new file mode 100644 index 0000000..d55382c --- /dev/null +++ b/src/speculators/train/checkpointer.py @@ -0,0 +1,151 @@ +from abc import abstractmethod +from pathlib import Path + +import torch +import torch.distributed as dist +from safetensors import safe_open +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + get_optimizer_state_dict, + set_model_state_dict, + set_optimizer_state_dict, +) +from transformers.modeling_utils import PreTrainedModel + + +class BaseCheckpointer: + """Helper class to save and load checkpoints. + + Checkpoint file structure: + ../path/ + 0/ # epoch number + model.safetensors + optimizer_state_dict.pt + 1/ + model.safetensors + optimizer_state_dict.pt + ... + """ + + def __init__(self, path: Path | str, try_load_last_checkpoint: bool = True): + self.path = Path(path) + if try_load_last_checkpoint: + self.previous_epoch: int = self._get_previous_epoch() + else: + self.previous_epoch: int = -1 + + @abstractmethod + def load_model_state_dict(self, model: PreTrainedModel): + raise NotImplementedError + + @abstractmethod + def load_optimizer_state_dict( + self, model: PreTrainedModel, optimizer: torch.optim.Optimizer + ): + raise NotImplementedError + + @abstractmethod + def save_checkpoint( + self, model: PreTrainedModel, optimizer: torch.optim.Optimizer, epoch: int + ): + raise NotImplementedError + + def _get_previous_epoch(self) -> int: + if not self.path.exists(): + return -1 + last_checkpoint_num = -1 + for d in self.path.iterdir(): + if d.is_dir(): + try: + last_checkpoint_num = max(last_checkpoint_num, int(d.name)) + except ValueError: + continue + return last_checkpoint_num + + def model_path(self, epoch: int): + model_fname = "model.safetensors" + return self.path / str(epoch) / model_fname + + def optimizer_path(self, epoch: int): + optimizer_fname = "optimizer_state_dict.pt" + return self.path / str(epoch) / optimizer_fname + + +def load_safetensors_state_dict(path: Path, device: str) -> dict[str, torch.Tensor]: + full_state_dict = {} + with safe_open(path, framework="pt", device=device) as f: + for key in f.keys(): + full_state_dict[key] = f.get_tensor(key) + return full_state_dict + + +class SingleGPUCheckpointer(BaseCheckpointer): + def load_model_state_dict(self, model: PreTrainedModel): + full_state_dict = load_safetensors_state_dict( + self.model_path(self.previous_epoch), "cuda:0" + ) + model.load_state_dict(full_state_dict) + + def load_optimizer_state_dict( + self, model: PreTrainedModel, optimizer: torch.optim.Optimizer + ): + full_state_dict = torch.load( + self.optimizer_path(self.previous_epoch), + weights_only=True, + map_location="cuda:0", # todo: make this configurable + ) + optimizer.load_state_dict(full_state_dict) + + def save_checkpoint( + self, model: PreTrainedModel, optimizer: torch.optim.Optimizer, epoch: int + ): + model.save_pretrained(self.path / str(epoch)) + torch.save(optimizer.state_dict(), self.optimizer_path(epoch)) + + +class DistributedCheckpointer(BaseCheckpointer): + def load_model_state_dict(self, model: PreTrainedModel): + full_state_dict = load_safetensors_state_dict( + self.model_path(self.previous_epoch), "cpu" + ) + set_model_state_dict( + model, + full_state_dict, + options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True), + ) + dist.barrier() + + def load_optimizer_state_dict(self, model, optimizer: torch.optim.Optimizer): + full_state_dict = torch.load( + self.optimizer_path(self.previous_epoch), + mmap=True, + weights_only=True, + map_location="cpu", + ) + set_optimizer_state_dict( + model, + optimizer, + full_state_dict, + options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True), + ) + dist.barrier() + + def save_checkpoint( + self, model: PreTrainedModel, optimizer: torch.optim.Optimizer, epoch: int + ): + model_state_dict = get_model_state_dict( + model, options=StateDictOptions(full_state_dict=True, cpu_offload=True) + ) + optimizer_state_dict = get_optimizer_state_dict( + model, + optimizer, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + + if dist.get_rank() == 0: + # Only rank 0 saves the checkpoint + model.save_pretrained(self.path / str(epoch), state_dict=model_state_dict) + torch.save(optimizer_state_dict, self.optimizer_path(epoch)) + + dist.barrier() diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py new file mode 100644 index 0000000..6f996b2 --- /dev/null +++ b/src/speculators/train/data.py @@ -0,0 +1,219 @@ +import math +import os +import random +from functools import lru_cache +from typing import Any + +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset + +BatchType = dict[str, Any] + + +class TransformTensors: + def __init__(self, tensors): + self.tensors = tensors + + def __call__(self, data): + for tensor in self.tensors: + data[tensor] = self.transform(data[tensor]) + return data + + def transform(self, tensor: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("Subclasses must implement this method") + + +class AddGaussianNoise(TransformTensors): + def __init__(self, mean=0.0, std=0.2, tensors=("hidden_states",)): + super().__init__(tensors) + self.mean = mean + self.std = std + + def transform(self, tensor: torch.Tensor) -> torch.Tensor: + return tensor + torch.randn_like(tensor) * self.std + self.mean + + +class AddUniformNoise(TransformTensors): + def __init__(self, std=0.2, tensors=("hidden_states",)): + super().__init__(tensors) + self.std = std + + def transform(self, tensor: torch.Tensor) -> torch.Tensor: + return tensor + (torch.rand_like(tensor) - 0.5) * self.std + + +def list_files(path): + datapath = [] + for root, _directories, files in os.walk(path): + for file in files: + file_path = os.path.join(root, file) + datapath.append(file_path) + + return datapath + + +def slice_and_pad_to_length(tensor, length): + sliced_tensor = tensor[:length] + padding = [0, 0] * sliced_tensor.dim() + padding[-1] = length - sliced_tensor.shape[0] + return F.pad(sliced_tensor, padding) + + +def shift_batch(batch: BatchType): + input_ids = batch["input_ids"] # shape: [seq_len] + # [x0, x1, x2, x3, x4, x5, x6, x7, x8, x9] + hidden_states = batch["hidden_states"] # shape: [seq_len, hidden_size] + # [g0, g1, g2, g3, g4, g5, g6, g7, g8, g9] + verifier_last_hidden_states = batch[ + "verifier_last_hidden_states" + ] # shape: [seq_len, hidden_size] + # [y0, y1, y2, y3, y4, y5, y6, y7, y8, y9] + loss_mask = batch["loss_mask"] # shape: [seq_len] + # [l0, l1, l2, l3, l4, l5, l6, l7, l8, l9] + lengths = batch["lengths"] # shape: [1] + # [10] + position_ids = batch["position_ids"] # shape: [seq_len] + # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + + # Need to align (x1, g0, y1, l1) + # todo: verify loss mask shift is correct + + # Drop x0, g(-1), y0, l0, reduce seq_len by 1 + + input_ids = input_ids[1:] + hidden_states = hidden_states[:-1] + verifier_last_hidden_states = verifier_last_hidden_states[1:] + loss_mask = loss_mask[1:] + lengths = lengths - 1 + position_ids = position_ids[1:] # Note: position_ids now start at 1 + + return { + "input_ids": input_ids, + "hidden_states": hidden_states, + "verifier_last_hidden_states": verifier_last_hidden_states, + "loss_mask": loss_mask, + "lengths": lengths, + "position_ids": position_ids, + } + + +def split_files(datapath: str, ratio: float = 0.9, seed: int = 0): + """Given a datapath, split the files into a training and validation set + ratio is the proportion of files to put in the training set + 1 - ratio is the proportion of files to put in the validation set + """ + random.seed(seed) + file_list = list_files(datapath) + random.shuffle(file_list) + num_files = len(file_list) + num_train_files = int(num_files * ratio) + train_files = file_list[:num_train_files] + val_files = file_list[num_train_files:] + return train_files, val_files + + +class Eagle3SampleFileDataset(Dataset): + def __init__( + self, + max_len: int, + datapath: str | None = None, + file_list: list[str] | None = None, + transform=None, + hidden_states_dtype=torch.float, + ): + if datapath is not None and file_list is not None: + raise ValueError("Only one of datapath or file_list may be provided") + + if datapath is not None: + file_list = list_files(datapath) + elif file_list is None: + raise ValueError("Either datapath or file_list must be provided") + + self.data = file_list + self.max_len = max_len + self.transform = transform + self.hidden_states_dtype = hidden_states_dtype + + def __len__(self): + return len(self.data) + + @lru_cache(maxsize=1) + def approx_lengths(self): + lengths_0 = self.__getitem__(0)["lengths"] + # this is a single sample so there is only one length + lengths_0 = lengths_0[0].item() + size_0 = os.path.getsize(self.data[0]) + + approx_lengths = [ + math.ceil(os.path.getsize(fname) / size_0 * lengths_0) + for fname in self.data + ] + return approx_lengths + + def __getitem__(self, index) -> BatchType: + data = torch.load(self.data[index]) + + # todo: standardize names during data generation and then remove this + data["hidden_states"] = data["hidden_state"] + data["verifier_last_hidden_states"] = data["target"] + del data["hidden_state"] + del data["target"] + + # todo: standardize dtypes during data generation and then remove this + data = { + k: v.to(self.hidden_states_dtype) if "hidden_states" in k else v + for k, v in data.items() + } + + seq_len = data["input_ids"].shape[0] + # Add lengths tensor + data["lengths"] = torch.tensor([seq_len], dtype=torch.long) + + if self.transform: + data = self.transform(data) + + data["position_ids"] = torch.arange(seq_len, dtype=torch.long) + # shape: [seq_len] + + # data structure: { + # "hidden_states": [seq_len, 3 * hidden_size], + # "input_ids": [seq_len], + # "verifier_last_hidden_states": [seq_len, hidden_size], + # "loss_mask": [seq_len], + # "lengths": [1], + # "position_ids": [seq_len], + # } + + # Note: shift_batch will reduce seq_len by 1 + data = shift_batch(data) + + return data + + +def create_collate_fn(max_len: int): + def collate_fn(batch: list[BatchType]) -> BatchType: + collated_data = {} + for key in batch[0].keys(): + collated_data[key] = torch.cat([b[key] for b in batch], dim=0) + + if key != "lengths": + collated_data[key] = slice_and_pad_to_length( + collated_data[key], max_len + ).unsqueeze(0) + # shape: [1, max_len, ...] + + # Handle lengths update + lengths = collated_data["lengths"] + new_lengths = [] + cum_length = 0 + for length in lengths: + if length + cum_length >= max_len: + new_lengths.append(max_len - cum_length) + break + new_lengths.append(length) + cum_length += length + collated_data["lengths"] = torch.tensor(new_lengths, dtype=torch.long) + return collated_data + + return collate_fn diff --git a/src/speculators/train/distributed_batch_sampler.py b/src/speculators/train/distributed_batch_sampler.py new file mode 100644 index 0000000..cb0c452 --- /dev/null +++ b/src/speculators/train/distributed_batch_sampler.py @@ -0,0 +1,211 @@ +""" +MIT License + +Copyright (c) 2023 One + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +Adapted from https://github.com/imoneoi/multipack_sampler. +""" + +# Standard +import warnings +from functools import lru_cache +from heapq import heapreplace +from typing import NamedTuple + +import numpy as np + +# Third Party +from numpy.typing import ArrayLike, NDArray +from torch.utils.data import Sampler + + +## Multipack Distributed Batch Sampler +class _Bin(NamedTuple): + """Helper named tuple for `lpt_packed_batch`""" + + fill: int # sum of items in _Bin + rank: int # device rank _Bin is associated with + + +def _lpt_packed_batch( + lengths: np.ndarray, max_len: int, num_replicas: int, start_index: int, rank: int +) -> None | list: + """ + Check if lengths can be distributed into `num_replicas` machines with at most `max_len` tokens per machine and return local rank's batch. + + Uses the LPT (Longest processing time first scheduling) algorithm + Time: O(|lengths| log |lengths| + |lengths| log replicas) + + Returns: + `None` if unable to find a valid packing. Otherwise return the batch indices that correspond to `rank`. + """ + + # Greedily assign lengths (in decreasing order) to the least full rank until they are all assigned or + # we run out of space. + local_batch = [] + heap = [_Bin(0, i) for i in range(num_replicas)] + + # sort in descending order + indices = np.argsort(lengths)[::-1] + + for idx, size in zip(indices, lengths[indices]): + new_fill = heap[0].fill + size + if new_fill > max_len: + # Size doesn't fit in least full batch (or any others), can't satisfy requirements + return None + + if heap[0].rank == rank: + # minimum bucket corresponds to the local rank -> add idx to local batch + local_batch.append(start_index + idx) + + _ = heapreplace(heap, _Bin(new_fill, heap[0].rank)) + + return local_batch + + +def _assign_to_packed_batches( + lengths: np.ndarray, max_len: int, rank: int, replicas: int +) -> list[NDArray]: + """Distribute lengths to batches across all ranks, while respecting batch_max_length. Uses a binary search + LPT algorithm + + Args: + lengths (np.ndarray): array of dataset sample lengths + max_len (int): maximum allowed sum of lengths in batch + rank (int): local rank to collect batches for + replicas (int): world size to distribute batches to + + Returns: + tuple[list, int, int]: + - list of np.arrays containing the indices for each batch on the local rank + - sum of dataset lengths included (total sum of lengths in dataset minus any that were dropped at end of dataset) + - total token capacity if each batch maxed out batch_max_length + """ + + lengths_so_far = 0 + ind = 0 + result = [] + lengths_cumsum = np.cumsum(lengths) + + # binary search for max integer x such that the next x elements in shuffled lengths array can be packed into `num_replicas` batches + # Add the local rank's batch to `result` and repeat until end of dataset + while True: + if len(lengths) - ind < replicas: + # Not enough lengths left to pack into `num_replicas` batches + # Break and drop whatever lengths we have left + break + + # binary search in [1, 1 + upper bound for x) + left = 1 + right = 1 + np.searchsorted( + lengths_cumsum[ind:], lengths_so_far + max_len * replicas, "right" + ) + + batch = None + while right - left > 1 and right > replicas: + mid = (left + right) // 2 + batch = _lpt_packed_batch( + lengths[ind : ind + mid], max_len, replicas, ind, rank + ) + if batch is None: + right = mid + else: + left = mid + + if batch is None: + batch = _lpt_packed_batch( + lengths[ind : ind + left], max_len, replicas, ind, rank + ) + + ind += left + lengths_so_far = lengths_cumsum[ind - 1] + + # append only result for local rank (already filtered in lpt_packed_batch) + result.append(batch) + + return result + + +class MultipackDistributedBatchSamplerV2(Sampler): + def __init__( + self, + batch_max_length: int, + lengths: ArrayLike, + num_replicas: int, + rank: int, + seed: int = 0, + ): + """Efficient distributed packing sampler for linear attention style models + + Args: + batch_max_length (int): max number of tokens in a single batch per device + lengths (ArrayLike[int]): the lengths of each sample in the dataset + num_replicas (int): The number of replicas to split the dataset across. + rank (int): The local rank to collect batches for. + seed (int, optional): Seed for RNG, must be the same on all ranks. Defaults to 0. + """ + self.num_replicas = num_replicas + self.rank = rank + self.seed = seed + self.epoch = 0 + self.batch_max_length = batch_max_length + self.lengths = np.array(lengths) + + self.valid_indices = np.nonzero(self.lengths <= self.batch_max_length)[0] + if self.rank == 0 and len(self.valid_indices) < len(self.lengths): + msg = ( + f"Dropping {len(self.lengths) - len(self.valid_indices)}" + f"/{len(self.lengths)} samples longer than batch_max_length. " + "Ensure that the right max_batch_length is used during data processing." + ) + warnings.warn(msg) + + def __iter__(self): + batches = self._generate_batches(self.epoch) + return iter(batches) + + def __len__(self): + batches = self._generate_batches(self.epoch) + return len(batches) + + def set_epoch(self, epoch: int): + self.epoch = epoch + + @lru_cache(maxsize=1) + def _generate_batches(self, epoch: int) -> list[NDArray]: + """Generate batches for local rank + + Returns: + list[NDArray]: list of np.arrays containing the indices for each batch on the local rank + """ + + rng = np.random.default_rng(seed=self.seed + epoch) + indices = rng.permutation(self.valid_indices) + + batches = _assign_to_packed_batches( + self.lengths[indices], self.batch_max_length, self.rank, self.num_replicas + ) + + # The indices in batches are relative to the shuffled self.lengths[indices] + # Translate them so that they are instead relative to the overall unshuffled self.lengths array + batches = [indices[batch] for batch in batches] + + # Cache result + return batches diff --git a/src/speculators/train/eagle3/__init__.py b/src/speculators/train/eagle3/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/speculators/train/eagle3/attention.py b/src/speculators/train/eagle3/attention.py new file mode 100644 index 0000000..c1f721c --- /dev/null +++ b/src/speculators/train/eagle3/attention.py @@ -0,0 +1,193 @@ +import torch +from torch.nn.attention.flex_attention import ( + BlockMask, + and_masks, + flex_attention, + or_masks, +) +from transformers.integrations.flex_attention import repeat_kv +from transformers.modeling_utils import AttentionInterface + +flex_attention = torch.compile(flex_attention) + + +def create_combined_mask_mod(lengths: torch.Tensor, total_seq_len: int): + document_ids = torch.repeat_interleave( + torch.arange(lengths.shape[0], device=lengths.device, dtype=torch.long), lengths + ) + # Pad ids with -1 to indicate padding + document_ids = torch.cat( + [ + document_ids, + -1 + * torch.ones( + total_seq_len - document_ids.shape[0], + device=lengths.device, + dtype=torch.long, + ), + ] + ).contiguous() + + N = document_ids.shape[0] + + def causal_mask_mod(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + def document_mask_mod(b, h, q_idx, kv_idx): + # Exclude padding tokens in attention mask + return torch.logical_and( + document_ids[q_idx] != -1, document_ids[q_idx] == document_ids[kv_idx % N] + ) + + def diagonal_draft_mask_mod(b, h, q_idx, kv_idx): + return kv_idx % total_seq_len == q_idx + + return or_masks( + and_masks(causal_mask_mod, document_mask_mod), diagonal_draft_mask_mod + ) + + +def extend_mask_for_draft_tokens(block_mask): + """ + Extend the block mask to include new draft tokens. Concatenates a diagonal mask for the new draft tokens. + + Assumptions: + - block_mask BLOCK_SIZE := KV_BLOCK_SIZE == Q_BLOCK_SIZE + - The number of query values is the original total_seq_len (or equivalently the number of query blocks is the original total_seq_len // BLOCK_SIZE) + + i.e. if block_mask is: + [ + [ + [1, 0, 0], + [1, 1, 0], + [0, 0, 1], + ] + ] + the result will be: + [ + [ + [1, 0, 0, 1, 0, 0], + [1, 1, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 1], + ] + ] + and then callinga again will give: + [ + [ + [1, 0, 0, 1, 0, 0, 1, 0, 0], + [1, 1, 0, 0, 1, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 1, 0, 0, 1], + ] + ] + + """ + kv_num_blocks = block_mask.kv_num_blocks + # shape: [B, H, Q_LEN // BLOCK_SIZE] + + kv_indices = block_mask.kv_indices + # shape: [B, H, Q_LEN // BLOCK_SIZE, KV_LEN // BLOCK_SIZE] + b, h, q_blocks, kv_blocks = kv_indices.shape + + # extend kv indices if needed + kv_indices = torch.cat( + [kv_indices, kv_indices.new_zeros((b, h, q_blocks, q_blocks))], dim=-1 + ) + new_block_indices = torch.arange( + kv_blocks, + kv_blocks + q_blocks, + dtype=kv_indices.dtype, + device=kv_indices.device, + ).reshape(1, 1, q_blocks, 1) + kv_indices.scatter_( + dim=-1, index=kv_num_blocks.unsqueeze(-1), src=new_block_indices + ) + + kv_num_blocks = kv_num_blocks + 1 + if block_mask.full_kv_indices is not None: + extended_full_kv_indices = torch.cat( + [ + block_mask.full_kv_indices, + block_mask.full_kv_indices.new_zeros((b, h, q_blocks, q_blocks)), + ], + dim=-1, + ) + else: + extended_full_kv_indices = None + return BlockMask.from_kv_blocks( + kv_num_blocks, + kv_indices, + block_mask.full_kv_num_blocks, + extended_full_kv_indices, + mask_mod=block_mask.mask_mod, + ) + + +def block_mask_to_dense_attention_mask( + block_mask: BlockMask, device: torch.device, dtype: torch.dtype +): + attention_mask = torch.ones(block_mask.shape, device=device, dtype=dtype) + + for q_idx in range(attention_mask.shape[2]): + attention_mask[0, 0, q_idx, :] = block_mask.mask_mod( + torch.zeros(1, device=device, dtype=torch.long), + torch.zeros(1, device=device, dtype=torch.long), + torch.ones(1, device=device, dtype=torch.long) * q_idx, + torch.arange(attention_mask.shape[3], device=device, dtype=torch.long), + ) + return attention_mask + + +def flex_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask, + scaling: float | None = None, + softcap: float | None = None, + head_mask: torch.Tensor | None = None, + s_aux: torch.Tensor | None = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor | None]: + block_mask = attention_mask + enable_gqa = False + + num_local_query_heads = query.shape[1] + # When running TP this helps: + if (num_local_query_heads & (num_local_query_heads - 1)) != 0: + key = repeat_kv(key, query.shape[1] // key.shape[1]) + value = repeat_kv(value, query.shape[1] // value.shape[1]) + + return_lse = query.device.type != "cpu" + + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + flex_attention_output = flex_attention( + query, + key, + value, + score_mod=None, + block_mask=block_mask, + enable_gqa=enable_gqa, + scale=scaling, + kernel_options=None, + # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. + # For simplification, we thus always return it as no additional computations are introduced. + return_lse=return_lse, + ) + # lse is returned in float32 + if return_lse: + attention_output, lse = flex_attention_output # type: ignore[misc] + lse = lse.to(value.dtype) + else: + attention_output = flex_attention_output # type: ignore[assignment] + lse = None + + attention_output = attention_output.transpose(1, 2).contiguous() + return attention_output, lse + + +ALL_ATTENTION_FUNCTIONS = AttentionInterface() # Singleton class used for registry +ALL_ATTENTION_FUNCTIONS.register("simple_flex_attention", flex_attention_forward) diff --git a/src/speculators/train/eagle3/core.py b/src/speculators/train/eagle3/core.py new file mode 100644 index 0000000..0d5df18 --- /dev/null +++ b/src/speculators/train/eagle3/core.py @@ -0,0 +1,316 @@ +from typing import ClassVar + +import torch +from torch.nn.attention.flex_attention import create_block_mask +from transformers import AutoModelForCausalLM, DynamicCache +from transformers.configuration_utils import PretrainedConfig + +from speculators.model import SpeculatorModel +from speculators.models.eagle3 import Eagle3SpeculatorConfig +from speculators.train.eagle3.attention import ( + create_combined_mask_mod, + extend_mask_for_draft_tokens, +) +from speculators.train.eagle3.model_definitions import model_classes + + +def load_verifier_embeddings(verifier_model_name_or_path: str): + verifier_model = AutoModelForCausalLM.from_pretrained(verifier_model_name_or_path) + return verifier_model.model.embed_tokens.state_dict() + + +class Eagle3VerifierLMHead(torch.nn.Module): + def __init__(self, hidden_size: int, draft_vocab_size: int): + super().__init__() + self.lm_head = torch.nn.Linear(hidden_size, draft_vocab_size, bias=False) + self.lm_head.weight.requires_grad = False + + def load_verifier_lm_head( + self, verifier_model_name_or_path: str, t2d: torch.Tensor + ): + verifier_model = AutoModelForCausalLM.from_pretrained( + verifier_model_name_or_path + ) + verifier_lm_head_data = verifier_model.lm_head.weight.data.to(t2d.device) + trucated_data = verifier_lm_head_data[t2d, :] + if trucated_data.shape[0] != self.lm_head.weight.shape[0]: + raise ValueError( + f"Truncated verifier lm head data shape {trucated_data.shape} does not match draft lm head shape {self.lm_head.weight.shape}" + ) + self.lm_head.weight.data = trucated_data + + @torch.no_grad() + def forward(self, verifier_last_hidden_states: torch.Tensor): + return self.lm_head(verifier_last_hidden_states) + + +def align_for_step( + logits: torch.Tensor, # shape: [batch_size, total_seq_len, draft_vocab_size] + targets: torch.Tensor, # shape: [batch_size, total_seq_len, draft_vocab_size] + loss_mask: torch.Tensor | None, # shape: [batch_size, total_seq_len] + ttt_step: int, +): + # We don't have target values for the last ttt_step tokens, so we mask them out on the logit side + # We shift the target values by ttt_step + 1 to the left because that's the position the generated tokens correspond to + # e.g. + # targets_indices = [1, 2, 3, 4, 5, 6, 7, 8, 9] + # logits_indices_ttt_step_0 = [1, 2, 3, 4, 5, 6, 7, 8, 9] + # logits_indices_ttt_step_1 = [2, 3, 4, 5, 6, 7, 8, 9, 10] + # logits_indices_ttt_step_2 = [3, 4, 5, 6, 7, 8, 9, 10, 11] + # The indices for the loss_mask need to be kept in line with the targets indices + logits = logits[:, :-ttt_step] if ttt_step > 0 else logits + # shape: [batch_size, total_seq_len - ttt_step, draft_vocab_size] + targets = targets[:, ttt_step:] + # shape: [batch_size, total_seq_len - ttt_step, draft_vocab_size] + if loss_mask is not None: + loss_mask = loss_mask[:, ttt_step:] + # shape: [batch_size, total_seq_len - ttt_step] + return logits, targets, loss_mask + + +@torch.no_grad() +def compute_accuracy( + logits: torch.Tensor, # shape: [batch_size, total_seq_len - ttt_step, draft_vocab_size] + targets: torch.Tensor, # shape: [batch_size, total_seq_len - ttt_step, draft_vocab_size] + loss_mask: torch.Tensor | None, # shape: [batch_size, total_seq_len - ttt_step] +): + # Note: logits, targets, and loss_mask are already aligned for the current ttt_step + target_tokens = torch.argmax(targets, dim=-1) + predicted_tokens = torch.argmax(logits, dim=-1) + # shape: [batch_size, total_seq_len - ttt_step] + + correct = predicted_tokens == target_tokens + if loss_mask is not None: + correct = torch.masked_select(correct, loss_mask.to(torch.bool)) + acc = correct.float().sum() / ( + correct.numel() + 1e-5 + ) # avoid NaNs when loss_mask is all False + return acc + + +def loss_function( + logits: torch.Tensor, # shape: [batch_size, total_seq_len - ttt_step, draft_vocab_size] + targets: torch.Tensor, # shape: [batch_size, total_seq_len - ttt_step, draft_vocab_size] + loss_mask: torch.Tensor | None, # shape: [batch_size, total_seq_len - ttt_step] +): + # Note: logits, targets, and loss_mask are already aligned for the current ttt_step + logits = torch.nn.functional.log_softmax(logits, dim=-1) + target_p = torch.nn.functional.softmax(targets, dim=-1) + elementwise_loss = torch.nn.functional.kl_div( + logits, target_p, reduction="none", log_target=False + ) + + if loss_mask is not None: + elementwise_loss = elementwise_loss * loss_mask.unsqueeze(-1) + denominator = loss_mask.sum(dim=1) + 1e-5 + else: + denominator = logits.shape[1] # total_seq_len - ttt_step + batch_loss = torch.sum(elementwise_loss, dim=(1, 2)) / denominator + # shape: [batch_size] + return batch_loss.mean() + + +@SpeculatorModel.register("eagle3_draft") +class Eagle3DraftModel(SpeculatorModel): + config_class: ClassVar[type[Eagle3SpeculatorConfig]] = Eagle3SpeculatorConfig # type: ignore[misc] + _keys_to_ignore_on_load_missing: ClassVar[list[str]] = [ # type: ignore[misc] + "embed_tokens.weight", + "lm_head.weight", + "d2t", + "t2d", + ] + _keys_to_ignore_on_save: ClassVar[list[str]] = [] # type: ignore[misc,assignment] + + def __init__( + self, + verifier_model_name_or_path: str, + hidden_size: int, # Must be same for verifier and draft + # Vocab mappings + t2d: torch.Tensor, + d2t: torch.Tensor, + decoder_layer_config: PretrainedConfig, + # Verifier + verifier_vocab_size: int, + verifier_pad_token_id: int | None, + # Draft config + num_layers: int = 1, + ttt_steps: int = 3, + ): + norm_before_residual = True + from speculators.config import SpeculatorsConfig, VerifierConfig + from speculators.proposals.greedy import GreedyTokenProposalConfig + + speculator_config = Eagle3SpeculatorConfig( + transformer_layer_config=decoder_layer_config, + draft_vocab_size=t2d.sum(dtype=torch.long).item(), + norm_before_residual=norm_before_residual, + speculators_config=SpeculatorsConfig( + algorithm="eagle3", + proposal_methods=[ + GreedyTokenProposalConfig( + proposal_type="greedy", + speculative_tokens=ttt_steps, + ) + ], + default_proposal_method="greedy", + verifier=VerifierConfig( + name_or_path=verifier_model_name_or_path, + architectures=["LlamaForCausalLM"], # todo: fix + ), + ), + ) + super().__init__( + config=speculator_config, + verifier=None, + verifier_attachment_mode="train_only", + ) + self.verifier_model_name_or_path = verifier_model_name_or_path + self.hidden_size = hidden_size + self.num_layers = num_layers + self.decoder_layer_config = decoder_layer_config + self.ttt_steps = ttt_steps + self.register_buffer("t2d", t2d) # shape: [verifier_vocab_size], bool + self.register_buffer("d2t", d2t) # shape: [draft_vocab_size], int offsets + self.draft_vocab_size = t2d.sum(dtype=torch.long).item() + model_definitions = model_classes[decoder_layer_config.model_type] + + self.fc = torch.nn.Linear(3 * hidden_size, hidden_size, bias=False) + self.layers = torch.nn.ModuleList( + [ + model_definitions.decoder_layer_class( + decoder_layer_config, + layer_idx, + norm_before_residual=norm_before_residual, + ) + for layer_idx in range(num_layers) + ] + ) + self.norm = model_definitions.norm_class( + hidden_size, eps=decoder_layer_config.rms_norm_eps + ) + self.rotary_emb = model_definitions.rotary_emb_class(decoder_layer_config) + self.embed_tokens = torch.nn.Embedding( + verifier_vocab_size, hidden_size, padding_idx=verifier_pad_token_id + ) + # shape: [verifier_vocab_size, hidden_size] + self.embed_tokens.load_state_dict( + load_verifier_embeddings(verifier_model_name_or_path) + ) + self.embed_tokens.weight.requires_grad = False + + self.lm_head = torch.nn.Linear(hidden_size, self.draft_vocab_size, bias=False) + # shape: [hidden_size, draft_vocab_size] + + def forward( + self, + hidden_states: torch.Tensor, # shape: [1, total_seq_len, 3 * hidden_size] + input_ids: torch.Tensor, # shape: [1, total_seq_len] + lengths: torch.Tensor | None = None, # shape: [batch_size] + loss_mask: torch.Tensor | None = None, # shape: [1, total_seq_len] + position_ids: torch.Tensor | None = None, # shape: [1, total_seq_len] + target_logits: torch.Tensor + | None = None, # shape: [1, total_seq_len, draft_vocab_size] + ttt_steps: int | None = None, + use_off_policy_tokens: bool = False, + **kwargs, + ): + device = hidden_states.device + total_seq_len = hidden_states.shape[1] + return_loss = target_logits is not None + + if ttt_steps is None: + ttt_steps = self.ttt_steps + if lengths is None: + lengths = torch.tensor([total_seq_len], dtype=torch.long, device=device) + + past_key_values = DynamicCache(config=self.decoder_layer_config) + + combined_mask_mod = create_combined_mask_mod(lengths.to(device), total_seq_len) + block_mask = create_block_mask( + combined_mask_mod, + B=None, + H=None, + Q_LEN=total_seq_len, + KV_LEN=total_seq_len, + device=device, + ) + + hidden_states = self.fc(hidden_states) + # shape: [1, total_seq_len, hidden_size] + + original_input_ids = input_ids.detach().clone() + + loss = torch.tensor(0.0, device=device) + draft_tokens = [] + accuracy_list = [] + for ttt_step in range(ttt_steps): + with torch.no_grad(): + input_embeds = self.embed_tokens(input_ids) + # shape: [1, total_seq_len, hidden_size] + cache_position = torch.arange( + ttt_step * total_seq_len, + (ttt_step + 1) * total_seq_len, + dtype=torch.long, + device=device, + ) + # shape: [total_seq_len] + + hidden_states = torch.cat([input_embeds, hidden_states], dim=-1) + # shape: [1, total_seq_len, 2 * hidden_size] + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers: + hidden_states = decoder_layer( + hidden_states, + attention_mask=block_mask, # block_mask_to_dense_attention_mask(block_mask, device, torch.bool), + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + logits = self.lm_head(self.norm(hidden_states)) + # shape: [1, total_seq_len, draft_vocab_size] + + if return_loss: + s_logits, s_targets, s_loss_mask = align_for_step( + logits, target_logits, loss_mask, ttt_step + ) + loss += loss_function(s_logits, s_targets, s_loss_mask) + accuracy_list.append(compute_accuracy(s_logits, s_targets, s_loss_mask)) + + input_ids = torch.argmax(logits, dim=-1) + draft_tokens.append(input_ids.detach().clone()) + # shape: [1, total_seq_len] + # Use d2t to map draft tokens to verifier tokens. + # Must be in verifier vocabulary space because we use full verifier vocabulary in embedding + input_ids = input_ids + self.d2t[input_ids] + + if use_off_policy_tokens: + # Overwrite input_ids with ground truth tokens + # shift input_ids by 1 to the left and pad with 0 + # note: inputs_ids will no longer line up with verifier_last_hidden_state + # the draft logits generated from the padded tokens are ignored sliced out for loss calculation + input_ids = torch.cat( + [ + original_input_ids[:, 1 + ttt_step :], + original_input_ids.new_zeros(1, 1 + ttt_step), + ], + dim=-1, + ) + # shape: [1, total_seq_len] + + block_mask = extend_mask_for_draft_tokens(block_mask) + position_ids = position_ids + 1 + # shape: [1, total_seq_len] + + if return_loss: + return ( + draft_tokens, + loss, + torch.tensor(accuracy_list, device=device, dtype=torch.float), + ) + else: + return draft_tokens diff --git a/src/speculators/train/eagle3/model_definitions.py b/src/speculators/train/eagle3/model_definitions.py new file mode 100644 index 0000000..186812d --- /dev/null +++ b/src/speculators/train/eagle3/model_definitions.py @@ -0,0 +1,118 @@ +import copy +from typing import NamedTuple, Optional + +import torch +from transformers import Cache, LlamaConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.models.llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaRMSNorm, + LlamaRotaryEmbedding, +) +from transformers.processing_utils import Unpack +from transformers.utils.generic import TransformersKwargs + + +class LlamaConcatInputDecoderLayer(LlamaDecoderLayer): + def __init__( + self, + config: PretrainedConfig, + layer_idx: int, + norm_before_residual: bool = False, + ): + super().__init__(config, layer_idx) + + ##### CHANGES START ##### + self.norm_before_residual = norm_before_residual + if layer_idx == 0: + self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.self_attn.q_proj = torch.nn.Linear( + 2 * config.hidden_size, # previous: config.hidden_size + config.num_attention_heads * config.head_dim, + bias=config.attention_bias, + ) + self.self_attn.k_proj = torch.nn.Linear( + 2 * config.hidden_size, # previous: config.hidden_size + config.num_key_value_heads * config.head_dim, + bias=config.attention_bias, + ) + self.self_attn.v_proj = torch.nn.Linear( + 2 * config.hidden_size, # previous: config.hidden_size + config.num_key_value_heads * config.head_dim, + bias=config.attention_bias, + ) + self.layer_idx = layer_idx + ##### CHANGES END ##### + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + ##### CHANGES START ##### + # previous: residual = hidden_states + if self.layer_idx == 0: + # hidden_states are cat([embeds, hidden], dim=-1) + # so residual should be hidden part only, and embeds should be normalized + mid = hidden_states.shape[2] // 2 + embeds, hidden = hidden_states.split(mid, dim=-1) + residual = hidden + + # Apply norms + embeds = self.input_layernorm(embeds) + hidden = self.hidden_norm(hidden) + if self.norm_before_residual: + residual = hidden # set residual to normalized hidden + hidden_states = torch.cat([embeds, hidden], dim=-1) + else: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + ##### CHANGES END ##### + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class LlamaConcatInputRotaryEmbedding(LlamaRotaryEmbedding): + def __init__(self, config: LlamaConfig, device=None): + config = copy.copy(config) + config.hidden_size = config.hidden_size * 2 + super().__init__(config, device) + + +class ModelComponents(NamedTuple): + decoder_layer_class: type + norm_class: type + rotary_emb_class: type + + +model_classes: dict[str, ModelComponents] = { + "llama": ModelComponents( + LlamaConcatInputDecoderLayer, LlamaRMSNorm, LlamaConcatInputRotaryEmbedding + ), +} diff --git a/src/speculators/train/logger.py b/src/speculators/train/logger.py new file mode 100644 index 0000000..6c956dd --- /dev/null +++ b/src/speculators/train/logger.py @@ -0,0 +1,549 @@ +"""Logging utilities for the Speculators training module. + +This module provides a logging system for training machine learning models, +supporting multiple logging backends including TensorBoard (tensorboard), Weights & Biases (wandb). + +Example Usage: + ```python + from speculators.train.logger import setup_metric_logger + + # Setup logging with TensorBoard and wandb + setup_metric_logger( + loggers=["tensorboard", "wandb"], + run_name="my_training_run", + output_dir="logs" + ) + + # Log metrics + import logging + logger = logging.getLogger("speculators.metrics") + + # Log a simple metric + logger.info({"loss": 0.5, "accuracy": 0.95}, extra={"step": 100}) + + # Log nested metrics + logger.info({ + "training": { + "loss": 0.5, + "accuracy": 0.95 + }, + "validation": { + "loss": 0.6, + "accuracy": 0.92 + } + }, extra={"step": 100}) + + # Log hyperparameters + logger.info({ + "learning_rate": 0.001, + "batch_size": 32, + "model": { + "hidden_size": 512, + "num_layers": 6 + } + }, extra={"hparams": True}) + ``` +""" + +# SPDX-License-Identifier: Apache-2.0 + +# Standard +import importlib +import logging +import os +import warnings +from collections.abc import Mapping +from datetime import datetime, timezone +from logging.config import dictConfig +from pathlib import Path +from typing import Any, Union + +import torch + +# Third Party +from rich.logging import RichHandler + +### Helper functions + +LogDict = Mapping[str, Union[str, int, float, "LogDict"]] + + +def _substitute_placeholders( + run_name: str | None, default_template: str = "{time}" +) -> str: + """Replace placeholders in the run name with actual values. + + This function supports dynamic run name generation by replacing placeholders + with actual values from the environment or current time. This is particularly + useful for distributed training scenarios where you want unique run names + for each process. + + Supported placeholders: + - {time}: Current local timestamp in ISO format + - {utc_time}: Current UTC timestamp in ISO format + - {rank}: Process rank from RANK environment variable + - {local_rank}: Local process rank from LOCAL_RANK environment variable + + Args: + run_name: String containing placeholders to be replaced. If None, uses default_template + default_template: Default template to use if run_name is None + + Returns: + String with all placeholders replaced by their values + + Example: + ```python + # With default template + name = _substitute_placeholders(None) + # Result: "2024-03-14T10:30:00_rank0" + + # With custom template + name = _substitute_placeholders("experiment_{time}_rank{rank}") + # Result: "experiment_2024-03-14T10:30:00_rank0" + ``` + """ + if run_name is None: + run_name = default_template + + substitutions = { + "{time}": datetime.now().isoformat(timespec="seconds"), + "{utc_time}": datetime.now(timezone.utc).isoformat(timespec="seconds"), + "{rank}": os.environ.get("RANK", 0), + "{local_rank}": os.environ.get("LOCAL_RANK", 0), + } + for placeholder_pat, value in substitutions.items(): + run_name = run_name.replace(placeholder_pat, str(value)) + + return run_name + + +def _flatten_dict(log_dict: LogDict, sep: str = "/", prefix: str = "") -> dict: + """Flatten a nested dictionary into a single-level dictionary. + + This function recursively traverses a nested dictionary and creates a new + dictionary with keys that represent the path to each value in the original + dictionary. + + Args: + d: The dictionary to flatten + sep: Separator to use between nested keys + prefix: Prefix to add to all keys + + Returns: + A flattened dictionary with keys joined by the separator + """ + flattened = {} + + for k, v in log_dict.items(): + if isinstance(v, Mapping): + flattened |= _flatten_dict(v, sep=sep, prefix=f"{prefix}{k}{sep}") + else: + flattened[prefix + k] = v + + return flattened + + +### Filters +class IsMappingFilter(logging.Filter): + """Filter that only allows log records with dictionary messages. + + This filter ensures that only log records containing dictionary messages + are processed by the handler. This is useful for metric logging where + we want to ensure all logged messages are structured data. + """ + + def filter(self, record): + """Check if the log record's message is a dictionary. + + Args: + record: The log record to check + + Returns: + bool: True if the message is a dictionary, False otherwise + """ + return isinstance(record.msg, Mapping) + + +class IsRank0Filter(logging.Filter): + """Filter that only allows log records from rank 0 in distributed training. + + This filter is useful in distributed training scenarios where you want to + ensure that only the main process (rank 0) logs metrics to avoid duplicate + logging. The rank can be determined from various sources in order of precedence: + 1. Explicitly provided rank value + 2. Record's rank attribute + 3. Record's message dictionary + 4. Environment variables + 5. PyTorch distributed rank + + Args: + rank_val: Optional explicit rank value to use + local_rank: If True, use local_rank instead of global rank + """ + + def __init__(self, rank_val: int | None = None, local_rank: bool = False): + self.rank_val = rank_val + if local_rank: + self.rank_attr = "local_rank" + else: + self.rank_attr = "rank" + + def _get_rank(self, record): + rank = ( + self.rank_val + or getattr(record, self.rank_attr, None) + or (isinstance(record.msg, Mapping) and record.msg.get(self.rank_attr)) + or os.environ.get(self.rank_attr.upper(), None) + or ( + self.rank_attr == "rank" + and torch.distributed.is_initialized() + and torch.distributed.get_rank() + ) + or 0 + ) + + return int(rank) + + def filter(self, record): + return self._get_rank(record) == 0 + + +class FormatDictFilter(logging.Filter): + """Reformats dictionary messages for prettier printing. + + This filter processes dictionary messages to create a more readable string + representation. It handles different types of values appropriately: + - Floats are formatted with 3 decimal places or scientific notation + - Integers are formatted as decimal numbers + - Other types are converted to their string representation + + Note: This is not a true filter, but a processing step as described in the + Python logging cookbook: https://docs.python.org/3/howto/logging-cookbook.html#using-filters-to-impart-contextual-information + """ + + @staticmethod + def _format_value(v): + if isinstance(v, float): + if abs(v) < 0.001 or abs(v) > 999: + return f"{v:.2e}" + return f"{v:.3f}" + elif isinstance(v, int): + return f"{v:d}" + else: + return repr(v) + + def filter(self, record): + if not isinstance(record.msg, Mapping): + return True + flat_dict = _flatten_dict(record.msg) + + record.msg = ", ".join( + f"{k}={self._format_value(v)}" for k, v in flat_dict.items() + ) + + return True + + +### Handlers +class TensorBoardHandler(logging.Handler): + """Logger that writes metrics to TensorBoard. + + This handler expects a (nested) dictionary of metrics or text to be logged with string keys. + A step can be specified by passing `extra={"step": }` to the logging method. + To log hyperparameters, pass a (nested) mapping of hyperparameters to the logging method + and set `extra={"hparams": True}`. + """ + + def __init__( + self, + level: int = logging.INFO, + run_name: str | None = None, + log_dir: str | os.PathLike = "logs", + **tboard_init_kwargs: Any, + ): + """Initialize the TensorBoard logger and check for required dependencies. + + Args: + level: The logging level for this handler + run_name: Name of the run, can contain placeholders + log_dir: Directory where TensorBoard logs should be stored + """ + super().__init__(level) + + self.tboard_init_kwargs = tboard_init_kwargs.copy() + self.tboard_init_kwargs.setdefault( + "log_dir", Path(log_dir) / _substitute_placeholders(run_name) + ) + + self._tboard_writer = None + + def _setup(self): + """Create the TensorBoard log directory and initialize the writer. + + Raises: + RuntimeError: If tensorboard package is not installed + """ + + try: + from torch.utils.tensorboard import SummaryWriter + except ImportError as e: + msg = ( + "Could not initialize TensorBoardHandler because package tensorboard could not be imported.\n" + "Please ensure it is installed by running 'pip install tensorboard' or configure the logger to use a different backend." + ) + raise RuntimeError(msg) from e + + os.makedirs(self.tboard_init_kwargs["log_dir"], exist_ok=True) + self._tboard_writer = SummaryWriter(**self.tboard_init_kwargs) + + def emit(self, record: logging.LogRecord): + """Emit a log record to TensorBoard. + + This method handles both scalar metrics and text logs, automatically + detecting the type of data being logged. + + Args: + record: The log record to emit + """ + if self._tboard_writer is None: + self._setup() + + if not isinstance(record.msg, Mapping): + warnings.warn( + f"TensorBoardHandler expected a mapping, got {type(record.msg)}. Skipping log. Please ensure the handler is configured correctly to filter out non-mapping objects." + ) + return + + flat_dict = _flatten_dict(record.msg) + step = getattr(record, "step", None) + if getattr(record, "hparams", None): + self._tboard_writer.add_hparams( + flat_dict, {}, run_name=".", global_step=step + ) + return + + for k, v in flat_dict.items(): + try: + # Check that `v` can be converted to float + float(v) + except ValueError: + # Occurs for strings that cannot be converted to floats (e.g. "3.2.3") and aren't "inf" or "nan" + self._tboard_writer.add_text(k, v, global_step=step) + except TypeError: + warnings.warn( + f"TensorBoardHandler expected a scalar or text, got {type(v)}. Skipping log. Please ensure metric logger is only called with mappings containing scalar values or text." + ) + else: + self._tboard_writer.add_scalar(k, v, global_step=step) + + def flush(self): + """Flush the TensorBoard writer.""" + if self._tboard_writer is not None: + self._tboard_writer.flush() + + def close(self): + """Close the TensorBoard writer and cleanup resources.""" + if self._tboard_writer is not None: + self._tboard_writer.close() + self._tboard_writer = None + super().close() + + +class WandbHandler(logging.Handler): + """Logger that sends metrics to Weights & Biases (wandb). + + This handler expects a (nested) dictionary of metrics or text to be logged with string keys. + A step can be specified by passing `extra={"step": }` to the logging method. + To log hyperparameters, pass a (nested) mapping of hyperparameters to the logging method + and set `extra={"hparams": True}`. + """ + + def __init__( + self, + level: int = logging.INFO, + run_name: str | None = None, + log_dir: str | os.PathLike = "logs", + **init_kwargs: Any, + ): + """Initialize the wandb logger and check for required dependencies. + + Args: + level: The logging level for this handler + run_name: Name of the run, can contain placeholders + log_dir: Directory where wandb logs should be stored + """ + super().__init__(level) + + self.init_kwargs = init_kwargs.copy() + self.init_kwargs.setdefault("dir", Path(log_dir)) + self.init_kwargs.setdefault("name", _substitute_placeholders(run_name)) + self.init_kwargs.setdefault("config", {}) + + self._package_name = "wandb" + self._run = None + + def _setup(self): + try: + wandb = importlib.import_module(self._package_name) + except ImportError as e: + msg = ( + f"Could not initialize {self.__class__.__name__} because package {self._package_name} could not be imported.\n" + f"Please ensure it is installed by running 'pip install {self._package_name}' or configure the logger to use a different backend." + ) + raise RuntimeError(msg) from e + + self._run = wandb.init(**self.init_kwargs) + + def emit(self, record: logging.LogRecord): + if self._run is None: + self._setup() + + if not isinstance(record.msg, Mapping): + warnings.warn( + f"{self.__class__.__name__} expected a mapping, got {type(record.msg)}. Skipping log. Please ensure the handler is configured correctly to filter out non-mapping objects." + ) + return + + flat_dict = _flatten_dict(record.msg) + step = getattr(record, "step", None) + if getattr(record, "hparams", None): + for k, v in flat_dict.items(): + self._run.config[k] = v + return + + self._run.log(flat_dict, step=step) + + +class TrackioHandler(WandbHandler): + """Logger that sends metrics to Trackio. + + This handler expects a (nested) dictionary of metrics or text to be logged with string keys. + A step can be specified by passing `extra={"step": }` to the logging method. + To log hyperparameters, pass a (nested) mapping of hyperparameters to the logging method + and set `extra={"hparams": True}`. + """ + + def __init__( + self, + level: int = logging.INFO, + run_name: str | None = None, + log_dir: str | os.PathLike = "logs", + **init_kwargs: Any, + ): + """Initialize the trackio logger and check for required dependencies. + + Args: + level: The logging level for this handler + run_name: Name of the run, can contain placeholders + """ + super().__init__(level) + + self.init_kwargs = init_kwargs.copy() + self.init_kwargs.setdefault("name", _substitute_placeholders(run_name)) + self.init_kwargs.setdefault("config", {}) + self.init_kwargs.setdefault("project", "speculators") + + # Trackio doesn't support the dir keyword argument so we ignore log_dir + + self._package_name = "trackio" + self._run = None + + +### Main functions + + +def setup_root_logger(level="INFO"): + """Configure the root logger with rich formatting. + + This function sets up the root logger with a RichHandler for + console output and adds the FormatDictFilter for better dictionary message + formatting. + """ + handler = RichHandler() + handler.addFilter(FormatDictFilter()) + handler.addFilter(IsRank0Filter(local_rank=True)) + logging.basicConfig( + level=level, format="%(message)s", datefmt="[%X]", handlers=[handler] + ) + + +def setup_metric_logger(loggers, run_name, output_dir): + """Configure the metric logging system with specified backends. + + This function sets up a comprehensive logging configuration that supports + multiple logging backends simultaneously. It configures filters, handlers, + and loggers for structured metric logging. + + Args: + loggers: A string or list of strings specifying which logging backends to use. + Supported values: "tensorboard", "wandb", "trackio" + run_name: Name for the current training run. Can include placeholders like + {time}, {rank}, {utc_time}, {local_rank}. + output_dir: Directory where log files will be stored + + Example: + ```python + # Setup logging with multiple backends + setup_metric_logger( + loggers=["tensorboard", "wandb", "trackio"], + run_name="experiment_{time}", + output_dir="logs" + ) + + # Setup logging with a single backend + setup_metric_logger( + loggers="tensorboard", + run_name="my_run", + output_dir="logs" + ) + ``` + """ + if isinstance(loggers, str): + loggers = loggers.split(",") + loggers = [logger.strip() for logger in loggers] + + logging_config = { + "version": 1, + "disable_existing_loggers": False, + "filters": { + "is_mapping": { + "()": IsMappingFilter, + }, + "is_rank0": { + "()": IsRank0Filter, + }, + }, + "handlers": { + "tensorboard": { + "()": TensorBoardHandler, + "log_dir": output_dir, + "run_name": run_name, + "filters": ["is_mapping", "is_rank0"], + }, + "wandb": { + "()": WandbHandler, + "log_dir": output_dir, + "run_name": run_name, + "filters": ["is_mapping", "is_rank0"], + }, + "trackio": { + "()": TrackioHandler, + "log_dir": output_dir, + "run_name": run_name, + "filters": ["is_mapping", "is_rank0"], + }, + }, + "loggers": { + "speculators.metrics": { + "handlers": loggers, + "filters": ["is_mapping"], + "level": "INFO", + "propagate": True, + }, + "speculators": { + "level": "INFO", + "propagate": True, + }, + }, + } + dictConfig(logging_config) diff --git a/src/speculators/train/trainer.py b/src/speculators/train/trainer.py new file mode 100644 index 0000000..a597235 --- /dev/null +++ b/src/speculators/train/trainer.py @@ -0,0 +1,201 @@ +import logging + +import torch +import torch.distributed as dist +from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard +from torch.utils.data import DataLoader +from tqdm.rich import tqdm # todo: requries tqdm and rich + +from speculators.train.checkpointer import ( + DistributedCheckpointer, + SingleGPUCheckpointer, +) + +root_logger = logging.getLogger("speculators") +metric_logger = logging.getLogger("speculators.metrics") + + +def apply_fully_sharded(model: torch.nn.Module): + fsdp_kwargs = { + "mp_policy": MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + ) + } + + for layer in model.layers: # todo: this is hardcoded to the Eagle3DraftModel definition, should be made more general + # we apply fully_shard to each DecoderLayer + layer.to_empty(device="meta") + fully_shard(layer, **fsdp_kwargs) + + fully_shard(model, **fsdp_kwargs) + + return model + + +class Trainer: + def __init__( + self, + model: torch.nn.Module, + verifier_lm_head: torch.nn.Module, + config: dict, + train_loader: DataLoader, + val_loader: DataLoader | None = None, + is_distributed: bool = False, + local_rank: int = 0, + world_size: int = 1, + ): + self.model = model + self.verifier_lm_head = verifier_lm_head + self.config = config + self.train_loader = train_loader + self.val_loader = val_loader + self.is_distributed = is_distributed + self.local_rank = local_rank + self.world_size = world_size + checkpointer_class = ( + DistributedCheckpointer if is_distributed else SingleGPUCheckpointer + ) + self.checkpointer = checkpointer_class( + config["save_path"], + try_load_last_checkpoint=config.get("resume_from_checkpoint", False), + ) + + self.setup_trainer() + self.setup_model() + self.setup_optimizer() + + def setup_trainer(self): + self.current_epoch = self.checkpointer.previous_epoch + 1 + self.global_step = 0 + + def setup_model(self): + if self.is_distributed: + apply_fully_sharded(self.model) + + if self.checkpointer.previous_epoch != -1: + self.checkpointer.load_model_state_dict(self.model) + else: + for m in self.model.layers.children(): # todo: generalize + if not isinstance(m, FSDPModule): + continue + m.to_empty(device="cuda") # todo: generalize + for sub_module in m.modules(): + if hasattr(sub_module, "reset_parameters"): + sub_module.reset_parameters() + # todo: We need to make sure we're loading lm_head and embed_tokens after this reset + else: + self.model.to(self.local_rank) + if self.checkpointer.previous_epoch != -1: + self.checkpointer.load_model_state_dict(self.model) + self.verifier_lm_head = self.verifier_lm_head.to(self.local_rank) + + def setup_optimizer(self): + self.opt = torch.optim.Adam(self.model.parameters(), lr=self.config["lr"]) + if self.checkpointer.previous_epoch != -1: + self.checkpointer.load_optimizer_state_dict(self.model, self.opt) + + def train_epoch(self, epoch: int): + self.model.train() + self.train_loader.batch_sampler.set_epoch( + epoch + ) # todo: check if this is safe to call + + if self.local_rank == 0: + train_loader = tqdm(self.train_loader, desc=f"Epoch {epoch}") + else: + train_loader = self.train_loader + root_logger.info(f"Training Epoch {epoch} started") + + for batch in train_loader: + batch = { + k: v.to(self.local_rank) if isinstance(v, torch.Tensor) else v + for k, v in batch.items() + } + target_logits = self.verifier_lm_head(batch["verifier_last_hidden_states"]) + del batch["verifier_last_hidden_states"] + + _draft_tokens, loss, draft_accuracies = self.model( + **batch, target_logits=target_logits, use_off_policy_tokens=False + ) # set this in a better way + + self.opt.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + self.opt.step() + + loss = loss.detach().clone() + if self.is_distributed: + # Note: this is not needed for training, just for logging + dist.reduce(loss, dst=0, op=dist.ReduceOp.AVG) + dist.reduce(draft_accuracies, dst=0, op=dist.ReduceOp.AVG) + + acc_values = { + f"acc_{i}": acc.item() for i, acc in enumerate(draft_accuracies) + } + metric_logger.info( + {"train": {"loss": loss.item(), **acc_values}, "epoch": epoch}, + extra={"step": self.global_step}, + ) + self.global_step += 1 + + root_logger.info(f"Training Epoch {epoch} completed") + + @torch.no_grad() + def val_epoch(self, epoch: int): + if self.val_loader is None: + root_logger.warning("No val loader, skipping validation") + return + self.model.eval() + self.val_loader.batch_sampler.set_epoch(epoch) + root_logger.info(f"Validation Epoch {epoch} started") + if self.local_rank == 0: + val_loader = tqdm(self.val_loader, desc=f"Epoch {epoch}") + else: + val_loader = self.val_loader + val_loss = torch.zeros(1, device=self.local_rank) + val_accuracies = torch.zeros( + (), device=self.local_rank + ) # initialize to tensor of shape () + for batch in val_loader: + batch = { + k: v.to(self.local_rank) if isinstance(v, torch.Tensor) else v + for k, v in batch.items() + } + target_logits = self.verifier_lm_head(batch["verifier_last_hidden_states"]) + del batch["verifier_last_hidden_states"] + + _draft_tokens, loss, draft_accuracies = self.model( + **batch, target_logits=target_logits, use_off_policy_tokens=False + ) # set this in a better way + + if self.is_distributed: + dist.reduce(val_loss, dst=0, op=dist.ReduceOp.AVG) + dist.reduce(draft_accuracies, dst=0, op=dist.ReduceOp.AVG) + + val_loss += loss.detach().clone() + # Can't use += here because val_accuracies is a tensor of shape () on first iteration + val_accuracies = val_accuracies + draft_accuracies.detach() + + val_loss /= len(val_loader) + val_accuracies /= len(val_loader) + acc_values = { + f"acc_{i}_epoch": acc.item() for i, acc in enumerate(val_accuracies) + } + metric_logger.info( + {"val": {"loss_epoch": val_loss.item(), **acc_values}, "epoch": epoch}, + extra={"step": self.global_step}, + ) + root_logger.info(f"Validation Epoch {epoch} completed") + + def save_checkpoint(self, epoch: int): + self.checkpointer.save_checkpoint(self.model, self.opt, epoch) + root_logger.info(f"Checkpoint saved to {self.checkpointer.path / str(epoch)}") + + def run_training(self): + for epoch in range(self.current_epoch, self.config["num_epochs"]): + self.train_epoch(epoch) + if self.is_distributed: + dist.barrier() + self.val_epoch(epoch) + self.save_checkpoint(epoch) diff --git a/src/speculators/train/utils.py b/src/speculators/train/utils.py new file mode 100644 index 0000000..f1b3346 --- /dev/null +++ b/src/speculators/train/utils.py @@ -0,0 +1,41 @@ +import os + +import torch +import torch.distributed as dist + +local_rank = int(os.environ.get("LOCAL_RANK", 0)) + + +def maybe_setup_distributed(): + # Based off of https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html#initialize-ddp-with-torch-distributed-run-torchrun + if "LOCAL_RANK" not in os.environ: + # No distributed training + return 0, 1, 0, False + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + + torch.accelerator.set_device_index(local_rank) + acc = torch.accelerator.current_accelerator() + backend = torch.distributed.get_default_backend_for_device(acc) + dist.init_process_group(backend) + + rank = dist.get_rank() + + print( + f"Started DDP with local_rank={local_rank}, world_size={world_size}, rank={rank}" + ) + return local_rank, world_size, rank, True + + +def maybe_destroy_distributed(): + if "LOCAL_RANK" not in os.environ: + # No distributed training + return + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + rank = dist.get_rank() + + dist.destroy_process_group() + print( + f"Destroyed DDP with local_rank={local_rank}, world_size={world_size}, rank={rank}" + ) diff --git a/tests/unit/train/test_eagle3_attention.py b/tests/unit/train/test_eagle3_attention.py new file mode 100644 index 0000000..7bd0b29 --- /dev/null +++ b/tests/unit/train/test_eagle3_attention.py @@ -0,0 +1,129 @@ +import pytest +import torch +from torch.nn.attention.flex_attention import BlockMask + +from speculators.train.eagle3.attention import ( + create_combined_mask_mod, + extend_mask_for_draft_tokens, +) + + +def test_create_combined_mask_mod(): + lengths = torch.tensor([1, 2, 3]) + mask_mod = create_combined_mask_mod(lengths, total_seq_len=lengths.sum().item()) + + # Creates causal document mask mod that supports extended diagonals + # lengths -> document ids [0, 1, 1, 2, 2, 2] + # Expected mask mod values for q_idx (row), kv_idx (column): + expected_mask_mod = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 1, 1, 1], + ] + t0 = torch.tensor(0) + + for q_idx in range(len(expected_mask_mod)): + for kv_idx in range(len(expected_mask_mod[q_idx])): + assert mask_mod(t0, t0, q_idx, kv_idx) == expected_mask_mod[q_idx][kv_idx] + + +@pytest.mark.parametrize( + "lengths", [torch.tensor([1, 2, 3]), torch.tensor([2, 2, 2]), torch.tensor([5])] +) +def test_diagonal_draft_tokens_mask_mod(lengths): + # Causal Diagonal + # ⌄ ⌄ ⌄ | ⌄ ⌄ ⌄ ⌄ ⌄ ⌄ + # 1 0 0 | 1 0 0 1 0 0 + # 1 1 0 | 0 1 0 0 1 0 + # 1 1 1 | 0 0 1 0 0 1 + # If kv_idx > N (N = original seq len = num query indices), only the diagonal tokens are in mask + # Diagonal tokens are those where kv_idx % N == q_idx + + mask_mod = create_combined_mask_mod(lengths, total_seq_len=lengths.sum().item()) + + N = lengths.sum().item() + + t0 = torch.tensor(0) + for q_idx in range(N): + for kv_idx in range(N, 3 * N): + assert mask_mod(t0, t0, q_idx, kv_idx) == (kv_idx % N == q_idx) + + +@pytest.mark.parametrize( + "kv_num_blocks, kv_indices, expected_kv_indices", + [ + # Test 1: Dense matrix shown in comments in test code + ( + torch.tensor([2, 2, 1]), + torch.tensor([[0, 2, -1], [0, 1, -1], [1, -1, -1]]), + torch.tensor([[0, 2, 3], [0, 1, 4], [1, 5, -1]]), + ), + # Test 2: Dense matrix below + # 0 1 1 0 + # 1 0 1 1 + # 1 0 0 1 + # 1 1 1 1 + ( + torch.tensor([2, 3, 2, 4]), + torch.tensor([[1, 2, -1, -1], [0, 2, 3, -1], [0, 3, -1, -1], [0, 1, 2, 3]]), + torch.tensor( + [ + [1, 2, 4, -1, -1], + [0, 2, 3, 5, -1], + [0, 3, 6, -1, -1], + [0, 1, 2, 3, 7], + ] + ), + ), + ], +) +def test_extend_mask_for_draft_tokens(kv_num_blocks, kv_indices, expected_kv_indices): + # Block mask is stored in Block Compressed Sparse Row (BSRS) format + # This means storing: + # - kv_num_blocks (shape: [batch, head, q_blocks]): contains the number of blocks for each batch, head, and query block + # - kv_indices (shape: [batch, head, q_blocks, kv_blocks]): contains the row indices of the blocks for each batch, head, and query block + # Only the first kv_num_blocks of each row of kv_indices are defined + # e.g. To store (ignoring batch and head dimensions): + # 1 0 1 + # 1 1 0 + # 0 1 0 + # There are 2 blocks for the first query row (0, 2), 2 blocks for the second query row (0, 1), and 1 block for the third query row (1) + # Therefore: + # kv_num_blocks = [2, 2, 1] + # kv_indices = [[[0, 2, U], [0, 1, U], [1, U, U]]] where U indicates the value is undefined + # Note: for our masks currently batch and head indices aren't considered in the mask function, so we just treat them as 1 when storing the BlockMask + + # During ttt, we extend the mask to accomodate the new draft tokens. The tokens included will be those on the diagonal (see diagonal test above), + # and therefore we need to include blocks on the newly added diagonal. + + # Therefore, we expect `kv_num_blocks` to increase by 1 for each query row because only the diagonal block will be added to each row. + # We also expect `kv_indices` to include the new diagonal blocks for each query row. + + kv_num_blocks = kv_num_blocks.reshape(1, 1, *kv_num_blocks.shape) + kv_indices = kv_indices.reshape(1, 1, *kv_indices.shape) + expected_kv_indices = expected_kv_indices.reshape(1, 1, *expected_kv_indices.shape) + + def dummy_mask_mod(b, h, q_idx, kv_idx): + return True + + block_mask = BlockMask.from_kv_blocks( + kv_num_blocks=kv_num_blocks.clone(), + kv_indices=kv_indices.clone(), + mask_mod=dummy_mask_mod, + ) + + extended_mask = extend_mask_for_draft_tokens(block_mask) + + for q_idx in range(kv_num_blocks.shape[2]): + num_defined_blocks_in_row = extended_mask.kv_num_blocks[0, 0, q_idx].item() + # Only the first num_defined_blocks_in_row of each row of kv_indices are defined, the rest can have any value + # Check that the defined blocks are match expected values + assert torch.equal( + extended_mask.kv_indices[0, 0, q_idx, :num_defined_blocks_in_row], + expected_kv_indices[0, 0, q_idx, :num_defined_blocks_in_row], + ) + + assert extended_mask.mask_mod == block_mask.mask_mod