Skip to content

Commit 614fcdd

Browse files
[MLIR][TORCH] Add support for 1-d group convolution (#3770)
This commit adds the support for the 1-d depthwise convolution as a special case of 1-d group convolution. Signed-Off By: Vivek Khandelwal <[email protected]>
1 parent f6721e5 commit 614fcdd

File tree

3 files changed

+67
-13
lines changed

3 files changed

+67
-13
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,10 +1184,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
11841184
return success();
11851185
}
11861186

1187-
if (numSpatialDims != 2)
1188-
return rewriter.notifyMatchFailure(
1189-
op, "unimplemented: only 2D grouped convolution supported");
1190-
11911187
// Special depthwise case: Cin = Cout = groups.
11921188
// Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple
11931189
// of groups) to be depthwise in their documentation, but the linalg ops
@@ -1199,21 +1195,45 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
11991195
if (inShape[1] == numGroups && weightShape[0] == numGroups &&
12001196
weightShape[1] == 1) {
12011197
// Collapse weight shape (C/G == 1)
1202-
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
1203-
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1],
1204-
weightShape[2], weightShape[3]};
1198+
SmallVector<ReassociationIndices> collapsedDims = {{0, 1}};
1199+
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1]};
1200+
for (unsigned i = 0; i < numSpatialDims; i++) {
1201+
collapsedDims.push_back({i + 2});
1202+
collapsedShape.push_back(weightShape[i + 2]);
1203+
}
12051204
Type collapsedType = RankedTensorType::get(
12061205
makeShapeLLVMCompatible(collapsedShape), weightDTy);
12071206
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
12081207
loc, collapsedType, weight, collapsedDims);
12091208
if (!inputZp) {
1210-
conv = rewriter
1211-
.create<linalg::DepthwiseConv2DNchwChwOp>(
1212-
loc, outputTensor.getType(),
1213-
ValueRange{paddedInput, collapsedWeight}, outputTensor,
1214-
stridesAttr, dilationAttr)
1215-
.getResult(0);
1209+
switch (numSpatialDims) {
1210+
case 1:
1211+
conv = rewriter
1212+
.create<linalg::DepthwiseConv1DNcwCwOp>(
1213+
loc, outputTensor.getType(),
1214+
ValueRange{paddedInput, collapsedWeight}, outputTensor,
1215+
stridesAttr, dilationAttr)
1216+
.getResult(0);
1217+
break;
1218+
case 2:
1219+
conv = rewriter
1220+
.create<linalg::DepthwiseConv2DNchwChwOp>(
1221+
loc, outputTensor.getType(),
1222+
ValueRange{paddedInput, collapsedWeight}, outputTensor,
1223+
stridesAttr, dilationAttr)
1224+
.getResult(0);
1225+
break;
1226+
default:
1227+
return rewriter.notifyMatchFailure(
1228+
op, "unimplemented: only 1D and 2D depthwise convolution "
1229+
"supported for special case of group convolution");
1230+
};
12161231
} else {
1232+
if (numSpatialDims != 2)
1233+
return rewriter.notifyMatchFailure(
1234+
op, "unimplemented: only 2D depthwise quantized convolution "
1235+
"supported for special case of group convolution");
1236+
12171237
// currently, the only named depthwise qconv op is nhwc_hwc
12181238
// input: nchw -> nhwc; weight (collapsed): chw -> hwc
12191239
// linalg conv result nhwc -> nchw
@@ -1260,6 +1280,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
12601280
return success();
12611281
}
12621282

1283+
if (numSpatialDims != 2)
1284+
return rewriter.notifyMatchFailure(
1285+
op, "unimplemented: only 2D grouped convolution supported");
1286+
12631287
// Grouped case, use the grouped conv linalg op
12641288
auto expandGroups = [&](Value tensor, size_t dim) {
12651289
auto inType = cast<RankedTensorType>(tensor.getType());

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,7 @@
10481048
"ContainsIntList_False",
10491049
"ContainsIntList_True",
10501050
"ContiguousModule_basic",
1051+
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
10511052
"Conv2dWithPaddingDilationStrideStaticModule_basic",
10521053
"Conv2dWithPaddingDilationStrideStaticModule_depthwise",
10531054
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
@@ -3395,6 +3396,7 @@
33953396
"ContainsIntList_False",
33963397
"ContainsIntList_True",
33973398
"Conv1dModule_basic",
3399+
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
33983400
"Conv2dQInt8Module_basic",
33993401
"Conv2dQInt8Module_depthwise",
34003402
"Conv2dQInt8Module_grouped",
@@ -4087,6 +4089,7 @@
40874089
"ContainsIntList_False",
40884090
"ContainsIntList_True",
40894091
"Conv1dModule_basic",
4092+
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
40904093
"Conv2dBiasNoPaddingModule_basic",
40914094
"Conv2dModule_basic",
40924095
"Conv2dNoPaddingModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,33 @@ def Conv1dModule_basic(module, tu: TestUtils):
10671067
module.forward(inputVec, weight, bias)
10681068

10691069

1070+
class Conv1dDepthwiseWithPaddingDilationStrideStaticModule(torch.nn.Module):
1071+
def __init__(self):
1072+
super().__init__()
1073+
1074+
@export
1075+
@annotate_args(
1076+
[
1077+
None,
1078+
([2, 4, 6], torch.float32, True),
1079+
([4, 1, 3], torch.float32, True),
1080+
]
1081+
)
1082+
def forward(self, inputVec, weight):
1083+
return torch.ops.aten.conv1d(
1084+
inputVec, weight, bias=None, stride=[1], padding=[4], dilation=[1], groups=4
1085+
)
1086+
1087+
1088+
@register_test_case(
1089+
module_factory=lambda: Conv1dDepthwiseWithPaddingDilationStrideStaticModule()
1090+
)
1091+
def Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
1092+
inputVec = tu.rand(2, 4, 6)
1093+
weight = torch.randn(4, 1, 3)
1094+
module.forward(inputVec, weight)
1095+
1096+
10701097
class Conv2dModule(torch.nn.Module):
10711098
def __init__(self):
10721099
super().__init__()

0 commit comments

Comments
 (0)