@@ -57,22 +57,7 @@ def init_model_fn(
5757 model_state = sharding_utils .shard_replicated (model_state )
5858 params = sharding_utils .shard_replicated (params )
5959 return params , model_state
60-
61- def model_fn (
62- self ,
63- params : spec .ParameterContainer ,
64- augmented_and_preprocessed_input_batch : Dict [str , spec .Tensor ],
65- model_state : spec .ModelAuxiliaryState ,
66- mode : spec .ForwardPassMode ,
67- rng : spec .RandomState ,
68- update_batch_norm : bool ,
69- use_running_average_bn : Optional [bool ] = None
70- ) -> Tuple [spec .Tensor , spec .ModelAuxiliaryState ]:
71-
72- model_fn_sharded = shard_map (model_fn_ref ,
73- self .mesh ,
74- )
75-
60+
7661 def model_fn_ref (
7762 self ,
7863 params : spec .ParameterContainer ,
@@ -104,6 +89,34 @@ def model_fn_ref(
10489 mutable = False )
10590 return (logits , logit_paddings ), model_state
10691
92+ def model_fn (
93+ self ,
94+ params : spec .ParameterContainer ,
95+ augmented_and_preprocessed_input_batch : Dict [str , spec .Tensor ],
96+ model_state : spec .ModelAuxiliaryState ,
97+ mode : spec .ForwardPassMode ,
98+ rng : spec .RandomState ,
99+ update_batch_norm : bool ,
100+ use_running_average_bn : Optional [bool ] = None
101+ ) -> Tuple [spec .Tensor , spec .ModelAuxiliaryState ]:
102+
103+ model_fn_partial = jax .tree_util .Partial (self .model_fn_ref ,
104+ mode = mode ,
105+ rng = rng ,
106+ update_batch_norm = update_batch_norm ,
107+ use_running_average_bn = use_running_average_bn )
108+
109+ model_fn_sharded = shard_map (model_fn_partial ,
110+ sharding_utils .get_mesh (),
111+ in_specs = (None , P ('batch' ), None ),
112+ out_specs = (P ('batch' ), None ),
113+ )
114+
115+ model_fn_sharded = model_fn_partial
116+ return model_fn_sharded (params ,
117+ augmented_and_preprocessed_input_batch ,
118+ model_state ,)
119+
107120 def is_output_params (self , param_key : spec .ParameterKey ) -> bool :
108121 return param_key == 'Dense_0'
109122
0 commit comments