Skip to content

Commit b08d086

Browse files
[TOSA] Add legalization for fill, flip, and round (#3768)
- Add Torch to TOSA lowering for aten.fill.Scalar/Tensor, aten.flip, and aten.round - Fix torchScalarToTosaTensor function to correctly convert Torch scalar input to TOSA tensor - Update xfail_sets.py with new e2e results - Update basic.mlir with LIT tests for new ops Change-Id: If1e42c2e582710dd8ad0465eed29806fbcdbde41 Signed-off-by: Justin Ngo <[email protected]>
1 parent f4840ed commit b08d086

File tree

3 files changed

+298
-56
lines changed

3 files changed

+298
-56
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 188 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,17 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
153153
return rewriter.notifyMatchFailure(op,
154154
"Unable to extract the scalar constant");
155155

156+
int64_t numElem = 1;
157+
for (int64_t dim : dshape)
158+
numElem *= dim;
159+
156160
if (isa<mlir::FloatType>(dtype)) {
157-
tosaTensor = tosa::getConstTensor<float>(rewriter, op,
158-
(isFloat ? doubleValue : intValue),
159-
dshape, dtype)
160-
.value();
161+
tosaTensor =
162+
tosa::getConstTensor<float>(
163+
rewriter, op,
164+
SmallVector<float>(numElem, (isFloat ? doubleValue : intValue)),
165+
dshape, dtype)
166+
.value();
161167
} else if (auto intType = dyn_cast<mlir::IntegerType>(dtype)) {
162168
auto w = intType.getWidth();
163169
if (w != 1 && w != 32 && w != 64)
@@ -173,8 +179,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
173179
}
174180
bool d = isFloat ? static_cast<bool>(doubleValue)
175181
: static_cast<bool>(intValue);
176-
tosaTensor =
177-
tosa::getConstTensor<bool>(rewriter, op, {d}, dshape).value();
182+
tosaTensor = tosa::getConstTensor<bool>(
183+
rewriter, op, SmallVector<bool>(numElem, d), dshape)
184+
.value();
178185
} else if (w == 32) {
179186
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
180187
return rewriter.notifyMatchFailure(
@@ -183,17 +190,19 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
183190
}
184191
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
185192
: static_cast<int32_t>(intValue);
186-
tosaTensor =
187-
tosa::getConstTensor<int32_t>(rewriter, op, {d}, dshape).value();
193+
tosaTensor = tosa::getConstTensor<int32_t>(
194+
rewriter, op, SmallVector<int32_t>(numElem, d), dshape)
195+
.value();
188196
} else if (w == 64) {
189197
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
190198
return rewriter.notifyMatchFailure(
191199
op, "Supplied value of scalar constant exceeds limits "
192200
"of destination type");
193201
}
194202
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
195-
tosaTensor =
196-
tosa::getConstTensor<int64_t>(rewriter, op, {d}, dshape).value();
203+
tosaTensor = tosa::getConstTensor<int64_t>(
204+
rewriter, op, SmallVector<int64_t>(numElem, d), dshape)
205+
.value();
197206
}
198207
} else {
199208
return rewriter.notifyMatchFailure(op, "Usupported element type");
@@ -5320,7 +5329,7 @@ class ConvertAtenConstPatternOp : public OpConversionPattern<AtenOpT> {
53205329
};
53215330

53225331
template <typename AtenOpT>
5323-
class ConvertAtenFillScalarOp : public OpConversionPattern<AtenOpT> {
5332+
class ConvertAtenFillOp : public OpConversionPattern<AtenOpT> {
53245333
public:
53255334
using OpConversionPattern<AtenOpT>::OpConversionPattern;
53265335
using OpAdaptor = typename AtenOpT::Adaptor;
@@ -5336,18 +5345,48 @@ class ConvertAtenFillScalarOp : public OpConversionPattern<AtenOpT> {
53365345
op, "Only Tensor types with static shapes are currently supported");
53375346

53385347
Type outElemTy = outType.getElementType();
5339-
if (!outElemTy.isIntOrFloat()) {
5348+
if (!outElemTy.isIntOrFloat())
53405349
return rewriter.notifyMatchFailure(
53415350
op, "Only floating-point or integer datatype legalization supported");
5351+
5352+
Value fillValueTargetTensor;
5353+
if constexpr (std::is_same<AtenOpT, AtenFillTensorOp>()) {
5354+
// Reshape value tensor to have same rank and shape as input
5355+
auto inputRank =
5356+
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
5357+
5358+
auto fillValue = adaptor.getValue();
5359+
auto fillValueType = dyn_cast<TensorType>(fillValue.getType());
5360+
if (!fillValueType)
5361+
return rewriter.notifyMatchFailure(op, "Fill value is not a tensor");
5362+
auto fillValueElemTy = fillValueType.getElementType();
5363+
5364+
SmallVector<int64_t> fillValueMatchedInputRankShape(inputRank, 1);
5365+
5366+
auto fillValueMatchedInputRankType = RankedTensorType::get(
5367+
makeShapeTorchCompatible(fillValueMatchedInputRankShape),
5368+
fillValueElemTy);
5369+
5370+
auto fillValueMatchedInputRankTensor = rewriter.create<tosa::ReshapeOp>(
5371+
op->getLoc(), fillValueMatchedInputRankType, fillValue,
5372+
rewriter.getDenseI64ArrayAttr(fillValueMatchedInputRankShape));
5373+
5374+
fillValueTargetTensor = rewriter.create<tosa::TileOp>(
5375+
op->getLoc(),
5376+
RankedTensorType::get(makeShapeTorchCompatible(outType.getShape()),
5377+
fillValueElemTy),
5378+
fillValueMatchedInputRankTensor.getResult(),
5379+
makeShapeTorchCompatible(outType.getShape()));
5380+
} else {
5381+
if (failed(torchScalarToTosaTensor(
5382+
rewriter, op, op.getValue(), fillValueTargetTensor, outElemTy,
5383+
makeShapeTorchCompatible(outType.getShape()))))
5384+
return rewriter.notifyMatchFailure(
5385+
op, "Fill value must be a scalar constant");
53425386
}
5343-
Value constOp;
5344-
if (failed(torchScalarToTosaTensor(
5345-
rewriter, op, op.getValue(), constOp, outElemTy,
5346-
makeShapeTorchCompatible(outType.getShape()))))
5347-
return rewriter.notifyMatchFailure(
5348-
op, "Supplied value must be a Scalar constant");
53495387

5350-
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, constOp);
5388+
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType,
5389+
fillValueTargetTensor);
53515390

53525391
return success();
53535392
}
@@ -5869,6 +5908,127 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
58695908
return success();
58705909
}
58715910

