@@ -5304,69 +5304,45 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
53045304 dyn_cast<TensorType>(getTypeConverter ()->convertType (op.getType ()));
53055305 auto outElemTy = outType.getElementType ();
53065306
5307- int64_t minInt, maxInt;
5308- double minFloat, maxFloat;
5309- bool isMinNotNone = false ;
5310- bool isMaxNotNone = false ;
5311-
5312- auto isMinInt = matchPattern (op.getMin (), m_TorchConstantInt (&minInt));
5313- auto isMinFloat = matchPattern (op.getMin (), m_TorchConstantFloat (&minFloat));
5314- if (isMinInt) {
5315- minFloat = static_cast <float >(minInt);
5316- isMinNotNone = true ;
5317- } else if (isMinFloat) {
5318- minInt = static_cast <int64_t >(minFloat);
5319- isMinNotNone = true ;
5320- } else {
5321- if (succeeded (checkNotNone (rewriter, op, op.getMin ())))
5307+ std::optional<int64_t > minInt;
5308+ std::optional<double > minFloat;
5309+ {
5310+ int64_t minIntVal;
5311+ double minFloatVal;
5312+ if (matchPattern (op.getMin (), m_TorchConstantInt (&minIntVal))) {
5313+ minInt = minIntVal;
5314+ minFloat = static_cast <double >(minIntVal);
5315+ } else if (matchPattern (op.getMin (), m_TorchConstantFloat (&minFloatVal))) {
5316+ minFloat = minFloatVal;
5317+ minInt = static_cast <int64_t >(minFloatVal);
5318+ } else if (succeeded (checkNotNone (rewriter, op, op.getMin ()))) {
53225319 return rewriter.notifyMatchFailure (op,
53235320 " min attr should be a torch constant" );
5321+ }
53245322 }
53255323
5326- auto isMaxInt = matchPattern (op.getMax (), m_TorchConstantInt (&maxInt));
5327- auto isMaxFloat = matchPattern (op.getMax (), m_TorchConstantFloat (&maxFloat));
5328- if (isMaxInt) {
5329- maxFloat = static_cast <float >(maxInt);
5330- isMaxNotNone = true ;
5331- } else if (isMaxFloat) {
5332- maxInt = static_cast <int64_t >(maxFloat);
5333- isMaxNotNone = true ;
5334- } else {
5335- if (succeeded (checkNotNone (rewriter, op, op.getMax ())))
5324+ std::optional<int64_t > maxInt;
5325+ std::optional<double > maxFloat;
5326+ {
5327+ int64_t maxIntVal;
5328+ double maxFloatVal;
5329+ if (matchPattern (op.getMax (), m_TorchConstantInt (&maxIntVal))) {
5330+ maxInt = maxIntVal;
5331+ maxFloat = static_cast <double >(maxIntVal);
5332+ } else if (matchPattern (op.getMax (), m_TorchConstantFloat (&maxFloatVal))) {
5333+ maxFloat = maxFloatVal;
5334+ maxInt = static_cast <int64_t >(maxFloatVal);
5335+ } else if (succeeded (checkNotNone (rewriter, op, op.getMax ()))) {
53365336 return rewriter.notifyMatchFailure (op,
53375337 " max attr should be a torch constant" );
5338+ }
53385339 }
53395340
53405341 if (!isa<mlir::FloatType>(outElemTy)) {
53415342 IntegerAttr minIntAttr, maxIntAttr;
5342- if (outElemTy.isInteger (8 )) {
5343- minIntAttr = rewriter.getIntegerAttr (
5344- outElemTy,
5345- isMinNotNone ? minInt : std::numeric_limits<int8_t >::min ());
5346- maxIntAttr = rewriter.getIntegerAttr (
5347- outElemTy,
5348- isMaxNotNone ? maxInt : std::numeric_limits<int8_t >::max ());
5349- } else if (outElemTy.isInteger (16 )) {
5350- minIntAttr = rewriter.getIntegerAttr (
5351- outElemTy,
5352- isMinNotNone ? minInt : std::numeric_limits<int16_t >::min ());
5353- maxIntAttr = rewriter.getIntegerAttr (
5354- outElemTy,
5355- isMaxNotNone ? maxInt : std::numeric_limits<int16_t >::max ());
5356- } else if (outElemTy.isInteger (32 )) {
5357- minIntAttr = rewriter.getIntegerAttr (
5358- outElemTy,
5359- isMinNotNone ? minInt : std::numeric_limits<int32_t >::min ());
5360- maxIntAttr = rewriter.getIntegerAttr (
5361- outElemTy,
5362- isMaxNotNone ? maxInt : std::numeric_limits<int32_t >::max ());
5363- } else if (outElemTy.isInteger (64 )) {
5364- minIntAttr = rewriter.getI64IntegerAttr (
5365- isMinNotNone ? minInt : std::numeric_limits<int64_t >::min ());
5366- maxIntAttr = rewriter.getI64IntegerAttr (
5367- isMaxNotNone ? maxInt : std::numeric_limits<int64_t >::max ());
5368- } else {
5369- return rewriter.notifyMatchFailure (op, " Unsupported integer type" );
5343+ if (failed (tosa::getIntegerClampAttrs (rewriter, op, outElemTy, minInt,
5344+ maxInt, minIntAttr, maxIntAttr))) {
5345+ return failure ();
53705346 }
53715347
53725348 rewriter.replaceOpWithNewOp <tosa::ClampOp>(
@@ -5376,28 +5352,10 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
53765352 tosa::NanPropagationMode::PROPAGATE));
53775353 } else {
53785354 FloatAttr minFloatAttr, maxFloatAttr;
5379- if (outElemTy.isF16 ()) {
5380- minFloatAttr =
5381- rewriter.getF16FloatAttr (isMinNotNone ? minFloat : Float16Lowest);
5382- maxFloatAttr =
5383- rewriter.getF16FloatAttr (isMaxNotNone ? maxFloat : Float16Max);
5384- } else if (outElemTy.isBF16 ()) {
5385- minFloatAttr = rewriter.getFloatAttr (
5386- rewriter.getBF16Type (), isMinNotNone ? minFloat : BFloat16Lowest);
5387- maxFloatAttr = rewriter.getFloatAttr (
5388- rewriter.getBF16Type (), isMaxNotNone ? maxFloat : BFloat16Max);
5389- } else if (outElemTy.isF32 ()) {
5390- minFloatAttr = rewriter.getF32FloatAttr (
5391- isMinNotNone ? minFloat : std::numeric_limits<float >::lowest ());
5392- maxFloatAttr = rewriter.getF32FloatAttr (
5393- isMaxNotNone ? maxFloat : std::numeric_limits<float >::max ());
5394- } else if (outElemTy.isF64 ()) {
5395- minFloatAttr = rewriter.getF64FloatAttr (
5396- isMinNotNone ? minFloat : std::numeric_limits<double >::lowest ());
5397- maxFloatAttr = rewriter.getF64FloatAttr (
5398- isMaxNotNone ? maxFloat : std::numeric_limits<double >::max ());
5399- } else {
5400- return rewriter.notifyMatchFailure (op, " Unsupported floating-point type" );
5355+ if (failed (tosa::getFloatClampAttrs (rewriter, op, outElemTy, minFloat,
5356+ maxFloat, minFloatAttr,
5357+ maxFloatAttr))) {
5358+ return failure ();
54015359 }
54025360
54035361 rewriter.replaceOpWithNewOp <tosa::ClampOp>(
@@ -7308,17 +7266,6 @@ template <>
73087266LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
73097267 AtenRoundOp op, OpAdaptor adaptor,
73107268 ConversionPatternRewriter &rewriter) const {
7311- // To round to the nearest integer, we will consider the fractional part of
7312- // the input element (= input element - integer part of element). If the
7313- // fractional part is smaller than 0.5, round the number down. If the
7314- // fractional part is 0.5, apply "round half to even" rule. If the fractional
7315- // part is greater than 0.5, round up.
7316- //
7317- // if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)):
7318- // res = floor(input)
7319- // else:
7320- // res = ceil(input)
7321-
73227269 auto self = adaptor.getSelf ();
73237270
73247271 auto selfTy = dyn_cast<TensorType>(self.getType ());
@@ -7328,67 +7275,13 @@ LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
73287275 auto resultTy =
73297276 cast<RankedTensorType>(getTypeConverter ()->convertType (op.getType ()));
73307277
7331- auto boolTy =
7332- RankedTensorType::get (resultTy.getShape (), rewriter.getIntegerType (1 ));
7333-
7334- auto resultElemTy = resultTy.getElementType ();
7335-
7336- auto oneHalf =
7337- tosa::getConstTensor<float >(rewriter, op, 0.5 , {}, resultElemTy).value ();
7338-
7339- auto two =
7340- tosa::getConstTensor<float >(rewriter, op, 2 , {}, resultElemTy).value ();
7341-
7342- if (mlir::tosa::EqualizeRanks (rewriter, op->getLoc (), self, oneHalf)
7343- .failed () ||
7344- mlir::tosa::EqualizeRanks (rewriter, op->getLoc (), self, two).failed ())
7278+ auto result = tosa::createRoundHalfToEven (rewriter, op, self, resultTy);
7279+ if (!result) {
73457280 return rewriter.notifyMatchFailure (
7346- op, " Failed to equalize ranks among operands and result" );
7347-
7348- auto floorInput =
7349- tosa::FloorOp::create (rewriter, op->getLoc (), resultTy, self);
7350-
7351- // input - floor(input)
7352- auto fractionalPart = tosa::SubOp::create (rewriter, op->getLoc (), resultTy,
7353- self, floorInput.getResult ());
7354-
7355- auto ceilInput = tosa::CeilOp::create (rewriter, op->getLoc (), resultTy, self);
7356-
7357- auto floorInputDivByTwo = tosa::createMulOpAndCast (
7358- rewriter, op, resultTy, floorInput.getResult (), oneHalf, /* shift=*/ 0 );
7359-
7360- auto floorDivResult = tosa::FloorOp::create (rewriter, op->getLoc (), resultTy,
7361- floorInputDivByTwo.getResult ());
7362-
7363- // (floor(input) // 2) * 2
7364- auto evenComparison = tosa::createMulOpAndCast (
7365- rewriter, op, resultTy, floorDivResult.getResult (), two, /* shift=*/ 0 );
7366-
7367- // floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0
7368- auto floorInputEven =
7369- tosa::EqualOp::create (rewriter, op->getLoc (), boolTy,
7370- floorInput.getResult (), evenComparison.getResult ());
7371-
7372- auto fracEqualOneHalf = tosa::EqualOp::create (
7373- rewriter, op->getLoc (), boolTy, fractionalPart.getResult (), oneHalf);
7374-
7375- auto fracLtOneHalf = tosa::GreaterOp::create (
7376- rewriter, op->getLoc (), boolTy, oneHalf, fractionalPart.getResult ());
7377-
7378- // (frac == 0.5) && (floor(input) % 2 == 0)
7379- auto fracEqualOneHalfCond = tosa::LogicalAndOp::create (
7380- rewriter, op->getLoc (), boolTy, fracEqualOneHalf.getResult (),
7381- floorInputEven.getResult ());
7382-
7383- // (frac < 0.5) || ((frac == 0.5) && (floor(input) % 2 == 0))
7384- auto floorResultCond = tosa::LogicalOrOp::create (
7385- rewriter, op->getLoc (), boolTy, fracLtOneHalf.getResult (),
7386- fracEqualOneHalfCond.getResult ());
7387-
7388- rewriter.replaceOpWithNewOp <tosa::SelectOp>(
7389- op, resultTy, floorResultCond.getResult (), floorInput.getResult (),
7390- ceilInput.getResult ());
7281+ op, " failed to implement round-half-to-even with TOSA ops" );
7282+ }
73917283
7284+ rewriter.replaceOp (op, *result);
73927285 return success ();
73937286}
73947287
@@ -9339,6 +9232,86 @@ LogicalResult ConvertAtenOp<AtenDequantizeTensorOp>::matchAndRewrite(
93399232 return success ();
93409233}
93419234
9235+ // Legalization for aten.quantize_per_tensor
9236+ // Implements
9237+ // Q = clamp(round(X / scale) + zero_point)
9238+ template <>
9239+ LogicalResult ConvertAtenOp<AtenQuantizePerTensorOp>::matchAndRewrite(
9240+ AtenQuantizePerTensorOp op, OpAdaptor adaptor,
9241+ ConversionPatternRewriter &rewriter) const {
9242+ Value input = adaptor.getSelf ();
9243+ auto loc = op->getLoc ();
9244+
9245+ // Get scale and zero_point as constants.
9246+ double scaleConst;
9247+ if (!matchPattern (op.getScale (), m_TorchConstantFloat (&scaleConst)))
9248+ return rewriter.notifyMatchFailure (op, " scale must be a Scalar constant" );
9249+
9250+ int64_t zpConst;
9251+ if (!matchPattern (op.getZeroPoint (), m_TorchConstantInt (&zpConst)))
9252+ return rewriter.notifyMatchFailure (op,
9253+ " zero point must be a Scalar constant" );
9254+
9255+ // Get input and result types.
9256+ auto inputTy = cast<RankedTensorType>(input.getType ());
9257+ auto inputElemTy = inputTy.getElementType ();
9258+ auto resultTy = cast<RankedTensorType>(
9259+ getTypeConverter ()->convertType (op->getResult (0 ).getType ()));
9260+ auto resultElemTy = resultTy.getElementType ();
9261+
9262+ // Rescale the input: input * (1.0 / scale)
9263+ auto scaleReciprocal = 1.0 / scaleConst;
9264+ auto scaleConstTensor = tosa::getConstTensor<float >(
9265+ rewriter, op, scaleReciprocal, {}, inputElemTy)
9266+ .value ();
9267+ if (mlir::tosa::EqualizeRanks (rewriter, loc, input, scaleConstTensor)
9268+ .failed ())
9269+ return rewriter.notifyMatchFailure (
9270+ op, " Failed to equalize ranks among operands" );
9271+ Value rescaledInput = tosa::createMulOpAndCast (
9272+ rewriter, op, inputTy, input, scaleConstTensor, /* shift =*/ 0 );
9273+
9274+ // Round
9275+ auto rounded =
9276+ tosa::createRoundHalfToEven (rewriter, op, rescaledInput, inputTy);
9277+ if (!rounded) {
9278+ return rewriter.notifyMatchFailure (
9279+ op, " failed to implement round-half-to-even with TOSA ops" );
9280+ }
9281+
9282+ // Cast to the destination integer type.
9283+ auto intermediateIntTy = resultTy.clone (resultElemTy);
9284+ Value castToInt =
9285+ tosa::CastOp::create (rewriter, loc, intermediateIntTy, *rounded);
9286+
9287+ // Add the zero point.
9288+ Value zpTensor =
9289+ tosa::createZeroPointTensor (rewriter, loc, intermediateIntTy, zpConst)
9290+ .value ();
9291+ if (mlir::tosa::EqualizeRanks (rewriter, loc, castToInt, zpTensor).failed ())
9292+ return failure ();
9293+ Value withZp = tosa::AddOp::create (rewriter, loc, intermediateIntTy,
9294+ castToInt, zpTensor);
9295+
9296+ // Clamp the result to the valid range of the quantized type.
9297+ std::optional<int64_t > minInt,
9298+ maxInt; // no initialization needed as we want to clamp to the numeric
9299+ // limits of the type
9300+ IntegerAttr minIntAttr, maxIntAttr;
9301+ if (failed (tosa::getIntegerClampAttrs (rewriter, op, resultElemTy, minInt,
9302+ maxInt, minIntAttr, maxIntAttr))) {
9303+ return failure ();
9304+ }
9305+ Value clamped = tosa::ClampOp::create (
9306+ rewriter, loc, resultTy, withZp, minIntAttr, maxIntAttr,
9307+ /* nan_mode=*/
9308+ tosa::NanPropagationModeAttr::get (rewriter.getContext (),
9309+ tosa::NanPropagationMode::PROPAGATE));
9310+
9311+ rewriter.replaceOp (op, clamped);
9312+ return success ();
9313+ }
9314+
93429315} // namespace
93439316
93449317// -----------------------------------------------------------------------------
@@ -9713,6 +9686,7 @@ std::set<StringRef> populateTorchToTosaConversionPatternsAndIllegalOps(
97139686 INSERT_ATENOP_PATTERN (AtenTanOp);
97149687 INSERT_ATENOP_PATTERN (AtenUnfoldOp);
97159688 INSERT_ATENOP_PATTERN (AtenDequantizeTensorOp);
9689+ INSERT_ATENOP_PATTERN (AtenQuantizePerTensorOp);
97169690#undef INSERT_ATENOP_PATTERN
97179691
97189692#define INSERT_CLONE_ATENOP_PATTERN (AtenOp ) \
0 commit comments