Skip to content

Commit 057f51d

Browse files
committed
Generalize max_unpool lowering
1 parent 94f5410 commit 057f51d

File tree

6 files changed

+62
-44
lines changed

6 files changed

+62
-44
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7159,31 +7159,6 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
71597159
}];
71607160
}
71617161

7162-
def Torch_AtenMaxUnpool2dOp : Torch_Op<"aten.max_unpool2d", [
7163-
AllowsTypeRefinement,
7164-
HasValueSemantics,
7165-
ReadOnly
7166-
]> {
7167-
let summary = "Generated op for `aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)`";
7168-
let arguments = (ins
7169-
AnyTorchTensorType:$self,
7170-
AnyTorchTensorType:$indices,
7171-
AnyTorchListOfTorchIntType:$output_size
7172-
);
7173-
let results = (outs
7174-
AnyTorchOptionalTensorType:$result
7175-
);
7176-
let hasCustomAssemblyFormat = 1;
7177-
let extraClassDefinition = [{
7178-
ParseResult AtenMaxUnpool2dOp::parse(OpAsmParser &parser, OperationState &result) {
7179-
return parseDefaultTorchOp(parser, result, 3, 1);
7180-
}
7181-
void AtenMaxUnpool2dOp::print(OpAsmPrinter &printer) {
7182-
printDefaultTorchOp(printer, *this, 3, 1);
7183-
}
7184-
}];
7185-
}
7186-
71877162
def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [
71887163
AllowsTypeRefinement,
71897164
HasValueSemantics,

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3430,11 +3430,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
34303430
SmallVector<int64_t> resultShape(resultType.getSizes());
34313431
Value resultShapeList =
34323432
createConstantIntList(binder, rewriter, resultShape);
3433-
if (rank == 4) {
3434-
rewriter.replaceOpWithNewOp<Torch::AtenMaxUnpool2dOp>(
3435-
binder.op, resultType, data, indices, resultShapeList);
3436-
return success();
3437-
}
34383433

34393434
SmallVector<int64_t> padding, strides;
34403435
if (binder.s64IntegerArrayAttr(padding, "pads", {}))

lib/Conversion/TorchToLinalg/Pooling.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -611,21 +611,22 @@ class ConvertAtenMaxUnpool3dOp final
611611
Value self = adaptor.getSelf();
612612
auto selfType = cast<RankedTensorType>(self.getType());
613613

614-
ArrayRef<int64_t> inputSize = selfType.getShape().take_back(3);
614+
size_t spatial = selfType.getRank() - 2;
615+
ArrayRef<int64_t> inputSize = selfType.getShape().take_back(spatial);
615616
if (ShapedType::isDynamicShape(inputSize))
616617
return rewriter.notifyMatchFailure(op,
617618
"input type must be of static shape");
618619

619620
Value indices = adaptor.getIndices();
620621
auto indicesType = cast<RankedTensorType>(indices.getType());
621-
if (inputSize != indicesType.getShape().take_back(3))
622+
if (inputSize != indicesType.getShape().take_back(spatial))
622623
return rewriter.notifyMatchFailure(op, "input/indices shape mismatch");
623624

624625
auto resType = typeConverter->convertType<RankedTensorType>(op.getType());
625626
if (!resType)
626627
return rewriter.notifyMatchFailure(op, "invalid result type");
627628

628-
ArrayRef<int64_t> inferredOutSize = resType.getShape().take_back(3);
629+
ArrayRef<int64_t> inferredOutSize = resType.getShape().take_back(spatial);
629630
if (ShapedType::isDynamicShape(inferredOutSize))
630631
return rewriter.notifyMatchFailure(op,
631632
"output type must be of static shape");
@@ -636,7 +637,7 @@ class ConvertAtenMaxUnpool3dOp final
636637
return rewriter.notifyMatchFailure(op,
637638
"only support constant int output");
638639

639-
if (inferredOutSize != ArrayRef(output))
640+
if (inferredOutSize != ArrayRef(output).take_back(spatial))
640641
return rewriter.notifyMatchFailure(op, "Invalid output size");
641642
}
642643
SmallVector<int64_t> stride;
@@ -652,12 +653,12 @@ class ConvertAtenMaxUnpool3dOp final
652653

