@@ -2148,6 +2148,62 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {
21482148};
21492149} // namespace
21502150
2151+ // Ref:
2152+ // https://github.com/pytorch/pytorch/blob/5314ae2660a778b87987030182f787bb6cb092c0/aten/src/ATen/native/transformers/attention.cpp#L663-L673
2153+ namespace {
2154+ class DecomposeAten_SafeSoftmaxOp
2155+ : public OpRewritePattern<Aten_SafeSoftmaxOp> {
2156+ public:
2157+ using OpRewritePattern::OpRewritePattern;
2158+ LogicalResult matchAndRewrite (Aten_SafeSoftmaxOp op,
2159+ PatternRewriter &rewriter) const override {
2160+ BaseTensorType resultTensorType = cast<BaseTensorType>(op.getType ());
2161+ if (!resultTensorType.hasDtype () || !resultTensorType.hasSizes ()) {
2162+ return rewriter.notifyMatchFailure (
2163+ op, " expected result type to have sizes and dtype" );
2164+ }
2165+ SmallVector<int64_t > sizes (resultTensorType.getSizes ());
2166+
2167+ int64_t dimInt;
2168+ if (!matchPattern (op.getDim (), m_TorchConstantInt (&dimInt)))
2169+ return rewriter.notifyMatchFailure (op, " Unsupported: non-constant dim" );
2170+
2171+ dimInt = toPositiveDim (dimInt, sizes.size ());
2172+ if (!isValidDim (dimInt, sizes.size ()))
2173+ return rewriter.notifyMatchFailure (op, " dim int is not valid" );
2174+
2175+ Location loc = op.getLoc ();
2176+ Value softmax = rewriter.create <AtenSoftmaxIntOp>(
2177+ loc, op.getType (), op.getSelf (), op.getDim (), op.getDtype ());
2178+
2179+ Type resultTensorDtype = resultTensorType.getDtype ();
2180+
2181+ Value negInfinity = getConstantWithGivenDtypeAndValue (
2182+ rewriter, loc, -std::numeric_limits<double >::infinity (),
2183+ resultTensorDtype);
2184+
2185+ auto boolDtype = rewriter.getI1Type ();
2186+ auto boolTensorType =
2187+ resultTensorType.getWithSizesAndDtype (sizes, boolDtype);
2188+ Value masked = rewriter.create <AtenEqScalarOp>(loc, boolTensorType,
2189+ op.getSelf (), negInfinity);
2190+
2191+ sizes[dimInt] = 1 ;
2192+ auto maskedRowsType =
2193+ resultTensorType.getWithSizesAndDtype (sizes, boolDtype);
2194+ Value cstTrue =
2195+ rewriter.create <Torch::ConstantBoolOp>(loc, rewriter.getBoolAttr (true ));
2196+ Value maskedRows = rewriter.create <AtenAllDimOp>(
2197+ loc, maskedRowsType, masked, op.getDim (), cstTrue);
2198+ Value cstZero = getConstantWithGivenDtypeAndValue (rewriter, loc, 0.0 ,
2199+ resultTensorDtype);
2200+ rewriter.replaceOpWithNewOp <AtenWhereScalarSelfOp>(
2201+ op, resultTensorType, maskedRows, cstZero, softmax);
2202+ return success ();
2203+ }
2204+ };
2205+ } // namespace
2206+
21512207// Aten_SoftmaxBackwardDataOp(gradOutput, output, dim) =>
21522208// newGrad = gradOutput * output
21532209// result = newGrad - output * sum(newGrad, dim))
@@ -9608,6 +9664,7 @@ class DecomposeComplexOpsPass
96089664 patterns);
96099665 addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);
96109666 addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxOp>(patterns);
9667+ addPatternIfTargetOpIsIllegal<DecomposeAten_SafeSoftmaxOp>(patterns);
96119668 addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxOp>(patterns);
96129669 addPatternIfTargetOpIsIllegal<DecomposeAtenLogSoftmaxIntOp>(patterns);
96139670 addPatternIfTargetOpIsIllegal<DecomposeAtenLogSigmoidOp>(patterns);
0 commit comments