Skip to content

Commit 404b8ab

Browse files
committed
added not unit stride support for input gradient
1 parent 04b4eed commit 404b8ab

File tree

2 files changed

+112
-15
lines changed

2 files changed

+112
-15
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1826,36 +1826,75 @@ class ConvertAtenConvolutionBackwardOp : public OpConversionPattern<AtenConvolut
18261826
torch_to_linalg::flipTensor(rewriter, loc, weightExpanded, kernelFlipDims);
18271827

18281828
// For backward-input, padding must be adjusted to:
1829-
// pad_bwd[i] = dilation[i] * (kernel_size[i] - 1) - pad_fwd[i]
1830-
Value cstOne = arith::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1));
1829+
// p'[i] = d[i] * (K[i] - 1) - p[i]
1830+
Value c1 = arith::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1));
18311831
Value padVal = arith::ConstantOp::create(
18321832
rewriter, loc, rewriter.getFloatAttr(gradOutputDTy, 0.0));
1833-
SmallVector<Value> gradOutputPaddingValues(numSpatialDims);
18341833
SmallVector<Value> dilationIntValues =
18351834
getAsConstantIntValues(rewriter, loc, dilationInts);
1835+
SmallVector<Value> weiSizes = getTensorSizes(rewriter, loc, weightExpanded);
1836+
SmallVector<Value> gradOutputPaddingValues(numSpatialDims);
18361837
for (size_t i = 0; i < numSpatialDims; ++i) {
1837-
Value kSize = castIndexToInt64(rewriter, loc, getDimOp(rewriter, loc, weightExpanded, spatialStartDimIdx + i));
1838+
Value kSize = castIndexToInt64(rewriter, loc, weiSizes[spatialStartDimIdx + i]);
18381839
Value kMinusOne = rewriter.createOrFold<arith::SubIOp>(
1839-
loc, kSize, cstOne);
1840+
loc, kSize, c1);
18401841
Value dilated = rewriter.createOrFold<arith::MulIOp>(
1841-
loc, kMinusOne, castIntToIndex(rewriter, loc, dilationIntValues[i]));
1842+
loc, kMinusOne, dilationIntValues[i]);
18421843
gradOutputPaddingValues[i] = arith::SubIOp::create(rewriter, loc, dilated, paddingIntValues[i]);
1844+
1845+
if (isValueNegative(gradOutputPaddingValues[i]))
1846+
return rewriter.notifyMatchFailure(
1847+
op, "unimplemented: negative padding values are not supported.");
18431848
}
18441849

