Skip to content

Commit d3a06fc

Browse files
committed
deepspeech jit changes
1 parent 2e4cc9e commit d3a06fc

File tree

2 files changed

+30
-17
lines changed
  • algoperf/workloads/librispeech_deepspeech/librispeech_jax
  • reference_algorithms/paper_baselines/adamw/jax

2 files changed

+30
-17
lines changed

algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,22 +57,7 @@ def init_model_fn(
5757
model_state = sharding_utils.shard_replicated(model_state)
5858
params = sharding_utils.shard_replicated(params)
5959
return params, model_state
60-
61-
def model_fn(
62-
self,
63-
params: spec.ParameterContainer,
64-
augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
65-
model_state: spec.ModelAuxiliaryState,
66-
mode: spec.ForwardPassMode,
67-
rng: spec.RandomState,
68-
update_batch_norm: bool,
69-
use_running_average_bn: Optional[bool] = None
70-
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
71-
72-
model_fn_sharded = shard_map(model_fn_ref,
73-
self.mesh,
74-
)
75-
60+
7661
def model_fn_ref(
7762
self,
7863
params: spec.ParameterContainer,
@@ -104,6 +89,34 @@ def model_fn_ref(
10489
mutable=False)
10590
return (logits, logit_paddings), model_state
10691

92+
def model_fn(
93+
self,
94+
params: spec.ParameterContainer,
95+
augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
96+
model_state: spec.ModelAuxiliaryState,
97+
mode: spec.ForwardPassMode,
98+
rng: spec.RandomState,
99+
update_batch_norm: bool,
100+
use_running_average_bn: Optional[bool] = None
101+
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
102+
103+
model_fn_partial = jax.tree_util.Partial(self.model_fn_ref,
104+
mode=mode,
105+
rng=rng,
106+
update_batch_norm=update_batch_norm,
107+
use_running_average_bn=use_running_average_bn)
108+
109+
model_fn_sharded = shard_map(model_fn_partial,
110+
sharding_utils.get_mesh(),
111+
in_specs=(None, P('batch'), None),
112+
out_specs=(P('batch'), None),
113+
)
114+
115+
model_fn_sharded = model_fn_partial
116+
return model_fn_sharded(params,
117+
augmented_and_preprocessed_input_batch,
118+
model_state,)
119+
107120
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
108121
return param_key == 'Dense_0'
109122

reference_algorithms/paper_baselines/adamw/jax/submission.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def get_batch_size(workload_name):
222222
elif workload_name == 'librispeech_conformer':
223223
return 256
224224
elif workload_name == 'librispeech_deepspeech':
225-
return 32
225+
return 256
226226
elif workload_name == 'ogbg':
227227
return 512
228228
elif workload_name == 'wmt':

0 commit comments

Comments
 (0)