Skip to content

Commit 47c8d2b

Browse files
Merge pull request #848 from mlcommons/jit_switch
Migrate JAX workloads from pmap to jit
2 parents 51eb65f + 43d2191 commit 47c8d2b

File tree

1,312 files changed

+1477
-303382
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,312 files changed

+1477
-303382
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ jobs:
180180
pip install -e .
181181
python tests/reference_algorithm_tests.py --workload=ogbg --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_nesterov.py --tuning_search_space=reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json
182182
python tests/reference_algorithm_tests.py --workload=ogbg --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_nesterov.py --tuning_search_space=reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json
183-
pytest:
183+
pytest-params:
184184
runs-on: ubuntu-latest
185185
steps:
186186
- uses: actions/checkout@v3

README.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,20 +57,16 @@ You can install this package and dependencies in a [Python virtual environment](
5757
We recommend using a Docker container (or alternatively, a Singularity/Apptainer container) to ensure a similar environment to our scoring and testing environments.
5858
Both options are described in detail in the [**Getting Started**](/docs/GETTING_STARTED.md) document.
5959

60-
_TL;DR to install the Jax version for GPU run:_
60+
*TL;DR to install the Jax version for GPU and all workload dependencies run:*
6161

6262
```bash
63-
pip3 install -e '.[pytorch_cpu]'
64-
pip3 install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html'
65-
pip3 install -e '.[full]'
63+
pip3 install -e '.[pytorch_cpu,jax_gpu,full]' --extra-index-url https://download.pytorch.org/whl/cpu
6664
```
6765

68-
_TL;DR to install the PyTorch version for GPU run:_
66+
*TL;DR to install the PyTorch version for GPU and all workload dependencies run:*
6967

7068
```bash
71-
pip3 install -e '.[jax_cpu]'
72-
pip3 install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'
73-
pip3 install -e '.[full]'
69+
pip3 install -e '.[jax_cpu,pytorch_gpu,full]'
7470
```
7571

7672
## Getting Started

algoperf/checkpoint_utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import os
88
from typing import Sequence, Tuple
99

10-
import jax
1110
import numpy as np
1211
import torch
1312
from absl import logging
@@ -210,10 +209,7 @@ def save_checkpoint(
210209
train_state, eval_results, global_step, preemption_count).
211210
"""
212211
if framework == 'jax':
213-
model_params = jax.device_get(jax_utils.unreplicate(model_params))
214212
opt_state, _ = optimizer_state
215-
opt_state = jax.device_get(jax_utils.unreplicate(opt_state))
216-
model_state = jax.device_get(jax_utils.unreplicate(model_state))
217213
else:
218214
if isinstance(
219215
model_params,

algoperf/data_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,19 @@ def _prepare(x):
6262
if remainder_size != 0 or pad_to_global_batch_size:
6363
x = pad(x, pad_size, padding_value=padding_value)
6464

65-
# Reshape (global_batch_size, ...) to
66-
# (local_device_count, per_device_batch_size, ...).
67-
# Assumes that `global_batch_size % local_device_count == 0`.
68-
return x.reshape((local_device_count, -1, *x.shape[1:]))
65+
# return x.reshape((local_device_count, -1, *x.shape[1:]))
66+
return x
6967

7068
return jax.tree.map(_prepare, batch)
7169

7270

71+
def shard(batch):
72+
local_device_count = max(torch.cuda.device_count(), jax.local_device_count())
73+
return jax.tree.map(
74+
lambda x: x.reshape((local_device_count, -1, *x.shape[1:])), batch
75+
)
76+
77+
7378
def pad(
7479
tensor: np.ndarray, pad_size: int, padding_value: int = 0
7580
) -> np.ndarray:

algoperf/jax_sharding_utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Utilities for dealing with sharding in JAX."""
2+
3+
import jax
4+
from jax.sharding import NamedSharding
5+
from jax.sharding import PartitionSpec as P
6+
7+
8+
def get_replicate_sharding():
9+
"""Returns a sharding spec that replicates data across all devices."""
10+
mesh = jax.sharding.Mesh(jax.devices(), ('batch',))
11+
return NamedSharding(mesh, P())
12+
13+
14+
def get_batch_dim_sharding():
15+
"""Returns a sharding spec that shards data along the first axis."""
16+
mesh = jax.sharding.Mesh(jax.devices(), ('batch',))
17+
return NamedSharding(mesh, P('batch'))
18+
19+
20+
def shard_along_batch_dim(x):
21+
"""Shards a tensor across all devices."""
22+
mesh = jax.sharding.Mesh(jax.devices(), ('batch',))
23+
return jax.tree.map(
24+
lambda x: jax.device_put(x, NamedSharding(mesh, P('batch'))), x
25+
)
26+
27+
28+
def replicate(x):
29+
"""Replicates tensor across all devices."""
30+
mesh = jax.sharding.Mesh(jax.devices(), ('batch',))
31+
return jax.tree.map(lambda x: jax.device_put(x, NamedSharding(mesh, P())), x)
32+
33+
34+
def display_shard_info(x: jax.Array):
35+
"""Displays shard info of a jax array."""
36+
for shard in x.addressable_shards:
37+
print(
38+
f'shard.device: {shard.device}, index: {shard.index}, replica_id:'
39+
f' {shard.replica_id}.\n'
40+
)

algoperf/workloads/cifar/cifar_jax/workload.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@
77
import jax.numpy as jnp
88
import optax
99
import tensorflow_datasets as tfds
10-
from flax import jax_utils
1110
from flax import linen as nn
1211
from flax.core import pop
1312
from jax import lax
1413

15-
from algoperf import param_utils, spec
14+
from algoperf import jax_sharding_utils, param_utils, spec
1615
from algoperf.workloads.cifar.cifar_jax import models
1716
from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter
1817
from algoperf.workloads.cifar.workload import BaseCifarWorkload
@@ -29,6 +28,7 @@ def _build_cifar_dataset(
2928
repeat_final_dataset: Optional[bool] = None,
3029
) -> Iterator[Dict[str, spec.Tensor]]:
3130
ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir)
31+
ds_builder.download_and_prepare()
3232
train = split == 'train'
3333
assert self.num_train_examples + self.num_validation_examples == 50000
3434
if split in ['train', 'eval_train']:
@@ -89,8 +89,8 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
8989
model_state, params = pop(variables, 'params')
9090
self._param_shapes = param_utils.jax_param_shapes(params)
9191
self._param_types = param_utils.jax_param_types(self._param_shapes)
92-
model_state = jax_utils.replicate(model_state)
93-
params = jax_utils.replicate(params)
92+
model_state = jax_sharding_utils.replicate(params)
93+
params = jax_sharding_utils.replicate(params)
9494
return params, model_state
9595

