Skip to content

Commit 35dd8c5

Browse files
[ONNX] Add OnnxToTorch Lowering for MaxUnpool op (#3413)
This commit also adds the Torch declaration for aten.max_unpool2d and aten.max_unpool3d op. The TorchToLinalg lowering for the same will be added in a follow-up commit. Signed-Off By: Vivek Khandelwal <[email protected]>
1 parent 89f7d24 commit 35dd8c5

File tree

4 files changed

+171
-0
lines changed

4 files changed

+171
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6819,6 +6819,31 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
68196819
}];
68206820
}
68216821

6822+
def Torch_AtenMaxUnpool2dOp : Torch_Op<"aten.max_unpool2d", [
6823+
AllowsTypeRefinement,
6824+
HasValueSemantics,
6825+
ReadOnly
6826+
]> {
6827+
let summary = "Generated op for `aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)`";
6828+
let arguments = (ins
6829+
AnyTorchTensorType:$self,
6830+
AnyTorchTensorType:$indices,
6831+
AnyTorchListOfTorchIntType:$output_size
6832+
);
6833+
let results = (outs
6834+
AnyTorchOptionalTensorType:$result
6835+
);
6836+
let hasCustomAssemblyFormat = 1;
6837+
let extraClassDefinition = [{
6838+
ParseResult AtenMaxUnpool2dOp::parse(OpAsmParser &parser, OperationState &result) {
6839+
return parseDefaultTorchOp(parser, result, 3, 1);
6840+
}
6841+
void AtenMaxUnpool2dOp::print(OpAsmPrinter &printer) {
6842+
printDefaultTorchOp(printer, *this, 3, 1);
6843+
}
6844+
}];
6845+
}
6846+
68226847
def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [
68236848
AllowsTypeRefinement,
68246849
HasValueSemantics,
@@ -6907,6 +6932,33 @@ def Torch_AtenMaxPool3dOp : Torch_Op<"aten.max_pool3d", [
69076932
}];
69086933
}
69096934

6935+
def Torch_AtenMaxUnpool3dOp : Torch_Op<"aten.max_unpool3d", [
6936+
AllowsTypeRefinement,
6937+
HasValueSemantics,
6938+
ReadOnly
6939+
]> {
6940+
let summary = "Generated op for `aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)`";
6941+
let arguments = (ins
6942+
AnyTorchTensorType:$self,
6943+
AnyTorchTensorType:$indices,
6944+
AnyTorchListOfTorchIntType:$output_size,
6945+
AnyTorchListOfTorchIntType:$stride,
6946+
AnyTorchListOfTorchIntType:$padding
6947+
);
6948+
let results = (outs
6949+
AnyTorchOptionalTensorType:$result
6950+
);
6951+
let hasCustomAssemblyFormat = 1;
6952+
let extraClassDefinition = [{
6953+
ParseResult AtenMaxUnpool3dOp::parse(OpAsmParser &parser, OperationState &result) {
6954+
return parseDefaultTorchOp(parser, result, 5, 1);
6955+
}
6956+
void AtenMaxUnpool3dOp::print(OpAsmPrinter &printer) {
6957+
printDefaultTorchOp(printer, *this, 5, 1);
6958+
}
6959+
}];
6960+
}
6961+
69106962
def Torch_AtenMaxPool3dWithIndicesOp : Torch_Op<"aten.max_pool3d_with_indices", [
69116963
AllowsTypeRefinement,
69126964
HasValueSemantics,

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1926,4 +1926,82 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
19261926

19271927
return success();
19281928
});
1929+
patterns.onOp(
1930+
"MaxUnpool", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1931+
// TODO: Add support for `output_shape` arg.
1932+
if (binder.op->getNumOperands() == 3)
1933+
return rewriter.notifyMatchFailure(
1934+
binder.op, "unimplemented: output_shape arg is not supported");
1935+
1936+
Torch::ValueTensorType resultType;
1937+
Value data, indices;
1938+
if (binder.tensorOperandAtIndex(data, 0) ||
1939+
binder.tensorOperandAtIndex(indices, 1) ||
1940+
binder.tensorResultType(resultType))
1941+
return rewriter.notifyMatchFailure(
1942+
binder.op, "data/indices/resultType bind failure");
1943+
std::optional<unsigned> maybeRank = Torch::getTensorRank(data);
1944+
if (!maybeRank)
1945+
return rewriter.notifyMatchFailure(binder.op,
1946+
"Unimplemented: unranked tensor");
1947+
int64_t rank = *maybeRank;
1948+
int64_t spatial = rank - 2;
1949+
1950+
if (rank <= 3 || rank > 5)
1951+
return rewriter.notifyMatchFailure(binder.op,
1952+
"Unimplemented: MaxUnpool support "
1953+
"only present for rank 4/5 input");
1954+
1955+
if (!(resultType.hasSizes() && resultType.areAllSizesKnown()))
1956+
return rewriter.notifyMatchFailure(
1957+
binder.op, "unimplemented: expected result to have all shapes "
1958+
"statically known");
1959+
1960+
SmallVector<int64_t> resultShape(resultType.getSizes());
1961+
Value resultShapeList =
1962+
createConstantIntList(binder, rewriter, resultShape);
1963+
if (rank == 4) {
1964+
rewriter.replaceOpWithNewOp<Torch::AtenMaxUnpool2dOp>(
1965+
binder.op, resultType, data, indices, resultShapeList);
1966+
return success();
1967+
}
1968+
1969+
SmallVector<int64_t> padding, strides;
1970+
if (binder.s64IntegerArrayAttr(padding, "pads", {}))
1971+
return rewriter.notifyMatchFailure(binder.op, "pads bind failure");
1972+
if (!padding.empty() &&
1973+
padding.size() != static_cast<size_t>(2 * spatial))
1974+
return rewriter.notifyMatchFailure(
1975+
binder.op, "padding list must contain (begin,end) pair for each "
1976+
"spatial axis");
1977+
if (binder.s64IntegerArrayAttr(strides, "strides", {}))
1978+
return rewriter.notifyMatchFailure(binder.op, "strides bind failure");
1979+
if (!strides.empty() && strides.size() != static_cast<size_t>(spatial))
1980+
return rewriter.notifyMatchFailure(
1981+
binder.op, "strides list size does not match the number of axes");
1982+
1983+
if (padding.empty())
1984+
padding.resize(spatial, 0);
1985+
if (strides.empty())
1986+
strides.resize(spatial, 1);
1987+
1988+
// If the padding is symmetric we can push the padding
1989+
// operation to the torch operator.
1990+
if (padding.size() == static_cast<size_t>(2 * spatial)) {
1991+
bool equal = true;
1992+
for (int i = 0; i < spatial; ++i) {
1993+
equal = equal && (padding[i] == padding[i + spatial]);
1994+
}
1995+
if (equal)
1996+
padding.resize(spatial);
1997+
}
1998+
1999+
Value paddingList = createConstantIntList(binder, rewriter, padding);
2000+
Value stridesList = createConstantIntList(binder, rewriter, strides);
2001+
2002+
rewriter.replaceOpWithNewOp<Torch::AtenMaxUnpool3dOp>(
2003+
binder.op, resultType, data, indices, resultShapeList, stridesList,
2004+
paddingList);
2005+
return success();
2006+
});
19292007
}

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ def emit_with_mutating_variants(key, **kwargs):
597597
)
598598
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
599599
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
600+
emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)")
600601
emit(
601602
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
602603
has_canonicalizer=True,
@@ -605,6 +606,7 @@ def emit_with_mutating_variants(key, **kwargs):
605606
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
606607
)
607608
emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
609+
emit("aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)")
608610
emit(
609611
"aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
610612
)

