@@ -3675,59 +3675,6 @@ class DecomposeAtenLeakyReluBackwardOp
36753675};
36763676} // namespace
36773677
3678- namespace {
3679- class DecomposeAtenRreluWithNoiseBackwardOp
3680- : public OpRewritePattern<AtenRreluWithNoiseBackwardOp> {
3681- public:
3682- using OpRewritePattern::OpRewritePattern;
3683- LogicalResult matchAndRewrite (AtenRreluWithNoiseBackwardOp op,
3684- PatternRewriter &rewriter) const override {
3685- Location loc = op.getLoc ();
3686- Value gradOutput = op.getGradOutput ();
3687- Value self = op.getSelf ();
3688- Value noise = op.getNoise ();
3689- auto resType = cast<BaseTensorType>(op.getType ());
3690- if (!resType.hasDtype ()) {
3691- return rewriter.notifyMatchFailure (op, " result should have dtype" );
3692- }
3693-
3694- bool training;
3695- if (!matchPattern (op.getTraining (), m_TorchConstantBool (&training))) {
3696- return rewriter.notifyMatchFailure (op,
3697- " training should be a bool constant" );
3698- }
3699-
3700- bool selfIsResult = false ;
3701- if (!matchPattern (op.getSelfIsResult (),
3702- m_TorchConstantBool (&selfIsResult)) ||
3703- selfIsResult)
3704- return rewriter.notifyMatchFailure (
3705- op, " unimplemented: self_is_result should be false" );
3706-
3707- double lower, upper;
3708- if (!matchPattern (op.getLower (), m_TorchConstantFloat (&lower)) ||
3709- !matchPattern (op.getUpper (), m_TorchConstantFloat (&upper))) {
3710- return rewriter.notifyMatchFailure (
3711- op, " lower and upper should be float constants" );
3712- }
3713-
3714- if (training && (upper - lower > 0.000001 )) {
3715- Value rreluWithNoiseBackwardOutput =
3716- rewriter.create <AtenMulTensorOp>(loc, resType, gradOutput, noise);
3717- rewriter.replaceOp (op, rreluWithNoiseBackwardOutput);
3718- } else {
3719- double negative_slope = (upper + lower) / 2 ;
3720- Value cstNegativeSlope = rewriter.create <ConstantFloatOp>(
3721- loc, rewriter.getF64FloatAttr (negative_slope));
3722- rewriter.replaceOpWithNewOp <AtenLeakyReluBackwardOp>(
3723- op, resType, gradOutput, self, cstNegativeSlope,
3724- op.getSelfIsResult ());
3725- }
3726- return success ();
3727- }
3728- };
3729- } // namespace
3730-
37313678namespace {
37323679class DecomposeAtenPreluOp : public OpRewritePattern <AtenPreluOp> {
37333680public:
@@ -3823,109 +3770,6 @@ class DecomposeAtenRreluOp : public OpRewritePattern<AtenRreluOp> {
38233770};
38243771} // namespace
38253772
3826- namespace {
3827- class DecomposeAtenRreluWithNoiseOp
3828- : public OpRewritePattern<AtenRreluWithNoiseOp> {
3829- public:
3830- using OpRewritePattern::OpRewritePattern;
3831- LogicalResult matchAndRewrite (AtenRreluWithNoiseOp op,
3832- PatternRewriter &rewriter) const override {
3833- Location loc = op.getLoc ();
3834- Value self = op.getSelf ();
3835- Value noise = op.getNoise ();
3836- Value lower = op.getLower ();
3837- Value upper = op.getUpper ();
3838- auto resType = cast<BaseTensorType>(op.getType ());
3839- Value cstNone = rewriter.create <ConstantNoneOp>(loc);
3840- Value cstFalse =
3841- rewriter.create <ConstantBoolOp>(loc, rewriter.getBoolAttr (false ));
3842- Value result =
3843- rewriter
3844- .create <AtenRreluWithNoiseFunctionalOp>(
3845- loc, resType, self, noise, lower, upper, cstFalse, cstNone)
3846- ->getResult (0 );
3847- rewriter.replaceOp (op, result);
3848- return success ();
3849- }
3850- };
3851- } // namespace
3852-
3853- namespace {
3854- class DecomposeAtenRreluWithNoiseFunctionalOp
3855- : public OpRewritePattern<AtenRreluWithNoiseFunctionalOp> {
3856- public:
3857- using OpRewritePattern::OpRewritePattern;
3858- LogicalResult matchAndRewrite (AtenRreluWithNoiseFunctionalOp op,
3859- PatternRewriter &rewriter) const override {
3860- Location loc = op.getLoc ();
3861- Value self = op.getSelf ();
3862- Value noise = op.getNoise ();
3863- Value lower = op.getLower ();
3864- Value upper = op.getUpper ();
3865- auto resType = cast<BaseTensorType>(op.getResultTypes ()[0 ]);
3866- if (!resType.hasDtype ()) {
3867- return rewriter.notifyMatchFailure (op, " result should have dtype" );
3868- }
3869-
3870- bool training;
3871- if (!matchPattern (op.getTraining (), m_TorchConstantBool (&training))) {
3872- return rewriter.notifyMatchFailure (op, " training should be a constant" );
3873- }
3874-
3875- Value constantZeroFloat =
3876- rewriter.create <ConstantFloatOp>(loc, rewriter.getF64FloatAttr (0.0 ));
3877- Value constantOneFloat =
3878- rewriter.create <ConstantFloatOp>(loc, rewriter.getF64FloatAttr (1.0 ));
3879- Value constantTwoFloat =
3880- rewriter.create <ConstantFloatOp>(loc, rewriter.getF64FloatAttr (2.0 ));
3881-
3882- Value alpha;
3883- if (training) {
3884- Value none = rewriter.create <ConstantNoneOp>(loc);
3885- Value emptyTensor = rewriter.create <AtenFullLikeOp>(
3886- loc, resType, self, constantZeroFloat, /* dtype=*/ none,
3887- /* layout=*/ none,
3888- /* device=*/ none, /* pin_memoty=*/ none, /* memory_format=*/ none);
3889- alpha = rewriter.create <AtenUniformOp>(loc, resType, emptyTensor,
3890- /* from=*/ lower, /* to=*/ upper,
3891- /* generator=*/ none);
3892- } else {
3893- Value half = rewriter.create <AtenAddOp>(loc, constantTwoFloat.getType (),
3894- lower, upper);
3895- alpha = rewriter.create <AtenDivOp>(loc, constantTwoFloat.getType (), half,
3896- constantTwoFloat);
3897- }
3898-
3899- Value zeroTensor =
3900- createRank0Tensor (rewriter, loc, resType, constantZeroFloat);
3901- Value positiveOutput =
3902- rewriter.create <AtenMaximumOp>(loc, resType, zeroTensor, self);
3903-
3904- Value scaledSelf;
3905- if (training) {
3906- scaledSelf = rewriter.create <AtenMulTensorOp>(loc, resType, self, alpha);
3907- auto boolResType = resType.getWithSizesAndDtype (resType.getSizes (),
3908- rewriter.getI1Type ());
3909- Value oneTensor =
3910- createRank0Tensor (rewriter, loc, resType, constantOneFloat);
3911- Value not_positive = rewriter.create <AtenLeScalarOp>(
3912- loc, boolResType, self, constantZeroFloat);
3913- noise = rewriter.create <AtenWhereSelfOp>(loc, resType, not_positive,
3914- alpha, oneTensor);
3915- } else {
3916- scaledSelf = rewriter.create <AtenMulScalarOp>(loc, resType, self, alpha);
3917- }
3918-
3919- Value negativeOutput =
3920- rewriter.create <AtenMinimumOp>(loc, resType, zeroTensor, scaledSelf);
3921- Value rreluOutput = rewriter.create <AtenAddTensorOp>(
3922- loc, resType, positiveOutput, negativeOutput, constantOneFloat);
3923- rewriter.replaceOp (op, {rreluOutput, noise});
3924- return success ();
3925- }
3926- };
3927- } // namespace
3928-
39293773// CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1))
39303774namespace {
39313775class DecomposeAtenCeluOp : public OpRewritePattern <AtenCeluOp> {
@@ -11590,11 +11434,6 @@ class DecomposeComplexOpsPass
1159011434 addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
1159111435 addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
1159211436 addPatternIfTargetOpIsIllegal<DecomposeAtenRreluOp>(patterns);
11593- addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseOp>(patterns);
11594- addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseFunctionalOp>(
11595- patterns);
11596- addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseBackwardOp>(
11597- patterns);
1159811437 addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
1159911438 addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast1dOp>(patterns);
1160011439 addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast2dOp>(patterns);
0 commit comments