Skip to content

Support FSDP in JAX workloads #797

@priyakasimbeg

Description

@priyakasimbeg

It is useful to shard optimizer state across devices (to save significant memory). This reflects current practice. We want to support it.

  • We want to switch from no sharding to naive model parameter sharding in both framworks.
  • We will forbid (in the rules) any hacks that change the model parallelization strategy and have workload-default sharding.
  • Allow submitters to opt-out of it on a per-workload basis.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions