Skip to content

Commit cfbaf7a

Browse files
committed
remove jax.device_put from imagenet test pipeline because it results in an OOM for subsequent training steps
1 parent 43d2191 commit cfbaf7a

File tree

1 file changed

+1
-8
lines changed

1 file changed

+1
-8
lines changed

algoperf/workloads/imagenet_resnet/imagenet_v2.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
import functools
77
from typing import Dict, Iterator, Tuple
88

9-
import jax
109
import tensorflow_datasets as tfds
1110

12-
from algoperf import data_utils, jax_sharding_utils, spec
11+
from algoperf import data_utils, spec
1312
from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline
1413

1514

@@ -47,10 +46,4 @@ def _decode_example(example: Dict[str, float]) -> Dict[str, float]:
4746
if framework == 'pytorch':
4847
it = map(data_utils.shard, it)
4948

50-
elif framework == 'jax':
51-
f = functools.partial(
52-
jax.device_put, device=jax_sharding_utils.get_batch_dim_sharding()
53-
)
54-
it = map(f, it)
55-
5649
return it

0 commit comments

Comments
 (0)