Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
1d81455
Merge pull request #847 from mlcommons/dev
priyakasimbeg Feb 27, 2025
da5f85a
first LM commit
Niccolo-Ajroldi Mar 11, 2025
a12a364
lm data pipeline
Niccolo-Ajroldi Mar 12, 2025
ca83ab8
testing
Niccolo-Ajroldi Mar 14, 2025
e3e78dc
LM workload tested torch pipeline
Niccolo-Ajroldi Mar 17, 2025
e619495
LM workload - fix torch tests
Niccolo-Ajroldi Mar 17, 2025
d8e9c56
add LM tests, remove dev files
Niccolo-Ajroldi Mar 18, 2025
6b4ff12
add LM tests, remove dev files
Niccolo-Ajroldi Mar 18, 2025
3c5c847
Stop tracking .gitignore
Niccolo-Ajroldi Mar 18, 2025
20d841b
Remove dev/ from repo, keep locally
Niccolo-Ajroldi Mar 18, 2025
f3ba059
fix comments
Niccolo-Ajroldi Mar 18, 2025
381451f
add class specifications
Niccolo-Ajroldi Mar 18, 2025
f111d2e
add workload LM info
Niccolo-Ajroldi Mar 18, 2025
808d398
restore data_utils.py tree map
Niccolo-Ajroldi Mar 18, 2025
35f8f89
fixed NFS bug
Niccolo-Ajroldi Mar 18, 2025
cbb6ee6
train/val split before concat
Niccolo-Ajroldi Mar 18, 2025
868987c
renamed datasets to avoid conflict with HF
Niccolo-Ajroldi Mar 19, 2025
8191f6d
Merge remote-tracking branch 'upstream/lm_workload' into lm_workload
Niccolo-Ajroldi Mar 19, 2025
dd59ded
renamed datasets to dataset
Niccolo-Ajroldi Mar 19, 2025
496b9c3
fix style
Niccolo-Ajroldi Mar 20, 2025
50989eb
fix formatting
Niccolo-Ajroldi Mar 20, 2025
5af0fdc
fix style
Niccolo-Ajroldi Mar 20, 2025
2683099
fix style
Niccolo-Ajroldi Mar 20, 2025
6b7ee29
fix yapf
Niccolo-Ajroldi Mar 20, 2025
46b645b
fix style
Niccolo-Ajroldi Mar 20, 2025
b3ae647
HF datasets pipeline
rka97 Mar 27, 2025
f095d4b
Testing with linear model
rka97 Mar 27, 2025
4189ae0
Merge branch 'jit_switch' into lm_workload
rka97 Mar 27, 2025
0c22f3d
lm workload with linear model
rka97 Apr 3, 2025
99c7b9b
add nanodo model
rka97 Apr 3, 2025
706d9f7
torch model
rka97 Apr 3, 2025
c335e34
lm workload dataset integration in jax
rka97 May 29, 2025
2d54365
lm workload dataset integration in jax
rka97 May 29, 2025
af8cce4
set package versions for transformers and datasets
priyakasimbeg Jun 5, 2025
d68c54e
use train_test_split method to shuffle and split fineweb-edu dataset
priyakasimbeg Jun 5, 2025
9737367
modifications to fwedu datasetup
priyakasimbeg Jun 9, 2025
1bf0750
rename fwedu data dir
priyakasimbeg Jun 9, 2025
a333391
fix
priyakasimbeg Jun 9, 2025
05dc4dd
add back batch mapping in tokenization for fwedu
priyakasimbeg Jun 9, 2025
b374cf8
debugging
priyakasimbeg Jun 10, 2025
c0c1e3c
debugging
priyakasimbeg Jun 10, 2025
f76dc39
debugging
priyakasimbeg Jun 10, 2025
e805fa7
use tfds to shuffle and split dataset
priyakasimbeg Jun 10, 2025
362cbda
Merge remote-tracking branch 'origin/dev' into lm_workload
rka97 Sep 11, 2025
c9e9abc
add command for fineweb-edu
priyakasimbeg Oct 2, 2025
e4323de
fix
priyakasimbeg Oct 2, 2025
f0c6e75
update calls to sharing utils
priyakasimbeg Oct 3, 2025
f4ffbe7
Fix torch sharding issue, update input pipeline and workload classes …
rka97 Oct 6, 2025
5c85c7e
test working, lm workload training not working (debugging)
rka97 Oct 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ scoring/plots/
!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv
!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv

