Skip to content

Commit 54d9e24

Browse files
authored
[TorchToLinalg] Implement lowering of torch.aten.rrelu_with_noise and torch.aten.rrelu_with_noise_backward ops (fix) (#3748)
1 parent ad9dfe9 commit 54d9e24

File tree

9 files changed

+568
-0
lines changed

9 files changed

+568
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,61 @@ def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [
309309
}];
310310
}
311311

312+
def Torch_AtenRreluWithNoiseOp : Torch_Op<"aten.rrelu_with_noise", [
313+
AllowsTypeRefinement,
314+
HasValueSemantics,
315+
ReadOnly
316+
]> {
317+
let summary = "Generated op for `aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
318+
let arguments = (ins
319+
AnyTorchTensorType:$self,
320+
AnyTorchTensorType:$noise,
321+
AnyTorchScalarType:$lower,
322+
AnyTorchScalarType:$upper,
323+
Torch_BoolType:$training,
324+
AnyTorchOptionalGeneratorType:$generator
325+
);
326+
let results = (outs
327+
AnyTorchOptionalTensorType:$result
328+
);
329+
let hasCustomAssemblyFormat = 1;
330+
let extraClassDefinition = [{
331+
ParseResult AtenRreluWithNoiseOp::parse(OpAsmParser &parser, OperationState &result) {
332+
return parseDefaultTorchOp(parser, result, 6, 1);
333+
}
334+
void AtenRreluWithNoiseOp::print(OpAsmPrinter &printer) {
335+
printDefaultTorchOp(printer, *this, 6, 1);
336+
}
337+
}];
338+
}
339+
340+
def Torch_AtenRreluWithNoise_Op : Torch_Op<"aten.rrelu_with_noise_", [
341+
IsTrailingUnderscoreInplaceVariant,
342+
AllowsTypeRefinement
343+
]> {
344+
let summary = "Generated op for `aten::rrelu_with_noise_ : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
345+
let arguments = (ins
346+
Torch_NonValueTensorType:$self,
347+
Torch_NonValueTensorType:$noise,
348+
AnyTorchScalarType:$lower,
349+
AnyTorchScalarType:$upper,
350+
Torch_BoolType:$training,
351+
AnyTorchOptionalGeneratorType:$generator
352+
);
353+
let results = (outs
354+
AnyTorchOptionalNonValueTensorType:$result
355+
);
356+
let hasCustomAssemblyFormat = 1;
357+
let extraClassDefinition = [{
358+
ParseResult AtenRreluWithNoise_Op::parse(OpAsmParser &parser, OperationState &result) {
359+
return parseDefaultTorchOp(parser, result, 6, 1);
360+
}
361+
void AtenRreluWithNoise_Op::print(OpAsmPrinter &printer) {
362+
printDefaultTorchOp(printer, *this, 6, 1);
363+
}
364+
}];
365+
}
366+
312367
def Torch_AtenCeluOp : Torch_Op<"aten.celu", [
313368
AllowsTypeRefinement,
314369
HasValueSemantics,
@@ -16814,6 +16869,35 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [
1681416869
}];
1681516870
}
1681616871

