Skip to content
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7aa505a
Combine parallel dense optimization pass
Arkar-Hema Apr 16, 2025
997d9e5
Clang format modified
Arkar-Hema Apr 16, 2025
15343bf
Clang format modified
Arkar-Hema Apr 16, 2025
4fad7ea
Added the unit test for the pass
Arkar-Hema Apr 16, 2025
d4d2fda
Merge branch 'main' into combine_parallel_dense
AlexandreEichenberger May 1, 2025
e6d8d6c
Updated test case, added compiler flag, and builder for gemm
Arkar-Hema May 2, 2025
7b357e2
Clang format fix
Arkar-Hema May 2, 2025
ab7a2aa
Clang fix
Arkar-Hema May 2, 2025
2b466ca
Added compiler option
Arkar-Hema May 2, 2025
6aff5ab
Added compiler option in test case
Arkar-Hema May 2, 2025
d8611f5
Test case updation
Arkar-Hema May 2, 2025
20cab0c
Merge branch 'main' into combine_parallel_dense
AlexandreEichenberger May 2, 2025
3bf58c0
Merge branch 'main' into combine_parallel_dense
Arkar-Hema May 5, 2025
d07e896
Added lit test for dynamic shapes
Arkar-Hema May 8, 2025
2266d90
Clang format fix
Arkar-Hema May 8, 2025
8882476
Added unrankedtype for outputtype
Arkar-Hema May 8, 2025
c8d2946
Added ranked type for output type
Arkar-Hema May 8, 2025
dd42652
Clang format fix
Arkar-Hema May 8, 2025
ca09e94
Merge branch 'main' into combine_parallel_dense
AlexandreEichenberger May 8, 2025
3f66539
Updated output type
Arkar-Hema May 9, 2025
2f0f113
Updated Compatible function
Arkar-Hema May 13, 2025
c2b3728
clang fix
Arkar-Hema May 13, 2025
81df6d9
Resolved conflicts
Arkar-Hema May 15, 2025
5f132a7
Merge branch 'main' into combine_parallel_dense
AlexandreEichenberger May 16, 2025
92ff3d4
Merge branch 'main' into combine_parallel_dense
chentong319 Sep 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ OptReport optReport; // onnx-mlir only
bool useOldBufferization; // onnx-mlir only
bool enableTiming; // onnx-mlir only
bool enableBoundCheck; // onnx-mlir only
bool fuseParallelOnnxGemm; // onnx-mlir only
bool split_input_file; // onnx-mlir-opt only
bool verify_diagnostics; // onnx-mlir-opt only
bool verify_passes; // onnx-mlir-opt only
Expand Down Expand Up @@ -723,6 +724,13 @@ static llvm::cl::opt<bool, true> enable_bound_check("enable-bound-check",
llvm::cl::location(enableBoundCheck), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));

static llvm::cl::opt<bool, true> fuse_parallel_onnx_gemm(
"fuse-parallel-onnx-gemm",
llvm::cl::desc("Enable Combine parallel dense layers (default=false)."),
llvm::cl::location(
fuseParallelOnnxGemm), // Link directly to the existing variable
llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions));

