Skip to content

Commit 654b567

Browse files
Ubunturajkthakur
authored andcommitted
Revert "Update PyTorch and XLA pin. (pytorch#9668)"
This reverts commit 11590c1.
1 parent be96adf commit 654b567

18 files changed

+82
-1390
lines changed

.github/workflows/_tpu_ci.yml

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,10 @@ jobs:
5151
pip install fsspec
5252
pip install rich
5353
54-
# Test dependencies
55-
pip install --upgrade protobuf
56-
pip install flax
57-
5854
# PyTorch/XLA Optional Dependencies
5955
# =================================
6056
#
61-
# Install `jax` and `libtpu` dependencies for pallas and TPU tests.
57+
# Install `JAX` and `libtpu` dependencies for pallas and TPU tests.
6258
#
6359
# Note that we might need to install pre-release versions of both, in
6460
# external artifact repositories.
@@ -74,6 +70,18 @@ jobs:
7470
pip install "$WHL[pallas]" --pre --index-url $INDEX --find-links $LINKS
7571
pip install "$WHL[tpu]" --pre --index-url $INDEX --find-links $LINKS
7672
73+
pip install --upgrade protobuf
74+
75+
# Flax Pin
76+
# ========
77+
#
78+
# Be careful when bumping the `flax` version, since it can cause tests that
79+
# depend on `jax` to start breaking.
80+
#
81+
# Newer `flax` versions might pull newer `jax` versions, which might be incompatible
82+
# with the current version of PyTorch/XLA.
83+
pip install flax==0.11.2
84+
7785
- name: Run Tests (${{ matrix.test_script }})
7886
if: inputs.has_code_changes == 'true'
7987
env:

.torch_commit

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
# 2025-09-29
2-
21fec65781bebe867faf209f89bb687ffd236ca4
1+
# 2025-09-17
2+
928ac57c2ab03f9f79376f9995553eea2e6f4ca8

BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
load("@python//:defs.bzl", "compile_pip_requirements")
12
load("@python_version_repo//:py_version.bzl", "REQUIREMENTS")
2-
load("@rules_python//python:pip.bzl", "compile_pip_requirements")
33

44
compile_pip_requirements(
55
name = "requirements",

WORKSPACE

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ new_local_repository(
5252

5353
# To build PyTorch/XLA with a new revison of OpenXLA, update the xla_hash to
5454
# the openxla git commit hash and note the date of the commit.
55-
xla_hash = '9a9aa0e11e4fcda8d6a9c3267dca6776ddbdb0ca' # Committed on 2025-10-01.
55+
xla_hash = '92f7b5952dd585c5be17c9a5caad27407005b513' # Committed on 2025-08-15.
5656

5757
http_archive(
5858
name = "xla",
@@ -63,7 +63,6 @@ http_archive(
6363
patch_tool = "patch",
6464
patches = [
6565
"//openxla_patches:no_fortify.diff",
66-
"//openxla_patches:if_constexpr_static_assert.diff",
6766
],
6867
strip_prefix = "xla-" + xla_hash,
6968
urls = [

openxla_patches/if_constexpr_static_assert.diff

Lines changed: 0 additions & 40 deletions
This file was deleted.

setup.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,12 @@
112112

113113
USE_NIGHTLY = True # Whether to use nightly or stable libtpu and JAX.
114114

115-
_libtpu_version = '0.0.24'
116-
_libtpu_date = '20250929'
115+
_libtpu_version = '0.0.21'
116+
_libtpu_date = '20250813'
117117

118-
_jax_version = '0.8.0'
119-
_jaxlib_version = '0.8.0'
120-
_jax_date = '20251001' # Date for jax and jaxlib.
118+
_jax_version = '0.7.1'
119+
_jaxlib_version = '0.7.1'
120+
_jax_date = '20250813' # Date for jax and jaxlib.
121121

122122
if USE_NIGHTLY:
123123
_libtpu_version += f".dev{_libtpu_date}+nightly"

test/spmd/test_fsdp_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_fsdp_v2_basic(self):
5555
# Make sure optimization barrier is applied.
5656
hlo = torch_xla._XLAC._get_xla_tensors_hlo([model.fc2.weight.grad])
5757
self.assertIn(
58-
'opt-barrier.1 = (f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) %tuple.2',
58+
'opt-barrier.38 = (f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) %tuple.37',
5959
hlo)
6060

6161
# Make sure the model can execute without error.

test/spmd/test_xla_sharding.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ def test_inplace_add_with_sharding(self):
613613
self.assertEqual(sharding_spec, torch_xla._XLAC._get_xla_sharding_spec(xt))
614614
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xt])
615615
self.assertIn(
616-
'%custom-call.1 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %add.1), custom_call_target="Sharding", sharding=',
616+
'%custom-call.7 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %add.6), custom_call_target="Sharding", sharding=',
617617
hlo)
618618

619619
# avoid calling xr.addressable_device_count here otherwise it will init the test
@@ -713,8 +713,7 @@ def test_xla_sharded_hlo_dump(self):
713713
partition_spec)
714714
xst2 = xst1 + 5
715715
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xst2.global_tensor])
716-
print(hlo)
717-
self.assertIn('%p1.1 = f32[1,8]{1,0} parameter(1), sharding', hlo)
716+
self.assertIn('%p1.3 = f32[1,8]{1,0} parameter(1), sharding', hlo)
718717
if torch_xla._XLAC._xla_get_auto_sharding():
719718
# scalar 5 should be implicitly replicated, so the pre-optimization HLO
720719
# shouldn't mark it with sharding.
@@ -829,13 +828,13 @@ def test_mark_sharding_ir(self):
829828
(0, 1))
830829
hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor])
831830
self.assertIn(
832-
'%custom-call.1 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.1), custom_call_target="Sharding", sharding=',
831+
'%custom-call.7 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.6), custom_call_target="Sharding", sharding=',
833832
hlo)
834833

