@@ -1007,13 +1007,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
1007
1007
return b.create <arith::SelectOp>(loc, pred, lhs, rhs);
1008
1008
}
1009
1009
if (auto clamp = dyn_cast<AtenClampOp>(op)) {
1010
- Type dtype = converter->convertType (clamp.getType ())
1011
- .cast <RankedTensorType>()
1012
- .getElementType ();
1013
- if (!dtype.isa <mlir::FloatType>()) {
1014
- clamp.emitError (" unimplemented: non-floating point dtype" );
1015
- return nullptr ;
1016
- }
1017
1010
AtenClampOp::Adaptor adaptor (operands);
1018
1011
auto min = adaptor.getMin ();
1019
1012
auto max = adaptor.getMax ();
@@ -1022,19 +1015,45 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
1022
1015
clamp.emitError (" unimplemented: runtime optional type" );
1023
1016
return nullptr ;
1024
1017
}
1025
- auto result = payloadArgs[0 ];
1026
- if (!min.getType ().isa <Torch::NoneType>()) {
1027
- auto minPromoted = convertScalarToDtype (b, loc, min, dtype);
1028
- auto pred = b.create <arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
1029
- result, minPromoted);
1030
- result = b.create <arith::SelectOp>(loc, pred, minPromoted, result);
1018
+
1019
+ Type dtype = converter->convertType (clamp.getType ())
1020
+ .cast <RankedTensorType>()
1021
+ .getElementType ();
1022
+ if (!dtype.isa <mlir::FloatType, mlir::IntegerType>()) {
1023
+ clamp.emitError (" unimplement type for clamp" );
1024
+ return nullptr ;
1031
1025
}
1032
- if (!max. getType (). isa <Torch::NoneType>()) {
1033
- auto maxPromoted = convertScalarToDtype (b, loc, max, dtype );
1034
- auto pred = b. create <arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
1035
- result, maxPromoted);
1036
- result = b. create <arith::SelectOp>(loc, pred, maxPromoted, result );
1026
+
1027
+ Type dstOriginalDtype = clamp. getType (). cast <BaseTensorType>(). getDtype ( );
1028
+ bool isUnsigned = isa<QUInt8Type>(dstOriginalDtype);
1029
+ if ( auto intTy = dstOriginalDtype. dyn_cast <IntegerType>()) {
1030
+ isUnsigned = intTy. isUnsigned ( );
1037
1031
}
1032
+ auto cmpSelect = [&](Value input, Value clamp, bool getMax) -> Value {
1033
+ clamp = convertScalarToDtype (b, loc, clamp, dtype,
1034
+ /* srcOriginalDtype=*/ std::nullopt ,
1035
+ /* dstOriginalDtype=*/ dstOriginalDtype);
1036
+
1037
+ Value pred;
1038
+ if (dtype.isa <mlir::FloatType>()) {
1039
+ auto cmp =
1040
+ getMax ? arith::CmpFPredicate::UGT : arith::CmpFPredicate::ULT;
1041
+ pred = b.create <arith::CmpFOp>(loc, cmp, input, clamp);
1042
+ } else if (dtype.isa <mlir::IntegerType>()) {
1043
+ auto cmp =
1044
+ isUnsigned ? arith::CmpIPredicate::ult : arith::CmpIPredicate::slt;
1045
+ if (getMax)
1046
+ cmp = arith::invertPredicate (cmp);
1047
+ pred = b.create <arith::CmpIOp>(loc, cmp, input, clamp);
1048
+ }
1049
+ return b.create <arith::SelectOp>(loc, pred, clamp, input);
1050
+ };
1051
+
1052
+ auto result = payloadArgs[0 ];
1053
+ if (!min.getType ().isa <Torch::NoneType>())
1054
+ result = cmpSelect (result, min, /* getMax=*/ false );
1055
+ if (!max.getType ().isa <Torch::NoneType>())
1056
+ result = cmpSelect (result, max, /* getMax=*/ true );
1038
1057
return result;
1039
1058
}
1040
1059
if (auto clampTensor = dyn_cast<AtenClampTensorOp>(op)) {
0 commit comments