@@ -72,3 +72,91 @@ func.func @shape_as_tensor_slice(%arg0 : !torch.vtensor<[5,?,?,?],f32>) -> !torc
7272 %slice = torch.aten.slice.Tensor %shape , %dim , %start , %end , %step : !torch.vtensor <[4 ], si32 >, !torch.int , !torch.int , !torch.int , !torch.int -> !torch.vtensor <[2 ], si32 >
7373 return %slice : !torch.vtensor <[2 ],si32 >
7474}
75+
76+
77+ // -----
78+
79+ // CHECK-LABEL: @view_as_flatten_static
80+ func.func @view_as_flatten_static (%arg0: !torch.vtensor <[?,?,16 ,64 ],f32 >) -> !torch.vtensor <[?,?,1024 ],f32 > {
81+ // CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2
82+ // CHECK-DAG: %[[THREE:.*]] = torch.constant.int 3
83+ // CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[THREE]] : !torch.vtensor<[?,?,16,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,1024],f32>
84+ // CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,1024],f32>
85+ %int1024 = torch.constant.int 1024
86+ %int1 = torch.constant.int 1
87+ %int0 = torch.constant.int 0
88+ %0 = torch.aten.size.int %arg0 , %int0 : !torch.vtensor <[?,?,16 ,64 ],f32 >, !torch.int -> !torch.int
89+ %1 = torch.aten.size.int %arg0 , %int1 : !torch.vtensor <[?,?,16 ,64 ],f32 >, !torch.int -> !torch.int
90+ %2 = torch.prim.ListConstruct %0 , %1 , %int1024 : (!torch.int , !torch.int , !torch.int ) -> !torch.list <int >
91+ %3 = torch.aten.view %arg0 , %2 : !torch.vtensor <[?,?,16 ,64 ],f32 >, !torch.list <int > -> !torch.vtensor <[?,?,1024 ],f32 >
92+ return %3 : !torch.vtensor <[?,?,1024 ],f32 >
93+ }
94+
95+
96+ // -----
97+
98+ // CHECK-LABEL: @view_as_unflatten_static
99+ func.func @view_as_unflatten_static (%arg0: !torch.vtensor <[?,?,1024 ],f32 >) -> !torch.vtensor <[?,?,16 ,64 ],f32 > {
100+ // CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2
101+ // CHECK-DAG: %[[CST16:.*]] = torch.constant.int 16
102+ // CHECK-DAG: %[[CST64:.*]] = torch.constant.int 64
103+ // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[CST16]], %[[CST64]] : (!torch.int, !torch.int) -> !torch.list<int>
104+ // CHECK: %[[FLAT:.*]] = torch.aten.unflatten.int %arg0, %[[TWO]], %[[LIST]] : !torch.vtensor<[?,?,1024],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[?,?,16,64],f32>
105+ // CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,16,64],f32>
106+ %int16 = torch.constant.int 16
107+ %int64 = torch.constant.int 64
108+ %int1 = torch.constant.int 1
109+ %int0 = torch.constant.int 0
110+ %0 = torch.aten.size.int %arg0 , %int0 : !torch.vtensor <[?,?,1024 ],f32 >, !torch.int -> !torch.int
111+ %1 = torch.aten.size.int %arg0 , %int1 : !torch.vtensor <[?,?,1024 ],f32 >, !torch.int -> !torch.int
112+ %2 = torch.prim.ListConstruct %0 , %1 , %int16 , %int64 : (!torch.int , !torch.int , !torch.int , !torch.int ) -> !torch.list <int >
113+ %3 = torch.aten.view %arg0 , %2 : !torch.vtensor <[?,?,1024 ],f32 >, !torch.list <int > -> !torch.vtensor <[?,?,16 ,64 ],f32 >
114+ return %3 : !torch.vtensor <[?,?,16 ,64 ],f32 >
115+ }
116+
117+
118+ // -----
119+
120+ // CHECK-LABEL: @view_as_flatten_dynamic
121+ func.func @view_as_flatten_dynamic (%arg0: !torch.vtensor <[?,?,?,?],f32 >) -> !torch.vtensor <[?,?,?],f32 > {
122+ // CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2
123+ // CHECK-DAG: %[[THREE:.*]] = torch.constant.int 3
124+ // CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[THREE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?],f32>
125+ // CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,?],f32>
126+ %int -1 = torch.constant.int -1
127+ %int1 = torch.constant.int 1
128+ %int0 = torch.constant.int 0
129+ %0 = torch.aten.size.int %arg0 , %int0 : !torch.vtensor <[?,?,?,?],f32 >, !torch.int -> !torch.int
130+ %1 = torch.aten.size.int %arg0 , %int1 : !torch.vtensor <[?,?,?,?],f32 >, !torch.int -> !torch.int
131+ %2 = torch.prim.ListConstruct %0 , %1 , %int -1 : (!torch.int , !torch.int , !torch.int ) -> !torch.list <int >
132+ %3 = torch.aten.view %arg0 , %2 : !torch.vtensor <[?,?,?,?],f32 >, !torch.list <int > -> !torch.vtensor <[?,?,?],f32 >
133+ return %3 : !torch.vtensor <[?,?,?],f32 >
134+ }
135+
136+
137+ // -----
138+
139+ // CHECK-LABEL: @unsqueeze_squeeze_combo
140+ func.func @unsqueeze_squeeze_combo (%arg0: !torch.vtensor <[?,?,16 ,64 ],f32 >) -> !torch.int {
141+ // CHECK: %int0 = torch.constant.int 0
142+ // CHECK: %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int
143+ // CHECK: return %0 : !torch.int
144+ %0 = torch.vtensor.literal (dense <1 > : tensor <1 xsi64 >) : !torch.vtensor <[1 ],si64 >
145+ %1 = torch.vtensor.literal (dense <0 > : tensor <1 xsi64 >) : !torch.vtensor <[1 ],si64 >
146+ %2 = torch.vtensor.literal (dense <1024 > : tensor <1 xsi64 >) : !torch.vtensor <[1 ],si64 >
147+ %int1 = torch.constant.int 1
148+ %int0 = torch.constant.int 0
149+ %3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor <[?,?,16 ,64 ],f32 > -> !torch.vtensor <[4 ],si64 >
150+ %4 = torch.aten.index_select %3 , %int0 , %1 : !torch.vtensor <[4 ],si64 >, !torch.int , !torch.vtensor <[1 ],si64 > -> !torch.vtensor <[1 ],si64 >
151+ %5 = torch.aten.squeeze.dim %4 , %int0 : !torch.vtensor <[1 ],si64 >, !torch.int -> !torch.vtensor <[],si64 >
152+ %6 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor <[?,?,16 ,64 ],f32 > -> !torch.vtensor <[4 ],si64 >
153+ %7 = torch.aten.index_select %6 , %int0 , %0 : !torch.vtensor <[4 ],si64 >, !torch.int , !torch.vtensor <[1 ],si64 > -> !torch.vtensor <[1 ],si64 >
154+ %8 = torch.aten.squeeze.dim %7 , %int0 : !torch.vtensor <[1 ],si64 >, !torch.int -> !torch.vtensor <[],si64 >
155+ %9 = torch.aten.unsqueeze %5 , %int0 : !torch.vtensor <[],si64 >, !torch.int -> !torch.vtensor <[1 ],si64 >
156+ %10 = torch.aten.unsqueeze %8 , %int0 : !torch.vtensor <[],si64 >, !torch.int -> !torch.vtensor <[1 ],si64 >
157+ %11 = torch.prim.ListConstruct %9 , %10 , %2 : (!torch.vtensor <[1 ],si64 >, !torch.vtensor <[1 ],si64 >, !torch.vtensor <[1 ],si64 >) -> !torch.list <vtensor >
158+ %12 = torch.aten.cat %11 , %int0 : !torch.list <vtensor >, !torch.int -> !torch.vtensor <[3 ],si64 >
159+ %13 = torch.aten.slice.Tensor %12 , %int0 , %int0 , %int1 , %int1 : !torch.vtensor <[3 ],si64 >, !torch.int , !torch.int , !torch.int , !torch.int -> !torch.vtensor <[1 ],si64 >
160+ %14 = torch.aten.item %13 : !torch.vtensor <[1 ],si64 > -> !torch.int
161+ return %14 : !torch.int
162+ }
0 commit comments