Skip to content

Commit 158387c

Browse files
[MLIR][TORCH] Add support for 1-d group convolution
This commit adds the support for the 1-d depthwise convolution as a special case of 1-d group convolution. Signed-Off By: Vivek Khandelwal <[email protected]>
1 parent f4840ed commit 158387c

File tree

3 files changed

+65
-12
lines changed

3 files changed

+65
-12
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,9 +1221,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
12211221
return success();
12221222
}
12231223

1224-
if (numSpatialDims != 2)
1225-
return rewriter.notifyMatchFailure(
1226-
op, "unimplemented: only 2D grouped convolution supported");
12271224

12281225
// Special depthwise case: Cin = Cout = groups.
12291226
// Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple
@@ -1236,21 +1233,45 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
12361233
if (inShape[1] == numGroups && weightShape[0] == numGroups &&
12371234
weightShape[1] == 1) {
12381235
// Collapse weight shape (C/G == 1)
1239-
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
1240-
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1],
1241-
weightShape[2], weightShape[3]};
1236+
SmallVector<ReassociationIndices> collapsedDims = {{0, 1}};
1237+
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1]};
1238+
for (unsigned i = 0; i < numSpatialDims; i++) {
1239+
collapsedDims.push_back({i + 2});
1240+
collapsedShape.push_back(weightShape[i + 2]);
1241+
}
12421242
Type collapsedType = RankedTensorType::get(
12431243
makeShapeLLVMCompatible(collapsedShape), weightDTy);
12441244
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
12451245
loc, collapsedType, weight, collapsedDims);
12461246
if (!inputZp) {
1247-
conv = rewriter
1248-
.create<linalg::DepthwiseConv2DNchwChwOp>(
1249-
loc, outputTensor.getType(),
1250-
ValueRange{paddedInput, collapsedWeight}, outputTensor,
1251-
stridesAttr, dilationAttr)
1252-
.getResult(0);
1247+
switch (numSpatialDims) {
1248+
case 1:
1249+
conv = rewriter
1250+
.create<linalg::DepthwiseConv1DNcwCwOp>(
1251+
loc, outputTensor.getType(),
1252+
ValueRange{paddedInput, collapsedWeight}, outputTensor,
1253+
stridesAttr, dilationAttr)
1254+
.getResult(0);
1255+
break;
1256+
case 2:
1257+
conv = rewriter
1258+
.create<linalg::DepthwiseConv2DNchwChwOp>(
1259+
loc, outputTensor.getType(),
1260+
ValueRange{paddedInput, collapsedWeight}, outputTensor,
1261+
stridesAttr, dilationAttr)
1262+
.getResult(0);
1263+
break;
1264+
default:
1265+
return rewriter.notifyMatchFailure(
1266+
op, "unimplemented: only 1D and 2D depthwise convolution "
1267+
"supported for special case of group convolution");
1268+
};
12531269
} else {
1270+
if (numSpatialDims != 2)
1271+
return rewriter.notifyMatchFailure(
1272+
op, "unimplemented: only 2D depthwise quantized convolution "
1273+
"supported for special case of group convolution");
1274+
12541275
// currently, the only named depthwise qconv op is nhwc_hwc
12551276
// input: nchw -> nhwc; weight (collapsed): chw -> hwc
12561277
// linalg conv result nhwc -> nchw
@@ -1297,6 +1318,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
12971318
return success();
12981319
}
12991320

1321+
if (numSpatialDims != 2)
1322+
return rewriter.notifyMatchFailure(
1323+
op, "unimplemented: only 2D grouped convolution supported");
1324+
13001325
// Grouped case, use the grouped conv linalg op
13011326
auto expandGroups = [&](Value tensor, size_t dim) {
13021327
auto inType = cast<RankedTensorType>(tensor.getType());

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,7 @@
10481048
"ContainsIntList_False",
10491049
"ContainsIntList_True",
10501050
"ContiguousModule_basic",
1051+
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
10511052
"Conv2dWithPaddingDilationStrideStaticModule_basic",
10521053
"Conv2dWithPaddingDilationStrideStaticModule_depthwise",
10531054
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
@@ -3385,6 +3386,7 @@
33853386
"ContainsIntList_False",
33863387
"ContainsIntList_True",
33873388
"Conv1dModule_basic",
3389+
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
33883390
"Conv2dQInt8Module_basic",
33893391
"Conv2dQInt8Module_depthwise",
33903392
"Conv2dQInt8Module_grouped",
@@ -4091,6 +4093,7 @@
40914093
"ContainsIntList_False",
40924094
"ContainsIntList_True",
40934095
"Conv1dModule_basic",
4096+
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
40944097
"Conv2dBiasNoPaddingModule_basic",
40954098
"Conv2dModule_basic",
40964099
"Conv2dNoPaddingModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,31 @@ def Conv1dModule_basic(module, tu: TestUtils):
10671067
module.forward(inputVec, weight, bias)
10681068

10691069

1070+
class Conv1dDepthwiseWithPaddingDilationStrideStaticModule(torch.nn.Module):
1071+
def __init__(self):
1072+
super().__init__()
1073+
1074+
@export
1075+
@annotate_args(
1076+
[
1077+
None,
1078+
([2, 4, 6], torch.float32, True),
1079+
([4, 1, 3], torch.float32, True),
1080+
]
1081+
)
1082+
def forward(self, inputVec, weight):
1083+
return torch.ops.aten.conv1d(
1084+
inputVec, weight, bias=None, stride=[1], padding=[4], dilation=[1], groups=4
1085+
)
1086+
1087+
1088+
@register_test_case(module_factory=lambda: Conv1dDepthwiseWithPaddingDilationStrideStaticModule())
1089+
def Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
1090+
inputVec = tu.rand(2, 4, 6)
1091+
weight = torch.randn(4, 1, 3)
1092+
module.forward(inputVec, weight)
1093+
1094+
10701095
class Conv2dModule(torch.nn.Module):
10711096
def __init__(self):
10721097
super().__init__()

0 commit comments

Comments
 (0)