Skip to content

Commit f00339e

Browse files
committed
[tosa] : Add support for quantize_per_tensor.
1 parent 1cf2871 commit f00339e

File tree

5 files changed

+331
-149
lines changed

5 files changed

+331
-149
lines changed

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,31 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op,
111111
RankedTensorType output_type, Value input_value,
112112
ElementsAttr axes_elems, bool keep_dims);
113113

114+
// Creates IntegerAttrs for clamping, using provided min/max values or the
115+
// numeric limits of the element type if the values are not provided.
116+
LogicalResult getIntegerClampAttrs(ConversionPatternRewriter &rewriter,
117+
Operation *op, Type elemTy,
118+
std::optional<int64_t> minInt,
119+
std::optional<int64_t> maxInt,
120+
IntegerAttr &minAttr, IntegerAttr &maxAttr);
121+
122+
// Creates FloatAttrs for clamping, using provided min/max values or the numeric
123+
// limits of the element type if the values are not provided.
124+
LogicalResult getFloatClampAttrs(ConversionPatternRewriter &rewriter,
125+
Operation *op, Type elemTy,
126+
std::optional<double> minFloat,
127+
std::optional<double> maxFloat,
128+
FloatAttr &minAttr, FloatAttr &maxAttr);
129+
130+
// Implements "round half to even" logic for aten.round using TOSA ops.
131+
// if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)):
132+
// res = floor(input)
133+
// else:
134+
// res = ceil(input)
135+
std::optional<Value> createRoundHalfToEven(ConversionPatternRewriter &rewriter,
136+
Operation *op, Value input,
137+
RankedTensorType resultTy);
138+
114139
} // namespace tosa
115140
} // namespace mlir
116141

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 119 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -5304,69 +5304,45 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
53045304
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType()));
53055305
auto outElemTy = outType.getElementType();
53065306

