Skip to content

Commit b0cb49c

Browse files
authored
Add scalar type promotion for mul and div (#454)
1 parent c9c9b68 commit b0cb49c

File tree

4 files changed

+96
-15
lines changed

4 files changed

+96
-15
lines changed

e2e_testing/torchscript/basic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,7 @@ def AddCDivModule_basic(module, tu: TestUtils):
784784

785785
# ==============================================================================
786786

787+
787788
class DropoutModule(torch.nn.Module):
788789
def __init__(self):
789790
super().__init__()

e2e_testing/torchscript/elementwise.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,8 @@ def RsubModule_noalpha_basic(module, tu: TestUtils):
363363
module.forward(tu.rand(3, 4))
364364

365365
# ==============================================================================
366+
367+
366368
class ElementwiseMulScalarModule(torch.nn.Module):
367369
def __init__(self):
368370
super().__init__()
@@ -378,7 +380,52 @@ def forward(self, x):
378380
@register_test_case(module_factory=lambda: ElementwiseMulScalarModule())
379381
def ElementwiseMulScalarModule_basic(module, tu: TestUtils):
380382
module.forward(tu.rand(3, 4))
381-
383+
384+
385+
386+
class ElementwiseMulTensorFloatModule(torch.nn.Module):
387+
def __init__(self):
388+
super().__init__()
389+
390+
@export
391+
@annotate_args([
392+
None,
393+
([-1], torch.float32, True),
394+
([-1], torch.float64, True),
395+
])
396+
def forward(self, a, b):
397+
return torch.mul(a, b)
398+
399+
400+
@register_test_case(
401+
module_factory=lambda: ElementwiseMulTensorFloatModule())
402+
def ElementwiseMulTensorFloatModule_basic(module, tu: TestUtils):
403+
module.forward(
404+
tu.rand(4),
405+
tu.rand(4).type(torch.float64))
406+
407+
class ElementwiseMulTensorIntModule(torch.nn.Module):
408+
def __init__(self):
409+
super().__init__()
410+
411+
@export
412+
@annotate_args([
413+
None,
414+
([-1], torch.int32, True),
415+
([-1], torch.int64, True),
416+
])
417+
def forward(self, a, b):
418+
return torch.mul(a, b)
419+
420+
421+
@register_test_case(
422+
module_factory=lambda: ElementwiseMulTensorIntModule())
423+
def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
424+
module.forward(
425+
torch.randint(10, [4]).type(torch.int32),
426+
torch.randint(10, [4]))
427+
428+
382429
# ==============================================================================
383430
class ElementwiseLogModule(torch.nn.Module):
384431
def __init__(self):
@@ -553,7 +600,32 @@ def forward(self, x):
553600
def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
554601
module.forward(tu.rand(3, 4))
555602

603+
604+
class ElementwiseDivTensorFloatModule(torch.nn.Module):
605+
def __init__(self):
606+
super().__init__()
607+
608+
@export
609+
@annotate_args([
610+
None,
611+
([-1], torch.float32, True),
612+
([-1], torch.float64, True),
613+
])
614+
def forward(self, a, b):
615+
return torch.div(a, b)
616+
617+
618+
@register_test_case(
619+
module_factory=lambda: ElementwiseDivTensorFloatModule())
620+
def ElementwiseDivTensorFloatModule_basic(module, tu: TestUtils):
621+
module.forward(
622+
tu.rand(4),
623+
tu.rand(4).type(torch.float64))
624+
625+
556626
# ==============================================================================
627+
628+
557629
class ElementwiseAndIntegerModule(torch.nn.Module):
558630
def __init__(self):
559631
super().__init__()
@@ -573,3 +645,5 @@ def forward(self, x, y):
573645
def ElementwiseAndIntegerModule_basic(module, tu: TestUtils):
574646
module.forward(torch.randint(-10, 10, (3, 4)).to(torch.int32),
575647
torch.randint(-10, 10, (3, 4)))
648+
649+

e2e_testing/torchscript/type_promotion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,5 @@ def forward(self, a, b):
111111
@register_test_case(module_factory=lambda: TypePromotionAlphaWiderModule())
112112
def TypePromotionAlphaWiderModule_basic(module, tu: TestUtils):
113113
module.forward(tu.rand(4), tu.rand())
114+
115+

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,24 +1531,28 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
15311531
}
15321532
}
15331533
if (auto mul = dyn_cast<AtenMulTensorOp>(op)) {
1534-
if (!mul.getType()
1535-
.cast<ValueTensorType>()
1536-
.getDtype()
1537-
.isa<mlir::FloatType>()) {
1538-
mul.emitError("unimplemented: non-floating point dtype");
1539-
return nullptr;
1534+
AtenMulTensorOp::Adaptor adaptor(operands);
1535+
Type dtype = converter->convertType(mul.getType())
1536+
.cast<RankedTensorType>()
1537+
.getElementType();
1538+
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
1539+
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
1540+
if (dtype.isa<mlir::FloatType>()) {
1541+
return b.create<arith::MulFOp>(loc, lhs, rhs);
1542+
} else {
1543+
return b.create<arith::MulIOp>(loc, lhs, rhs);
15401544
}
1541-
return b.create<arith::MulFOp>(loc, payloadArgs[0], payloadArgs[1]);
15421545
}
15431546
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
1544-
if (!div.getType()
1545-
.cast<ValueTensorType>()
1546-
.getDtype()
1547-
.isa<mlir::FloatType>()) {
1547+
AtenDivTensorOp::Adaptor adaptor(operands);
1548+
Type dtype = converter->convertType(div.getType())
1549+
.cast<RankedTensorType>()
1550+
.getElementType();
1551+
if (!dtype.isa<mlir::FloatType>())
15481552
div.emitError("unimplemented: non-floating point dtype");
1549-
return nullptr;
1550-
}
1551-
return b.create<arith::DivFOp>(loc, payloadArgs[0], payloadArgs[1]);
1553+
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
1554+
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
1555+
return b.create<arith::DivFOp>(loc, lhs, rhs);
15521556
}
15531557
if (auto pow = dyn_cast<AtenPowTensorScalarOp>(op)) {
15541558
if (!pow.getType()

0 commit comments

Comments
 (0)