Skip to content

Commit ae1a6e4

Browse files
authored
[onnx] Lower onnx.Gemm to torch (#2663)
General lowering for `onnx.Gemm` to `torch`
1 parent cee8563 commit ae1a6e4

File tree

2 files changed

+150
-0
lines changed

2 files changed

+150
-0
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,80 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
107107
binder.op, resultType, data, constAxis, indices, sparseGrad);
108108
return success();
109109
});
110+
patterns.onOp(
111+
"Gemm", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
112+
Torch::ValueTensorType resultType;
113+
Value a, b, c;
114+
float alpha, beta;
115+
int64_t transA, transB;
116+
if (binder.tensorOperandAtIndex(a, 0) ||
117+
binder.tensorOperandAtIndex(b, 1) ||
118+
binder.tensorOperandAtIndex(c, 2) ||
119+
binder.s64IntegerAttr(transA, "transA", 0) ||
120+
binder.s64IntegerAttr(transB, "transB", 0) ||
121+
binder.f32FloatAttr(alpha, "alpha", 1.0) ||
122+
binder.f32FloatAttr(beta, "beta", 1.0) ||
123+
binder.tensorResultType(resultType))
124+
return failure();
125+
126+
Value zero = rewriter.create<Torch::ConstantIntOp>(
127+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
128+
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
129+
Value one = rewriter.create<Torch::ConstantIntOp>(
130+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
131+
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
132+
133+
auto transpose = [&](Value m) -> Value {
134+
auto tty = m.getType().cast<Torch::ValueTensorType>();
135+
auto shape = tty.getOptionalSizes();
136+
if (shape.has_value()) {
137+
llvm::SmallVector<int64_t> newShape(shape.value());
138+
std::reverse(newShape.begin(), newShape.end());
139+
shape = std::move(newShape);
140+
}
141+
auto oty = Torch::ValueTensorType::get(tty.getContext(), shape,
142+
tty.getOptionalDtype());
143+
return rewriter.create<Torch::AtenTransposeIntOp>(binder.getLoc(),
144+
oty, m, zero, one);
145+
};
146+
147+
if (transA) {
148+
a = transpose(a);
149+
}
150+
151+
if (transB) {
152+
b = transpose(b);
153+
}
154+
155+
Value mm =
156+
rewriter.create<Torch::AtenMmOp>(binder.getLoc(), resultType, a, b);
157+
if (alpha == 1.0 && beta == 1.0) {
158+
rewriter.replaceOpWithNewOp<Torch::AtenAddTensorOp>(
159+
binder.op, resultType, mm, c, one);
160+
return success();
161+
}
162+
163+
if (alpha != 1.0 && beta != 1.0) {
164+
Value constAlpha = rewriter.create<Torch::ConstantFloatOp>(
165+
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
166+
rewriter.getF64FloatAttr(alpha));
167+
mm = rewriter.create<Torch::AtenMulScalarOp>(
168+
binder.getLoc(), resultType, mm, constAlpha);
169+
alpha = 1.0;
170+
}
171+
172+
if (alpha != 1.0) {
173+
std::swap(alpha, beta);
174+
std::swap(mm, c);
175+
}
176+
177+
Value constBeta = rewriter.create<Torch::ConstantFloatOp>(
178+
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
179+
rewriter.getF64FloatAttr(beta));
180+
rewriter.replaceOpWithNewOp<Torch::AtenAddTensorOp>(
181+
binder.op, resultType, mm, c, constBeta);
182+
return success();
183+
});
110184
patterns.onOp("LeakyRelu", 16,
111185
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
112186
Torch::ValueTensorType resultType;

test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,82 @@ func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torc
4040
return %0 : !torch.vtensor<[3,4,5],f32>
4141
}
4242

43+
// -----
44+
45+
// CHECK-LABEL: func.func @test_gemm_default
46+
func.func @test_gemm_default(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
47+
// CHECK: %[[I1:.+]] = torch.constant.int 1
48+
// CHECK: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32>
49+
// CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[I1]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32>
50+
%0 = torch.operator "onnx.Gemm"(%arg0, %arg1, %arg2) : (!torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32>, !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32>
51+
return %0 : !torch.vtensor<[3,4],f32>
52+
}
53+
54+
// -----
55+
56+
// CHECK-LABEL: func.func @test_gemm_transposeA
57+
func.func @test_gemm_transposeA(%arg0: !torch.vtensor<[5,3],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
58+
// CHECK: %[[I0:.+]] = torch.constant.int 0
59+
// CHECK: %[[I1:.+]] = torch.constant.int 1
60+
// CHECK: %[[TRANS:.+]] = torch.aten.transpose.int %arg0, %[[I0]], %[[I1]] : !torch.vtensor<[5,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,5],f32>
61+
// CHECK: %[[MM:.+]] = torch.aten.mm %[[TRANS]], %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32>
62+
// CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[I1]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32>
63+
%0 = torch.operator "onnx.Gemm"(%arg0, %arg1, %arg2) {torch.onnx.transA = 1 : si64} : (!torch.vtensor<[5,3],f32>, !torch.vtensor<[5,4],f32>, !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32>
64+
return %0 : !torch.vtensor<[3,4],f32>
65+
}
66+
67+
// -----
68+
69+
// CHECK-LABEL: func.func @test_gemm_transposeB
70+
func.func @test_gemm_transposeB(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[4,5],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
71+
// CHECK: %[[I0:.+]] = torch.constant.int 0
72+
// CHECK: %[[I1:.+]] = torch.constant.int 1
73+
// CHECK: %[[TRANS:.+]] = torch.aten.transpose.int %arg1, %[[I0]], %[[I1]] : !torch.vtensor<[4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[5,4],f32>
74+
// CHECK: %[[MM:.+]] = torch.aten.mm %arg0, %[[TRANS]] : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32>
75+
// CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[I1]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32>
76+
%0 = torch.operator "onnx.Gemm"(%arg0, %arg1, %arg2) {torch.onnx.transB = 1 : si64} : (!torch.vtensor<[3,5],f32>, !torch.vtensor<[4,5],f32>, !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32>
77+
return %0 : !torch.vtensor<[3,4],f32>
78+
}
79+
80+
// -----
81+
82+
// CHECK-LABEL: func.func @test_gemm_alpha
83+
func.func @test_gemm_alpha(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
84+
// CHECK-DAG: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32>
85+
// CHECK-DAG: %[[ALPHA:.+]] = torch.constant.float 5.000000e-01
86+
// CHECK: torch.aten.add.Tensor %arg2, %[[MM]], %[[ALPHA]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[3,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32>
87+
%0 = torch.operator "onnx.Gemm"(%arg0, %arg1, %arg2) {torch.onnx.alpha = 5.000000e-01 : f32} : (!torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32>, !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32>
88+
return %0 : !torch.vtensor<[3,4],f32>
89+
}
90+
91+
// -----
92+
93+
// CHECK-LABEL: func.func @test_gemm_beta
94+
func.func @test_gemm_beta(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
95+
// CHECK-DAG: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32>
96+
// CHECK-DAG: %[[BETA:.+]] = torch.constant.float 5.000000e-01
97+
// CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[BETA]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32>
98+
%0 = torch.operator "onnx.Gemm"(%arg0, %arg1, %arg2) {torch.onnx.beta = 5.000000e-01 : f32} : (!torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32>, !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32>
99+
return %0 : !torch.vtensor<[3,4],f32>
100+
}
101+
102+
// -----
103+
104+
// CHECK-LABEL: func.func @test_gemm_alpha_beta
105+
func.func @test_gemm_alpha_beta(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
106+
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
107+
// CHECK-DAG: %[[I1:.+]] = torch.constant.int 1
108+
// CHECK-DAG: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32>
109+
// CHECK-DAG: %[[ALPHA:.+]] = torch.constant.float 5.000000e-01
110+
// CHECK-DAG: %[[BETA:.+]] = torch.constant.float 2.500000e-01
111+
// CHECK-DAG: %[[MUL:.+]] = torch.aten.mul.Scalar %[[MM]], %[[ALPHA]] : !torch.vtensor<[3,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32>
112+
// CHECK: torch.aten.add.Tensor %[[MUL]], %arg2, %[[BETA]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32>
113+
%0 = torch.operator "onnx.Gemm"(%arg0, %arg1, %arg2) {torch.onnx.alpha = 5.000000e-01 : f32, torch.onnx.beta = 2.500000e-01 : f32} : (!torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32>, !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32>
114+
return %0 : !torch.vtensor<[3,4],f32>
115+
}
116+
117+
// -----
118+
43119
// CHECK-LABEL: func.func @test_leaky_relu
44120
func.func @test_leaky_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 16 : si64} {
45121
// CHECK-DAG: %[[F2:.+]] = torch.constant.float 2

0 commit comments

Comments
 (0)