5307-
int64_t minInt, maxInt;
5308-
double minFloat, maxFloat;
5309-
bool isMinNotNone = false;
5310-
bool isMaxNotNone = false;
5311-
5312-
auto isMinInt = matchPattern(op.getMin(), m_TorchConstantInt(&minInt));
5313-
auto isMinFloat = matchPattern(op.getMin(), m_TorchConstantFloat(&minFloat));
5314-
if (isMinInt) {
5315-
minFloat = static_cast<float>(minInt);
5316-
isMinNotNone = true;
5317-
} else if (isMinFloat) {
5318-
minInt = static_cast<int64_t>(minFloat);
5319-
isMinNotNone = true;
5320-
} else {
5321-
if (succeeded(checkNotNone(rewriter, op, op.getMin())))
5307+
std::optional<int64_t> minInt;
5308+
std::optional<double> minFloat;
5309+
{
5310+
int64_t minIntVal;
5311+
double minFloatVal;
5312+
if (matchPattern(op.getMin(), m_TorchConstantInt(&minIntVal))) {
5313+
minInt = minIntVal;
5314+
minFloat = static_cast<double>(minIntVal);
5315+
} else if (matchPattern(op.getMin(), m_TorchConstantFloat(&minFloatVal))) {
5316+
minFloat = minFloatVal;
5317+
minInt = static_cast<int64_t>(minFloatVal);
5318+
} else if (succeeded(checkNotNone(rewriter, op, op.getMin()))) {
53225319
return rewriter.notifyMatchFailure(op,
53235320
"min attr should be a torch constant");
5321+
}
53245322
}
53255323

5326-
auto isMaxInt = matchPattern(op.getMax(), m_TorchConstantInt(&maxInt));
5327-
auto isMaxFloat = matchPattern(op.getMax(), m_TorchConstantFloat(&maxFloat));
5328-
if (isMaxInt) {
5329-
maxFloat = static_cast<float>(maxInt);
5330-
isMaxNotNone = true;
5331-
} else if (isMaxFloat) {
5332-
maxInt = static_cast<int64_t>(maxFloat);
5333-
isMaxNotNone = true;
5334-
} else {
5335-
if (succeeded(checkNotNone(rewriter, op, op.getMax())))
5324+
std::optional<int64_t> maxInt;
5325+
std::optional<double> maxFloat;
5326+
{
5327+
int64_t maxIntVal;
5328+
double maxFloatVal;
5329+
if (matchPattern(op.getMax(), m_TorchConstantInt(&maxIntVal))) {
5330+
maxInt = maxIntVal;
5331+
maxFloat = static_cast<double>(maxIntVal);
5332+
} else if (matchPattern(op.getMax(), m_TorchConstantFloat(&maxFloatVal))) {
5333+
maxFloat = maxFloatVal;
5334+
maxInt = static_cast<int64_t>(maxFloatVal);
5335+
} else if (succeeded(checkNotNone(rewriter, op, op.getMax()))) {
53365336
return rewriter.notifyMatchFailure(op,
53375337
"max attr should be a torch constant");
5338+
}
53385339
}
53395340

53405341
if (!isa<mlir::FloatType>(outElemTy)) {
53415342
IntegerAttr minIntAttr, maxIntAttr;
5342-
if (outElemTy.isInteger(8)) {
5343-
minIntAttr = rewriter.getIntegerAttr(
5344-
outElemTy,
5345-
isMinNotNone ? minInt : std::numeric_limits<int8_t>::min());
5346-
maxIntAttr = rewriter.getIntegerAttr(
5347-
outElemTy,
5348-
isMaxNotNone ? maxInt : std::numeric_limits<int8_t>::max());
5349-
} else if (outElemTy.isInteger(16)) {
5350-
minIntAttr = rewriter.getIntegerAttr(
5351-
outElemTy,
5352-
isMinNotNone ? minInt : std::numeric_limits<int16_t>::min());
5353-
maxIntAttr = rewriter.getIntegerAttr(
5354-
outElemTy,
5355-
isMaxNotNone ? maxInt : std::numeric_limits<int16_t>::max());
5356-
} else if (outElemTy.isInteger(32)) {
5357-
minIntAttr = rewriter.getIntegerAttr(
5358-
outElemTy,
5359-
isMinNotNone ? minInt : std::numeric_limits<int32_t>::min());
5360-
maxIntAttr = rewriter.getIntegerAttr(
5361-
outElemTy,
5362-
isMaxNotNone ? maxInt : std::numeric_limits<int32_t>::max());
5363-
} else if (outElemTy.isInteger(64)) {
5364-
minIntAttr = rewriter.getI64IntegerAttr(
5365-
isMinNotNone ? minInt : std::numeric_limits<int64_t>::min());
5366-
maxIntAttr = rewriter.getI64IntegerAttr(
5367-
isMaxNotNone ? maxInt : std::numeric_limits<int64_t>::max());
5368-
} else {
5369-
return rewriter.notifyMatchFailure(op, "Unsupported integer type");
5343+
if (failed(tosa::getIntegerClampAttrs(rewriter, op, outElemTy, minInt,
5344+
maxInt, minIntAttr, maxIntAttr))) {
5345+
return failure();
53705346
}
53715347

53725348
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
@@ -5376,28 +5352,10 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
53765352
tosa::NanPropagationMode::PROPAGATE));
53775353
} else {
53785354
FloatAttr minFloatAttr, maxFloatAttr;
5379-
if (outElemTy.isF16()) {
5380-
minFloatAttr =
5381-
rewriter.getF16FloatAttr(isMinNotNone ? minFloat : Float16Lowest);
5382-
maxFloatAttr =
5383-
rewriter.getF16FloatAttr(isMaxNotNone ? maxFloat : Float16Max);
5384-
} else if (outElemTy.isBF16()) {
5385-
minFloatAttr = rewriter.getFloatAttr(
5386-
rewriter.getBF16Type(), isMinNotNone ? minFloat : BFloat16Lowest);
5387-
maxFloatAttr = rewriter.getFloatAttr(
5388-
rewriter.getBF16Type(), isMaxNotNone ? maxFloat : BFloat16Max);
5389-
} else if (outElemTy.isF32()) {
5390-
minFloatAttr = rewriter.getF32FloatAttr(
5391-
isMinNotNone ? minFloat : std::numeric_limits<float>::lowest());
5392-
maxFloatAttr = rewriter.getF32FloatAttr(
5393-
isMaxNotNone ? maxFloat : std::numeric_limits<float>::max());
5394-
} else if (outElemTy.isF64()) {
5395-
minFloatAttr = rewriter.getF64FloatAttr(
5396-
isMinNotNone ? minFloat : std::numeric_limits<double>::lowest());
5397-
maxFloatAttr = rewriter.getF64FloatAttr(
5398-
isMaxNotNone ? maxFloat : std::numeric_limits<double>::max());
5399-
} else {
5400-
return rewriter.notifyMatchFailure(op, "Unsupported floating-point type");
5355+
if (failed(tosa::getFloatClampAttrs(rewriter, op, outElemTy, minFloat,
5356+
maxFloat, minFloatAttr,
5357+
maxFloatAttr))) {
5358+
return failure();
54015359
}
54025360

54035361
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
@@ -7308,17 +7266,6 @@ template <>
73087266
LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
73097267
AtenRoundOp op, OpAdaptor adaptor,
73107268
ConversionPatternRewriter &rewriter) const {
7311-
// To round to the nearest integer, we will consider the fractional part of
7312-
// the input element (= input element - integer part of element). If the
7313-
// fractional part is smaller than 0.5, round the number down. If the
7314-
// fractional part is 0.5, apply "round half to even" rule. If the fractional
7315-
// part is greater than 0.5, round up.
7316-
//
7317-
// if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)):
7318-
// res = floor(input)
7319-
// else:
7320-
// res = ceil(input)
7321-
73227269
auto self = adaptor.getSelf();
73237270

73247271
auto selfTy = dyn_cast<TensorType>(self.getType());
@@ -7328,67 +7275,13 @@ LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
73287275
auto resultTy =
73297276
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
73307277

7331-
auto boolTy =
7332-
RankedTensorType::get(resultTy.getShape(), rewriter.getIntegerType(1));
7333-
7334-
auto resultElemTy = resultTy.getElementType();
7335-
7336-
auto oneHalf =
7337-
tosa::getConstTensor<float>(rewriter, op, 0.5, {}, resultElemTy).value();
7338-
7339-
auto two =
7340-
tosa::getConstTensor<float>(rewriter, op, 2, {}, resultElemTy).value();
7341-
7342-
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, oneHalf)
7343-
.failed() ||
7344-
mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, two).failed())
7278+
auto result = tosa::createRoundHalfToEven(rewriter, op, self, resultTy);
7279+
if (!result) {
73457280
return rewriter.notifyMatchFailure(
7346-
op, "Failed to equalize ranks among operands and result");
7347-
7348-
auto floorInput =
7349-
tosa::FloorOp::create(rewriter, op->getLoc(), resultTy, self);
7350-
7351-
// input - floor(input)
7352-
auto fractionalPart = tosa::SubOp::create(rewriter, op->getLoc(), resultTy,
7353-
self, floorInput.getResult());
7354-
7355-
auto ceilInput = tosa::CeilOp::create(rewriter, op->getLoc(), resultTy, self);
7356-
7357-
auto floorInputDivByTwo = tosa::createMulOpAndCast(
7358-
rewriter, op, resultTy, floorInput.getResult(), oneHalf, /*shift=*/0);
7359-
7360-
auto floorDivResult = tosa::FloorOp::create(rewriter, op->getLoc(), resultTy,
7361-
floorInputDivByTwo.getResult());
7362-
7363-
// (floor(input) // 2) * 2
7364-
auto evenComparison = tosa::createMulOpAndCast(
7365-
rewriter, op, resultTy, floorDivResult.getResult(), two, /*shift=*/0);
7366-
7367-
// floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0
7368-
auto floorInputEven =
7369-
tosa::EqualOp::create(rewriter, op->getLoc(), boolTy,
7370-
floorInput.getResult(), evenComparison.getResult());
7371-
7372-
auto fracEqualOneHalf = tosa::EqualOp::create(
7373-
rewriter, op->getLoc(), boolTy, fractionalPart.getResult(), oneHalf);
7374-
7375-
auto fracLtOneHalf = tosa::GreaterOp::create(
7376-
rewriter, op->getLoc(), boolTy, oneHalf, fractionalPart.getResult());
7377-
7378-
// (frac == 0.5) && (floor(input) % 2 == 0)
7379-
auto fracEqualOneHalfCond = tosa::LogicalAndOp::create(
7380-
rewriter, op->getLoc(), boolTy, fracEqualOneHalf.getResult(),
7381-
floorInputEven.getResult());
7382-
7383-
// (frac < 0.5) || ((frac == 0.5) && (floor(input) % 2 == 0))
7384-
auto floorResultCond = tosa::LogicalOrOp::create(
7385-
rewriter, op->getLoc(), boolTy, fracLtOneHalf.getResult(),
7386-
fracEqualOneHalfCond.getResult());
7387-
7388-
rewriter.replaceOpWithNewOp<tosa::SelectOp>(
7389-
op, resultTy, floorResultCond.getResult(), floorInput.getResult(),
7390-
ceilInput.getResult());
7281+
op, "failed to implement round-half-to-even with TOSA ops");
7282+
}
73917283