1845-
bool do_insert_slice = llvm::any_of(strideInts, [](int64_t stride) { return stride > 1; });
1846-
if (do_insert_slice) {
1847-
return rewriter.notifyMatchFailure(
1848-
op, "unimplemented: do_insert_slice");
1850+
// If there are not unit strides, we have to scatter `grad_output` into a zero-initialized tensor.
1851+
SmallVector<Value> gradInputSizes = getTensorSizes(rewriter, loc, input);
1852+
Value gradOutputSliced;
1853+
if (llvm::any_of(strideInts, [](int64_t stride) { return stride > 1; })) {
1854+
// Destination spatial sizes are computed as:
1855+
// size[i] = (D[i] - 1) + d[i] * (K[i] - 1) + 1
1856+
// Offsets on spatial dims are paddings
1857+
// Strides on spatial dims are the original stride[i].
1858+
Value zero = arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0));
1859+
Value one = arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(1));
1860+
1861+
// Initialize slice strides, sizes and offsets
1862+
SmallVector<Value> goSizes = getTensorSizes(rewriter, loc, gradOutputExpanded);
1863+
SmallVector<Value> sizes(goSizes.begin(), goSizes.begin() + spatialStartDimIdx);
1864+
SmallVector<Value> offsets(spatialStartDimIdx, zero);
1865+
SmallVector<Value> strides(spatialStartDimIdx, one);
1866+
for (size_t i = 0; i < numSpatialDims; ++i) {
1867+
// Shapes of `grad_input` are collapsed here
1868+
Value h = gradInputSizes[2 + i];
1869+
Value k = weiSizes[spatialStartDimIdx + i];
1870+
Value hMinusOne = rewriter.createOrFold<arith::SubIOp>(loc, h, one);
1871+
Value kMinusOne = rewriter.createOrFold<arith::SubIOp>(loc, k, one);
1872+
Value mul = rewriter.createOrFold<arith::MulIOp>(
1873+
loc, castIntToIndex(rewriter, loc, dilationIntValues[i]), kMinusOne);
1874+
Value sum = rewriter.createOrFold<arith::AddIOp>(loc, hMinusOne, mul);
1875+
sizes.push_back(rewriter.createOrFold<arith::AddIOp>(loc, sum, one));
1876+
offsets.push_back(castIntToIndex(rewriter, loc, gradOutputPaddingValues[i]));
1877+
1878+
Value strideIntValue = arith::ConstantOp::create(
1879+
rewriter, loc, rewriter.getI64IntegerAttr(strideInts[i]));
1880+
strides.push_back(castIntToIndex(rewriter, loc, strideIntValue));
1881+
}
1882+
1883+
Value zeroInit =
1884+
createZeroInitTensor(rewriter, loc, sizes, gradOutputDTy);
1885+
gradOutputSliced = tensor::InsertSliceOp::create(
1886+
rewriter, loc, torch_to_linalg::removeSizeInformation(rewriter, loc, gradOutputExpanded),
1887+
zeroInit, offsets, goSizes, strides);
18491888
} else {
1850-
// Pad `grad_output` spatial dims with zeros. If grouped, input has shape:
1851-
// N x G x F/G x <spatial>. Otherwise: N x F x <spatial>.
1852-
gradOutputExpanded = torch_to_linalg::getDynamicZeroPaddedTensor(
1889+
// If there unit strides, pad `grad_output` spatial dims with zeros.
1890+
// If conv is grouped, output has shape:
1891+
// N x G x F/G x <spatial>. Otherwise: N x F x <spatial>.
1892+
gradOutputSliced = torch_to_linalg::getDynamicZeroPaddedTensor(
18531893
op, rewriter, gradOutputExpanded, gradOutputPaddingValues, spatialStartDimIdx, padVal);
18541894
}
18551895

18561896
// Initialize output buffer. For grouped, compute into an expanded
18571897
// [N, G, C/G, D*] tensor and collapse back to the original input shape.
1858-
SmallVector<Value> gradInputSizes = getTensorSizes(rewriter, loc, input);
18591898
Value gradInputInit =
18601899
createZeroInitTensor(rewriter, loc, gradInputSizes, inputDTy);
18611900
SmallVector<ReassociationIndices> gradInputCollapseIndices;
@@ -1975,7 +2014,7 @@ class ConvertAtenConvolutionBackwardOp : public OpConversionPattern<AtenConvolut
19752014

19762015
auto genericRes = linalg::GenericOp::create(
19772016
rewriter, loc, gradInputInit.getType(),
1978-
ValueRange{gradOutputExpanded, weightExpanded},
2017+
ValueRange{gradOutputSliced, weightExpanded},
19792018
gradInputInit, indexingMaps, iteratorTypes,
19802019
[&](OpBuilder &b, Location loc, ValueRange args) {
19812020
Value grad = args[0];

test/Conversion/TorchToLinalg/convolution_bwd.mlir

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,64 @@ func.func @convolution_backward_input_1x1s_0x0p_1x1d_1g(%arg0: !torch.vtensor<[2
5959

6060
// -----
6161

62+
// CHECK-LABEL: func.func @convolution_backward_input_2x2s_2x2p_2x2d_1g(
63+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,33,33],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>,
64+
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,128,2,2],f32>,
65+
// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) {
66+
func.func @convolution_backward_input_2x2s_2x2p_2x2d_1g(%arg0: !torch.vtensor<[2,16,33,33],f32>, %arg1: !torch.vtensor<[2,128,64,64],f32>, %arg2: !torch.vtensor<[16,128,2,2],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) {
67+
// CHECK: %[[CST1:.*]] = arith.constant 1 : index
68+
// CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
69+
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[16,128,2,2],f32> -> tensor<16x128x2x2xf32>
70+
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,33,33],f32> -> tensor<2x16x33x33xf32>
71+
// CHECK: %[[W_EMPTY:.*]] = tensor.empty() : tensor<16x128x2x2xf32>
72+
// CHECK: %[[W_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[W_EMPTY]] : tensor<16x128x2x2xf32>) -> tensor<16x128x2x2xf32>
73+
// CHECK: %[[W_REV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[T1]] : tensor<16x128x2x2xf32>) outs(%[[W_FILLED]] : tensor<16x128x2x2xf32>) {
74+
// CHECK-NEXT: ^bb0(%[[IN_W:.*]]: f32, %[[OUT_W:.*]]: f32):
75+
// CHECK-NEXT: %[[I0:.*]] = linalg.index 0 : index
76+
// CHECK-NEXT: %[[I1:.*]] = linalg.index 1 : index
77+
// CHECK-NEXT: %[[I2:.*]] = linalg.index 2 : index
78+
// CHECK-NEXT: %[[I3:.*]] = linalg.index 3 : index
79+
// CHECK-NEXT: %[[R2:.*]] = arith.subi %[[CST1]], %[[I2]] : index
80+
// CHECK-NEXT: %[[R3:.*]] = arith.subi %[[CST1]], %[[I3]] : index
81+
// CHECK-NEXT: %[[EX:.*]] = tensor.extract %[[T1]][%[[I0]], %[[I1]], %[[R2]], %[[R3]]] : tensor<16x128x2x2xf32>
82+
// CHECK-NEXT: linalg.yield %[[EX]] : f32
83+
// CHECK-NEXT: } -> tensor<16x128x2x2xf32>
84+
// CHECK: %[[SLICE_EMPTY:.*]] = tensor.empty() : tensor<2x16x66x66xf32>
85+
// CHECK-NEXT: %[[SLICE_FILLED:.*]] = linalg.fill ins(%cst : f32) outs(%[[SLICE_EMPTY]] : tensor<2x16x66x66xf32>) -> tensor<2x16x66x66xf32>
86+
// CHECK-NEXT: %[[SLICE:.*]] = tensor.insert_slice %[[T0]] into %[[SLICE_FILLED]][0, 0, 0, 0] [2, 16, 33, 33] [1, 1, 2, 2] : tensor<2x16x33x33xf32> into tensor<2x16x66x66xf32>
87+
// CHECK: %[[OUT_EMPTY:.*]] = tensor.empty() : tensor<2x128x64x64xf32>
88+
// CHECK: %[[OUT_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[OUT_EMPTY]] : tensor<2x128x64x64xf32>) -> tensor<2x128x64x64xf32>
89+
// CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d5 * 2 + d2, d6 * 2 + d3)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[SLICE]], %[[W_REV]] : tensor<2x16x66x66xf32>, tensor<16x128x2x2xf32>) outs(%[[OUT_FILLED]] : tensor<2x128x64x64xf32>) {
90+
// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
91+
// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32
92+
// CHECK-NEXT: %[[ACC:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32
93+
// CHECK-NEXT: linalg.yield %[[ACC]] : f32
94+
// CHECK-NEXT: } -> tensor<2x128x64x64xf32>
95+
// CHECK: %[[IGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV]] : tensor<2x128x64x64xf32> -> !torch.vtensor<[2,128,64,64],f32>
96+
// CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32>
97+
// CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32>
98+
// CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x33x33xf32>) outs(%[[SUM_FILLED]] : tensor<16xf32>) {
99+
// CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32):
100+
// CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32
101+
// CHECK-NEXT: linalg.yield %[[B_RES]] : f32
102+
// CHECK-NEXT: } -> tensor<16xf32>
103+
// CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<16xf32> -> !torch.vtensor<[16],f32>
104+
// CHECK: return %[[IGRAD]], %[[BIAS]] : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>
105+
%true = torch.constant.bool true
106+
%int0 = torch.constant.int 0
107+
%false = torch.constant.bool false
108+
%int1 = torch.constant.int 1
109+
%int2 = torch.constant.int 2
110+
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
111+
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
112+
%2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
113+
%3 = torch.prim.ListConstruct %true, %false, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
114+
%result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int1, %3 : !torch.vtensor<[2,16,33,33],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16,128,2,2],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int, !torch.list<bool> -> !torch.vtensor<[2,128,64,64],f32>, !torch.none, !torch.vtensor<[16],f32>
115+
return %result0, %result2 : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>
116+
}
117+
118+
// -----
119+
62120
// CHECK-LABEL: func.func @convolution_backward_weights_1x1s_0x0p_1x1d_1g(
63121
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,63,63],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>,
64122
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,128,2,2],f32>,

0 commit comments

Comments
 (0)