Skip to content

Commit 2717519

Browse files
committed
reformatting
1 parent c337cc4 commit 2717519

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

algoperf/workloads/ogbg/input_pipeline.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ def _get_weights_by_nan_and_padding(labels, padding_mask):
9898
return replaced_labels, weights
9999

100100

101-
def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None, shard=False):
101+
def _get_batch_iterator(
102+
dataset_iter, global_batch_size, num_shards=None, shard=False
103+
):
102104
"""Turns a per-example iterator into a batched iterator.
103105
104106
Constructs the batch from num_shards smaller batches, so that we can easily
@@ -160,8 +162,11 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None, shard=
160162
'weights': np.vstack(weights_shards),
161163
}
162164
else:
165+
163166
def f(x):
164-
return jax.tree.map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:])
167+
return jax.tree.map(
168+
lambda *vals: np.stack(vals, axis=0), x[0], *x[1:]
169+
)
165170

166171
graphs_shards = f(graphs_shards)
167172
labels_shards = f(labels_shards)

0 commit comments

Comments
 (0)