Skip to content

Commit 6b55adf

Browse files
committed
refactor evaluation pipeline for lm
1 parent af91b12 commit 6b55adf

File tree

3 files changed

+35
-34
lines changed

3 files changed

+35
-34
lines changed

algoperf/workloads/lm/input_pipeline.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Optional
66

77
import jax
8+
import numpy as np
89
import tensorflow as tf
910

1011
from algoperf import data_utils
@@ -106,7 +107,7 @@ def get_lm_dataset(
106107
repeated_sequences_dataset = shuffled_sequences_ds.repeat()
107108
ds = repeated_sequences_dataset.batch(
108109
global_batch_size, drop_remainder=False
109-
).take(100).prefetch(tf.data.experimental.AUTOTUNE)
110+
).prefetch(tf.data.experimental.AUTOTUNE)
110111
elif split == 'eval_train':
111112
ds = batch_with_padding(
112113
sequences_ds,
@@ -115,7 +116,11 @@ def get_lm_dataset(
115116
'inputs': (global_batch_size, None),
116117
'targets': (global_batch_size, None),
117118
},
118-
).take(100).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size of validation
119+
)
120+
ds = ds.map(lambda x: {'inputs': x['inputs'],
121+
'targets': x['targets'],
122+
'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)})
123+
ds = ds.take(1000).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size of validation
119124
elif split == 'validation':
120125
ds = batch_with_padding(
121126
sequences_ds,
@@ -124,6 +129,10 @@ def get_lm_dataset(
124129
'inputs': (global_batch_size, None),
125130
'targets': (global_batch_size, None),
126131
},
127-
).take(100).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size
132+
)
133+
ds = ds.map(lambda x: {'inputs': x['inputs'],
134+
'targets': x['targets'],
135+
'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)})
136+
ds = ds.take(1000).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size
128137

129138
return ds

algoperf/workloads/lm/lm_jax/workload.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,13 @@ def _build_input_queue(self,
2828
"""Build an input queue using pre-cached FineWeb dataset."""
2929
del num_batches
3030
del repeat_final_dataset
31-
loader = get_data_iter(
31+
ds = get_data_iter(
3232
data_rng=data_rng,
3333
split=split,
3434
data_dir=data_dir,
3535
global_batch_size=global_batch_size)
36-
loader = map(jax_sharding_utils.shard_along_batch_dim, loader)
37-
return loader
38-
39-
def _build_hf_input_queue(self,
40-
data_rng: jax.random.PRNGKey,
41-
split: str,
42-
data_dir: str,
43-
global_batch_size: int,
44-
num_batches: Optional[int] = None,
45-
repeat_final_dataset: bool = False):
46-
"""Build an input queue using HuggingFace FineWeb dataset."""
47-
del num_batches
48-
del repeat_final_dataset
49-
iter = get_data_iter(data_rng, split, data_dir, global_batch_size)
50-
return iter
36+
ds = map(jax_sharding_utils.shard_along_batch_dim, ds)
37+
return ds
5138

5239
def init_model_fn(
5340
self,
@@ -156,9 +143,10 @@ def _eval_batch(self,
156143
"""Evaluate the model on a single batch."""
157144
logits, _ = self.model_fn(
158145
params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False)
159-
targets = batch['targets']
160-
161146
# Calculate cross-entropy loss
162147
# TODO(kasimbeg): add weights?
163-
loss_metrics = self.compute_weighted_cross_entropy(logits, targets)
164-
return loss_metrics
148+
metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights'])
149+
return {
150+
'loss': metrics['summed'],
151+
'denominator': metrics['n_valid_examples'],
152+
}

algoperf/workloads/lm/workload.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import abc
44
import math
5+
import numpy as np
56
import os
67
from typing import Any, Dict, Optional
78

@@ -44,11 +45,11 @@ def validation_target_value(self) -> float:
4445
return 20.0 # Target perplexity
4546

4647
def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool:
47-
return eval_result['test/ppl'] <= self.test_target_value
48+
return True # No test targets
4849

4950
@property
5051
def test_target_value(self) -> float:
51-
return 20.0 # Target perplexity
52+
return None # No test targets
5253

5354
@property
5455
def loss_type(self) -> spec.LossType:
@@ -60,19 +61,19 @@ def num_train_examples(self) -> int:
6061

6162
@property
6263
def num_eval_train_examples(self) -> int:
63-
return 10000 # Subset for evaluation
64+
return 500 # Subset for evaluation. # TODO(kasimbeg): update
6465

6566
@property
6667
def num_validation_examples(self) -> int:
67-
return 50000
68+
return 500 # TODO(kasimbeg update)
6869

6970
@property
7071
def num_test_examples(self) -> int:
71-
return 50000
72+
return 0
7273

7374
@property
7475
def eval_batch_size(self) -> int:
75-
return 8
76+
return 32
7677

7778
@property
7879
def train_mean(self):
@@ -84,7 +85,7 @@ def train_stddev(self):
8485

8586
@property
8687
def max_allowed_runtime_sec(self) -> int:
87-
return 3600 * 4 # 4 hours
88+
return 3600 * 5 # 4 hours
8889

8990
@property
9091
def eval_period_time_sec(self) -> int:
@@ -93,7 +94,7 @@ def eval_period_time_sec(self) -> int:
9394
@property
9495
def step_hint(self) -> int:
9596
"""Approx. steps the baseline can do in the allowed runtime budget."""
96-
return 7000
97+
return 54000
9798

9899
@property
99100
def pre_ln(self) -> bool:
@@ -141,7 +142,7 @@ def _eval_batch(
141142
)
142143

143144
loss_dict = self.loss_fn(batch['targets'], logits)
144-
return loss_dict['summed']
145+
return loss_dict
145146

146147
def _eval_model_on_split(
147148
self,
@@ -170,12 +171,15 @@ def _eval_model_on_split(
170171
eval_metrics = {}
171172
for _ in range(num_batches):
172173
eval_batch = next(self._eval_iters[split])
173-
metrics = self._eval_batch(params, eval_batch)
174+
metrics = self._eval_batch(params, eval_batch, model_state, rng)
174175
for metric_name, metric_value in metrics.items():
175176
if metric_name not in eval_metrics:
176177
eval_metrics[metric_name] = 0.0
177178
eval_metrics[metric_name] += metric_value
178-
eval_results = self._normalize_eval_metrics(num_examples, eval_metrics)
179+
180+
eval_results = self._normalize_eval_metrics(num_examples, eval_metrics)
181+
eval_results['ppl'] = np.exp(eval_results['loss'])
182+
print(eval_results)
179183

180184
return eval_results
181185

0 commit comments

Comments
 (0)