test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,3 +1087,42 @@ func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torc
10871087
%0 = torch.operator "onnx.LpNormalization"(%arg0) {torch.onnx.axis = 2 : si64, torch.onnx.p = 2 : si64} : (!torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,1,6,7],f32>
10881088
return %0 : !torch.vtensor<[3,4,1,6,7],f32>
10891089
}
1090+
1091+
// -----
1092+
1093+
// CHECK-LABEL: func.func @test_maxunpool_export_without_output_shape
1094+
func.func @test_maxunpool_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1095+
// CHECK: %[[INT1:.*]] = torch.constant.int 1
1096+
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1
1097+
// CHECK: %[[INT4:.*]] = torch.constant.int 4
1098+
// CHECK: %[[INT4_0:.*]] = torch.constant.int 4
1099+
// CHECK: %[[OUTPUT_SHAPE:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_0]], %[[INT4]], %[[INT4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
1100+
// CHECK: %[[RESULT:.*]] = torch.aten.max_unpool2d %arg0, %arg1, %[[OUTPUT_SHAPE]] : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list<int> -> !torch.vtensor<[1,1,4,4],f32>
1101+
// return %[[RESULT]] : !torch.vtensor<[1,1,4,4],f32>
1102+
%0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32>
1103+
return %0 : !torch.vtensor<[1,1,4,4],f32>
1104+
}
1105+
1106+
// -----
1107+
1108+
// CHECK-LABEL: func.func @test_maxunpool3d_export_without_output_shape
1109+
func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1110+
// CHECK: %[[INT1:.*]] = torch.constant.int 1
1111+
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1
1112+
// CHECK: %[[INT4:.*]] = torch.constant.int 4
1113+
// CHECK: %[[INT4_0:.*]] = torch.constant.int 4
1114+
// CHECK: %[[INT4_1:.*]] = torch.constant.int 4
1115+
// CHECK: %[[OUTPUT_SHAPE:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_0]], %[[INT4]], %[[INT4_0]], %[[INT4_1]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
1116+
// CHECK: %[[INT0:.*]] = torch.constant.int 0
1117+
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
1118+
// CHECK: %[[INT0_2:.*]] = torch.constant.int 0
1119+
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0_1]], %[[INT0_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
1120+
// CHECK: %[[INT2:.*]] = torch.constant.int 2
1121+
// CHECK: %[[INT2_1:.*]] = torch.constant.int 2
1122+
// CHECK: %[[INT2_2:.*]] = torch.constant.int 2
1123+
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_1]], %[[INT2_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
1124+
// CHECK: %[[RESULT:.*]] = torch.aten.max_unpool3d %arg0, %arg1, %[[OUTPUT_SHAPE]], %[[STRIDE]], %[[PADDING]] : !torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,4,4,4],f32>
1125+
// return %[[RESULT]] : !torch.vtensor<[1,1,4,4,4],f32>
1126+
%0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32>
1127+
return %0 : !torch.vtensor<[1,1,4,4,4],f32>
1128+
}

0 commit comments

Comments
 (0)