Skip to content

Conversation

@davidtweedle
Copy link
Contributor

A draft PR for potential changes for the pytorch workloads to upgrade to FSDP (fully sharded data parallel) from DDP (distributed data parallel).

Summary for changes to: cifar, mnist, criteo1tb, imagenet vit, imagenet resnet, librispeech deepspeech, librispeech conformer, ogbg, wmt, fastmri

  • import required packages (e.g., fsdp)
  • construct model using FSDP constructor instead of DDP constructor
  • very naive sharding for now
  • Crucially: must zero grad before eval

Summary for changes to momentum (as simple test optimizer):

  • first compute weighted loss on each device
  • then loss.backward (the gradient of the losses will now be all reduced by a pytorch communication hook)
  • then display the correct loss

@github-actions
Copy link

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

@priyakasimbeg
Copy link
Contributor

won't fix. closing for now

@github-actions github-actions bot locked and limited conversation to collaborators Mar 13, 2025
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants