Skip to content

Commit 661be2d

Browse files
[MLIR][Torch] Add TorchToLinalg lowering for AtenAvgPool3dOp (#3030)
This commit also fixes the average pool op' test failing for OnnxToLinalg lowering. Signed-Off By: Vivek Khandelwal <[email protected]>
1 parent 35dd8c5 commit 661be2d

File tree

8 files changed

+379
-60
lines changed

8 files changed

+379
-60
lines changed

include/torch-mlir/Conversion/TorchToLinalg/Utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ getBackendTypeForScalarType(MLIRContext *context,
9797

9898
bool isUnsignedTorchType(Type type);
9999

100+
LogicalResult permuteTensor(Operation *op, PatternRewriter &rewriter,
101+
Location loc, SmallVector<int64_t> dimensions,
102+
Value input, Value &result);
103+
100104
} // namespace torch_to_linalg
101105
} // namespace torch
102106
} // namespace mlir

lib/Conversion/TorchToLinalg/DataMovement.cpp

Lines changed: 7 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1800,55 +1800,15 @@ class ConvertAtenPermuteOp : public OpConversionPattern<AtenPermuteOp> {
18001800
return rewriter.notifyMatchFailure(op, "all dimensions must be constant");
18011801

18021802
Value inVector = adaptor.getSelf();
1803-
auto inType = cast<RankedTensorType>(inVector.getType());
1804-
int64_t inputRank = inType.getRank();
1805-
auto outType = cast<RankedTensorType>(
1806-
getTypeConverter()->convertType(op->getResult(0).getType()));
1807-
Type elementType = inType.getElementType();
1808-
1809-
// Check if the dimensions are a valid constants.
1810-
int64_t numDimensions = dimensions.size();
1811-
if (inputRank != numDimensions)
1803+
Value result;
1804+
if (failed(torch_to_linalg::permuteTensor(op, rewriter, op->getLoc(),
1805+
dimensions, inVector, result)))
18121806
return rewriter.notifyMatchFailure(
1813-
op, "size of `dims` must be equal to the rank of the input");
1814-
for (unsigned i = 0; i < numDimensions; i++) {
1815-
if (dimensions[i] < 0)
1816-
dimensions[i] = toPositiveDim(dimensions[i], inputRank);
1817-
if (!isValidDim(dimensions[i], inputRank))
1818-
return rewriter.notifyMatchFailure(op, "dimension out of range");
1819-
}
1820-
1821-
Location loc = op.getLoc();
1822-
1823-
SmallVector<Value> outputDims;
1824-
for (unsigned i = 0; i < inputRank; i++)
1825-
outputDims.push_back(getDimOp(rewriter, loc, inVector, dimensions[i]));
1826-
1827-
Value outVector = rewriter.create<tensor::EmptyOp>(
1828-
loc, getAsOpFoldResult(outputDims), elementType);
1829-
SmallVector<AffineExpr> idExprs;
1830-
SmallVector<AffineExpr> swapExprs;
1831-
for (unsigned i = 0; i < inputRank; i++)
1832-
idExprs.push_back(getAffineDimExpr(i, rewriter.getContext()));
1833-
for (unsigned i = 0; i < inputRank; i++)
1834-
swapExprs.push_back(idExprs[dimensions[i]]);
1807+
op, "failed to perform permutation of tensor");
18351808

1836-
AffineMap inputMap =
1837-
AffineMap::get(inputRank, /*symbolCount=*/0, idExprs, op->getContext());
1838-
AffineMap outputMap = AffineMap::get(inputRank, /*symbolCount=*/0,
1839-
swapExprs, op->getContext());
1840-
SmallVector<AffineMap> indexingMaps{inputMap, outputMap};
1841-
SmallVector<utils::IteratorType> iteratorTypes(
1842-
inputRank, utils::IteratorType::parallel);
1843-
auto transpose = rewriter
1844-
.create<linalg::GenericOp>(
1845-
loc, outVector.getType(), inVector, outVector,
1846-
indexingMaps, iteratorTypes,
1847-
[](OpBuilder &b, Location loc, ValueRange args) {
1848-
b.create<linalg::YieldOp>(loc, args[0]);
1849-
})
1850-
.getResult(0);
1851-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outType, transpose);
1809+
auto outType = cast<RankedTensorType>(
1810+
getTypeConverter()->convertType(op->getResult(0).getType()));
1811+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outType, result);
18521812
return success();
18531813
}
18541814
};

lib/Conversion/TorchToLinalg/Pooling.cpp

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,42 @@ static LogicalResult createPoolingOp(
168168
Value windowTensor = rewriter.create<tensor::EmptyOp>(
169169
loc, getAsOpFoldResult(shape), elementType);
170170

171-
result = rewriter
172-
.create<OpTy>(loc, outTensorInitialized.getType(),
173-
ValueRange{paddedInput, windowTensor},
174-
outTensorInitialized, stridesAttr, dilationAttr)
175-
.getResult(0);
171+
Value permutedInput = paddedInput, permutedOutput = outTensorInitialized;
172+
if (dimensionality == 3) {
173+
// Permute input and output tensor as follows:
174+
// (n,c,d,h,w) -> (n,d,h,w,c)
175+
SmallVector<int64_t> dimensions = {0, 2, 3, 4, 1};
176+
if (failed(torch_to_linalg::permuteTensor(op, rewriter, op->getLoc(),
177+
dimensions, paddedInput,
178+
permutedInput)))
179+
return rewriter.notifyMatchFailure(
180+
op, "failed to perform permutation of tensor");
181+
182+
if (failed(torch_to_linalg::permuteTensor(op, rewriter, op->getLoc(),
183+
dimensions, outTensorInitialized,
184+
permutedOutput)))
185+
return rewriter.notifyMatchFailure(
186+
op, "failed to perform permutation of tensor");
187+
}
188+
189+
Value poolingResult =
190+
rewriter
191+
.create<OpTy>(loc, permutedOutput.getType(),
192+
ValueRange{permutedInput, windowTensor}, permutedOutput,
193+
stridesAttr, dilationAttr)
194+
.getResult(0);
195+
196+
result = poolingResult;
197+
if (dimensionality == 3) {
198+
// Permute output tensor as follows:
199+
// (n,d,h,w,c) -> (n,c,d,h,w)
200+
SmallVector<int64_t> dimensions = {0, 4, 1, 2, 3};
201+
if (failed(torch_to_linalg::permuteTensor(
202+
op, rewriter, op->getLoc(), dimensions, poolingResult, result)))
203+
return rewriter.notifyMatchFailure(
204+
op, "failed to perform permutation of tensor");
205+
}
206+
176207
return success();
177208
}
178209

@@ -604,15 +635,17 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
604635
paddingInts, dilationInts, rewriter.getZeroAttr(inputElementType),
605636
outTensorShape, paddedInput, sumPool)))
606637
return rewriter.notifyMatchFailure(op, "unable to compute sumpool");
607-
Value divisor;
608-
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
609-
Value kHtimeskW = rewriter.create<arith::MulIOp>(
610-
loc, kernelSizeIntValues[0], kernelSizeIntValues[1]);
638+
// }
639+
640+
Value divisor = kernelSizeIntValues[0];
641+
for (uint32_t i = 1; i < kernelSizeIntValues.size(); i++) {
642+
divisor =
643+
rewriter.create<arith::MulIOp>(loc, divisor, kernelSizeIntValues[i]);
644+
}
645+
if constexpr (!std::is_same<OpTy, AtenAvgPool1dOp>()) {
611646
divisor = isa<Torch::NoneType>(op.getDivisorOverride().getType())
612-
? kHtimeskW
647+
? divisor
613648
: adaptor.getDivisorOverride();
614-
} else {
615-
divisor = kernelSizeIntValues[0];
616649
}
617650
divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType);
618651

@@ -1115,13 +1148,16 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
11151148