835834
actual += 0
836835
hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor])
837836
self.assertIn(
838-
'%add.3 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.1, f32[1,128]{1,0} %broadcast.3)',
837+
'%add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.9, f32[1,128]{1,0} %broadcast.11)',
839838
hlo)
840839

841840
self.assertTrue(torch.allclose(expected, actual.cpu()))
@@ -1142,7 +1141,7 @@ def test_backward_optimization_barrier(self):
11421141

11431142
hlo = torch_xla._XLAC._get_xla_tensors_hlo([model.fc2.weight.grad])
11441143
self.assertIn(
1145-
'%opt-barrier.1 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.2)',
1144+
'%opt-barrier.37 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.36)',
11461145
hlo)
11471146

11481147
def test_mark_shard_scalar(self):
@@ -1199,7 +1198,7 @@ def test_spmd_full_to_shard_shape(self):
11991198

12001199
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xx])
12011200
self.assertEqual(xx.shape, (8, 8 // self.n_devices))
1202-
self.assertIn(f'%custom-call.1 = f32[8,{8//self.n_devices}]{{1,0}}', hlo)
1201+
self.assertIn(f'%custom-call.2 = f32[8,{8//self.n_devices}]{{1,0}}', hlo)
12031202
self.assertIn(
12041203
f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}', hlo)
12051204
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{manual}")
@@ -1216,7 +1215,7 @@ def test_spmd_full_to_shard_shape(self):
12161215

12171216
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xx])
12181217
self.assertEqual(xx.shape, (8, 4))
1219-
self.assertIn(f'%custom-call.1 = f32[8,4]{{1,0}}', hlo)
1218+
self.assertIn(f'%custom-call.2 = f32[8,4]{{1,0}}', hlo)
12201219
self.assertIn(
12211220
f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}', hlo)
12221221
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{manual}")
@@ -1247,7 +1246,7 @@ def test_spmd_shard_to_full_shape(self):
12471246

12481247
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xx])
12491248
self.assertEqual(xx.shape, x.shape)
1250-
self.assertIn('%custom-call.5 = f32[8,8]{1,0}', hlo)
1249+
self.assertIn('%custom-call.9 = f32[8,8]{1,0}', hlo)
12511250
self.assertIn(
12521251
'custom_call_target="SPMDShardToFullShape", sharding={replicated}', hlo)
12531252
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{replicated}")
@@ -1298,7 +1297,7 @@ def test_spmd_reduce_scatter(self):
12981297

12991298
hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
13001299
self.assertIn(
1301-
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.3), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.1",
1300+
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.3",
13021301
hlo)
13031302

