@@ -1816,11 +1816,10 @@ class ConvertAtenConvolutionBackwardOp : public OpConversionPattern<AtenConvolut
18161816 gradInput =
18171817 tensor::CastOp::create (rewriter, loc, getTypeConverter ()->convertType (op->getResult (0 ).getType ()), gradInput);
18181818 } else {
1819- // If the backward-weight convolution is not needed, zero init the grad_input tensor.
1820- SmallVector<Value> gradInputSizes =
1821- getTensorSizes (rewriter, loc, input);
1822- gradInput =
1823- createZeroInitTensor (rewriter, loc, gradInputSizes, inputDTy);
1819+ // If input gradient is not needed, we replace torch.none with a constant zero.
1820+ // This constant will be eliminated by DCE.
1821+ gradInput = arith::ConstantOp::create (rewriter, loc,
1822+ rewriter.getI64IntegerAttr (0 ));
18241823 }
18251824
18261825 // Computing Backward-Weight Convolution.
@@ -1856,6 +1855,18 @@ class ConvertAtenConvolutionBackwardOp : public OpConversionPattern<AtenConvolut
18561855 SmallVector<AffineMap> indexingMaps;
18571856 SmallVector<IT> iteratorTypes;
18581857
1858+ // To calculate convolution backward-weight, we use generic operation.
1859+ // The generic operation is a generalization of the convolution operation
1860+ // that can handle any number of spatial dimensions.
1861+ // The generic operation is defined as follows:
1862+ // ```
1863+ // dLdw[f, g, c, k] = sum(x[n, g, c, d0 * k + s0 * o] * dLdy[n, g, f, o] for n in range(batch_size) for o in range(output_spatial_dims))
1864+ // ```
1865+ // where `n` is the batch dimension, `g` is the group dimension,
1866+ // `c` is the input channel dimension, `k` is the output channel dimension,
1867+ // `o` is the output spatial dimension, d0 is dilation, s0 is stride.
1868+ // `x` is the input tensor, `dLdy` is the gradient of the output tensor.
1869+ // `dLdw` is the weight-gradient tensor.
18591870 if (!isGroupedConvBwd) {
18601871 if (numSpatialDims == 1 ) {
18611872 AffineExpr f, c, k, n, o;
@@ -2011,11 +2022,10 @@ class ConvertAtenConvolutionBackwardOp : public OpConversionPattern<AtenConvolut
20112022 getTypeConverter ()->convertType (op->getResult (1 ).getType ()),
20122023 genericRes);
20132024 } else {
2014- // If the backward-weight convolution is not needed, zero init the grad_weight tensor.
2015- SmallVector<Value> gradWeightSizes =
2016- getTensorSizes (rewriter, loc, weight);
2017- gradWeight =
2018- createZeroInitTensor (rewriter, loc, gradWeightSizes, weightDTy);
2025+ // If weight gradient is not needed, we replace torch.none with a constant zero.
2026+ // This constant will be eliminated by DCE.
2027+ gradWeight = arith::ConstantOp::create (rewriter, loc,
2028+ rewriter.getI64IntegerAttr (0 ));
20192029 }
20202030
20212031 // Computing Backward-Bias Convolution.
@@ -2045,10 +2055,10 @@ class ConvertAtenConvolutionBackwardOp : public OpConversionPattern<AtenConvolut
20452055 gradBias =
20462056 tensor::CastOp::create (rewriter, loc, getTypeConverter ()->convertType (op->getResult (2 ).getType ()), gradBias);
20472057 } else {
2048- // If the bias are not needed, zero init the grad_bias tensor .
2049- // TODO FIX IT
2050- SmallVector<Value> gradBiasSizes = getTensorSizes (rewriter, loc, gradOutput);
2051- gradBias = createZeroInitTensor ( rewriter, loc, gradBiasSizes, gradOutputDTy );
2058+ // If bias gradient is not needed, we replace torch.none with a constant zero .
2059+ // This constant will be eliminated by DCE.
2060+ gradBias = arith::ConstantOp::create (rewriter, loc,
2061+ rewriter. getI64IntegerAttr ( 0 ) );
20522062 }
20532063
20542064 rewriter.replaceOp (op, {gradInput, gradWeight, gradBias});
0 commit comments