@@ -1676,6 +1676,42 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
16761676 return b.create <arith::SubIOp>(loc, lhs, scaled);
16771677 }
16781678 }
1679+ if (auto subScalar = dyn_cast<AtenSubScalarOp>(op)) {
1680+ Type dtype = converter->convertType (subScalar.getType ())
1681+ .cast <RankedTensorType>()
1682+ .getElementType ();
1683+ Value self = convertScalarToDtype (b, loc, payloadArgs[0 ], dtype);
1684+ Value other = convertScalarToDtype (b, loc, operands[1 ], dtype);
1685+ Value alpha = convertScalarToDtype (b, loc, operands[2 ], dtype);
1686+ if (dtype.isa <mlir::FloatType>()) {
1687+ Value mult = b.create <arith::MulFOp>(loc, other, alpha);
1688+ return b.create <arith::SubFOp>(loc, self, mult);
1689+ } else if (dtype.isa <mlir::IntegerType>()) {
1690+ Value mult = b.create <arith::MulIOp>(loc, other, alpha);
1691+ return b.create <arith::SubIOp>(loc, self, mult);
1692+ }
1693+ subScalar.emitError (" unimplemented: dtype other than float and integer "
1694+ " types are not supported." );
1695+ return nullptr ;
1696+ }
1697+ if (auto addScalar = dyn_cast<AtenAddScalarOp>(op)) {
1698+ Type dtype = converter->convertType (addScalar.getType ())
1699+ .cast <RankedTensorType>()
1700+ .getElementType ();
1701+ Value self = convertScalarToDtype (b, loc, payloadArgs[0 ], dtype);
1702+ Value other = convertScalarToDtype (b, loc, operands[1 ], dtype);
1703+ Value alpha = convertScalarToDtype (b, loc, operands[2 ], dtype);
1704+ if (dtype.isa <mlir::FloatType>()) {
1705+ Value mult = b.create <arith::MulFOp>(loc, other, alpha);
1706+ return b.create <arith::AddFOp>(loc, self, mult);
1707+ } else if (dtype.isa <mlir::IntegerType>()) {
1708+ Value mult = b.create <arith::MulIOp>(loc, other, alpha);
1709+ return b.create <arith::AddIOp>(loc, self, mult);
1710+ }
1711+ addScalar.emitError (" unimplemented: dtype other than float and integer "
1712+ " types are not supported." );
1713+ return nullptr ;
1714+ }
16791715 if (auto mul = dyn_cast<AtenMulTensorOp>(op)) {
16801716 AtenMulTensorOp::Adaptor adaptor (operands);
16811717 Type dtype = converter->convertType (mul.getType ())
@@ -2244,7 +2280,8 @@ struct ConvertElementwiseOp : ConversionPattern {
22442280 AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
22452281 AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenEqScalarOp,
22462282 AtenLtScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
2247- AtenEqTensorOp, AtenLtTensorOp>(op))
2283+ AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp>(
2284+ op))
22482285 return rewriter.notifyMatchFailure (op, " not a supported elementwise op" );
22492286
22502287 if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
0 commit comments