Skip to content

Commit d2970f7

Browse files
Revert "[TorchToLinalg] Implement lowering of torch.aten.rrelu_with_noise and torch.aten.rrelu_with_noise_backward ops (fix) (#3748)"
This reverts commit 54d9e24.
1 parent 00adc10 commit d2970f7

File tree

9 files changed

+0
-630
lines changed

9 files changed

+0
-630
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 0 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -309,59 +309,6 @@ def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [
309309
}];
310310
}
311311

312-
def Torch_AtenRreluWithNoiseOp : Torch_Op<"aten.rrelu_with_noise", [
313-
AllowsTypeRefinement
314-
]> {
315-
let summary = "Generated op for `aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
316-
let arguments = (ins
317-
AnyTorchTensorType:$self,
318-
AnyTorchTensorType:$noise,
319-
AnyTorchScalarType:$lower,
320-
AnyTorchScalarType:$upper,
321-
Torch_BoolType:$training,
322-
AnyTorchOptionalGeneratorType:$generator
323-
);
324-
let results = (outs
325-
AnyTorchOptionalTensorType:$result
326-
);
327-
let hasCustomAssemblyFormat = 1;
328-
let extraClassDefinition = [{
329-
ParseResult AtenRreluWithNoiseOp::parse(OpAsmParser &parser, OperationState &result) {
330-
return parseDefaultTorchOp(parser, result, 6, 1);
331-
}
332-
void AtenRreluWithNoiseOp::print(OpAsmPrinter &printer) {
333-
printDefaultTorchOp(printer, *this, 6, 1);
334-
}
335-
}];
336-
}
337-
338-
def Torch_AtenRreluWithNoise_Op : Torch_Op<"aten.rrelu_with_noise_", [
339-
IsTrailingUnderscoreInplaceVariant,
340-
AllowsTypeRefinement
341-
]> {
342-
let summary = "Generated op for `aten::rrelu_with_noise_ : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
343-
let arguments = (ins
344-
Torch_NonValueTensorType:$self,
345-
Torch_NonValueTensorType:$noise,
346-
AnyTorchScalarType:$lower,
347-
AnyTorchScalarType:$upper,
348-
Torch_BoolType:$training,
349-
AnyTorchOptionalGeneratorType:$generator
350-
);
351-
let results = (outs
352-
AnyTorchOptionalNonValueTensorType:$result
353-
);
354-
let hasCustomAssemblyFormat = 1;
355-
let extraClassDefinition = [{
356-
ParseResult AtenRreluWithNoise_Op::parse(OpAsmParser &parser, OperationState &result) {
357-
return parseDefaultTorchOp(parser, result, 6, 1);
358-
}
359-
void AtenRreluWithNoise_Op::print(OpAsmPrinter &printer) {
360-
printDefaultTorchOp(printer, *this, 6, 1);
361-
}
362-
}];
363-
}
364-
365312
def Torch_AtenCeluOp : Torch_Op<"aten.celu", [
366313
AllowsTypeRefinement,
367314
HasValueSemantics,
@@ -17488,64 +17435,6 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [
1748817435
}];
1748917436
}
1749017437

