@@ -149,59 +149,18 @@ static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter,
149149 return convertScalarToDtype (b, loc, newOp, outTy, std::nullopt , outTTy);
150150}
151151
152- template <typename OpTy>
153- static Value createCompareTensorOp (OpBuilder &b, Location loc, OpTy op,
154- Value lhs, Value rhs) {
155- static_assert (std::is_same<OpTy, AtenLtTensorOp>() ||
156- std::is_same<OpTy, AtenLeTensorOp>() ||
157- std::is_same<OpTy, AtenGtTensorOp>() ||
158- std::is_same<OpTy, AtenGeTensorOp>() ||
159- std::is_same<OpTy, AtenEqTensorOp>() ||
160- std::is_same<OpTy, AtenNeTensorOp>(),
161- " unimplemented: op type not supported" );
162-
163- Type lhsDtype = lhs.getType ();
164- Type rhsDtype = rhs.getType ();
165-
166- // TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
167- // to be handled.
168- if (lhsDtype != rhsDtype) {
169- op.emitError (" unimplemented: lhs and rhs dtype must be same" );
170- return nullptr ;
171- }
172-
173- Type elementalType = cast<BaseTensorType>(op.getSelf ().getType ()).getDtype ();
174- if constexpr (std::is_same<OpTy, AtenLtTensorOp>()) {
175- return createLessThan (b, loc, elementalType, lhs, rhs);
176- }
177- if constexpr (std::is_same<OpTy, AtenLeTensorOp>()) {
178- return createLessThanOrEqual (b, loc, elementalType, lhs, rhs);
179- }
180- if constexpr (std::is_same<OpTy, AtenGtTensorOp>()) {
181- return createGreaterThan (b, loc, elementalType, lhs, rhs);
182- }
183- if constexpr (std::is_same<OpTy, AtenGeTensorOp>()) {
184- return createGreaterThanOrEqual (b, loc, elementalType, lhs, rhs);
185- }
186- if constexpr (std::is_same<OpTy, AtenEqTensorOp>()) {
187- return createEqual (b, loc, elementalType, lhs, rhs);
188- }
189- if constexpr (std::is_same<OpTy, AtenNeTensorOp>()) {
190- return createNotEqual (b, loc, elementalType, lhs, rhs);
191- }
192- llvm_unreachable (" unimplemented: op type not supported" );
193- }
152+ template <class T , class ... Ts>
153+ struct is_any_same : std::disjunction<std::is_same<T, Ts>...> {};
194154
195155template <typename OpTy>
196- static Value createCompareScalarOp (OpBuilder &b, Location loc, OpTy op,
197- Value lhs, Value rhs) {
198- static_assert (std::is_same<OpTy, AtenLtScalarOp>() ||
199- std::is_same<OpTy, AtenLeScalarOp>() ||
200- std::is_same<OpTy, AtenEqScalarOp>() ||
201- std::is_same<OpTy, AtenNeScalarOp>() ||
202- std::is_same<OpTy, AtenGtScalarOp>() ||
203- std::is_same<OpTy, AtenGeScalarOp>(),
204- " unimplemented: op type not supported" );
156+ static Value createCompareOp (OpBuilder &b, Location loc, OpTy op, Value lhs,
157+ Value rhs) {
158+ static_assert (
159+ is_any_same<OpTy, AtenLtScalarOp, AtenLeScalarOp, AtenEqScalarOp,
160+ AtenNeScalarOp, AtenGtScalarOp, AtenGeScalarOp,
161+ AtenLtTensorOp, AtenLeTensorOp, AtenGtTensorOp,
162+ AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp>(),
163+ " unimplemented: op type not supported" );
205164
206165 Type lhsDtype = lhs.getType ();
207166 Type rhsDtype = rhs.getType ();
@@ -229,22 +188,22 @@ static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op,
229188 return nullptr ;
230189 }
231190
232- if constexpr (std::is_same <OpTy, AtenLtScalarOp>()) {
191+ if constexpr (is_any_same <OpTy, AtenLtScalarOp, AtenLtTensorOp >()) {
233192 return createLessThan (b, loc, elementalType, lhs, rhs);
234193 }
235- if constexpr (std::is_same <OpTy, AtenLeScalarOp>()) {
194+ if constexpr (is_any_same <OpTy, AtenLeScalarOp, AtenLeTensorOp >()) {
236195 return createLessThanOrEqual (b, loc, elementalType, lhs, rhs);
237196 }
238- if constexpr (std::is_same <OpTy, AtenGtScalarOp>()) {
197+ if constexpr (is_any_same <OpTy, AtenGtScalarOp, AtenGtTensorOp >()) {
239198 return createGreaterThan (b, loc, elementalType, lhs, rhs);
240199 }
241- if constexpr (std::is_same <OpTy, AtenGeScalarOp>()) {
200+ if constexpr (is_any_same <OpTy, AtenGeScalarOp, AtenGeTensorOp >()) {
242201 return createGreaterThanOrEqual (b, loc, elementalType, lhs, rhs);
243202 }
244- if constexpr (std::is_same <OpTy, AtenEqScalarOp>()) {
203+ if constexpr (is_any_same <OpTy, AtenEqScalarOp, AtenEqTensorOp >()) {
245204 return createEqual (b, loc, elementalType, lhs, rhs);
246205 }
247- if constexpr (std::is_same <OpTy, AtenNeScalarOp>()) {
206+ if constexpr (is_any_same <OpTy, AtenNeScalarOp, AtenNeTensorOp >()) {
248207 return createNotEqual (b, loc, elementalType, lhs, rhs);
249208 }
250209 llvm_unreachable (" unimplemented: op type not supported" );
@@ -892,28 +851,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
892851 return b.create <math::Atan2Op>(loc, lhs, rhs);
893852 }
894853 if (auto ltTensor = dyn_cast<AtenLtTensorOp>(op)) {
895- return createCompareTensorOp (b, loc, ltTensor, payloadArgs[0 ],
896- payloadArgs[1 ]);
854+ return createCompareOp (b, loc, ltTensor, payloadArgs[0 ], payloadArgs[1 ]);
897855 }
898856 if (auto leTensor = dyn_cast<AtenLeTensorOp>(op)) {
899- return createCompareTensorOp (b, loc, leTensor, payloadArgs[0 ],
900- payloadArgs[1 ]);
857+ return createCompareOp (b, loc, leTensor, payloadArgs[0 ], payloadArgs[1 ]);
901858 }
902859 if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
903- return createCompareTensorOp (b, loc, gtTensor, payloadArgs[0 ],
904- payloadArgs[1 ]);
860+ return createCompareOp (b, loc, gtTensor, payloadArgs[0 ], payloadArgs[1 ]);
905861 }
906862 if (auto geTensor = dyn_cast<AtenGeTensorOp>(op)) {
907- return createCompareTensorOp (b, loc, geTensor, payloadArgs[0 ],
908- payloadArgs[1 ]);
863+ return createCompareOp (b, loc, geTensor, payloadArgs[0 ], payloadArgs[1 ]);
909864 }
910865 if (auto eqTensor = dyn_cast<AtenEqTensorOp>(op)) {
911- return createCompareTensorOp (b, loc, eqTensor, payloadArgs[0 ],
912- payloadArgs[1 ]);
866+ return createCompareOp (b, loc, eqTensor, payloadArgs[0 ], payloadArgs[1 ]);
913867 }
914868 if (auto neTensor = dyn_cast<AtenNeTensorOp>(op)) {
915- return createCompareTensorOp (b, loc, neTensor, payloadArgs[0 ],
916- payloadArgs[1 ]);
869+ return createCompareOp (b, loc, neTensor, payloadArgs[0 ], payloadArgs[1 ]);
917870 }
918871 if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
919872 AtenDivTensorOp::Adaptor adaptor (operands);
@@ -996,27 +949,27 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
996949 }
997950
998951 if (auto gtScalar = dyn_cast<AtenGtScalarOp>(op)) {
999- return createCompareScalarOp (b, loc, gtScalar, payloadArgs[0 ], operands[1 ]);
952+ return createCompareOp (b, loc, gtScalar, payloadArgs[0 ], operands[1 ]);
1000953 }
1001954
1002955 if (auto geScalar = dyn_cast<AtenGeScalarOp>(op)) {
1003- return createCompareScalarOp (b, loc, geScalar, payloadArgs[0 ], operands[1 ]);
956+ return createCompareOp (b, loc, geScalar, payloadArgs[0 ], operands[1 ]);
1004957 }
1005958
1006959 if (auto eqScalar = dyn_cast<AtenEqScalarOp>(op)) {
1007- return createCompareScalarOp (b, loc, eqScalar, payloadArgs[0 ], operands[1 ]);
960+ return createCompareOp (b, loc, eqScalar, payloadArgs[0 ], operands[1 ]);
1008961 }
1009962
1010963 if (auto neScalar = dyn_cast<AtenNeScalarOp>(op)) {
1011- return createCompareScalarOp (b, loc, neScalar, payloadArgs[0 ], operands[1 ]);
964+ return createCompareOp (b, loc, neScalar, payloadArgs[0 ], operands[1 ]);
1012965 }
1013966
1014967 if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
1015- return createCompareScalarOp (b, loc, ltScalar, payloadArgs[0 ], operands[1 ]);
968+ return createCompareOp (b, loc, ltScalar, payloadArgs[0 ], operands[1 ]);
1016969 }
1017970
1018971 if (auto leScalar = dyn_cast<AtenLeScalarOp>(op)) {
1019- return createCompareScalarOp (b, loc, leScalar, payloadArgs[0 ], operands[1 ]);
972+ return createCompareOp (b, loc, leScalar, payloadArgs[0 ], operands[1 ]);
1020973 }
1021974
1022975 if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {
0 commit comments