Skip to content

Commit 828f1b6

Browse files
committed
Generalize MaxUnpool lowering including 2d case
1 parent 057f51d commit 828f1b6

File tree

6 files changed

+82
-23
lines changed

6 files changed

+82
-23
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7159,6 +7159,33 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
71597159
}];
71607160
}
71617161

7162+
def Torch_AtenMaxUnpool2dOp : Torch_Op<"aten.max_unpool2d", [
7163+
AllowsTypeRefinement,
7164+
HasValueSemantics,
7165+
ReadOnly
7166+
]> {
7167+
let summary = "Generated op for `aten::max_unpool2d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)`";
7168+
let arguments = (ins
7169+
AnyTorchTensorType:$self,
7170+
AnyTorchTensorType:$indices,
7171+
AnyTorchListOfTorchIntType:$output_size,
7172+
AnyTorchListOfTorchIntType:$stride,
7173+
AnyTorchListOfTorchIntType:$padding
7174+
);
7175+
let results = (outs
7176+
AnyTorchOptionalTensorType:$result
7177+
);
7178+
let hasCustomAssemblyFormat = 1;
7179+
let extraClassDefinition = [{
7180+
ParseResult AtenMaxUnpool2dOp::parse(OpAsmParser &parser, OperationState &result) {
7181+
return parseDefaultTorchOp(parser, result, 5, 1);
7182+
}
7183+
void AtenMaxUnpool2dOp::print(OpAsmPrinter &printer) {
7184+
printDefaultTorchOp(printer, *this, 5, 1);
7185+
}
7186+
}];
7187+
}
7188+
71627189
def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [
71637190
AllowsTypeRefinement,
71647191
HasValueSemantics,

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3464,6 +3464,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
34643464
Value paddingList = createConstantIntList(binder, rewriter, padding);
34653465
Value stridesList = createConstantIntList(binder, rewriter, strides);
34663466

3467+
if (rank == 4) {
3468+
rewriter.replaceOpWithNewOp<Torch::AtenMaxUnpool2dOp>(
3469+
binder.op, resultType, data, indices, resultShapeList,
3470+
stridesList, paddingList);
3471+
return success();
3472+
}
34673473
rewriter.replaceOpWithNewOp<Torch::AtenMaxUnpool3dOp>(
34683474
binder.op, resultType, data, indices, resultShapeList, stridesList,
34693475
paddingList);

lib/Conversion/TorchToLinalg/Pooling.cpp

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -596,37 +596,51 @@ namespace {
596596
// input_size=2, output_size=5 and stride=2, kernel_size can be either 2 or 3).
597597
// What worse, without knowing kernel size we cannot even reliably detect such
598598
// cases and this conversion will just return invalid values.
599-
class ConvertAtenMaxUnpool3dOp final
600-
: public OpConversionPattern<AtenMaxUnpool3dOp> {
601-
public:
602-
using OpConversionPattern::OpConversionPattern;
603-
LogicalResult
604-
matchAndRewrite(AtenMaxUnpool3dOp op, OpAdaptor adaptor,
605-
ConversionPatternRewriter &rewriter) const override {
599+
600+
template <> struct DimensionTraits<AtenMaxUnpool2dOp> {
601+
static constexpr int64_t Dim = 2;
602+
// unused const variable warning suppression:
603+
static_assert(Dim == Dim);
604+
};
605+
606+
template <> struct DimensionTraits<AtenMaxUnpool3dOp> {
607+
static constexpr int64_t Dim = 3;
608+
// unused const variable warning suppression:
609+
static_assert(Dim == Dim);
610+
};
611+
612+
template <typename OpTy>
613+
class ConvertAtenMaxUnpoolOp : public OpConversionPattern<OpTy> {
614+
using OpConversionPattern<OpTy>::OpConversionPattern;
615+
616+
private:
617+
static const int64_t Dim = DimensionTraits<OpTy>::Dim;
618+
619+
LogicalResult createUnpoolOp(OpTy &op, typename OpTy::Adaptor adaptor,
620+
ConversionPatternRewriter &rewriter) const {
606621
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
607622
return failure();
608623

609624
Location loc = op->getLoc();
610-
const TypeConverter *typeConverter = getTypeConverter();
625+
const TypeConverter *typeConverter = this->getTypeConverter();
611626
Value self = adaptor.getSelf();
612627
auto selfType = cast<RankedTensorType>(self.getType());
613628

614-
size_t spatial = selfType.getRank() - 2;
615-
ArrayRef<int64_t> inputSize = selfType.getShape().take_back(spatial);
629+
ArrayRef<int64_t> inputSize = selfType.getShape().take_back(Dim);
616630
if (ShapedType::isDynamicShape(inputSize))
617631
return rewriter.notifyMatchFailure(op,
618632
"input type must be of static shape");
619633

620634
Value indices = adaptor.getIndices();
621635
auto indicesType = cast<RankedTensorType>(indices.getType());
622-
if (inputSize != indicesType.getShape().take_back(spatial))
636+
if (inputSize != indicesType.getShape().take_back(Dim))
623637
return rewriter.notifyMatchFailure(op, "input/indices shape mismatch");
624638

625639
auto resType = typeConverter->convertType<RankedTensorType>(op.getType());
626640
if (!resType)
627641
return rewriter.notifyMatchFailure(op, "invalid result type");
628642

629-
ArrayRef<int64_t> inferredOutSize = resType.getShape().take_back(spatial);
643+
ArrayRef<int64_t> inferredOutSize = resType.getShape().take_back(Dim);
630644
if (ShapedType::isDynamicShape(inferredOutSize))
631645
return rewriter.notifyMatchFailure(op,
632646
"output type must be of static shape");
@@ -637,7 +651,7 @@ class ConvertAtenMaxUnpool3dOp final
637651
return rewriter.notifyMatchFailure(op,
638652
"only support constant int output");
639653

640-
if (inferredOutSize != ArrayRef(output).take_back(spatial))
654+
if (inferredOutSize != ArrayRef(output).take_back(Dim))
641655
return rewriter.notifyMatchFailure(op, "Invalid output size");
642656
}
643657
SmallVector<int64_t> stride;
@@ -653,12 +667,12 @@ class ConvertAtenMaxUnpool3dOp final
653667

654668
// TODO: add support for asymmetric padding coming from "onnx.MaxUnpool"
655669
// (padding.size() == 6).
656-
if (stride.size() != spatial || padding.size() != spatial)
670+
if (stride.size() != Dim || padding.size() != Dim)
657671
return rewriter.notifyMatchFailure(
658-
op, "stride and padding must be of size 3");
672+
op, "stride and padding must be of size Dim");
659673

660674
int64_t outRank = resType.getRank();
661-
int64_t NC = outRank - spatial;
675+
int64_t NC = outRank - Dim;
662676

663677
for (auto &&[inDim, outDim, str, pad] :
664678
llvm::zip_equal(inputSize, inferredOutSize, stride, padding)) {
@@ -695,7 +709,7 @@ class ConvertAtenMaxUnpool3dOp final
695709
// (e.g. pooling_input_size=5, kernel_size=2, stride=2, output_size=2)
696710
// pad self and indices tensors to avoid out of bounds access.
697711
SmallVector<int64_t> expectedInputShape =
698-
llvm::to_vector(resType.getShape().drop_back(spatial));
712+
llvm::to_vector(resType.getShape().drop_back(Dim));
699713
for (auto &&[str, pad, resSize] :
700714
llvm::zip_equal(stride, padding, inferredOutSize))
701715
expectedInputShape.emplace_back(ceilDiv(resSize, str) + pad * 2);
@@ -708,7 +722,7 @@ class ConvertAtenMaxUnpool3dOp final
708722
SmallVector<int64_t> low(outRank, 0);
709723
SmallVector<int64_t> high(NC, 0);
710724
for (auto &&[inpSize, outSize] : llvm::zip_equal(
711-
inputSize, ArrayRef(expectedInputShape).take_back(spatial))) {
725+
inputSize, ArrayRef(expectedInputShape).take_back(Dim))) {
712726
high.emplace_back(outSize - inpSize);
713727
}
714728

@@ -827,6 +841,13 @@ class ConvertAtenMaxUnpool3dOp final
827841
rewriter.replaceOp(op, result);
828842
return success();
829843
}
844+
845+
public:
846+
LogicalResult
847+
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
848+
ConversionPatternRewriter &rewriter) const override {
849+
return createUnpoolOp(op, adaptor, rewriter);
850+
}
830851
};
831852
} // namespace
832853

