@@ -75,6 +75,7 @@ def _loss_fn(params):
7575 spec .ForwardPassMode .TRAIN ,
7676 rng ,
7777 update_batch_norm = True ,)
78+ jax .debug .print ("logits: {logits}" , logits = logits )
7879 loss_dict = workload .loss_fn (
7980 label_batch = batch ['targets' ],
8081 logits_batch = logits ,
@@ -140,31 +141,29 @@ def update_params(
140141 replicated = NamedSharding (mesh , P ()) # No partitioning
141142 sharded = NamedSharding (mesh , P ('batch' )) # Partition along batch dimension
142143
143- # Define input and output shardings
144- arg_shardings = (
145- # workload is static
146- # opt_update_fn is static
147- replicated , # model_state
148- replicated , # optimizer_state
149- replicated , # current_param_container
150- sharded , # batch
151- replicated , # rng
152- replicated , # grad_clip
153- replicated # label_smoothing
154- )
155- out_shardings = (
156- replicated , # new_optimizer_state
157- replicated , # updated_params
158- replicated , # new_model_state
159- replicated , # loss
160- replicated # grad_norm
161- )
162144 jitted_train_step = jax .jit (
163145 train_step ,
164146 static_argnums = (0 , 1 ),
165147 donate_argnums = (2 , 3 , 4 ),
166- in_shardings = arg_shardings ,
167- out_shardings = out_shardings )
148+ in_shardings = (
149+ # workload is static
150+ # opt_update_fn is static
151+ replicated , # model_state
152+ replicated , # optimizer_state
153+ replicated , # current_param_container
154+ sharded , # batch
155+ replicated , # rng
156+ replicated , # grad_clip
157+ replicated # label_smoothing
158+ ),
159+ out_shardings = (
160+ replicated , # new_optimizer_state
161+ replicated , # updated_params
162+ replicated , # new_model_state
163+ replicated , # loss
164+ replicated # grad_norm
165+ ))
166+ # print(batch)
168167 new_optimizer_state , new_params , new_model_state , loss , grad_norm = jitted_train_step (workload ,
169168 opt_update_fn ,
170169 model_state ,
@@ -176,7 +175,7 @@ def update_params(
176175 label_smoothing )
177176
178177 # Log loss, grad_norm.
179- if global_step % 100 == 0 and workload .metrics_logger is not None :
178+ if global_step % 1 == 0 and workload .metrics_logger is not None :
180179 workload .metrics_logger .append_scalar_metrics (
181180 {
182181 'loss' : loss .item (),
0 commit comments