diff --git a/.gitignore b/.gitignore index 7d35f0ccc..916a29ff4 100644 --- a/.gitignore +++ b/.gitignore @@ -25,4 +25,4 @@ scoring/plots/ !scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv !scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv -algoperf/_version.py +algoperf/_version.py \ No newline at end of file diff --git a/algoperf/param_utils.py b/algoperf/param_utils.py index 908ef0f27..26a351bb4 100644 --- a/algoperf/param_utils.py +++ b/algoperf/param_utils.py @@ -44,6 +44,8 @@ def pytorch_param_types( param_types[name] = spec.ParameterType.ATTENTION_BIAS elif 'in_proj' in name: param_types[name] = spec.ParameterType.ATTENTION_QKV + elif 'qkv' in name: + param_types[name] = spec.ParameterType.ATTENTION_QKV elif 'kv_proj' in name: param_types[name] = spec.ParameterType.ATTENTION_KV elif 'k_proj' in name or 'key' in name: diff --git a/algoperf/workloads/lm/__init__.py b/algoperf/workloads/lm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py new file mode 100644 index 000000000..c010b32af --- /dev/null +++ b/algoperf/workloads/lm/input_pipeline.py @@ -0,0 +1,154 @@ +"""Input pipeline for a LM dataset.""" +import functools +import os +from typing import Optional + +import jax +import jax.numpy as jnp +import tensorflow as tf +import torch +import torch.nn.functional as F +from transformers import GPT2Tokenizer + +from algoperf import data_utils +from algoperf.pytorch_utils import pytorch_setup +from datasets import load_dataset +from datasets import load_from_disk + +RANK = pytorch_setup()[1] +# Avoid multithreading in all processes but the first (rank 0). +# This ensures that only the primary process (RANK == 0) uses TensorFlow's +# automatic optimization (AUTOTUNE), while other processes disable it (None). +# tf.data.AUTOTUNE is a constant that lets TensorFlow automatically determine +# the optimal number of elements to prefetch or parallelize for dataset +# operations, improving performance. +AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None + + +def get_hf_dataloader(cache_dir: str, + data_rng: jax.random.PRNGKey, + batch_size: int = 8, + seq_len: int = 32, + framework: str = "torch", + split="train"): + """ + Create a data loader from HuggingFace's FineWeb dataset. + + Args: + cache_dir: Directory to cache the dataset + batch_size: Number of sequences per batch + seq_len: Length of each sequence + framework: Either "torch" or "jax" to specify output tensor type + split: Dataset split to load + """ + # Initialize tokenizer and get vocab size + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + vocab_size = tokenizer.vocab_size + # Load the FineWeb dataset in streaming mode + fw = load_dataset( + "HuggingFaceFW/fineweb-edu", + name="sample-10BT", + split=split, + streaming=True, + cache_dir=cache_dir) + fw = fw.batch(batch_size=batch_size, drop_last_batch=True) + if split in ['train', 'eval_train']: + fw = fw.shuffle(seed=int(data_rng[-1])) + + def _tokenize(x): + """Tokenize and pad text to seq_len+1 tokens.""" + if framework == "torch": + tokens = tokenizer(x, return_tensors="pt")["input_ids"].squeeze() + pad_length = seq_len - tokens.shape[0] + if pad_length > 0: + tokens = F.pad(tokens, pad_length, value=tokenizer.pad_token_id) + elif framework == "jax": + tokens = tokenizer(x, return_tensors="jax")["input_ids"].squeeze() + pad_length = seq_len - tokens.shape[0] + if pad_length > 0: + tokens = jnp.pad( + tokens, + pad_length, + mode="constant", + constant_values=tokenizer.pad_token_id) + return tokens[:seq_len + 1] + + def batch_iterator(): + for doc in fw: + if framework == "torch": + token_ids = torch.stack([_tokenize(x) for x in doc['text']]) + # Take first seq_len+1 tokens and convert to one-hot + tokens = F.one_hot(token_ids, num_classes=vocab_size).float() + # Split into input/target + inputs, targets = tokens[:, :-1, :], tokens[:, 1:, :] + inputs, targets = inputs.to("cuda"), targets.to("cuda") + elif framework == "jax": + token_ids = jnp.stack([_tokenize(x) for x in doc['text']]) + tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size) + inputs, targets = tokens[:, :-1], tokens[:, 1:] + inputs, targets = jax.device_put(inputs), jax.device_put(targets) + yield {'inputs': inputs, 'targets': targets} + + return batch_iterator() + + +def get_lm_dataset(data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None): + """Load HF dataset and return a TF dataset.""" + + dataset_path = os.path.join(data_dir, split) + dataset = load_from_disk(dataset_path) + + is_training = split == "train" + shuffle = split in ['train', 'eval_train'] + + dataset.set_format("tensorflow") # tf.int64 # TODO (nico): is this needed? + + def tf_generator(): + """Generates data in a TensorFlow-friendly format.""" + for example in dataset: + yield { + "inputs": example["input_ids"][:-1], + "targets": example["input_ids"][1:], + } + + # Create a TensorFlow dataset + ds = tf.data.Dataset.from_generator( + tf_generator, + output_signature={ + "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int32), + "targets": tf.TensorSpec(shape=(None,), dtype=tf.int32), + }) + + # Avoid creating too many threads when using PyTorch DDP. + # Limits TensorFlow's threading for non-primary processes (RANK != 0) + if RANK != 0: + options = tf.data.Options() + options.threading.private_threadpool_size = 1 + ds = ds.with_options(options) + + if shuffle: + ds = ds.shuffle(buffer_size=1024, seed=data_rng[0]) + + if is_training: + ds = ds.repeat() + + # Batch the dataset, grouping consecutive elements into fixed-size chunks. + ds = ds.batch(global_batch_size, drop_remainder=is_training) + ds = ds.prefetch(AUTOTUNE) + + # Limit the dataset to a fixed number of batches if `num_batches` is specified + if num_batches: + ds = ds.take(num_batches) + + # Shard the dataset across multiple GPUs/TPUs if necessary + ds = map( + functools.partial( + data_utils.shard_and_maybe_pad_np, + global_batch_size=global_batch_size), + ds) + + return ds diff --git a/algoperf/workloads/lm/lm_jax/__init__.py b/algoperf/workloads/lm/lm_jax/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/lm/lm_jax/models.py b/algoperf/workloads/lm/lm_jax/models.py new file mode 100644 index 000000000..72ee5bd83 --- /dev/null +++ b/algoperf/workloads/lm/lm_jax/models.py @@ -0,0 +1,19 @@ +from flax import linen as nn +import jax.numpy as jnp + +class LinearModel(nn.Module): + vocab_size: int + + @nn.compact + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + x = nn.Dense( + 10, + kernel_init=nn.initializers.normal(0.02), + bias_init=nn.initializers.zeros + )(inputs) + return nn.Dense( + self.vocab_size, + kernel_init=nn.initializers.normal(0.02), + bias_init=nn.initializers.zeros, + name="output" + )(x) diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py new file mode 100644 index 000000000..d21fd5090 --- /dev/null +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -0,0 +1,345 @@ +# Self-contained version of the DecoderOnly Transformer from NanoDO + +import dataclasses +from functools import partial + +from flax import linen as nn +import jax +import jax.numpy as jnp + +# =========== Transformer Decoder-only Model ========== + + + +@dataclasses.dataclass +class DoConfig: + """Hyper-parameters for Transformer decoder-only.""" + + D: int # model/embed dim = qkv dim + H: int # num attention heads + L: int # max context/sequence length + N: int # number of transformer block layers + V: int # vocab size + F: int # FF inner dimension + kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform() + embed_init: nn.initializers.Initializer = nn.initializers.variance_scaling( + 1.0, "fan_in", "normal", out_axis=0 + ) + dtype: jnp.dtype = jnp.float32 + rmsnorm_epsilon: float = 1e-6 + multiple_of: int = 256 + tie_embeddings: bool = True # Whether to tie input and output embeddings + + +class Mlp(nn.Module): + """Multilayer perceptron with GLU activation.""" + + cfg: DoConfig + + @nn.compact + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg + # Use Xavier uniform initialization explicitly + xavier_init = nn.initializers.xavier_uniform() + linear = partial( + nn.Dense, kernel_init=xavier_init, use_bias=False, dtype=cfg.dtype + ) + hidden_dim = cfg.multiple_of * ( + (cfg.F + cfg.multiple_of - 1) // cfg.multiple_of + ) + # Double the hidden dimension for GLU + x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD) + # Apply GLU activation + x_BxLxF = nn.glu(x_BxLx2F, axis=-1) + x_BxLxD = linear(cfg.D)(x_BxLxF) + return x_BxLxD + +@partial(jax.jit, static_argnums=(0,1,2)) +def init_rope(dim=256, seq_len=128, n_heads=4): + """Initialize rotary embeddings.""" + def precompute_freqs_cis_jax(dim, end, theta=10000.0): + inv_freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2) / dim)) + t = jnp.arange(end) / 1.0 + freqs = jnp.outer(t, inv_freqs).astype(jnp.float32) + return jnp.stack([ + jnp.cos(freqs)[None, :, None, :], + jnp.sin(freqs)[None, :, None, :] + ], axis=3) + + freqs_cis = precompute_freqs_cis_jax(dim // n_heads, seq_len, theta=500000) + return freqs_cis.transpose(0, 1, 2, 4, 3) + +@jax.jit +def apply_rope(q, k, freqs_cis): + """Apply rotary embeddings to Q and K.""" + def rotate_tensor(x): + # Split into real and imaginary parts + x_r2 = x.reshape(*x.shape[:-1], -1, 2) + L = x.shape[1] + freqs = freqs_cis[:, :L, :, :, :] + + # Apply rotation + rotated_x_r2 = jnp.stack([ + x_r2[..., 0] * freqs[..., 0] - x_r2[..., 1] * freqs[..., 1], + x_r2[..., 1] * freqs[..., 0] + x_r2[..., 0] * freqs[..., 1] + ], axis=-1) + + return rotated_x_r2.reshape(*x.shape) + + # Apply rotation to Q and K separately + rotated_q = rotate_tensor(q) + rotated_k = rotate_tensor(k) + + return rotated_q, rotated_k + + +class CausalAttn(nn.Module): + """Causal attention layer with rotary embeddings.""" + + cfg: DoConfig + + def setup(self): + cfg = self.cfg + assert cfg.D % cfg.H == 0, f"D {cfg.D} not divisible by H {cfg.H}" + self.Dh = cfg.D // cfg.H + + # Initialize rotary embeddings + self.freqs_cis = init_rope(cfg.D, cfg.L, cfg.H) + + # Maps D -> (H, Dh) + self.multilinear = partial( + nn.DenseGeneral, + axis=-1, + features=(cfg.H, self.Dh), + kernel_init=cfg.kernel_init, + use_bias=False, + dtype=cfg.dtype, + ) + + self.multilinear_query = self.multilinear(name="query") + self.multilinear_key = self.multilinear(name="key") + self.multilinear_value = self.multilinear(name="value") + self.output_projection = nn.DenseGeneral( + features=cfg.D, + name="attn_out_proj", + # axis=(-2, -1), # + kernel_init=cfg.kernel_init, + use_bias=False, + dtype=cfg.dtype, + ) + + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg + + # Project inputs to Q, K, V + q_BxLxHxDh = self.multilinear_query(x_BxLxD) + k_BxLxHxDh = self.multilinear_key(x_BxLxD) + v_BxLxHxDh = self.multilinear_value(x_BxLxD) + + # Apply rotary embeddings to Q and K + q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis) + + # Scale queries + q_BxLxHxDh /= self.Dh**0.5 + + # Compute attention scores + att_BxHxLxL = jnp.einsum("...qhd,...khd->...hqk", q_BxLxHxDh, k_BxLxHxDh) + + # Causal attention mask + L = x_BxLxD.shape[1] + mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_)) + + # Apply mask and softmax + _NEG_INF = jnp.finfo(cfg.dtype).min + att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) + att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) + att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) + + # Compute attention output + out_BxLxHxDh = jnp.einsum("...hqk,...khd->...qhd", att_BxHxLxL, v_BxLxHxDh) + + # Reshape and project output + out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape) + + # Output projection + out_BxLxD = self.output_projection(out_BxLxD) + + return out_BxLxD + + +class TBlock(nn.Module): + """Transformer Block.""" + + docfg: DoConfig + + @nn.compact + def __call__(self, in_BxLxD: jax.Array): + cfg = self.docfg + + # x = x + attn( attn_norm(x) ) + x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + in_BxLxD + ) + x_BxLxD = CausalAttn(cfg)(x_BxLxD) + x_BxLxD += in_BxLxD + + # x = x + mlp( mlp_norm(x) ) + z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + x_BxLxD + ) + z_BxLxD = Mlp(cfg)(z_BxLxD) + + return x_BxLxD + z_BxLxD + + +class TransformerDo(nn.Module): + """Transformer decoder-only.""" + + docfg: DoConfig + + def setup(self): + cfg = self.docfg + self.embed = nn.Embed( + num_embeddings=cfg.V, + features=cfg.D, + embedding_init=cfg.embed_init, + ) + + self.blocks = [TBlock(cfg) for _ in range(cfg.N)] + self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) + + # Output projection - tied to input embeddings if configured + if cfg.tie_embeddings: + self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) + else: + self.output_proj = nn.Dense( + cfg.V, + kernel_init=cfg.embed_init, + dtype=cfg.dtype, + name="output_proj" + ) + + def __call__(self, y_BxL: jax.Array): + # For training on concatenated examples. + y_BxLxD = self.embed(y_BxL) + for block in self.blocks: + y_BxLxD = block(y_BxLxD) + y_BxLxD = self.out_ln(y_BxLxD) + logits_BxLxV = self.output_proj(y_BxLxD) + return logits_BxLxV + + def predict(self, y_BxL: jax.Array, k: int = 1): + """Generate k tokens autoregressively. + + Args: + y_BxL: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict + + Returns: + Tuple of (input_ids, predicted_ids) + """ + cfg = self.docfg + batch_size = y_BxL.shape[0] + seq_len = y_BxL.shape[1] + + # Store original input + original_input = y_BxL + + # Make sure we don't exceed the model's context length + if seq_len + k > cfg.L: + raise ValueError( + f"Total sequence length ({seq_len + k}) exceeds model's context length ({cfg.L})" + ) + + # Generate k tokens autoregressively + for _ in range(k): + # Get logits for the entire sequence + logits = self(y_BxL) + + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] + + # Get the most likely token + next_token = jnp.argmax(next_token_logits, axis=-1) + + # Append the predicted token to the sequence + y_BxL = jnp.concatenate([y_BxL, next_token[:, None]], axis=1) + + # Return original input and the k predicted tokens + return original_input, y_BxL[:, -k:] + + +# =========== Demo Code ========== + + +def main(): + """Create and run the DecoderOnly Transformer model.""" + # Initialize model configuration with smaller parameters for demo + B, L = (2, 128) # Batch size, sequence length + cfg = DoConfig(D=128, H=4, L=L, N=2, V=256, F=4 * 128) + model = TransformerDo(cfg) + + # Print model info + print(f"\nModel Configuration:") + print(f" - Model dimension (D): {cfg.D}") + print(f" - Number of heads (H): {cfg.H}") + print(f" - Max sequence length (L): {cfg.L}") + print(f" - Number of layers (N): {cfg.N}") + print(f" - Vocabulary size (V): {cfg.V}") + print(f" - Feed forward dimension (F): {cfg.F}") + + # Create random input tokens (simulated token IDs) + rng_key = jax.random.PRNGKey(42) + input_rng, init_rng = jax.random.split(rng_key) + + # Generate random token IDs (integers between 0 and vocab_size-1) + x_BxL = jax.random.randint( + input_rng, shape=(B, L), minval=0, maxval=cfg.V, dtype=jnp.int32 + ) + + # Initialize model parameters + print("\nInitializing model parameters...") + params = model.init(init_rng, x_BxL) + + # Print parameter count + param_count = sum(x.size for x in jax.tree_util.tree_leaves(params)) + print(f"Total parameters: {param_count:,}") + + # Make a prediction (forward pass) + print("\nRunning forward pass...") + logits = model.apply(params, x_BxL) + + # Print output shape and sample values + print(f"\nOutput shape: {logits.shape} (batch_size, sequence_length, vocab_size)") + print(f"Output data type: {logits.dtype}") + + # Print sample logits (first 5 positions of the first sequence) + print("\nSample logits (first sequence, first 5 positions, first 5 values):") + for position in range(min(5, L)): + print(f" Position {position}: {logits[0, position, :5]}") + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print("\nPredicted token IDs (first sequence, first 10 positions):") + print(predictions[0, :10]) + + # Test the predict function + print("\nTesting predict function...") + # Use a shorter + short_seq = x_BxL[:, :10] + print(f"Input sequence shape: {short_seq.shape}") + + # Predict 5 tokens + k = 5 + original, predicted = model.apply(params, short_seq, k, method=model.predict) + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print("\nPredicted token IDs (first sequence, first 10 positions):") + print(predictions[0, :10]) + + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py new file mode 100644 index 000000000..5401ad240 --- /dev/null +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -0,0 +1,152 @@ +"""LM workload implemented in Jax.""" + +from typing import Dict, Optional, Tuple + +import jax +import jax.numpy as jnp +import optax +from flax import jax_utils +from algoperf import param_utils +from algoperf import jax_sharding_utils +from algoperf import spec +from algoperf.workloads.lm.workload import BaseLmWorkload +from algoperf.workloads.lm.lm_jax.models import LinearModel +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader, get_lm_dataset +from algoperf.workloads.lm.lm_jax.nanodo_model import ( + TransformerDo, DoConfig, init_rope, apply_rope) + + +class LmWorkload(BaseLmWorkload): + """LM JAX workload.""" + def _build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + """Build an input queue using pre-cached FineWeb dataset.""" + del num_batches + del repeat_final_dataset + loader = get_lm_dataset( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + loader = map(jax_sharding_utils.shard_along_batch_dim, loader) + return loader + + def _build_hf_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + """Build an input queue using HuggingFace FineWeb dataset.""" + del num_batches + del repeat_final_dataset + loader = get_hf_dataloader( + cache_dir=data_dir, + data_rng=data_rng, + batch_size=global_batch_size, + seq_len=self._seq_len, + framework="jax", + split=split) + return loader + + def init_model_fn( + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + + # Initialize NanoDO transformer model + cfg = DoConfig( + D=512, # model dim + H=8, # num heads + L=self._seq_len, + N=6, # num layers + V=self._vocab_size, + F=2048, # feedforward dim + dtype=jnp.float32 + ) + self._model = TransformerDo(cfg) + input_shape = (1, self._seq_len) # For token IDs + + params_rng, init_rng = jax.random.split(rng) + variables = jax.jit(self._model.init)({'params': params_rng}, + jnp.ones(input_shape, jnp.int32)) + params = variables['params'] + self._param_shapes = param_utils.jax_param_shapes(params) + self._param_types = param_utils.jax_param_types(self._param_shapes) + params = jax_sharding_utils.replicate(params) + model_state = None + return params, model_state + + def model_fn( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + del mode, rng, update_batch_norm, model_state, dropout_rate + inputs = batch['inputs'] + # Convert one-hot inputs to token IDs if needed + if inputs.ndim == 3: # one-hot encoded + inputs = jnp.argmax(inputs, axis=-1) + logits = self._model.apply({'params': params}, inputs) + return logits, None + + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling in JAX.""" + # Convert one-hot labels to token IDs if needed + if len(label_batch.shape) == len(logits_batch.shape): # one-hot + label_batch = jnp.argmax(label_batch, axis=-1) + + # Reshape for sequence modeling + logits = logits_batch.reshape(-1, logits_batch.shape[-1]) + labels = label_batch.reshape(-1) + + # Compute cross-entropy loss + loss = -jnp.sum( + jax.nn.log_softmax(logits)[jnp.arange(labels.shape[0]), labels]) + + if mask_batch is not None: + mask = mask_batch.reshape(-1) + loss = loss * mask + n_valid = mask.sum() + else: + n_valid = labels.shape[0] + + return { + 'summed': loss, + 'n_valid_examples': n_valid, + 'per_example': loss / n_valid # Return per-token loss + } + + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return param_name.contains('output') + + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + targets = batch['targets'] + + # Calculate cross-entropy loss + loss = -jnp.sum(targets * jax.nn.log_softmax(logits, axis=-1)) + return loss diff --git a/algoperf/workloads/lm/lm_pytorch/__init__.py b/algoperf/workloads/lm/lm_pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/lm/lm_pytorch/models.py b/algoperf/workloads/lm/lm_pytorch/models.py new file mode 100644 index 000000000..545763924 --- /dev/null +++ b/algoperf/workloads/lm/lm_pytorch/models.py @@ -0,0 +1,18 @@ +import torch +import torch.nn as nn + +class LinearLayer(nn.Module): + def __init__(self, vocab_size: int): + super().__init__() + self.bottleneck = nn.Linear(vocab_size, 512) + self.output = nn.Linear(512, vocab_size) + self.reset_parameters() + + def reset_parameters(self): + nn.init.normal_(self.bottleneck.weight, std=0.02) + nn.init.zeros_(self.bottleneck.bias) + nn.init.normal_(self.output.weight, std=0.02) + nn.init.zeros_(self.output.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.output(self.bottleneck(x)) diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py new file mode 100644 index 000000000..627a0e16d --- /dev/null +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -0,0 +1,298 @@ +import math +import torch +import torch.nn.functional as F +from torch import nn +from dataclasses import dataclass +from typing import Tuple + + + +@dataclass +class ModelConfig: + vocab_size: int + seq_len: int + dim: int + expand: float + n_layers: int + n_heads: int + rmsnorm_eps: float = 1e-6 + tie_embeddings: bool = False + + +class MLP(nn.Module): + + def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256): + super().__init__() + hidden_dim = multiple_of * ( + (hidden_dim + multiple_of - 1) // multiple_of) + self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False) + self.fc2 = nn.Linear(hidden_dim, dim, bias=False) + self.glu = nn.GLU(dim=2) + + # Initialize with Xavier uniform + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + + def forward(self, x): + # x: (bsz, T, dim) + return self.fc2(self.glu(self.fc1(x))) + + +def precompute_freqs_cis(dim: int, + end: int, + theta: float = 10000.0, + condense_ratio: int = 1): + inv_freqs = 1.0 / (theta**(torch.arange( + 0, dim, 2, dtype=torch.float32, device=torch.device("cpu")) / dim)) + t = torch.arange(end, dtype=torch.float32, + device=inv_freqs.device) / condense_ratio + freqs = torch.outer(t, inv_freqs).float() + return torch.stack([ + torch.cos(freqs)[None, :, None, :], + torch.sin(freqs)[None, :, None, :] + ], + dim=4) + + +def apply_rotary_emb_complex_like( + q: torch.Tensor, k: torch.Tensor, + freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # Rotate query and key vectors using RoPE + qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float() + rotated_qk_r2 = torch.stack( + [ + qk_r2[..., 0] * freqs_cis[..., 0] - + qk_r2[..., 1] * freqs_cis[..., 1], + qk_r2[..., 1] * freqs_cis[..., 0] + + qk_r2[..., 0] * freqs_cis[..., 1], + ], + -1, + ).flatten(3) + rotated_qk = rotated_qk_r2 + return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) + + +class Attention(nn.Module): + + def __init__(self, cfg: ModelConfig): + super().__init__() + assert cfg.dim % cfg.n_heads == 0 + self.dim = cfg.dim + self.n_heads = cfg.n_heads + self.head_dim = cfg.dim // cfg.n_heads + + self.w_qkv = nn.Linear(cfg.dim, 3 * cfg.dim, bias=False) + self.w_out = nn.Linear(cfg.dim, cfg.dim, bias=False) + + def forward(self, x, freqs_cis): + bsz, seqlen, d = x.shape # (bsz, seqlen, d) + + q, k, v = self.w_qkv(x).split(d, dim=2) # (bsz, seqlen, d) + q = q.view(bsz, seqlen, self.n_heads, + self.head_dim) # (bsz, seqlen, nh, h_dim) + k = k.view(bsz, seqlen, self.n_heads, + self.head_dim) # (bsz, seqlen, nh, h_dim) + v = v.view(bsz, seqlen, self.n_heads, + self.head_dim) # (bsz, seqlen, nh, h_dim) + + q, k = apply_rotary_emb_complex_like( + q, k, freqs_cis=freqs_cis) # (bsz, seqlen, nh, h_dim) + + q = q.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + k = k.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + v = v.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + + out = F.scaled_dot_product_attention( + q, k, v, is_causal=True) # (bsz, nh, seqlen, h_dim) + + out = out.transpose(1, 2).contiguous().view(bsz, seqlen, + d) # (bsz, seqlen, d) + + return self.w_out(out) + + +class Block(nn.Module): + + def __init__(self, layer_id: int, cfg: ModelConfig): + super().__init__() + self.attn = Attention(cfg) + self.attn_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.mlp = MLP(dim=cfg.dim, hidden_dim=int(cfg.expand * cfg.dim)) + self.mlp_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.layer_id = layer_id + + def forward(self, x, freqs_cis): + # x: (bsz, seqlen, dim) + x = x + self.attn(self.attn_norm(x), freqs_cis) + x = x + self.mlp(self.mlp_norm(x)) + return x + + +class Transformer(nn.Module): + + def __init__(self, cfg): + super().__init__() + self.n_layers = cfg.n_layers + self.cfg = cfg + head_dim = cfg.dim // cfg.n_heads + assert cfg.dim % cfg.n_heads == 0 + + self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.dim) + self.layers = nn.ModuleList( + [Block(idx, cfg) for idx in range(cfg.n_layers)]) + self.out_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) + + # Initialize freqs_cis on CPU first (more memory efficient) + self.register_buffer('freqs_cis', + precompute_freqs_cis(head_dim, cfg.seq_len, 500000)[0:cfg.seq_len], + persistent=False) + + # init all weights, scale residual branches + self.apply(self._init_weights) + self._scale_residual_branches() + + # Move model to device (which will also move freqs_cis) + if torch.cuda.is_available(): + self.cuda() + + if cfg.tie_embeddings: + self.tie_weights() + + def forward(self, x): + # x: (bsz, seqlen) + x = self.embed_tokens(x) # (bsz, seqlen, dim) + L = x.shape[1] + + # Make sure we have enough precomputed frequencies + if L > self.freqs_cis.shape[1]: + # Need to recompute for longer sequence + head_dim = self.cfg.dim // self.cfg.n_heads + new_freqs = precompute_freqs_cis(head_dim, max(L, self.cfg.seq_len), 500000) + self.register_buffer('freqs_cis', new_freqs[0:max(L, self.cfg.seq_len)], persistent=False) + if torch.cuda.is_available(): + self.freqs_cis = self.freqs_cis.cuda() + + # Select the frequencies for current sequence length and ensure correct device + freqs_cis = self.freqs_cis[:, :L, :].to(x.device) + + for layer in self.layers: + x = layer(x, freqs_cis) # (bsz, seqlen, dim) + return self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) + + def predict(self, x, k=1): + """Generate k tokens autoregressively. + + Args: + x: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict + + Returns: + Tuple of (input_ids, predicted_ids) + """ + # For debugging + predictions = [] + + batch_size = x.shape[0] + seq_len = x.shape[1] + + # Store original input + original_input = x.clone() + generated_input = x.clone() + + # Generate k tokens autoregressively + for i in range(k): + # Get logits for the entire sequence + logits = self(generated_input) + + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] + + # Zero out the last token ID to prevent repetition + # This is a common issue - the model gets stuck repeating the last token + last_token_id = generated_input[:, -1] + next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) + + # Print top 5 tokens for debugging + if i == 0: + print("\nPyTorch detailed prediction:") + top5_values, top5_indices = torch.topk(next_token_logits[0], 5) + for j, (idx, val) in enumerate(zip(top5_indices.tolist(), top5_values.tolist())): + prob = torch.softmax(next_token_logits[0], dim=-1)[idx].item() + print(f" Top {j+1}: Token {idx}, logit={val:.2f}, prob={prob:.6f}") + + # Get the most likely token + next_token = torch.argmax(next_token_logits, dim=-1) + predictions.append(next_token.item()) + + # Append the predicted token to the sequence + next_token = next_token.unsqueeze(1) # Add sequence dimension + generated_input = torch.cat([generated_input, next_token], dim=1) + + print(f" Full predictions step by step: {predictions}") + + # Return all tokens, not just the last k + return original_input, generated_input[:, -k:] + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def _scale_residual_branches(self): + for n, p in self.named_parameters(): + if n.endswith("fc2.weight"): # mlp/glu output layer + torch.nn.init.normal_(p, + mean=0.0, + std=0.02 / math.sqrt(2 * self.n_layers)) + if n.endswith("w_out.weight"): # attn output layer + torch.nn.init.normal_(p, + mean=0.0, + std=0.02 / math.sqrt(2 * self.n_layers)) + + def tie_weights(self): + self.lm_head.weight = self.embed_tokens.weight + + def count_params(self, non_embedding=True): + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.embed_tokens.weight.numel() + if (not self.lm_head.weight + is self.embed_tokens.weight): # if no weight tying + n_params -= self.lm_head.weight.numel() + return n_params + + +def main(): + print("Initializing transformer model and running forward pass...") + + seq_length = 512 + + # Define model configuration + config = ModelConfig( + vocab_size=32000, # Common vocab size for tokenizers like BPE or SentencePiece + seq_len=seq_length, # Maximum sequence length + dim=768, # Embedding dimension + expand=4.0, # MLP expansion factor + n_layers=12, # Number of transformer layers + n_heads=12, # Number of attention heads + rmsnorm_eps=1e-6, # RMSNorm epsilon + tie_embeddings=True # Tie embedding and output weights + ) + + def tie_weights(self): + self.lm_head.weight = self.embed_tokens.weight + + def count_params(self, non_embedding=True): + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.embed_tokens.weight.numel() + if (not self.lm_head.weight + is self.embed_tokens.weight): # if no weight tying + n_params -= self.lm_head.weight.numel() + return n_params + + diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py new file mode 100644 index 000000000..e5dafdd3c --- /dev/null +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -0,0 +1,182 @@ +"""LM workload implemented in PyTorch.""" + +from typing import Dict, Iterator, Optional, Tuple + +import jax +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from itertools import islice +from algoperf import data_utils +from algoperf import param_utils +from algoperf import pytorch_utils +from algoperf import spec +from algoperf.workloads.lm.workload import BaseLmWorkload +from algoperf.workloads.lm.lm_pytorch.plainlm_model import Transformer, ModelConfig + +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() + + +class LmWorkload(BaseLmWorkload): + """LM PyTorch workload.""" + + def init_model_fn( + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + + if hasattr(self, '_model'): + # Reinitialize weights but keep same config + self._model.apply(self._model._init_weights) + self._model._scale_residual_branches() + return self._model, None + + torch.manual_seed(rng[0]) + cfg = ModelConfig( + vocab_size=self._vocab_size, + seq_len=self._seq_len, + dim=512, # Model dimension + expand=4, # MLP expansion factor + n_layers=6, # Number of transformer layers + n_heads=8, # Number of attention heads + rmsnorm_eps=1e-6, + tie_embeddings=True + ) + self._model = Transformer(cfg) + self._param_shapes = param_utils.pytorch_param_shapes(self._model) + self._param_types = param_utils.pytorch_param_types(self._param_shapes) + self._model.to(DEVICE) + + if N_GPUS > 1: + if USE_PYTORCH_DDP: + self._model = DDP(self._model, device_ids=[RANK], output_device=RANK) + else: + self._model = torch.nn.DataParallel(self._model) + + return self._model, None + + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + + del model_state, rng, update_batch_norm + model = params + + # Convert one-hot inputs to token IDs if needed + inputs = augmented_and_preprocessed_input_batch['inputs'] + if inputs.dim() == 3: # one-hot encoded + inputs = inputs.argmax(dim=-1) + + logits = model(inputs) + return logits, None + + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: + """Build an input queue for the given split.""" + from algoperf.workloads.lm.input_pipeline import get_lm_dataset + local_batch_size = global_batch_size // N_GPUS + + loader = get_lm_dataset( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=local_batch_size, + num_batches=num_batches + ) + if USE_PYTORCH_DDP: + loader = islice(loader, RANK, None, N_GPUS) + seq_len = self._seq_len + weights = None + + dtype = torch.int32 + is_train = split == 'train' + + for batch in loader: + inputs = batch['inputs'] + targets = batch['targets'] + + if USE_PYTORCH_DDP: + if not is_train: + # During eval, the batch size of the remainder might be different + per_device_batch_size = torch.tensor( + targets.shape[0], dtype=dtype, device=DEVICE) + dist.broadcast(per_device_batch_size, src=0) + local_batch_size = per_device_batch_size.item() + # Broadcast to all devices + #dist.broadcast(inputs, src=0) + #dist.broadcast(targets, src=0) + + if weights is None: + weights = torch.ones((local_batch_size, seq_len), device=DEVICE) + batch = { + 'inputs': torch.tensor(inputs, device=DEVICE, dtype=dtype), + 'targets': torch.tensor(targets, device=DEVICE, dtype=dtype), + 'weights': weights, + } + yield batch + + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return 'lm_head.weight' in param_name or 'lm_head.bias' in param_name + + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" + model = params + logits, _ = self.model_fn( + model, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + + # Handle both one-hot and token ID targets + targets = batch['targets'] + if targets.dim() == 3: # one-hot + loss = -torch.sum(targets * torch.nn.functional.log_softmax(logits, dim=-1)) + else: # token IDs + loss = torch.nn.functional.cross_entropy( + logits.view(-1, logits.size(-1)), + targets.view(-1), + reduction='sum' + ) + return loss + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling in PyTorch.""" + vocab_size = logits_batch.shape[-1] + + if len(label_batch.shape) == len(logits_batch.shape): + # One-hot labels + log_probs = torch.nn.functional.log_softmax(logits_batch, dim=-1) + loss = -torch.sum(label_batch * log_probs, dim=-1) + else: + # Dense labels + loss = torch.nn.functional.cross_entropy( + logits_batch, + label_batch, + reduction='none') + if mask_batch is not None: + loss = loss * mask_batch + + n_valid = mask_batch.sum() if mask_batch is not None else label_batch.shape[0] + return { + 'summed': loss.sum(), + 'n_valid_examples': n_valid, + 'per_example': loss + } diff --git a/algoperf/workloads/lm/tests/test_build_input_queue_jax.py b/algoperf/workloads/lm/tests/test_build_input_queue_jax.py new file mode 100644 index 000000000..b9adc70d2 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_build_input_queue_jax.py @@ -0,0 +1,60 @@ +import jax +import jax.numpy as jnp + +from algoperf.profiler import PassThroughProfiler +from algoperf.workloads.lm.lm_jax.workload import LmWorkload +import os + +RANK = os.environ.get('RANK', 0) + +def test_dataloader_jax(): + # Test config. + rng_seed = 1996 + data_dir = '/home/ak4605/data/finewebedu/' + split = 'train' + global_batch_size = 64 + dtype = jnp.int32 + seq_len = 2048 + + workload = LmWorkload() + data_rng = jax.random.PRNGKey(rng_seed) + input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + + for _ in range(1): + + batch = next(input_queue) + print(f"RANK {RANK} got batch") + + assert type(batch) == dict + assert 'inputs' in batch + assert 'targets' in batch + + inputs, targets = batch['inputs'], batch['targets'] + print(f"RANK {RANK} inputs.shape: {inputs.shape}") + print(f"RANK {RANK} targets.shape: {targets.shape}") + print(f"RANK {RANK} type(inputs): {type(inputs)}") + + jax.debug.inspect_array_sharding(inputs, callback=print) + assert inputs.dtype == dtype + assert targets.dtype == dtype + + assert inputs.shape == (global_batch_size, seq_len) + assert targets.shape == (global_batch_size, seq_len) + + assert jnp.equal(inputs[:, 1:], targets[:, :-1]).all() + print(f"RANK {RANK} inputs[0, :10]: {inputs[0, :10]}") + + print(f"=== ALL TEST PASSED ===") + + +def main(): + profiler = PassThroughProfiler() + test_dataloader_jax() + + +if __name__ == '__main__': + main() diff --git a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py new file mode 100644 index 000000000..827272037 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py @@ -0,0 +1,86 @@ +import jax +import torch + +from algoperf.profiler import PassThroughProfiler +from algoperf.pytorch_utils import pytorch_init +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.lm.lm_pytorch.workload import LmWorkload + +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() + + +def sync_ddp(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +def test_dataloader_torch(): + # Test config. + rng_seed = 1996 + data_dir = '/home/ak4605/data/finewebedu/' + split = 'train' + global_batch_size = 64 + dtype = torch.int32 + seq_len = 2048 + + local_batch_size = global_batch_size // N_GPUS + + workload = LmWorkload() + + data_rng = jax.random.PRNGKey(rng_seed) + + input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + + print(f"RANK {RANK} of {N_GPUS}") + sync_ddp() + + # batch = next(input_queue) + # inputs, targets = batch['inputs'], batch['targets'] + # print(f"inputs.shape: {inputs.shape}") + # print(f"inputs: {inputs}") + + # Start test. + for _ in range(1): + + batch = next(input_queue) + print(f"RANK {RANK} got batch") + + assert type(batch) == dict + assert 'inputs' in batch + assert 'targets' in batch + + inputs, targets = batch['inputs'], batch['targets'] + print(f"RANK {RANK} inputs.shape: {inputs.shape}") + print(f"RANK {RANK} targets.shape: {targets.shape}") + print(f"RANK {RANK} type(inputs): {type(inputs)}") + assert type(inputs) == torch.Tensor + assert type(targets) == torch.Tensor + + assert inputs.device == DEVICE + assert targets.device == DEVICE + assert inputs.dtype == dtype + assert targets.dtype == dtype + + print(local_batch_size, seq_len) + assert inputs.shape == (local_batch_size, seq_len) + assert targets.shape == (local_batch_size, seq_len) + + assert torch.equal(inputs[:, 1:], targets[:, :-1]) + print(f"RANK {RANK} inputs[0, :10]: {inputs[0, :10]}") + + print(f"=== ALL TEST PASSED ===") + + +def main(): + profiler = PassThroughProfiler() + print(USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS) + pytorch_init(USE_PYTORCH_DDP, RANK, profiler) + test_dataloader_torch() + + +if __name__ == '__main__': + main() diff --git a/algoperf/workloads/lm/tests/test_hf_input_pipeline.py b/algoperf/workloads/lm/tests/test_hf_input_pipeline.py new file mode 100644 index 000000000..36bab0d02 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_hf_input_pipeline.py @@ -0,0 +1,116 @@ +"""Tests for LM HuggingFace input pipeline.""" +import os + +import jax +import jax.numpy as jnp +import torch +from transformers import GPT2Tokenizer + +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader + + +def main(): + # Setup test environment + cache_dir = "/home/ak4605/data" + if not os.path.exists(cache_dir): + raise FileNotFoundError(f"Cache directory {cache_dir} not found") + + data_rng = jax.random.PRNGKey(42) + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + vocab_size = tokenizer.vocab_size + + print("Running JAX output shapes and types test...") + batch_size = 8 + seq_len = 32 + loader = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="train", + data_rng=data_rng) + inputs, targets = next(loader) + assert inputs.shape == (batch_size, seq_len, vocab_size), \ + f"Expected inputs shape {(batch_size, seq_len, vocab_size)}, got {inputs.shape}" + assert targets.shape == (batch_size, seq_len, vocab_size), \ + f"Expected targets shape {(batch_size, seq_len, vocab_size)}, got {targets.shape}" + assert inputs.dtype == jnp.float32, \ + f"Expected inputs dtype float32, got {inputs.dtype}" + assert targets.dtype == jnp.float32, \ + f"Expected targets dtype float32, got {targets.dtype}" + assert jnp.all(jnp.sum(inputs, axis=-1) == 1), "Inputs should be one-hot encoded" + assert jnp.all(jnp.sum(targets, axis=-1) == 1), "Targets should be one-hot encoded" + print("✓ JAX test passed") + + print("\nRunning Torch output shapes and types test...") + loader = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="torch", + split="train", + data_rng=data_rng) + inputs, targets = next(loader) + assert inputs.shape == (batch_size, seq_len, vocab_size), \ + f"Expected inputs shape {(batch_size, seq_len, vocab_size)}, got {inputs.shape}" + assert targets.shape == (batch_size, seq_len, vocab_size), \ + f"Expected targets shape {(batch_size, seq_len, vocab_size)}, got {targets.shape}" + assert inputs.dtype == torch.float32, \ + f"Expected inputs dtype float32, got {inputs.dtype}" + assert targets.dtype == torch.float32, \ + f"Expected targets dtype float32, got {targets.dtype}" + assert torch.all(torch.sum(inputs, dim=-1) == 1), "Inputs should be one-hot encoded" + assert torch.all(torch.sum(targets, dim=-1) == 1), "Targets should be one-hot encoded" + print("✓ Torch test passed") + + print("\nTesting consistent batching with same seed...") + loader1 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="train", + data_rng=jax.random.PRNGKey(42)) + batch1 = next(loader1) + + loader2 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="train", + data_rng=jax.random.PRNGKey(42)) + batch2 = next(loader2) + + assert jnp.array_equal(batch1[0], batch2[0]), "Input batches should be identical with same seed" + assert jnp.array_equal(batch1[1], batch2[1]), "Target batches should be identical with same seed" + print("✓ Consistent batching test passed") + + print("\nTesting eval split doesn't shuffle...") + loader1 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="eval", + data_rng=jax.random.PRNGKey(42)) + batch1 = next(loader1) + + loader2 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="eval", + data_rng=jax.random.PRNGKey(999)) + batch2 = next(loader2) + + assert jnp.array_equal(batch1[0], batch2[0]), "Eval inputs should be identical regardless of seed" + assert jnp.array_equal(batch1[1], batch2[1]), "Eval targets should be identical regardless of seed" + print("✓ Eval no shuffling test passed") + + print("\nAll tests passed successfully!") + + +if __name__ == "__main__": + main() diff --git a/algoperf/workloads/lm/tests/test_linear_model.py b/algoperf/workloads/lm/tests/test_linear_model.py new file mode 100644 index 000000000..31cd1d577 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_linear_model.py @@ -0,0 +1,39 @@ +import jax +import jax.numpy as jnp +import torch + +TEST_SEQ_LEN = 512 + +def test_pytorch_linear(): + from algoperf.workloads.lm.lm_pytorch.models import LinearLayer + vocab_size = 32000 + model = LinearLayer(vocab_size) + + batch_size = 8 + seq_len = TEST_SEQ_LEN + inputs = torch.randn(batch_size, seq_len, vocab_size) + outputs = model(inputs) + + assert outputs.shape == (batch_size, seq_len, vocab_size) + assert not torch.isnan(outputs).any() + +def test_jax_linear(): + from algoperf.workloads.lm.lm_jax.models import LinearModel + + vocab_size = 32000 + seq_len = TEST_SEQ_LEN + batch_size = 8 + model = LinearModel(vocab_size) + rng = jax.random.PRNGKey(0) + params = model.init(rng, jnp.ones((1, seq_len, vocab_size))) + + inputs = jax.random.normal(rng, (batch_size, seq_len, vocab_size)) + outputs = model.apply(params, inputs) + + assert outputs.shape == (batch_size, seq_len, vocab_size) + assert not jnp.isnan(outputs).any() + +if __name__ == '__main__': + test_pytorch_linear() + test_jax_linear() + print("All tests passed!") diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py new file mode 100644 index 000000000..986a98297 --- /dev/null +++ b/algoperf/workloads/lm/workload.py @@ -0,0 +1,180 @@ +"""LM workload parent class.""" + +import abc +import math +import os +from typing import Dict, Optional + +from absl import flags +import jax +import torch.distributed as dist + +from algoperf import spec +from algoperf.workloads.lm import input_pipeline +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader + +FLAGS = flags.FLAGS + +USE_PYTORCH_DDP = "LOCAL_RANK" in os.environ + + +class BaseLmWorkload(spec.Workload): + """LM workload.""" + + _vocab_size: int = 50257 + _seq_len: int = 5 + warmup_factor: float = 0.1 + + def __init__(self) -> None: + super().__init__() + self._param_shapes = None + self._param_types = None + + @property + def target_metric_name(self) -> str: + """The name of the target metric (useful for scoring/processing code).""" + return 'ppl' + + def has_reached_validation_target(self, eval_result: float) -> bool: + return eval_result['validation/ppl'] > self.validation_target_value + + @property + def validation_target_value(self) -> float: + return 20.0 # Target perplexity + + def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: + return eval_result['test/ppl'] <= self.test_target_value + + @property + def test_target_value(self) -> float: + return 20.0 # Target perplexity + + @property + def loss_type(self) -> spec.LossType: + return spec.LossType.SOFTMAX_CROSS_ENTROPY + + @property + def num_train_examples(self) -> int: + return 1000000 # Example size + + @property + def num_eval_train_examples(self) -> int: + return 10000 # Subset for evaluation + + @property + def num_validation_examples(self) -> int: + return 50000 + + @property + def num_test_examples(self) -> int: + return 50000 + + @property + def eval_batch_size(self) -> int: + return 8 + + @property + def train_mean(self): + raise NotImplementedError + + @property + def train_stddev(self): + raise NotImplementedError + + @property + def max_allowed_runtime_sec(self) -> int: + return 3600 * 4 # 4 hours + + @property + def eval_period_time_sec(self) -> int: + return 600 # 10 minutes + + @property + def step_hint(self) -> int: + """Approx. steps the baseline can do in the allowed runtime budget.""" + return 7000 + + @property + def pre_ln(self) -> bool: + return True + + @property + def attention_temp(self) -> float: + return 1.0 + + @property + def activation(self) -> str: + return 'silu' + + @property + def glu(self) -> bool: + return True + + @abc.abstractmethod + def _build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + """Build an input queue for the given split.""" + + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + dropout_rate=None) + + loss_dict = self.loss_fn(batch['targets'], logits) + return loss_dict['summed'] + + def _eval_model_on_split(self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0) -> Dict[str, float]: + """Run a full evaluation of the model.""" + num_batches = int(math.ceil(num_examples / global_batch_size)) + if split not in self._eval_iters: + # These iterators will repeat indefinitely. + self._eval_iters[split] = self._build_input_queue( + rng, + split, + data_dir, + global_batch_size, + num_batches, + repeat_final_dataset=True) + + loss = 0.0 + for _ in range(num_batches): + eval_batch = next(self._eval_iters[split]) + loss += self._eval_batch(params, eval_batch, model_state, rng) + if USE_PYTORCH_DDP: + dist.all_reduce(loss) + mean_loss = loss.item() / num_examples + return {'loss': mean_loss} + + # Does NOT apply regularization, which is left to the submitter to do in + # `update_params`. + @abc.abstractmethod + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling.""" diff --git a/algoperf/workloads/workloads.py b/algoperf/workloads/workloads.py index 4dd4717e9..114b1adb4 100644 --- a/algoperf/workloads/workloads.py +++ b/algoperf/workloads/workloads.py @@ -9,151 +9,151 @@ BASE_WORKLOADS_DIR = 'algoperf/workloads/' WORKLOADS = { - 'cifar': { - 'workload_path': 'cifar/cifar', - 'workload_class_name': 'CifarWorkload', - }, - 'criteo1tb': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallWorkload', - }, - 'criteo1tb_test': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallTestWorkload', - }, - 'criteo1tb_layernorm': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallLayerNormWorkload', - }, - 'criteo1tb_embed_init': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallEmbedInitWorkload', - }, - 'criteo1tb_resnet': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallResNetWorkload', - }, - 'fastmri': { - 'workload_path': 'fastmri/fastmri', - 'workload_class_name': 'FastMRIWorkload', - }, - 'fastmri_model_size': { - 'workload_path': 'fastmri/fastmri', - 'workload_class_name': 'FastMRIModelSizeWorkload', - }, - 'fastmri_tanh': { - 'workload_path': 'fastmri/fastmri', - 'workload_class_name': 'FastMRITanhWorkload', - }, - 'fastmri_layernorm': { - 'workload_path': 'fastmri/fastmri', - 'workload_class_name': 'FastMRILayerNormWorkload', - }, - 'imagenet_resnet': { - 'workload_path': 'imagenet_resnet/imagenet', - 'workload_class_name': 'ImagenetResNetWorkload', - }, - 'imagenet_resnet_silu': { - 'workload_path': 'imagenet_resnet/imagenet', - 'workload_class_name': 'ImagenetResNetSiLUWorkload', - }, - 'imagenet_resnet_gelu': { - 'workload_path': 'imagenet_resnet/imagenet', - 'workload_class_name': 'ImagenetResNetGELUWorkload', - }, - 'imagenet_resnet_large_bn_init': { - 'workload_path': 'imagenet_resnet/imagenet', - 'workload_class_name': 'ImagenetResNetLargeBNScaleWorkload', - }, - 'imagenet_vit': { - 'workload_path': 'imagenet_vit/imagenet', - 'workload_class_name': 'ImagenetVitWorkload', - }, - 'imagenet_vit_glu': { - 'workload_path': 'imagenet_vit/imagenet', - 'workload_class_name': 'ImagenetVitGluWorkload', - }, - 'imagenet_vit_post_ln': { - 'workload_path': 'imagenet_vit/imagenet', - 'workload_class_name': 'ImagenetVitPostLNWorkload', - }, - 'imagenet_vit_map': { - 'workload_path': 'imagenet_vit/imagenet', - 'workload_class_name': 'ImagenetVitMapWorkload', - }, - 'librispeech_conformer': { - 'workload_path': 'librispeech_conformer/librispeech', - 'workload_class_name': 'LibriSpeechConformerWorkload', - }, - 'librispeech_conformer_attention_temperature': { - 'workload_path': 'librispeech_conformer/librispeech', - 'workload_class_name': 'LibriSpeechConformerAttentionTemperatureWorkload', - }, - 'librispeech_conformer_layernorm': { - 'workload_path': 'librispeech_conformer/librispeech', - 'workload_class_name': 'LibriSpeechConformerLayerNormWorkload', - }, - 'librispeech_conformer_gelu': { - 'workload_path': 'librispeech_conformer/librispeech', - 'workload_class_name': 'LibriSpeechConformerGeluWorkload', - }, - 'librispeech_deepspeech': { - 'workload_path': 'librispeech_deepspeech/librispeech', - 'workload_class_name': 'LibriSpeechDeepSpeechWorkload', - }, - 'librispeech_deepspeech_tanh': { - 'workload_path': 'librispeech_deepspeech/librispeech', - 'workload_class_name': 'LibriSpeechDeepSpeechTanhWorkload', - }, - 'librispeech_deepspeech_no_resnet': { - 'workload_path': 'librispeech_deepspeech/librispeech', - 'workload_class_name': 'LibriSpeechDeepSpeechNoResNetWorkload', - }, - 'librispeech_deepspeech_norm_and_spec_aug': { - 'workload_path': 'librispeech_deepspeech/librispeech', - 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload', - }, - 'mnist': { - 'workload_path': 'mnist/mnist', - 'workload_class_name': 'MnistWorkload', - }, - 'ogbg': {'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgWorkload'}, - 'ogbg_gelu': { - 'workload_path': 'ogbg/ogbg', - 'workload_class_name': 'OgbgGeluWorkload', - }, - 'ogbg_silu': { - 'workload_path': 'ogbg/ogbg', - 'workload_class_name': 'OgbgSiluWorkload', - }, - 'ogbg_model_size': { - 'workload_path': 'ogbg/ogbg', - 'workload_class_name': 'OgbgModelSizeWorkload', - }, - 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'}, - 'wmt_post_ln': { - 'workload_path': 'wmt/wmt', - 'workload_class_name': 'WmtWorkloadPostLN', - }, - 'wmt_attention_temp': { - 'workload_path': 'wmt/wmt', - 'workload_class_name': 'WmtWorkloadAttentionTemp', - }, - 'wmt_glu_tanh': { - 'workload_path': 'wmt/wmt', - 'workload_class_name': 'WmtWorkloadGLUTanH', - }, + 'cifar': { + 'workload_path': 'cifar/cifar', 'workload_class_name': 'CifarWorkload' + }, + 'criteo1tb': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallWorkload', + }, + 'criteo1tb_test': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallTestWorkload', + }, + 'criteo1tb_layernorm': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallLayerNormWorkload' + }, + 'criteo1tb_embed_init': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallEmbedInitWorkload' + }, + 'criteo1tb_resnet': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallResNetWorkload' + }, + 'fastmri': { + 'workload_path': 'fastmri/fastmri', + 'workload_class_name': 'FastMRIWorkload', + }, + 'fastmri_model_size': { + 'workload_path': 'fastmri/fastmri', + 'workload_class_name': 'FastMRIModelSizeWorkload', + }, + 'fastmri_tanh': { + 'workload_path': 'fastmri/fastmri', + 'workload_class_name': 'FastMRITanhWorkload', + }, + 'fastmri_layernorm': { + 'workload_path': 'fastmri/fastmri', + 'workload_class_name': 'FastMRILayerNormWorkload', + }, + 'imagenet_resnet': { + 'workload_path': 'imagenet_resnet/imagenet', + 'workload_class_name': 'ImagenetResNetWorkload', + }, + 'imagenet_resnet_silu': { + 'workload_path': 'imagenet_resnet/imagenet', + 'workload_class_name': 'ImagenetResNetSiLUWorkload', + }, + 'imagenet_resnet_gelu': { + 'workload_path': 'imagenet_resnet/imagenet', + 'workload_class_name': 'ImagenetResNetGELUWorkload', + }, + 'imagenet_resnet_large_bn_init': { + 'workload_path': 'imagenet_resnet/imagenet', + 'workload_class_name': 'ImagenetResNetLargeBNScaleWorkload', + }, + 'imagenet_vit': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitWorkload', + }, + 'imagenet_vit_glu': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitGluWorkload', + }, + 'imagenet_vit_post_ln': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitPostLNWorkload', + }, + 'imagenet_vit_map': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitMapWorkload', + }, + 'librispeech_conformer': { + 'workload_path': 'librispeech_conformer/librispeech', + 'workload_class_name': 'LibriSpeechConformerWorkload', + }, + 'librispeech_conformer_attention_temperature': { + 'workload_path': + 'librispeech_conformer/librispeech', + 'workload_class_name': + 'LibriSpeechConformerAttentionTemperatureWorkload', + }, + 'librispeech_conformer_layernorm': { + 'workload_path': 'librispeech_conformer/librispeech', + 'workload_class_name': 'LibriSpeechConformerLayerNormWorkload', + }, + 'librispeech_conformer_gelu': { + 'workload_path': 'librispeech_conformer/librispeech', + 'workload_class_name': 'LibriSpeechConformerGeluWorkload', + }, + 'librispeech_deepspeech': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechWorkload', + }, + 'librispeech_deepspeech_tanh': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechTanhWorkload', + }, + 'librispeech_deepspeech_no_resnet': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechNoResNetWorkload', + }, + 'librispeech_deepspeech_norm_and_spec_aug': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload', + }, + 'lm': {'workload_path': 'lm/lm', 'workload_class_name': 'LmWorkload'}, + 'mnist': { + 'workload_path': 'mnist/mnist', 'workload_class_name': 'MnistWorkload' + }, + 'ogbg': { + 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgWorkload' + }, + 'ogbg_gelu': { + 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgGeluWorkload' + }, + 'ogbg_silu': { + 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgSiluWorkload' + }, + 'ogbg_model_size': { + 'workload_path': 'ogbg/ogbg', + 'workload_class_name': 'OgbgModelSizeWorkload' + }, + 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'}, + 'wmt_post_ln': { + 'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkloadPostLN' + }, + 'wmt_attention_temp': { + 'workload_path': 'wmt/wmt', + 'workload_class_name': 'WmtWorkloadAttentionTemp' + }, + 'wmt_glu_tanh': { + 'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkloadGLUTanH' + }, } BASE_WORKLOADS = [ - 'criteo1tb', - 'fastmri', - 'imagenet_resnet', - 'imagenet_vit', - 'librispeech_conformer', - 'librispeech_deepspeech', - 'ogbg', - 'wmt', + 'criteo1tb', + 'fastmri', + 'imagenet_resnet', + 'imagenet_vit', + 'librispeech_conformer', + 'librispeech_deepspeech', + 'lm', + 'ogbg', + 'wmt' ] diff --git a/algorithms/archived_paper_baselines/adamw/pytorch/submission.py b/algorithms/archived_paper_baselines/adamw/pytorch/submission.py index 761ce5cb1..8fa4e27f6 100644 --- a/algorithms/archived_paper_baselines/adamw/pytorch/submission.py +++ b/algorithms/archived_paper_baselines/adamw/pytorch/submission.py @@ -189,6 +189,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'lm': + return 4 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/archived_paper_baselines/nesterov/jax/submission.py b/algorithms/archived_paper_baselines/nesterov/jax/submission.py index e199fb2b9..cc8eba3c5 100644 --- a/algorithms/archived_paper_baselines/nesterov/jax/submission.py +++ b/algorithms/archived_paper_baselines/nesterov/jax/submission.py @@ -292,6 +292,8 @@ def get_batch_size(workload_name): return 16 elif workload_name == 'cifar': return 128 + elif workload_name == 'lm': + return 8 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py index 0577cd4e0..6e40cdab1 100644 --- a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py @@ -394,6 +394,8 @@ def get_batch_size(workload_name): return 512 elif workload_name == 'wmt': return 128 + elif workload_name == 'lm': + return 128 elif workload_name == 'mnist': return 16 else: diff --git a/datasets/README.md b/dataset/README.md similarity index 99% rename from datasets/README.md rename to dataset/README.md index 1aeb83239..1bfd9bf73 100644 --- a/datasets/README.md +++ b/dataset/README.md @@ -453,3 +453,13 @@ The preprocessing script will generate `.npy` files for audio data, `features.cs ```bash python3 librispeech_preprocess.py --data_dir=$DATA_DIR/librispeech --tokenizer_vocab_path=$DATA_DIR/librispeech/spm_model.vocab ``` + +### Fineweb-EDU 10B +From `algorithmic-efficiency` run: + +```bash +python3 datasets/dataset_setup.py \ +--data_dir $DATA_DIR \ +--temp_dir $DATA_DIR/tmp \ +--fineweb_edu +``` \ No newline at end of file diff --git a/datasets/dataset_setup.py b/dataset/dataset_setup.py similarity index 85% rename from datasets/dataset_setup.py rename to dataset/dataset_setup.py index e110930cd..872e2ef0b 100644 --- a/datasets/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -56,7 +56,7 @@ Example command: -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir=~/data \ --temp_dir=/tmp/mlcommons_data --imagenet \ @@ -72,16 +72,22 @@ from torchvision.datasets import CIFAR10 from algoperf.workloads.wmt import tokenizer -from algoperf.workloads.wmt.input_pipeline import normalize_feature_names -from datasets import librispeech_preprocess -from datasets import librispeech_tokenizer +from algoperf.workloads.wmt.input_pipeline import \ + normalize_feature_names +from dataset import librispeech_preprocess +from dataset import librispeech_tokenizer + +import datasets as hf_datasets +from transformers import AutoTokenizer import functools +import itertools import os import shutil import subprocess import tarfile +from typing import Dict, List, Any from absl import app from absl import flags from absl import logging @@ -106,38 +112,38 @@ 'files will be deleted.', ) flags.DEFINE_boolean( - 'all', - False, - 'Whether or not to download all datasets. If false, can download some ' - 'combination of datasets by setting the individual dataset flags below.', -) - -flags.DEFINE_boolean( - 'criteo1tb', False, 'If --all=false, whether or not to download Criteo 1TB.' -) -flags.DEFINE_boolean( - 'cifar', False, 'If --all=false, whether or not to download CIFAR-10.' -) -flags.DEFINE_boolean( - 'fastmri', False, 'If --all=false, whether or not to download FastMRI.' -) -flags.DEFINE_boolean( - 'imagenet', False, 'If --all=false, whether or not to download Imagenet.' -) -flags.DEFINE_boolean( - 'librispeech', - False, - 'If --all=false, whether or not to download LibriSpeech.', -) -flags.DEFINE_boolean( - 'mnist', False, 'If --all=false, whether or not to download MNIST.' -) -flags.DEFINE_boolean( - 'ogbg', False, 'If --all=false, whether or not to download OGBG.' -) -flags.DEFINE_boolean( - 'wmt', False, 'If --all=false, whether or not to download WMT.' -) + 'all', + False, + 'Whether or not to download all datasets. If false, can download some ' + 'combination of datasets by setting the individual dataset flags below.') + +flags.DEFINE_boolean('criteo1tb', + False, + 'If --all=false, whether or not to download Criteo 1TB.') +flags.DEFINE_boolean('cifar', + False, + 'If --all=false, whether or not to download CIFAR-10.') +flags.DEFINE_boolean('fastmri', + False, + 'If --all=false, whether or not to download FastMRI.') +flags.DEFINE_boolean('finewebedu', + False, + 'If --all=false, whether or not to download FineWebEdu.') +flags.DEFINE_boolean('imagenet', + False, + 'If --all=false, whether or not to download Imagenet.') +flags.DEFINE_boolean('librispeech', + False, + 'If --all=false, whether or not to download LibriSpeech.') +flags.DEFINE_boolean('mnist', + False, + 'If --all=false, whether or not to download MNIST.') +flags.DEFINE_boolean('ogbg', + False, + 'If --all=false, whether or not to download OGBG.') +flags.DEFINE_boolean('wmt', + False, + 'If --all=false, whether or not to download WMT.') flags.DEFINE_string( 'data_dir', @@ -194,6 +200,7 @@ flags.DEFINE_string('framework', None, 'Can be either jax or pytorch.') flags.DEFINE_boolean('skip_download', False, 'Skips data download.') +flags.DEFINE_boolean('skip_tokenization', False, 'Skip Fineweb-edu tokenization.') FLAGS = flags.FLAGS @@ -767,6 +774,93 @@ def download_wmt(data_dir): ) +def download_finewebedu(data_dir, + tmp_dir=None, + skip_download=False, + skip_tokenization=False): + """Download FineWebEdu-10B.""" + + if not skip_download: + data_dir = os.path.join(data_dir, 'fineweb_edu_10B') + tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' + cache_dir = os.path.join(tmp_dir, + 'lm') if tmp_dir is not None else os.path.expanduser( + '~/.cache/huggingface/datasets') + + _maybe_mkdir(data_dir) + _maybe_mkdir(tmp_dir) + _maybe_mkdir(cache_dir) + + os.environ["TMPDIR"] = tmp_dir + + ds = hf_datasets.load_dataset( + 'HuggingFaceFW/fineweb-edu', + name='sample-10BT', + split='train', + cache_dir=cache_dir) + ds.save_to_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) + else: + ds = hf_datasets.load_from_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) + + if not skip_tokenization: + # Tokenize + lm_tokenizer = AutoTokenizer.from_pretrained('gpt2') + logging.info(f"Vocab size of lm_tokenizer = {len(lm_tokenizer)}") + + def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + + def add_eos(seq): + return seq + lm_tokenizer.eos_token if seq else seq + + def add_eos_batched(seqs): + return [add_eos(seq) for seq in seqs] + + return lm_tokenizer( + add_eos_batched(examples["text"]), + return_special_tokens_mask=False, + return_attention_mask=False) + + lm_tokenizer.model_max_length = 1e30 # prevent truncation during tokenization + logging.info("Tokenizing...") + tokenized_dataset = ds.map( + tokenize, + remove_columns=[ + 'text', + 'id', + 'dump', + 'url', + 'file_path', + 'language', + 'language_score', + 'token_count', + 'score', + 'int_score' + ], + batched=True, + batch_size=1024, + num_proc=8) + + tokenized_dataset.save_to_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) + else: + tokenized_dataset = hf_datasets.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) + + # Convert to tensorflow_datasets.Dataset objects + tokenized_dataset = tokenized_dataset.to_tf_dataset() + + # Shuffle dataset + dataset_size = tokenized_dataset.cardinality().numpy() + shuffled_dataset = tokenized_dataset.shuffle(dataset_size, seed=0) + train_size = int(0.9 * dataset_size) + train_dataset = shuffled_dataset.take(train_size) + val_dataset = shuffled_dataset.skip(train_size) + + # Split in train and valid. + train_dataset.save(os.path.join(data_dir, "train")) + val_dataset.save(os.path.join(data_dir, "val")) + + return + + def main(_): data_dir = FLAGS.data_dir tmp_dir = FLAGS.temp_dir @@ -854,6 +948,10 @@ def main(_): logging.info('Downloading WMT...') download_wmt(data_dir) + if FLAGS.all or FLAGS.finewebedu: + logging.info('Downloading FineWebEdu-10B...') + download_finewebedu(data_dir, tmp_dir, FLAGS.skip_download, FLAGS.skip_tokenization) + # pylint: enable=logging-format-interpolation # pylint: enable=consider-using-with diff --git a/datasets/librispeech_preprocess.py b/dataset/librispeech_preprocess.py similarity index 99% rename from datasets/librispeech_preprocess.py rename to dataset/librispeech_preprocess.py index 1c216db46..878f10f2a 100644 --- a/datasets/librispeech_preprocess.py +++ b/dataset/librispeech_preprocess.py @@ -14,7 +14,7 @@ from absl import logging from pydub import AudioSegment -from datasets import librispeech_tokenizer +from dataset import librispeech_tokenizer gfile = tf.io.gfile copy = tf.io.gfile.copy diff --git a/datasets/librispeech_tokenizer.py b/dataset/librispeech_tokenizer.py similarity index 100% rename from datasets/librispeech_tokenizer.py rename to dataset/librispeech_tokenizer.py diff --git a/pyproject.toml b/pyproject.toml index e4de98f89..76bcfb7ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,9 @@ version_file = "algoperf/_version.py" ############################################################################### [project.optional-dependencies] # All workloads -full = ["algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]"] +full = [ + "algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt,lm]", +] # All workloads plus development dependencies full_dev = ["algoperf[full,dev]"] # Dependencies for developing the package @@ -88,6 +90,7 @@ librispeech_conformer = [ "pydub==0.25.1", ] wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.19.0"] +lm = ["transformers==4.25.4", "datasets==3.6.0"] # Frameworks jax_core_deps = [ diff --git a/submission_runner.py b/submission_runner.py index 552c99b79..1c51ec58f 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -253,11 +253,12 @@ def train_once( model_params, model_state = workload.init_model_fn(model_init_rng) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ - 'librispeech_conformer', - 'ogbg', - 'criteo1tb', - 'imagenet_vit', - 'librispeech_deepspeech', + 'librispeech_conformer', + 'ogbg', + 'criteo1tb', + 'imagenet_vit', + 'librispeech_deepspeech', + 'lm' ] eager_backend_workloads = [] aot_eager_backend_workloads = [] @@ -795,10 +796,11 @@ def main(_): workload_metadata = WORKLOADS[FLAGS.workload] if base_workload in [ - 'librispeech_conformer', - 'librispeech_deepspeech', - 'imagenet_vit', - 'criteo1tb', + 'librispeech_conformer', + 'librispeech_deepspeech', + 'imagenet_vit', + 'criteo1tb', + 'lm' ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80'