Skip to content

Commit 49864fb

Browse files
committed
refactoring
1 parent 1edb724 commit 49864fb

File tree

7 files changed

+34
-35
lines changed

7 files changed

+34
-35
lines changed

algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def init_model_fn(
106106
initial_params = initial_variables['params']
107107
self._param_shapes = param_utils.jax_param_shapes(initial_params)
108108
self._param_types = param_utils.jax_param_types(self._param_shapes)
109-
return jax_sharding_utils.shard(initial_params), None
109+
return jax_sharding_utils.shard_along_batch_dim(initial_params), None
110110

111111
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
112112
return param_key == 'Dense_7'
@@ -132,11 +132,11 @@ def model_fn(
132132
@functools.partial(
133133
jax.jit,
134134
in_shardings=(
135-
jax_sharding_utils.get_replicated_sharding(),
136-
jax_sharding_utils.get_batch_sharding(),
135+
jax_sharding_utils.get_replicate_sharding(),
136+
jax_sharding_utils.get_batch_dim_sharding(),
137137
),
138138
static_argnums=(0,),
139-
out_shardings=jax_sharding_utils.get_replicated_sharding())
139+
out_shardings=jax_sharding_utils.get_replicate_sharding())
140140
def _eval_batch_jitted(self,
141141
params: spec.ParameterContainer,
142142
batch: Dict[str, spec.Tensor]) -> spec.Tensor:

algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,13 @@ def init_model_fn(
103103
model_state, params = pop(variables, "params")
104104
self._param_shapes = param_utils.jax_param_shapes(params)
105105
self._param_types = param_utils.jax_param_types(self._param_shapes)
106-
mesh = jax_sharding_utils.get_mesh()
107106
params = jax.tree_map(
108107
lambda x: jax.device_put(x,
109-
jax_sharding_utils.get_replicated_sharding(mesh)),
108+
jax_sharding_utils.get_replicate_sharding()),
110109
params)
111110
model_state = jax.tree_map(
112111
lambda x: jax.device_put(x,
113-
jax_sharding_utils.get_replicated_sharding(mesh)),
112+
jax_sharding_utils.get_replicate_sharding()),
114113
model_state)
115114
return params, model_state
116115

