Skip to content

Commit f6721e5

Browse files
[MLIR][TORCH] Add support for negative step in aten.slice.Tensor op (#3763)
This commit adds the support for negative step values in aten.slice.Tensor op. Although, PyTorch does not allow negative step value for slice op but the Onnx.Slice op supports negative step value which eventually lowers to torch.aten.slice.Tensor op. Hence, the support is added for handling those kind of values during the Torch->Linalg lowering of aten.slice.Tensor op. Signed-Off By: Vivek Khandelwal <[email protected]>
1 parent b08d086 commit f6721e5

File tree

4 files changed

+86
-47
lines changed

4 files changed

+86
-47
lines changed

include/torch-mlir/Conversion/TorchToLinalg/Utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ LogicalResult permuteTensor(Operation *op, PatternRewriter &rewriter,
101101
Location loc, SmallVector<int64_t> dimensions,
102102
Value input, Value &result);
103103

104+
// Flips an input tensor based on the values of axis list.
105+
Value flipTensor(PatternRewriter &rewriter, Location loc, Value input,
106+
SmallVector<int64_t> axis);
107+
104108
} // namespace torch_to_linalg
105109
} // namespace torch
106110
} // namespace mlir

lib/Conversion/TorchToLinalg/DataMovement.cpp

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ static int64_t productReduce(ArrayRef<int64_t> a) {
4040
template <typename OpTy, typename OpAdaptor>
4141
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
4242
ConversionPatternRewriter &rewriter,
43+
int64_t &dim,
4344
SmallVector<Value> &resultShape,
4445
SmallVector<Value> &offsets,
4546
SmallVector<Value> &strides) {
@@ -51,7 +52,6 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
5152
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
5253
Value negone = rewriter.create<arith::ConstantIndexOp>(loc, -1);
5354

54-
int64_t dim;
5555
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
5656
return op->emitError("unimplemented: dim is not constant");
5757

@@ -1857,14 +1857,46 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern<AtenSliceTensorOp> {
18571857
RankedTensorType resultType = cast<RankedTensorType>(
18581858
typeConverter->convertType(op->getResult(0).getType()));
18591859

1860-
SmallVector<Value> resultShape;
1861-
SmallVector<Value> offsets;
1862-
SmallVector<Value> strides;
1860+
SmallVector<Value> resultShape, offsets, strides;
1861+
int64_t dim;
18631862
if (failed(prepareArgumentsForSlicingOp<AtenSliceTensorOp,
18641863
AtenSliceTensorOpAdaptor>(
1865-
op, adaptor, rewriter, resultShape, offsets, strides))) {
1864+
op, adaptor, rewriter, dim, resultShape, offsets, strides))) {
18661865
return failure();
18671866
}
1867+
1868+
// If stride is negative, then flip the input tensor corresponding to that
1869+
// dim, update the stride for flipped tensor by multiplying it by -1, and
1870+
// update the offset as follows:
1871+
// flipped_offset = input_shape[dim] - (result_shape[dim] * flipped_stride)
1872+
//
1873+
// For example:
1874+
// Input = [0, 1, 2, 3, 4, 5]
1875+
// stride = [-2], result_shape = [2], offset = [3]
1876+
// Result = [3, 1]
1877+
// After flipping:
1878+
// Input = [5, 4, 3, 2, 1, 0]
1879+
// stride = [2], result_shape = [2], offset = [6 - (2 * 2)] = [2]
1880+
// Result = [3, 1]
1881+
1882+
Value flippedInput = torch_to_linalg::flipTensor(rewriter, loc, input,
1883+
SmallVector<int64_t>{dim});
1884+
Value cstDim = rewriter.create<arith::ConstantIndexOp>(loc, dim);
1885+
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1886+
Value isNegativeStride = rewriter.create<arith::CmpIOp>(
1887+
loc, arith::CmpIPredicate::slt, strides[dim], zero);
1888+
strides[dim] = rewriter.create<math::AbsIOp>(loc, strides[dim]);
1889+
Value resShapeMulStride =
1890+
rewriter.create<arith::MulIOp>(loc, resultShape[dim], strides[dim]);
1891+
Value inputDim = rewriter.create<tensor::DimOp>(loc, input, cstDim);
1892+
Value flippedOffset =
1893+
rewriter.create<arith::SubIOp>(loc, inputDim, resShapeMulStride);
1894+
offsets[dim] = rewriter.create<arith::SelectOp>(
1895+
loc, isNegativeStride, flippedOffset, offsets[dim]);
1896+
1897+
input = rewriter.create<arith::SelectOp>(loc, isNegativeStride,
1898+
flippedInput, input);
1899+
18681900
SmallVector<int64_t> dynShape(resultType.getRank(), ShapedType::kDynamic);
18691901
auto sliceType = RankedTensorType::get(
18701902
dynShape, resultType.getElementType(), resultType.getEncoding());
@@ -2095,12 +2127,11 @@ class ConvertAtenSliceScatterOp
20952127
RankedTensorType resultType = cast<RankedTensorType>(
20962128
typeConverter->convertType(op->getResult(0).getType()));
20972129

