Skip to content

from_pretrained orchestration + distributed save/load#45409

Open
3outeille wants to merge 1 commit intomoe-sequence-parallelfrom
orchestration-save-load
Open

from_pretrained orchestration + distributed save/load#45409
3outeille wants to merge 1 commit intomoe-sequence-parallelfrom
orchestration-save-load

Conversation

@3outeille
Copy link
Copy Markdown
Member

Summary

  • Full distributed_config integration in from_pretrained() — mesh creation, apply TP + FSDP, attach model.device_mesh
  • gather_full_state_dict() for streaming DTensor→full tensor saving (rank 0 only)
  • convert_strided_to_shard() / restore_strided_from_shard() for DCP compatibility with _StridedShard
  • save_optimizer() / load_optimizer() in distributed/utils.py
  • Rename apply_fsdp2apply_fully_shard_data_parallel
  • Trainer integration with distributed_config

Part of the distributed training API chain: #44989

Chain: main ← #44989 ← #44083 ← #44974 ← #45028 ← #45408 ← this PR

Review question

Does from_pretrained wire things up in the right order? Is save/load round-trip correct?

Test plan

  • End-to-end from_pretrained with distributed_config
  • gather_full_state_dict() roundtrip verification
  • save_optimizer() / load_optimizer() roundtrip
  • Run existing FSDP and TP mixin tests

- Add gather_full_state_dict() for DTensor→full tensor saving
- Add convert_strided_to_shard() / restore_strided_from_shard() for DCP
- Add _redistribute_dtensor() helper
- Full distributed_config integration in from_pretrained/save_pretrained
- Rename apply_fsdp2 → apply_fully_shard_data_parallel
- save_optimizer() / load_optimizer() in distributed/utils
- Trainer integration with distributed_config
- Updated FSDP and TP tests for new orchestration API
- DTensor shard-on-read test updates
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant