Skip to content

Commit c208cc7

Browse files
committed
sharding deepspeech
1 parent f1db3d3 commit c208cc7

File tree

4 files changed

+29
-16
lines changed

4 files changed

+29
-16
lines changed

algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -397,15 +397,15 @@ def __call__(
397397
seq_lengths_np = np.shape(seq_lengths)
398398

399399
n = jax.devices()
400-
logging.info(f"jax num devices {n}")
401-
logging.info(f'inputs shape {inputs_shape}')
402-
logging.info(f'h_0 shape {h_0_shape}')
403-
logging.info(f'c_0 shape {c_0_shape}')
404-
logging.info(f'seq_lengths shape {seq_lengths_np}')
405-
logging.info(f'weights_shape {weights_shape}')
406-
logging.info(f'input_size {input_size}')
407-
logging.info(f'hidden_size {self.features}')
408-
logging.info(f'num_layers {self.num_layers}')
400+
# logging.info(f"jax num devices {n}")
401+
# logging.info(f'inputs shape {inputs_shape}')
402+
# logging.info(f'h_0 shape {h_0_shape}')
403+
# logging.info(f'c_0 shape {c_0_shape}')
404+
# logging.info(f'seq_lengths shape {seq_lengths_np}')
405+
# logging.info(f'weights_shape {weights_shape}')
406+
# logging.info(f'input_size {input_size}')
407+
# logging.info(f'hidden_size {self.features}')
408+
# logging.info(f'num_layers {self.num_layers}')
409409

410410
y, h, c = rnn.lstm(
411411
x=inputs, h_0=h_0, c_0=c_0, weights=weights,

algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from flax import jax_utils
55
import jax
66
import jax.numpy as jnp
7+
from jax.experimental.shard_map import shard_map
8+
from jax.sharding import PartitionSpec as P
79
import numpy as np
810

911
from algoperf import param_utils
@@ -66,6 +68,21 @@ def model_fn(
6668
update_batch_norm: bool,
6769
use_running_average_bn: Optional[bool] = None
6870
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
71+
72+
model_fn_sharded = shard_map(model_fn_ref,
73+
self.mesh,
74+
)
75+
76+
def model_fn_ref(
77+
self,
78+
params: spec.ParameterContainer,
79+
augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
80+
model_state: spec.ModelAuxiliaryState,
81+
mode: spec.ForwardPassMode,
82+
rng: spec.RandomState,
83+
update_batch_norm: bool,
84+
use_running_average_bn: Optional[bool] = None
85+
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
6986
variables = {'params': params, **model_state}
7087
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
7188
is_train_mode = mode == spec.ForwardPassMode.TRAIN

pyproject.toml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,11 @@ jax_core_deps = [
106106
"protobuf==4.25.5",
107107
]
108108
jax_cpu = [
109-
"jax==0.4.28",
110-
"jaxlib==0.4.28",
109+
"jax",
111110
"algoperf[jax_core_deps]",
112111
]
113112
jax_gpu = [
114-
"jax==0.4.28",
115-
"jaxlib==0.4.28",
116-
"jax-cuda12-plugin[with_cuda]==0.4.28",
117-
"jax-cuda12-pjrt==0.4.28",
113+
"jax[cuda12]",
118114
"algoperf[jax_core_deps]",
119115
]
120116
pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"]

reference_algorithms/paper_baselines/adamw/jax/submission.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def _loss_fn(params):
7474
model_state,
7575
spec.ForwardPassMode.TRAIN,
7676
rng,
77-
update_batch_norm=True)
77+
update_batch_norm=True,)
7878
loss_dict = workload.loss_fn(
7979
label_batch=batch['targets'],
8080
logits_batch=logits,

0 commit comments

Comments
 (0)