9696
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
@@ -171,15 +171,8 @@ def _compute_metrics(
171171
'loss': summed_loss,
172172
'accuracy': accuracy,
173173
}
174-
metrics = lax.psum(metrics, axis_name='batch')
175174
return metrics
176175

177-
@functools.partial(
178-
jax.pmap,
179-
axis_name='batch',
180-
in_axes=(None, 0, 0, 0, None),
181-
static_broadcasted_argnums=(0,),
182-
)
183176
def _eval_model(
184177
self,
185178
params: spec.ParameterContainer,
@@ -188,21 +181,41 @@ def _eval_model(
188181
rng: spec.RandomState,
189182
) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
190183
"""Return the mean accuracy and loss as a dict."""
191-
logits, _ = self.model_fn(
192-
params,
193-
batch,
194-
model_state,
195-
spec.ForwardPassMode.EVAL,
196-
rng,
197-
update_batch_norm=False,
184+
185+
@functools.partial(
186+
jax.jit,
187+
in_shardings=(
188+
jax_sharding_utils.get_replicate_sharding(), # params
189+
jax_sharding_utils.get_batch_dim_sharding(), # batch
190+
jax_sharding_utils.get_replicate_sharding(), # model_state
191+
jax_sharding_utils.get_batch_dim_sharding(), # rng
192+
),
198193
)
199-
weights = batch.get('weights')
200-
if weights is None:
201-
weights = jnp.ones(len(logits))
202-
return self._compute_metrics(logits, batch['targets'], weights)
194+
def _eval_model_jitted(
195+
params: spec.ParameterContainer,
196+
batch: Dict[str, spec.Tensor],
197+
model_state: spec.ModelAuxiliaryState,
198+
rng: spec.RandomState,
199+
) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
200+
"""Return the mean accuracy and loss as a dict."""
201+
logits, _ = self.model_fn(
202+
params,
203+
batch,
204+
model_state,
205+
spec.ForwardPassMode.EVAL,
206+
rng,
207+
update_batch_norm=False,
208+
)
209+
weights = batch.get('weights')
210+
if weights is None:
211+
weights = jnp.ones(len(logits))
212+
return self._compute_metrics(logits, batch['targets'], weights)
213+
214+
metrics = _eval_model_jitted(params, batch, model_state, rng)
215+
return jax.tree.map(lambda x: x.item(), metrics)
203216

204217
def _normalize_eval_metrics(
205218
self, num_examples: int, total_metrics: Dict[str, Any]
206219
) -> Dict[str, float]:
207220
"""Normalize eval metrics."""
208-
return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics)
221+
return jax.tree_map(lambda x: x / num_examples, total_metrics)

algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
import jax
77
import jax.numpy as jnp
88
import numpy as np
9-
from flax import jax_utils
109

11-
from algoperf import param_utils, spec
10+
from algoperf import jax_sharding_utils, param_utils, spec
1211
from algoperf.workloads.criteo1tb.criteo1tb_jax import models
1312
from algoperf.workloads.criteo1tb.workload import BaseCriteo1TbDlrmSmallWorkload
1413

@@ -106,7 +105,7 @@ def init_model_fn(
106105
initial_params = initial_variables['params']
107106
self._param_shapes = param_utils.jax_param_shapes(initial_params)
108107
self._param_types = param_utils.jax_param_types(self._param_shapes)
109-
return jax_utils.replicate(initial_params), None
108+
return jax_sharding_utils.replicate(initial_params), None
110109

111110
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
112111
return param_key == 'Dense_7'
@@ -132,13 +131,40 @@ def model_fn(
132131
logits_batch = self._model.apply({'params': params}, inputs, **apply_kwargs)
133132
return logits_batch, None
134133

134+
def _build_input_queue(
135+
self,
136+
data_rng: spec.RandomState,
137+
split: str,
138+
data_dir: str,
139+
global_batch_size: int,
140+
cache: Optional[bool] = None,
141+
repeat_final_dataset: Optional[bool] = None,
142+
num_batches: Optional[int] = None,
143+
):
144+
it = super()._build_input_queue(
145+
data_rng,
146+
split,
147+
data_dir,
148+
global_batch_size,
149+
cache,
150+
repeat_final_dataset,
151+
num_batches,
152+
)
153+
f = functools.partial(
154+
jax.device_put, device=jax_sharding_utils.get_batch_dim_sharding()
155+
)
156+
return map(f, it)
157+
135158
@functools.partial(
136-
jax.pmap,
137-
axis_name='batch',
138-
in_axes=(None, 0, 0),
139-
static_broadcasted_argnums=(0,),
159+
jax.jit,
160+
in_shardings=(
161+
jax_sharding_utils.get_replicate_sharding(),
162+
jax_sharding_utils.get_batch_dim_sharding(),
163+
),
164+
static_argnums=(0,),
165+
out_shardings=jax_sharding_utils.get_replicate_sharding(),
140166
)
141-
def _eval_batch_pmapped(
167+
def _eval_batch_jitted(
142168
self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]
143169
) -> spec.Tensor:
144170
logits, _ = self.model_fn(
@@ -162,9 +188,7 @@ def _eval_batch(
162188
) -> spec.Tensor:
163189
# We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of
164190
# shape (local_device_count,) will all be different values.
165-
return np.array(
166-
self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64
167-
)
191+
return np.array(self._eval_batch_jitted(params, batch), dtype=np.float64)
168192

169193

170194
class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload):

algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.distributed as dist
88
from torch.nn.parallel import DistributedDataParallel as DDP
99

10-
from algoperf import param_utils, spec
10+
from algoperf import data_utils, param_utils, spec
1111
from algoperf.pytorch_utils import pytorch_setup
1212
from algoperf.workloads.criteo1tb.criteo1tb_pytorch import models
1313
from algoperf.workloads.criteo1tb.workload import BaseCriteo1TbDlrmSmallWorkload
@@ -152,6 +152,7 @@ def _build_input_queue(
152152
num_batches=num_batches,
153153
repeat_final_dataset=repeat_final_dataset,
154154
)
155+
np_iter = map(data_utils.shard, np_iter)
155156
weights = None
156157
while True:
157158
if RANK == 0:

0 commit comments

Comments
 (0)