@@ -1921,6 +1921,56 @@ class ConvertAtenConvolutionBackwardOp : public OpConversionPattern<AtenConvolut
19211921 IT::reduction, IT::reduction, IT::reduction,
19221922 IT::reduction};
19231923 }
1924+ } else {
1925+ if (numSpatialDims == 1 ) {
1926+ AffineExpr n, g, cg, o, fg, k;
1927+ bindDims (context, n, g, cg, o, fg, k);
1928+ AffineExpr d0 = rewriter.getAffineConstantExpr (dilationInts[0 ]);
1929+ SmallVector<AffineExpr> goExprs = {n, g, fg, d0 * k + o};
1930+ SmallVector<AffineExpr> weiExprs = {g, fg, cg, k};
1931+ SmallVector<AffineExpr> outExprs = {n, g, cg, o};
1932+ indexingMaps = {
1933+ AffineMap::get (6 , 0 , goExprs, context),
1934+ AffineMap::get (6 , 0 , weiExprs, context),
1935+ AffineMap::get (6 , 0 , outExprs, context)};
1936+ iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, IT::parallel,
1937+ IT::reduction, IT::reduction};
1938+ } else if (numSpatialDims == 2 ) {
1939+ AffineExpr n, g, cg, oh, ow, fg, kh, kw;
1940+ bindDims (context, n, g, cg, oh, ow, fg, kh, kw);
1941+ AffineExpr d0 = rewriter.getAffineConstantExpr (dilationInts[0 ]);
1942+ AffineExpr d1 = rewriter.getAffineConstantExpr (dilationInts[1 ]);
1943+ SmallVector<AffineExpr> goExprs = {
1944+ n, g, fg, d0 * kh + oh, d1 * kw + ow};
1945+ SmallVector<AffineExpr> weiExprs = {g, fg, cg, kh, kw};
1946+ SmallVector<AffineExpr> outExprs = {n, g, cg, oh, ow};
1947+ indexingMaps = {
1948+ AffineMap::get (8 , 0 , goExprs, context),
1949+ AffineMap::get (8 , 0 , weiExprs, context),
1950+ AffineMap::get (8 , 0 , outExprs, context)};
1951+ iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, IT::parallel, IT::parallel,
1952+ IT::reduction, IT::reduction, IT::reduction};
1953+ } else {
1954+ AffineExpr n, g, cg, od, oh, ow, fg, kd, kh, kw;
1955+ bindDims (context, n, g, cg, od, oh, ow, fg, kd, kh, kw);
1956+ AffineExpr d0 = rewriter.getAffineConstantExpr (dilationInts[0 ]);
1957+ AffineExpr d1 = rewriter.getAffineConstantExpr (dilationInts[1 ]);
1958+ AffineExpr d2 = rewriter.getAffineConstantExpr (dilationInts[2 ]);
1959+ SmallVector<AffineExpr> goExprs = {n,
1960+ g,
1961+ fg,
1962+ d0 * kd + od,
1963+ d1 * kh + oh,
1964+ d2 * kw + ow};
1965+ SmallVector<AffineExpr> weiExprs = {g, fg, cg, kd, kh, kw};
1966+ SmallVector<AffineExpr> outExprs = {n, g, cg, od, oh, ow};
1967+ indexingMaps = {
1968+ AffineMap::get (10 , 0 , goExprs, context),
1969+ AffineMap::get (10 , 0 , weiExprs, context),
1970+ AffineMap::get (10 , 0 , outExprs, context)};
1971+ iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, IT::parallel, IT::parallel, IT::parallel,
1972+ IT::reduction, IT::reduction, IT::reduction, IT::reduction};
1973+ }
19241974 }
19251975
19261976 auto genericRes = linalg::GenericOp::create (
0 commit comments