66import  unittest 
77from  unittest .mock  import  patch 
88import  sys 
9- import  os 
109
1110import  torch 
1211from  torch  import  nn 
2726from  torch_xla ._internal  import  tpu 
2827
2928
30- def  should_convert_to_shardy ():
31-   return  os .environ .get ("CONVERT_SHLO_TO_SHARDY" ,
32-                         "" ).lower () in  ("1" , "true" , "yes" )
33- 
34- 
3529class  BasicXlaShardingTest (test_xla_sharding_base .XlaShardingTest ):
3630
3731  @classmethod  
3832  def  setUpClass (cls ):
3933    super ().setUpClass ()
34+     cls .convert_to_shardy  =  xu .check_env_flag ("CONVERT_SHLO_TO_SHARDY" )
4035
4136  def  test_xla_sharded_tensor (self ):
4237    partition_spec  =  (0 , 1 )
@@ -244,7 +239,7 @@ def test_custom_tile_assignment(self):
244239    if  self .n_devices  >  1 :
245240      annotation  =  '{devices=[1,%d]%s}'  %  (self .n_devices , ',' .join (
246241          [str (i ) for  i  in  reversed (range (self .n_devices ))]))
247-       if  should_convert_to_shardy () :
242+       if  self . convert_to_shardy :
248243        annotation  =  '{devices=[1,%d]<=[%d]}'  %  (self .n_devices , self .n_devices )
249244      self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (xt ))
250245
@@ -260,7 +255,7 @@ def test_mark_sharding_2d(self):
260255    if  self .n_devices  >  1 :
261256      annotation  =  '{devices=[1,%d]%s}'  %  (self .n_devices , ',' .join (
262257          [str (i ) for  i  in  range (self .n_devices )]))
263-       if  should_convert_to_shardy () :
258+       if  self . convert_to_shardy :
264259        annotation  =  '{devices=[1,%d]<=[%d]}'  %  (self .n_devices , self .n_devices )
265260      self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (xt1 ))
266261
@@ -281,7 +276,7 @@ def test_mark_sharding_4d(self):
281276      annotation  =  '{devices=[1,1,%d,%d]%s}'  %  (
282277          z_dim , self .n_devices  //  z_dim , ',' .join (
283278              [str (i ) for  i  in  range (self .n_devices )]))
284-       if  should_convert_to_shardy () :
279+       if  self . convert_to_shardy :
285280        annotation  =  '{devices=[1,1,%d,%d]<=[%d]}'  %  (z_dim , self .n_devices  // 
286281                                                      z_dim , self .n_devices )
287282      self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (xt ))
@@ -418,7 +413,7 @@ def test_tupled_partition_spec(self):
418413    xs .mark_sharding (t , mesh , ((0 , 1 ),))
419414    annotation  =  "{devices=[%d]%s}"  %  (self .n_devices , ',' .join (
420415        str (x ) for  x  in  range (self .n_devices )))
421-     if  should_convert_to_shardy () :
416+     if  self . convert_to_shardy :
422417      annotation  =  "{devices=[%d]<=[%d]}"  %  (self .n_devices , self .n_devices )
423418    self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
424419
@@ -432,7 +427,7 @@ def test_named_partial_tupled_partition_spec(self):
432427    xs .mark_sharding (t , mesh , (('r' , 'b' ), None ))
433428    annotation  =  "{devices=[2,1,%d]%s last_tile_dim_replicate}"  %  (
434429        self .n_devices  //  2 , ',' .join (str (x ) for  x  in  range (self .n_devices )))
435-     if  should_convert_to_shardy () :
430+     if  self . convert_to_shardy :
436431      annotation  =  "{devices=[2,1,%d]<=[%d] last_tile_dim_replicate}"  %  (
437432          self .n_devices  //  2 , self .n_devices )
438433    self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
@@ -442,7 +437,7 @@ def test_named_partial_tupled_partition_spec(self):
442437    xs .mark_sharding (u , mesh , (None , ('b' , 'm' )))
443438    annotation  =  "{devices=[1,%d]%s}"  %  (self .n_devices , ',' .join (
444439        str (x ) for  x  in  range (self .n_devices )))
445-     if  should_convert_to_shardy () :
440+     if  self . convert_to_shardy :
446441      annotation  =  "{devices=[1,%d]<=[%d]}"  %  (self .n_devices , self .n_devices )
447442    self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (u ), annotation )
448443
@@ -452,7 +447,7 @@ def test_named_partial_tupled_partition_spec(self):
452447    device_order  =  mesh .get_logical_mesh ().transpose ((0 , 2 , 1 )).flatten ()
453448    annotation  =  "{devices=[1,%d,2]%s last_tile_dim_replicate}"  %  (
454449        self .n_devices  //  2 , ',' .join (str (x ) for  x  in  device_order ))
455-     if  should_convert_to_shardy () :
450+     if  self . convert_to_shardy :
456451      annotation  =  "{devices=[1,%d,2]<=[2,%d]T(1,0) last_tile_dim_replicate}"  %  (
457452          self .n_devices  //  2 , self .n_devices  //  2 )
458453    self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (v ), annotation )
@@ -463,7 +458,7 @@ def test_named_partial_tupled_partition_spec(self):
463458    device_order  =  mesh .get_logical_mesh ().transpose ((2 , 1 , 0 )).flatten ()
464459    annotation  =  "{devices=[1,%d]%s}"  %  (self .n_devices , ',' .join (
465460        str (x ) for  x  in  device_order ))
466-     if  should_convert_to_shardy () :
461+     if  self . convert_to_shardy :
467462      annotation  =  "{devices=[1,%d]<=[2,%d]T(1,0)}"  %  (self .n_devices ,
468463                                                       self .n_devices  //  2 )
469464    self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (v ), annotation )
@@ -478,7 +473,7 @@ def test_multiple_tuples_in_spec(self):
478473    xs .mark_sharding (t , mesh , (('a' , 'b' ), ('c' , 'd' )))
479474    annotation  =  "{devices=[2,%d]%s}"  %  (self .n_devices  //  2 , ',' .join (
480475        str (x ) for  x  in  range (self .n_devices )))
481-     if  should_convert_to_shardy () :
476+     if  self . convert_to_shardy :
482477      annotation  =  "{devices=[2,%d]<=[%d]}"  %  (self .n_devices  //  2 ,
483478                                               self .n_devices )
484479    self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
@@ -491,7 +486,7 @@ def test_3d_tensor_2d_mesh(self):
491486    xs .mark_sharding (t , mesh , (None , 0 , 1 ))
492487    annotation  =  '{devices=[1,2,%d]%s}'  %  (self .n_devices  //  2 , ',' .join (
493488        str (x ) for  x  in  range (self .n_devices )))
494-     if  should_convert_to_shardy () :
489+     if  self . convert_to_shardy :
495490      annotation  =  '{devices=[1,2,%d]<=[%d]}'  %  (self .n_devices  //  2 ,
496491                                                 self .n_devices )
497492    self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
@@ -1013,8 +1008,7 @@ def test_op_sharding_cache(self):
10131008
10141009    t  =  torch .randn (1 , self .n_devices ).to ('xla' )
10151010    xs .mark_sharding (t , mesh , (0 , 1 ))
1016-     counter_name  =  "CreateIotaOpSharding"  if  should_convert_to_shardy (
1017-     ) else  "CreateOpSharding" 
1011+     counter_name  =  "CreateIotaOpSharding"  if  self .convert_to_shardy  else  "CreateOpSharding" 
10181012    self .assertIn (counter_name , met .counter_names ())
10191013    self .assertEqual (met .counter_value (counter_name ), 1 )
10201014
@@ -1435,7 +1429,7 @@ def test_data_loader_with_sharding(self):
14351429    data , _  =  iter (train_device_loader ).__next__ ()
14361430    self .assertEqual (data .size (), torch .Size ([8 , 3 , 64 , 64 ]))
14371431    annotation  =  f"{{devices=[{ mesh .size ()} { ',' .join ([str (i ) for  i  in  range (mesh .size ())])}  
1438-     if  should_convert_to_shardy () :
1432+     if  self . convert_to_shardy :
14391433      annotation  =  f"{{devices=[{ mesh .size ()} { mesh .size ()}  
14401434    self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (data ), annotation )
14411435
@@ -1458,7 +1452,7 @@ def test_data_loader_with_non_batch_size(self):
14581452    data , _  =  iter (train_device_loader ).__next__ ()
14591453    self .assertEqual (data .size (), torch .Size ([mesh .size () -  1 , 3 , 64 , 64 ]))
14601454    annotation  =  f"{{devices=[{ mesh .size ()} { ',' .join ([str (i ) for  i  in  range (mesh .size ())])}  
1461-     if  should_convert_to_shardy () :
1455+     if  self . convert_to_shardy :
14621456      annotation  =  f"{{devices=[{ mesh .size ()} { mesh .size ()}  
14631457    self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (data ), annotation )
14641458
0 commit comments