16872+
def Torch_AtenRreluWithNoiseBackwardOp : Torch_Op<"aten.rrelu_with_noise_backward", [
16873+
AllowsTypeRefinement,
16874+
HasValueSemantics,
16875+
ReadOnly
16876+
]> {
16877+
let summary = "Generated op for `aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)`";
16878+
let arguments = (ins
16879+
AnyTorchTensorType:$grad_output,
16880+
AnyTorchTensorType:$self,
16881+
AnyTorchTensorType:$noise,
16882+
AnyTorchScalarType:$lower,
16883+
AnyTorchScalarType:$upper,
16884+
Torch_BoolType:$training,
16885+
Torch_BoolType:$self_is_result
16886+
);
16887+
let results = (outs
16888+
AnyTorchOptionalTensorType:$result
16889+
);
16890+
let hasCustomAssemblyFormat = 1;
16891+
let extraClassDefinition = [{
16892+
ParseResult AtenRreluWithNoiseBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
16893+
return parseDefaultTorchOp(parser, result, 7, 1);
16894+
}
16895+
void AtenRreluWithNoiseBackwardOp::print(OpAsmPrinter &printer) {
16896+
printDefaultTorchOp(printer, *this, 7, 1);
16897+
}
16898+
}];
16899+
}
16900+
1681716901
def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [
1681816902
AllowsTypeRefinement,
1681916903
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6683,6 +6683,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
66836683
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
66846684
" return %0 : !torch.list<int>\n"
66856685
" }\n"
6686+
" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.float, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.list<int> {\n"
6687+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
6688+
" return %0 : !torch.list<int>\n"
6689+
" }\n"
66866690
" func.func @\"__torch_mlir_shape_fn.aten.hardtanh_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float) -> !torch.list<int> {\n"
66876691
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
66886692
" return %0 : !torch.list<int>\n"
@@ -7285,6 +7289,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
72857289
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
72867290
" return %0 : !torch.list<int>\n"
72877291
" }\n"
7292+
" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.list<int> {\n"
7293+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
7294+
" return %0 : !torch.list<int>\n"
7295+
" }\n"
72887296
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
72897297
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
72907298
" return %0 : !torch.list<int>\n"
@@ -12055,6 +12063,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1205512063
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
1205612064
" return %4 : !torch.int\n"
1205712065
" }\n"
12066+
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number, %arg4: !torch.number, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.int {\n"
12067+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12068+
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12069+
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
12070+
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
12071+
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
12072+
" return %4 : !torch.int\n"
12073+
" }\n"
1205812074
" func.func @\"__torch_mlir_dtype_fn.aten.lift_fresh_copy\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
1205912075
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1206012076
" return %0#1 : !torch.int\n"
@@ -12247,6 +12263,47 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1224712263
" }\n"
1224812264
" return %0#1 : !torch.int\n"
1224912265
" }\n"
12266+
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n"
12267+
" %none = torch.constant.none\n"
12268+
" %str = torch.constant.str \"AssertionError: \"\n"
12269+
" %true = torch.constant.bool true\n"
12270+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12271+
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12272+
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
12273+
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
12274+
" torch.prim.If.yield %true : !torch.bool\n"
12275+
" } else {\n"
12276+
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
12277+
" torch.prim.If.yield %7 : !torch.bool\n"
12278+
" }\n"
12279+
" torch.prim.If %3 -> () {\n"
12280+
" torch.prim.If.yield\n"
12281+
" } else {\n"
12282+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12283+
" torch.prim.If.yield\n"
12284+
" }\n"
12285+
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
12286+
" %5 = torch.prim.If %4 -> (!torch.bool) {\n"
12287+
" torch.prim.If.yield %true : !torch.bool\n"
12288+
" } else {\n"
12289+
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
12290+
" torch.prim.If.yield %7 : !torch.bool\n"
12291+
" }\n"
12292+
" torch.prim.If %5 -> () {\n"
12293+
" torch.prim.If.yield\n"
12294+
" } else {\n"
12295+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12296+
" torch.prim.If.yield\n"
12297+
" }\n"
12298+
" %6 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n"
12299+
" torch.prim.If %6 -> () {\n"
12300+
" torch.prim.If.yield\n"
12301+
" } else {\n"
12302+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12303+
" torch.prim.If.yield\n"
12304+
" }\n"
12305+
" return %0#1 : !torch.int\n"
12306+
" }\n"
1225012307
" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
1225112308
" %none = torch.constant.none\n"
1225212309
" %str = torch.constant.str \"AssertionError: \"\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3489,6 +3489,59 @@ class DecomposeAtenLeakyReluBackwardOp
34893489
};
34903490
} // namespace
34913491

