-
Notifications
You must be signed in to change notification settings - Fork 374
Combine parallel dense Optimization pass in ONNX Dialect #3123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 24 commits
7aa505a
997d9e5
15343bf
4fad7ea
d4d2fda
e6d8d6c
7b357e2
ab7a2aa
2b466ca
6aff5ab
d8611f5
20cab0c
3bf58c0
d07e896
2266d90
8882476
c8d2946
dd42652
ca09e94
3f66539
2f0f113
c2b3728
81df6d9
5f132a7
92ff3d4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,7 @@ | |
| #include "mlir/Transforms/DialectConversion.h" | ||
| #include "llvm/Support/Debug.h" | ||
|
|
||
| #include "src/Compiler/CompilerOptions.hpp" | ||
| #include "src/Dialect/ONNX/DialectBuilder.hpp" | ||
| #include "src/Dialect/ONNX/ONNXOps.hpp" | ||
| #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" | ||
|
|
@@ -90,6 +91,183 @@ namespace { | |
| /// Include the patterns defined in the Declarative Rewrite framework. | ||
| // #include "src/Dialect/ONNX/Transforms/ONNXRecompose.inc" | ||
|
|
||
| struct CombineParallelDensePattern : public OpRewritePattern<ONNXGemmOp> { | ||
| using OpRewritePattern<ONNXGemmOp>::OpRewritePattern; | ||
|
|
||
| // Helper function to check if an gemm is mergeable | ||
| static bool areCompatible(ONNXGemmOp a, ONNXGemmOp b) { | ||
| if (a.getAlpha() != b.getAlpha() || a.getBeta() != b.getBeta() || | ||
| a.getTransA() != b.getTransA() || a.getTransB() != b.getTransB()) | ||
| return false; | ||
|
|
||
| auto aBShape = mlir::cast<ShapedType>(a.getB().getType()).getShape(); | ||
| auto bBShape = mlir::cast<ShapedType>(b.getB().getType()).getShape(); | ||
| int64_t axis = a.getTransB() ? 1 : 0; | ||
| if (aBShape[axis] != bBShape[axis]) | ||
| return false; | ||
|
|
||
| // Check C compatibility — only allow None or 1D | ||
| Value aC = a.getC(); | ||
| Value bC = b.getC(); | ||
| if (!onnx_mlir::isNoneValue(aC) && !onnx_mlir::isNoneValue(bC)) { | ||
| auto aCType = mlir::cast<ShapedType>(aC.getType()); | ||
| auto bCType = mlir::cast<ShapedType>(bC.getType()); | ||
| auto aCShape = aCType.getShape(); | ||
| auto bCShape = bCType.getShape(); | ||
| if (aCShape.size() != 1 || bCShape.size() != 1) | ||
| return false; | ||
| if (aCType.isDynamicDim(0) || bCType.isDynamicDim(0)) | ||
| return false; | ||
| // check output channels match | ||
| if (aCShape[0] == 1 && bCShape[0] == 1) { | ||
| auto aOutputShape = | ||
| mlir::cast<ShapedType>(a.getResult().getType()).getShape(); | ||
| auto bOutputShape = | ||
| mlir::cast<ShapedType>(b.getResult().getType()).getShape(); | ||
| // Output channels is the last dim | ||
| if (aOutputShape.back() != bOutputShape.back()) | ||
| return false; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It does not make sense to me how this can solve the problem. You must check there is no broadcasting here, say the last dim in the output must be 1 also, for example: if (aOutputShape.back() != 1 || bOutputShape.back() != 1)
return false;Also, please do add a lit test for this case, to make sure gemm ops are not merged. |
||
| } | ||
| // Otherwise, shapes must be equal | ||
| else if (aCShape[0] != bCShape[0]) | ||
| return false; | ||
| } | ||
| return true; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When aC is None, do check that the last dim of aOutput is static. Otherwise, it fails when you create a constant tensor of zeros in the later code that you use the last dim of aOutput. Check the same thing for bC. Please add a list test for the case where aC or bC is None. |
||
| } | ||
|
|
||
| LogicalResult matchAndRewrite( | ||
| ONNXGemmOp gemmOp1, PatternRewriter &rewriter) const final { | ||
| Value input = gemmOp1.getA(); | ||
| if (!onnx_mlir::isRankedShapedType(input.getType()) || | ||
| !mlir::cast<ShapedType>(input.getType()).hasStaticShape()) | ||
| return failure(); | ||
|
|
||
| SmallVector<ONNXGemmOp> parallelGemms = {gemmOp1}; | ||
|
|
||
| for (auto user : input.getUsers()) { | ||
| ONNXGemmOp currentGemm = dyn_cast<ONNXGemmOp>(user); | ||
| if (currentGemm && currentGemm != gemmOp1 && | ||
| areCompatible(gemmOp1, currentGemm)) { | ||
| parallelGemms.push_back(currentGemm); | ||
| } | ||
| } | ||
| if (parallelGemms.size() < 2) | ||
| return failure(); | ||
|
|
||
| Location loc = gemmOp1.getLoc(); | ||
| ShapedType inputType = mlir::cast<ShapedType>(input.getType()); | ||
| Type elementType = inputType.getElementType(); | ||
| onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> create( | ||
| rewriter, loc); | ||
|
|
||
| // Identify axis based on Gemm shape | ||
| int64_t concatWeightAxis = gemmOp1.getTransB() ? 0 : 1; | ||
| int64_t splitAxis = 1; | ||
|
|
||
| // Concatenate weights | ||
| SmallVector<Value> weightValues; | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Redundant empty line? |
||
| for (auto gemm : parallelGemms) { | ||
| weightValues.push_back(gemm.getB()); | ||
| } | ||
| Type unrankedTensorType = mlir::UnrankedTensorType::get(elementType); | ||
| Type newWeightType = unrankedTensorType; | ||
| Value newWeight = | ||
| create.onnx.concat(newWeightType, weightValues, concatWeightAxis); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replace |
||
|
|
||
| // Concatenate biases (create zero constants for missing biases) | ||
| SmallVector<Value> biasValues; | ||
| for (auto gemm : parallelGemms) { | ||
| if (!onnx_mlir::isNoneValue(gemm.getC())) { | ||
| biasValues.push_back(gemm.getC()); | ||
| } else { | ||
| auto gemmShape = | ||
| mlir::cast<ShapedType>(gemm.getResult().getType()).getShape(); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use: |
||
| Value zeroBias = create.onnx.constant(DenseElementsAttr::get( | ||
| RankedTensorType::get({gemmShape[splitAxis]}, elementType), 0.0)); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check if |
||
| biasValues.push_back(zeroBias); | ||
| } | ||
| } | ||
|
|
||
| Type newBiasType = unrankedTensorType; | ||
| Value newBias = create.onnx.concat(newBiasType, biasValues, 0); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replace |
||
|
|
||
| // Create combined Gemm operation | ||
| SmallVector<int64_t, 2> newOutputShape( | ||
| mlir::cast<ShapedType>(parallelGemms[0].getResult().getType()) | ||
| .getShape()); | ||
|
|
||
| // Sum output channels from parallel gemms | ||
| int64_t totalOutputChannels = 0; | ||
| for (auto gemm : parallelGemms) { | ||
| int64_t outCh = mlir::cast<ShapedType>(gemm.getResult().getType()) | ||
| .getShape()[splitAxis]; | ||
| totalOutputChannels += outCh; | ||
| } | ||
| newOutputShape[splitAxis] = totalOutputChannels; | ||
| auto newOutputType = RankedTensorType::get(newOutputShape, elementType); | ||
|
|
||
| auto newGemm = rewriter.create<ONNXGemmOp>(loc, newOutputType, input, | ||
| newWeight, newBias, gemmOp1.getAlphaAttr(), gemmOp1.getBetaAttr(), | ||
| gemmOp1.getTransAAttr(), gemmOp1.getTransBAttr()); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please replace this by Value newGemmOutput = create.onnx.gemm(unrankedTensorType, input,
newWeight, newBias, gemmOp1.getAlphaAttr(), gemmOp1.getBetaAttr(),
gemmOp1.getTransAAttr(), gemmOp1.getTransBAttr()); |
||
|
|
||
| // Check for common ConcatOp | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check this earlier just after you collect all |
||
| ONNXConcatOp commonConcatOp = nullptr; | ||
| for (auto gemm : parallelGemms) { | ||
| for (auto user : gemm.getResult().getUsers()) { | ||
| if (auto concatOp = dyn_cast<ONNXConcatOp>(user)) { | ||
| if (!commonConcatOp) { | ||
| commonConcatOp = concatOp; | ||
| if (concatOp.getAxis() != splitAxis) | ||
| return failure(); | ||
| } | ||
| if (concatOp != commonConcatOp || concatOp.getAxis() != splitAxis) { | ||
| commonConcatOp = nullptr; | ||
| break; | ||
| } | ||
| } else { | ||
| commonConcatOp = nullptr; | ||
| break; | ||
| } | ||
| } | ||
| if (!commonConcatOp) { | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| if (commonConcatOp) { | ||
| if (commonConcatOp.getAxis() == splitAxis) { | ||
| commonConcatOp.getResult().replaceAllUsesWith(newGemm.getResult()); | ||
| rewriter.eraseOp(commonConcatOp); | ||
| } | ||
| } else { | ||
| SmallVector<int64_t, 4> splitSizesVec; | ||
| for (auto gemm : parallelGemms) { | ||
| int64_t outputChannels = | ||
| mlir::cast<ShapedType>(gemm.getResult().getType()) | ||
| .getShape()[splitAxis]; | ||
| splitSizesVec.push_back(outputChannels); | ||
| } | ||
|
|
||
| ArrayRef<int64_t> splitSizes(splitSizesVec); | ||
| ValueRange splitResults = onnx_mlir::emitSplitByChannels( | ||
| rewriter, loc, newGemm.getResult(), splitSizes, splitAxis); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please replace this by SmallVector<Type, 4> splitTypes(splitSizes.size(), unrankedTensorType);
ValueRange splitResults = create.onnx.split(
splitTypes, newGemmOutput, create.onnx.constantInt64(splitSizes), splitAxis); |
||
|
|
||
| for (size_t i = 0; i < parallelGemms.size(); ++i) { | ||
| parallelGemms[i].replaceAllUsesWith(splitResults[i]); | ||
| } | ||
| } | ||
|
|
||
| for (auto gemm : parallelGemms) { | ||
| if (gemm.getResult().use_empty()) { | ||
| rewriter.eraseOp(gemm); | ||
| } | ||
| } | ||
|
|
||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| struct RecomposeLayerNormFromMulPattern : public OpRewritePattern<ONNXMulOp> { | ||
| using OpRewritePattern<ONNXMulOp>::OpRewritePattern; | ||
|
|
||
|
|
@@ -897,6 +1075,9 @@ void RecomposeONNXToONNXPass::runOnOperation() { | |
| void onnx_mlir::getRecomposeONNXToONNXPatterns( | ||
| mlir::RewritePatternSet &patterns) { | ||
| MLIRContext *context = patterns.getContext(); | ||
| if (fuseParallelOnnxGemm) { | ||
| patterns.insert<CombineParallelDensePattern>(context); | ||
| } | ||
| patterns.insert<RecomposeGeluFromMulPattern>(context); | ||
| patterns.insert<RecomposeLayerNormFromMulPattern>(context); | ||
| patterns.insert<RecomposeQLinearMatMulFromQuantizeLinearPattern>(context); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| // RUN: onnx-mlir --useOnnxModelTypes=false --fuse-parallel-onnx-gemm --EmitONNXIR --printIR %s | FileCheck %s | ||
|
|
||
| func.func @test_gemm_concat_simple(%arg0: tensor<1x4xf32>) -> tensor<1x6xf32> { | ||
| %0 = onnx.Constant dense<5.5>: tensor<4x3xf32> | ||
| %1 = onnx.Constant dense<0.2> : tensor<3xf32> | ||
| %2 = onnx.Constant dense<4.5>: tensor<4x3xf32> | ||
| %3 = onnx.Constant dense<0.5> : tensor<3xf32> | ||
| %4 = "onnx.Gemm"(%arg0, %0, %1) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_1", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> | ||
| %5 = "onnx.Gemm"(%arg0, %2, %3) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_2", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> | ||
| %6 = "onnx.Concat"(%4, %5) {axis = 1 : si64, onnx_node_name = "Concat"} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x6xf32> | ||
| return %6 : tensor<1x6xf32> | ||
|
|
||
| // CHECK-LABEL: func @test_gemm_concat_simple | ||
| // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x4xf32>) -> tensor<1x6xf32> { | ||
| // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<4x6xf32> | ||
|
|
||
| // CHECK: [[VAR_1_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<6xf32> | ||
|
|
||
| // CHECK: [[VAR_2_:%.+]] = "onnx.Gemm"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]]) | ||
| // CHECK-SAME: : (tensor<1x4xf32>, tensor<4x6xf32>, tensor<6xf32>) -> tensor<1x6xf32> | ||
| // CHECK-NEXT: return [[VAR_2_]] : tensor<1x6xf32> | ||
|
|
||
| } | ||
|
|
||
| func.func @test_combine_gemm_split(%arg0: tensor<1x4xf32>) -> tensor<1x12xf32> { | ||
| %0 = onnx.Constant dense<1.6> : tensor<4x3xf32> | ||
| %1 = onnx.Constant dense<2.7> : tensor<4x3xf32> | ||
| %2 = onnx.Constant dense<3.7> : tensor<4x3xf32> | ||
| %3 = onnx.Constant dense<4.6> : tensor<4x3xf32> | ||
| %4 = onnx.Constant dense<0.1> : tensor<3xf32> | ||
| %5 = onnx.Constant dense<0.9> : tensor<3xf32> | ||
| %6 = onnx.Constant dense<0.2> : tensor<3xf32> | ||
| %7 = onnx.Constant dense<0.8> : tensor<3xf32> | ||
| %8 = "onnx.Gemm"(%arg0, %0, %4) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_1", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> | ||
| %9 = "onnx.Gemm"(%arg0, %1, %5) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_2", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> | ||
| %10 = "onnx.Gemm"(%arg0, %2, %6) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_3", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> | ||
| %11 = "onnx.Gemm"(%arg0, %3, %7) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_4", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> | ||
| %12 = "onnx.Relu"(%8) {onnx_node_name = "ReLU_1"} : (tensor<1x3xf32>) -> tensor<1x3xf32> | ||
| %13 = "onnx.Sigmoid"(%9) {onnx_node_name = "Sigmoid_2"} : (tensor<1x3xf32>) -> tensor<1x3xf32> | ||
| %14 = "onnx.Tanh"(%10) {onnx_node_name = "Tanh_3"} : (tensor<1x3xf32>) -> tensor<1x3xf32> | ||
| %15 = "onnx.LeakyRelu"(%11) {alpha = 0.00999999977 : f32, onnx_node_name = "LeakyReLU_4"} : (tensor<1x3xf32>) -> tensor<1x3xf32> | ||
Arkar-Hema marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| %16 = "onnx.Concat"(%12, %13, %14, %15) {axis = 1 : si64, onnx_node_name = "Concat"} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x12xf32> | ||
| return %16 : tensor<1x12xf32> | ||
|
|
||
| // CHECK-LABEL: func @test_combine_gemm_split | ||
| // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x4xf32>) -> tensor<1x12xf32> { | ||
| // CHECK: [[CONST_SPLIT_:%.+]] = onnx.Constant dense<3> : tensor<4xi64> | ||
| // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<4x12xf32> | ||
| // CHECK: [[VAR_1_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<12xf32> | ||
| // CHECK: [[GEMM_OUT_:%.+]] = "onnx.Gemm"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]]) | ||
| // CHECK-SAME: : (tensor<1x4xf32>, tensor<4x12xf32>, tensor<12xf32>) -> tensor<1x12xf32> | ||
| // CHECK: [[VAR_2_:[^ ]+]]:4 = "onnx.Split"([[GEMM_OUT_]], [[CONST_SPLIT_]]) {axis = 1 : si64, onnx_node_name = "onnx.Split_2"} : (tensor<1x12xf32>, tensor<4xi64>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) | ||
| // CHECK: [[VAR_3_:%.+]] = "onnx.Relu"([[VAR_2_]]#0) {onnx_node_name = "ReLU_1"} : (tensor<1x3xf32>) -> tensor<1x3xf32> | ||
| // CHECK: [[VAR_4_:%.+]] = "onnx.Sigmoid"([[VAR_2_]]#3) {onnx_node_name = "Sigmoid_2"} : (tensor<1x3xf32>) -> tensor<1x3xf32> | ||
| // CHECK: [[VAR_5_:%.+]] = "onnx.Tanh"([[VAR_2_]]#2) {onnx_node_name = "Tanh_3"} : (tensor<1x3xf32>) -> tensor<1x3xf32> | ||
| // CHECK: [[VAR_6_:%.+]] = "onnx.LeakyRelu"([[VAR_2_]]#1) {alpha = 0.00999999977 : f32, onnx_node_name = "LeakyReLU_4"} : (tensor<1x3xf32>) -> tensor<1x3xf32> | ||
| // CHECK: [[FINAL_OUT:%.+]] = "onnx.Concat"([[VAR_3_]], [[VAR_4_]], [[VAR_5_]], [[VAR_6_]]) {axis = 1 : si64, onnx_node_name = "Concat"} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x12xf32> | ||
| // CHECK: return [[FINAL_OUT]] : tensor<1x12xf32> | ||
|
|
||
|
|
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace these by: