Skip to content

Commit f4ffbe7

Browse files
committed
Fix torch sharding issue, update input pipeline and workload classes to use int32 for tensor types and add dropout rate parameter
1 parent f0c6e75 commit f4ffbe7

File tree

5 files changed

+37
-27
lines changed

5 files changed

+37
-27
lines changed

algoperf/workloads/lm/input_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ def tf_generator():
119119
ds = tf.data.Dataset.from_generator(
120120
tf_generator,
121121
output_signature={
122-
"inputs": tf.TensorSpec(shape=(None,), dtype=tf.int64),
123-
"targets": tf.TensorSpec(shape=(None,), dtype=tf.int64),
122+
"inputs": tf.TensorSpec(shape=(None,), dtype=tf.int32),
123+
"targets": tf.TensorSpec(shape=(None,), dtype=tf.int32),
124124
})
125125

126126
# Avoid creating too many threads when using PyTorch DDP.

algoperf/workloads/lm/lm_jax/workload.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,9 @@ def model_fn(
9090
model_state: spec.ModelAuxiliaryState,
9191
mode: spec.ForwardPassMode,
9292
rng: spec.RandomState,
93-
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
94-
del mode, rng, update_batch_norm, model_state
93+
update_batch_norm: bool,
94+
dropout_rate: float) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
95+
del mode, rng, update_batch_norm, model_state, dropout_rate
9596
inputs = batch['inputs']
9697
# Convert one-hot inputs to token IDs if needed
9798
if inputs.ndim == 3: # one-hot encoded

algoperf/workloads/lm/lm_pytorch/workload.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import torch
77
import torch.distributed as dist
88
from torch.nn.parallel import DistributedDataParallel as DDP
9-
9+
from itertools import islice
10+
from algoperf import data_utils
1011
from algoperf import param_utils
1112
from algoperf import pytorch_utils
1213
from algoperf import spec
@@ -84,19 +85,22 @@ def _build_input_queue(
8485
num_batches: Optional[int] = None,
8586
repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]:
8687
"""Build an input queue for the given split."""
87-
from algoperf.workloads.lm.input_pipeline import get_hf_dataloader
88-
89-
loader = get_hf_dataloader(
90-
cache_dir=data_dir,
88+
from algoperf.workloads.lm.input_pipeline import get_lm_dataset
89+
local_batch_size = global_batch_size // N_GPUS
90+
91+
loader = get_lm_dataset(
9192
data_rng=data_rng,
92-
batch_size=global_batch_size,
93-
seq_len=self._seq_len,
94-
framework="torch",
95-
split=split)
93+
split=split,
94+
data_dir=data_dir,
95+
global_batch_size=local_batch_size,
96+
num_batches=num_batches
97+
)
98+
if USE_PYTORCH_DDP:
99+
loader = islice(loader, RANK, None, N_GPUS)
96100
seq_len = self._seq_len
97101
weights = None
98102

99-
dtype = torch.long
103+
dtype = torch.int32
100104
is_train = split == 'train'
101105

102106
for batch in loader:
@@ -109,17 +113,16 @@ def _build_input_queue(
109113
per_device_batch_size = torch.tensor(
110114
targets.shape[0], dtype=dtype, device=DEVICE)
111115
dist.broadcast(per_device_batch_size, src=0)
112-
116+
local_batch_size = per_device_batch_size.item()
113117
# Broadcast to all devices
114-
dist.broadcast(inputs, src=0)
115-
dist.broadcast(targets, src=0)
118+
#dist.broadcast(inputs, src=0)
119+
#dist.broadcast(targets, src=0)
116120

117121
if weights is None:
118-
batch_size = targets.shape[0] if not USE_PYTORCH_DDP else per_device_batch_size.item()
119-
weights = torch.ones((batch_size, seq_len), device=DEVICE)
122+
weights = torch.ones((local_batch_size, seq_len), device=DEVICE)
120123
batch = {
121-
'inputs': inputs,
122-
'targets': targets,
124+
'inputs': torch.tensor(inputs, device=DEVICE, dtype=dtype),
125+
'targets': torch.tensor(targets, device=DEVICE, dtype=dtype),
123126
'weights': weights,
124127
}
125128
yield batch

algoperf/workloads/lm/tests/test_build_input_queue_torch.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ def sync_ddp():
1717
def test_dataloader_torch():
1818
# Test config.
1919
rng_seed = 1996
20-
data_dir = '/fast/najroldi/data/finewebedu'
20+
data_dir = '/home/ak4605/data/finewebedu/'
2121
split = 'train'
22-
global_batch_size = 8
22+
global_batch_size = 64
2323
dtype = torch.int32
2424
seq_len = 2048
2525

@@ -44,35 +44,40 @@ def test_dataloader_torch():
4444
# print(f"inputs: {inputs}")
4545

4646
# Start test.
47-
for _ in range(100):
47+
for _ in range(1):
4848

4949
batch = next(input_queue)
50+
print(f"RANK {RANK} got batch")
5051

5152
assert type(batch) == dict
5253
assert 'inputs' in batch
5354
assert 'targets' in batch
5455

5556
inputs, targets = batch['inputs'], batch['targets']
56-
57+
print(f"RANK {RANK} inputs.shape: {inputs.shape}")
58+
print(f"RANK {RANK} targets.shape: {targets.shape}")
59+
print(f"RANK {RANK} type(inputs): {type(inputs)}")
5760
assert type(inputs) == torch.Tensor
5861
assert type(targets) == torch.Tensor
5962

6063
assert inputs.device == DEVICE
6164
assert targets.device == DEVICE
62-
6365
assert inputs.dtype == dtype
6466
assert targets.dtype == dtype
6567

68+
print(local_batch_size, seq_len)
6669
assert inputs.shape == (local_batch_size, seq_len)
6770
assert targets.shape == (local_batch_size, seq_len)
6871

6972
assert torch.equal(inputs[:, 1:], targets[:, :-1])
73+
print(f"RANK {RANK} inputs[0, :10]: {inputs[0, :10]}")
7074

7175
print(f"=== ALL TEST PASSED ===")
7276

7377

7478
def main():
7579
profiler = PassThroughProfiler()
80+
print(USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS)
7681
pytorch_init(USE_PYTORCH_DDP, RANK, profiler)
7782
test_dataloader_torch()
7883

algoperf/workloads/lm/workload.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ def _eval_batch(self,
132132
model_state,
133133
spec.ForwardPassMode.EVAL,
134134
rng,
135-
update_batch_norm=False)
135+
update_batch_norm=False,
136+
dropout_rate=None)
136137

137138
loss_dict = self.loss_fn(batch['targets'], logits)
138139
return loss_dict['summed']

0 commit comments

Comments
 (0)