@@ -31,6 +31,7 @@ class BasicXlaShardingTest(test_xla_sharding_base.XlaShardingTest):
3131 @classmethod
3232 def setUpClass (cls ):
3333 super ().setUpClass ()
34+ cls .convert_to_shardy = xu .check_env_flag ("CONVERT_SHLO_TO_SHARDY" )
3435
3536 def test_xla_sharded_tensor (self ):
3637 partition_spec = (0 , 1 )
@@ -238,6 +239,8 @@ def test_custom_tile_assignment(self):
238239 if self .n_devices > 1 :
239240 annotation = '{devices=[1,%d]%s}' % (self .n_devices , ',' .join (
240241 [str (i ) for i in reversed (range (self .n_devices ))]))
242+ if self .convert_to_shardy :
243+ annotation = '{devices=[1,%d]<=[%d]}' % (self .n_devices , self .n_devices )
241244 self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (xt ))
242245
243246 def test_mark_sharding_2d (self ):
@@ -252,6 +255,8 @@ def test_mark_sharding_2d(self):
252255 if self .n_devices > 1 :
253256 annotation = '{devices=[1,%d]%s}' % (self .n_devices , ',' .join (
254257 [str (i ) for i in range (self .n_devices )]))
258+ if self .convert_to_shardy :
259+ annotation = '{devices=[1,%d]<=[%d]}' % (self .n_devices , self .n_devices )
255260 self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (xt1 ))
256261
257262 actual = (xt1 + xt2 ).cpu ()
@@ -271,6 +276,9 @@ def test_mark_sharding_4d(self):
271276 annotation = '{devices=[1,1,%d,%d]%s}' % (
272277 z_dim , self .n_devices // z_dim , ',' .join (
273278 [str (i ) for i in range (self .n_devices )]))
279+ if self .convert_to_shardy :
280+ annotation = '{devices=[1,1,%d,%d]<=[%d]}' % (z_dim , self .n_devices //
281+ z_dim , self .n_devices )
274282 self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (xt ))
275283
276284 actual = (xt + xt ).cpu ()
@@ -403,9 +411,11 @@ def test_tupled_partition_spec(self):
403411 mesh = self ._get_mesh ((2 , self .n_devices // 2 ))
404412 t = torch .randn (16 ).to ('xla' )
405413 xs .mark_sharding (t , mesh , ((0 , 1 ),))
406- self .assertEqual (
407- torch_xla ._XLAC ._get_xla_sharding_spec (t ), "{devices=[%d]%s}" %
408- (self .n_devices , ',' .join (str (x ) for x in range (self .n_devices ))))
414+ annotation = "{devices=[%d]%s}" % (self .n_devices , ',' .join (
415+ str (x ) for x in range (self .n_devices )))
416+ if self .convert_to_shardy :
417+ annotation = "{devices=[%d]<=[%d]}" % (self .n_devices , self .n_devices )
418+ self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
409419
410420 @unittest .skipUnless (xr .global_runtime_device_count () >= 4 ,
411421 "Multiple devices required for tupled partition spec" )
@@ -415,34 +425,43 @@ def test_named_partial_tupled_partition_spec(self):
415425 # Shard the first dimension on `r` and `b`, replicate the second dimension
416426 t = torch .randn (16 , 16 ).to ('xla' )
417427 xs .mark_sharding (t , mesh , (('r' , 'b' ), None ))
418- self .assertEqual (
419- torch_xla ._XLAC ._get_xla_sharding_spec (t ),
420- "{devices=[2,1,%d]%s last_tile_dim_replicate}" %
421- (self .n_devices // 2 , ',' .join (str (x ) for x in range (self .n_devices ))))
428+ annotation = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % (
429+ self .n_devices // 2 , ',' .join (str (x ) for x in range (self .n_devices )))
430+ if self .convert_to_shardy :
431+ annotation = "{devices=[2,1,%d]<=[%d] last_tile_dim_replicate}" % (
432+ self .n_devices // 2 , self .n_devices )
433+ self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
422434
423435 # Replicate the first dimension, shard the second on `b` and `m`
424436 u = torch .randn (16 , 16 ).to ('xla' )
425437 xs .mark_sharding (u , mesh , (None , ('b' , 'm' )))
426- self .assertEqual (
427- torch_xla ._XLAC ._get_xla_sharding_spec (u ), "{devices=[1,%d]%s}" %
428- (self .n_devices , ',' .join (str (x ) for x in range (self .n_devices ))))
438+ annotation = "{devices=[1,%d]%s}" % (self .n_devices , ',' .join (
439+ str (x ) for x in range (self .n_devices )))
440+ if self .convert_to_shardy :
441+ annotation = "{devices=[1,%d]<=[%d]}" % (self .n_devices , self .n_devices )
442+ self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (u ), annotation )
429443
430444 # Replicate the first dimension, shard the second on `r` and `m`
431445 v = torch .randn (16 , 16 ).to ('xla' )
432446 xs .mark_sharding (v , mesh , (None , ('r' , 'm' )))
433447 device_order = mesh .get_logical_mesh ().transpose ((0 , 2 , 1 )).flatten ()
434- self .assertEqual (
435- torch_xla ._XLAC ._get_xla_sharding_spec (v ),
436- "{devices=[1,%d,2]%s last_tile_dim_replicate}" %
437- (self .n_devices // 2 , ',' .join (str (x ) for x in device_order )))
448+ annotation = "{devices=[1,%d,2]%s last_tile_dim_replicate}" % (
449+ self .n_devices // 2 , ',' .join (str (x ) for x in device_order ))
450+ if self .convert_to_shardy :
451+ annotation = "{devices=[1,%d,2]<=[2,%d]T(1,0) last_tile_dim_replicate}" % (
452+ self .n_devices // 2 , self .n_devices // 2 )
453+ self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (v ), annotation )
438454
439455 # Replicate the first dimension, shard the second on `m` and `b`
440456 v = torch .randn (16 , 16 ).to ('xla' )
441457 xs .mark_sharding (v , mesh , (None , ('m' , 'b' )))
442458 device_order = mesh .get_logical_mesh ().transpose ((2 , 1 , 0 )).flatten ()
443- self .assertEqual (
444- torch_xla ._XLAC ._get_xla_sharding_spec (v ), "{devices=[1,%d]%s}" %
445- (self .n_devices , ',' .join (str (x ) for x in device_order )))
459+ annotation = "{devices=[1,%d]%s}" % (self .n_devices , ',' .join (
460+ str (x ) for x in device_order ))
461+ if self .convert_to_shardy :
462+ annotation = "{devices=[1,%d]<=[2,%d]T(1,0)}" % (self .n_devices ,
463+ self .n_devices // 2 )
464+ self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (v ), annotation )
446465
447466 @unittest .skipUnless (xr .global_runtime_device_count () > 1 ,
448467 'Multiple devices required for tupled partition spec' )
@@ -452,19 +471,25 @@ def test_multiple_tuples_in_spec(self):
452471 ('a' , 'b' , 'c' , 'd' ))
453472 t = torch .randn (2 , 2 ).to ('xla' )
454473 xs .mark_sharding (t , mesh , (('a' , 'b' ), ('c' , 'd' )))
455- self .assertEqual (
456- torch_xla ._XLAC ._get_xla_sharding_spec (t ), "{devices=[2,%d]%s}" %
457- (self .n_devices // 2 , ',' .join (str (x ) for x in range (self .n_devices ))))
474+ annotation = "{devices=[2,%d]%s}" % (self .n_devices // 2 , ',' .join (
475+ str (x ) for x in range (self .n_devices )))
476+ if self .convert_to_shardy :
477+ annotation = "{devices=[2,%d]<=[%d]}" % (self .n_devices // 2 ,
478+ self .n_devices )
479+ self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
458480
459481 @unittest .skipUnless (xr .global_runtime_device_count () > 1 ,
460482 'At least 2 devices needed for 2D mesh' )
461483 def test_3d_tensor_2d_mesh (self ):
462484 mesh = self ._get_mesh ((2 , self .n_devices // 2 ))
463485 t = torch .randn (16 , 16 , 16 ).to ('xla' )
464486 xs .mark_sharding (t , mesh , (None , 0 , 1 ))
465- self .assertEqual (
466- torch_xla ._XLAC ._get_xla_sharding_spec (t ), '{devices=[1,2,%d]%s}' %
467- (self .n_devices // 2 , ',' .join (str (x ) for x in range (self .n_devices ))))
487+ annotation = '{devices=[1,2,%d]%s}' % (self .n_devices // 2 , ',' .join (
488+ str (x ) for x in range (self .n_devices )))
489+ if self .convert_to_shardy :
490+ annotation = '{devices=[1,2,%d]<=[%d]}' % (self .n_devices // 2 ,
491+ self .n_devices )
492+ self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
468493
469494 def test_partial_replication_addmm (self ):
470495 device = torch_xla .device ()
@@ -983,18 +1008,20 @@ def test_op_sharding_cache(self):
9831008
9841009 t = torch .randn (1 , self .n_devices ).to ('xla' )
9851010 xs .mark_sharding (t , mesh , (0 , 1 ))
986- self .assertIn ("CreateOpSharding" , met .counter_names ())
987- self .assertEqual (met .counter_value ("CreateOpSharding" ), 1 )
1011+ counter_name = "CreateIotaOpSharding" if self .convert_to_shardy else "CreateOpSharding"
1012+ self .assertIn (counter_name , met .counter_names ())
1013+ self .assertEqual (met .counter_value (counter_name ), 1 )
9881014
9891015 # Sharding with the same partition spec should not result in another call
9901016 u = torch .randn (1 , self .n_devices ).to ('xla' )
9911017 xs .mark_sharding (u , mesh , (0 , 1 ))
992- self .assertEqual (met .counter_value ("CreateOpSharding" ), 1 )
1018+ self .assertEqual (met .counter_value (counter_name ), 1 )
9931019
994- # Changing the partition spec will result in another CreateOpSharding
1020+ # Changing the partition spec will result in another
1021+ # CreateOpSharding or CreatingIotaOpSharding call
9951022 v = torch .randn (1 , self .n_devices ).to ('xla' )
9961023 xs .mark_sharding (v , mesh , (0 , None ))
997- self .assertEqual (met .counter_value ("CreateOpSharding" ), 2 )
1024+ self .assertEqual (met .counter_value (counter_name ), 2 )
9981025
9991026 def test_from_cpu_shards_replicated (self ):
10001027 from_cpu_shards = torch_xla ._XLAC ._global_tensor_from_cpu_shards
@@ -1397,10 +1424,10 @@ def test_data_loader_with_sharding(self):
13971424 input_sharding = xs .ShardingSpec (mesh , ('data' , None , None , None )))
13981425 data , _ = iter (train_device_loader ).__next__ ()
13991426 self .assertEqual (data .size (), torch .Size ([8 , 3 , 64 , 64 ]))
1400- self . assertEqual (
1401- torch_xla . _XLAC . _get_xla_sharding_spec ( data ),
1402- f"{{devices=[{ mesh .size ()} ,1,1,1]{ ',' . join ([ str ( i ) for i in range ( mesh .size ())]) } }}"
1403- )
1427+ annotation = f"{{devices=[ { mesh . size () } ,1,1,1] { ',' . join ([ str ( i ) for i in range ( mesh . size ())]) } }}"
1428+ if self . convert_to_shardy :
1429+ annotation = f"{{devices=[{ mesh .size ()} ,1,1,1]<=[ { mesh .size ()} ] }}"
1430+ self . assertEqual ( torch_xla . _XLAC . _get_xla_sharding_spec ( data ), annotation )
14041431
14051432 @unittest .skipUnless (
14061433 xr .global_runtime_device_count () > 1 ,
@@ -1420,10 +1447,10 @@ def test_data_loader_with_non_batch_size(self):
14201447 input_sharding = xs .ShardingSpec (mesh , ('data' , None , None , None )))
14211448 data , _ = iter (train_device_loader ).__next__ ()
14221449 self .assertEqual (data .size (), torch .Size ([mesh .size () - 1 , 3 , 64 , 64 ]))
1423- self . assertEqual (
1424- torch_xla . _XLAC . _get_xla_sharding_spec ( data ),
1425- f"{{devices=[{ mesh .size ()} ,1,1,1]{ ',' . join ([ str ( i ) for i in range ( mesh .size ())]) } }}"
1426- )
1450+ annotation = f"{{devices=[ { mesh . size () } ,1,1,1] { ',' . join ([ str ( i ) for i in range ( mesh . size ())]) } }}"
1451+ if self . convert_to_shardy :
1452+ annotation = f"{{devices=[{ mesh .size ()} ,1,1,1]<=[ { mesh .size ()} ] }}"
1453+ self . assertEqual ( torch_xla . _XLAC . _get_xla_sharding_spec ( data ), annotation )
14271454
14281455 @unittest .skipUnless (
14291456 xr .global_runtime_device_count () > 1 ,
0 commit comments