Skip to content

Commit f1db3d3

Browse files
committed
reformatting
1 parent 004afbd commit f1db3d3

File tree

15 files changed

+95
-88
lines changed

15 files changed

+95
-88
lines changed

algoperf/profiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def _make_report(
7272
float(np.std(d)),
7373
len(d),
7474
float(np.sum(d)),
75-
100.0 * float(np.sum(d)) / total_duration)
76-
for a, d in self.recorded_durations.items()]
75+
100.0 * float(np.sum(d)) / total_duration) for a,
76+
d in self.recorded_durations.items()]
7777
report.sort(key=lambda x: x[5], reverse=True)
7878
total_calls = sum(x[3] for x in report)
7979
return report, total_calls, total_duration

algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -130,16 +130,16 @@ def model_fn(
130130
return logits_batch, None
131131

132132
@functools.partial(
133-
jax.jit,
134-
in_shardings=(sharding_utils.get_replicated_sharding(),
135-
sharding_utils.get_naive_sharding_spec(),
136-
),
137-
static_argnums=(0,),
138-
out_shardings=sharding_utils.get_replicated_sharding()
139-
)
133+
jax.jit,
134+
in_shardings=(
135+
sharding_utils.get_replicated_sharding(),
136+
sharding_utils.get_naive_sharding_spec(),
137+
),
138+
static_argnums=(0,),
139+
out_shardings=sharding_utils.get_replicated_sharding())
140140
def _eval_batch_jitted(self,
141-
params: spec.ParameterContainer,
142-
batch: Dict[str, spec.Tensor]) -> spec.Tensor:
141+
params: spec.ParameterContainer,
142+
batch: Dict[str, spec.Tensor]) -> spec.Tensor:
143143
logits, _ = self.model_fn(
144144
params,
145145
batch,
@@ -160,8 +160,7 @@ def _eval_batch(self,
160160
batch: Dict[str, spec.Tensor]) -> spec.Tensor:
161161
# We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of
162162
# shape (local_device_count,) will all be different values.
163-
return np.array(
164-
self._eval_batch_jitted(params, batch), dtype=np.float64)
163+
return np.array(self._eval_batch_jitted(params, batch), dtype=np.float64)
165164

166165

167166
class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload):

algoperf/workloads/fastmri/fastmri_jax/workload.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,11 @@ def loss_fn(
9696

9797
@functools.partial(
9898
jax.jit,
99-
in_shardings=(
100-
sharding_utils.get_replicated_sharding(),
101-
sharding_utils.get_naive_sharding_spec(),
102-
sharding_utils.get_replicated_sharding()
103-
),
99+
in_shardings=(sharding_utils.get_replicated_sharding(),
100+
sharding_utils.get_naive_sharding_spec(),
101+
sharding_utils.get_replicated_sharding()),
104102
static_argnums=(0,),
105-
out_shardings=sharding_utils.get_replicated_sharding()
106-
)
103+
out_shardings=sharding_utils.get_replicated_sharding())
107104
def _eval_model(self,
108105
params: spec.Tensor,
109106
batch: Dict[str, spec.Tensor],

algoperf/workloads/fastmri/fastmri_pytorch/workload.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,9 @@ def _eval_model_on_split(self,
250250
for _ in range(num_batches):
251251
batch = next(self._eval_iters[split])
252252
batch_metrics = self._eval_model(params, batch, model_rng)
253-
total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()}
253+
total_metrics = {
254+
k: v + batch_metrics[k] for k, v in total_metrics.items()
255+
}
254256
if USE_PYTORCH_DDP:
255257
for metric in total_metrics.values():
256258
dist.all_reduce(metric)

algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,31 +20,27 @@
2020
tf.dtypes.float64,
2121
}
2222

