Skip to content

Commit 7712b97

Browse files
[Torch] Canonicalize aten.convolution with single int tuple params (#4388)
This PR is a more robust fix for the issue captured in #4380 . Essentially, `torch.ops.aten.convolution` lowering to `tosa`, `linalg` and `stablehlo` fail if `stride`, `padding`, `dilation` or `output_padding` is a tuple with a singleton element, but convolution operates on 2 or 3 spatial dimensions. In the failing case, `torch.nn.Conv2d` with `padding= 'valid'` generates a `torch.ops.aten.conv2d.padding` op in `ExportedProgram`, which is later decomposed to `torch.ops.aten.convolution.default` with a single padding value of [0] after running `ep.run_decompositions()`. In #4380, I attempted to fix just the `torch-to-tosa` pass, but I later realised that this is a more general bug in all the backends for multiple params (thanks to #4380 (comment)). ### Fix: I followed #4250 to canonicalize aten.convolution if it operates on 2 or 3 spatial dims but params are singleton. For example, if `aten.convolution` is 2D but `padding == [0]`, we canonicalize it to `padding == [0, 0]`
1 parent d1f6dcf commit 7712b97

File tree

6 files changed

+301
-1
lines changed

6 files changed

+301
-1
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7119,6 +7119,7 @@ def Torch_AtenConvolutionOp : Torch_Op<"aten.convolution", [
71197119
printDefaultTorchOp(printer, *this, 9, 1);
71207120
}
71217121
}];
7122+
let hasCanonicalizer = 1;
71227123
}
71237124

71247125
def Torch_Aten_ConvolutionOp : Torch_Op<"aten._convolution", [

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
// Also available under a BSD-style license. See LICENSE.
77
//
88
//===----------------------------------------------------------------------===//
9+
#include "llvm/ADT/SmallVector.h"
910
#define DEBUG_TYPE "torch-mlir-torch-dialect"
1011
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
1112
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
@@ -5898,6 +5899,160 @@ void AtenMaxPool3dOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
58985899
patterns.add<CanonicalizeMaxPoolWithSingleIntTuple<AtenMaxPool3dOp>>(context);
58995900
}
59005901

