Skip to content

Conversation

@chentong319
Copy link
Collaborator

@chentong319 chentong319 commented May 13, 2025

ONNX model provides symbolic name for dynamic dimension of input tensor with argument attribute "onnx.dim_params". This PR tries to propagate this symbolic name for dynamic dimension during shape inference with extra attribute to operation "onnx.dim_params_[n]", where n is the index of the output tensor.
The major change is on IndexExpr to carry such info, if it exists. Other changes include the use of the dim_param info in shape inference and test case.
So far, the usu of dim_param info is far from complete. The PR is kind of experimental.
At least one benefit of this approach is that the readability of the output of onnx-mlir can be improved with the symbolic names.
Another potential benefit is the reuse the existing shape inference code to do symbolic dim propagation. For example, without touching the shape inference for onnx.MatMul, we can get the result of the following test case:

func.func @test_matmul_2_param(%arg0 : tensor<16x?x64x42xf32> {onnx.dim_params="1:dim1"}, %arg1 : tensor<42x?xf32> {onnx.dim_params="1:dim2"}) -> tensor<*xf32> {
  %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<42x?xf32>) -> tensor<*xf32>
  "onnx.Return"(%0) : (tensor<*xf32>) -> ()
}

Result from onnx-mlir-opt --shape-inference

  func.func @test_matmul_2_param(%arg0: tensor<16x?x64x42xf32> {onnx.dim_params = "1:dim1"}, %arg1: tensor<42x?xf32> {onnx.dim_params = "1:dim2"}) -> tensor<16x?x64x?xf32> {
    %0 = "onnx.MatMul"(%arg0, %arg1) {onnx.dim_params_0 = "1:dim1,3:dim2"} : (tensor<16x?x64x42xf32>, tensor<42x?xf32>) -> tensor<16x?x64x?xf32>
    onnx.Return %0 : tensor<16x?x64x?xf32>
  }

Signed-off-by: Chen Tong <[email protected]>
Signed-off-by: Chen Tong <[email protected]>
Signed-off-by: Chen Tong <[email protected]>
Signed-off-by: Chen Tong <[email protected]>
if (dimParamsStr == "")
return;
StringAttr dimParamsAttr = StringAttr::get(op->getContext(), dimParamsStr);
op->setAttr("onnx.dim_params_" + std::to_string(resultIndex),
Copy link
Collaborator

@tungld tungld May 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: how about using onnx.out_dim_params_? I thought onnx.dim_params_0 was for the first input, but actually it was for the first output.

Copy link
Collaborator

@tungld tungld left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with a minor change in naming.

The propagation looks effective. Let's see how we can utilize this information in future PRs.

For the changes to IndexExpr, I would like to hear comments from @AlexandreEichenberger as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants