66import flax .linen as nn
77import jax
88from jax import lax
9+ from jax .sharding import NamedSharding , PartitionSpec as P
10+
11+ from algorithmic_efficiency import sharding_utils
912import jax .numpy as jnp
1013import numpy as np
1114import optax
2124from algorithmic_efficiency .workloads .librispeech_conformer .librispeech_jax import \
2225 models
2326
24-
2527class LibriSpeechConformerWorkload (workload .BaseLibrispeechWorkload ):
2628
2729 def __init__ (self ,
@@ -93,8 +95,16 @@ def init_model_fn(
9395
9496 self ._param_shapes = param_utils .jax_param_shapes (params )
9597 self ._param_types = param_utils .jax_param_types (self ._param_shapes )
96- model_state = jax_utils .replicate (model_state )
97- params = jax_utils .replicate (params )
98+
99+ # Add sharding
100+ mesh = sharding_utils .get_mesh ()
101+ params = jax .tree_map (
102+ lambda x : jax .device_put (x , sharding_utils .get_replicated_sharding (mesh )),
103+ params )
104+ model_state = jax .tree_map (
105+ lambda x : jax .device_put (x , sharding_utils .get_replicated_sharding (mesh )),
106+ model_state )
107+
98108 return params , model_state
99109
100110 def is_output_params (self , param_key : spec .ParameterKey ) -> bool :
@@ -176,6 +186,7 @@ def _build_input_queue(
176186 'targets' : (targets .numpy (), target_paddings .numpy ()),
177187 }
178188
189+ # Use data_utils.shard_and_maybe_pad_np to handle sharding
179190 padded_batch = data_utils .shard_and_maybe_pad_np (
180191 numpy_batch , padding_value = 1.0 )
181192 yield padded_batch
@@ -300,11 +311,16 @@ def greedy_decode(
300311 return hyp , hyp_paddings
301312
302313 @functools .partial (
303- jax .pmap ,
304- axis_name = 'batch' ,
305- in_axes = (None , 0 , 0 , 0 , None ),
306- static_broadcasted_argnums = (0 ,))
307- def eval_step_pmapped (
314+ jax .jit ,
315+ in_shardings = (
316+ sharding_utils .get_replicated_sharding (), # params
317+ sharding_utils .get_naive_sharding_spec (), # batch
318+ sharding_utils .get_replicated_sharding (), # model_state
319+ sharding_utils .get_replicated_sharding (), # rng
320+ ),
321+ out_shardings = sharding_utils .get_naive_sharding_spec (),
322+ static_argnums = (0 ,))
323+ def _eval_step (
308324 self ,
309325 params : spec .ParameterContainer ,
310326 batch : Dict [str , spec .Tensor ],
@@ -322,13 +338,39 @@ def eval_step_pmapped(
322338 loss = self .loss_fn (batch ['targets' ], (logits , logit_paddings ))
323339
324340 targets , target_paddings = batch ['targets' ]
325- return self .metrics_bundle .gather_from_model_output (
326- loss_dict = loss ,
327- decoded = decoded ,
328- decoded_paddings = decoded_paddings ,
329- targets = targets ,
330- target_paddings = target_paddings ,
331- axis_name = 'batch' )
341+ # Convert metrics bundle to dictionary
342+ metrics_dict = {
343+ 'loss_per_example' : loss ['per_example' ],
344+ 'decoded' : decoded ,
345+ 'decoded_paddings' : decoded_paddings ,
346+ 'targets' : targets ,
347+ 'target_paddings' : target_paddings ,
348+ 'n_valid_examples' : jnp .zeros ((len (jax .devices ()), 1 )) + loss ['n_valid_examples' ]
349+ }
350+ return metrics_dict
351+
352+ def eval_step (
353+ self ,
354+ params : spec .ParameterContainer ,
355+ batch : Dict [str , spec .Tensor ],
356+ model_state : spec .ModelAuxiliaryState ,
357+ rng : spec .RandomState ):
358+ """Evaluates the model and returns a metrics bundle."""
359+ metrics_dict = self ._eval_step (params , batch , model_state , rng )
360+
361+ # Convert dictionary back to metrics bundle
362+ metrics = self .metrics_bundle .single_from_model_output (
363+ loss_dict = {
364+ 'summed' : metrics_dict ['loss_per_example' ].sum (),
365+ 'per_example' : metrics_dict ['loss_per_example' ],
366+ 'n_valid_examples' : metrics_dict ['n_valid_examples' ].sum ()
367+ },
368+ decoded = metrics_dict ['decoded' ],
369+ decoded_paddings = metrics_dict ['decoded_paddings' ],
370+ targets = metrics_dict ['targets' ],
371+ target_paddings = metrics_dict ['target_paddings' ])
372+
373+ return metrics
332374
333375 def _eval_model_on_split (self ,
334376 split : str ,
@@ -353,10 +395,10 @@ def _eval_model_on_split(self,
353395 metrics_report = None
354396 for _ in range (num_batches ):
355397 eval_batch = next (self ._eval_iters [split ])
356- computed_metrics = self .eval_step_pmapped (params ,
357- eval_batch ,
358- model_state ,
359- rng ). unreplicate ( )
398+ computed_metrics = self .eval_step (params ,
399+ eval_batch ,
400+ model_state ,
401+ rng )
360402
361403 if metrics_report is None :
362404 metrics_report = computed_metrics
@@ -368,15 +410,22 @@ def _eval_model_on_split(self,
368410
369411 return computed_metrics
370412
413+ @functools .partial (
414+ jax .jit ,
415+ in_shardings = (
416+ sharding_utils .get_replicated_sharding (), # model_state
417+ ),
418+ out_shardings = sharding_utils .get_replicated_sharding (),
419+ static_argnums = (0 ,)
420+ )
371421 def sync_batch_stats (
372422 self , model_state : spec .ModelAuxiliaryState ) -> spec .ModelAuxiliaryState :
373- # An axis_name is passed to pmap which can then be used by pmean.
374- # In this case each device has its own version of the batch statistics and
375- # we average them.
376- avg_fn = jax .pmap (lambda x : lax .pmean (x , 'x' ), 'x' )
377- new_model_state = model_state .copy (
378- {'batch_stats' : avg_fn (model_state ['batch_stats' ])})
379- return new_model_state
423+ """Sync batch statistics across replicas."""
424+ # Replace pmean with direct mean across devices
425+ new_batch_stats = jax .tree_map (
426+ lambda x : jnp .mean (x , axis = 0 ),
427+ model_state ['batch_stats' ])
428+ return model_state .copy ({'batch_stats' : new_batch_stats })
380429
381430
382431class LibriSpeechConformerAttentionTemperatureWorkload (
0 commit comments