@@ -28,10 +28,9 @@ def _build_cifar_dataset(
2828 data_dir : str ,
2929 batch_size : int ,
3030 cache : Optional [bool ] = None ,
31- repeat_final_dataset : Optional [bool ] = None ,
31+ repeat_final_dataset : Optional [bool ] = None
3232 ) -> Iterator [Dict [str , spec .Tensor ]]:
33- data_dir = data_dir + "/cifar10"
34- ds_builder = tfds .builder ("cifar10:3.0.2" , data_dir = data_dir )
33+ ds_builder = tfds .builder ('cifar10:3.0.2' , data_dir = data_dir )
3534 train = split == 'train'
3635 assert self .num_train_examples + self .num_validation_examples == 50000
3736 if split in ['train' , 'eval_train' ]:
@@ -92,15 +91,17 @@ def init_model_fn(
9291 model = model_cls (num_classes = self ._num_classes , dtype = jnp .float32 )
9392 self ._model = model
9493 input_shape = (1 , 32 , 32 , 3 )
95- variables = jax .jit (model .init )({" params" : rng },
94+ variables = jax .jit (model .init )({' params' : rng },
9695 jnp .ones (input_shape , model .dtype ))
9796 model_state , params = pop (variables , 'params' )
9897 self ._param_shapes = param_utils .jax_param_shapes (params )
9998 self ._param_types = param_utils .jax_param_types (self ._param_shapes )
99+ model_state = jax_sharding_utils .replicate (params )
100+ params = jax_sharding_utils .replicate (params )
100101 return params , model_state
101102
102103 def is_output_params (self , param_key : spec .ParameterKey ) -> bool :
103- return param_key == " Dense_0"
104+ return param_key == ' Dense_0'
104105
105106 def model_fn (
106107 self ,
@@ -114,19 +115,19 @@ def model_fn(
114115 ) -> Tuple [spec .Tensor , spec .ModelAuxiliaryState ]:
115116 del mode
116117 del rng
117- variables = {" params" : params , ** model_state }
118+ variables = {' params' : params , ** model_state }
118119 if update_batch_norm :
119120 logits , new_model_state = self ._model .apply (
120121 variables ,
121- augmented_and_preprocessed_input_batch [" inputs" ],
122+ augmented_and_preprocessed_input_batch [' inputs' ],
122123 update_batch_norm = update_batch_norm ,
123124 mutable = ['batch_stats' ],
124125 use_running_average_bn = use_running_average_bn )
125126 return logits , new_model_state
126127 else :
127128 logits = self ._model .apply (
128129 variables ,
129- augmented_and_preprocessed_input_batch [" inputs" ],
130+ augmented_and_preprocessed_input_batch [' inputs' ],
130131 update_batch_norm = update_batch_norm ,
131132 mutable = False ,
132133 use_running_average_bn = use_running_average_bn )
@@ -139,15 +140,13 @@ def loss_fn(
139140 label_batch : spec .Tensor , # Dense or one-hot labels.
140141 logits_batch : spec .Tensor ,
141142 mask_batch : Optional [spec .Tensor ] = None ,
142- label_smoothing : float = 0.0 ,
143- ) -> Dict [str , spec .Tensor ]: # differentiable
143+ label_smoothing : float = 0.0 ) -> Dict [str , spec .Tensor ]: # differentiable
144144 """Evaluate the (masked) loss function at (label_batch, logits_batch).
145145
146- Return {'summed': scalar summed loss,
147- 'n_valid_examples': scalar number of
148- valid examples in batch, 'per_example': 1-d array of per-example losses}
149- (not synced across devices).
150- """
146+ Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
147+ valid examples in batch, 'per_example': 1-d array of per-example losses}
148+ (not synced across devices).
149+ """
151150 one_hot_targets = jax .nn .one_hot (label_batch , self ._num_classes )
152151 smoothed_targets = optax .smooth_labels (one_hot_targets , label_smoothing )
153152 per_example_losses = - jnp .sum (
@@ -160,66 +159,51 @@ def loss_fn(
160159 n_valid_examples = len (per_example_losses )
161160 summed_loss = per_example_losses .sum ()
162161 return {
163- " summed" : summed_loss ,
164- " n_valid_examples" : n_valid_examples ,
165- " per_example" : per_example_losses ,
162+ ' summed' : summed_loss ,
163+ ' n_valid_examples' : n_valid_examples ,
164+ ' per_example' : per_example_losses ,
166165 }
167166
168167 def _compute_metrics (self ,
169168 logits : spec .Tensor ,
170169 labels : spec .Tensor ,
171170 weights : spec .Tensor ) -> Dict [str , spec .Tensor ]:
172- summed_loss = self .loss_fn (labels , logits , weights )[" summed" ]
171+ summed_loss = self .loss_fn (labels , logits , weights )[' summed' ]
173172 # Number of correct predictions.
174173 accuracy = jnp .sum ((jnp .argmax (logits , - 1 ) == labels ) * weights )
175- return jnp .array (summed_loss ), jnp .array (accuracy )
174+ metrics = {
175+ 'loss' : summed_loss ,
176+ 'accuracy' : accuracy ,
177+ }
178+ return metrics
176179
180+ @functools .partial (
181+ jax .jit ,
182+ in_shardings = (
183+ jax_sharding_utils .get_replicated_sharding (), # params
184+ jax_sharding_utils .get_batch_sharding (), # batch
185+ jax_sharding_utils .get_replicated_sharding (), # model_state
186+ jax_sharding _utils .get_batch_sharding (), # rng
187+ ),
188+ )
177189 def _eval_model (
178190 self ,
179191 params : spec .ParameterContainer ,
180192 batch : Dict [str , spec .Tensor ],
181193 model_state : spec .ModelAuxiliaryState ,
182- rng : spec .RandomState ,
183- ) -> Dict [spec .Tensor , spec .ModelAuxiliaryState ]:
194+ rng : spec .RandomState ) -> Dict [spec .Tensor , spec .ModelAuxiliaryState ]:
184195 """Return the mean accuracy and loss as a dict."""
185-
186- @functools .partial (
187- jax .jit ,
188- in_shardings = (
189- jax_sharding_utils .get_replicated_sharding (), # params
190- jax_sharding_utils .get_batch_sharding (), # batch
191- jax_sharding_utils .get_replicated_sharding (), # model_state
192- jax_sharding_utils .get_batch_sharding (), # rng
193- ),
194- )
195- def _per_device_eval_model (
196- params : spec .ParameterContainer ,
197- batch : Dict [str , spec .Tensor ],
198- model_state : spec .ModelAuxiliaryState ,
199- rng : spec .RandomState ,
200- ) -> Tuple [spec .Tensor , spec .ModelAuxiliaryState ]:
201- logits , _ = self .model_fn (
202- params ,
203- batch ,
204- model_state ,
205- spec .ForwardPassMode .EVAL ,
206- rng ,
207- update_batch_norm = False ,
208- )
209- weights = batch .get ("weights" )
210- if weights is None :
211- weights = jnp .ones (len (logits ))
212- return self ._compute_metrics (logits , batch ["targets" ], weights )
213-
214- losses , accuracies = _per_device_eval_model (params , batch , model_state , rng )
215- metrics = {
216- "loss" :
217- jnp .mean (losses , axis = 0 ) if losses .ndim > 0 else losses ,
218- "accuracy" :
219- (jnp .mean (accuracies , axis = 0 ) if accuracies .ndim > 0 else accuracies
220- ),
221- }
222- return metrics
196+ logits , _ = self .model_fn (
197+ params ,
198+ batch ,
199+ model_state ,
200+ spec .ForwardPassMode .EVAL ,
201+ rng ,
202+ update_batch_norm = False )
203+ weights = batch .get ('weights' )
204+ if weights is None :
205+ weights = jnp .ones (len (logits ))
206+ return self ._compute_metrics (logits , batch ['targets' ], weights )
223207
224208 def _normalize_eval_metrics (
225209 self , num_examples : int , total_metrics : Dict [str ,
0 commit comments