Skip to content

Commit 5af0fdc

Browse files
fix style
1 parent 50989eb commit 5af0fdc

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

algoperf/workloads/lm/lm_pytorch/workload.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def _build_input_queue(
4545
not_train = split != 'train'
4646
per_device_batch_size = int(global_batch_size / N_GPUS)
4747

48-
seq_len = 2048 # TODO: define it somewehere else
49-
DTYPE = torch.int32 # TODO: decide between int32 and int64.
48+
seq_len = self._seq_len # TODO: define it somewehere else?
49+
dtype = torch.int32 # TODO: decide between int32 and int64.
5050

5151
# Only create and iterate over tf input pipeline in one Python process to
5252
# avoid creating too many threads.
@@ -66,18 +66,18 @@ def _build_input_queue(
6666
if RANK == 0:
6767
batch = next(np_iter) # pylint: disable=stop-iteration-return
6868
inputs = torch.as_tensor(
69-
batch['inputs'], dtype=DTYPE,
69+
batch['inputs'], dtype=dtype,
7070
device=DEVICE) # (N_GPUS, global_batch_size, seq_len)
7171
targets = torch.as_tensor(
72-
batch['targets'], dtype=DTYPE,
72+
batch['targets'], dtype=dtype,
7373
device=DEVICE) # (N_GPUS, global_batch_size, seq_len)
7474

7575
# Send batch to other devices when using DDP.
7676
if USE_PYTORCH_DDP:
7777
if not_train:
7878
# During eval, the batch size of the remainder might be different.
7979
per_device_batch_size = torch.tensor(
80-
len(targets[0]), dtype=DTYPE, device=DEVICE)
80+
len(targets[0]), dtype=dtype, device=DEVICE)
8181
dist.broadcast(per_device_batch_size, src=0)
8282
# We don't broadcast the shard for RANK 0.
8383
dist.broadcast(inputs[1:], src=0)
@@ -90,15 +90,15 @@ def _build_input_queue(
9090
# Receive batch from rank 0.
9191
if not_train:
9292
# During eval, the batch size of the remainder might be different.
93-
per_device_batch_size = torch.empty((1,), dtype=DTYPE, device=DEVICE)
93+
per_device_batch_size = torch.empty((1,), dtype=dtype, device=DEVICE)
9494
dist.broadcast(per_device_batch_size, src=0)
9595

9696
# N_GPUS - 1 since we don't broadcast the shard for RANK 0.
9797
inputs = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len),
98-
dtype=DTYPE,
98+
dtype=dtype,
9999
device=DEVICE)
100100
targets = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len),
101-
dtype=DTYPE,
101+
dtype=dtype,
102102
device=DEVICE)
103103
dist.broadcast(inputs, src=0)
104104
dist.broadcast(targets, src=0)

algoperf/workloads/lm/workload.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class BaseLmWorkload(spec.Workload):
2121
"""LM workload."""
2222

2323
_vocab_size: int = 32000
24+
_seq_len: int = 2048
2425

2526
def __init__(self) -> None:
2627
super().__init__()

0 commit comments

Comments
 (0)