@@ -613,7 +613,7 @@ def test_inplace_add_with_sharding(self):
613613 self .assertEqual (sharding_spec , torch_xla ._XLAC ._get_xla_sharding_spec (xt ))
614614 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([xt ])
615615 self .assertIn (
616- '%custom-call.1 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %add.1 ), custom_call_target="Sharding", sharding=' ,
616+ '%custom-call.7 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %add.6 ), custom_call_target="Sharding", sharding=' ,
617617 hlo )
618618
619619 # avoid calling xr.addressable_device_count here otherwise it will init the test
@@ -713,8 +713,7 @@ def test_xla_sharded_hlo_dump(self):
713713 partition_spec )
714714 xst2 = xst1 + 5
715715 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([xst2 .global_tensor ])
716- print (hlo )
717- self .assertIn ('%p1.1 = f32[1,8]{1,0} parameter(1), sharding' , hlo )
716+ self .assertIn ('%p1.3 = f32[1,8]{1,0} parameter(1), sharding' , hlo )
718717 if torch_xla ._XLAC ._xla_get_auto_sharding ():
719718 # scalar 5 should be implicitly replicated, so the pre-optimization HLO
720719 # shouldn't mark it with sharding.
@@ -829,13 +828,13 @@ def test_mark_sharding_ir(self):
829828 (0 , 1 ))
830829 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([actual .global_tensor ])
831830 self .assertIn (
832- '%custom-call.1 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.1 ), custom_call_target="Sharding", sharding=' ,
831+ '%custom-call.7 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.6 ), custom_call_target="Sharding", sharding=' ,
833832 hlo )
834833
835834 actual += 0
836835 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([actual .global_tensor ])
837836 self .assertIn (
838- '%add.3 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.1 , f32[1,128]{1,0} %broadcast.3 )' ,
837+ '%add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.9 , f32[1,128]{1,0} %broadcast.11 )' ,
839838 hlo )
840839
841840 self .assertTrue (torch .allclose (expected , actual .cpu ()))
@@ -1142,7 +1141,7 @@ def test_backward_optimization_barrier(self):
11421141
11431142 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([model .fc2 .weight .grad ])
11441143 self .assertIn (
1145- '%opt-barrier.1 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.2 )' ,
1144+ '%opt-barrier.37 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.36 )' ,
11461145 hlo )
11471146
11481147 def test_mark_shard_scalar (self ):
@@ -1199,7 +1198,7 @@ def test_spmd_full_to_shard_shape(self):
11991198
12001199 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([xx ])
12011200 self .assertEqual (xx .shape , (8 , 8 // self .n_devices ))
1202- self .assertIn (f'%custom-call.1 = f32[8,{ 8 // self .n_devices } ]{{1,0}}' , hlo )
1201+ self .assertIn (f'%custom-call.2 = f32[8,{ 8 // self .n_devices } ]{{1,0}}' , hlo )
12031202 self .assertIn (
12041203 f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}' , hlo )
12051204 self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (xx ), "{manual}" )
@@ -1216,7 +1215,7 @@ def test_spmd_full_to_shard_shape(self):
12161215
12171216 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([xx ])
12181217 self .assertEqual (xx .shape , (8 , 4 ))
1219- self .assertIn (f'%custom-call.1 = f32[8,4]{{1,0}}' , hlo )
1218+ self .assertIn (f'%custom-call.2 = f32[8,4]{{1,0}}' , hlo )
12201219 self .assertIn (
12211220 f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}' , hlo )
12221221 self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (xx ), "{manual}" )
@@ -1247,7 +1246,7 @@ def test_spmd_shard_to_full_shape(self):
12471246
12481247 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([xx ])
12491248 self .assertEqual (xx .shape , x .shape )
1250- self .assertIn ('%custom-call.5 = f32[8,8]{1,0}' , hlo )
1249+ self .assertIn ('%custom-call.9 = f32[8,8]{1,0}' , hlo )
12511250 self .assertIn (
12521251 'custom_call_target="SPMDShardToFullShape", sharding={replicated}' , hlo )
12531252 self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (xx ), "{replicated}" )
@@ -1298,7 +1297,7 @@ def test_spmd_reduce_scatter(self):
12981297
12991298 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([x ])
13001299 self .assertIn (
1301- f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.3 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.1 " ,
1300+ f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.3 " ,
13021301 hlo )
13031302
13041303 expected_x = torch .ones (8 // self .n_devices , 8 ) * self .n_devices
@@ -1319,7 +1318,7 @@ def test_spmd_reduce_scatter_canonical_index(self):
13191318
13201319 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([x ])
13211320 self .assertIn (
1322- f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.3 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.1 " ,
1321+ f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.3 " ,
13231322 hlo )
13241323
13251324 expected_x = torch .ones (8 , 8 // self .n_devices ) * self .n_devices
@@ -1339,7 +1338,7 @@ def test_spmd_all_reduce(self):
13391338
13401339 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([x ])
13411340 self .assertIn (
1342- f"all-reduce(f32[8,8]{{1,0}} %custom-call.3 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, to_apply=%AddComputation.1 " ,
1341+ f"all-reduce(f32[8,8]{{1,0}} %custom-call.2 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, to_apply=%AddComputation.3 " ,
13431342 hlo )
13441343
13451344 expected_x = torch .ones (8 , 8 ) * self .n_devices
@@ -1360,7 +1359,7 @@ def test_spmd_all_reduce_scale(self):
13601359
13611360 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([x ])
13621361 self .assertIn (
1363- f"all-reduce(f32[8,8]{{1,0}} %custom-call.3 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, to_apply=%AddComputation.1 " ,
1362+ f"all-reduce(f32[8,8]{{1,0}} %custom-call.2 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, to_apply=%AddComputation.3 " ,
13641363 hlo )
13651364
13661365 expected_x = torch .ones (8 , 8 ) * int (self .n_devices * scale )
@@ -1714,7 +1713,7 @@ def test_annotate_custom_sharding(self):
17141713 f'%p0.1 = f32[2,4,64,64]{{3,2,1,0}} parameter(0), sharding={ original_sharding_spec } ' ,
17151714 hlo )
17161715 self .assertIn (
1717- f'%custom-call.1 = f32[2,4,64,64]{{3,2,1,0}} custom-call(f32[2,4,64,64]{{3,2,1,0}} %p0.1), custom_call_target="Sharding", sharding={ custom_sharding_spec } ' ,
1716+ f'%custom-call.2 = f32[2,4,64,64]{{3,2,1,0}} custom-call(f32[2,4,64,64]{{3,2,1,0}} %p0.1), custom_call_target="Sharding", sharding={ custom_sharding_spec } ' ,
17181717 hlo )
17191718 xm .mark_step ()
17201719 # Ensure that the resulting sharding spec is preserved
0 commit comments