@@ -192,24 +192,18 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
192192 workload .param_shapes )
193193 optimizer_state = opt_init_fn (params_zeros_like )
194194
195- return jax_utils .replicate (optimizer_state ), opt_update_fn
196-
197-
198- @functools .partial (
199- jax .pmap ,
200- axis_name = 'batch' ,
201- in_axes = (None , None , 0 , 0 , 0 , 0 , 0 , None , None ),
202- static_broadcasted_argnums = (0 , 1 ),
203- donate_argnums = (2 , 3 , 4 ))
204- def pmapped_train_step (workload ,
205- opt_update_fn ,
206- model_state ,
207- optimizer_state ,
208- current_param_container ,
209- batch ,
210- rng ,
211- grad_clip ,
212- label_smoothing ):
195+ return optimizer_state , opt_update_fn
196+
197+
198+ def train_step (workload ,
199+ opt_update_fn ,
200+ model_state ,
201+ optimizer_state ,
202+ current_param_container ,
203+ batch ,
204+ rng ,
205+ grad_clip ,
206+ label_smoothing ):
213207
214208 def _loss_fn (params ):
215209 """Loss function used for training."""
@@ -232,9 +226,7 @@ def _loss_fn(params):
232226 grad_fn = jax .value_and_grad (_loss_fn , has_aux = True )
233227 (summed_loss , (n_valid_examples , new_model_state )), grad = grad_fn (
234228 current_param_container )
235- # Get correct global mean loss and grad.
236- (summed_loss , n_valid_examples , grad ) = lax .psum (
237- (summed_loss , n_valid_examples , grad ), axis_name = 'batch' )
229+ # Compute mean loss and grad
238230 loss = summed_loss / n_valid_examples
239231 grad = jax .tree .map (lambda x : x / n_valid_examples , grad )
240232
@@ -272,7 +264,6 @@ def update_params(
272264 del eval_results
273265
274266 optimizer_state , opt_update_fn = optimizer_state
275- per_device_rngs = jax .random .split (rng , jax .local_device_count ())
276267 if hasattr (hyperparameters , 'label_smoothing' ):
277268 label_smoothing = hyperparameters .label_smoothing
278269 else :
@@ -281,13 +272,48 @@ def update_params(
281272 grad_clip = hyperparameters .grad_clip
282273 else :
283274 grad_clip = None
284- outputs = pmapped_train_step (workload ,
275+
276+ # Get mesh
277+ mesh = jax_sharding_utils .get_mesh ()
278+ # Create shardings for each argument
279+ replicated = jax_sharding_utils .get_replicated_sharding (mesh ) # No partitioning
280+ sharded = jax_sharding_utils .get_batch_sharding (
281+ mesh ) # Partition along batch dimension
282+
283+ # Create the sharding rules for each argument
284+ arg_shardings = (
285+ # workload is static
286+ # opt_update_fn is static
287+ replicated , # model_state
288+ replicated , # optimizer_state
289+ replicated , # current_param_container
290+ sharded , # batch
291+ replicated , # rng
292+ replicated , # grad_clip
293+ replicated # label_smoothing
294+ )
295+ out_shardings = (
296+ replicated , # new_optimizer_state
297+ replicated , # updated_params
298+ replicated , # new_model_state
299+ replicated , # loss
300+ replicated # grad_norm
301+ )
302+ # Jit with shardings
303+ jitted_train_step = jax .jit (
304+ train_step ,
305+ static_argnums = (0 , 1 ),
306+ donate_argnums = (2 , 3 , 4 ),
307+ in_shardings = arg_shardings ,
308+ out_shardings = out_shardings )
309+
310+ new_optimizer_state , new_params , new_model_state , loss , grad_norm = jitted_train_step (workload ,
285311 opt_update_fn ,
286312 model_state ,
287313 optimizer_state ,
288314 current_param_container ,
289315 batch ,
290- per_device_rngs ,
316+ rng ,
291317 grad_clip ,
292318 label_smoothing )
293319 new_optimizer_state , new_params , new_model_state , loss , grad_norm = outputs
@@ -296,8 +322,8 @@ def update_params(
296322 if global_step % 100 == 0 and workload .metrics_logger is not None :
297323 workload .metrics_logger .append_scalar_metrics (
298324 {
299- 'loss' : loss [ 0 ] ,
300- 'grad_norm' : grad_norm [ 0 ],
325+ 'loss' : loss . item () ,
326+ 'grad_norm' : grad_norm . item ()
301327 }, global_step )
302328 return (new_optimizer_state , opt_update_fn ), new_params , new_model_state
303329
0 commit comments