Skip to content

Commit 0698e34

Browse files
committed
librispeech_conformer now running
Still need to test out (a) output losses, (b) speed, and (c) look into other librispeech.
1 parent e6037d6 commit 0698e34

File tree

2 files changed

+77
-29
lines changed
  • algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax
  • reference_algorithms/paper_baselines/nesterov/jax

2 files changed

+77
-29
lines changed

algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py

Lines changed: 75 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
import flax.linen as nn
77
import jax
88
from jax import lax
9+
from jax.sharding import NamedSharding, PartitionSpec as P
10+
11+
from algorithmic_efficiency import sharding_utils
912
import jax.numpy as jnp
1013
import numpy as np
1114
import optax
@@ -21,7 +24,6 @@
2124
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax import \
2225
models
2326

24-
2527
class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload):
2628

2729
def __init__(self,
@@ -93,8 +95,16 @@ def init_model_fn(
9395

9496
self._param_shapes = param_utils.jax_param_shapes(params)
9597
self._param_types = param_utils.jax_param_types(self._param_shapes)
96-
model_state = jax_utils.replicate(model_state)
97-
params = jax_utils.replicate(params)
98+
99+
# Add sharding
100+
mesh = sharding_utils.get_mesh()
101+
params = jax.tree_map(
102+
lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh)),
103+
params)
104+
model_state = jax.tree_map(
105+
lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh)),
106+
model_state)
107+
98108
return params, model_state
99109

100110
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
@@ -176,6 +186,7 @@ def _build_input_queue(
176186
'targets': (targets.numpy(), target_paddings.numpy()),
177187
}
178188

189+
# Use data_utils.shard_and_maybe_pad_np to handle sharding
179190
padded_batch = data_utils.shard_and_maybe_pad_np(
180191
numpy_batch, padding_value=1.0)
181192
yield padded_batch
@@ -300,11 +311,16 @@ def greedy_decode(
300311
return hyp, hyp_paddings
301312

302313
@functools.partial(
303-
jax.pmap,
304-
axis_name='batch',
305-
in_axes=(None, 0, 0, 0, None),
306-
static_broadcasted_argnums=(0,))
307-
def eval_step_pmapped(
314+
jax.jit,
315+
in_shardings=(
316+
sharding_utils.get_replicated_sharding(), # params
317+
sharding_utils.get_naive_sharding_spec(), # batch
318+
sharding_utils.get_replicated_sharding(), # model_state
319+
sharding_utils.get_replicated_sharding(), # rng
320+
),
321+
out_shardings=sharding_utils.get_naive_sharding_spec(),
322+
static_argnums=(0,))
323+
def _eval_step(
308324
self,
309325
params: spec.ParameterContainer,
310326
batch: Dict[str, spec.Tensor],
@@ -322,13 +338,39 @@ def eval_step_pmapped(
322338
loss = self.loss_fn(batch['targets'], (logits, logit_paddings))
323339

324340
targets, target_paddings = batch['targets']
325-
return self.metrics_bundle.gather_from_model_output(
326-
loss_dict=loss,
327-
decoded=decoded,
328-
decoded_paddings=decoded_paddings,
329-
targets=targets,
330-
target_paddings=target_paddings,
331-
axis_name='batch')
341+
# Convert metrics bundle to dictionary
342+
metrics_dict = {
343+
'loss_per_example': loss['per_example'],
344+
'decoded': decoded,
345+
'decoded_paddings': decoded_paddings,
346+
'targets': targets,
347+
'target_paddings': target_paddings,
348+
'n_valid_examples': jnp.zeros((len(jax.devices()), 1)) + loss['n_valid_examples']
349+
}
350+
return metrics_dict
351+
352+
def eval_step(
353+
self,
354+
params: spec.ParameterContainer,
355+
batch: Dict[str, spec.Tensor],
356+
model_state: spec.ModelAuxiliaryState,
357+
rng: spec.RandomState):
358+
"""Evaluates the model and returns a metrics bundle."""
359+
metrics_dict = self._eval_step(params, batch, model_state, rng)
360+
361+
# Convert dictionary back to metrics bundle
362+
metrics = self.metrics_bundle.single_from_model_output(
363+
loss_dict={
364+
'summed': metrics_dict['loss_per_example'].sum(),
365+
'per_example': metrics_dict['loss_per_example'],
366+
'n_valid_examples': metrics_dict['n_valid_examples'].sum()
367+
},
368+
decoded=metrics_dict['decoded'],
369+
decoded_paddings=metrics_dict['decoded_paddings'],
370+
targets=metrics_dict['targets'],
371+
target_paddings=metrics_dict['target_paddings'])
372+
373+
return metrics
332374

333375
def _eval_model_on_split(self,
334376
split: str,
@@ -353,10 +395,10 @@ def _eval_model_on_split(self,
353395
metrics_report = None
354396
for _ in range(num_batches):
355397
eval_batch = next(self._eval_iters[split])
356-
computed_metrics = self.eval_step_pmapped(params,
357-
eval_batch,
358-
model_state,
359-
rng).unreplicate()
398+
computed_metrics = self.eval_step(params,
399+
eval_batch,
400+
model_state,
401+
rng)
360402

361403
if metrics_report is None:
362404
metrics_report = computed_metrics
@@ -368,15 +410,22 @@ def _eval_model_on_split(self,
368410

369411
return computed_metrics
370412

413+
@functools.partial(
414+
jax.jit,
415+
in_shardings=(
416+
sharding_utils.get_replicated_sharding(), # model_state
417+
),
418+
out_shardings=sharding_utils.get_replicated_sharding(),
419+
static_argnums=(0,)
420+
)
371421
def sync_batch_stats(
372422
self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState:
373-
# An axis_name is passed to pmap which can then be used by pmean.
374-
# In this case each device has its own version of the batch statistics and
375-
# we average them.
376-
avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x')
377-
new_model_state = model_state.copy(
378-
{'batch_stats': avg_fn(model_state['batch_stats'])})
379-
return new_model_state
423+
"""Sync batch statistics across replicas."""
424+
# Replace pmean with direct mean across devices
425+
new_batch_stats = jax.tree_map(
426+
lambda x: jnp.mean(x, axis=0),
427+
model_state['batch_stats'])
428+
return model_state.copy({'batch_stats': new_batch_stats})
380429

381430

382431
class LibriSpeechConformerAttentionTemperatureWorkload(

reference_algorithms/paper_baselines/nesterov/jax/submission.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ def update_params(workload: spec.Workload,
159159
del eval_results
160160

161161
optimizer_state, opt_update_fn = optimizer_state
162-
per_device_rngs = jax.random.split(rng, jax.local_device_count())
163162
if hasattr(hyperparameters, 'label_smoothing'):
164163
label_smoothing = hyperparameters.label_smoothing
165164
else:
@@ -182,7 +181,7 @@ def update_params(workload: spec.Workload,
182181
replicated, # optimizer_state
183182
replicated, # current_param_container
184183
sharded, # batch
185-
sharded, # per_device_rngs
184+
replicated, # rngs
186185
replicated, # grad_clip
187186
replicated # label_smoothing
188187
)
@@ -206,7 +205,7 @@ def update_params(workload: spec.Workload,
206205
optimizer_state,
207206
current_param_container,
208207
batch,
209-
per_device_rngs,
208+
rng,
210209
grad_clip,
211210
label_smoothing)
212211
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs

0 commit comments

Comments
 (0)