|
10 | 10 | #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" |
11 | 11 |
|
12 | 12 | #include "../PassDetail.h" |
| 13 | +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" |
13 | 14 | #include "mlir/Dialect/Linalg/IR/LinalgOps.h" |
14 | 15 | #include "mlir/Dialect/Math/IR/Math.h" |
15 | 16 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
@@ -1378,11 +1379,30 @@ static Value createLinalgPayloadCalculationForElementwiseOp( |
1378 | 1379 | } |
1379 | 1380 | Type elementType = payloadArgs[0].getType(); |
1380 | 1381 | Value constZero = |
1381 | | - b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.0)); |
| 1382 | + b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType)); |
1382 | 1383 | Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, |
1383 | 1384 | payloadArgs[0], constZero); |
1384 | 1385 | return b.create<SelectOp>(loc, pred, payloadArgs[0], constZero); |
1385 | 1386 | } |
| 1387 | + if (auto lrelu = dyn_cast<AtenLeakyReluOp>(op)) { |
| 1388 | + if (!lrelu.getType() |
| 1389 | + .cast<ValueTensorType>() |
| 1390 | + .getDtype() |
| 1391 | + .isa<mlir::FloatType>()) { |
| 1392 | + lrelu.emitError("unimplemented: non-floating point dtype"); |
| 1393 | + return nullptr; |
| 1394 | + } |
| 1395 | + Type elementType = payloadArgs[0].getType(); |
| 1396 | + Value constZero = |
| 1397 | + b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType)); |
| 1398 | + Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, |
| 1399 | + payloadArgs[0], constZero); |
| 1400 | + Value positivePart = b.create<SelectOp>(loc, pred, payloadArgs[0], constZero); |
| 1401 | + Value negativePart = b.create<SelectOp>(loc, pred, constZero, payloadArgs[0]); |
| 1402 | + Value scale = convertScalarToDtype(b, loc, operands[1], elementType); |
| 1403 | + Value scaledNegativePart = b.create<arith::MulFOp>(loc, negativePart, scale); |
| 1404 | + return b.create<arith::AddFOp>(loc, positivePart, scaledNegativePart); |
| 1405 | + } |
1386 | 1406 | if (auto gelu = dyn_cast<AtenGeluOp>(op)) { |
1387 | 1407 | if (!gelu.getType() |
1388 | 1408 | .cast<ValueTensorType>() |
@@ -1812,7 +1832,7 @@ struct ConvertElementwiseOp : ConversionPattern { |
1812 | 1832 | LogicalResult |
1813 | 1833 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
1814 | 1834 | ConversionPatternRewriter &rewriter) const override { |
1815 | | - if (!isa<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp, |
| 1835 | + if (!isa<AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, |
1816 | 1836 | AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, |
1817 | 1837 | AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp, |
1818 | 1838 | AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, |
@@ -2969,7 +2989,7 @@ class ConvertTorchToLinalg |
2969 | 2989 | target.addIllegalOp<AtenBatchNormOp>(); |
2970 | 2990 | patterns.add<ConvertAtenBatchNormOp>(typeConverter, context); |
2971 | 2991 | target.addIllegalOp< |
2972 | | - AtenTanhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, |
| 2992 | + AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, |
2973 | 2993 | AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, |
2974 | 2994 | AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, |
2975 | 2995 | AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, |
|
0 commit comments