5902+
namespace {
5903+
class CanonicalizeConvolutionWithSingleIntTuple
5904+
: public OpRewritePattern<AtenConvolutionOp> {
5905+
public:
5906+
using OpRewritePattern<AtenConvolutionOp>::OpRewritePattern;
5907+
5908+
LogicalResult matchAndRewrite(AtenConvolutionOp op,
5909+
PatternRewriter &rewriter) const override {
5910+
5911+
auto weight = op.getWeight();
5912+
auto weightType = dyn_cast<ValueTensorType>(weight.getType());
5913+
5914+
if (!weightType) {
5915+
return rewriter.notifyMatchFailure(op, "weight is not a vtensor");
5916+
}
5917+
auto optionalSizes = weightType.getOptionalSizes();
5918+
if (!optionalSizes.has_value()) {
5919+
return rewriter.notifyMatchFailure(op,
5920+
"unranked weight tensor unsupported!");
5921+
}
5922+
5923+
// The rank is the size of the dimensions array
5924+
int64_t weightRank = optionalSizes.value().size();
5925+
5926+
// We canonicalize Rank 4 (2D Conv) or Rank 5 (3D Conv).
5927+
if (weightRank < 4 || weightRank > 5) {
5928+
return rewriter.notifyMatchFailure(
5929+
op, "unsupported weight rank (must be 4 or 5)");
5930+
}
5931+
int requiredSpatialDims = weightRank - 2;
5932+
5933+
// Validate stride, padding, output_padding, and dilation are constant
5934+
// lists.
5935+
SmallVector<int64_t, 3> strideInts;
5936+
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) {
5937+
return rewriter.notifyMatchFailure(op,
5938+
"non-const int stride unsupported!");
5939+
}
5940+
SmallVector<int64_t, 3> paddingInts;
5941+
if (!matchPattern(op.getPadding(),
5942+
m_TorchListOfConstantInts(paddingInts))) {
5943+
return rewriter.notifyMatchFailure(op,
5944+
"non-const int padding unsupported!");
5945+
}
5946+
5947+
SmallVector<int64_t, 3> dilationInts;
5948+
if (!matchPattern(op.getDilation(),
5949+
m_TorchListOfConstantInts(dilationInts))) {
5950+
return rewriter.notifyMatchFailure(op,
5951+
"non-const int dilation unsupported!");
5952+
}
5953+
5954+
bool transposed;
5955+
if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) {
5956+
return rewriter.notifyMatchFailure(
5957+
op, "non-const int tranposed unsupported!");
5958+
}
5959+
5960+
SmallVector<int64_t, 3> outputPaddingInts;
5961+
if (!matchPattern(op.getOutputPadding(),
5962+
m_TorchListOfConstantInts(outputPaddingInts))) {
5963+
return rewriter.notifyMatchFailure(
5964+
op, "non-const int output_padding unsupported!");
5965+
}
5966+
5967+
// Canonicalization Logic: Only rewrite if convolution attribute provided is
5968+
// 1 element but the convolution requires 2 or 3 elements.
5969+
auto isCanonical = [requiredSpatialDims](ArrayRef<int64_t> param) {
5970+
return param.size() == static_cast<size_t>(requiredSpatialDims);
5971+
};
5972+
5973+
if (isCanonical(strideInts) && isCanonical(paddingInts) &&
5974+
isCanonical(dilationInts)) {
5975+
return rewriter.notifyMatchFailure(
5976+
op, "stride, padding, dialtion and outputPadding is already fully "
5977+
"specified");
5978+
}
5979+
5980+
if (transposed && isCanonical(outputPaddingInts)) {
5981+
return rewriter.notifyMatchFailure(
5982+
op, "output_padding is already fully specified");
5983+
}
5984+
5985+
expand(strideInts, requiredSpatialDims);
5986+
expand(paddingInts, requiredSpatialDims);
5987+
expand(dilationInts, requiredSpatialDims);
5988+
5989+
if (transposed)
5990+
expand(outputPaddingInts, requiredSpatialDims);
5991+
5992+
// Construct the new List
5993+
// For example: If user provided padding=[1], and we need 2 or 3 dims, we
5994+
// create padding=[1, 1] or padding = [1,1,1]
5995+
Location loc = op.getLoc();
5996+
SmallVector<Value> cstPadding, cstStrides, cstDilation, cstOutputPadding;
5997+
5998+
for (auto dim : llvm::seq<int>(0, requiredSpatialDims)) {
5999+
6000+
cstStrides.push_back(Torch::ConstantIntOp::create(
6001+
rewriter, loc, rewriter.getI64IntegerAttr(strideInts[dim])));
6002+
6003+
cstPadding.push_back(Torch::ConstantIntOp::create(
6004+
rewriter, loc, rewriter.getI64IntegerAttr(paddingInts[dim])));
6005+
6006+
cstDilation.push_back(Torch::ConstantIntOp::create(
6007+
rewriter, loc, rewriter.getI64IntegerAttr(dilationInts[dim])));
6008+
6009+
if (transposed)
6010+
cstOutputPadding.push_back(Torch::ConstantIntOp::create(
6011+
rewriter, loc, rewriter.getI64IntegerAttr(outputPaddingInts[dim])));
6012+
}
6013+
6014+
auto targetListType =
6015+
Torch::ListType::get(Torch::IntType::get(op->getContext()));
6016+
6017+
// Create the list construct op
6018+
auto stridesList = Torch::PrimListConstructOp::create(
6019+
rewriter, loc, targetListType, cstStrides);
6020+
auto paddingList = Torch::PrimListConstructOp::create(
6021+
rewriter, loc, targetListType, cstPadding);
6022+
auto dilationsList = Torch::PrimListConstructOp::create(
6023+
rewriter, loc, targetListType, cstDilation);
6024+
6025+
Value outputPaddingList;
6026+
if (transposed) {
6027+
outputPaddingList = Torch::PrimListConstructOp::create(
6028+
rewriter, loc, targetListType, cstOutputPadding);
6029+
} else {
6030+
outputPaddingList = op.getOutputPadding();
6031+
}
6032+
6033+
// Replace the Op
6034+
// We create a new convolution op, keeping all operands the same except
6035+
// stride, padding,dilation, and output_padding which are now fully
6036+
// specified
6037+
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
6038+
op, op.getType(), op.getInput(), op.getWeight(), op.getBias(),
6039+
stridesList.getResult(), paddingList.getResult(),
6040+
dilationsList.getResult(), op.getTransposed(), outputPaddingList,
6041+
op.getGroups());
6042+
6043+
return success();
6044+
}
6045+
};
6046+
} // namespace
6047+
6048+
//===----------------------------------------------------------------------===//
6049+
// AtenConvolutionOp Registration
6050+
//===----------------------------------------------------------------------===//
6051+
void AtenConvolutionOp::getCanonicalizationPatterns(RewritePatternSet &results,
6052+
MLIRContext *context) {
6053+
results.add<CanonicalizeConvolutionWithSingleIntTuple>(context);
6054+
}
6055+
59016056
//===----------------------------------------------------------------------===//
59026057
// AtenLinalgCrossOp
59036058
//===----------------------------------------------------------------------===//

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,8 +1130,10 @@
11301130
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
11311131
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
11321132
"Convolution2DStaticModule_basic",
1133+
"Convolution2DSingleIntTupleModule_basic",
11331134
"ConvolutionBackwardModule2DStatic_basic",
11341135
"ConvolutionModule2DTransposeStridedStatic_basic",
1136+
"ConvolutionModule2DTransposeScalarTupleParams_basic",
11351137
"Conv_Transpose1dStaticModule_basic",
11361138
"ConstantPad2dStaticModule_basic",
11371139
"ConstantPadNdModule_basic",
@@ -2163,6 +2165,7 @@
21632165
"Conv2dWithValidPaddingModule_basic",
21642166
"Conv2dWithSamePaddingModule_basic",
21652167
"Convolution2DStaticModule_basic",
2168+
"Convolution2DSingleIntTupleModule_basic",
21662169
"Conv3dModule_basic",
21672170
"Conv3dWithSamePaddingModule_basic",
21682171
"Conv3dWithValidPaddingModule_basic",
@@ -2912,6 +2915,13 @@
29122915
"Conv2dWithPaddingModule_basic",
29132916
"Conv2dWithSamePaddingModule_basic",
29142917
"Conv2dWithValidPaddingModule_basic",
2918+
"Conv3dModule_basic",
2919+
"Conv3dModuleScalarTupleParams_basic",
2920+
"Conv3dWithSamePaddingModule_basic",
2921+
"Conv3dWithValidPaddingModule_basic",
2922+
"ConvolutionModule3DGroups_basic",
2923+
"ConvolutionModule3DGroupsStrided_basic",
2924+
"ConvolutionModule3DGroupsDilated_basic",
29152925
"ConvTbcModule_basic",
29162926
"ConvTranspose2DQInt8_basic",
29172927
"Conv_Transpose2dModule_basic",
@@ -2922,7 +2932,9 @@
29222932
"ConvolutionBackwardModule2DStrided_basic",
29232933
"ConvolutionBackwardModule2D_basic",
29242934
"ConvolutionModule2DGroups_basic",
2935+
"Convolution2DSingleIntTupleModule_basic",
29252936
"ConvolutionModule2DTransposeNonUnitOutputPadding_basic",
2937+
"ConvolutionModule2DTransposeScalarTupleParams_basic",
29262938
"ConvolutionModule2DTransposeStrided_basic",
29272939
"ConvolutionModule2DTranspose_basic",
29282940
# Error: onnx lowering,
@@ -4310,18 +4322,24 @@
43104322
"Conv2dWithPaddingModule_basic",
43114323
"Conv2dWithSamePaddingModule_basic",
43124324
"Conv2dWithValidPaddingModule_basic",
4325+
"Conv3dModule_basic",
4326+
"Conv3dModuleScalarTupleParams_basic",
4327+
"Conv3dWithSamePaddingModule_basic",
4328+
"Conv3dWithValidPaddingModule_basic",
43134329
"ConvTbcModule_basic",
43144330
"ConvTranspose2DQInt8_basic",
43154331
"Conv_Transpose2dModule_basic",
43164332
"Convolution2DModule_basic",
43174333
"Convolution2DStridedModule_basic",
4334+
"Convolution2DSingleIntTupleModule_basic",
43184335
"ConvolutionBackwardModule2DPadded_basic",
43194336
"ConvolutionBackwardModule2DStatic_basic",
43204337
"ConvolutionBackwardModule2DStrided_basic",
43214338
"ConvolutionBackwardModule2D_basic",
43224339
"ConvolutionModule2DGroups_basic",
43234340
"ConvolutionModule2DTransposeNonUnitOutputPadding_basic",
43244341
"ConvolutionModule2DTransposeStridedStatic_basic",
4342+
"ConvolutionModule2DTransposeScalarTupleParams_basic",
43254343
"ConvolutionModule2DTransposeStrided_basic",
43264344
"ConvolutionModule2DTranspose_basic",
43274345
"ConvolutionModule2DGroupedTranspose_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,8 @@ def emit_with_mutating_variants(key, **kwargs):
612612
"aten::conv_tbc_backward : (Tensor, Tensor, Tensor, Tensor, int) -> (Tensor, Tensor, Tensor)"
613613
)
614614
emit(
615-
"aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)"
615+
"aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)",
616+
has_canonicalizer=True,
616617
)
617618
emit(
618619
"aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)"

projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,37 @@ def Convolution2DStaticModule_basic(module, tu: TestUtils):
304304
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))
305305

