Skip to content

Commit 4f252c8

Browse files
[MLIR][ONNX] Add OnnxToTorch support for GlobalAveragePool op. (#2692)
This commit adds the OnnxToTorch support for GlobalAveragePool op. Signed-Off By: [email protected]
1 parent ee75e8d commit 4f252c8

File tree

2 files changed

+107
-0
lines changed

2 files changed

+107
-0
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,77 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
242242
binder.op, resultType, mm, c, constBeta);
243243
return success();
244244
});
245+
patterns.onOp(
246+
"GlobalAveragePool", 1,
247+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
248+
Torch::ValueTensorType resultType;
249+
Value operand;
250+
if (binder.tensorOperand(operand) ||
251+
binder.tensorResultType(resultType))
252+
return failure();
253+
254+
auto inputTensorType = operand.getType().cast<Torch::ValueTensorType>();
255+
if (!inputTensorType || !inputTensorType.hasSizes()) {
256+
return rewriter.notifyMatchFailure(
257+
binder.op, "Expected input type having sizes");
258+
}
259+
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
260+
unsigned inputRank = inputShape.size();
261+
if (!resultType || !resultType.hasSizes()) {
262+
return rewriter.notifyMatchFailure(
263+
binder.op, "Expected result type having sizes");
264+
}
265+
ArrayRef<int64_t> resultShape = resultType.getSizes();
266+
267+
SmallVector<Value> cstKernel, cstPadding, cstStrides;
268+
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
269+
binder.getLoc(), rewriter.getI64IntegerAttr(0));
270+
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
271+
binder.getLoc(), rewriter.getI64IntegerAttr(1));
272+
for (unsigned i = 2; i < inputRank; i++) {
273+
int64_t kernelSize = inputShape[i] - resultShape[i] + 1;
274+
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
275+
binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize)));
276+
cstPadding.push_back(cstZero);
277+
cstStrides.push_back(cstOne);
278+
}
279+
Value kernelSizeList = rewriter.create<Torch::PrimListConstructOp>(
280+
binder.getLoc(),
281+
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
282+
cstKernel);
283+
Value paddingList = rewriter.create<Torch::PrimListConstructOp>(
284+
binder.getLoc(),
285+
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
286+
cstPadding);
287+
Value stridesList = rewriter.create<Torch::PrimListConstructOp>(
288+
binder.getLoc(),
289+
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
290+
cstStrides);
291+
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
292+
Value cstCeilMode = cstFalse;
293+
Value cstCountIncludePad = cstFalse;
294+
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
295+
296+
if (inputRank == 3) {
297+
rewriter.replaceOpWithNewOp<Torch::AtenAvgPool1dOp>(
298+
binder.op, resultType, operand, kernelSizeList, stridesList,
299+
paddingList, cstCeilMode, cstCountIncludePad);
300+
return success();
301+
} else if (inputRank == 4) {
302+
rewriter.replaceOpWithNewOp<Torch::AtenAvgPool2dOp>(
303+
binder.op, resultType, operand, kernelSizeList, stridesList,
304+
paddingList, cstCeilMode, cstCountIncludePad,
305+
/*divisor_override=*/cstNone);
306+
return success();
307+
} else if (inputRank == 5) {
308+
rewriter.replaceOpWithNewOp<Torch::AtenAvgPool3dOp>(
309+
binder.op, resultType, operand, kernelSizeList, stridesList,
310+
paddingList, cstCeilMode, cstCountIncludePad,
311+
/*divisor_override=*/cstNone);
312+
return success();
313+
}
314+
return failure();
315+
});
245316
patterns.onOp("LeakyRelu", 16,
246317
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
247318
Torch::ValueTensorType resultType;

test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,39 @@ func.func @test_hardsigmoid_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torc
265265
%0 = torch.operator "onnx.HardSigmoid"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
266266
return %0 : !torch.vtensor<[3,4,5],f32>
267267
}
268+
269+
// -----
270+
271+
// CHECK-LABEL: @test_globalaveragepool
272+
func.func @test_globalaveragepool(%arg0: !torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
273+
// CHECK: %[[C0:.*]] = torch.constant.int 0
274+
// CHECK: %[[C1:.*]] = torch.constant.int 1
275+
// CHECK: %[[C5:.*]] = torch.constant.int 5
276+
// CHECK: %[[C5_0:.*]] = torch.constant.int 5
277+
// CHECK: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[C5]], %[[C5_0]] : (!torch.int, !torch.int) -> !torch.list<int>
278+
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
279+
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list<int>
280+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
281+
// CHECK: %[[NONE:.*]] = torch.constant.none
282+
// CHECK: torch.aten.avg_pool2d %arg0, %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,3,5,5],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,1,1],f32>
283+
%0 = torch.operator "onnx.GlobalAveragePool"(%arg0) : (!torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32>
284+
return %0 : !torch.vtensor<[1,3,1,1],f32>
285+
}
286+
287+
// -----
288+
289+
// CHECK-LABEL: @test_globalaveragepool_precomputed
290+
func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,1,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
291+
// CHECK: %[[C0:.*]] = torch.constant.int 0
292+
// CHECK: %[[C1:.*]] = torch.constant.int 1
293+
// CHECK: %[[C3:.*]] = torch.constant.int 3
294+
// CHECK: %[[C3_0:.*]] = torch.constant.int 3
295+
// CHECK: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[C3]], %[[C3_0]] : (!torch.int, !torch.int) -> !torch.list<int>
296+
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
297+
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list<int>
298+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
299+
// CHECK: %[[NONE:.*]] = torch.constant.none
300+
// CHECK: torch.aten.avg_pool2d %arg0, %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,1,3,3],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1,1],f32>
301+
%0 = torch.operator "onnx.GlobalAveragePool"(%arg0) : (!torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,1,1],f32>
302+
return %0 : !torch.vtensor<[1,1,1,1],f32>
303+
}

0 commit comments

Comments
 (0)