653654
// TODO: add support for asymmetric padding coming from "onnx.MaxUnpool"
654655
// (padding.size() == 6).
655-
if (stride.size() != 3 || padding.size() != 3)
656+
if (stride.size() != spatial || padding.size() != spatial)
656657
return rewriter.notifyMatchFailure(
657658
op, "stride and padding must be of size 3");
658659

659660
int64_t outRank = resType.getRank();
660-
int64_t NC = outRank - 3;
661+
int64_t NC = outRank - spatial;
661662

662663
for (auto &&[inDim, outDim, str, pad] :
663664
llvm::zip_equal(inputSize, inferredOutSize, stride, padding)) {
@@ -694,7 +695,7 @@ class ConvertAtenMaxUnpool3dOp final
694695
// (e.g. pooling_input_size=5, kernel_size=2, stride=2, output_size=2)
695696
// pad self and indices tensors to avoid out of bounds access.
696697
SmallVector<int64_t> expectedInputShape =
697-
llvm::to_vector(resType.getShape().drop_back(3));
698+
llvm::to_vector(resType.getShape().drop_back(spatial));
698699
for (auto &&[str, pad, resSize] :
699700
llvm::zip_equal(stride, padding, inferredOutSize))
700701
expectedInputShape.emplace_back(ceilDiv(resSize, str) + pad * 2);
@@ -707,7 +708,7 @@ class ConvertAtenMaxUnpool3dOp final
707708
SmallVector<int64_t> low(outRank, 0);
708709
SmallVector<int64_t> high(NC, 0);
709710
for (auto &&[inpSize, outSize] : llvm::zip_equal(
710-
inputSize, ArrayRef(expectedInputShape).take_back(3))) {
711+
inputSize, ArrayRef(expectedInputShape).take_back(spatial))) {
711712
high.emplace_back(outSize - inpSize);
712713
}
713714

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,6 @@ def emit_with_mutating_variants(key, **kwargs):
622622
"aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
623623
)
624624
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
625-
emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)")
626625
emit(
627626
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
628627
has_canonicalizer=True,

test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,23 +1667,29 @@ func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torc
16671667

16681668
// -----
16691669

1670-
// CHECK-LABEL: func.func @test_maxunpool_export_without_output_shape
1671-
func.func @test_maxunpool_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1670+
// CHECK-LABEL: func.func @test_maxunpool_2d_export_without_output_shape
1671+
func.func @test_maxunpool_2d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
16721672
// CHECK: %[[INT1:.*]] = torch.constant.int 1
16731673
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1
16741674
// CHECK: %[[INT4:.*]] = torch.constant.int 4
16751675
// CHECK: %[[INT4_0:.*]] = torch.constant.int 4
16761676
// CHECK: %[[OUTPUT_SHAPE:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_0]], %[[INT4]], %[[INT4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
1677-
// CHECK: %[[RESULT:.*]] = torch.aten.max_unpool2d %arg0, %arg1, %[[OUTPUT_SHAPE]] : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list<int> -> !torch.vtensor<[1,1,4,4],f32>
1677+
// CHECK: %[[INT0:.*]] = torch.constant.int 0
1678+
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
1679+
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
1680+
// CHECK: %[[INT2:.*]] = torch.constant.int 2
1681+
// CHECK: %[[INT2_1:.*]] = torch.constant.int 2
1682+
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_1]] : (!torch.int, !torch.int) -> !torch.list<int>
1683+
// CHECK: %[[RESULT:.*]] = torch.aten.max_unpool3d %arg0, %arg1, %[[OUTPUT_SHAPE]], %[[STRIDE]], %[[PADDING]] : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,4,4],f32>
16781684
// return %[[RESULT]] : !torch.vtensor<[1,1,4,4],f32>
16791685
%0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32>
16801686
return %0 : !torch.vtensor<[1,1,4,4],f32>
16811687
}
16821688

