Skip to content

Conversation

@hshahTT
Copy link
Contributor

@hshahTT hshahTT commented Aug 5, 2025

This PR implements the necessary changes to support the Shardy dialect within Torch-XLA (relevant issue: #9348):

  1. Adding support for V2 HLO sharding within the OpSharding and XlaShardingSpec classes (since Shardy doesn't support the V1 shardings that are currently implemented).
  2. Add the OpenXLA addStablehloImportPipeline() pass that performs the SHLO to Shardy conversion.
  3. This is protected by the "CONVERT_SHLO_TO_SHARDY" environment variable.

@hshahTT
Copy link
Contributor Author

hshahTT commented Aug 5, 2025

Note: this PR still needs tests added to it, that I will add in a future commit before merging this PR in. I was hoping someone with more knowledge would let me know where to add them, and also take a look at the V2 logic to make sure I didn't make any obvious mistakes.

We (Tenstorrent) have tested this with our own MLIR compiler that ingests Shardy graphs and we saw that the sharding worked as intended for some basic sharding specs. We were also able to run tensor parallel inference on the Llama 3.1 8B model with these changes.

Also, the visualize_sharding_spec() function is broken when the "CONVERT_SHLO_TO_SHARDY" environment variable is set, since that function expects the sharding string to be in V1 format. I will add that fix in a future commit once I know where to add the tests and can make sure I've accounted for all possible sharding specs correctly.

@hshahTT hshahTT mentioned this pull request Aug 5, 2025
@hshahTT
Copy link
Contributor Author

hshahTT commented Aug 14, 2025

Hi, could I get some eyes on this please?

@hshahTT
Copy link
Contributor Author

hshahTT commented Aug 30, 2025

Thanks @ZixuanJiang for your review! The link you gave me to the OpenXLA implementation and your comment here were very useful.

Based on your comments I made the following changes:

  • xla_sharding.py:

    • Modified the logic in _get_op_sharding_args_v2() to match the convertToHloSharding() function within OpenXLA
  • test_xla_sharding.py:

    • Modified the existing test cases to test the V2 sharding logic when the CONVERT_SHLO_TO_SHARDY environment variable is set.
    • The way it is currently setup requires you to set the environment variable before running the test (i.e., by running CONVERT_SHLO_TO_SHARDY=1 python test_xla_sharding.py in whatever Bash script actually calls it in CI).
    • Another way is to parameterize the testing class itself based on whether the env var should be set or not (meaning we run all the tests once with the env variable unset and then again with it set). I can do this but may need to add the parameterized pip module as a test dependency to do it cleanly.
  • debugging.py:

    • Added support for V2 shardings within the visualize_tensor_sharding() debugging function by converting them into V1 shardings first (which are already supported) via the construct_v1_sharding_str() function.
    • I added a _get_xla_op_sharding_v2_params Pybind function inside init_python_bindings.cpp that takes a tensor and returns all the V2 sharding parameters needed to represent it: tile_assignment_dims, reshape_dims, transpose_perm, is_last_tile_dim_replicate.
    • Another way I can get those params is by just parsing the V2 sharding string directly, since it already has the format:
      {devices=[tile_assignment_dims]<=[reshape_dims]T(transpose_perm) last_tile_dim_replicate}
      
      but I thought that the Pybind function would be more readable.
  • test_spmd_debugging.py: Added the ConvertV2ShardingToV1Test testing class to test the construct_v1_sharding_str() function.

I tested everything on a Cloud v4-8 TPU courtesy of Google's TPU Research Cloud program. Please let me know if anything else needs to be done!

@hshahTT hshahTT force-pushed the hshah/v2-sharding-pr branch from 7ac5b09 to 3f2af6b Compare September 2, 2025 21:03
hshahTT added a commit to tenstorrent/pytorch-xla that referenced this pull request Sep 3, 2025
…ation (#7)

This PR adds support for all previously unsupported partition specs and fixes the visualize_tensor_sharding() function to support V2 sharding specs.

See pytorch#9541 for the upstream PR discussion and additional context.

* Add some tests and reviewer suggestions. Will update V2 op sharding logic in a later commit soon.

* New implementation (WIP)

* Fix new implementation

* Fix visualize_tensor_sharding function for V2 shardings
@hshahTT hshahTT force-pushed the hshah/v2-sharding-pr branch from 9c0e2c0 to d1e0b17 Compare September 7, 2025 06:14
@hshahTT
Copy link
Contributor Author

hshahTT commented Sep 7, 2025

Saw the failing build on CI and added a fix for it: d1e0b17. Could someone please re-run it?

@hshahTT
Copy link
Contributor Author

hshahTT commented Sep 11, 2025

Hi, friendly ping here for a review and CI run please

Copy link

@ZixuanJiang ZixuanJiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

hshahTT and others added 7 commits October 3, 2025 16:58
…n PyTorch/XLA (#1)

Adds an environment variable CONVERT_SHLO_TO_SHARDY that does 2 things:

- Uses V2 sharding annotations when generating the GSPMD SHLO module (i.e., in V1 a mesh annotation string like: devices=[2,1,4]0,1,2,3,4,5,6,7 becomes this in V2: devices=[2,1,4]<=[8]).
- Converts the new GSPMD module with the V2 annotations into a Shardy module.
…chip training (#2)

* Add V2 sharding support and improve partition spec handling for multi-chip training

These changes are required to support multi-chip training for real models on the torch-xla side.

- Added V2 OpSharding support in XlaShardingSpec, which is used internally by MpLoader for parallel input loading. The original implementation only supported V1 shardings.
- Fixed environment variable parsing for CONVERT_SHLO_TO_SHARDY - previous logic treated values like "0" or "false" as truthy.
- Added logic to compute dims, reshape_dims, and transpose_perm for V2 sharding based on mesh_shape and partition_spec.

The new logic now correctly handles cases that were previously unsupported:

  case 1: mesh_shape=(2,1,1,1), partition_spec=(0,None,None,None)
           -> dims=[2,1,1,1], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]

  case 2: mesh_shape=(2,1,1,1), partition_spec=(0,)
          Ã-> dims=[2], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]

  case 3: mesh_shape=(2,4), partition_spec=(0,None)
           -> dims=[2,1,4], reshape_dims=[2,4], transpose_perm=[0,1]

* Fix formatting according to Torch-XLA style guide

---------

Co-authored-by: Het Shah <[email protected]>
@hshahTT hshahTT force-pushed the hshah/v2-sharding-pr branch from d1e0b17 to ba3995c Compare October 3, 2025 17:08
@hshahTT hshahTT requested review from ZixuanJiang and qihqi October 3, 2025 17:15
sshonTT pushed a commit to tenstorrent/pytorch-xla that referenced this pull request Oct 3, 2025
…ation (#7)

This PR adds support for all previously unsupported partition specs and fixes the visualize_tensor_sharding() function to support V2 sharding specs.

See pytorch#9541 for the upstream PR discussion and additional context.

* Add some tests and reviewer suggestions. Will update V2 op sharding logic in a later commit soon.

* New implementation (WIP)

* Fix new implementation

* Fix visualize_tensor_sharding function for V2 shardings
@hshahTT
Copy link
Contributor Author

hshahTT commented Oct 7, 2025

Thanks for re-running the CI! I fixed the errors in ConvertV2ShardingToV1Test in this commit. However, I wasn't able to reproduce the failure in test/tpu/run_training_tests.sh on a v4-8 TPU. It also seems unrelated to my changes in this PR since the CONVERT_SHLO_TO_SHARDY flag wasn't enabled for that test. Does anyone what might be happening there?

@hshahTT
Copy link
Contributor Author

hshahTT commented Oct 10, 2025

Hey there, does this seem good to be merged in now?

@bfolie
Copy link
Collaborator

bfolie commented Oct 10, 2025

Hey there, does this seem good to be merged in now?

yes

@hshahTT
Copy link
Contributor Author

hshahTT commented Oct 10, 2025

Hey there, does this seem good to be merged in now?

yes

Awesome, would you be able to squash merge it in for me? I don't have the permissions to do it myself.

@bfolie bfolie merged commit 0fa6e31 into pytorch:master Oct 10, 2025
26 of 27 checks passed
@bfolie
Copy link
Collaborator

bfolie commented Oct 10, 2025

Hey there, does this seem good to be merged in now?

yes

Awesome, would you be able to squash merge it in for me? I don't have the permissions to do it myself.

Oh yes, sorry -- didn't realize that

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants