@@ -160,3 +160,58 @@ func.func @unsqueeze_squeeze_combo(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !t
160160 %14 = torch.aten.item %13 : !torch.vtensor <[1 ],si64 > -> !torch.int
161161 return %14 : !torch.int
162162}
163+
164+
165+ // -----
166+
167+ // CHECK-LABEL: @eq_tensor_and_where_self
168+ func.func @eq_tensor_and_where_self (%arg0: !torch.vtensor <[?,?],si64 >) -> !torch.vtensor <[4 ],si64 > {
169+ // CHECK-DAG: %[[false:.*]] = torch.constant.bool false
170+ // CHECK-DAG: %[[none:.*]] = torch.constant.none
171+ // CHECK-DAG: %[[I1:.*]] = torch.constant.int 1
172+ // CHECK-DAG: %[[I0:.*]] = torch.constant.int 0
173+ // CHECK-DAG: %[[DIM1:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
174+ // CHECK-DAG: %[[DIM0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
175+ // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[I1]], %[[DIM1]], %[[DIM1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
176+ // CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[none]], %[[none]], %[[false]] : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],si64>
177+ // CHECK: return %[[TENSOR]] : !torch.vtensor<[4],si64>
178+ %none = torch.constant.none
179+ %0 = torch.vtensor.literal (dense <-1 > : tensor <4 xsi64 >) : !torch.vtensor <[4 ],si64 >
180+ %1 = torch.vtensor.literal (dense <1 > : tensor <4 xsi64 >) : !torch.vtensor <[4 ],si64 >
181+ %false = torch.constant.bool false
182+ %int1 = torch.constant.int 1
183+ %int0 = torch.constant.int 0
184+ %2 = torch.aten.size.int %arg0 , %int1 : !torch.vtensor <[?,?],si64 >, !torch.int -> !torch.int
185+ %3 = torch.aten.size.int %arg0 , %int0 : !torch.vtensor <[?,?],si64 >, !torch.int -> !torch.int
186+ %4 = torch.prim.ListConstruct %3 , %int1 , %2 , %2 : (!torch.int , !torch.int , !torch.int , !torch.int ) -> !torch.list <int >
187+ %5 = torch.aten.tensor %4 , %none , %none , %false : !torch.list <int >, !torch.none , !torch.none , !torch.bool -> !torch.vtensor <[4 ],si64 >
188+ %6 = torch.aten.eq.Tensor %5 , %0 : !torch.vtensor <[4 ],si64 >, !torch.vtensor <[4 ],si64 > -> !torch.vtensor <[4 ],i1 >
189+ %7 = torch.aten.where.self %6 , %1 , %5 : !torch.vtensor <[4 ],i1 >, !torch.vtensor <[4 ],si64 >, !torch.vtensor <[4 ],si64 > -> !torch.vtensor <[4 ],si64 >
190+ return %7 : !torch.vtensor <[4 ],si64 >
191+ }
192+
193+
194+ // -----
195+
196+ // CHECK-LABEL: @eq_tensor_from_tensor_and_literal
197+ func.func @eq_tensor_from_tensor_and_literal (%arg0: !torch.vtensor <[?,?],si64 >) -> !torch.vtensor <[4 ],i1 > {
198+ // CHECK-DAG: %[[none:.*]] = torch.constant.none
199+ // CHECK-DAG: %[[false:.*]] = torch.constant.bool false
200+ // CHECK-DAG: %[[true:.*]] = torch.constant.bool true
201+ // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[false]], %[[true]], %[[false]], %[[false]] : (!torch.bool, !torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
202+ // CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[none]], %[[none]], %[[false]] : !torch.list<bool>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],i1>
203+ // CHECK: return %[[TENSOR]] : !torch.vtensor<[4],i1>
204+ %none = torch.constant.none
205+ %0 = torch.vtensor.literal (dense <-1 > : tensor <4 xsi64 >) : !torch.vtensor <[4 ],si64 >
206+ %1 = torch.vtensor.literal (dense <1 > : tensor <4 xsi64 >) : !torch.vtensor <[4 ],si64 >
207+ %false = torch.constant.bool false
208+ %int1 = torch.constant.int 1
209+ %int -1 = torch.constant.int -1
210+ %int0 = torch.constant.int 0
211+ %2 = torch.aten.size.int %arg0 , %int1 : !torch.vtensor <[?,?],si64 >, !torch.int -> !torch.int
212+ %3 = torch.aten.size.int %arg0 , %int0 : !torch.vtensor <[?,?],si64 >, !torch.int -> !torch.int
213+ %4 = torch.prim.ListConstruct %3 , %int -1 , %2 , %2 : (!torch.int , !torch.int , !torch.int , !torch.int ) -> !torch.list <int >
214+ %5 = torch.aten.tensor %4 , %none , %none , %false : !torch.list <int >, !torch.none , !torch.none , !torch.bool -> !torch.vtensor <[4 ],si64 >
215+ %6 = torch.aten.eq.Tensor %5 , %0 : !torch.vtensor <[4 ],si64 >, !torch.vtensor <[4 ],si64 > -> !torch.vtensor <[4 ],i1 >
216+ return %6 : !torch.vtensor <[4 ],i1 >
217+ }
0 commit comments