Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7273,6 +7273,7 @@ def Torch_AtenMaxPool3dWithIndicesOp : Torch_Op<"aten.max_pool3d_with_indices",
printDefaultTorchOp(printer, *this, 6, 2);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenMaxPool3dWithIndicesBackwardOp : Torch_Op<"aten.max_pool3d_with_indices_backward", [
Expand Down
44 changes: 37 additions & 7 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5011,26 +5011,56 @@ OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) {
}

//===----------------------------------------------------------------------===//
// AtenMaxPool2dWithIndicesOp
// AtenMaxPoolWithIndicesOp
//===----------------------------------------------------------------------===//

void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](AtenMaxPool2dWithIndicesOp op, PatternRewriter &rewriter) {
namespace {

template <typename OpTy> struct MaxPoolWithoutIndices {
using type = OpTy;
};

template <> struct MaxPoolWithoutIndices<AtenMaxPool2dWithIndicesOp> {
using type = AtenMaxPool2dOp;
};

template <> struct MaxPoolWithoutIndices<AtenMaxPool3dWithIndicesOp> {
using type = AtenMaxPool3dOp;
};

} // namespace

template <typename OpTy>
struct SimplifyMaxPoolWithIndices : public mlir::OpRewritePattern<OpTy> {
SimplifyMaxPoolWithIndices(mlir::MLIRContext *context)
: OpRewritePattern<OpTy>(context, /*benefit=*/1) {}

LogicalResult
matchAndRewrite(OpTy op, mlir::PatternRewriter &rewriter) const override {
if (!op.getResult1().use_empty()) {
return rewriter.notifyMatchFailure(
op, "result1 of MaxPool2dWithIndices should be unused");
op, "result1 of MaxPoolWithIndices should be unused");
}

Value result = rewriter.create<Torch::AtenMaxPool2dOp>(
Value result = rewriter.create<typename MaxPoolWithoutIndices<OpTy>::type>(
op->getLoc(), op.getResult0().getType(), op.getSelf(),
op.getKernelSize(), op.getStride(), op.getPadding(), op.getDilation(),
op.getCeilMode());

op.getResult0().replaceAllUsesWith(result);
rewriter.eraseOp(op);
return success();
});
}
};

void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<SimplifyMaxPoolWithIndices<AtenMaxPool2dWithIndicesOp>>(context);
}

void AtenMaxPool3dWithIndicesOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<SimplifyMaxPoolWithIndices<AtenMaxPool3dWithIndicesOp>>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
5 changes: 0 additions & 5 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,12 +747,7 @@
"MaxPool2dWithIndicesBackwardStatic3DModule_basic",
"MaxPool2dWithIndicesBackwardStatic4DModule_basic",
"MaxPool3dCeilModeTrueModule_basic",
"MaxPool3dEmptyStrideStaticModule_basic",
"MaxPool3dLargeDatadModule_basic",
"MaxPool3dModuleRandomSimple_basic",
"MaxPool3dModule_basic",
"MaxPool3dStaticCeilModeTrueModule_basic",
"MaxPool3dStaticModule_basic",
"MaxPool3dWithIndicesAllNegativeValuesModule_basic",
"MaxPool3dWithIndicesAllOnesModule_basic",
"MaxPool3dWithIndicesCeilModeTrueModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,8 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit("aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)")
emit(
"aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
"aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
has_canonicalizer=True,
)
emit(
"aten::max_pool3d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
Expand Down
18 changes: 18 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3136,6 +3136,24 @@ func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor

// -----

// CHECK-LABEL: @torch.aten.max_pool3d_with_indices$canonicalize(
// CHECK: %[[ARG:.*]]: !torch.vtensor<[10,64,112,112,112],f32>) -> !torch.vtensor<[10,64,56,56,56],f32> {
// CHECK: %[[RET:.*]] = torch.aten.max_pool3d %[[ARG]]
// CHECK: return %[[RET]] : !torch.vtensor<[10,64,56,56,56],f32>
func.func @torch.aten.max_pool3d_with_indices$canonicalize(%arg0: !torch.vtensor<[10,64,112,112,112],f32>) -> !torch.vtensor<[10,64,56,56,56],f32> {
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%29 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%30 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%31 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%result0, %result1 = torch.aten.max_pool3d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[10,64,56,56,56],f32>, !torch.vtensor<[10,64,56,56,56],si64>
return %result0 : !torch.vtensor<[10,64,56,56,56],f32>
}

// -----

// CHECK-LABEL: @torch.aten.clone$no_fold(
func.func @torch.aten.clone$no_fold(%arg0: !torch.vtensor<[1,2,50,4],f32>) -> (!torch.tensor) {
// CHECK: %{{.*}} = torch.aten.clone %{{.*}}, %{{.*}} : !torch.vtensor<[1,2,50,4],f32>, !torch.none -> !torch.vtensor
Expand Down
Loading