5911+
// Legalization for aten.flip
5912+
template <>
5913+
LogicalResult ConvertAtenOp<AtenFlipOp>::matchAndRewrite(
5914+
AtenFlipOp op, OpAdaptor adaptor,
5915+
ConversionPatternRewriter &rewriter) const {
5916+
5917+
auto self = adaptor.getSelf();
5918+
5919+
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
5920+
if (!selfTy)
5921+
return rewriter.notifyMatchFailure(
5922+
op, "Only ranked tensor types are currently supported");
5923+
5924+
SmallVector<int64_t> dims;
5925+
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(dims)))
5926+
return rewriter.notifyMatchFailure(
5927+
op, "Only constant dims are currently supported");
5928+
5929+
auto selfRank = selfTy.getRank();
5930+
5931+
auto resultTy = getTypeConverter()->convertType(op.getType());
5932+
Value result = self;
5933+
5934+
for (auto &dim : dims) {
5935+
dim = toPositiveDim(dim, selfRank);
5936+
if (!isValidDim(dim, selfRank))
5937+
return rewriter.notifyMatchFailure(op, "Not all dims are valid");
5938+
5939+
result = rewriter.create<tosa::ReverseOp>(op->getLoc(), resultTy, result,
5940+
static_cast<int32_t>(dim));
5941+
}
5942+
5943+
rewriter.replaceOp(op, result);
5944+
return success();
5945+
}
5946+
5947+
// Legalization for aten.round:
5948+
// Rounds elements of input to the nearest integer.
5949+
// Implements "round half to even" to break ties when a number is equidistant
5950+
// from two integers.
5951+
template <>
5952+
LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
5953+
AtenRoundOp op, OpAdaptor adaptor,
5954+
ConversionPatternRewriter &rewriter) const {
5955+
// To round to the nearest integer, we will consider the fractional part of
5956+
// the input element (= input element - integer part of element). If the
5957+
// fractional part is smaller than 0.5, round the number down. If the
5958+
// fractional part is 0.5, apply "round half to even" rule. If the fractional
5959+
// part is greater than 0.5, round up.
5960+
//
5961+
// if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)):
5962+
// res = floor(input)
5963+
// else:
5964+
// res = ceil(input)
5965+
5966+
auto self = adaptor.getSelf();
5967+
5968+
auto selfTy = dyn_cast<TensorType>(self.getType());
5969+
if (!selfTy)
5970+
return rewriter.notifyMatchFailure(op, "Only tensor types supported");
5971+
5972+
auto resultTy =
5973+
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
5974+
5975+
auto boolTy =
5976+
RankedTensorType::get(resultTy.getShape(), rewriter.getIntegerType(1));
5977+
5978+
auto resultElemTy = resultTy.getElementType();
5979+
5980+
auto oneHalf =
5981+
tosa::getConstTensor<float>(rewriter, op, 0.5, {}, resultElemTy).value();
5982+
5983+
auto two =
5984+
tosa::getConstTensor<float>(rewriter, op, 2, {}, resultElemTy).value();
5985+
5986+
auto floorInput =
5987+
rewriter.create<tosa::FloorOp>(op->getLoc(), resultTy, self);
5988+
5989+
// input - floor(input)
5990+
auto fractionalPart = rewriter.create<tosa::SubOp>(
5991+
op->getLoc(), resultTy, self, floorInput.getResult());
5992+
5993+
auto ceilInput = rewriter.create<tosa::CeilOp>(op->getLoc(), resultTy, self);
5994+
5995+
auto floorInputDivByTwo = rewriter.create<tosa::MulOp>(
5996+
op->getLoc(), resultTy, floorInput.getResult(), oneHalf, /*shift=*/0);
5997+
5998+
auto floorDivResult = rewriter.create<tosa::FloorOp>(
5999+
op->getLoc(), resultTy, floorInputDivByTwo.getResult());
6000+
6001+
// (floor(input) // 2) * 2
6002+
auto evenComparison = rewriter.create<tosa::MulOp>(
6003+
op->getLoc(), resultTy, floorDivResult.getResult(), two, /*shift=*/0);
6004+
6005+
// floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0
6006+
auto floorInputEven = rewriter.create<tosa::EqualOp>(
6007+
op->getLoc(), boolTy, floorInput.getResult(), evenComparison.getResult());
6008+
6009+
auto fracEqualOneHalf = rewriter.create<tosa::EqualOp>(
6010+
op->getLoc(), boolTy, fractionalPart.getResult(), oneHalf);
6011+
6012+
auto fracLtOneHalf = rewriter.create<tosa::GreaterOp>(
6013+
op->getLoc(), boolTy, oneHalf, fractionalPart.getResult());
6014+
6015+
// (frac == 0.5) && (floor(input) % 2 == 0)
6016+
auto fracEqualOneHalfCond = rewriter.create<tosa::LogicalAndOp>(
6017+
op->getLoc(), boolTy, fracEqualOneHalf.getResult(),
6018+
floorInputEven.getResult());
6019+
6020+
// (frac < 0.5) || ((frac == 0.5) && (floor(input) % 2 == 0))
6021+
auto floorResultCond = rewriter.create<tosa::LogicalOrOp>(
6022+
op->getLoc(), boolTy, fracLtOneHalf.getResult(),
6023+
fracEqualOneHalfCond.getResult());
6024+
6025+
rewriter.replaceOpWithNewOp<tosa::SelectOp>(
6026+
op, resultTy, floorResultCond.getResult(), floorInput.getResult(),
6027+
ceilInput.getResult());
6028+
6029+
return success();
6030+
}
6031+
58726032
// Template to create supporting diagonal mask tensor for aten.diagonal
58736033
template <typename T>
58746034
Value createDiagonalMask(PatternRewriter &rewriter, Operation *op,
@@ -6052,6 +6212,7 @@ LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
60526212

60536213
return success();
60546214
}
6215+
60556216
} // namespace
60566217