@@ -120,13 +119,13 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
120119
@functools.partial(
121120
jax.jit,
122121
in_shardings=(
123-
jax_sharding_utils.get_replicated_sharding(), # params
124-
jax_sharding_utils.get_batch_sharding(), # batch
125-
jax_sharding_utils.get_replicated_sharding(), # model_state
126-
jax_sharding_utils.get_replicated_sharding(), # rng
122+
jax_sharding_utils.get_replicate_sharding(), # params
123+
jax_sharding_utils.get_batch_dim_sharding(), # batch
124+
jax_sharding_utils.get_replicate_sharding(), # model_state
125+
jax_sharding_utils.get_replicate_sharding(), # rng
127126
),
128127
static_argnums=(0,),
129-
out_shardings=jax_sharding_utils.get_replicated_sharding())
128+
out_shardings=jax_sharding_utils.get_replicate_sharding())
130129
def _eval_model(self,
131130
params: spec.ParameterContainer,
132131
batch: Dict[str, spec.Tensor],

algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -310,12 +310,12 @@ def greedy_decode(
310310
@functools.partial(
311311
jax.jit,
312312
in_shardings=(
313-
jax_sharding_utils.get_replicated_sharding(), # params
314-
jax_sharding_utils.get_batch_sharding(), # batch
315-
jax_sharding_utils.get_replicated_sharding(), # model_state
316-
jax_sharding_utils.get_replicated_sharding(), # rng
313+
jax_sharding_utils.get_replicate_sharding(), # params
314+
jax_sharding_utils.get_batch_dim_sharding(), # batch
315+
jax_sharding_utils.get_replicate_sharding(), # model_state
316+
jax_sharding_utils.get_replicate_sharding(), # rng
317317
),
318-
out_shardings=jax_sharding_utils.get_batch_sharding(),
318+
out_shardings=jax_sharding_utils.get_batch_dim_sharding(),
319319
static_argnums=(0,))
320320
def _eval_step(
321321
self,

algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def model_fn(
108108
use_running_average_bn=use_running_average_bn)
109109

110110
model_fn_sharded = shard_map(model_fn_partial,
111-
jax_sharding_utils.get_mesh(),
111+
jax.sharding.Mesh(jax.devices(), ('batch')),
112112
in_specs=(None, P('batch'), None),
113113
out_specs=(P('batch'), None),
114114
)

algoperf/workloads/mnist/mnist_jax/workload.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,10 @@ def loss_fn(
103103
@functools.partial(
104104
jax.jit,
105105
in_shardings=(
106-
jax_sharding_utils.get_replicated_sharding(), # params
107-
jax_sharding_utils.get_batch_sharding(), # batch
108-
jax_sharding_utils.get_replicated_sharding(), # model_state
109-
jax_sharding_utils.get_batch_sharding(), # rng
106+
jax_sharding_utils.get_replicate_sharding(), # params
107+
jax_sharding_utils.get_batch_dim_sharding(), # batch
108+
jax_sharding_utils.get_replicate_sharding(), # model_state
109+
jax_sharding_utils.get_batch_dim_sharding(), # rng
110110
),
111111
static_argnums=(0,))
112112
def _eval_model(

algoperf/workloads/ogbg/ogbg_jax/workload.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,12 @@ def _eval_metric(self, labels, logits, masks):
111111

112112
@functools.partial(
113113
jax.jit,
114-
in_shardings=(jax_sharding_utils.get_replicated_sharding(),
115-
jax_sharding_utils.get_batch_sharding(),
116-
jax_sharding_utils.get_replicated_sharding(),
117-
jax_sharding_utils.get_replicated_sharding()),
114+
in_shardings=(jax_sharding_utils.get_replicate_sharding(),
115+
jax_sharding_utils.get_batch_dim_sharding(),
116+
jax_sharding_utils.get_replicate_sharding(),
117+
jax_sharding_utils.get_replicate_sharding()),
118118
static_argnums=(0,),
119-
out_shardings=jax_sharding_utils.get_replicated_sharding(),
119+
out_shardings=jax_sharding_utils.get_replicate_sharding(),
120120
)
121121
def _eval_batch(self, params, batch, model_state, rng):
122122
return super()._eval_batch(params, batch, model_state, rng)

algoperf/workloads/wmt/wmt_jax/workload.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def eval_step(self,
100100
@functools.partial(
101101
jax.jit,
102102
in_shardings=(
103-
jax_sharding_utils.get_batch_sharding(), # inputs
103+
jax_sharding_utils.get_batch_dim_sharding(), # inputs
104104
),
105105
static_argnums=(
106106
0,
@@ -112,9 +112,9 @@ def initialize_cache(self,
112112
"""Initialize a cache for a given input shape and max decode length."""
113113
config = models.TransformerConfig(deterministic=True, decode=True)
114114
target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:]
115-
dummy_inputs = jax_sharding_utils.shard_naive(
115+
dummy_inputs = jax_sharding_utils.shard_along_batch_dim(
116116
jnp.ones(inputs.shape, jnp.float32))
117-
dummy_targets = jax_sharding_utils.shard_naive(
117+
dummy_targets = jax_sharding_utils.shard_along_batch_dim(
118118
jnp.ones(target_shape, jnp.float32))
119119
initial_variables = models.Transformer(config).init(
120120
jax.random.PRNGKey(0), dummy_inputs, dummy_targets)
@@ -196,8 +196,8 @@ def translate_and_calculate_bleu(self,
196196
jitted_predict_step = jax.jit(
197197
self.predict_step,
198198
in_shardings=(
199-
jax_sharding_utils.get_batch_sharding(), # inputs
200-
jax_sharding_utils.get_replicated_sharding(), # params
199+
jax_sharding_utils.get_batch_dim_sharding(), # inputs
200+
jax_sharding_utils.get_replicate_sharding(), # params
201201
jax_sharding_utils.get_naive_sharding_tree(cache), # cache
202202
),
203203
static_argnums=(
@@ -260,8 +260,8 @@ def init_model_fn(
260260
params_rng, dropout_rng = jax.random.split(rng)
261261
inputs = jnp.ones(input_shape, jnp.float32)
262262
targets = jnp.ones(target_shape, jnp.float32)
263-
sharded_inputs = jax_sharding_utils.shard_naive(inputs)
264-
sharded_targets = jax_sharding_utils.shard_naive(targets)
263+
sharded_inputs = jax_sharding_utils.shard_along_batch_dim(inputs)
264+
sharded_targets = jax_sharding_utils.shard_along_batch_dim(targets)
265265

266266
initial_variables = jax.jit(
267267
self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng},
@@ -271,7 +271,7 @@ def init_model_fn(
271271
initial_params = initial_variables['params']
272272
self._param_shapes = param_utils.jax_param_shapes(initial_params)
273273
self._param_types = param_utils.jax_param_types(self._param_shapes)
274-
params = jax_sharding_utils.shard(initial_params)
274+
params = jax_sharding_utils.shard_along_batch_dim(initial_params)
275275
return initial_params, None
276276

277277
def is_output_params(self, param_key: spec.ParameterKey) -> bool:

0 commit comments

Comments
 (0)