3492+
namespace {
3493+
class DecomposeAtenRreluWithNoiseBackwardOp
3494+
: public OpRewritePattern<AtenRreluWithNoiseBackwardOp> {
3495+
public:
3496+
using OpRewritePattern::OpRewritePattern;
3497+
LogicalResult matchAndRewrite(AtenRreluWithNoiseBackwardOp op,
3498+
PatternRewriter &rewriter) const override {
3499+
Location loc = op.getLoc();
3500+
Value gradOutput = op.getGradOutput();
3501+
Value self = op.getSelf();
3502+
Value noise = op.getNoise();
3503+
auto resType = cast<BaseTensorType>(op.getType());
3504+
if (!resType.hasDtype()) {
3505+
return rewriter.notifyMatchFailure(op, "result should have dtype");
3506+
}
3507+
3508+
bool training;
3509+
if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
3510+
return rewriter.notifyMatchFailure(op,
3511+
"training should be a bool constant");
3512+
}
3513+
3514+
bool selfIsResult = false;
3515+
if (!matchPattern(op.getSelfIsResult(),
3516+
m_TorchConstantBool(&selfIsResult)) ||
3517+
selfIsResult)
3518+
return rewriter.notifyMatchFailure(
3519+
op, "unimplemented: self_is_result should be false");
3520+
3521+
double lower, upper;
3522+
if (!matchPattern(op.getLower(), m_TorchConstantFloat(&lower)) ||
3523+
!matchPattern(op.getUpper(), m_TorchConstantFloat(&upper))) {
3524+
return rewriter.notifyMatchFailure(
3525+
op, "lower and upper should be float constants");
3526+
}
3527+
3528+
if (training && (upper - lower > 0.000001)) {
3529+
Value rreluWithNoiseBackwardOutput =
3530+
rewriter.create<AtenMulTensorOp>(loc, resType, gradOutput, noise);
3531+
rewriter.replaceOp(op, rreluWithNoiseBackwardOutput);
3532+
} else {
3533+
double negative_slope = (upper + lower) / 2;
3534+
Value cstNegativeSlope = rewriter.create<ConstantFloatOp>(
3535+
loc, rewriter.getF64FloatAttr(negative_slope));
3536+
rewriter.replaceOpWithNewOp<AtenLeakyReluBackwardOp>(
3537+
op, resType, gradOutput, self, cstNegativeSlope,
3538+
op.getSelfIsResult());
3539+
}
3540+
return success();
3541+
}
3542+
};
3543+
} // namespace
3544+
34923545
namespace {
34933546
class DecomposeAtenPreluOp : public OpRewritePattern<AtenPreluOp> {
34943547
public:
@@ -3588,6 +3641,82 @@ class DecomposeAtenRreluOp : public OpRewritePattern<AtenRreluOp> {
35883641
};
35893642
} // namespace
35903643

3644+
namespace {
3645+
class DecomposeAtenRreluWithNoiseOp
3646+
: public OpRewritePattern<AtenRreluWithNoiseOp> {
3647+
public:
3648+
using OpRewritePattern::OpRewritePattern;
3649+
LogicalResult matchAndRewrite(AtenRreluWithNoiseOp op,
3650+
PatternRewriter &rewriter) const override {
3651+
Location loc = op.getLoc();
3652+
Value self = op.getSelf();
3653+
Value noise = op.getNoise();
3654+
Value lower = op.getLower();
3655+
Value upper = op.getUpper();
3656+
auto resType = cast<BaseTensorType>(op.getType());
3657+
if (!resType.hasDtype()) {
3658+
return rewriter.notifyMatchFailure(op, "result should have dtype");
3659+
}
3660+
3661+
bool training;
3662+
if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
3663+
return rewriter.notifyMatchFailure(op, "training should be a constant");
3664+
}
3665+
3666+
Value constantZeroFloat =
3667+
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
3668+
Value constantOneFloat =
3669+
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
3670+
Value constantTwoFloat =
3671+
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(2.0));
3672+
3673+
Value alpha;
3674+
if (training) {
3675+
Value none = rewriter.create<ConstantNoneOp>(loc);
3676+
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
3677+
loc, resType, self, constantZeroFloat, /*dtype=*/none,
3678+
/*layout=*/none,
3679+
/*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none);
3680+
alpha = rewriter.create<AtenUniformOp>(loc, resType, emptyTensor,
3681+
/*from=*/lower, /*to=*/upper,
3682+
/*generator=*/none);
3683+
} else {
3684+
Value half = rewriter.create<AtenAddOp>(loc, constantTwoFloat.getType(),
3685+
lower, upper);
3686+
alpha = rewriter.create<AtenDivOp>(loc, constantTwoFloat.getType(), half,
3687+
constantTwoFloat);
3688+
}
3689+
3690+
Value zeroTensor =
3691+
createRank0Tensor(rewriter, loc, resType, constantZeroFloat);
3692+
Value positiveOutput =
3693+
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, self);
3694+
3695+
Value scaledSelf;
3696+
if (training) {
3697+
scaledSelf = rewriter.create<AtenMulTensorOp>(loc, resType, self, alpha);
3698+
auto boolResType = resType.getWithSizesAndDtype(resType.getSizes(),
3699+
rewriter.getI1Type());
3700+
Value oneTensor =
3701+
createRank0Tensor(rewriter, loc, resType, constantOneFloat);
3702+
Value not_positive = rewriter.create<AtenLtScalarOp>(
3703+
loc, boolResType, self, constantZeroFloat);
3704+
noise = rewriter.create<AtenWhereSelfOp>(loc, resType, not_positive,
3705+
alpha, oneTensor);
3706+
} else {
3707+
scaledSelf = rewriter.create<AtenMulScalarOp>(loc, resType, self, alpha);
3708+
}
3709+
3710+
Value negativeOutput =
3711+
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledSelf);
3712+
Value rreluOutput = rewriter.create<AtenAddTensorOp>(
3713+
loc, resType, positiveOutput, negativeOutput, constantOneFloat);
3714+
rewriter.replaceOp(op, rreluOutput);
3715+
return success();
3716+
}
3717+
};
3718+
} // namespace
3719+
35913720
// CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1))
35923721
namespace {
35933722
class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {
@@ -9924,6 +10053,9 @@ class DecomposeComplexOpsPass
992410053
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
992510054
addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
992610055
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluOp>(patterns);
10056+
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseOp>(patterns);
10057+
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseBackwardOp>(
10058+
patterns);
992710059
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
992810060
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast1dOp>(patterns);
992910061
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast2dOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
498498
target.addIllegalOp<AtenPadOp>();
499499
target.addIllegalOp<AtenPreluOp>();
500500
target.addIllegalOp<AtenRreluOp>();
501+
target.addIllegalOp<AtenRreluWithNoiseOp>();
502+
target.addIllegalOp<AtenRreluWithNoiseBackwardOp>();
501503
target.addIllegalOp<AtenCeluOp>();
502504
target.addIllegalOp<AtenToDtypeLayoutOp>();
503505
target.addIllegalOp<AtenToDeviceOp>();

0 commit comments

Comments
 (0)