17491-
def Torch_AtenRreluWithNoiseBackwardOp : Torch_Op<"aten.rrelu_with_noise_backward", [
17492-
AllowsTypeRefinement,
17493-
HasValueSemantics,
17494-
ReadOnly
17495-
]> {
17496-
let summary = "Generated op for `aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)`";
17497-
let arguments = (ins
17498-
AnyTorchTensorType:$grad_output,
17499-
AnyTorchTensorType:$self,
17500-
AnyTorchTensorType:$noise,
17501-
AnyTorchScalarType:$lower,
17502-
AnyTorchScalarType:$upper,
17503-
Torch_BoolType:$training,
17504-
Torch_BoolType:$self_is_result
17505-
);
17506-
let results = (outs
17507-
AnyTorchOptionalTensorType:$result
17508-
);
17509-
let hasCustomAssemblyFormat = 1;
17510-
let extraClassDefinition = [{
17511-
ParseResult AtenRreluWithNoiseBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
17512-
return parseDefaultTorchOp(parser, result, 7, 1);
17513-
}
17514-
void AtenRreluWithNoiseBackwardOp::print(OpAsmPrinter &printer) {
17515-
printDefaultTorchOp(printer, *this, 7, 1);
17516-
}
17517-
}];
17518-
}
17519-
17520-
def Torch_AtenRreluWithNoiseFunctionalOp : Torch_Op<"aten.rrelu_with_noise_functional", [
17521-
AllowsTypeRefinement,
17522-
HasValueSemantics,
17523-
ReadOnly
17524-
]> {
17525-
let summary = "Generated op for `aten::rrelu_with_noise_functional : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor, Tensor)`";
17526-
let arguments = (ins
17527-
AnyTorchTensorType:$self,
17528-
AnyTorchTensorType:$noise,
17529-
AnyTorchScalarType:$lower,
17530-
AnyTorchScalarType:$upper,
17531-
Torch_BoolType:$training,
17532-
AnyTorchOptionalGeneratorType:$generator
17533-
);
17534-
let results = (outs
17535-
AnyTorchOptionalTensorType:$result0,
17536-
AnyTorchOptionalTensorType:$noise_out
17537-
);
17538-
let hasCustomAssemblyFormat = 1;
17539-
let extraClassDefinition = [{
17540-
ParseResult AtenRreluWithNoiseFunctionalOp::parse(OpAsmParser &parser, OperationState &result) {
17541-
return parseDefaultTorchOp(parser, result, 6, 2);
17542-
}
17543-
void AtenRreluWithNoiseFunctionalOp::print(OpAsmPrinter &printer) {
17544-
printDefaultTorchOp(printer, *this, 6, 2);
17545-
}
17546-
}];
17547-
}
17548-
1754917438
def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [
1755017439
AllowsTypeRefinement,
1755117440
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6694,10 +6694,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
66946694
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
66956695
" return %0 : !torch.list<int>\n"
66966696
" }\n"
6697-
" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.float, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.list<int> {\n"
6698-
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
6699-
" return %0 : !torch.list<int>\n"
6700-
" }\n"
67016697
" func.func @\"__torch_mlir_shape_fn.aten.hardtanh_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float) -> !torch.list<int> {\n"
67026698
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
67036699
" return %0 : !torch.list<int>\n"
@@ -7300,16 +7296,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
73007296
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
73017297
" return %0 : !torch.list<int>\n"
73027298
" }\n"
7303-
" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.list<int> {\n"
7304-
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
7305-
" return %0 : !torch.list<int>\n"
7306-
" }\n"
7307-
" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise_functional\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.tuple<list<int>, list<int>> {\n"
7308-
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
7309-
" %1 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
7310-
" %2 = torch.prim.TupleConstruct %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
7311-
" return %2 : !torch.tuple<list<int>, list<int>>\n"
7312-
" }\n"
73137299
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
73147300
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
73157301
" return %0 : !torch.list<int>\n"
@@ -12424,14 +12410,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1242412410
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
1242512411
" return %4 : !torch.int\n"
1242612412
" }\n"
12427-
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number, %arg4: !torch.number, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.int {\n"
12428-
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12429-
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12430-
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
12431-
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
12432-
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
12433-
" return %4 : !torch.int\n"
12434-
" }\n"
1243512413
" func.func @\"__torch_mlir_dtype_fn.aten.lift_fresh_copy\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
1243612414
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1243712415
" return %0#1 : !torch.int\n"
@@ -12622,21 +12600,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1262212600
" }\n"
1262312601
" return %0#1 : !torch.int\n"
1262412602
" }\n"
12625-
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise_functional\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.tuple<int, int> {\n"
12626-
" %none = torch.constant.none\n"
12627-
" %str = torch.constant.str \"AssertionError: \"\n"
12628-
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12629-
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12630-
" %2 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n"
12631-
" torch.prim.If %2 -> () {\n"
12632-
" torch.prim.If.yield\n"
12633-
" } else {\n"
12634-
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12635-
" torch.prim.If.yield\n"
12636-
" }\n"
12637-
" %3 = torch.prim.TupleConstruct %0#1, %1#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
12638-
" return %3 : !torch.tuple<int, int>\n"
12639-
" }\n"
1264012603
" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
1264112604
" %none = torch.constant.none\n"
1264212605
" %str = torch.constant.str \"AssertionError: \"\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 0 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -3675,59 +3675,6 @@ class DecomposeAtenLeakyReluBackwardOp
36753675
};
36763676
} // namespace
36773677

3678-
namespace {
3679-
class DecomposeAtenRreluWithNoiseBackwardOp
3680-
: public OpRewritePattern<AtenRreluWithNoiseBackwardOp> {
3681-
public:
3682-
using OpRewritePattern::OpRewritePattern;
3683-
LogicalResult matchAndRewrite(AtenRreluWithNoiseBackwardOp op,
3684-
PatternRewriter &rewriter) const override {
3685-
Location loc = op.getLoc();
3686-
Value gradOutput = op.getGradOutput();
3687-
Value self = op.getSelf();
3688-
Value noise = op.getNoise();
3689-
auto resType = cast<BaseTensorType>(op.getType());
3690-
if (!resType.hasDtype()) {
3691-
return rewriter.notifyMatchFailure(op, "result should have dtype");
3692-
}
3693-
3694-
bool training;
3695-
if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
3696-
return rewriter.notifyMatchFailure(op,
3697-
"training should be a bool constant");
3698-
}
3699-
3700-
bool selfIsResult = false;
3701-
if (!matchPattern(op.getSelfIsResult(),
3702-
m_TorchConstantBool(&selfIsResult)) ||
3703-
selfIsResult)
3704-
return rewriter.notifyMatchFailure(
3705-
op, "unimplemented: self_is_result should be false");
3706-
3707-
double lower, upper;
3708-
if (!matchPattern(op.getLower(), m_TorchConstantFloat(&lower)) ||
3709-
!matchPattern(op.getUpper(), m_TorchConstantFloat(&upper))) {
3710-
return rewriter.notifyMatchFailure(
3711-
op, "lower and upper should be float constants");
3712-
}
3713-
3714-
if (training && (upper - lower > 0.000001)) {
3715-
Value rreluWithNoiseBackwardOutput =
3716-
rewriter.create<AtenMulTensorOp>(loc, resType, gradOutput, noise);
3717-
rewriter.replaceOp(op, rreluWithNoiseBackwardOutput);
3718-
} else {
3719-
double negative_slope = (upper + lower) / 2;
3720-
Value cstNegativeSlope = rewriter.create<ConstantFloatOp>(
3721-
loc, rewriter.getF64FloatAttr(negative_slope));
3722-
rewriter.replaceOpWithNewOp<AtenLeakyReluBackwardOp>(
3723-
op, resType, gradOutput, self, cstNegativeSlope,
3724-
op.getSelfIsResult());
3725-
}
3726-
return success();
3727-
}
3728-
};
3729-
} // namespace
3730-
37313678
namespace {
37323679
class DecomposeAtenPreluOp : public OpRewritePattern<AtenPreluOp> {
37333680
public:
@@ -3823,109 +3770,6 @@ class DecomposeAtenRreluOp : public OpRewritePattern<AtenRreluOp> {
38233770
};
38243771
} // namespace
38253772

3826-
namespace {
3827-
class DecomposeAtenRreluWithNoiseOp
3828-
: public OpRewritePattern<AtenRreluWithNoiseOp> {
3829-
public:
3830-
using OpRewritePattern::OpRewritePattern;
3831-
LogicalResult matchAndRewrite(AtenRreluWithNoiseOp op,
3832-
PatternRewriter &rewriter) const override {
3833-
Location loc = op.getLoc();
3834-
Value self = op.getSelf();
3835-
Value noise = op.getNoise();
3836-
Value lower = op.getLower();
3837-
Value upper = op.getUpper();
3838-
auto resType = cast<BaseTensorType>(op.getType());
3839-
Value cstNone = rewriter.create<ConstantNoneOp>(loc);
3840-
Value cstFalse =
3841-
rewriter.create<ConstantBoolOp>(loc, rewriter.getBoolAttr(false));
3842-
Value result =
3843-
rewriter
3844-
.create<AtenRreluWithNoiseFunctionalOp>(
3845-
loc, resType, self, noise, lower, upper, cstFalse, cstNone)
3846-
->getResult(0);
3847-
rewriter.replaceOp(op, result);
3848-
return success();
3849-
}
3850-
};
3851-
} // namespace
3852-
3853-
namespace {
3854-
class DecomposeAtenRreluWithNoiseFunctionalOp
3855-
: public OpRewritePattern<AtenRreluWithNoiseFunctionalOp> {
3856-
public:
3857-
using OpRewritePattern::OpRewritePattern;
3858-
LogicalResult matchAndRewrite(AtenRreluWithNoiseFunctionalOp op,
3859-
PatternRewriter &rewriter) const override {
3860-
Location loc = op.getLoc();
3861-
Value self = op.getSelf();
3862-
Value noise = op.getNoise();
3863-
Value lower = op.getLower();
3864-
Value upper = op.getUpper();
3865-
auto resType = cast<BaseTensorType>(op.getResultTypes()[0]);
3866-
if (!resType.hasDtype()) {
3867-
return rewriter.notifyMatchFailure(op, "result should have dtype");
3868-
}
3869-
3870-
bool training;
3871-
if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
3872-
return rewriter.notifyMatchFailure(op, "training should be a constant");
3873-
}
3874-
3875-
Value constantZeroFloat =
3876-
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
3877-
Value constantOneFloat =
3878-
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
3879-
Value constantTwoFloat =
3880-
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(2.0));
3881-
3882-
Value alpha;
3883-
if (training) {
3884-
Value none = rewriter.create<ConstantNoneOp>(loc);
3885-
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
3886-
loc, resType, self, constantZeroFloat, /*dtype=*/none,
3887-
/*layout=*/none,
3888-
/*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none);
3889-
alpha = rewriter.create<AtenUniformOp>(loc, resType, emptyTensor,
3890-
/*from=*/lower, /*to=*/upper,
3891-
/*generator=*/none);
3892-
} else {
3893-
Value half = rewriter.create<AtenAddOp>(loc, constantTwoFloat.getType(),
3894-
lower, upper);
3895-
alpha = rewriter.create<AtenDivOp>(loc, constantTwoFloat.getType(), half,
3896-
constantTwoFloat);
3897-
}
3898-
3899-
Value zeroTensor =
3900-
createRank0Tensor(rewriter, loc, resType, constantZeroFloat);
3901-
Value positiveOutput =
3902-
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, self);
3903-
3904-
Value scaledSelf;
3905-
if (training) {
3906-
scaledSelf = rewriter.create<AtenMulTensorOp>(loc, resType, self, alpha);
3907-
auto boolResType = resType.getWithSizesAndDtype(resType.getSizes(),
3908-
rewriter.getI1Type());
3909-
Value oneTensor =
3910-
createRank0Tensor(rewriter, loc, resType, constantOneFloat);
3911-
Value not_positive = rewriter.create<AtenLeScalarOp>(
3912-
loc, boolResType, self, constantZeroFloat);
3913-
noise = rewriter.create<AtenWhereSelfOp>(loc, resType, not_positive,
3914-
alpha, oneTensor);
3915-
} else {
3916-
scaledSelf = rewriter.create<AtenMulScalarOp>(loc, resType, self, alpha);
3917-
}
3918-
3919-
Value negativeOutput =
3920-
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledSelf);
3921-
Value rreluOutput = rewriter.create<AtenAddTensorOp>(
3922-
loc, resType, positiveOutput, negativeOutput, constantOneFloat);
3923-
rewriter.replaceOp(op, {rreluOutput, noise});
3924-
return success();
3925-
}
3926-
};
3927-
} // namespace
3928-
39293773
// CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1))
39303774
namespace {
39313775
class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {
@@ -11590,11 +11434,6 @@ class DecomposeComplexOpsPass
1159011434
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
1159111435
addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
1159211436
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluOp>(patterns);
11593-
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseOp>(patterns);
11594-
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseFunctionalOp>(
11595-
patterns);
11596-
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseBackwardOp>(
11597-
patterns);
1159811437
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
1159911438
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast1dOp>(patterns);
1160011439
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast2dOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -500,9 +500,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
500500
target.addIllegalOp<AtenPadOp>();
501501
target.addIllegalOp<AtenPreluOp>();
502502
target.addIllegalOp<AtenRreluOp>();
503-
target.addIllegalOp<AtenRreluWithNoiseOp>();
504-
target.addIllegalOp<AtenRreluWithNoiseFunctionalOp>();
505-
target.addIllegalOp<AtenRreluWithNoiseBackwardOp>();
506503
target.addIllegalOp<AtenCeluOp>();
507504
target.addIllegalOp<AtenToDtypeLayoutOp>();
508505
target.addIllegalOp<AtenToDeviceOp>();

0 commit comments

Comments
 (0)