Skip to content

Commit f1d59b4

Browse files
committed
none results handled
1 parent 7c14720 commit f1d59b4

File tree

1 file changed

+24
-14
lines changed

1 file changed

+24
-14
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1816,11 +1816,10 @@ class ConvertAtenConvolutionBackwardOp : public OpConversionPattern<AtenConvolut
18161816
gradInput =
18171817
tensor::CastOp::create(rewriter, loc, getTypeConverter()->convertType(op->getResult(0).getType()), gradInput);
18181818
} else {
1819-
// If the backward-weight convolution is not needed, zero init the grad_input tensor.
1820-
SmallVector<Value> gradInputSizes =
1821-
getTensorSizes(rewriter, loc, input);
1822-
gradInput =
1823-
createZeroInitTensor(rewriter, loc, gradInputSizes, inputDTy);
1819+
// If input gradient is not needed, we replace torch.none with a constant zero.
1820+
// This constant will be eliminated by DCE.
1821+
gradInput = arith::ConstantOp::create(rewriter, loc,
1822+
rewriter.getI64IntegerAttr(0));
18241823
}
18251824

18261825
// Computing Backward-Weight Convolution.
@@ -1856,6 +1855,18 @@ class ConvertAtenConvolutionBackwardOp : public OpConversionPattern<AtenConvolut
18561855
SmallVector<AffineMap> indexingMaps;
18571856
SmallVector<IT> iteratorTypes;
18581857

1858+
// To calculate convolution backward-weight, we use generic operation.
1859+
// The generic operation is a generalization of the convolution operation
1860+
// that can handle any number of spatial dimensions.
1861+
// The generic operation is defined as follows:
1862+
// ```
1863+
// dLdw[f, g, c, k] = sum(x[n, g, c, d0 * k + s0 * o] * dLdy[n, g, f, o] for n in range(batch_size) for o in range(output_spatial_dims))
1864+
// ```
1865+
// where `n` is the batch dimension, `g` is the group dimension,
1866+
// `c` is the input channel dimension, `k` is the output channel dimension,
1867+
// `o` is the output spatial dimension, d0 is dilation, s0 is stride.
1868+
// `x` is the input tensor, `dLdy` is the gradient of the output tensor.
1869+
// `dLdw` is the weight-gradient tensor.
18591870
if (!isGroupedConvBwd) {
18601871
if (numSpatialDims == 1) {
18611872
AffineExpr f, c, k, n, o;
@@ -2011,11 +2022,10 @@ class ConvertAtenConvolutionBackwardOp : public OpConversionPattern<AtenConvolut
20112022
getTypeConverter()->convertType(op->getResult(1).getType()),
20122023
genericRes);
20132024
} else {
2014-
// If the backward-weight convolution is not needed, zero init the grad_weight tensor.
2015-
SmallVector<Value> gradWeightSizes =
2016-
getTensorSizes(rewriter, loc, weight);
2017-
gradWeight =
2018-
createZeroInitTensor(rewriter, loc, gradWeightSizes, weightDTy);
2025+
// If weight gradient is not needed, we replace torch.none with a constant zero.
2026+
// This constant will be eliminated by DCE.
2027+
gradWeight = arith::ConstantOp::create(rewriter, loc,
2028+
rewriter.getI64IntegerAttr(0));
20192029
}
20202030

20212031
// Computing Backward-Bias Convolution.
@@ -2045,10 +2055,10 @@ class ConvertAtenConvolutionBackwardOp : public OpConversionPattern<AtenConvolut
20452055
gradBias =
20462056
tensor::CastOp::create(rewriter, loc, getTypeConverter()->convertType(op->getResult(2).getType()), gradBias);
20472057
} else {
2048-
// If the bias are not needed, zero init the grad_bias tensor.
2049-
// TODO FIX IT
2050-
SmallVector<Value> gradBiasSizes = getTensorSizes(rewriter, loc, gradOutput);
2051-
gradBias = createZeroInitTensor(rewriter, loc, gradBiasSizes, gradOutputDTy);
2058+
// If bias gradient is not needed, we replace torch.none with a constant zero.
2059+
// This constant will be eliminated by DCE.
2060+
gradBias = arith::ConstantOp::create(rewriter, loc,
2061+
rewriter.getI64IntegerAttr(0));
20522062
}
20532063

20542064
rewriter.replaceOp(op, {gradInput, gradWeight, gradBias});

0 commit comments

Comments
 (0)