diff --git a/src/Dialect/Mlir/CMakeLists.txt b/src/Dialect/Mlir/CMakeLists.txt index 01b4062aa4..be01a16345 100644 --- a/src/Dialect/Mlir/CMakeLists.txt +++ b/src/Dialect/Mlir/CMakeLists.txt @@ -13,6 +13,7 @@ add_onnx_mlir_library(OMMlirDialects LINK_LIBS PUBLIC OMCompilerOptions + OMONNXOps MLIRMathDialect MLIRAffineDialect MLIRSCFDialect diff --git a/src/Dialect/Mlir/IndexExpr.cpp b/src/Dialect/Mlir/IndexExpr.cpp index 65ac1f30ce..39b905ab23 100644 --- a/src/Dialect/Mlir/IndexExpr.cpp +++ b/src/Dialect/Mlir/IndexExpr.cpp @@ -237,6 +237,8 @@ bool IndexExpr::hasAffineExpr() const { return getObj().hasAffineExpr(); } bool IndexExpr::hasValue() const { return getObj().hasValue(); } +bool IndexExpr::hasDimParam() const { return getObj().hasDimParam(); } + //===----------------------------------------------------------------------===// // IndexExpr list queries. //===----------------------------------------------------------------------===// @@ -413,6 +415,8 @@ AffineExpr IndexExpr::getAffineExpr() const { Value IndexExpr::getValue() const { return getObj().getValue(); } +std::string IndexExpr::getDimParam() const { return getObj().getDimParam(); } + void IndexExpr::getAffineMapAndOperands( AffineMap &map, SmallVectorImpl &operands) const { assert(!isFloat() && "attempt to get affine map of a float index expr"); diff --git a/src/Dialect/Mlir/IndexExpr.hpp b/src/Dialect/Mlir/IndexExpr.hpp index 0ff8658f35..f190236cea 100644 --- a/src/Dialect/Mlir/IndexExpr.hpp +++ b/src/Dialect/Mlir/IndexExpr.hpp @@ -452,6 +452,7 @@ class IndexExpr { } bool hasAffineExpr() const; bool hasValue() const; + bool hasDimParam() const; // Value/values has/have to be literal and satisfy the test. bool isLiteralAndIdenticalTo(int b) const; // Values equal. @@ -486,6 +487,7 @@ class IndexExpr { mlir::AffineMap &map, llvm::SmallVectorImpl &operands) const; mlir::Value getValue() const; int64_t getShape(bool uniqueQuestionMark = false) const; + std::string getDimParam() const; // Helpers for list of IndexExpressions: given a (list of) IndexExpr, provide // the (list of) Shape/Value/OpFoldResult corresponding to the original (list diff --git a/src/Dialect/Mlir/IndexExprBuilder.cpp b/src/Dialect/Mlir/IndexExprBuilder.cpp index 74d1944678..440426cc44 100644 --- a/src/Dialect/Mlir/IndexExprBuilder.cpp +++ b/src/Dialect/Mlir/IndexExprBuilder.cpp @@ -239,7 +239,12 @@ IndexExpr IndexExprBuilder::getValFromArray( else return DimIE(castedVal); } - return QuestionmarkIndexExpr(isFloat); + + if (isFloat) + return QuestionmarkIndexExpr(isFloat); + else + // Try to get more info from array[i] + return QuestionmarkIndexExpr(array, i); } IndexExpr IndexExprBuilder::getIntAsSymbol(Value value) { diff --git a/src/Dialect/Mlir/IndexExprDetail.cpp b/src/Dialect/Mlir/IndexExprDetail.cpp index c3753d9058..907b8881a8 100644 --- a/src/Dialect/Mlir/IndexExprDetail.cpp +++ b/src/Dialect/Mlir/IndexExprDetail.cpp @@ -24,6 +24,8 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" #include #define DEBUG_TYPE "index-expr" @@ -41,7 +43,7 @@ namespace onnx_mlir { IndexExprImpl::IndexExprImpl() : defined(false), literal(false), isFloat(false), kind(IndexExprKind::NonAffine), intLit(0), affineExpr(nullptr), - value(nullptr) { + value(nullptr), dimParam("") { // Set scope from thread private global. scope = IndexExprScope::getCurrentScopePtr(); assert(scope && "expected IndexExpr Scope to be defined"); @@ -66,6 +68,109 @@ void IndexExprImpl::initAsQuestionmark(int64_t const val, bool isFloatFlag) { IndexExprKind::Questionmark, val, AffineExpr(nullptr), Value(nullptr)); } +std::string getDimParamFromString(std::string dimParams, int64_t index) { + std::stringstream shapeInfoString(dimParams); + std::string dimString; + while (std::getline(shapeInfoString, dimString, ',')) { + size_t pos = dimString.find(':'); + std::string inputString = dimString.substr(0, pos); + std::string paramString = dimString.substr(pos + 1); + + int64_t inputID = std::stoi(inputString); + if (inputID == index) { + return (paramString); + } + } + return std::string(""); +} + +// Get DimParam from the direct defining op of the tensorOrMemref +std::string getDimParamFromDirectDefiningOpUtil( + Value tensorOrMemref, int64_t index) { + if (auto blockArg = llvm::dyn_cast(tensorOrMemref)) { + int64_t argIndex = blockArg.getArgNumber(); + Block *block = blockArg.getOwner(); + Operation *op = block->getParentOp(); + if (op && llvm::isa(op)) { + func::FuncOp funcOp = llvm::cast(op); + DictionaryAttr dictAttr = + mlir::function_interface_impl::getArgAttrDict(funcOp, argIndex); + if (dictAttr && dictAttr.contains(FUNC_DIM_PARAMS)) { + StringAttr dimParamAttr = mlir::cast( + dictAttr.getNamed(FUNC_DIM_PARAMS).value().getValue()); + return getDimParamFromString( + std::string(dimParamAttr.getValue().str()), index); + } + return std::string(""); + } else { + // ToFix, Loop, If and etc. + return std::string(""); + } + } else { + Operation *op = tensorOrMemref.getDefiningOp(); + if (!op) { + // func.func parameter? + return std::string(""); + } else { + // Get the info from attribute "onnx.out_dim_param_*" + auto opResult = llvm::cast(tensorOrMemref); + unsigned resultIndex = opResult.getResultNumber(); + Attribute dimParamAttr = + op->getAttr(OP_DIM_PARAMS + std::to_string(resultIndex)); + if (!dimParamAttr) + return std::string(""); + return getDimParamFromString( + std::string(llvm::cast(dimParamAttr).getValue().str()), + index); + } + } +} + +// Initialize a Questionmark with the value of val[index]. +// Assume that the existing code handles the constant case already. +// Here a Questionmark is generated, perhaps with dimParam info. +// To find out the info for dimParam, the definition chain of val will be +// inspected. The possible pattern is value from ConcatOp. + +std::string getDimParamForDimOp(Value val) { + auto dimOp = val.getDefiningOp(); + if (dimOp) { + Value dataOfDim = dimOp.getData(); + // Get the index of onnx.Dim + int64_t axis = dimOp.getAxis(); + // return std::string(std::to_string(axis)); + return getDimParamFromDirectDefiningOpUtil(dataOfDim, axis); + } + return std::string(""); +} + +static std::string getDimParamForIndexedValueUtil(Value val, int64_t index) { + // Pattern#1: The value comes from Concat. The index can be used to trace back + // the particular input of Concat. + // Copy code from src/Dialect/ONNX/ONNXOps/Tensor/Reshape + if (areDimsFromConcat(val)) { + SmallVector shapeDimVals; + // Question: need to check the shape of input of Concat? + getDims(val, shapeDimVals); + return getDimParamForDimOp(shapeDimVals[index]); + } + return std::string(""); +} + +std::string getDimParamUtil(Value tensorOrMemref, int64_t index) { + if (std::string resultString = + getDimParamFromDirectDefiningOpUtil(tensorOrMemref, index); + resultString != "") { + return resultString; + } else if (std::string resultString = + getDimParamForIndexedValueUtil(tensorOrMemref, index); + resultString != "") { + return resultString; + } else { + return std::string(""); + } +} + // Used for runtime dims; integer by default. void IndexExprImpl::initAsQuestionmark(Value tensorOrMemref, int64_t index) { // Each question mark is assigned a unique integer that is obtained @@ -78,6 +183,25 @@ void IndexExprImpl::initAsQuestionmark(Value tensorOrMemref, int64_t index) { init(/*isDefined*/ true, /*literal*/ false, /*isLitFloat, as this is for shapes*/ false, IndexExprKind::Questionmark, questionValue, AffineExpr(nullptr), Value(nullptr)); + + // Get the dimSymbol from the dim_params + // This symbol acts similar to questionValue, but predefined from onnx model + std::string dimSymbol = getDimParamUtil(tensorOrMemref, index); + if (dimSymbol != "") + dimParam = dimSymbol; +} + +void IndexExprImpl::initAsQuestionmarkForIndexedValue( + Value tensorOrMemref, int64_t index) { + llvm::hash_code questionValue = llvm::hash_combine( + mlir::hash_value(tensorOrMemref), llvm::hash_value(index)); + init(/*isDefined*/ true, /*literal*/ false, + /*isLitFloat, as this is for shapes*/ false, IndexExprKind::Questionmark, + questionValue, AffineExpr(nullptr), Value(nullptr)); + + std::string dimSymbol = getDimParamForIndexedValueUtil(tensorOrMemref, index); + if (dimSymbol != "") + dimParam = dimSymbol; } void IndexExprImpl::initAsLiteral(int64_t const val, const IndexExprKind kind) { @@ -329,6 +453,11 @@ bool IndexExprImpl::hasValue() const { return value != nullptr; } +bool IndexExprImpl::hasDimParam() const { + assert(isDefined()); + return dimParam != ""; +} + //===----------------------------------------------------------------------===// // IndexExprExpr getters. //===----------------------------------------------------------------------===// @@ -508,6 +637,11 @@ Value IndexExprImpl::getValue() { return value; } +std::string IndexExprImpl::getDimParam() { + // Should it be only for QuestionMark? + return dimParam; +} + //===----------------------------------------------------------------------===// // IndexExprExpr setters. //===----------------------------------------------------------------------===// diff --git a/src/Dialect/Mlir/IndexExprDetail.hpp b/src/Dialect/Mlir/IndexExprDetail.hpp index 177462c24c..d14523a58c 100644 --- a/src/Dialect/Mlir/IndexExprDetail.hpp +++ b/src/Dialect/Mlir/IndexExprDetail.hpp @@ -43,6 +43,11 @@ class IndexExprImpl { // question mark is assigned to a unique value hashed from the given // tensorOrMemref and dimension index. void initAsQuestionmark(mlir::Value tensorOrMemref, int64_t index); + // Initialize a question mark for a indexed value, the value of val[index] + // In general this value could be constant. Here only symbolic case is + // considered, and the dim_param symbol is propagated, if exists. + // In constrast, the above function is to initialize with the shape. + void initAsQuestionmarkForIndexedValue(mlir::Value val, int64_t index); void initAsLiteral(int64_t const value, IndexExprKind const kind); void initAsLiteral(double const value, IndexExprKind const kind); void initAsKind(mlir::Value const value, IndexExprKind const kind); @@ -69,6 +74,7 @@ class IndexExprImpl { bool isInCurrentScope() const; bool hasAffineExpr() const; bool hasValue() const; + bool hasDimParam() const; // Getters. IndexExprScope &getScope() const; @@ -83,6 +89,7 @@ class IndexExprImpl { void getAffineMapAndOperands( mlir::AffineMap &map, llvm::SmallVectorImpl &operands); mlir::Value getValue(); + std::string getDimParam(); // Setters. void setLiteral(int64_t val); @@ -122,6 +129,7 @@ class IndexExprImpl { mlir::AffineExpr affineExpr; // Value expression, may be defined whenever the IndexExpr is defined. mlir::Value value; + std::string dimParam; }; } // namespace onnx_mlir diff --git a/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp b/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp index 92fc63d07f..4b24736e3e 100644 --- a/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp @@ -67,6 +67,37 @@ int64_t getAxisInRange(int64_t axis, Value val, bool includeRank) { // ONNX Op Shape Helper //===----------------------------------------------------------------------===// +static void refineDimParams( + Operation *op, DimsExpr &inferredDims, Value output) { + // Get the index of output + if (!llvm::isa(output)) { + // output is a block parameter. It could be Func, Loop, If and etc. + // Give up now due to the complicated control flow. + return; + } + + auto opResult = llvm::cast(output); + unsigned resultIndex = opResult.getResultNumber(); + std::string dimParamsStr(""); + bool isFirst = true; + for (unsigned i = 0; i < inferredDims.size(); ++i) { + if (inferredDims[i].isQuestionmark() && inferredDims[i].hasDimParam()) { + if (isFirst) { + isFirst = false; + } else { + dimParamsStr.append(","); + } + dimParamsStr = dimParamsStr + std::to_string(i) + ":" + + inferredDims[i].getDimParam(); + } + } + if (dimParamsStr == "") + return; + StringAttr dimParamsAttr = StringAttr::get(op->getContext(), dimParamsStr); + op->setAttr( + OP_DIM_PARAMS + std::to_string(resultIndex), StringAttr(dimParamsAttr)); +} + /// Refine `inferredDims` using the output's shape if possible. For example, /// replacing a dynamic dim in `inferredDims` by a static dim in the output's /// shape. @@ -85,7 +116,8 @@ static void refineDims(Operation *op, DimsExpr &inferredDims, Value output) { "Inferred shape and existing shape are inconsistent in the number " "of elements"); - // Try to update inferredDim if existingDim is static. + refineDimParams(op, inferredDims, output); + for (unsigned i = 0; i < existingDims.size(); ++i) { // Safety checks for old convention of using -1 for dynamic. assert(existingDims[i] != -1 && "dynamic use kDynamic now"); @@ -377,6 +409,11 @@ LogicalResult ONNXBroadcastOpShapeHelper::customComputeShape( continue; } // Case: QuestionMark - QuestionMark + if (currentDimExpr.hasDimParam() && nextDimExpr.hasDimParam() && + currentDimExpr.getDimParam() == nextDimExpr.getDimParam()) { + // Same symbolic dim + continue; + } if (!hasUniBroadcasting) { dimsExpr[j] = IndexExpr::max(currentDimExpr, nextDimExpr); } @@ -404,6 +441,22 @@ bool ONNXBroadcastOpShapeHelper::hasNoBroadcast(DimAnalysis *dimAnalysis) { // broadcasting for any reasons, hasNoBroadcast is set to false. bool hasNoBroadcast = true; for (uint64_t r = 0; r < outputRank && hasNoBroadcast; ++r) { + // Check with dim_param info: if all input of this dimension has same + // dim_param, sameDyn will remain true, and further check of this dimension + // is no needed. + DimsExpr dimsInput0 = inputsDims[0]; + if (dimsInput0[r].isQuestionmark() && dimsInput0[r].hasDimParam()) { + bool sameDyn = true; + for (uint64_t i = 1; i < inputsDims.size(); i++) { + DimsExpr dims = inputsDims[i]; + if (!(dims[r].isQuestionmark() && dims[r].hasDimParam() && + dimsInput0[r].getDimParam() == dims[r].getDimParam())) { + sameDyn = false; + } + } + if (sameDyn) + continue; + } bool hasOne, hasOtherThanOne; hasOne = hasOtherThanOne = false; for (DimsExpr dims : inputsDims) { diff --git a/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp b/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp index 0a6267ee62..fdb5cf0ab7 100644 --- a/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp +++ b/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp @@ -38,6 +38,10 @@ namespace onnx_mlir { +// Define the attribute name for onnx.dim_param and its propagation +const std::string FUNC_DIM_PARAMS("onnx.dim_params"); +const std::string OP_DIM_PARAMS("onnx.out_dim_params_"); + //===----------------------------------------------------------------------===// // Support functions. //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Reshape.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Reshape.cpp index 104cffbdff..b3ee03def5 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Reshape.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Reshape.cpp @@ -159,8 +159,15 @@ LogicalResult ONNXReshapeOpShapeHelper::computeShape() { for (unsigned i = 0; i < outputRank; ++i) { if (hasShapeAndRank(data)) { IndexExpr dimShape = createIE->getIntFromArrayAsSymbol(shape, i); - outputDims[i] = outputDims[i].selectOrSelf( - dimShape == -1, numOfElements.floorDiv(numOfElementsFromShape)); + if (auto search = outputIgnoredDims.find(i); + search != outputIgnoredDims.end()) + // The outputIgnoreDims are dim with symbolic value matching a dim in + // data. Therefore, it can not be -1. The current folding of IndexExp + // can not propagate the dim_param info. + outputDims[i] = dimShape; + else + outputDims[i] = outputDims[i].selectOrSelf( + dimShape == -1, numOfElements.floorDiv(numOfElementsFromShape)); } else { // ToFix: can not check getAllowzero because the operandAdaptor is // constructed without attributes diff --git a/test/mlir/accelerators/nnpa/conversion/rewrite-onnx-for-zhigh-arch15.mlir b/test/mlir/accelerators/nnpa/conversion/rewrite-onnx-for-zhigh-arch15.mlir index 3e7d0b4c0d..c84fd75bf8 100644 --- a/test/mlir/accelerators/nnpa/conversion/rewrite-onnx-for-zhigh-arch15.mlir +++ b/test/mlir/accelerators/nnpa/conversion/rewrite-onnx-for-zhigh-arch15.mlir @@ -157,6 +157,8 @@ func.func @test_nd_qlinearmatmul_nd_nd(%arg0: tensor {onnx.dim_p // CHECK: } } +// ----- + func.func @test_nd_qlinearmatmul_nd_2d(%arg0: tensor {onnx.dim_params = "0:bs,1:sl"}, %arg1: tensor<64x384xf32>, %arg2: tensor, %arg3: tensor) -> tensor { %0 = "onnx.QuantizeLinear"(%arg0, %arg2, %arg3) : (tensor, tensor, tensor) -> tensor %1 = "onnx.QuantizeLinear"(%arg1, %arg2, %arg3) : (tensor<64x384xf32>, tensor, tensor) -> tensor<64x384xi8> @@ -191,6 +193,8 @@ func.func @test_nd_qlinearmatmul_nd_2d(%arg0: tensor {onnx.dim_p // CHECK: } } +// ----- + func.func @test_nd_qlinearmatmul_2d_nd(%arg0: tensor<384x64xf32>, %arg1: tensor {onnx.dim_params = "0:bs,1:sl"}, %arg2: tensor, %arg3: tensor) -> tensor { %0 = "onnx.QuantizeLinear"(%arg0, %arg2, %arg3) : (tensor<384x64xf32>, tensor, tensor) -> tensor<384x64xi8> %1 = "onnx.QuantizeLinear"(%arg1, %arg2, %arg3) : (tensor, tensor, tensor) -> tensor @@ -226,6 +230,8 @@ func.func @test_nd_qlinearmatmul_2d_nd(%arg0: tensor<384x64xf32>, %arg1: tensor< // CHECK: } } +// ----- + // Do not rewrite because of potential broadcasting. func.func @test_nd_qlinearmatmul_nd_nd_not_rewriting(%arg0: tensor {onnx.dim_params = "0:bs,1:sl"}, %arg1: tensor<1x?x64x384xf32> {onnx.dim_params = "1:sl"}, %arg2: tensor, %arg3: tensor) -> tensor { %0 = "onnx.QuantizeLinear"(%arg0, %arg2, %arg3) : (tensor, tensor, tensor) -> tensor @@ -236,10 +242,10 @@ func.func @test_nd_qlinearmatmul_nd_nd_not_rewriting(%arg0: tensor {onnx.dim_params = "0:bs,1:sl"}, [[PARAM_1_:%.+]]: tensor<1x?x64x384xf32> {onnx.dim_params = "1:sl"}, [[PARAM_2_:%.+]]: tensor, [[PARAM_3_:%.+]]: tensor) -> tensor { -// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.QuantizeLinear"([[PARAM_0_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = 1 : si64, saturate = 1 : si64} : (tensor, tensor, tensor) -> tensor -// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.QuantizeLinear"([[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = 1 : si64, saturate = 1 : si64} : (tensor<1x?x64x384xf32>, tensor, tensor) -> tensor<1x?x64x384xi8> -// CHECK: [[VAR_2_:%.+]] = "onnx.QLinearMatMul"([[VAR_0_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_2_]], [[PARAM_3_]]) : (tensor, tensor, tensor, tensor<1x?x64x384xi8>, tensor, tensor, tensor, tensor) -> tensor -// CHECK: [[VAR_3_:%.+]] = "onnx.DequantizeLinear"([[VAR_2_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = 1 : si64} : (tensor, tensor, tensor) -> tensor +// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.QuantizeLinear"([[PARAM_0_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = 1 : si64, onnx.out_dim_params_0 = "0:bs,1:sl", saturate = 1 : si64} : (tensor, tensor, tensor) -> tensor +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.QuantizeLinear"([[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = 1 : si64, onnx.out_dim_params_0 = "1:sl", saturate = 1 : si64} : (tensor<1x?x64x384xf32>, tensor, tensor) -> tensor<1x?x64x384xi8> +// CHECK: [[VAR_2_:%.+]] = "onnx.QLinearMatMul"([[VAR_0_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_2_]], [[PARAM_3_]]) {onnx.out_dim_params_0 = "0:bs,1:sl"} : (tensor, tensor, tensor, tensor<1x?x64x384xi8>, tensor, tensor, tensor, tensor) -> tensor +// CHECK: [[VAR_3_:%.+]] = "onnx.DequantizeLinear"([[VAR_2_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = 1 : si64, onnx.out_dim_params_0 = "0:bs,1:sl"} : (tensor, tensor, tensor) -> tensor // CHECK: return [[VAR_3_]] : tensor // CHECK: } } diff --git a/test/mlir/conversion/onnx_to_krnl/Tensor/Reshape_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Tensor/Reshape_with_canonicalize.mlir index a05fad928d..1d3b17343a 100644 --- a/test/mlir/conversion/onnx_to_krnl/Tensor/Reshape_with_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Tensor/Reshape_with_canonicalize.mlir @@ -3,11 +3,13 @@ // Adding canonicalize is important here as this is the only way to check the values of the map, // which are otherwise before the function, and thus are hard to test. -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 10)> +// ----- + func.func private @test_reshape(%arg0 : tensor, %arg1 : tensor<4xi64>) -> tensor<*xf32> { %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor, tensor<4xi64>) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 10)> // CHECK-LABEL: func.func private @test_reshape // CHECK-SAME: ([[PARAM_0_:%.+]]: memref, [[PARAM_1_:%.+]]: memref<4xi64>) -> memref { // CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index diff --git a/test/mlir/onnx/onnx_dim_analysis.mlir b/test/mlir/onnx/onnx_dim_analysis.mlir index 74f51c0f65..e1d091557f 100644 --- a/test/mlir/onnx/onnx_dim_analysis.mlir +++ b/test/mlir/onnx/onnx_dim_analysis.mlir @@ -1,5 +1,7 @@ // RUN: onnx-mlir-opt --onnx-dim-analysis %s -split-input-file | FileCheck %s +// ----- + // Check if dim_analysis takes into account the relationship between inputs via dim_params. func.func @test_dim_params_onnx_return(%arg0: tensor {onnx.dim_params = "0:M,1:N", onnx.name = "X"}, %arg1: tensor {onnx.dim_params = "0:M,1:P", onnx.name = "Y"}) -> (tensor {onnx.dim_params = "0:M,1:N", onnx.name = "Z"}) { %0 = "onnx.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor @@ -11,7 +13,7 @@ func.func @test_dim_params_onnx_return(%arg0: tensor {onnx.dim_params = // CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 1 : si64, group_id = 4 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () -// CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) {onnx.out_dim_params_0 = "0:M"} : (tensor, tensor) -> tensor // CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () // CHECK: onnx.Return [[VAR_0_]] : tensor @@ -31,7 +33,7 @@ func.func @test_dim_params_std_return(%arg0: tensor {onnx.dim_params = // CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 1 : si64, group_id = 4 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[PARAM_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () -// CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) {onnx.out_dim_params_0 = "0:M"} : (tensor, tensor) -> tensor // CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 1 : si64, group_id = 1 : si64} : (tensor) -> () // CHECK-DAG: "onnx.DimGroup"([[VAR_0_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () // CHECK: return [[VAR_0_]] : tensor diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index ec8145659e..386a5285b9 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -4046,7 +4046,18 @@ func.func @test_grid_sample_dim_shape3(%arg0: tensor, %arg1: tensor // ----- -// Test Binarizer Sample +func.func @dim_params_1(%arg0: tensor {onnx.dim_params = "0:batch_size"}, %arg1: tensor {onnx.dim_params = "0:batch_size"}) -> (tensor {onnx.name = "sum"}) { + %0 = "onnx.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %1 = "onnx.Add"(%0, %arg0) : (tensor, tensor) -> tensor + onnx.Return %1 : tensor +// CHECK-LABEL: func.func @dim_params_1 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor {onnx.dim_params = "0:batch_size"}, [[PARAM_1_:%.+]]: tensor {onnx.dim_params = "0:batch_size"}) -> (tensor {onnx.name = "sum"}) { +// CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) {onnx.out_dim_params_0 = "0:batch_size"} : (tensor, tensor) -> tensor +// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_0_]], [[PARAM_0_]]) {onnx.out_dim_params_0 = "0:batch_size"} : (tensor, tensor) -> tensor +// CHECK: onnx.Return [[VAR_1_]] : tensor +} + +// ----- func.func @test_binarizer(%arg0 : tensor) -> tensor<*xf32> { %0 = "onnx.Binarizer"(%arg0) {threshold = 1.0 : f32} : (tensor) -> tensor<*xf32> @@ -4061,6 +4072,20 @@ func.func @test_binarizer(%arg0 : tensor) -> tensor<*xf32> { // ----- + +func.func @test_matmul_2_param(%arg0 : tensor<16x?x64x42xf32> {onnx.dim_params="1:dim1"}, %arg1 : tensor<42x?xf32> {onnx.dim_params="1:dim2"}) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<42x?xf32>) -> tensor<*xf32> + "onnx.Return"(%0) : (tensor<*xf32>) -> () +// CHECK-LABEL: func.func @test_matmul_2_param +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<16x?x64x42xf32> {onnx.dim_params = "1:dim1"}, [[PARAM_1_:%.+]]: tensor<42x?xf32> {onnx.dim_params = "1:dim2"}) -> tensor<16x?x64x?xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[PARAM_1_]]) {onnx.out_dim_params_0 = "1:dim1,3:dim2"} : (tensor<16x?x64x42xf32>, tensor<42x?xf32>) -> tensor<16x?x64x?xf32> +// CHECK: onnx.Return [[VAR_0_]] : tensor<16x?x64x?xf32> +// CHECK: } +} + +// ----- + + func.func private @test_hammingwindow_shape(%arg0 : tensor<1xi32>) -> tensor { %0 = "onnx.HammingWindow"(%arg0) {output_datatype = 1 : si64 , periodic = 1 : si64} : (tensor<1xi32>) -> tensor "func.return"(%0) : (tensor) -> () @@ -4139,3 +4164,4 @@ func.func @test_random_uniform_static_bf16() -> tensor<*xbf16> { } //===----------------------------------------------------------------------===// +>>>>>>> upstream/main