7284+
rewriter.replaceOp(op, *result);
73927285
return success();
73937286
}
73947287

@@ -9339,6 +9232,86 @@ LogicalResult ConvertAtenOp<AtenDequantizeTensorOp>::matchAndRewrite(
93399232
return success();
93409233
}
93419234

9235+
// Legalization for aten.quantize_per_tensor
9236+
// Implements
9237+
// Q = clamp(round(X / scale) + zero_point)
9238+
template <>
9239+
LogicalResult ConvertAtenOp<AtenQuantizePerTensorOp>::matchAndRewrite(
9240+
AtenQuantizePerTensorOp op, OpAdaptor adaptor,
9241+
ConversionPatternRewriter &rewriter) const {
9242+
Value input = adaptor.getSelf();
9243+
auto loc = op->getLoc();
9244+
9245+
// Get scale and zero_point as constants.
9246+
double scaleConst;
9247+
if (!matchPattern(op.getScale(), m_TorchConstantFloat(&scaleConst)))
9248+
return rewriter.notifyMatchFailure(op, "scale must be a Scalar constant");
9249+
9250+
int64_t zpConst;
9251+
if (!matchPattern(op.getZeroPoint(), m_TorchConstantInt(&zpConst)))
9252+
return rewriter.notifyMatchFailure(op,
9253+
"zero point must be a Scalar constant");
9254+
9255+
// Get input and result types.
9256+
auto inputTy = cast<RankedTensorType>(input.getType());
9257+
auto inputElemTy = inputTy.getElementType();
9258+
auto resultTy = cast<RankedTensorType>(
9259+
getTypeConverter()->convertType(op->getResult(0).getType()));
9260+
auto resultElemTy = resultTy.getElementType();
9261+
9262+
// Rescale the input: input * (1.0 / scale)
9263+
auto scaleReciprocal = 1.0 / scaleConst;
9264+
auto scaleConstTensor = tosa::getConstTensor<float>(
9265+
rewriter, op, scaleReciprocal, {}, inputElemTy)
9266+
.value();
9267+
if (mlir::tosa::EqualizeRanks(rewriter, loc, input, scaleConstTensor)
9268+
.failed())
9269+
return rewriter.notifyMatchFailure(
9270+
op, "Failed to equalize ranks among operands");
9271+
Value rescaledInput = tosa::createMulOpAndCast(
9272+
rewriter, op, inputTy, input, scaleConstTensor, /*shift =*/0);
9273+
9274+
// Round
9275+
auto rounded =
9276+
tosa::createRoundHalfToEven(rewriter, op, rescaledInput, inputTy);
9277+
if (!rounded) {
9278+
return rewriter.notifyMatchFailure(
9279+
op, "failed to implement round-half-to-even with TOSA ops");
9280+
}
9281+
9282+
// Cast to the destination integer type.
9283+
auto intermediateIntTy = resultTy.clone(resultElemTy);
9284+
Value castToInt =
9285+
tosa::CastOp::create(rewriter, loc, intermediateIntTy, *rounded);
9286+
9287+
// Add the zero point.
9288+
Value zpTensor =
9289+
tosa::createZeroPointTensor(rewriter, loc, intermediateIntTy, zpConst)
9290+
.value();
9291+
if (mlir::tosa::EqualizeRanks(rewriter, loc, castToInt, zpTensor).failed())
9292+
return failure();
9293+
Value withZp = tosa::AddOp::create(rewriter, loc, intermediateIntTy,
9294+
castToInt, zpTensor);
9295+
9296+
// Clamp the result to the valid range of the quantized type.
9297+
std::optional<int64_t> minInt,
9298+
maxInt; // no initialization needed as we want to clamp to the numeric
9299+
// limits of the type
9300+
IntegerAttr minIntAttr, maxIntAttr;
9301+
if (failed(tosa::getIntegerClampAttrs(rewriter, op, resultElemTy, minInt,
9302+
maxInt, minIntAttr, maxIntAttr))) {
9303+
return failure();
9304+
}
9305+
Value clamped = tosa::ClampOp::create(
9306+
rewriter, loc, resultTy, withZp, minIntAttr, maxIntAttr,
9307+
/*nan_mode=*/
9308+
tosa::NanPropagationModeAttr::get(rewriter.getContext(),
9309+
tosa::NanPropagationMode::PROPAGATE));
9310+
9311+
rewriter.replaceOp(op, clamped);
9312+
return success();
9313+
}
9314+
93429315
} // namespace
93439316

93449317
// -----------------------------------------------------------------------------
@@ -9713,6 +9686,7 @@ std::set<StringRef> populateTorchToTosaConversionPatternsAndIllegalOps(
97139686
INSERT_ATENOP_PATTERN(AtenTanOp);
97149687
INSERT_ATENOP_PATTERN(AtenUnfoldOp);
97159688
INSERT_ATENOP_PATTERN(AtenDequantizeTensorOp);
9689+
INSERT_ATENOP_PATTERN(AtenQuantizePerTensorOp);
97169690
#undef INSERT_ATENOP_PATTERN
97179691

97189692
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

0 commit comments

Comments
 (0)