Skip to content

Commit a3a9b9f

Browse files
committed
simplify changes in cifar jax
1 parent b14174b commit a3a9b9f

File tree

2 files changed

+47
-63
lines changed

2 files changed

+47
-63
lines changed

algoperf/workloads/cifar/cifar_jax/input_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import functools
99
from typing import Dict, Iterator, Tuple
1010

11+
from flax import jax_utils
1112
import jax
1213
import tensorflow as tf
1314
import tensorflow_datasets as tfds
@@ -170,7 +171,6 @@ def create_input_iter(
170171
functools.partial(
171172
shard_and_maybe_pad_np, global_batch_size=global_batch_size),
172173
ds)
173-
# FIXME(rka97): Figure out how to do prefetching+sharding.
174-
# TODO (kasimbeg)
175-
# it = jax_utils.prefetch_to_device(it, 2)
174+
175+
it = jax_utils.prefetch_to_device(it, 2)
176176
return it

algoperf/workloads/cifar/cifar_jax/workload.py

Lines changed: 44 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,9 @@ def _build_cifar_dataset(
2828
data_dir: str,
2929
batch_size: int,
3030
cache: Optional[bool] = None,
31-
repeat_final_dataset: Optional[bool] = None,
31+
repeat_final_dataset: Optional[bool] = None
3232
) -> Iterator[Dict[str, spec.Tensor]]:
33-
data_dir = data_dir + "/cifar10"
34-
ds_builder = tfds.builder("cifar10:3.0.2", data_dir=data_dir)
33+
ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir)
3534
train = split == 'train'
3635
assert self.num_train_examples + self.num_validation_examples == 50000
3736
if split in ['train', 'eval_train']:
@@ -92,15 +91,17 @@ def init_model_fn(
9291
model = model_cls(num_classes=self._num_classes, dtype=jnp.float32)
9392
self._model = model
9493
input_shape = (1, 32, 32, 3)
95-
variables = jax.jit(model.init)({"params": rng},
94+
variables = jax.jit(model.init)({'params': rng},
9695
jnp.ones(input_shape, model.dtype))
9796
model_state, params = pop(variables, 'params')
9897
self._param_shapes = param_utils.jax_param_shapes(params)
9998
self._param_types = param_utils.jax_param_types(self._param_shapes)
99+
model_state = jax_sharding_utils.replicate(params)
100+
params = jax_sharding_utils.replicate(params)
100101
return params, model_state
101102

102103
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
103-
return param_key == "Dense_0"
104+
return param_key == 'Dense_0'
104105

105106
def model_fn(
106107
self,
@@ -114,19 +115,19 @@ def model_fn(
114115
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
115116
del mode
116117
del rng
117-
variables = {"params": params, **model_state}
118+
variables = {'params': params, **model_state}
118119
if update_batch_norm:
119120
logits, new_model_state = self._model.apply(
120121
variables,
121-
augmented_and_preprocessed_input_batch["inputs"],
122+
augmented_and_preprocessed_input_batch['inputs'],
122123
update_batch_norm=update_batch_norm,
123124
mutable=['batch_stats'],
124125
use_running_average_bn=use_running_average_bn)
125126
return logits, new_model_state
126127
else:
127128
logits = self._model.apply(
128129
variables,
129-
augmented_and_preprocessed_input_batch["inputs"],
130+
augmented_and_preprocessed_input_batch['inputs'],
130131
update_batch_norm=update_batch_norm,
131132
mutable=False,
132133
use_running_average_bn=use_running_average_bn)
@@ -139,15 +140,13 @@ def loss_fn(
139140
label_batch: spec.Tensor, # Dense or one-hot labels.
140141
logits_batch: spec.Tensor,
141142
mask_batch: Optional[spec.Tensor] = None,
142-
label_smoothing: float = 0.0,
143-
) -> Dict[str, spec.Tensor]: # differentiable
143+
label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
144144
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
145145
146-
Return {'summed': scalar summed loss,
147-
'n_valid_examples': scalar number of
148-
valid examples in batch, 'per_example': 1-d array of per-example losses}
149-
(not synced across devices).
150-
"""
146+
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
147+
valid examples in batch, 'per_example': 1-d array of per-example losses}
148+
(not synced across devices).
149+
"""
151150
one_hot_targets = jax.nn.one_hot(label_batch, self._num_classes)
152151
smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing)
153152
per_example_losses = -jnp.sum(
@@ -160,66 +159,51 @@ def loss_fn(
160159
n_valid_examples = len(per_example_losses)
161160
summed_loss = per_example_losses.sum()
162161
return {
163-
"summed": summed_loss,
164-
"n_valid_examples": n_valid_examples,
165-
"per_example": per_example_losses,
162+
'summed': summed_loss,
163+
'n_valid_examples': n_valid_examples,
164+
'per_example': per_example_losses,
166165
}
167166

168167
def _compute_metrics(self,
169168
logits: spec.Tensor,
170169
labels: spec.Tensor,
171170
weights: spec.Tensor) -> Dict[str, spec.Tensor]:
172-
summed_loss = self.loss_fn(labels, logits, weights)["summed"]
171+
summed_loss = self.loss_fn(labels, logits, weights)['summed']
173172
# Number of correct predictions.
174173
accuracy = jnp.sum((jnp.argmax(logits, -1) == labels) * weights)
175-
return jnp.array(summed_loss), jnp.array(accuracy)
174+
metrics = {
175+
'loss': summed_loss,
176+
'accuracy': accuracy,
177+
}
178+
return metrics
176179

180+
@functools.partial(
181+
jax.jit,
182+
in_shardings=(
183+
jax_sharding_utils.get_replicated_sharding(), # params
184+
jax_sharding_utils.get_batch_sharding(), # batch
185+
jax_sharding_utils.get_replicated_sharding(), # model_state
186+
jax_sharding _utils.get_batch_sharding(), # rng
187+
),
188+
)
177189
def _eval_model(
178190
self,
179191
params: spec.ParameterContainer,
180192
batch: Dict[str, spec.Tensor],
181193
model_state: spec.ModelAuxiliaryState,
182-
rng: spec.RandomState,
183-
) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
194+
rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
184195
"""Return the mean accuracy and loss as a dict."""
185-
186-
@functools.partial(
187-
jax.jit,
188-
in_shardings=(
189-
jax_sharding_utils.get_replicated_sharding(), # params
190-
jax_sharding_utils.get_batch_sharding(), # batch
191-
jax_sharding_utils.get_replicated_sharding(), # model_state
192-
jax_sharding_utils.get_batch_sharding(), # rng
193-
),
194-
)
195-
def _per_device_eval_model(
196-
params: spec.ParameterContainer,
197-
batch: Dict[str, spec.Tensor],
198-
model_state: spec.ModelAuxiliaryState,
199-
rng: spec.RandomState,
200-
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
201-
logits, _ = self.model_fn(
202-
params,
203-
batch,
204-
model_state,
205-
spec.ForwardPassMode.EVAL,
206-
rng,
207-
update_batch_norm=False,
208-
)
209-
weights = batch.get("weights")
210-
if weights is None:
211-
weights = jnp.ones(len(logits))
212-
return self._compute_metrics(logits, batch["targets"], weights)
213-
214-
losses, accuracies = _per_device_eval_model(params, batch, model_state, rng)
215-
metrics = {
216-
"loss":
217-
jnp.mean(losses, axis=0) if losses.ndim > 0 else losses,
218-
"accuracy":
219-
(jnp.mean(accuracies, axis=0) if accuracies.ndim > 0 else accuracies
220-
),
221-
}
222-
return metrics
196+
logits, _ = self.model_fn(
197+
params,
198+
batch,
199+
model_state,
200+
spec.ForwardPassMode.EVAL,
201+
rng,
202+
update_batch_norm=False)
203+
weights = batch.get('weights')
204+
if weights is None:
205+
weights = jnp.ones(len(logits))
206+
return self._compute_metrics(logits, batch['targets'], weights)
223207

224208
def _normalize_eval_metrics(
225209
self, num_examples: int, total_metrics: Dict[str,

0 commit comments

Comments
 (0)