@@ -1527,8 +1548,12 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
15271548
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dWithIndicesOp>>(typeConverter,
15281549
context);
15291550

1551+
target.addIllegalOp<AtenMaxUnpool2dOp>();
15301552
target.addIllegalOp<AtenMaxUnpool3dOp>();
1531-
patterns.add<ConvertAtenMaxUnpool3dOp>(typeConverter, context);
1553+
patterns.add<ConvertAtenMaxUnpoolOp<AtenMaxUnpool2dOp>>(typeConverter,
1554+
context);
1555+
patterns.add<ConvertAtenMaxUnpoolOp<AtenMaxUnpool3dOp>>(typeConverter,
1556+
context);
15321557

15331558
target.addIllegalOp<AtenAvgPool1dOp, AtenAvgPool2dOp, AtenAvgPool3dOp>();
15341559
patterns

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,7 @@ def emit_with_mutating_variants(key, **kwargs):
622622
"aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
623623
)
624624
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
625+
emit("aten::max_unpool2d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)")
625626
emit(
626627
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
627628
has_canonicalizer=True,

test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1680,7 +1680,7 @@ func.func @test_maxunpool_2d_export_without_output_shape(%arg0: !torch.vtensor<[
16801680
// CHECK: %[[INT2:.*]] = torch.constant.int 2
16811681
// CHECK: %[[INT2_1:.*]] = torch.constant.int 2
16821682
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_1]] : (!torch.int, !torch.int) -> !torch.list<int>
1683-
// CHECK: %[[RESULT:.*]] = torch.aten.max_unpool3d %arg0, %arg1, %[[OUTPUT_SHAPE]], %[[STRIDE]], %[[PADDING]] : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,4,4],f32>
1683+
// CHECK: %[[RESULT:.*]] = torch.aten.max_unpool2d %arg0, %arg1, %[[OUTPUT_SHAPE]], %[[STRIDE]], %[[PADDING]] : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,4,4],f32>
16841684
// return %[[RESULT]] : !torch.vtensor<[1,1,4,4],f32>
16851685
%0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32>
16861686
return %0 : !torch.vtensor<[1,1,4,4],f32>

test/Conversion/TorchToLinalg/pooling.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.
100100

101101
// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 2, d3 floordiv 2)>
102102
// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
103-
// CHECK-LABEL: func @forward_max_unpool
104-
func.func @forward_max_unpool(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
103+
// CHECK-LABEL: func @forward_max_unpool2d
104+
func.func @forward_max_unpool2d(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
105105
%int1 = torch.constant.int 1
106106
%int1_0 = torch.constant.int 1
107107
%int4 = torch.constant.int 4
@@ -113,7 +113,7 @@ func.func @forward_max_unpool(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torc
113113
%int2 = torch.constant.int 2
114114
%int2_3 = torch.constant.int 2
115115
%2 = torch.prim.ListConstruct %int2, %int2_3 : (!torch.int, !torch.int) -> !torch.list<int>
116-
%3 = torch.aten.max_unpool3d %arg0, %arg1, %0, %2, %1 : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,4,4],f32>
116+
%3 = torch.aten.max_unpool2d %arg0, %arg1, %0, %2, %1 : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,4,4],f32>
117117

118118
// CHECK: %[[INDICES:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,1,2,2],si64> -> tensor<1x1x2x2xi64>
119119
// CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,1,2,2],f32> -> tensor<1x1x2x2xf32>

0 commit comments

Comments
 (0)