@@ -1184,10 +1184,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
11841184 return success ();
11851185 }
11861186
1187- if (numSpatialDims != 2 )
1188- return rewriter.notifyMatchFailure (
1189- op, " unimplemented: only 2D grouped convolution supported" );
1190-
11911187 // Special depthwise case: Cin = Cout = groups.
11921188 // Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple
11931189 // of groups) to be depthwise in their documentation, but the linalg ops
@@ -1199,21 +1195,45 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
11991195 if (inShape[1 ] == numGroups && weightShape[0 ] == numGroups &&
12001196 weightShape[1 ] == 1 ) {
12011197 // Collapse weight shape (C/G == 1)
1202- SmallVector<ReassociationIndices, 4 > collapsedDims = {{0 , 1 }, {2 }, {3 }};
1203- SmallVector<int64_t > collapsedShape{weightShape[0 ] * weightShape[1 ],
1204- weightShape[2 ], weightShape[3 ]};
1198+ SmallVector<ReassociationIndices> collapsedDims = {{0 , 1 }};
1199+ SmallVector<int64_t > collapsedShape{weightShape[0 ] * weightShape[1 ]};
1200+ for (unsigned i = 0 ; i < numSpatialDims; i++) {
1201+ collapsedDims.push_back ({i + 2 });
1202+ collapsedShape.push_back (weightShape[i + 2 ]);
1203+ }
12051204 Type collapsedType = RankedTensorType::get (
12061205 makeShapeLLVMCompatible (collapsedShape), weightDTy);
12071206 Value collapsedWeight = rewriter.create <tensor::CollapseShapeOp>(
12081207 loc, collapsedType, weight, collapsedDims);
12091208 if (!inputZp) {
1210- conv = rewriter
1211- .create <linalg::DepthwiseConv2DNchwChwOp>(
1212- loc, outputTensor.getType (),
1213- ValueRange{paddedInput, collapsedWeight}, outputTensor,
1214- stridesAttr, dilationAttr)
1215- .getResult (0 );
1209+ switch (numSpatialDims) {
1210+ case 1 :
1211+ conv = rewriter
1212+ .create <linalg::DepthwiseConv1DNcwCwOp>(
1213+ loc, outputTensor.getType (),
1214+ ValueRange{paddedInput, collapsedWeight}, outputTensor,
1215+ stridesAttr, dilationAttr)
1216+ .getResult (0 );
1217+ break ;
1218+ case 2 :
1219+ conv = rewriter
1220+ .create <linalg::DepthwiseConv2DNchwChwOp>(
1221+ loc, outputTensor.getType (),
1222+ ValueRange{paddedInput, collapsedWeight}, outputTensor,
1223+ stridesAttr, dilationAttr)
1224+ .getResult (0 );
1225+ break ;
1226+ default :
1227+ return rewriter.notifyMatchFailure (
1228+ op, " unimplemented: only 1D and 2D depthwise convolution "
1229+ " supported for special case of group convolution" );
1230+ };
12161231 } else {
1232+ if (numSpatialDims != 2 )
1233+ return rewriter.notifyMatchFailure (
1234+ op, " unimplemented: only 2D depthwise quantized convolution "
1235+ " supported for special case of group convolution" );
1236+
12171237 // currently, the only named depthwise qconv op is nhwc_hwc
12181238 // input: nchw -> nhwc; weight (collapsed): chw -> hwc
12191239 // linalg conv result nhwc -> nchw
@@ -1260,6 +1280,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
12601280 return success ();
12611281 }
12621282
1283+ if (numSpatialDims != 2 )
1284+ return rewriter.notifyMatchFailure (
1285+ op, " unimplemented: only 2D grouped convolution supported" );
1286+
12631287 // Grouped case, use the grouped conv linalg op
12641288 auto expandGroups = [&](Value tensor, size_t dim) {
12651289 auto inType = cast<RankedTensorType>(tensor.getType ());
0 commit comments