Skip to content

Commit 91912cc

Browse files
committed
fix
1 parent 2717519 commit 91912cc

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

algoperf/workloads/ogbg/ogbg_pytorch/workload.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def _build_input_queue(
7777
split: str,
7878
data_dir: str,
7979
global_batch_size: int,
80+
shard: bool = True
8081
):
8182
# TODO: Check where the + 1 comes from.
8283
per_device_batch_size = int(global_batch_size / N_GPUS) + 1
@@ -86,7 +87,7 @@ def _build_input_queue(
8687
if RANK == 0:
8788
data_rng = data_rng.astype('uint32')
8889
dataset_iter = super()._build_input_queue(
89-
data_rng, split, data_dir, global_batch_size, shard=True
90+
data_rng, split, data_dir, global_batch_size, shard
9091
)
9192

9293
while True:

algoperf/workloads/ogbg/workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def _build_input_queue(
100100
split: str,
101101
data_dir: str,
102102
global_batch_size: int,
103-
shard: bool,
103+
shard: bool = False,
104104
):
105105
dataset_iter = input_pipeline.get_dataset_iter(
106106
split, data_rng, data_dir, global_batch_size, shard

0 commit comments

Comments
 (0)