Skip to content

Commit d59d0b6

Browse files
authored
[Linalg] Promote type for compare tensor op (#3416)
1 parent 661be2d commit d59d0b6

File tree

3 files changed

+76
-75
lines changed

3 files changed

+76
-75
lines changed

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 28 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -149,59 +149,18 @@ static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter,
149149
return convertScalarToDtype(b, loc, newOp, outTy, std::nullopt, outTTy);
150150
}
151151

152-
template <typename OpTy>
153-
static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op,
154-
Value lhs, Value rhs) {
155-
static_assert(std::is_same<OpTy, AtenLtTensorOp>() ||
156-
std::is_same<OpTy, AtenLeTensorOp>() ||
157-
std::is_same<OpTy, AtenGtTensorOp>() ||
158-
std::is_same<OpTy, AtenGeTensorOp>() ||
159-
std::is_same<OpTy, AtenEqTensorOp>() ||
160-
std::is_same<OpTy, AtenNeTensorOp>(),
161-
"unimplemented: op type not supported");
162-
163-
Type lhsDtype = lhs.getType();
164-
Type rhsDtype = rhs.getType();
165-
166-
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
167-
// to be handled.
168-
if (lhsDtype != rhsDtype) {
169-
op.emitError("unimplemented: lhs and rhs dtype must be same");
170-
return nullptr;
171-
}
172-
173-
Type elementalType = cast<BaseTensorType>(op.getSelf().getType()).getDtype();
174-
if constexpr (std::is_same<OpTy, AtenLtTensorOp>()) {
175-
return createLessThan(b, loc, elementalType, lhs, rhs);
176-
}
177-
if constexpr (std::is_same<OpTy, AtenLeTensorOp>()) {
178-
return createLessThanOrEqual(b, loc, elementalType, lhs, rhs);
179-
}
180-
if constexpr (std::is_same<OpTy, AtenGtTensorOp>()) {
181-
return createGreaterThan(b, loc, elementalType, lhs, rhs);
182-
}
183-
if constexpr (std::is_same<OpTy, AtenGeTensorOp>()) {
184-
return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs);
185-
}
186-
if constexpr (std::is_same<OpTy, AtenEqTensorOp>()) {
187-
return createEqual(b, loc, elementalType, lhs, rhs);
188-
}
189-
if constexpr (std::is_same<OpTy, AtenNeTensorOp>()) {
190-
return createNotEqual(b, loc, elementalType, lhs, rhs);
191-
}
192-
llvm_unreachable("unimplemented: op type not supported");
193-
}
152+
template <class T, class... Ts>
153+
struct is_any_same : std::disjunction<std::is_same<T, Ts>...> {};
194154

