Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cn/docs/parallelism/05_ddp.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@
download=True,
)

sampler = flow.utils.data.distributed.DistributedSampler(training_data)

train_dataloader = flow.utils.data.DataLoader(
training_data, BATCH_SIZE, shuffle=True
training_data, BATCH_SIZE, shuffle=(sampler is None), sampler=sampler
)

model = flowvision.models.mobilenet_v2().to(DEVICE)
Expand All @@ -48,6 +50,7 @@

for t in range(EPOCH_NUM):
print(f"Epoch {t+1}\n-------------------------------")
train_dataloader.sampler.set_epoch(t)
size = len(train_dataloader.dataset)
for batch, (x, y) in enumerate(train_dataloader):
x = x.to_global(placement=PLACEMENT, sbp=S0)
Expand Down