22
33from typing import Dict , Optional , Tuple
44
5+ import jax .numpy as jnp
6+ from flax import jax_utils
7+ from algoperf import param_utils
58from algoperf import spec
69from algoperf .workloads .lm .workload import BaseLmWorkload
10+ from algoperf .workloads .lm .lm_jax .models import LinearModel
711
812
913class LmWorkload (BaseLmWorkload ):
@@ -14,18 +18,32 @@ def init_model_fn(
1418 rng : spec .RandomState ,
1519 dropout_rate : Optional [float ] = None ,
1620 aux_dropout_rate : Optional [float ] = None ) -> spec .ModelInitState :
17- """aux_dropout_rate is used as attention_dropout_rate."""
18- pass
21+
22+ model = LinearModel (vocab_size = self ._vocab_size )
23+ input_shape = (1 , self ._seq_len , self ._vocab_size )
24+ variables = model .init (rng , jnp .ones (input_shape , jnp .float32 ))
25+ model_state , params = variables .pop ('params' )
26+
27+ self ._param_shapes = param_utils .jax_param_shapes (params )
28+ self ._param_types = param_utils .jax_param_types (self ._param_shapes )
29+ model_state = jax_utils .replicate (model_state )
30+ params = jax_utils .replicate (params )
31+
32+ return params , model_state
1933
2034 def model_fn (
2135 self ,
2236 params : spec .ParameterContainer ,
23- augmented_and_preprocessed_input_batch : Dict [str , spec .Tensor ],
37+ batch : Dict [str , spec .Tensor ],
2438 model_state : spec .ModelAuxiliaryState ,
2539 mode : spec .ForwardPassMode ,
2640 rng : spec .RandomState ,
2741 update_batch_norm : bool ) -> Tuple [spec .Tensor , spec .ModelAuxiliaryState ]:
28- pass
42+
43+ del mode , rng , update_batch_norm # Not used for linear model
44+ inputs = batch ['inputs' ]
45+ logits = self ._model .apply ({'params' : params , ** model_state }, inputs )
46+ return logits , model_state
2947
3048 def _eval_batch (self ,
3149 params : spec .ParameterContainer ,
0 commit comments