11161149
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
11171150
patterns.add<ConvertAtenMaxPool2dWithIndicesOp>(typeConverter, context);
1118-
target.addIllegalOp<AtenAvgPool1dOp, AtenAvgPool2dOp>();
1151+
target.addIllegalOp<AtenAvgPool1dOp, AtenAvgPool2dOp, AtenAvgPool3dOp>();
11191152
patterns
11201153
.add<ConvertAtenAvgPoolOp<AtenAvgPool1dOp, linalg::PoolingNcwSumOp, 1>>(
11211154
typeConverter, context);
11221155
patterns
11231156
.add<ConvertAtenAvgPoolOp<AtenAvgPool2dOp, linalg::PoolingNchwSumOp, 2>>(
11241157
typeConverter, context);
1158+
patterns
1159+
.add<ConvertAtenAvgPoolOp<AtenAvgPool3dOp, linalg::PoolingNdhwcSumOp, 3>>(
1160+
typeConverter, context);
11251161
target.addIllegalOp<AtenAdaptiveAvgPool1dOp, AtenAdaptiveAvgPool2dOp,
11261162
AtenAdaptiveAvgPool3dOp, Aten_AdaptiveAvgPool3dOp>();
11271163
patterns.add<ConvertAtenAdaptivePoolOp<AtenAdaptiveAvgPool1dOp>>(

lib/Conversion/TorchToLinalg/Utils.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,3 +572,55 @@ bool torch_to_linalg::isUnsignedTorchType(Type type) {
572572
llvm_unreachable("Unknown type checked for signedness");
573573
return false;
574574
}
575+
576+
LogicalResult torch_to_linalg::permuteTensor(Operation *op,
577+
PatternRewriter &rewriter,
578+
Location loc,
579+
SmallVector<int64_t> dimensions,
580+
Value input, Value &result) {
581+
auto inType = cast<RankedTensorType>(input.getType());
582+
int64_t inputRank = inType.getRank();
583+
Type elementType = inType.getElementType();
584+
585+
// Check if the dimensions are a valid constants.
586+
int64_t numDimensions = dimensions.size();
587+
if (inputRank != numDimensions)
588+
return rewriter.notifyMatchFailure(
589+
op, "size of `dims` must be equal to the rank of the input");
590+
for (uint32_t i = 0; i < numDimensions; i++) {
591+
if (dimensions[i] < 0)
592+
dimensions[i] = toPositiveDim(dimensions[i], inputRank);
593+
if (!isValidDim(dimensions[i], inputRank))
594+
return rewriter.notifyMatchFailure(op, "dimension out of range");
595+
}
596+
597+
SmallVector<Value> outputDims;
598+
for (uint32_t i = 0; i < inputRank; i++)
599+
outputDims.push_back(getDimOp(rewriter, loc, input, dimensions[i]));
600+
601+
Value outVector = rewriter.create<tensor::EmptyOp>(
602+
loc, getAsOpFoldResult(outputDims), elementType);
603+
SmallVector<AffineExpr> idExprs;
604+
SmallVector<AffineExpr> swapExprs;
605+
for (uint32_t i = 0; i < inputRank; i++)
606+
idExprs.push_back(getAffineDimExpr(i, rewriter.getContext()));
607+
for (uint32_t i = 0; i < inputRank; i++)
608+
swapExprs.push_back(idExprs[dimensions[i]]);
609+
610+
AffineMap inputMap =
611+
AffineMap::get(inputRank, /*symbolCount=*/0, idExprs, op->getContext());
612+
AffineMap outputMap =
613+
AffineMap::get(inputRank, /*symbolCount=*/0, swapExprs, op->getContext());
614+
SmallVector<AffineMap> indexingMaps{inputMap, outputMap};
615+
SmallVector<utils::IteratorType> iteratorTypes(inputRank,
616+
utils::IteratorType::parallel);
617+
result = rewriter
618+
.create<linalg::GenericOp>(
619+
loc, outVector.getType(), input, outVector, indexingMaps,
620+
iteratorTypes,
621+
[](OpBuilder &b, Location loc, ValueRange args) {
622+
b.create<linalg::YieldOp>(loc, args[0]);
623+
})
624+
.getResult(0);
625+
return success();
626+
}

0 commit comments

Comments
 (0)