Skip to content

Commit 5077090

Browse files
[TOSA] Add some more mixed dtype handling (#3909)
* Add int input handling for activation functions like erf, sigmoid, and tanh * Fix mixed dtype handling for scalar comparison ops * Add mixed dtype handling for pow tensor op (with only floating point result type support for now) * Add Torch to TOSA lowering for torch.aten.tan Change-Id: I3a8aa1e6febbc0e39ebdb5734f87ae171b03cd73 Signed-off-by: Justin Ngo <[email protected]>
1 parent a99e378 commit 5077090

File tree

3 files changed

+216
-40
lines changed

3 files changed

+216
-40
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 83 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
405405
Value rhsAsTensor;
406406
if (!rhsTy) {
407407
if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(),
408-
rhsAsTensor, lhsElemTy, {})))
408+
rhsAsTensor, rhs.getType(), {})))
409409
return rewriter.notifyMatchFailure(
410410
op, "Currently only scalar constants are supported for "
411411
"conversion in TOSA operation");
@@ -414,11 +414,26 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
414414
auto rhsTensorTy = dyn_cast<TensorType>(rhsTensor.getType());
415415
auto rhsElemTy = rhsTensorTy.getElementType();
416416

417+
// There is no Lesser operator in TOSA.
418+
constexpr auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
419+
std::is_same<AtenOpT, AtenLtScalarOp>() ||
420+
std::is_same<AtenOpT, AtenLeTensorOp>() ||
421+
std::is_same<AtenOpT, AtenLeScalarOp>());
422+
423+
// Promote lhs and rhs dtypes for bitwise operators.
424+
TensorType resultTy = cast<TensorType>(
425+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
426+
op.getType()));
427+
if (isBitwiseOp) {
428+
lhs = tosa::promoteType(rewriter, lhs, resultTy);
429+
rhsTensor = tosa::promoteType(rewriter, rhsTensor, resultTy);
430+
}
431+
432+
// Support different types comparisons
417433
auto isLhsElemFloat = isa<mlir::FloatType>(lhsElemTy);
418434
auto isRhsElemFloat = isa<mlir::FloatType>(rhsElemTy);
419435

420-
// Support different types comparisons
421-
if (lhsElemTy != rhsElemTy) {
436+
if (lhsElemTy != rhsElemTy && !isBitwiseOp) {
422437
if (isLhsElemFloat && !isRhsElemFloat) {
423438
rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy);
424439
} else if (!isLhsElemFloat && isRhsElemFloat) {
@@ -441,20 +456,6 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
441456
}
442457
}
443458
}
444-
// There is no Lesser operator in TOSA.
445-
constexpr auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
446-
std::is_same<AtenOpT, AtenLtScalarOp>() ||
447-
std::is_same<AtenOpT, AtenLeTensorOp>() ||
448-
std::is_same<AtenOpT, AtenLeScalarOp>());
449-
450-
// Promote lhs and rhs dtypes for bitwise operators.
451-
TensorType resultTy = cast<TensorType>(
452-
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
453-
op.getType()));
454-
if (isBitwiseOp) {
455-
lhs = tosa::promoteType(rewriter, lhs, resultTy);
456-
rhsTensor = tosa::promoteType(rewriter, rhsTensor, resultTy);
457-
}
458459

