Skip to content

Commit 7f35327

Browse files
committed
swap out lstm layer
1 parent 3593463 commit 7f35327

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,8 @@ class BatchRNN(nn.Module):
439439
@nn.compact
440440
def __call__(self, inputs, input_paddings, train):
441441
config = self.config
442+
lengths = jnp.sum(1 - input_paddings, axis=-1, dtype=jnp.int32)
443+
442444

443445
if config.layernorm_everywhere:
444446
inputs = LayerNorm(config.encoder_dim)(inputs)
@@ -452,7 +454,7 @@ def __call__(self, inputs, input_paddings, train):
452454
output, _ = lstm.LSTM(
453455
hidden_size=config.encoder_dim // 2,
454456
bidirectional=config.bidirectional,
455-
num_layers=1)(inputs, input_paddings)
457+
num_layers=1)(inputs, lengths)
456458

457459
return output
458460

algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ def model_fn(
8080
train=True,
8181
rngs={'dropout' : rng},
8282
mutable=['batch_stats'])
83-
if 'batch_stats' in new_model_state and new_model_state['batch_stats']:
84-
new_model_state = jax.lax.pmean(new_model_state, 'batch')
8583
return (logits, logit_paddings), new_model_state
8684
else:
8785
logits, logit_paddings = self._model.apply(

scoring/plot_utils/plot_curves.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,17 @@
1010
flags.DEFINE_string(
1111
'experiment_dir',
1212
# '/home/kasimbeg/algoperf-runs-internal/experiments/jit_switch_debug_conformer_old_step_hint',
13-
'/home/kasimbeg/submissions_algorithms/logs/external_tuning/baseline',
13+
'/home/kasimbeg/algoperf-runs-internal/experiments/jit_switch_debug_deepspeech_nadamw_jit_branch',
1414
'Path to experiment dir.')
1515
flags.DEFINE_string(
1616
'workloads',
17-
'librispeech_conformer_jax',
17+
'librispeech_deepspeech_jax',
1818
'Filter only for workload e.g. fastmri_jax. If None include all workloads in experiment.'
1919
)
2020
flags.DEFINE_string('project_name',
21-
'visualize-training-curves-legacy-stephint-conformer',
21+
'visualize-training-curves-legacy-stephint-deepspeech',
2222
'Wandb project name.')
23-
flags.DEFINE_string('run_postfix', 'pmap', 'Postfix for wandb runs.')
23+
flags.DEFINE_string('run_postfix', 'jit_legacy_lstm', 'Postfix for wandb runs.')
2424

2525
FLAGS = flags.FLAGS
2626

0 commit comments

Comments
 (0)