2098-
SmallVector<Value> resultShape;
2099-
SmallVector<Value> offsets;
2100-
SmallVector<Value> strides;
2130+
SmallVector<Value> resultShape, offsets, strides;
2131+
int64_t dim;
21012132
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
21022133
AtenSliceScatterOpAdaptor>(
2103-
op, adaptor, rewriter, resultShape, offsets, strides))) {
2134+
op, adaptor, rewriter, dim, resultShape, offsets, strides))) {
21042135
return failure();
21052136
}
21062137

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -222,14 +222,9 @@ class ConvertAtenFlipOp : public OpConversionPattern<AtenFlipOp> {
222222
ConversionPatternRewriter &rewriter) const override {
223223

224224
Location loc = op->getLoc();
225-
MLIRContext *context = op.getContext();
226225
Value self = adaptor.getSelf();
227226
auto selfRank =
228227
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
229-
Type elementType =
230-
cast<RankedTensorType>(adaptor.getSelf().getType()).getElementType();
231-
Value c1 =
232-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
233228

234229
SmallVector<int64_t> axis;
235230
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis)))
@@ -242,40 +237,8 @@ class ConvertAtenFlipOp : public OpConversionPattern<AtenFlipOp> {
242237
}
243238
}
244239

245-
// Only used to calculate flipped values, i.e. those on the flip axes. Other
246-
// dims won't be used.
247-
SmallVector<Value> dims = getTensorSizes(rewriter, loc, self);
248-
for (auto flipDim : axis)
249-
dims[flipDim] = rewriter.create<arith::SubIOp>(loc, dims[flipDim], c1);
250-
251-
Value initTensor = createZeroInitTensor(
252-
rewriter, loc, getTensorSizes(rewriter, loc, self), elementType);
253-
254-
SmallVector<utils::IteratorType> iteratorTypes(
255-
selfRank, utils::IteratorType::parallel);
256-
SmallVector<AffineMap> indexingMaps(
257-
2, AffineMap::getMultiDimIdentityMap(selfRank, context));
258-
Value flipped =
259-
rewriter
260-
.create<linalg::GenericOp>(
261-
loc, self.getType(), self, initTensor, indexingMaps,
262-
iteratorTypes,
263-
[&](OpBuilder &b, Location loc, ValueRange args) {
264-
SmallVector<Value> indices;
265-
for (auto i = 0; i < selfRank; i++)
266-
indices.push_back(b.create<linalg::IndexOp>(loc, i));
267-
for (auto flipDim : axis) {
268-
indices[flipDim] = b.create<arith::SubIOp>(
269-
loc, dims[flipDim], indices[flipDim]);
270-
}
271-
Value res = b.create<tensor::ExtractOp>(loc, self, indices)
272-
.getResult();
273-
b.create<linalg::YieldOp>(loc, res);
274-
})
275-
.getResult(0);
276-
240+
Value flipped = torch_to_linalg::flipTensor(rewriter, loc, self, axis);
277241
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, self.getType(), flipped);
278-
279242
return success();
280243
}
281244
};

lib/Conversion/TorchToLinalg/Utils.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,3 +620,44 @@ LogicalResult torch_to_linalg::permuteTensor(Operation *op,
620620
.getResult(0);
621621
return success();
622622
}
623+
624+
// Flips an input tensor based on the values of axis list.
625+
Value torch_to_linalg::flipTensor(PatternRewriter &rewriter, Location loc,
626+
Value input, SmallVector<int64_t> axis) {
627+
Value c1 = rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
628+
Type elementType = cast<RankedTensorType>(input.getType()).getElementType();
629+
auto selfRank = cast<RankedTensorType>(input.getType()).getRank();
630+
631+
// Only used to calculate flipped values, i.e. those on the flip axes. Other
632+
// dims won't be used.
633+
SmallVector<Value> dims = getTensorSizes(rewriter, loc, input);
634+
for (auto flipDim : axis)
635+
dims[flipDim] = rewriter.create<arith::SubIOp>(loc, dims[flipDim], c1);
636+
637+
Value initTensor = createZeroInitTensor(
638+
rewriter, loc, getTensorSizes(rewriter, loc, input), elementType);
639+
640+
SmallVector<utils::IteratorType> iteratorTypes(selfRank,
641+
utils::IteratorType::parallel);
642+
SmallVector<AffineMap> indexingMaps(
643+
2, AffineMap::getMultiDimIdentityMap(selfRank, rewriter.getContext()));
644+
Value flipped =
645+
rewriter
646+
.create<linalg::GenericOp>(
647+
loc, input.getType(), input, initTensor, indexingMaps,
648+
iteratorTypes,
649+
[&](OpBuilder &b, Location loc, ValueRange args) {
650+
SmallVector<Value> indices;
651+
for (auto i = 0; i < selfRank; i++)
652+
indices.push_back(b.create<linalg::IndexOp>(loc, i));
653+
for (auto flipDim : axis) {
654+
indices[flipDim] = b.create<arith::SubIOp>(loc, dims[flipDim],
655+
indices[flipDim]);
656+
}
657+
Value res = b.create<tensor::ExtractOp>(loc, input, indices)
658+
.getResult();
659+
b.create<linalg::YieldOp>(loc, res);
660+
})
661+
.getResult(0);
662+
return flipped;
663+
}

0 commit comments

Comments
 (0)