16831689
// -----
16841690

1685-
// CHECK-LABEL: func.func @test_maxunpool3d_export_without_output_shape
1686-
func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1691+
// CHECK-LABEL: func.func @test_maxunpool_3d_export_without_output_shape
1692+
func.func @test_maxunpool_3d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
16871693
// CHECK: %[[INT1:.*]] = torch.constant.int 1
16881694
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1
16891695
// CHECK: %[[INT4:.*]] = torch.constant.int 4

test/Conversion/TorchToLinalg/pooling.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,45 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.
9595
// CHECK: } -> tensor<?x?x?x?x?xf32>
9696
return %4 : !torch.vtensor<[?,?,?,?,?],f32>
9797
}
98+
99+
// -----
100+
101+
// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 2, d3 floordiv 2)>
102+
// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
103+
// CHECK-LABEL: func @forward_max_unpool
104+
func.func @forward_max_unpool(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
105+
%int1 = torch.constant.int 1
106+
%int1_0 = torch.constant.int 1
107+
%int4 = torch.constant.int 4
108+
%int4_1 = torch.constant.int 4
109+
%0 = torch.prim.ListConstruct %int1, %int1_0, %int4, %int4_1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
110+
%int0 = torch.constant.int 0
111+
%int0_2 = torch.constant.int 0
112+
%1 = torch.prim.ListConstruct %int0, %int0_2 : (!torch.int, !torch.int) -> !torch.list<int>
113+
%int2 = torch.constant.int 2
114+
%int2_3 = torch.constant.int 2
115+
%2 = torch.prim.ListConstruct %int2, %int2_3 : (!torch.int, !torch.int) -> !torch.list<int>
116+
%3 = torch.aten.max_unpool3d %arg0, %arg1, %0, %2, %1 : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,4,4],f32>
117+
118+
// CHECK: %[[INDICES:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,1,2,2],si64> -> tensor<1x1x2x2xi64>
119+
// CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,1,2,2],f32> -> tensor<1x1x2x2xf32>
120+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
121+
// CHECK: %[[DIM0:.*]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<1x1x2x2xf32>
122+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
123+
// CHECK: %[[DIM1:.*]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<1x1x2x2xf32>
124+
// CHECK: %[[SHAPE:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?x4x4xf32>
125+
// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[INPUT]], %[[INDICES]] : tensor<1x1x2x2xf32>, tensor<1x1x2x2xi64>) outs(%[[SHAPE]] : tensor<?x?x4x4xf32>) {
126+
// CHECK-NEXT: ^bb0(%[[CURRENT_VALUE:.*]]: f32, %[[CURRENT_INDEX:.*]]: i64, %[[OUT:.*]]: f32):
127+
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
128+
// CHECK-NEXT: %[[INDEX_CAST:.*]] = arith.index_cast %[[CURRENT_INDEX:.*]] : i64 to index
129+
// CHECK-NEXT: %[[INDEX2:.*]] = linalg.index 2 : index
130+
// CHECK-NEXT: %[[INDEX3:.*]] = linalg.index 3 : index
131+
// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : index
132+
// CHECK-NEXT: %[[MULI:.*]] = arith.muli %[[INDEX2:.*]], %[[C4:.*]] : index
133+
// CHECK-NEXT: %[[ADDI:.*]] = arith.addi %[[MULI:.*]], %[[INDEX3:.*]] : index
134+
// CHECK-NEXT: %[[CMPI:.*]] = arith.cmpi eq, %[[INDEX_CAST:.*]], %[[ADDI:.*]] : index
135+
// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMPI:.*]], %[[CURRENT_VALUE:.*]], %[[CST:.*]] : f32
136+
// CHECK-NEXT: linalg.yield %[[SELECT:.*]] : f32
137+
// CHECK: } -> tensor<?x?x4x4xf32>
138+
return %3 : !torch.vtensor<[1,1,4,4],f32>
139+
}

0 commit comments

Comments
 (0)