Skip to content

Commit 9e1ecf2

Browse files
author
Prashant Kumar
committed
Add Add and Sub scalar op conversions.
`aten.add.Scalar` and `aten.sub.Scalar` op conversions have been added. The changes have been made as a part of `-convert-torch-to-linalg` pass.
1 parent 3cb46ce commit 9e1ecf2

File tree

2 files changed

+104
-1
lines changed

2 files changed

+104
-1
lines changed

e2e_testing/torchscript/elementwise.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,3 +1000,69 @@ def ElementwiseAndIntegerModule_basic(module, tu: TestUtils):
10001000
torch.randint(-10, 10, (3, 4)))
10011001

10021002

1003+
class ElementwiseSubScalarIntModule(torch.nn.Module):
1004+
def __init__(self):
1005+
super().__init__()
1006+
1007+
@export
1008+
@annotate_args([
1009+
None,
1010+
([-1, -1], torch.int64, True),
1011+
])
1012+
def forward(self, x):
1013+
return torch.sub(x, 2.1, alpha = 2)
1014+
1015+
@register_test_case(module_factory=lambda: ElementwiseSubScalarIntModule())
1016+
def ElementwiseSubScalarIntModule_basic(module, tu: TestUtils):
1017+
module.forward(torch.randint(10, (3, 4)))
1018+
1019+
1020+
class ElementwiseSubScalarFloatModule(torch.nn.Module):
1021+
def __init__(self):
1022+
super().__init__()
1023+
1024+
@export
1025+
@annotate_args([
1026+
None,
1027+
([-1, -1], torch.float32, True),
1028+
])
1029+
def forward(self, x):
1030+
return torch.sub(x, 2.1)
1031+
1032+
@register_test_case(module_factory=lambda: ElementwiseSubScalarFloatModule())
1033+
def ElementwiseSubScalarFloatModule_basic(module, tu: TestUtils):
1034+
module.forward(tu.rand(3, 4))
1035+
1036+
1037+
class ElementwiseAddScalarIntModule(torch.nn.Module):
1038+
def __init__(self):
1039+
super().__init__()
1040+
1041+
@export
1042+
@annotate_args([
1043+
None,
1044+
([-1, -1], torch.int64, True),
1045+
])
1046+
def forward(self, x):
1047+
return torch.add(x, 3.0)
1048+
1049+
@register_test_case(module_factory=lambda: ElementwiseAddScalarIntModule())
1050+
def ElementwiseAddScalarIntModule_basic(module, tu: TestUtils):
1051+
module.forward(torch.randint(10, (3, 4)))
1052+
1053+
1054+
class ElementwiseAddScalarFloatModule(torch.nn.Module):
1055+
def __init__(self):
1056+
super().__init__()
1057+
1058+
@export
1059+
@annotate_args([
1060+
None,
1061+
([-1, -1], torch.float32, True),
1062+
])
1063+
def forward(self, x):
1064+
return torch.add(x, 3.0, alpha = 2)
1065+
1066+
@register_test_case(module_factory=lambda: ElementwiseAddScalarFloatModule())
1067+
def ElementwiseAddScalarFloatModule_basic(module, tu: TestUtils):
1068+
module.forward(tu.rand(3, 4))

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1676,6 +1676,42 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
16761676
return b.create<arith::SubIOp>(loc, lhs, scaled);
16771677
}
16781678
}
1679+
if (auto subScalar = dyn_cast<AtenSubScalarOp>(op)) {
1680+
Type dtype = converter->convertType(subScalar.getType())
1681+
.cast<RankedTensorType>()
1682+
.getElementType();
1683+
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
1684+
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
1685+
Value alpha = convertScalarToDtype(b, loc, operands[2], dtype);
1686+
if (dtype.isa<mlir::FloatType>()) {
1687+
Value mult = b.create<arith::MulFOp>(loc, other, alpha);
1688+
return b.create<arith::SubFOp>(loc, self, mult);
1689+
} else if (dtype.isa<mlir::IntegerType>()) {
1690+
Value mult = b.create<arith::MulIOp>(loc, other, alpha);
1691+
return b.create<arith::SubIOp>(loc, self, mult);
1692+
}
1693+
subScalar.emitError("unimplemented: dtype other than float and integer "
1694+
"types are not supported.");
1695+
return nullptr;
1696+
}
1697+
if (auto addScalar = dyn_cast<AtenAddScalarOp>(op)) {
1698+
Type dtype = converter->convertType(addScalar.getType())
1699+
.cast<RankedTensorType>()
1700+
.getElementType();
1701+
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
1702+
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
1703+
Value alpha = convertScalarToDtype(b, loc, operands[2], dtype);
1704+
if (dtype.isa<mlir::FloatType>()) {
1705+
Value mult = b.create<arith::MulFOp>(loc, other, alpha);
1706+
return b.create<arith::AddFOp>(loc, self, mult);
1707+
} else if (dtype.isa<mlir::IntegerType>()) {
1708+
Value mult = b.create<arith::MulIOp>(loc, other, alpha);
1709+
return b.create<arith::AddIOp>(loc, self, mult);
1710+
}
1711+
addScalar.emitError("unimplemented: dtype other than float and integer "
1712+
"types are not supported.");
1713+
return nullptr;
1714+
}
16791715
if (auto mul = dyn_cast<AtenMulTensorOp>(op)) {
16801716
AtenMulTensorOp::Adaptor adaptor(operands);
16811717
Type dtype = converter->convertType(mul.getType())
@@ -2244,7 +2280,8 @@ struct ConvertElementwiseOp : ConversionPattern {
22442280
AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
22452281
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenEqScalarOp,
22462282
AtenLtScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
2247-
AtenEqTensorOp, AtenLtTensorOp>(op))
2283+
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp>(
2284+
op))
22482285
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
22492286

22502287
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))

0 commit comments

Comments
 (0)