@@ -405,7 +405,7 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
405405 Value rhsAsTensor;
406406 if (!rhsTy) {
407407 if (failed (torchScalarToTosaTensor (rewriter, op, op.getOther (),
408- rhsAsTensor, lhsElemTy , {})))
408+ rhsAsTensor, rhs. getType () , {})))
409409 return rewriter.notifyMatchFailure (
410410 op, " Currently only scalar constants are supported for "
411411 " conversion in TOSA operation" );
@@ -414,11 +414,26 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
414414 auto rhsTensorTy = dyn_cast<TensorType>(rhsTensor.getType ());
415415 auto rhsElemTy = rhsTensorTy.getElementType ();
416416
417+ // There is no Lesser operator in TOSA.
418+ constexpr auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
419+ std::is_same<AtenOpT, AtenLtScalarOp>() ||
420+ std::is_same<AtenOpT, AtenLeTensorOp>() ||
421+ std::is_same<AtenOpT, AtenLeScalarOp>());
422+
423+ // Promote lhs and rhs dtypes for bitwise operators.
424+ TensorType resultTy = cast<TensorType>(
425+ OpConversionPattern<AtenOpT>::getTypeConverter ()->convertType (
426+ op.getType ()));
427+ if (isBitwiseOp) {
428+ lhs = tosa::promoteType (rewriter, lhs, resultTy);
429+ rhsTensor = tosa::promoteType (rewriter, rhsTensor, resultTy);
430+ }
431+
432+ // Support different types comparisons
417433 auto isLhsElemFloat = isa<mlir::FloatType>(lhsElemTy);
418434 auto isRhsElemFloat = isa<mlir::FloatType>(rhsElemTy);
419435
420- // Support different types comparisons
421- if (lhsElemTy != rhsElemTy) {
436+ if (lhsElemTy != rhsElemTy && !isBitwiseOp) {
422437 if (isLhsElemFloat && !isRhsElemFloat) {
423438 rhsTensor = tosa::promoteType (rewriter, rhsTensor, lhsTy);
424439 } else if (!isLhsElemFloat && isRhsElemFloat) {
@@ -441,20 +456,6 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
441456 }
442457 }
443458 }
444- // There is no Lesser operator in TOSA.
445- constexpr auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
446- std::is_same<AtenOpT, AtenLtScalarOp>() ||
447- std::is_same<AtenOpT, AtenLeTensorOp>() ||
448- std::is_same<AtenOpT, AtenLeScalarOp>());
449-
450- // Promote lhs and rhs dtypes for bitwise operators.
451- TensorType resultTy = cast<TensorType>(
452- OpConversionPattern<AtenOpT>::getTypeConverter ()->convertType (
453- op.getType ()));
454- if (isBitwiseOp) {
455- lhs = tosa::promoteType (rewriter, lhs, resultTy);
456- rhsTensor = tosa::promoteType (rewriter, rhsTensor, resultTy);
457- }
458459
459460 auto resultOp = rewriter.create <TosaOpT>(op.getLoc (), resultTy,
460461 (swapLhsRhs ? rhsTensor : lhs),
@@ -770,17 +771,24 @@ class ConvertAtenActivationFunctionOp : public OpConversionPattern<AtenOpT> {
770771 matchAndRewrite (AtenOpT op, OpAdaptor adaptor,
771772 ConversionPatternRewriter &rewriter) const override {
772773 Value self = adaptor.getSelf ();
773- auto selfTy = cast <TensorType>(self.getType ());
774+ auto selfTy = dyn_cast <TensorType>(self.getType ());
774775
775776 if (!selfTy)
776777 return rewriter.notifyMatchFailure (op, " Only Tensor types supported" );
777778
778- if (!isa<mlir::FloatType>(selfTy.getElementType ()))
779+ auto resultTy = dyn_cast<TensorType>(
780+ this ->getTypeConverter ()->convertType (op.getType ()));
781+
782+ if (!isa<mlir::FloatType>(resultTy.getElementType ()))
779783 return rewriter.notifyMatchFailure (
780- op, " Only floating-point datatype legalization currently supported" );
784+ op, " Only floating-point datatype result types are supported" );
781785
782- rewriter.replaceOpWithNewOp <TosaOpT>(
783- op, this ->getTypeConverter ()->convertType (op.getType ()), self);
786+ // Non floating point inputs are not supported for activation functions
787+ // (erf, sigmoid, tanh) in TOSA so we cast the input to result type
788+ if (!isa<mlir::FloatType>(selfTy.getElementType ()))
789+ self = tosa::promoteType (rewriter, self, resultTy);
790+
791+ rewriter.replaceOpWithNewOp <TosaOpT>(op, resultTy, self);
784792
785793 return success ();
786794 }
@@ -1283,6 +1291,10 @@ class ConvertAtenPowOp : public OpConversionPattern<AtenOpT> {
12831291 auto outType =
12841292 cast<TensorType>(this ->getTypeConverter ()->convertType (op.getType ()));
12851293
1294+ if (!isa<mlir::FloatType>(outType.getElementType ()))
1295+ return rewriter.notifyMatchFailure (
1296+ op, " Only floating-point datatype result types are supported" );
1297+
12861298 Value selfTensor;
12871299 if constexpr (std::is_same<AtenOpT, AtenPowScalarOp>()) {
12881300 Value selfScalar = op.getSelf ();
@@ -1299,9 +1311,10 @@ class ConvertAtenPowOp : public OpConversionPattern<AtenOpT> {
12991311 return rewriter.notifyMatchFailure (
13001312 op, " Only ranked tensor types supported in TOSA Pow" );
13011313
1314+ // Non floating point inputs are not supported for tosa.pow so we cast the
1315+ // input to result type
13021316 if (!isa<mlir::FloatType>(selfTy.getElementType ()))
1303- return rewriter.notifyMatchFailure (
1304- op, " Only floating-point datatype legalization supported" );
1317+ selfTensor = tosa::promoteType (rewriter, selfTensor, outType);
13051318 }
13061319
13071320 Value expTensor;
@@ -1319,6 +1332,11 @@ class ConvertAtenPowOp : public OpConversionPattern<AtenOpT> {
13191332 if (!expTy)
13201333 return rewriter.notifyMatchFailure (
13211334 op, " Only ranked tensor types supported in TOSA Pow" );
1335+
1336+ // Non floating point exponents are not supported for tosa.pow so we cast
1337+ // the exponent to result type
1338+ if (!isa<mlir::FloatType>(expTy.getElementType ()))
1339+ expTensor = tosa::promoteType (rewriter, expTensor, outType);
13221340 }
13231341
13241342 auto powOp = tosa::createBinaryOpAndCast<tosa::PowOp>(
@@ -8198,6 +8216,46 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
81988216 return success ();
81998217}
82008218
8219+ // Legalization for aten.tan
8220+ template <>
8221+ LogicalResult ConvertAtenOp<AtenTanOp>::matchAndRewrite(
8222+ AtenTanOp op, OpAdaptor adaptor,
8223+ ConversionPatternRewriter &rewriter) const {
8224+ // tan = sin / cos
8225+ auto self = adaptor.getSelf ();
8226+
8227+ auto selfType = dyn_cast<TensorType>(self.getType ());
8228+ if (!selfType)
8229+ return rewriter.notifyMatchFailure (op, " Only tensor types are supported" );
8230+
8231+ auto resultType =
8232+ dyn_cast<TensorType>(typeConverter->convertType (op.getType ()));
8233+
8234+ if (!isa<mlir::FloatType>(resultType.getElementType ()))
8235+ return rewriter.notifyMatchFailure (
8236+ op, " Only floating-point datatype result types are supported" );
8237+
8238+ // Non floating point inputs are not supported in TOSA so we cast the input
8239+ // to result type
8240+ if (!isa<mlir::FloatType>(selfType.getElementType ()))
8241+ self = tosa::promoteType (rewriter, self, resultType);
8242+
8243+ auto sinOp = rewriter.create <tosa::SinOp>(op->getLoc (), resultType, self);
8244+
8245+ auto cosOp = rewriter.create <tosa::CosOp>(op->getLoc (), resultType, self);
8246+
8247+ auto reciprocalOp =
8248+ rewriter.create <tosa::ReciprocalOp>(op->getLoc (), resultType, cosOp);
8249+
8250+ auto result = rewriter.create <tosa::MulOp>(
8251+ op->getLoc (), resultType, sinOp.getResult (), reciprocalOp.getResult (),
8252+ /* shift=*/ 0 );
8253+
8254+ rewriter.replaceOp (op, {result.getResult ()});
8255+
8256+ return success ();
8257+ }
8258+
82018259} // namespace
82028260
82038261// -----------------------------------------------------------------------------
@@ -8540,6 +8598,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
85408598 INSERT_ATENOP_PATTERN (AtenLogitOp);
85418599 INSERT_ATENOP_PATTERN (AtenLog1pOp);
85428600 INSERT_ATENOP_PATTERN (AtenLog10Op);
8601+ INSERT_ATENOP_PATTERN (AtenTanOp);
85438602#undef INSERT_ATENOP_PATTERN
85448603
85458604#define INSERT_CLONE_ATENOP_PATTERN (AtenOp ) \
0 commit comments