Skip to content

Conversation

@AmosLewis
Copy link
Collaborator

@AmosLewis AmosLewis commented Oct 4, 2024

test torch.scatter.reduce linalg lowering llvm/torch-mlir#3754

torch.scatter.reduce step by step example:

src = [1, 2, 3, 4, 5, 6]
index = [0, 1, 0, 1, 2, 1]
self = [1, 2, 3, 4]
Step 0:
self[index[0]] += src[0]
self[0] += 1  = 1+1 = 2
1+1 = 2
self = [2, 2, 3, 4])

Step 1:
self[index[1]] += src[1]
self[1] += 2  = 2+2 = 4
self = [2, 4, 3, 4])

Step 2:
self[index[2]] += src[2]
self[0] += 3  = 2+3 = 5
self = [5, 4, 3, 4])

Step 3:
self[index[3]] += src[3]
self[1] += 4  = 4+4 = 8
self = [5, 8, 3, 4])

Step 4:
self[index[4]] += src[4]
self[2] += 5  = 3+5 = 8
self = [5, 8, 8, 4])

Step 5:
self[index[5]] += src[5]
self[1] += 6  = 8+6 = 14
self = [5, 14, 8, 4])

@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Oct 4, 2024

python -m torch_mlir.tools.import_onnx --opset-version=21 model.onnx -o ScatterElements.default.torch-onnx.mlir ScatterElements.default.torch-onnx.mlir

module {
  func.func @scatter_graph(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<[0, 1, 0, 1, 2, 1]> : tensor<6xsi64>} : () -> !torch.vtensor<[6],si64> 
    %1 = torch.operator "onnx.ScatterElements"(%arg0, %0, %arg1) {torch.onnx.axis = 0 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[4],f32>, !torch.vtensor<[6],si64>, !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> 
    return %1 : !torch.vtensor<[4],f32>
  }
}

torch-mlir-opt -pass-pipeline='builtin.module(func.func(convert-torch-onnx-to-torch),torch-lower-to-backend-contract,func.func(cse,canonicalize))' ScatterElements.default.torch-onnx.mlir > ScatterElements.default.onnx.torch.mlir ScatterElements.default.onnx.torch.mlir

module {
  func.func @scatter_graph(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %true = torch.constant.bool true
    %str = torch.constant.str "sum"
    %0 = torch.vtensor.literal(dense<[0, 1, 0, 1, 2, 1]> : tensor<6xsi64>) : !torch.vtensor<[6],si64>
    %int0 = torch.constant.int 0
    %1 = torch.aten.scatter_reduce.two %arg0, %int0, %0, %arg1, %str, %true : !torch.vtensor<[4],f32>, !torch.int, !torch.vtensor<[6],si64>, !torch.vtensor<[6],f32>, !torch.str, !torch.bool -> !torch.vtensor<[4],f32>
    return %1 : !torch.vtensor<[4],f32>
  }
}

torch-mlir-opt --convert-torch-onnx-to-torch --torch-decompose-complex-ops --cse --canonicalize --convert-torch-to-linalg ScatterElements.default.onnx.torch.mlir > linalg.mlir linalg.mlir

#map = affine_map<(d0) -> (d0, 0)>
#map1 = affine_map<(d0) -> (d0)>
module {
  func.func @scatter_graph(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %0 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[6],f32> -> tensor<6xf32>
    %1 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[4],f32> -> tensor<4xf32>
    %true = torch.constant.bool true
    %str = torch.constant.str "sum"
    %2 = torch.vtensor.literal(dense<[0, 1, 0, 1, 2, 1]> : tensor<6xsi64>) : !torch.vtensor<[6],si64>
    %3 = torch_c.to_builtin_tensor %2 : !torch.vtensor<[6],si64> -> tensor<6xi64>
    %int0 = torch.constant.int 0
    %c0 = arith.constant 0 : index
    %c6 = arith.constant 6 : index
    %c1 = arith.constant 1 : index
    %4 = arith.muli %c1, %c6 : index
    %5 = arith.index_cast %4 : index to i64
    %6 = arith.index_cast %5 : i64 to index
    %c0_0 = arith.constant 0 : index
    %c6_1 = arith.constant 6 : index
    %c1_2 = arith.constant 1 : index
    %7 = tensor.empty(%6) : tensor<?x1xi32>
    %c0_i32 = arith.constant 0 : i32
    %8 = linalg.fill ins(%c0_i32 : i32) outs(%7 : tensor<?x1xi32>) -> tensor<?x1xi32>
    %9 = tensor.empty(%6) : tensor<?xf32>
    %cst = arith.constant 0.000000e+00 : f32
    %10 = linalg.fill ins(%cst : f32) outs(%9 : tensor<?xf32>) -> tensor<?xf32>
    %11:2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} outs(%8, %10 : tensor<?x1xi32>, tensor<?xf32>) {
    ^bb0(%out: i32, %out_13: f32):
      %16 = linalg.index 0 : index
      %17 = arith.remsi %16, %c6_1 : index
      %18 = arith.divsi %16, %c6_1 : index
      %extracted = tensor.extract %3[%17] : tensor<6xi64>
      %extracted_14 = tensor.extract %0[%17] : tensor<6xf32>
      %19 = arith.index_cast %17 : index to i64
      %20 = arith.trunci %19 : i64 to i32
      %21 = arith.trunci %extracted : i64 to i32
      linalg.yield %21, %extracted_14 : i32, f32
    } -> (tensor<?x1xi32>, tensor<?xf32>)
    %c0_3 = arith.constant 0 : index
    %c0_4 = arith.constant 0 : index
    %c1_5 = arith.constant 1 : index
    %c1_6 = arith.constant 1 : index
    %c1_7 = arith.constant 1 : index
    %12 = tensor.empty(%6) : tensor<?x1xi32>
    %c0_i32_8 = arith.constant 0 : i32
    %13 = linalg.fill ins(%c0_i32_8 : i32) outs(%12 : tensor<?x1xi32>) -> tensor<?x1xi32>
    %c0_9 = arith.constant 0 : index
    %dim = tensor.dim %11#0, %c0_9 : tensor<?x1xi32>
    %c1_10 = arith.constant 1 : index
    %c1_11 = arith.constant 1 : index
    %inserted_slice = tensor.insert_slice %11#0 into %13[0, 0] [%dim, 1] [1, 1] : tensor<?x1xi32> into tensor<?x1xi32>
    %c1_12 = arith.constant 1 : index
    %14 = tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(false) ins(%11#1, %inserted_slice : tensor<?xf32>, tensor<?x1xi32>) outs(%1 : tensor<4xf32>) {
    ^bb0(%arg2: f32, %arg3: f32):
      %16 = arith.addf %arg2, %arg3 : f32
      tm_tensor.yield %16 : f32
    } -> tensor<4xf32>
    %cast = tensor.cast %14 : tensor<4xf32> to tensor<4xf32>
    %15 = torch_c.from_builtin_tensor %cast : tensor<4xf32> -> !torch.vtensor<[4],f32>
    return %15 : !torch.vtensor<[4],f32>
  }
}

@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Oct 4, 2024

Pass by most recent patch

Status report for run: test-run using mode:onnx todtype:default backend:llvm-cpu

| tests                          | model-run   | onnx-import   | torch-mlir   | iree-compile   | inference   |
|:-------------------------------|:------------|:--------------|:-------------|:---------------|:------------|
| onnx/operators/ScatterElements | passed      | passed        | passed       | passed         | passed      |

AmosLewis added a commit to llvm/torch-mlir that referenced this pull request Oct 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant