77import jax .numpy as jnp
88import optax
99import tensorflow_datasets as tfds
10- from flax import jax_utils
1110from flax import linen as nn
1211from flax .core import pop
1312from jax import lax
1413
15- from algoperf import param_utils , spec
14+ from algoperf import jax_sharding_utils , param_utils , spec
1615from algoperf .workloads .cifar .cifar_jax import models
1716from algoperf .workloads .cifar .cifar_jax .input_pipeline import create_input_iter
1817from algoperf .workloads .cifar .workload import BaseCifarWorkload
@@ -29,6 +28,7 @@ def _build_cifar_dataset(
2928 repeat_final_dataset : Optional [bool ] = None ,
3029 ) -> Iterator [Dict [str , spec .Tensor ]]:
3130 ds_builder = tfds .builder ('cifar10:3.0.2' , data_dir = data_dir )
31+ ds_builder .download_and_prepare ()
3232 train = split == 'train'
3333 assert self .num_train_examples + self .num_validation_examples == 50000
3434 if split in ['train' , 'eval_train' ]:
@@ -89,8 +89,8 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
8989 model_state , params = pop (variables , 'params' )
9090 self ._param_shapes = param_utils .jax_param_shapes (params )
9191 self ._param_types = param_utils .jax_param_types (self ._param_shapes )
92- model_state = jax_utils .replicate (model_state )
93- params = jax_utils .replicate (params )
92+ model_state = jax_sharding_utils .replicate (params )
93+ params = jax_sharding_utils .replicate (params )
9494 return params , model_state
9595
9696 def is_output_params (self , param_key : spec .ParameterKey ) -> bool :
@@ -105,9 +105,11 @@ def model_fn(
105105 rng : spec .RandomState ,
106106 update_batch_norm : bool ,
107107 use_running_average_bn : Optional [bool ] = None ,
108+ dropout_rate : float = 0.0 ,
108109 ) -> Tuple [spec .Tensor , spec .ModelAuxiliaryState ]:
109110 del mode
110111 del rng
112+ del dropout_rate
111113 variables = {'params' : params , ** model_state }
112114 if update_batch_norm :
113115 logits , new_model_state = self ._model .apply (
@@ -171,15 +173,8 @@ def _compute_metrics(
171173 'loss' : summed_loss ,
172174 'accuracy' : accuracy ,
173175 }
174- metrics = lax .psum (metrics , axis_name = 'batch' )
175176 return metrics
176177
177- @functools .partial (
178- jax .pmap ,
179- axis_name = 'batch' ,
180- in_axes = (None , 0 , 0 , 0 , None ),
181- static_broadcasted_argnums = (0 ,),
182- )
183178 def _eval_model (
184179 self ,
185180 params : spec .ParameterContainer ,
@@ -188,21 +183,41 @@ def _eval_model(
188183 rng : spec .RandomState ,
189184 ) -> Dict [spec .Tensor , spec .ModelAuxiliaryState ]:
190185 """Return the mean accuracy and loss as a dict."""
191- logits , _ = self .model_fn (
192- params ,
193- batch ,
194- model_state ,
195- spec .ForwardPassMode .EVAL ,
196- rng ,
197- update_batch_norm = False ,
186+
187+ @functools .partial (
188+ jax .jit ,
189+ in_shardings = (
190+ jax_sharding_utils .get_replicate_sharding (), # params
191+ jax_sharding_utils .get_batch_dim_sharding (), # batch
192+ jax_sharding_utils .get_replicate_sharding (), # model_state
193+ jax_sharding_utils .get_batch_dim_sharding (), # rng
194+ ),
198195 )
199- weights = batch .get ('weights' )
200- if weights is None :
201- weights = jnp .ones (len (logits ))
202- return self ._compute_metrics (logits , batch ['targets' ], weights )
196+ def _eval_model_jitted (
197+ params : spec .ParameterContainer ,
198+ batch : Dict [str , spec .Tensor ],
199+ model_state : spec .ModelAuxiliaryState ,
200+ rng : spec .RandomState ,
201+ ) -> Dict [spec .Tensor , spec .ModelAuxiliaryState ]:
202+ """Return the mean accuracy and loss as a dict."""
203+ logits , _ = self .model_fn (
204+ params ,
205+ batch ,
206+ model_state ,
207+ spec .ForwardPassMode .EVAL ,
208+ rng ,
209+ update_batch_norm = False ,
210+ )
211+ weights = batch .get ('weights' )
212+ if weights is None :
213+ weights = jnp .ones (len (logits ))
214+ return self ._compute_metrics (logits , batch ['targets' ], weights )
215+
216+ metrics = _eval_model_jitted (params , batch , model_state , rng )
217+ return jax .tree .map (lambda x : x .item (), metrics )
203218
204219 def _normalize_eval_metrics (
205220 self , num_examples : int , total_metrics : Dict [str , Any ]
206221 ) -> Dict [str , float ]:
207222 """Normalize eval metrics."""
208- return jax .tree . map (lambda x : float ( x [ 0 ] / num_examples ) , total_metrics )
223+ return jax .tree_map (lambda x : x / num_examples , total_metrics )
0 commit comments