Skip to content

Commit 6550542

Browse files
committed
add a few tests
1 parent 357db3e commit 6550542

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

test/Dialect/Torch/scalarize-shapes.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,58 @@ func.func @unsqueeze_squeeze_combo(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !t
160160
%14 = torch.aten.item %13 : !torch.vtensor<[1],si64> -> !torch.int
161161
return %14 : !torch.int
162162
}
163+
164+
165+
// -----
166+
167+
// CHECK-LABEL: @eq_tensor_and_where_self
168+
func.func @eq_tensor_and_where_self(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[4],si64> {
169+
// CHECK-DAG: %[[false:.*]] = torch.constant.bool false
170+
// CHECK-DAG: %[[none:.*]] = torch.constant.none
171+
// CHECK-DAG: %[[I1:.*]] = torch.constant.int 1
172+
// CHECK-DAG: %[[I0:.*]] = torch.constant.int 0
173+
// CHECK-DAG: %[[DIM1:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
174+
// CHECK-DAG: %[[DIM0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
175+
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[I1]], %[[DIM1]], %[[DIM1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
176+
// CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[none]], %[[none]], %[[false]] : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],si64>
177+
// CHECK: return %[[TENSOR]] : !torch.vtensor<[4],si64>
178+
%none = torch.constant.none
179+
%0 = torch.vtensor.literal(dense<-1> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
180+
%1 = torch.vtensor.literal(dense<1> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
181+
%false = torch.constant.bool false
182+
%int1 = torch.constant.int 1
183+
%int0 = torch.constant.int 0
184+
%2 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
185+
%3 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
186+
%4 = torch.prim.ListConstruct %3, %int1, %2, %2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
187+
%5 = torch.aten.tensor %4, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],si64>
188+
%6 = torch.aten.eq.Tensor %5, %0 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1>
189+
%7 = torch.aten.where.self %6, %1, %5 : !torch.vtensor<[4],i1>, !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],si64>
190+
return %7 : !torch.vtensor<[4],si64>
191+
}
192+
193+
194+
// -----
195+
196+
// CHECK-LABEL: @eq_tensor_from_tensor_and_literal
197+
func.func @eq_tensor_from_tensor_and_literal(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[4],i1> {
198+
// CHECK-DAG: %[[none:.*]] = torch.constant.none
199+
// CHECK-DAG: %[[false:.*]] = torch.constant.bool false
200+
// CHECK-DAG: %[[true:.*]] = torch.constant.bool true
201+
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[false]], %[[true]], %[[false]], %[[false]] : (!torch.bool, !torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
202+
// CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[none]], %[[none]], %[[false]] : !torch.list<bool>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],i1>
203+
// CHECK: return %[[TENSOR]] : !torch.vtensor<[4],i1>
204+
%none = torch.constant.none
205+
%0 = torch.vtensor.literal(dense<-1> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
206+
%1 = torch.vtensor.literal(dense<1> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
207+
%false = torch.constant.bool false
208+
%int1 = torch.constant.int 1
209+
%int-1 = torch.constant.int -1
210+
%int0 = torch.constant.int 0
211+
%2 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
212+
%3 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
213+
%4 = torch.prim.ListConstruct %3, %int-1, %2, %2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
214+
%5 = torch.aten.tensor %4, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],si64>
215+
%6 = torch.aten.eq.Tensor %5, %0 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1>
216+
return %6 : !torch.vtensor<[4],i1>
217+
}

0 commit comments

Comments
 (0)