195155
template <typename OpTy>
196-
static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op,
197-
Value lhs, Value rhs) {
198-
static_assert(std::is_same<OpTy, AtenLtScalarOp>() ||
199-
std::is_same<OpTy, AtenLeScalarOp>() ||
200-
std::is_same<OpTy, AtenEqScalarOp>() ||
201-
std::is_same<OpTy, AtenNeScalarOp>() ||
202-
std::is_same<OpTy, AtenGtScalarOp>() ||
203-
std::is_same<OpTy, AtenGeScalarOp>(),
204-
"unimplemented: op type not supported");
156+
static Value createCompareOp(OpBuilder &b, Location loc, OpTy op, Value lhs,
157+
Value rhs) {
158+
static_assert(
159+
is_any_same<OpTy, AtenLtScalarOp, AtenLeScalarOp, AtenEqScalarOp,
160+
AtenNeScalarOp, AtenGtScalarOp, AtenGeScalarOp,
161+
AtenLtTensorOp, AtenLeTensorOp, AtenGtTensorOp,
162+
AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp>(),
163+
"unimplemented: op type not supported");
205164

206165
Type lhsDtype = lhs.getType();
207166
Type rhsDtype = rhs.getType();
@@ -229,22 +188,22 @@ static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op,
229188
return nullptr;
230189
}
231190

232-
if constexpr (std::is_same<OpTy, AtenLtScalarOp>()) {
191+
if constexpr (is_any_same<OpTy, AtenLtScalarOp, AtenLtTensorOp>()) {
233192
return createLessThan(b, loc, elementalType, lhs, rhs);
234193
}
235-
if constexpr (std::is_same<OpTy, AtenLeScalarOp>()) {
194+
if constexpr (is_any_same<OpTy, AtenLeScalarOp, AtenLeTensorOp>()) {
236195
return createLessThanOrEqual(b, loc, elementalType, lhs, rhs);
237196
}
238-
if constexpr (std::is_same<OpTy, AtenGtScalarOp>()) {
197+
if constexpr (is_any_same<OpTy, AtenGtScalarOp, AtenGtTensorOp>()) {
239198
return createGreaterThan(b, loc, elementalType, lhs, rhs);
240199
}
241-
if constexpr (std::is_same<OpTy, AtenGeScalarOp>()) {
200+
if constexpr (is_any_same<OpTy, AtenGeScalarOp, AtenGeTensorOp>()) {
242201
return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs);
243202
}
244-
if constexpr (std::is_same<OpTy, AtenEqScalarOp>()) {
203+
if constexpr (is_any_same<OpTy, AtenEqScalarOp, AtenEqTensorOp>()) {
245204
return createEqual(b, loc, elementalType, lhs, rhs);
246205
}
247-
if constexpr (std::is_same<OpTy, AtenNeScalarOp>()) {
206+
if constexpr (is_any_same<OpTy, AtenNeScalarOp, AtenNeTensorOp>()) {
248207
return createNotEqual(b, loc, elementalType, lhs, rhs);
249208
}
250209
llvm_unreachable("unimplemented: op type not supported");
@@ -892,28 +851,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
892851
return b.create<math::Atan2Op>(loc, lhs, rhs);
893852
}
894853
if (auto ltTensor = dyn_cast<AtenLtTensorOp>(op)) {
895-
return createCompareTensorOp(b, loc, ltTensor, payloadArgs[0],
896-
payloadArgs[1]);
854+
return createCompareOp(b, loc, ltTensor, payloadArgs[0], payloadArgs[1]);
897855
}
898856
if (auto leTensor = dyn_cast<AtenLeTensorOp>(op)) {
899-
return createCompareTensorOp(b, loc, leTensor, payloadArgs[0],
900-
payloadArgs[1]);
857+
return createCompareOp(b, loc, leTensor, payloadArgs[0], payloadArgs[1]);
901858
}
902859
if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
903-
return createCompareTensorOp(b, loc, gtTensor, payloadArgs[0],
904-
payloadArgs[1]);
860+
return createCompareOp(b, loc, gtTensor, payloadArgs[0], payloadArgs[1]);
905861
}
906862
if (auto geTensor = dyn_cast<AtenGeTensorOp>(op)) {
907-
return createCompareTensorOp(b, loc, geTensor, payloadArgs[0],
908-
payloadArgs[1]);
863+
return createCompareOp(b, loc, geTensor, payloadArgs[0], payloadArgs[1]);
909864
}
910865
if (auto eqTensor = dyn_cast<AtenEqTensorOp>(op)) {
911-
return createCompareTensorOp(b, loc, eqTensor, payloadArgs[0],
912-
payloadArgs[1]);
866+
return createCompareOp(b, loc, eqTensor, payloadArgs[0], payloadArgs[1]);
913867
}
914868
if (auto neTensor = dyn_cast<AtenNeTensorOp>(op)) {
915-
return createCompareTensorOp(b, loc, neTensor, payloadArgs[0],
916-
payloadArgs[1]);
869+
return createCompareOp(b, loc, neTensor, payloadArgs[0], payloadArgs[1]);
917870
}
918871
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
919872
AtenDivTensorOp::Adaptor adaptor(operands);
@@ -996,27 +949,27 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
996949
}
997950

998951
if (auto gtScalar = dyn_cast<AtenGtScalarOp>(op)) {
999-
return createCompareScalarOp(b, loc, gtScalar, payloadArgs[0], operands[1]);
952+
return createCompareOp(b, loc, gtScalar, payloadArgs[0], operands[1]);
1000953
}
1001954

1002955
if (auto geScalar = dyn_cast<AtenGeScalarOp>(op)) {
1003-
return createCompareScalarOp(b, loc, geScalar, payloadArgs[0], operands[1]);
956+
return createCompareOp(b, loc, geScalar, payloadArgs[0], operands[1]);
1004957
}
1005958

1006959
if (auto eqScalar = dyn_cast<AtenEqScalarOp>(op)) {
1007-
return createCompareScalarOp(b, loc, eqScalar, payloadArgs[0], operands[1]);
960+
return createCompareOp(b, loc, eqScalar, payloadArgs[0], operands[1]);
1008961
}
1009962

1010963
if (auto neScalar = dyn_cast<AtenNeScalarOp>(op)) {
1011-
return createCompareScalarOp(b, loc, neScalar, payloadArgs[0], operands[1]);
964+
return createCompareOp(b, loc, neScalar, payloadArgs[0], operands[1]);
1012965
}
1013966

1014967
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
1015-
return createCompareScalarOp(b, loc, ltScalar, payloadArgs[0], operands[1]);
968+
return createCompareOp(b, loc, ltScalar, payloadArgs[0], operands[1]);
1016969
}
1017970

1018971
if (auto leScalar = dyn_cast<AtenLeScalarOp>(op)) {
1019-
return createCompareScalarOp(b, loc, leScalar, payloadArgs[0], operands[1]);
972+
return createCompareOp(b, loc, leScalar, payloadArgs[0], operands[1]);
1020973
}
1021974

1022975
if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"InterpolateDynamicModule_sizes_nearest",
2828
"InterpolateStaticModule_scales_bilinear_align_corners",
2929
"InterpolateDynamicModule_scales_recompute_bilinear",
30+
"ElementwiseFloatTensorGtIntTensorModule_basic",
3031
}
3132

3233
LINALG_CRASHING_SET = {
@@ -2707,6 +2708,7 @@
27072708
"ElementwiseTanIntModule_basic",
27082709
"ElementwiseToDtypeI64ToUI8Module_basic",
27092710
"ElementwiseUnaryIntModule_basic",
2711+
"ElementwiseFloatTensorGtIntTensorModule_basic",
27102712
"MaskedFillTensorFloatValueModule_basic",
27112713
"NativeDropoutTrainModule_basic",
27122714
"NativeDropoutTrainStaticShapeModule_basic",
@@ -3786,6 +3788,7 @@
37863788
"ElementwiseExpm1IntModule_basic",
37873789
"ElementwiseExpm1Module_basic",
37883790
"ElementwiseFlattenBroadcastModule_basic",
3791+
"ElementwiseFloatTensorGtIntTensorModule_basic",
37893792
"ElementwiseFmodTensor_Float_basic",
37903793
"ElementwiseFmodTensor_Int_Float_basic",
37913794
"ElementwiseFmodTensor_Int_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,51 @@ def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils):
599599
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
600600

601601

602+
class ElementwiseIntTensorLtFloatTensorModule(torch.nn.Module):
603+
def __init__(self):
604+
super().__init__()
605+
606+
@export
607+
@annotate_args(
608+
[
609+
None,
610+
([-1, -1], torch.int64, True),
611+
([-1], torch.float64, True),
612+
]
613+
)
614+
def forward(self, x, y):
615+
return torch.lt(x, y)
616+
617+
618+
@register_test_case(module_factory=lambda: ElementwiseIntTensorLtFloatTensorModule())
619+
def ElementwiseIntTensorLtFloatTensorModule_basic(module, tu: TestUtils):
620+
module.forward(tu.randint(3, 5, high=10), tu.rand(5, high=10).to(torch.float64))
621+
622+
623+
class ElementwiseFloatTensorGtIntTensorModule(torch.nn.Module):
624+
def __init__(self):
625+
super().__init__()
626+
627+
@export
628+
@annotate_args(
629+
[
630+
None,
631+
([-1, -1], torch.float32, True),
632+
([-1], torch.int32, True),
633+
]
634+
)
635+
def forward(self, x, y):
636+
return torch.gt(x, y)
637+
638+
639+
@register_test_case(module_factory=lambda: ElementwiseIntTensorLtFloatTensorModule())
640+
def ElementwiseFloatTensorGtIntTensorModule_basic(module, tu: TestUtils):
641+
module.forward(
642+
tu.rand(3, 5, high=10).to(torch.float32),
643+
tu.randint(5, high=10, dtype=torch.int32),
644+
)
645+
646+
602647
# ==============================================================================
603648

604649

0 commit comments

Comments
 (0)