23-
Number = Union[
24-
float,
25-
int,
26-
np.float16,
27-
np.float32,
28-
np.float64,
29-
np.int8,
30-
np.int16,
31-
np.int32,
32-
np.int64,
33-
np.uint8,
34-
np.uint16,
35-
np.uint32,
36-
np.uint64,
37-
]
38-
39-
TensorLike = Union[
40-
List[Union[Number, list]],
41-
tuple,
42-
Number,
43-
np.ndarray,
44-
tf.Tensor,
45-
tf.SparseTensor,
46-
tf.Variable,
47-
]
23+
Number = Union[float,
24+
int,
25+
np.float16,
26+
np.float32,
27+
np.float64,
28+
np.int8,
29+
np.int16,
30+
np.int32,
31+
np.int64,
32+
np.uint8,
33+
np.uint16,
34+
np.uint32,
35+
np.uint64,]
36+
37+
TensorLike = Union[List[Union[Number, list]],
38+
tuple,
39+
Number,
40+
np.ndarray,
41+
tf.Tensor,
42+
tf.SparseTensor,
43+
tf.Variable,]
4844

4945

5046
def get_ndims(image):

algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,8 @@ def build_lut(histo, step):
316316
# If step is zero, return the original image. Otherwise, build
317317
# lut from the full histogram and step and then index from it.
318318
result = tf.cond(
319-
tf.equal(step, 0), lambda: im,
319+
tf.equal(step, 0),
320+
lambda: im,
320321
lambda: tf.gather(build_lut(histo, step), im))
321322

322323
return tf.cast(result, tf.uint8)
@@ -551,6 +552,7 @@ def distort_image_with_randaugment(image, num_layers, magnitude, key):
551552
translate_const=100)
552553
image = tf.cond(
553554
tf.equal(i, op_to_select),
554-
lambda selected_func=func, selected_args=args: selected_func(
555-
image, *selected_args), lambda: image)
555+
lambda selected_func=func,
556+
selected_args=args: selected_func(image, *selected_args),
557+
lambda: image)
556558
return image

algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,12 @@ def init_model_fn(
105105
self._param_types = param_utils.jax_param_types(self._param_shapes)
106106
mesh = sharding_utils.get_mesh()
107107
params = jax.tree_map(
108-
lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh)
109-
),
108+
lambda x: jax.device_put(x,
109+
sharding_utils.get_replicated_sharding(mesh)),
110110
params)
111111
model_state = jax.tree_map(
112-
lambda x: jax.device_put(x, sharding_utils.get_replicated_sharding(mesh)
113-
),
112+
lambda x: jax.device_put(x,
113+
sharding_utils.get_replicated_sharding(mesh)),
114114
model_state)
115115
return params, model_state
116116

algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,9 @@ def _eval_model_on_split(self,
307307
update_batch_norm=False)
308308
weights = batch.get('weights')
309309
batch_metrics = self._compute_metrics(logits, batch['targets'], weights)
310-
total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()}
310+
total_metrics = {
311+
k: v + batch_metrics[k] for k, v in total_metrics.items()
312+
}
311313
if USE_PYTORCH_DDP:
312314
for metric in total_metrics.values():
313315
dist.all_reduce(metric)

algoperf/workloads/librispeech_conformer/librispeech_jax/models.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ def setup(self):
153153
self.kernel = self.param('kernel',
154154
nn.initializers.xavier_uniform(),
155155
self.filter_shape)
156-
self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32),
157-
self.output_channels)
156+
self.bias = self.param(
157+
'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels)
158158

159159
@nn.compact
160160
def __call__(self, inputs, paddings):
@@ -442,10 +442,12 @@ def setup(self):
442442
dtype = self.config.dtype
443443

444444
self.ra_mean = self.variable('batch_stats',
445-
'mean', lambda s: jnp.zeros(s, dtype),
445+
'mean',
446+
lambda s: jnp.zeros(s, dtype),
446447
dim)
447448
self.ra_var = self.variable('batch_stats',
448-
'var', lambda s: jnp.ones(s, dtype),
449+
'var',
450+
lambda s: jnp.ones(s, dtype),
449451
dim)
450452

451453
self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype)

algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def _get_mask(self,
8181
jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0),
8282
[batch_size, 1])
8383
multiplicity_tensor = masks_per_frame * choose_range
84-
multiplicity_weights = (multiplicity_weights
85-
< multiplicity_tensor).astype(jnp.int32)
84+
multiplicity_weights = (multiplicity_weights <
85+
multiplicity_tensor).astype(jnp.int32)
8686
pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights)
8787
else:
8888
pre_mask = jnp.einsum('bmt->bt', pre_mask)

0 commit comments

Comments
 (0)