306306

307+
class Convolution2DSingleIntTupleModule(torch.nn.Module):
308+
def __init__(self):
309+
super().__init__()
310+
311+
@export
312+
@annotate_args(
313+
[
314+
None,
315+
([3, 3, 10, 10], torch.float32, True),
316+
([3, 3, 2, 2], torch.float32, True),
317+
]
318+
)
319+
def forward(self, inputVec, weight):
320+
return torch.ops.aten.convolution(
321+
inputVec,
322+
weight,
323+
bias=None,
324+
stride=(1,),
325+
padding=(0,),
326+
dilation=(1,),
327+
transposed=False,
328+
output_padding=[0, 0],
329+
groups=1,
330+
)
331+
332+
333+
@register_test_case(module_factory=lambda: Convolution2DSingleIntTupleModule())
334+
def Convolution2DSingleIntTupleModule_basic(module, tu: TestUtils):
335+
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))
336+
337+
307338
class Convolution2DStridedModule(torch.nn.Module):
308339
def __init__(self):
309340
super().__init__()
@@ -901,6 +932,39 @@ def ConvolutionModule2DTransposeNonUnitOutputPadding_basic(module, tu: TestUtils
901932
module.forward(tu.rand(1, 2, 4, 4), tu.rand(2, 2, 3, 3))
902933

903934

935+
class ConvolutionModule2DTransposeScalarTupleParams(torch.nn.Module):
936+
def __init__(self):
937+
super().__init__()
938+
939+
@export
940+
@annotate_args(
941+
[
942+
None,
943+
([5, 2, 5, 6], torch.float32, True),
944+
([2, 5, 2, 2], torch.float32, True),
945+
]
946+
)
947+
def forward(self, inputVec, weight):
948+
return torch.ops.aten.convolution(
949+
inputVec,
950+
weight,
951+
bias=None,
952+
stride=(1,),
953+
padding=(1,),
954+
dilation=(1,),
955+
transposed=True,
956+
output_padding=(0,),
957+
groups=1,
958+
)
959+
960+
961+
@register_test_case(
962+
module_factory=lambda: ConvolutionModule2DTransposeScalarTupleParams()
963+
)
964+
def ConvolutionModule2DTransposeScalarTupleParams_basic(module, tu: TestUtils):
965+
module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2))
966+
967+
904968
class Conv_Transpose1dModule(torch.nn.Module):
905969
def __init__(self):
906970
super().__init__()
@@ -1569,6 +1633,39 @@ def Conv3dWithValidPaddingModule_basic(module, tu: TestUtils):
15691633
module.forward(inputVec, weight, bias)
15701634

