@@ -40,6 +40,7 @@ static int64_t productReduce(ArrayRef<int64_t> a) {
4040template <typename OpTy, typename OpAdaptor>
4141LogicalResult prepareArgumentsForSlicingOp (OpTy op, OpAdaptor adaptor,
4242 ConversionPatternRewriter &rewriter,
43+ int64_t &dim,
4344 SmallVector<Value> &resultShape,
4445 SmallVector<Value> &offsets,
4546 SmallVector<Value> &strides) {
@@ -51,7 +52,6 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
5152 Value one = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
5253 Value negone = rewriter.create <arith::ConstantIndexOp>(loc, -1 );
5354
54- int64_t dim;
5555 if (!matchPattern (op.getDim (), m_TorchConstantInt (&dim)))
5656 return op->emitError (" unimplemented: dim is not constant" );
5757
@@ -1857,14 +1857,46 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern<AtenSliceTensorOp> {
18571857 RankedTensorType resultType = cast<RankedTensorType>(
18581858 typeConverter->convertType (op->getResult (0 ).getType ()));
18591859
1860- SmallVector<Value> resultShape;
1861- SmallVector<Value> offsets;
1862- SmallVector<Value> strides;
1860+ SmallVector<Value> resultShape, offsets, strides;
1861+ int64_t dim;
18631862 if (failed (prepareArgumentsForSlicingOp<AtenSliceTensorOp,
18641863 AtenSliceTensorOpAdaptor>(
1865- op, adaptor, rewriter, resultShape, offsets, strides))) {
1864+ op, adaptor, rewriter, dim, resultShape, offsets, strides))) {
18661865 return failure ();
18671866 }
1867+
1868+ // If stride is negative, then flip the input tensor corresponding to that
1869+ // dim, update the stride for flipped tensor by multiplying it by -1, and
1870+ // update the offset as follows:
1871+ // flipped_offset = input_shape[dim] - (result_shape[dim] * flipped_stride)
1872+ //
1873+ // For example:
1874+ // Input = [0, 1, 2, 3, 4, 5]
1875+ // stride = [-2], result_shape = [2], offset = [3]
1876+ // Result = [3, 1]
1877+ // After flipping:
1878+ // Input = [5, 4, 3, 2, 1, 0]
1879+ // stride = [2], result_shape = [2], offset = [6 - (2 * 2)] = [2]
1880+ // Result = [3, 1]
1881+
1882+ Value flippedInput = torch_to_linalg::flipTensor (rewriter, loc, input,
1883+ SmallVector<int64_t >{dim});
1884+ Value cstDim = rewriter.create <arith::ConstantIndexOp>(loc, dim);
1885+ Value zero = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
1886+ Value isNegativeStride = rewriter.create <arith::CmpIOp>(
1887+ loc, arith::CmpIPredicate::slt, strides[dim], zero);
1888+ strides[dim] = rewriter.create <math::AbsIOp>(loc, strides[dim]);
1889+ Value resShapeMulStride =
1890+ rewriter.create <arith::MulIOp>(loc, resultShape[dim], strides[dim]);
1891+ Value inputDim = rewriter.create <tensor::DimOp>(loc, input, cstDim);
1892+ Value flippedOffset =
1893+ rewriter.create <arith::SubIOp>(loc, inputDim, resShapeMulStride);
1894+ offsets[dim] = rewriter.create <arith::SelectOp>(
1895+ loc, isNegativeStride, flippedOffset, offsets[dim]);
1896+
1897+ input = rewriter.create <arith::SelectOp>(loc, isNegativeStride,
1898+ flippedInput, input);
1899+
18681900 SmallVector<int64_t > dynShape (resultType.getRank (), ShapedType::kDynamic );
18691901 auto sliceType = RankedTensorType::get (
18701902 dynShape, resultType.getElementType (), resultType.getEncoding ());
@@ -2095,12 +2127,11 @@ class ConvertAtenSliceScatterOp
20952127 RankedTensorType resultType = cast<RankedTensorType>(
20962128 typeConverter->convertType (op->getResult (0 ).getType ()));
20972129
2098- SmallVector<Value> resultShape;
2099- SmallVector<Value> offsets;
2100- SmallVector<Value> strides;
2130+ SmallVector<Value> resultShape, offsets, strides;
2131+ int64_t dim;
21012132 if (failed (prepareArgumentsForSlicingOp<AtenSliceScatterOp,
21022133 AtenSliceScatterOpAdaptor>(
2103- op, adaptor, rewriter, resultShape, offsets, strides))) {
2134+ op, adaptor, rewriter, dim, resultShape, offsets, strides))) {
21042135 return failure ();
21052136 }
21062137
0 commit comments