@@ -8432,16 +8432,16 @@ def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]:
84328432 return op .SplitToSequence (self , split_sizes , axis = dim , keepdims = False )
84338433
84348434
8435- @torch_op ("aten::unflatten.int" )
8436- def aten_unflatten (self : TReal , dim : INT64 , sizes : INT64 ):
8435+ @torch_op ("aten::unflatten.int" , trace_only = True )
8436+ def aten_unflatten (self : TReal , dim : int , sizes : Sequence [ INT64 ] ):
84378437 """unflatten(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a)"""
84388438
84398439 self_size = op .Shape (self )
84408440
84418441 # PyTorch accepts negative dim as reversed counting
8442- self_rank = op . Size ( self_size )
8443- dim = self_rank + dim
8444- dim = dim % self_rank
8442+ self_rank = len ( self . shape )
8443+ if dim < 0 :
8444+ dim = self_rank + dim
84458445
84468446 head_start_idx = op .Constant (value_ints = [0 ])
84478447 head_end_idx = op .Reshape (dim , op .Constant (value_ints = [1 ]))
@@ -8451,8 +8451,16 @@ def aten_unflatten(self: TReal, dim: INT64, sizes: INT64):
84518451 tail_end_idx = op .Constant (value_ints = [_INT64_MAX ])
84528452 tail_part_rank = op .Slice (self_size , tail_start_idx , tail_end_idx )
84538453
8454- final_shape = op .Concat ( head_part_rank , sizes , tail_part_rank , axis = 0 )
8454+ sizes = [ op .Reshape ( size , op . Constant ( value_ints = [ 1 ])) for size in sizes ]
84558455
8456+ # corner case 1: head part is None
8457+ if dim == 0 :
8458+ final_shape = op .Concat (* sizes , tail_part_rank , axis = 0 )
8459+ # corner case 2: tail part is None
8460+ elif dim == self_rank - 1 :
8461+ final_shape = op .Concat (head_part_rank , * sizes , axis = 0 )
8462+ else :
8463+ final_shape = op .Concat (head_part_rank , * sizes , tail_part_rank , axis = 0 )
84568464 return op .Reshape (self , final_shape )
84578465
84588466
0 commit comments