15711635

1636+
class Conv3dModuleScalarTupleParams(torch.nn.Module):
1637+
def __init__(self):
1638+
super().__init__()
1639+
1640+
@export
1641+
@annotate_args(
1642+
[
1643+
None,
1644+
([-1, -1, -1, -1, -1], torch.float32, True),
1645+
([-1, -1, -1, -1, -1], torch.float32, True),
1646+
([-1], torch.float32, True),
1647+
]
1648+
)
1649+
def forward(self, inputVec, weight, bias):
1650+
return torch.ops.aten.conv3d(
1651+
inputVec,
1652+
weight,
1653+
bias=bias,
1654+
stride=(1,),
1655+
padding=(0,),
1656+
dilation=(1,),
1657+
groups=1,
1658+
)
1659+
1660+
1661+
@register_test_case(module_factory=lambda: Conv3dModuleScalarTupleParams())
1662+
def Conv3dModuleScalarTupleParams_basic(module, tu: TestUtils):
1663+
inputVec = tu.rand(2, 2, 6, 6, 6)
1664+
weight = torch.randn(8, 2, 3, 3, 3)
1665+
bias = torch.randn(8)
1666+
module.forward(inputVec, weight, bias)
1667+
1668+
15721669
class ConvTbcModule(torch.nn.Module):
15731670
def __init__(self):
15741671
super().__init__()

0 commit comments

Comments
 (0)