algoperf/_version.py
algoperf/_version.py
2 changes: 2 additions & 0 deletions algoperf/param_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def pytorch_param_types(
param_types[name] = spec.ParameterType.ATTENTION_BIAS
elif 'in_proj' in name:
param_types[name] = spec.ParameterType.ATTENTION_QKV
elif 'qkv' in name:
param_types[name] = spec.ParameterType.ATTENTION_QKV
elif 'kv_proj' in name:
param_types[name] = spec.ParameterType.ATTENTION_KV
elif 'k_proj' in name or 'key' in name:
Expand Down
Empty file.
154 changes: 154 additions & 0 deletions algoperf/workloads/lm/input_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""Input pipeline for a LM dataset."""
import functools
import os
from typing import Optional

import jax
import jax.numpy as jnp
import tensorflow as tf
import torch
import torch.nn.functional as F
from transformers import GPT2Tokenizer

from algoperf import data_utils
from algoperf.pytorch_utils import pytorch_setup
from datasets import load_dataset
from datasets import load_from_disk

RANK = pytorch_setup()[1]
# Avoid multithreading in all processes but the first (rank 0).
# This ensures that only the primary process (RANK == 0) uses TensorFlow's
# automatic optimization (AUTOTUNE), while other processes disable it (None).
# tf.data.AUTOTUNE is a constant that lets TensorFlow automatically determine
# the optimal number of elements to prefetch or parallelize for dataset
# operations, improving performance.
AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None


def get_hf_dataloader(cache_dir: str,
data_rng: jax.random.PRNGKey,
batch_size: int = 8,
seq_len: int = 32,
framework: str = "torch",
split="train"):
"""
Create a data loader from HuggingFace's FineWeb dataset.

Args:
cache_dir: Directory to cache the dataset
batch_size: Number of sequences per batch
seq_len: Length of each sequence
framework: Either "torch" or "jax" to specify output tensor type
split: Dataset split to load
"""
# Initialize tokenizer and get vocab size
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
vocab_size = tokenizer.vocab_size
# Load the FineWeb dataset in streaming mode
fw = load_dataset(
"HuggingFaceFW/fineweb-edu",
name="sample-10BT",
split=split,
streaming=True,
cache_dir=cache_dir)
fw = fw.batch(batch_size=batch_size, drop_last_batch=True)
if split in ['train', 'eval_train']:
fw = fw.shuffle(seed=int(data_rng[-1]))

def _tokenize(x):
"""Tokenize and pad text to seq_len+1 tokens."""
if framework == "torch":
tokens = tokenizer(x, return_tensors="pt")["input_ids"].squeeze()
pad_length = seq_len - tokens.shape[0]
if pad_length > 0:
tokens = F.pad(tokens, pad_length, value=tokenizer.pad_token_id)
elif framework == "jax":
tokens = tokenizer(x, return_tensors="jax")["input_ids"].squeeze()
pad_length = seq_len - tokens.shape[0]
if pad_length > 0:
tokens = jnp.pad(
tokens,
pad_length,
mode="constant",
constant_values=tokenizer.pad_token_id)
return tokens[:seq_len + 1]

def batch_iterator():
for doc in fw:
if framework == "torch":
token_ids = torch.stack([_tokenize(x) for x in doc['text']])
# Take first seq_len+1 tokens and convert to one-hot
tokens = F.one_hot(token_ids, num_classes=vocab_size).float()
# Split into input/target
inputs, targets = tokens[:, :-1, :], tokens[:, 1:, :]
inputs, targets = inputs.to("cuda"), targets.to("cuda")
elif framework == "jax":
token_ids = jnp.stack([_tokenize(x) for x in doc['text']])
tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size)
inputs, targets = tokens[:, :-1], tokens[:, 1:]
inputs, targets = jax.device_put(inputs), jax.device_put(targets)
yield {'inputs': inputs, 'targets': targets}

return batch_iterator()


def get_lm_dataset(data_rng: jax.random.PRNGKey,
split: str,
data_dir: str,
global_batch_size: int,
num_batches: Optional[int] = None):
"""Load HF dataset and return a TF dataset."""

dataset_path = os.path.join(data_dir, split)
dataset = load_from_disk(dataset_path)

is_training = split == "train"
shuffle = split in ['train', 'eval_train']

dataset.set_format("tensorflow") # tf.int64 # TODO (nico): is this needed?

def tf_generator():
"""Generates data in a TensorFlow-friendly format."""
for example in dataset:
yield {
"inputs": example["input_ids"][:-1],
"targets": example["input_ids"][1:],
}

# Create a TensorFlow dataset
ds = tf.data.Dataset.from_generator(
tf_generator,
output_signature={
"inputs": tf.TensorSpec(shape=(None,), dtype=tf.int32),
"targets": tf.TensorSpec(shape=(None,), dtype=tf.int32),
})

# Avoid creating too many threads when using PyTorch DDP.
# Limits TensorFlow's threading for non-primary processes (RANK != 0)
if RANK != 0:
options = tf.data.Options()
options.threading.private_threadpool_size = 1
ds = ds.with_options(options)

if shuffle:
ds = ds.shuffle(buffer_size=1024, seed=data_rng[0])

if is_training:
ds = ds.repeat()

# Batch the dataset, grouping consecutive elements into fixed-size chunks.
ds = ds.batch(global_batch_size, drop_remainder=is_training)
ds = ds.prefetch(AUTOTUNE)

# Limit the dataset to a fixed number of batches if `num_batches` is specified
if num_batches:
ds = ds.take(num_batches)

# Shard the dataset across multiple GPUs/TPUs if necessary
ds = map(
functools.partial(
data_utils.shard_and_maybe_pad_np,
global_batch_size=global_batch_size),
ds)

return ds
Empty file.
19 changes: 19 additions & 0 deletions algoperf/workloads/lm/lm_jax/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from flax import linen as nn
import jax.numpy as jnp

class LinearModel(nn.Module):
vocab_size: int

@nn.compact
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
x = nn.Dense(
10,
kernel_init=nn.initializers.normal(0.02),
bias_init=nn.initializers.zeros
)(inputs)
return nn.Dense(
self.vocab_size,
kernel_init=nn.initializers.normal(0.02),
bias_init=nn.initializers.zeros,
name="output"
)(x)
Loading
Loading