Skip to content

Commit b14174b

Browse files
committed
refactoring & clean up
1 parent 99caa03 commit b14174b

File tree

15 files changed

+75
-158
lines changed

15 files changed

+75
-158
lines changed

algoperf/data_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torch.utils.data import DistributedSampler
1212
from torch.utils.data import Sampler
1313

14-
from algoperf import sharding_utils
14+
from algoperf import jax_sharding_utils
1515
from algoperf import spec
1616

1717

@@ -51,7 +51,6 @@ def shard_and_maybe_pad_np(
5151
weights = batch.get('weights')
5252
# The weights will also be padded.
5353
batch['weights'] = np.ones(mask_shape) if weights is None else weights
54-
naive_sharding_spec = sharding_utils.get_naive_sharding_spec()
5554

5655
def _prepare(x):
5756
# Use _numpy() for zero-copy conversion between TF and NumPy.
@@ -62,7 +61,7 @@ def _prepare(x):
6261
if remainder_size != 0 or pad_to_global_batch_size:
6362
x = pad(x, pad_size, padding_value=padding_value)
6463

65-
return jax.device_put(x, naive_sharding_spec)
64+
return jax.device_put(x, jax.sharding_utils.get_batch_dim_sharding())
6665

6766
return jax.tree.map(_prepare, batch)
6867

algoperf/sharding_utils.py

Lines changed: 0 additions & 82 deletions
This file was deleted.

algoperf/workloads/cifar/cifar_jax/workload.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import tensorflow_datasets as tfds
1313

1414
from algoperf import param_utils
15-
from algoperf import sharding_utils
15+
from algoperf import jax_sharding_utils
1616
from algoperf import spec
1717
from algoperf.workloads.cifar.cifar_jax import models
1818
from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter
@@ -186,10 +186,10 @@ def _eval_model(
186186
@functools.partial(
187187
jax.jit,
188188
in_shardings=(
189-
sharding_utils.get_replicated_sharding(), # params
190-
sharding_utils.get_naive_sharding_spec(), # batch
191-
sharding_utils.get_replicated_sharding(), # model_state
192-
sharding_utils.get_naive_sharding_spec(), # rng
189+
jax_sharding_utils.get_replicated_sharding(), # params
190+
jax_sharding_utils.get_batch_sharding(), # batch
191+
jax_sharding_utils.get_replicated_sharding(), # model_state
192+
jax_sharding_utils.get_batch_sharding(), # rng
193193
),
194194
)
195195
def _per_device_eval_model(

algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from algoperf import param_utils
1212
from algoperf import spec
1313
from algoperf.workloads.criteo1tb.criteo1tb_jax import models
14-
from algoperf import sharding_utils
14+
from algoperf import jax_sharding_utils
1515
from algoperf.workloads.criteo1tb.workload import \
1616
BaseCriteo1TbDlrmSmallWorkload
1717

@@ -106,7 +106,7 @@ def init_model_fn(
106106
initial_params = initial_variables['params']
107107
self._param_shapes = param_utils.jax_param_shapes(initial_params)
108108
self._param_types = param_utils.jax_param_types(self._param_shapes)
109-
return sharding_utils.shard_replicated(initial_params), None
109+
return jax_sharding_utils.shard(initial_params), None
110110

111111
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
112112
return param_key == 'Dense_7'
@@ -132,11 +132,11 @@ def model_fn(
132132
@functools.partial(
133133
jax.jit,
134134
in_shardings=(
135-
sharding_utils.get_replicated_sharding(),
136-
sharding_utils.get_naive_sharding_spec(),
135+
jax_sharding_utils.get_replicated_sharding(),
136+
jax_sharding_utils.get_batch_sharding(),
137137
),
138138
static_argnums=(0,),
139-
out_shardings=sharding_utils.get_replicated_sharding())
139+
out_shardings=jax_sharding_utils.get_replicated_sharding())
140140
def _eval_batch_jitted(self,
141141
params: spec.ParameterContainer,
142142
batch: Dict[str, spec.Tensor]) -> spec.Tensor:

algoperf/workloads/fastmri/fastmri_jax/workload.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from algoperf import param_utils
1212
from algoperf import spec
13-
from algoperf import sharding_utils
13+
from algoperf import jax_sharding_utils
1414
import algoperf.random_utils as prng
1515
from algoperf.workloads.fastmri.fastmri_jax.models import UNet
1616
from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim
@@ -40,7 +40,7 @@ def init_model_fn(
4040
params = variables['params']
4141
self._param_shapes = param_utils.jax_param_shapes(params)
4242
self._param_types = param_utils.jax_param_types(self._param_shapes)
43-
params = sharding_utils.shard_replicated(params)
43+
params = jax_sharding_utils.shard(params)
4444
return params, None
4545

4646
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
@@ -96,11 +96,11 @@ def loss_fn(
9696

9797
@functools.partial(
9898
jax.jit,
99-
in_shardings=(sharding_utils.get_replicated_sharding(),
100-
sharding_utils.get_naive_sharding_spec(),
101-
sharding_utils.get_replicated_sharding()),
99+
in_shardings=(jax_sharding_utils.get_replicated_sharding(),
100+
jax_sharding_utils.get_batch_sharding(),
101+
jax_sharding_utils.get_replicated_sharding()),
102102
static_argnums=(0,),
103-
out_shardings=sharding_utils.get_replicated_sharding())
103+
out_shardings=jax_sharding_utils.get_replicated_sharding())
104104
def _eval_model(self,
105105
params: spec.Tensor,
106106
batch: Dict[str, spec.Tensor],

algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from algoperf import param_utils
2222
from algoperf import random_utils as prng
23-
from algoperf import sharding_utils
23+
from algoperf import jax_sharding_utils
2424
from algoperf import spec
2525
from algoperf.workloads.imagenet_resnet import imagenet_v2
2626
from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline
@@ -103,14 +103,14 @@ def init_model_fn(
103103
model_state, params = pop(variables, "params")
104104
self._param_shapes = param_utils.jax_param_shapes(params)
105105
self._param_types = param_utils.jax_param_types(self._param_shapes)
106-
mesh = sharding_utils.get_mesh()
106+
mesh = jax_sharding_utils.get_mesh()
107107
params = jax.tree_map(
108108
lambda x: jax.device_put(x,
109-
sharding_utils.get_replicated_sharding(mesh)),
109+
jax_sharding_utils.get_replicated_sharding(mesh)),
110110
params)
111111
model_state = jax.tree_map(
112112
lambda x: jax.device_put(x,
113-
sharding_utils.get_replicated_sharding(mesh)),
113+
jax_sharding_utils.get_replicated_sharding(mesh)),
114114
model_state)
115115
return params, model_state
116116

@@ -120,13 +120,13 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
120120
@functools.partial(
121121
jax.jit,
122122
in_shardings=(
123-
sharding_utils.get_replicated_sharding(), # params
124-
sharding_utils.get_naive_sharding_spec(), # batch
125-
sharding_utils.get_replicated_sharding(), # model_state
126-
sharding_utils.get_replicated_sharding(), # rng
123+
jax_sharding_utils.get_replicated_sharding(), # params
124+
jax_sharding_utils.get_batch_sharding(), # batch
125+
jax_sharding_utils.get_replicated_sharding(), # model_state
126+
jax_sharding_utils.get_replicated_sharding(), # rng
127127
),
128128
static_argnums=(0,),
129-
out_shardings=sharding_utils.get_replicated_sharding())
129+
out_shardings=jax_sharding_utils.get_replicated_sharding())
130130
def _eval_model(self,
131131
params: spec.ParameterContainer,
132132
batch: Dict[str, spec.Tensor],

algoperf/workloads/imagenet_vit/imagenet_jax/workload.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import jax.numpy as jnp
99

1010
from algoperf import param_utils
11-
from algoperf import sharding_utils
11+
from algoperf import jax_sharding_utils
1212
from algoperf import spec
1313
from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \
1414
ImagenetResNetWorkload
@@ -46,8 +46,8 @@ def init_model_fn(
4646
params, model_state = self.initialized(rng, self._model)
4747
self._param_shapes = param_utils.jax_param_shapes(params)
4848
self._param_types = param_utils.jax_param_types(self._param_shapes)
49-
params = sharding_utils.shard_replicated(params)
50-
model_state = sharding_utils.shard_replicated(model_state)
49+
params = jax_sharding_utils.shard(params)
50+
model_state = jax_sharding_utils.shard(model_state)
5151
return params, model_state
5252

5353
def is_output_params(self, param_key: spec.ParameterKey) -> bool:

algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from algoperf import data_utils
1414
from algoperf import param_utils
15-
from algoperf import sharding_utils
15+
from algoperf import jax_sharding_utils
1616
from algoperf import spec
1717
from algoperf.workloads.librispeech_conformer import metrics
1818
from algoperf.workloads.librispeech_conformer import workload
@@ -94,8 +94,8 @@ def init_model_fn(
9494
self._param_types = param_utils.jax_param_types(self._param_shapes)
9595

9696
# Add sharding
97-
params = sharding_utils.shard_replicated(params)
98-
model_state = sharding_utils.shard_replicated(model_state)
97+
params = jax_sharding_utils.shard(params)
98+
model_state = jax_sharding_utils.shard(model_state)
9999

100100
return params, model_state
101101

@@ -310,12 +310,12 @@ def greedy_decode(
310310
@functools.partial(
311311
jax.jit,
312312
in_shardings=(
313-
sharding_utils.get_replicated_sharding(), # params
314-
sharding_utils.get_naive_sharding_spec(), # batch
315-
sharding_utils.get_replicated_sharding(), # model_state
316-
sharding_utils.get_replicated_sharding(), # rng
313+
jax_sharding_utils.get_replicated_sharding(), # params
314+
jax_sharding_utils.get_batch_sharding(), # batch
315+
jax_sharding_utils.get_replicated_sharding(), # model_state
316+
jax_sharding_utils.get_replicated_sharding(), # rng
317317
),
318-
out_shardings=sharding_utils.get_naive_sharding_spec(),
318+
out_shardings=jax_sharding_utils.get_batch_sharding(),
319319
static_argnums=(0,))
320320
def _eval_step(
321321
self,

algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from algoperf import param_utils
1212
from algoperf import spec
13-
from algoperf import sharding_utils
13+
from algoperf import jax_sharding_utils
1414
from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \
1515
LibriSpeechConformerWorkload
1616
from algoperf.workloads.librispeech_deepspeech.librispeech_jax import models
@@ -55,8 +55,8 @@ def init_model_fn(
5555
params = variables['params']
5656
self._param_shapes = param_utils.jax_param_shapes(params)
5757
self._param_types = param_utils.jax_param_types(self._param_shapes)
58-
model_state = sharding_utils.shard_replicated(model_state)
59-
params = sharding_utils.shard_replicated(params)
58+
model_state = jax_sharding_utils.shard(model_state)
59+
params = jax_sharding_utils.shard(params)
6060
return params, model_state
6161

6262
def model_fn_ref(
@@ -108,7 +108,7 @@ def model_fn(
108108
use_running_average_bn=use_running_average_bn)
109109

110110
model_fn_sharded = shard_map(model_fn_partial,
111-
sharding_utils.get_mesh(),
111+
jax_sharding_utils.get_mesh(),
112112
in_specs=(None, P('batch'), None),
113113
out_specs=(P('batch'), None),
114114
)

algoperf/workloads/mnist/mnist_jax/workload.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import optax
1111

1212
from algoperf import param_utils
13-
from algoperf import sharding_utils
13+
from algoperf import jax_sharding_utils
1414
from algoperf import spec
1515
from algoperf.workloads.mnist.workload import BaseMnistWorkload
1616

@@ -103,10 +103,10 @@ def loss_fn(
103103
@functools.partial(
104104
jax.jit,
105105
in_shardings=(
106-
sharding_utils.get_replicated_sharding(), # params
107-
sharding_utils.get_naive_sharding_spec(), # batch
108-
sharding_utils.get_replicated_sharding(), # model_state
109-
sharding_utils.get_naive_sharding_spec(), # rng
106+
jax_sharding_utils.get_replicated_sharding(), # params
107+
jax_sharding_utils.get_batch_sharding(), # batch
108+
jax_sharding_utils.get_replicated_sharding(), # model_state
109+
jax_sharding_utils.get_batch_sharding(), # rng
110110
),
111111
static_argnums=(0,))
112112
def _eval_model(

0 commit comments

Comments
 (0)