459460
auto resultOp = rewriter.create<TosaOpT>(op.getLoc(), resultTy,
460461
(swapLhsRhs ? rhsTensor : lhs),
@@ -770,17 +771,24 @@ class ConvertAtenActivationFunctionOp : public OpConversionPattern<AtenOpT> {
770771
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
771772
ConversionPatternRewriter &rewriter) const override {
772773
Value self = adaptor.getSelf();
773-
auto selfTy = cast<TensorType>(self.getType());
774+
auto selfTy = dyn_cast<TensorType>(self.getType());
774775

775776
if (!selfTy)
776777
return rewriter.notifyMatchFailure(op, "Only Tensor types supported");
777778

778-
if (!isa<mlir::FloatType>(selfTy.getElementType()))
779+
auto resultTy = dyn_cast<TensorType>(
780+
this->getTypeConverter()->convertType(op.getType()));
781+
782+
if (!isa<mlir::FloatType>(resultTy.getElementType()))
779783
return rewriter.notifyMatchFailure(
780-
op, "Only floating-point datatype legalization currently supported");
784+
op, "Only floating-point datatype result types are supported");
781785

782-
rewriter.replaceOpWithNewOp<TosaOpT>(
783-
op, this->getTypeConverter()->convertType(op.getType()), self);
786+
// Non floating point inputs are not supported for activation functions
787+
// (erf, sigmoid, tanh) in TOSA so we cast the input to result type
788+
if (!isa<mlir::FloatType>(selfTy.getElementType()))
789+
self = tosa::promoteType(rewriter, self, resultTy);
790+
791+
rewriter.replaceOpWithNewOp<TosaOpT>(op, resultTy, self);
784792

785793
return success();
786794
}
@@ -1283,6 +1291,10 @@ class ConvertAtenPowOp : public OpConversionPattern<AtenOpT> {
12831291
auto outType =
12841292
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
12851293

1294+
if (!isa<mlir::FloatType>(outType.getElementType()))
1295+
return rewriter.notifyMatchFailure(
1296+
op, "Only floating-point datatype result types are supported");
1297+
12861298
Value selfTensor;
12871299
if constexpr (std::is_same<AtenOpT, AtenPowScalarOp>()) {
12881300
Value selfScalar = op.getSelf();
@@ -1299,9 +1311,10 @@ class ConvertAtenPowOp : public OpConversionPattern<AtenOpT> {
12991311
return rewriter.notifyMatchFailure(
13001312
op, "Only ranked tensor types supported in TOSA Pow");
13011313

1314+
// Non floating point inputs are not supported for tosa.pow so we cast the
1315+
// input to result type
13021316
if (!isa<mlir::FloatType>(selfTy.getElementType()))
1303-
return rewriter.notifyMatchFailure(
1304-
op, "Only floating-point datatype legalization supported");
1317+
selfTensor = tosa::promoteType(rewriter, selfTensor, outType);
13051318
}
13061319

13071320
Value expTensor;
@@ -1319,6 +1332,11 @@ class ConvertAtenPowOp : public OpConversionPattern<AtenOpT> {
13191332
if (!expTy)
13201333
return rewriter.notifyMatchFailure(
13211334
op, "Only ranked tensor types supported in TOSA Pow");
1335+
1336+
// Non floating point exponents are not supported for tosa.pow so we cast
1337+
// the exponent to result type
1338+
if (!isa<mlir::FloatType>(expTy.getElementType()))
1339+
expTensor = tosa::promoteType(rewriter, expTensor, outType);
13221340
}
13231341

13241342
auto powOp = tosa::createBinaryOpAndCast<tosa::PowOp>(
@@ -8198,6 +8216,46 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
81988216
return success();
81998217
}
82008218

8219+
// Legalization for aten.tan
8220+
template <>
8221+
LogicalResult ConvertAtenOp<AtenTanOp>::matchAndRewrite(
8222+
AtenTanOp op, OpAdaptor adaptor,
8223+
ConversionPatternRewriter &rewriter) const {
8224+
// tan = sin / cos
8225+
auto self = adaptor.getSelf();
8226+
8227+
auto selfType = dyn_cast<TensorType>(self.getType());
8228+
if (!selfType)
8229+
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
8230+
8231+
auto resultType =
8232+
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
8233+
8234+
if (!isa<mlir::FloatType>(resultType.getElementType()))
8235+
return rewriter.notifyMatchFailure(
8236+
op, "Only floating-point datatype result types are supported");
8237+
8238+
// Non floating point inputs are not supported in TOSA so we cast the input
8239+
// to result type
8240+
if (!isa<mlir::FloatType>(selfType.getElementType()))
8241+
self = tosa::promoteType(rewriter, self, resultType);
8242+
8243+
auto sinOp = rewriter.create<tosa::SinOp>(op->getLoc(), resultType, self);
8244+
8245+
auto cosOp = rewriter.create<tosa::CosOp>(op->getLoc(), resultType, self);
8246+
8247+
auto reciprocalOp =
8248+
rewriter.create<tosa::ReciprocalOp>(op->getLoc(), resultType, cosOp);
8249+
8250+
auto result = rewriter.create<tosa::MulOp>(
8251+
op->getLoc(), resultType, sinOp.getResult(), reciprocalOp.getResult(),
8252+
/*shift=*/0);
8253+
8254+
rewriter.replaceOp(op, {result.getResult()});
8255+
8256+
return success();
8257+
}
8258+
82018259
} // namespace
82028260

82038261
// -----------------------------------------------------------------------------
@@ -8540,6 +8598,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
85408598
INSERT_ATENOP_PATTERN(AtenLogitOp);
85418599
INSERT_ATENOP_PATTERN(AtenLog1pOp);
85428600
INSERT_ATENOP_PATTERN(AtenLog10Op);
8601+
INSERT_ATENOP_PATTERN(AtenTanOp);
85438602
#undef INSERT_ATENOP_PATTERN
85448603

85458604
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,6 +1717,13 @@
17171717
# Write the TOSA set as a "passing" set as it is very early in development
17181718
# and very few tests work yet.
17191719
TOSA_PASS_SET = {
1720+
"ElementwiseErfIntModule_basic",
1721+
"ElementwiseIntTensorLtFloatScalarModule_basic",
1722+
"ElementwiseSigmoidIntModule_basic",
1723+
"ElementwiseTanIntModule_basic",
1724+
"ElementwiseTanModule_basic",
1725+
"ElementwiseUnaryIntModule_basic",
1726+
"PowIntFloatModule_basic",
17201727
"Deg2radModule_basic",
17211728
"ElementwiseIntTensorLtFloatTensorModule_basic",
17221729
"L1LossMeanReductionModule_basic",
@@ -3658,22 +3665,16 @@
36583665
"ElementwiseCoshModule_basic",
36593666
"ElementwiseDequantizePerChannelModule_basic",
36603667
"ElementwiseDequantizePerTensorModule_basic",
3661-
"ElementwiseErfIntModule_basic",
36623668
"ElementwiseExpm1IntModule_basic",
36633669
"ElementwiseExpm1Module_basic",
3664-
"ElementwiseIntTensorLtFloatScalarModule_basic",
36653670
"ElementwiseMulTensorComplexDiffModule_basic",
36663671
"ElementwiseMulTensorComplexModule_basic",
36673672
"ElementwiseQuantizePerTensorModule_basic",
36683673
"ElementwiseQuantizePerTensorUIntModule_basic",
3669-
"ElementwiseSigmoidIntModule_basic",
36703674
"ElementwiseSinhIntModule_basic",
36713675
"ElementwiseSinhModule_basic",
3672-
"ElementwiseTanIntModule_basic",
3673-
"ElementwiseTanModule_basic",
36743676
"ElementwiseToDtypeF32ToI64Module_basic",
36753677
"ElementwiseToDtypeI64ToUI8Module_basic",
3676-
"ElementwiseUnaryIntModule_basic",
36773678
"ElementwiseWhereScalarOtherStaticModule_basic",
36783679
"EqIntModule_basic",
36793680
"FloatImplicitModule_basic",
@@ -3780,7 +3781,6 @@
37803781
"NumelZeroRankModule_basic",
37813782
"OnesLikeModule_falsePinMemory",
37823783
"PowIntIntModule_basic",
3783-
"PowIntFloatModule_basic",
37843784
"PrimMaxIntModule_basic",
37853785
"PrimMinIntDynamicModule_basic",
37863786
"PrimMinIntModule_basic",
@@ -4369,7 +4369,6 @@
43694369
"ElementwiseSqrtIntModule_basic",
43704370
"ElementwiseSubScalarIntModule_basic",
43714371
"ElementwiseTanIntModule_basic",
4372-
"ElementwiseTanModule_basic",
43734372
"ElementwiseTernaryModule_basic",
43744373
"ElementwiseToDtypeF32ToI64Module_basic",
43754374
"ElementwiseToDtypeI64ToI8Module_basic",

0 commit comments

Comments
 (0)