60576218
// -----------------------------------------------------------------------------
@@ -6283,11 +6444,13 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
62836444
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
62846445
#undef INSERT_CONSTANT_FILL_PATTERN
62856446

6286-
#define INSERT_FILL_SCALAR_PATTERN(AtenOp) \
6447+
#define INSERT_FILL_PATTERN(AtenOp) \
62876448
target.addIllegalOp<AtenOp>(); \
6288-
patterns.add<ConvertAtenFillScalarOp<AtenOp>>(typeConverter, context);
6289-
INSERT_FILL_SCALAR_PATTERN(AtenFill_ScalarOp);
6290-
#undef INSERT_FILL_SCALAR_PATTERN
6449+
patterns.add<ConvertAtenFillOp<AtenOp>>(typeConverter, context);
6450+
INSERT_FILL_PATTERN(AtenFill_ScalarOp);
6451+
INSERT_FILL_PATTERN(AtenFillScalarOp);
6452+
INSERT_FILL_PATTERN(AtenFillTensorOp);
6453+
#undef INSERT_FILL_PATTERN
62916454

62926455
#define INSERT_MASKED_FILL_PATTERN(AtenOp) \
62936456
target.addIllegalOp<AtenOp>(); \
@@ -6359,6 +6522,8 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
63596522
INSERT_ATENOP_PATTERN(AtenTrilOp);
63606523
INSERT_ATENOP_PATTERN(AtenDiagonalOp);
63616524
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
6525+
INSERT_ATENOP_PATTERN(AtenFlipOp);
6526+
INSERT_ATENOP_PATTERN(AtenRoundOp);
63626527
#undef INSERT_ATENOP_PATTERN
63636528

63646529
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

0 commit comments

Comments
 (0)