Skip to content

Commit 985e779

Browse files
authored
[linalg] Added aten.clamp support with integers to torch-to-linalg (#2718)
The lowering for `aten.clamp` did not support integer types. Added support for integer types including a signed integer test.
1 parent 6096fcb commit 985e779

File tree

2 files changed

+65
-18
lines changed

2 files changed

+65
-18
lines changed

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,13 +1007,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
10071007
return b.create<arith::SelectOp>(loc, pred, lhs, rhs);
10081008
}
10091009
if (auto clamp = dyn_cast<AtenClampOp>(op)) {
1010-
Type dtype = converter->convertType(clamp.getType())
1011-
.cast<RankedTensorType>()
1012-
.getElementType();
1013-
if (!dtype.isa<mlir::FloatType>()) {
1014-
clamp.emitError("unimplemented: non-floating point dtype");
1015-
return nullptr;
1016-
}
10171010
AtenClampOp::Adaptor adaptor(operands);
10181011
auto min = adaptor.getMin();
10191012
auto max = adaptor.getMax();
@@ -1022,19 +1015,45 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
10221015
clamp.emitError("unimplemented: runtime optional type");
10231016
return nullptr;
10241017
}
1025-
auto result = payloadArgs[0];
1026-
if (!min.getType().isa<Torch::NoneType>()) {
1027-
auto minPromoted = convertScalarToDtype(b, loc, min, dtype);
1028-
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
1029-
result, minPromoted);
1030-
result = b.create<arith::SelectOp>(loc, pred, minPromoted, result);
1018+
1019+
Type dtype = converter->convertType(clamp.getType())
1020+
.cast<RankedTensorType>()
1021+
.getElementType();
1022+
if (!dtype.isa<mlir::FloatType, mlir::IntegerType>()) {
1023+
clamp.emitError("unimplement type for clamp");
1024+
return nullptr;
10311025
}
1032-
if (!max.getType().isa<Torch::NoneType>()) {
1033-
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
1034-
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
1035-
result, maxPromoted);
1036-
result = b.create<arith::SelectOp>(loc, pred, maxPromoted, result);
1026+
1027+
Type dstOriginalDtype = clamp.getType().cast<BaseTensorType>().getDtype();
1028+
bool isUnsigned = isa<QUInt8Type>(dstOriginalDtype);
1029+
if (auto intTy = dstOriginalDtype.dyn_cast<IntegerType>()) {
1030+
isUnsigned = intTy.isUnsigned();
10371031
}
1032+
auto cmpSelect = [&](Value input, Value clamp, bool getMax) -> Value {
1033+
clamp = convertScalarToDtype(b, loc, clamp, dtype,
1034+
/*srcOriginalDtype=*/std::nullopt,
1035+
/*dstOriginalDtype=*/dstOriginalDtype);
1036+
1037+
Value pred;
1038+
if (dtype.isa<mlir::FloatType>()) {
1039+
auto cmp =
1040+
getMax ? arith::CmpFPredicate::UGT : arith::CmpFPredicate::ULT;
1041+
pred = b.create<arith::CmpFOp>(loc, cmp, input, clamp);
1042+
} else if (dtype.isa<mlir::IntegerType>()) {
1043+
auto cmp =
1044+
isUnsigned ? arith::CmpIPredicate::ult : arith::CmpIPredicate::slt;
1045+
if (getMax)
1046+
cmp = arith::invertPredicate(cmp);
1047+
pred = b.create<arith::CmpIOp>(loc, cmp, input, clamp);
1048+
}
1049+
return b.create<arith::SelectOp>(loc, pred, clamp, input);
1050+
};
1051+
1052+
auto result = payloadArgs[0];
1053+
if (!min.getType().isa<Torch::NoneType>())
1054+
result = cmpSelect(result, min, /*getMax=*/false);
1055+
if (!max.getType().isa<Torch::NoneType>())
1056+
result = cmpSelect(result, max, /*getMax=*/true);
10381057
return result;
10391058
}
10401059
if (auto clampTensor = dyn_cast<AtenClampTensorOp>(op)) {

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,34 @@ def ElementwiseClampTensorIntModule_basic(module, tu: TestUtils):
988988
# ==============================================================================
989989

990990

991+
class ElementwiseClampTensorInt8Module(torch.nn.Module):
992+
993+
def __init__(self):
994+
super().__init__()
995+
996+
@export
997+
@annotate_args([
998+
None,
999+
([-1, -1], torch.int8, True)
1000+
])
1001+
def forward(self, x):
1002+
min = -5
1003+
max = 5
1004+
min_clamp = torch.clamp(x, min)
1005+
max_clamp = torch.clamp(x, max=max)
1006+
both_clamp = torch.clamp(x, min=min, max=max)
1007+
return min_clamp, max_clamp, both_clamp
1008+
1009+
1010+
@register_test_case(module_factory=lambda: ElementwiseClampTensorInt8Module())
1011+
def ElementwiseClampTensorInt8Module_basic(module, tu: TestUtils):
1012+
module.forward(tu.randint(3, 5, low=-10, high=10, dtype=torch.int8))
1013+
1014+
1015+
# ==============================================================================
1016+
1017+
1018+
9911019
class ElementwiseClampMinTensorFloatModule(torch.nn.Module):
9921020

9931021
def __init__(self):

0 commit comments

Comments
 (0)