Skip to content

Conversation

@rka97
Copy link
Contributor

@rka97 rka97 commented Dec 9, 2024

Purpose

The goal of this PR is to allow model parameter and optimizer state sharding, and also to migrate the JAX code from using jax.pmap to using jax.jit.

TODOs:

  • Migrate reference optimizers to use jax.jit
    • Nesterov
    • AdamW
    • Others
  • Migrate workloads to use jax.jit
    • (Test workload) MNIST
    • (Test workload) CIFAR
    • WMT
    • Criteo1TB
    • FastMRI
    • Librispeech
    • OGBG
    • ImageNet

Changelog

  • Added some sharding utilities to handle data distributed
  • Replaced pmap code for CIFAR/MNIST with jit
  • Modified AdamW and Nesterov accordingly
  • Updated checkpoint and data_utils to support the new approach (mostly removing explicit jax_utils.replicate calls).

Issues

  • Prefetching functionality in CIFAR is temporarily disabled (marked with FIXME), not sure how to best support it here.
  • I haven't edited any of the PyTorch code, we will need to make sure they still do comparably..

@rka97 rka97 requested a review from a team as a code owner December 9, 2024 21:21
@github-actions
Copy link

github-actions bot commented Dec 9, 2024

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

@rka97
Copy link
Contributor Author

rka97 commented Dec 9, 2024

recheck

Still need to test out (a) output losses, (b) speed, and (c) look into
other librispeech.
@priyakasimbeg
Copy link
Contributor

migrating this PR to one that merges from branch on this repo.

@github-actions github-actions bot locked and limited conversation to collaborators Mar 6, 2025
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants