Skip to content

Commit 04b4eed

Browse files
committed
added group support for input gradient
1 parent f96844a commit 04b4eed

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1921,6 +1921,56 @@ class ConvertAtenConvolutionBackwardOp : public OpConversionPattern<AtenConvolut
19211921
IT::reduction, IT::reduction, IT::reduction,
19221922
IT::reduction};
19231923
}
1924+
} else {
1925+
if (numSpatialDims == 1) {
1926+
AffineExpr n, g, cg, o, fg, k;
1927+
bindDims(context, n, g, cg, o, fg, k);
1928+
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
1929+
SmallVector<AffineExpr> goExprs = {n, g, fg, d0 * k + o};
1930+
SmallVector<AffineExpr> weiExprs = {g, fg, cg, k};
1931+
SmallVector<AffineExpr> outExprs = {n, g, cg, o};
1932+
indexingMaps = {
1933+
AffineMap::get(6, 0, goExprs, context),
1934+
AffineMap::get(6, 0, weiExprs, context),
1935+
AffineMap::get(6, 0, outExprs, context)};
1936+
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, IT::parallel,
1937+
IT::reduction, IT::reduction};
1938+
} else if (numSpatialDims == 2) {
1939+
AffineExpr n, g, cg, oh, ow, fg, kh, kw;
1940+
bindDims(context, n, g, cg, oh, ow, fg, kh, kw);
1941+
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
1942+
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
1943+
SmallVector<AffineExpr> goExprs = {
1944+
n, g, fg, d0 * kh + oh, d1 * kw + ow};
1945+
SmallVector<AffineExpr> weiExprs = {g, fg, cg, kh, kw};
1946+
SmallVector<AffineExpr> outExprs = {n, g, cg, oh, ow};
1947+
indexingMaps = {
1948+
AffineMap::get(8, 0, goExprs, context),
1949+
AffineMap::get(8, 0, weiExprs, context),
1950+
AffineMap::get(8, 0, outExprs, context)};
1951+
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, IT::parallel, IT::parallel,
1952+
IT::reduction, IT::reduction, IT::reduction};
1953+
} else {
1954+
AffineExpr n, g, cg, od, oh, ow, fg, kd, kh, kw;
1955+
bindDims(context, n, g, cg, od, oh, ow, fg, kd, kh, kw);
1956+
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
1957+
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
1958+
AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]);
1959+
SmallVector<AffineExpr> goExprs = {n,
1960+
g,
1961+
fg,
1962+
d0 * kd + od,
1963+
d1 * kh + oh,
1964+
d2 * kw + ow};
1965+
SmallVector<AffineExpr> weiExprs = {g, fg, cg, kd, kh, kw};
1966+
SmallVector<AffineExpr> outExprs = {n, g, cg, od, oh, ow};
1967+
indexingMaps = {
1968+
AffineMap::get(10, 0, goExprs, context),
1969+
AffineMap::get(10, 0, weiExprs, context),
1970+
AffineMap::get(10, 0, outExprs, context)};
1971+
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, IT::parallel, IT::parallel, IT::parallel,
1972+
IT::reduction, IT::reduction, IT::reduction, IT::reduction};
1973+
}
19241974
}
19251975

19261976
auto genericRes = linalg::GenericOp::create(

0 commit comments

Comments
 (0)