/*
How to use the optional optimization for testing.

Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ extern bool useOldBufferization; // onnx-mlir only
extern bool enableTiming; // onnx-mlir only
extern bool enableBoundCheck; // onnx-mlir only
extern bool debugTestCompilerOpt; // onnx-mlir only
extern bool fuseParallelOnnxGemm; // onnx-mlir only

extern bool split_input_file; // onnx-mlir-opt only
extern bool verify_diagnostics; // onnx-mlir-opt only
Expand Down
7 changes: 7 additions & 0 deletions src/Dialect/ONNX/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ Value OnnxBuilder::gelu(Value input, StringAttr approximateAttr) const {
toTensor(input.getType()), input, approximateAttr);
}

Value OnnxBuilder::gemm(Type Y, Value A, Value B, Value C, FloatAttr alpha,
FloatAttr beta, IntegerAttr transA, IntegerAttr transB) const {

return createOpAndInferShapes<ONNXGemmOp>(
toTensor(Y), A, B, C, alpha, beta, transA, transB);
}

// ONNXLayerNormalizationOp, version with one output only (Y).
Value OnnxBuilder::layerNorm(Type outputType, Value input, Value scale,
Value bias, int64_t axis, FloatAttr epsilon) const {
Expand Down
5 changes: 5 additions & 0 deletions src/Dialect/ONNX/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ struct OnnxBuilder : DialectBuilder {
// ONNXGeluOp
mlir::Value gelu(mlir::Value input, mlir::StringAttr approximateAttr) const;

// ONNXGemmOp
mlir::Value gemm(mlir::Type Y, mlir::Value A, mlir::Value B, mlir::Value C,
mlir::FloatAttr alpha, mlir::FloatAttr beta, mlir::IntegerAttr transA,
mlir::IntegerAttr transB) const;

// ONNXLayerNormalizationOp, version with one output only (Y).
mlir::Value layerNorm(mlir::Type outputType, mlir::Value input,
mlir::Value scale, mlir::Value bias, int64_t axis,
Expand Down
181 changes: 181 additions & 0 deletions src/Dialect/ONNX/Transforms/Recompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"

#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
Expand Down Expand Up @@ -90,6 +91,183 @@ namespace {
/// Include the patterns defined in the Declarative Rewrite framework.
// #include "src/Dialect/ONNX/Transforms/ONNXRecompose.inc"

struct CombineParallelDensePattern : public OpRewritePattern<ONNXGemmOp> {
using OpRewritePattern<ONNXGemmOp>::OpRewritePattern;

// Helper function to check if an gemm is mergeable
static bool areCompatible(ONNXGemmOp a, ONNXGemmOp b) {
if (a.getAlpha() != b.getAlpha() || a.getBeta() != b.getBeta() ||
a.getTransA() != b.getTransA() || a.getTransB() != b.getTransB())
return false;

auto aBShape = mlir::cast<ShapedType>(a.getB().getType()).getShape();
auto bBShape = mlir::cast<ShapedType>(b.getB().getType()).getShape();
int64_t axis = a.getTransB() ? 1 : 0;
if (aBShape[axis] != bBShape[axis])
return false;

// Check C compatibility — only allow None or 1D
Value aC = a.getC();
Value bC = b.getC();
if (!onnx_mlir::isNoneValue(aC) && !onnx_mlir::isNoneValue(bC)) {
auto aCType = mlir::cast<ShapedType>(aC.getType());
auto bCType = mlir::cast<ShapedType>(bC.getType());
auto aCShape = aCType.getShape();
auto bCShape = bCType.getShape();
if (aCShape.size() != 1 || bCShape.size() != 1)
return false;
if (aCType.isDynamicDim(0) || bCType.isDynamicDim(0))
return false;
// check output channels match
if (aCShape[0] == 1 && bCShape[0] == 1) {
auto aOutputShape =
mlir::cast<ShapedType>(a.getResult().getType()).getShape();
auto bOutputShape =
mlir::cast<ShapedType>(b.getResult().getType()).getShape();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace these by:

ArrayRef<int64_t> aOutputShape = getShape(a.getY().getType());
ArrayRef<int64_t> bOutputShape = getShape(b.getY().getType());

// Output channels is the last dim
if (aOutputShape.back() != bOutputShape.back())
return false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both Biases as tensor<1xf32>:
If both biases are of shape tensor<1xf32>, I now check their corresponding Gemm output shapes and ensure their output channels (last dimension) match before considering them compatible. If they differ, the function returns false, as merging them without this check would be invalid.

It does not make sense to me how this can solve the problem. You must check there is no broadcasting here, say the last dim in the output must be 1 also, for example:

if (aOutputShape.back() != 1 || bOutputShape.back() != 1)
   return false;

Also, please do add a lit test for this case, to make sure gemm ops are not merged.

}
// Otherwise, shapes must be equal
else if (aCShape[0] != bCShape[0])
return false;
}
return true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When aC is None, do check that the last dim of aOutput is static. Otherwise, it fails when you create a constant tensor of zeros in the later code that you use the last dim of aOutput.

Check the same thing for bC.

Please add a list test for the case where aC or bC is None.

}

LogicalResult matchAndRewrite(
ONNXGemmOp gemmOp1, PatternRewriter &rewriter) const final {
Value input = gemmOp1.getA();
if (!onnx_mlir::isRankedShapedType(input.getType()) ||
!mlir::cast<ShapedType>(input.getType()).hasStaticShape())
return failure();

SmallVector<ONNXGemmOp> parallelGemms = {gemmOp1};

for (auto user : input.getUsers()) {
ONNXGemmOp currentGemm = dyn_cast<ONNXGemmOp>(user);
if (currentGemm && currentGemm != gemmOp1 &&
areCompatible(gemmOp1, currentGemm)) {
parallelGemms.push_back(currentGemm);
}
}
if (parallelGemms.size() < 2)
return failure();

Location loc = gemmOp1.getLoc();
ShapedType inputType = mlir::cast<ShapedType>(input.getType());
Type elementType = inputType.getElementType();
onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> create(
rewriter, loc);

// Identify axis based on Gemm shape
int64_t concatWeightAxis = gemmOp1.getTransB() ? 0 : 1;
int64_t splitAxis = 1;

// Concatenate weights
SmallVector<Value> weightValues;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant empty line?

for (auto gemm : parallelGemms) {
weightValues.push_back(gemm.getB());
}
Type unrankedTensorType = mlir::UnrankedTensorType::get(elementType);
Type newWeightType = unrankedTensorType;
Value newWeight =
create.onnx.concat(newWeightType, weightValues, concatWeightAxis);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace newWeightType by unrankedTensorType . It's redundant to define newWeightType.


// Concatenate biases (create zero constants for missing biases)
SmallVector<Value> biasValues;
for (auto gemm : parallelGemms) {
if (!onnx_mlir::isNoneValue(gemm.getC())) {
biasValues.push_back(gemm.getC());
} else {
auto gemmShape =
mlir::cast<ShapedType>(gemm.getResult().getType()).getShape();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use: ArrayRef<int64_t> gemmShape = getShape(gemm.getY().getType());

Value zeroBias = create.onnx.constant(DenseElementsAttr::get(
RankedTensorType::get({gemmShape[splitAxis]}, elementType), 0.0));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check if gemmShape[splitAxis] is a static dimension in areCompatible() function. Otherwise, it fails to create a constant tensor here.

biasValues.push_back(zeroBias);
}
}

Type newBiasType = unrankedTensorType;
Value newBias = create.onnx.concat(newBiasType, biasValues, 0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace newBiasType by unrankedTensorType . It's redundant to define newBiasType .


// Create combined Gemm operation
SmallVector<int64_t, 2> newOutputShape(
mlir::cast<ShapedType>(parallelGemms[0].getResult().getType())
.getShape());

// Sum output channels from parallel gemms
int64_t totalOutputChannels = 0;
for (auto gemm : parallelGemms) {
int64_t outCh = mlir::cast<ShapedType>(gemm.getResult().getType())
.getShape()[splitAxis];
totalOutputChannels += outCh;
}
newOutputShape[splitAxis] = totalOutputChannels;
auto newOutputType = RankedTensorType::get(newOutputShape, elementType);

auto newGemm = rewriter.create<ONNXGemmOp>(loc, newOutputType, input,
newWeight, newBias, gemmOp1.getAlphaAttr(), gemmOp1.getBetaAttr(),
gemmOp1.getTransAAttr(), gemmOp1.getTransBAttr());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please replace this by

Value newGemmOutput = create.onnx.gemm(unrankedTensorType, input,
        newWeight, newBias, gemmOp1.getAlphaAttr(), gemmOp1.getBetaAttr(),
        gemmOp1.getTransAAttr(), gemmOp1.getTransBAttr());


// Check for common ConcatOp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check this earlier just after you collect all parallelGemms. The reason is you have a return failure() here which may interrupt the whole rewriting while you created new weight, new bias, and new gemm. Moving this check earlier before creating any new ops would make the IR clean.

ONNXConcatOp commonConcatOp = nullptr;
for (auto gemm : parallelGemms) {
for (auto user : gemm.getResult().getUsers()) {
if (auto concatOp = dyn_cast<ONNXConcatOp>(user)) {
if (!commonConcatOp) {
commonConcatOp = concatOp;
if (concatOp.getAxis() != splitAxis)
return failure();
}
if (concatOp != commonConcatOp || concatOp.getAxis() != splitAxis) {
commonConcatOp = nullptr;
break;
}
} else {
commonConcatOp = nullptr;
break;
}
}
if (!commonConcatOp) {
break;
}
}

if (commonConcatOp) {
if (commonConcatOp.getAxis() == splitAxis) {
commonConcatOp.getResult().replaceAllUsesWith(newGemm.getResult());
rewriter.eraseOp(commonConcatOp);
}
} else {
SmallVector<int64_t, 4> splitSizesVec;
for (auto gemm : parallelGemms) {
int64_t outputChannels =
mlir::cast<ShapedType>(gemm.getResult().getType())
.getShape()[splitAxis];
splitSizesVec.push_back(outputChannels);
}

ArrayRef<int64_t> splitSizes(splitSizesVec);
ValueRange splitResults = onnx_mlir::emitSplitByChannels(
rewriter, loc, newGemm.getResult(), splitSizes, splitAxis);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please replace this by

    SmallVector<Type, 4> splitTypes(splitSizes.size(), unrankedTensorType);
    ValueRange splitResults = create.onnx.split(
        splitTypes, newGemmOutput, create.onnx.constantInt64(splitSizes), splitAxis);


for (size_t i = 0; i < parallelGemms.size(); ++i) {
parallelGemms[i].replaceAllUsesWith(splitResults[i]);
}
}

for (auto gemm : parallelGemms) {
if (gemm.getResult().use_empty()) {
rewriter.eraseOp(gemm);
}
}

return success();
}
};

struct RecomposeLayerNormFromMulPattern : public OpRewritePattern<ONNXMulOp> {
using OpRewritePattern<ONNXMulOp>::OpRewritePattern;

Expand Down Expand Up @@ -897,6 +1075,9 @@ void RecomposeONNXToONNXPass::runOnOperation() {
void onnx_mlir::getRecomposeONNXToONNXPatterns(
mlir::RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
if (fuseParallelOnnxGemm) {
patterns.insert<CombineParallelDensePattern>(context);
}
patterns.insert<RecomposeGeluFromMulPattern>(context);
patterns.insert<RecomposeLayerNormFromMulPattern>(context);
patterns.insert<RecomposeQLinearMatMulFromQuantizeLinearPattern>(context);
Expand Down
61 changes: 61 additions & 0 deletions test/mlir/onnx/onnx_recompose_combine_parallel_dense.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// RUN: onnx-mlir --useOnnxModelTypes=false --fuse-parallel-onnx-gemm --EmitONNXIR --printIR %s | FileCheck %s

func.func @test_gemm_concat_simple(%arg0: tensor<1x4xf32>) -> tensor<1x6xf32> {
%0 = onnx.Constant dense<5.5>: tensor<4x3xf32>
%1 = onnx.Constant dense<0.2> : tensor<3xf32>
%2 = onnx.Constant dense<4.5>: tensor<4x3xf32>
%3 = onnx.Constant dense<0.5> : tensor<3xf32>
%4 = "onnx.Gemm"(%arg0, %0, %1) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_1", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%5 = "onnx.Gemm"(%arg0, %2, %3) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_2", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%6 = "onnx.Concat"(%4, %5) {axis = 1 : si64, onnx_node_name = "Concat"} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x6xf32>
return %6 : tensor<1x6xf32>

// CHECK-LABEL: func @test_gemm_concat_simple
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x4xf32>) -> tensor<1x6xf32> {
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<4x6xf32>

// CHECK: [[VAR_1_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<6xf32>

// CHECK: [[VAR_2_:%.+]] = "onnx.Gemm"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]])
// CHECK-SAME: : (tensor<1x4xf32>, tensor<4x6xf32>, tensor<6xf32>) -> tensor<1x6xf32>
// CHECK-NEXT: return [[VAR_2_]] : tensor<1x6xf32>

}

func.func @test_combine_gemm_split(%arg0: tensor<1x4xf32>) -> tensor<1x12xf32> {
%0 = onnx.Constant dense<1.6> : tensor<4x3xf32>
%1 = onnx.Constant dense<2.7> : tensor<4x3xf32>
%2 = onnx.Constant dense<3.7> : tensor<4x3xf32>
%3 = onnx.Constant dense<4.6> : tensor<4x3xf32>
%4 = onnx.Constant dense<0.1> : tensor<3xf32>
%5 = onnx.Constant dense<0.9> : tensor<3xf32>
%6 = onnx.Constant dense<0.2> : tensor<3xf32>
%7 = onnx.Constant dense<0.8> : tensor<3xf32>
%8 = "onnx.Gemm"(%arg0, %0, %4) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_1", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%9 = "onnx.Gemm"(%arg0, %1, %5) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_2", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%10 = "onnx.Gemm"(%arg0, %2, %6) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_3", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%11 = "onnx.Gemm"(%arg0, %3, %7) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "Gemm_4", transA = 0 : si64, transB = 0 : si64} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32>
%12 = "onnx.Relu"(%8) {onnx_node_name = "ReLU_1"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
%13 = "onnx.Sigmoid"(%9) {onnx_node_name = "Sigmoid_2"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
%14 = "onnx.Tanh"(%10) {onnx_node_name = "Tanh_3"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
%15 = "onnx.LeakyRelu"(%11) {alpha = 0.00999999977 : f32, onnx_node_name = "LeakyReLU_4"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
%16 = "onnx.Concat"(%12, %13, %14, %15) {axis = 1 : si64, onnx_node_name = "Concat"} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x12xf32>
return %16 : tensor<1x12xf32>

// CHECK-LABEL: func @test_combine_gemm_split
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x4xf32>) -> tensor<1x12xf32> {
// CHECK: [[CONST_SPLIT_:%.+]] = onnx.Constant dense<3> : tensor<4xi64>
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<4x12xf32>
// CHECK: [[VAR_1_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<12xf32>
// CHECK: [[GEMM_OUT_:%.+]] = "onnx.Gemm"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]])
// CHECK-SAME: : (tensor<1x4xf32>, tensor<4x12xf32>, tensor<12xf32>) -> tensor<1x12xf32>
// CHECK: [[VAR_2_:[^ ]+]]:4 = "onnx.Split"([[GEMM_OUT_]], [[CONST_SPLIT_]]) {axis = 1 : si64, onnx_node_name = "onnx.Split_2"} : (tensor<1x12xf32>, tensor<4xi64>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>)
// CHECK: [[VAR_3_:%.+]] = "onnx.Relu"([[VAR_2_]]#0) {onnx_node_name = "ReLU_1"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
// CHECK: [[VAR_4_:%.+]] = "onnx.Sigmoid"([[VAR_2_]]#3) {onnx_node_name = "Sigmoid_2"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
// CHECK: [[VAR_5_:%.+]] = "onnx.Tanh"([[VAR_2_]]#2) {onnx_node_name = "Tanh_3"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
// CHECK: [[VAR_6_:%.+]] = "onnx.LeakyRelu"([[VAR_2_]]#1) {alpha = 0.00999999977 : f32, onnx_node_name = "LeakyReLU_4"} : (tensor<1x3xf32>) -> tensor<1x3xf32>
// CHECK: [[FINAL_OUT:%.+]] = "onnx.Concat"([[VAR_3_]], [[VAR_4_]], [[VAR_5_]], [[VAR_6_]]) {axis = 1 : si64, onnx_node_name = "Concat"} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x12xf32>
// CHECK: return [[FINAL_OUT]] : tensor<1x12xf32>


}
Loading