@@ -1221,9 +1221,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
12211221      return  success ();
12221222    }
12231223
1224-     if  (numSpatialDims != 2 )
1225-       return  rewriter.notifyMatchFailure (
1226-           op, " unimplemented: only 2D grouped convolution supported"  );
12271224
12281225    //  Special depthwise case: Cin = Cout = groups.
12291226    //  Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple
@@ -1236,21 +1233,45 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
12361233    if  (inShape[1 ] == numGroups && weightShape[0 ] == numGroups &&
12371234        weightShape[1 ] == 1 ) {
12381235      //  Collapse weight shape (C/G == 1)
1239-       SmallVector<ReassociationIndices, 4 > collapsedDims = {{0 , 1 }, {2 }, {3 }};
1240-       SmallVector<int64_t > collapsedShape{weightShape[0 ] * weightShape[1 ],
1241-                                           weightShape[2 ], weightShape[3 ]};
1236+       SmallVector<ReassociationIndices> collapsedDims = {{0 , 1 }};
1237+       SmallVector<int64_t > collapsedShape{weightShape[0 ] * weightShape[1 ]};
1238+       for  (unsigned  i = 0 ; i < numSpatialDims; i++) {
1239+         collapsedDims.push_back ({i + 2 });
1240+         collapsedShape.push_back (weightShape[i + 2 ]);
1241+       }
12421242      Type collapsedType = RankedTensorType::get (
12431243          makeShapeLLVMCompatible (collapsedShape), weightDTy);
12441244      Value collapsedWeight = rewriter.create <tensor::CollapseShapeOp>(
12451245          loc, collapsedType, weight, collapsedDims);
12461246      if  (!inputZp) {
1247-         conv = rewriter
1248-                    .create <linalg::DepthwiseConv2DNchwChwOp>(
1249-                        loc, outputTensor.getType (),
1250-                        ValueRange{paddedInput, collapsedWeight}, outputTensor,
1251-                        stridesAttr, dilationAttr)
1252-                    .getResult (0 );
1247+         switch  (numSpatialDims) {
1248+         case  1 :
1249+           conv = rewriter
1250+                      .create <linalg::DepthwiseConv1DNcwCwOp>(
1251+                          loc, outputTensor.getType (),
1252+                          ValueRange{paddedInput, collapsedWeight}, outputTensor,
1253+                          stridesAttr, dilationAttr)
1254+                      .getResult (0 );
1255+           break ;
1256+         case  2 :
1257+           conv = rewriter
1258+                      .create <linalg::DepthwiseConv2DNchwChwOp>(
1259+                          loc, outputTensor.getType (),
1260+                          ValueRange{paddedInput, collapsedWeight}, outputTensor,
1261+                          stridesAttr, dilationAttr)
1262+                      .getResult (0 );
1263+           break ;
1264+         default :
1265+           return  rewriter.notifyMatchFailure (
1266+               op, " unimplemented: only 1D and 2D depthwise convolution " 
1267+                   " supported for special case of group convolution"  );
1268+         };
12531269      } else  {
1270+         if  (numSpatialDims != 2 )
1271+           return  rewriter.notifyMatchFailure (
1272+               op, " unimplemented: only 2D depthwise quantized convolution " 
1273+                   " supported for special case of group convolution"  );
1274+ 
12541275        //  currently, the only named depthwise qconv op is nhwc_hwc
12551276        //  input: nchw -> nhwc; weight (collapsed): chw -> hwc
12561277        //  linalg conv result nhwc -> nchw
@@ -1297,6 +1318,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
12971318      return  success ();
12981319    }
12991320
1321+     if  (numSpatialDims != 2 )
1322+       return  rewriter.notifyMatchFailure (
1323+           op, " unimplemented: only 2D grouped convolution supported"  );
1324+ 
13001325    //  Grouped case, use the grouped conv linalg op
13011326    auto  expandGroups = [&](Value tensor, size_t  dim) {
13021327      auto  inType = cast<RankedTensorType>(tensor.getType ());
0 commit comments