13041303
expected_x = torch.ones(8 // self.n_devices, 8) * self.n_devices
@@ -1319,7 +1318,7 @@ def test_spmd_reduce_scatter_canonical_index(self):
13191318

13201319
hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
13211320
self.assertIn(
1322-
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.3), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.1",
1321+
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.3",
13231322
hlo)
13241323

13251324
expected_x = torch.ones(8, 8 // self.n_devices) * self.n_devices
@@ -1339,7 +1338,7 @@ def test_spmd_all_reduce(self):
13391338

13401339
hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
13411340
self.assertIn(
1342-
f"all-reduce(f32[8,8]{{1,0}} %custom-call.3), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.1",
1341+
f"all-reduce(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.3",
13431342
hlo)
13441343

13451344
expected_x = torch.ones(8, 8) * self.n_devices
@@ -1360,7 +1359,7 @@ def test_spmd_all_reduce_scale(self):
13601359

13611360
hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
13621361
self.assertIn(
1363-
f"all-reduce(f32[8,8]{{1,0}} %custom-call.3), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.1",
1362+
f"all-reduce(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.3",
13641363
hlo)
13651364

13661365
expected_x = torch.ones(8, 8) * int(self.n_devices * scale)
@@ -1714,7 +1713,7 @@ def test_annotate_custom_sharding(self):
17141713
f'%p0.1 = f32[2,4,64,64]{{3,2,1,0}} parameter(0), sharding={original_sharding_spec}',
17151714
hlo)
17161715
self.assertIn(
1717-
f'%custom-call.1 = f32[2,4,64,64]{{3,2,1,0}} custom-call(f32[2,4,64,64]{{3,2,1,0}} %p0.1), custom_call_target="Sharding", sharding={custom_sharding_spec}',
1716+
f'%custom-call.2 = f32[2,4,64,64]{{3,2,1,0}} custom-call(f32[2,4,64,64]{{3,2,1,0}} %p0.1), custom_call_target="Sharding", sharding={custom_sharding_spec}',
17181717
hlo)
17191718
xm.mark_step()
17201719
# Ensure that the resulting sharding spec is preserved

torch_xla/csrc/lowering_context.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ class LoweringContext : public torch::lazy::LoweringContext {
124124
};
125125

126126
// Reports an XLA builder error for the given node.
127-
ABSL_ATTRIBUTE_NORETURN void ReportBuilderError(const torch::lazy::Node& node,
128-
absl::string_view error_msg);
127+
TF_ATTRIBUTE_NORETURN void ReportBuilderError(const torch::lazy::Node& node,
128+
absl::string_view error_msg);
129129

130130
xla::XlaBuilder builder_;
131131
std::unordered_map<torch::lazy::BackendData::Handle, Parameter>

torch_xla/csrc/runtime/BUILD

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -381,34 +381,18 @@ cc_test(
381381
],
382382
)
383383

384-
cc_library(
385-
name = "tsl_platform_logging",
386-
srcs = ["tsl_platform_logging.cpp"],
387-
hdrs = ["tsl_platform_logging.h"],
388-
deps = [
389-
"@xla//xla/tsl/platform:env_time",
390-
"@xla//xla/tsl/platform:logging",
391-
"@xla//xla/tsl/platform:macros",
392-
"@xla//xla/tsl/platform:types",
393-
"@com_google_absl//absl/base:core_headers",
394-
"@com_google_absl//absl/base:log_severity",
395-
"@com_google_absl//absl/container:flat_hash_map",
396-
"@com_google_absl//absl/strings:str_format",
397-
"@com_google_absl//absl/strings:string_view",
398-
],
399-
)
400-
401384
cc_library(
402385
name = "tf_logging",
403386
srcs = ["tf_logging.cpp"],
404387
hdrs = ["tf_logging.h"],
405388
deps = [
406-
":tsl_platform_logging",
407389
"//torch_xla/csrc:status",
408390
"@torch//:headers",
409391
"@torch//:runtime_headers",
392+
"@tsl//tsl/platform:stacktrace",
393+
"@tsl//tsl/platform:statusor",
394+
"@xla//xla/service:platform_util",
410395
"@com_google_absl//absl/base:log_severity",
411-
"@com_google_absl//absl/log:absl_log",
412396
],
413397
)
414398

0 commit comments

Comments
 (0)