@@ -261,15 +261,16 @@ func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %ar
261261
262262// CHECK-LABEL: func.func @test_scatter_elements_with_duplicate_indices
263263func.func @test_scatter_elements_with_duplicate_indices (%arg0: !torch.vtensor <[1 ,5 ],f32 >, %arg1: !torch.vtensor <[1 ,2 ],si64 >, %arg2: !torch.vtensor <[1 ,2 ],f32 >) -> !torch.vtensor <[1 ,5 ],f32 > attributes {torch.onnx_meta.ir_version = 8 : si64 , torch.onnx_meta.opset_version = 18 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
264- // CHECK: %[[AXIS:.*]] = torch.constant.int 1
265- // CHECK: %[[ZERO:.+]] = torch.constant.int 0
266- // CHECK: %[[ONE:.+]] = torch.constant.int 1
267- // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
268- // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
269- // CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
270- // CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
271- // CHECK: %[[STR:.*]] = torch.constant.str "add"
272- // CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
264+ // CHECK: %[[AXIS:.*]] = torch.constant.int 1
265+ // CHECK: %[[ZERO:.*]] = torch.constant.int 0
266+ // CHECK: %[[FIVE:.*]] = torch.constant.int 1
267+ // CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int
268+ // CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64>
269+ // CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1>
270+ // CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64>
271+ // CHECK: %[[STR:.*]] = torch.constant.str "sum"
272+ // CHECK: %[[TRUE:.*]] = torch.constant.bool true
273+ // CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32>
273274 %0 = torch.operator " onnx.ScatterElements" (%arg0 , %arg1 , %arg2 ) {torch.onnx.axis = 1 : si64 , torch.onnx.reduction = " add" } : (!torch.vtensor <[1 ,5 ],f32 >, !torch.vtensor <[1 ,2 ],si64 >, !torch.vtensor <[1 ,2 ],f32 >) -> !torch.vtensor <[1 ,5 ],f32 >
274275 return %0 : !torch.vtensor <[1 ,5 ],f32 >
275276}
@@ -294,15 +295,16 @@ func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>,
294295
295296// CHECK-LABEL: func.func @test_scatter_elements_with_reduction_mul
296297func.func @test_scatter_elements_with_reduction_mul (%arg0: !torch.vtensor <[1 ,5 ],f32 >, %arg1: !torch.vtensor <[1 ,2 ],si64 >, %arg2: !torch.vtensor <[1 ,2 ],f32 >) -> !torch.vtensor <[1 ,5 ],f32 > attributes {torch.onnx_meta.ir_version = 8 : si64 , torch.onnx_meta.opset_version = 18 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
297- // CHECK: %[[AXIS:.*]] = torch.constant.int 1
298- // CHECK: %[[ZERO:.+]] = torch.constant.int 0
299- // CHECK: %[[ONE:.+]] = torch.constant.int 1
300- // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
301- // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
302- // CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
303- // CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
304- // CHECK: %[[STR:.*]] = torch.constant.str "multiply"
305- // CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
298+ // CHECK: %[[AXIS:.*]] = torch.constant.int 1
299+ // CHECK: %[[ZERO:.*]] = torch.constant.int 0
300+ // CHECK: %[[FIVE:.*]] = torch.constant.int 1
301+ // CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int
302+ // CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64>
303+ // CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1>
304+ // CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64>
305+ // CHECK: %[[STR:.*]] = torch.constant.str "prod"
306+ // CHECK: %[[TRUE:.*]] = torch.constant.bool true
307+ // CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32>
306308 %0 = torch.operator " onnx.ScatterElements" (%arg0 , %arg1 , %arg2 ) {torch.onnx.axis = 1 : si64 , torch.onnx.reduction = " mul" } : (!torch.vtensor <[1 ,5 ],f32 >, !torch.vtensor <[1 ,2 ],si64 >, !torch.vtensor <[1 ,2 ],f32 >) -> !torch.vtensor <[1 ,5 ],f32 >
307309 return %0 : !torch.vtensor <[1 ,5 ],f32 >
308310}
0 commit comments