diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 6bda073b4..e24993dde 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -21740,6 +21740,149 @@ struct GatherElementwise } }; +// (a1 * A) × (b1 * B) → (a1 * b1) * (A × B) +struct FactorScalarsInDotGeneral final + : public CheckedOpRewritePattern { + using CheckedOpRewritePattern< + stablehlo::DotGeneralOp, + FactorScalarsInDotGeneral>::CheckedOpRewritePattern; + + LogicalResult matchAndRewriteImpl(stablehlo::DotGeneralOp op, + PatternRewriter &rewriter) const { + auto lhs = op.getLhs(); + auto rhs = op.getRhs(); + + // first check if atleast one of lhs or rhs has a + // mutliplication with a scalar + if (!canRewriteOperation(op, lhs, rhs)) + return failure(); + + // From v, extract scalar * tensor, and return true if the operation is not + // used elsewhere. + auto extractMul = [&](Value v, Value &scalar, Value &z) -> bool { + auto mulOp = v.getDefiningOp(); + if (!mulOp) { // set default scalar to 1 + scalar = nullptr; + z = v; + return true; + } + if (!isOnlyUsedInOperation(mulOp, op)) { + return false; + } + + Value mLhs = mulOp.getLhs(); + Value mRhs = mulOp.getRhs(); + + SplatElementsAttr splatAttr; + if (matchPattern(mLhs, m_Constant(&splatAttr))) { + auto mLhsType = cast(mLhs.getType()); + auto scalarType = RankedTensorType::get({}, mLhsType.getElementType()); + scalar = + stablehlo::ConstantOp::create(rewriter, op.getLoc(), scalarType, + splatAttr.resizeSplat(scalarType)); + z = mRhs; + } else if (matchPattern(mRhs, m_Constant(&splatAttr))) { + auto mRhsType = cast(mRhs.getType()); + auto scalarType = RankedTensorType::get({}, mRhsType.getElementType()); + scalar = + stablehlo::ConstantOp::create(rewriter, op.getLoc(), scalarType, + splatAttr.resizeSplat(scalarType)); + z = mLhs; + } else if (auto lhsBcast = + mLhs.getDefiningOp()) { + if (cast(lhsBcast.getOperand().getType()).getRank() == + 0) { + scalar = lhsBcast.getOperand(); + z = mRhs; + } else { + scalar = nullptr; + return false; + } + } else if (auto rhsBcast = + mRhs.getDefiningOp()) { + if (cast(rhsBcast.getOperand().getType()).getRank() == + 0) { + scalar = rhsBcast.getOperand(); + z = mLhs; + } else { + scalar = nullptr; + return false; + } + } else { // If both are non-scalar, treat whole v as Z, no scalar + scalar = nullptr; + z = v; + } + return true; + }; + + Value lhsScalar, lhsZ; + Value rhsScalar, rhsZ; + + auto lhsExtracted = extractMul(lhs, lhsScalar, lhsZ); + auto rhsExtracted = extractMul(rhs, rhsScalar, rhsZ); + + assert(lhsExtracted && rhsExtracted); + + auto newDot = stablehlo::DotGeneralOp::create( + rewriter, op.getLoc(), op.getType(), lhsZ, rhsZ, + op.getDotDimensionNumbers(), op.getPrecisionConfigAttr(), + op.getAlgorithmAttr()); + + Value combinedScalar; + if (lhsScalar && rhsScalar) { + combinedScalar = + stablehlo::MulOp::create(rewriter, op.getLoc(), lhsScalar, rhsScalar); + } else { + combinedScalar = lhsScalar ? lhsScalar : rhsScalar; + } + + auto bcastedScalar = stablehlo::BroadcastInDimOp::create( + rewriter, op.getLoc(), newDot.getType(), combinedScalar, + rewriter.getDenseI64ArrayAttr({})); + rewriter.replaceOpWithNewOp(op, bcastedScalar, newDot); + return success(); + } + +private: + bool canRewriteOperation(stablehlo::DotGeneralOp op, Value lhs, + Value rhs) const { + auto lhsMulOp = lhs.getDefiningOp(); + auto rhsMulOp = rhs.getDefiningOp(); + if (!lhsMulOp && !rhsMulOp) + return false; // nothing to do + + if ((lhsMulOp && !isOnlyUsedInOperation(lhsMulOp, op)) || + (rhsMulOp && !isOnlyUsedInOperation(rhsMulOp, op))) + return false; // better to not do anything + + auto isScalar = [&](Value v) -> bool { + SplatElementsAttr splatAttr; + if (matchPattern(v, m_Constant(&splatAttr))) + return true; + + auto bcastOp = v.getDefiningOp(); + if (bcastOp && + cast(bcastOp.getOperand().getType()).getRank() == 0) + return true; + + return false; + }; + + bool lhsHasScalar = false; + if (lhsMulOp) { + lhsHasScalar = isScalar(lhsMulOp.getLhs()) || isScalar(lhsMulOp.getRhs()); + } + + bool rhsHasScalar = false; + if (rhsMulOp) { + rhsHasScalar = isScalar(rhsMulOp.getLhs()) || isScalar(rhsMulOp.getRhs()); + } + + return lhsHasScalar || rhsHasScalar; + } +}; + struct ChainedMultiplyToPower final : public CheckedOpRewritePattern { using CheckedOpRewritePattern< @@ -26278,6 +26421,7 @@ struct EnzymeHLOOptPass (no_nan || all_finite), context); patterns.add(context); + patterns.add(context); // clang-format off patterns.add< diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.td b/src/enzyme_ad/jax/TransformOps/TransformOps.td index dc0d6f760..327dcc639 100644 --- a/src/enzyme_ad/jax/TransformOps/TransformOps.td +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.td @@ -608,6 +608,11 @@ def ApplyTransposeSymmetricSimplify : EnzymeHLOPatternOp< let patterns = ["TransposeSymmetricSimplify"]; } +def ApplyFactorScalarsInDotGeneral : EnzymeHLOPatternOp< + "factor_scalars_in_dot_general"> { + let patterns = ["FactorScalarsInDotGeneral"]; +} + def ApplyTransposeElementwisePatterns : EnzymeHLOParameterizedPatternOp< "transpose_elementwise"> { let arguments = (ins OptionalAttr:$benefit, BoolAttr:$parameter); diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 8d107478d..d5d401038 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -314,6 +314,7 @@ def optimization_passes( "dot_general_only_diagonal_access", "divide_negated_operands_simplify", "multiply_negated_operands_simplify", + "factor_scalars_in_dot_general", ] # constant propagation patterns diff --git a/test/lit_tests/factor_scalars.mlir b/test/lit_tests/factor_scalars.mlir new file mode 100644 index 000000000..3db93bb85 --- /dev/null +++ b/test/lit_tests/factor_scalars.mlir @@ -0,0 +1,80 @@ +// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=factor_scalars_in_dot_general" --transform-interpreter --enzyme-hlo-remove-transform --enzyme-hlo-opt %s | FileCheck %s + +func.func @main1(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> { + %0 = stablehlo.constant dense<4.0> : tensor<10x10xf64> + %1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64> + %2 = stablehlo.dot_general %1, %arg0, contracting_dims = [0] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> + return %2 : tensor<10x10xf64> +} + +// CHECK: func.func @main1(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> { +// CHECK-NEXT: %cst = stablehlo.constant dense<4.000000e+00> : tensor<10x10xf64> +// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [0] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> +// CHECK-NEXT: %1 = stablehlo.multiply %cst, %0 : tensor<10x10xf64> +// CHECK-NEXT: return %1 : tensor<10x10xf64> +// CHECK-NEXT: } + +func.func @main2(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> { + %0 = stablehlo.constant dense<4.0> : tensor<10x10xf64> + %1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64> + %2 = stablehlo.constant dense<2.0> : tensor<10x10xf64> + %3 = stablehlo.multiply %2, %arg0 : tensor<10x10xf64> + %4 = stablehlo.dot_general %1, %3, contracting_dims = [1] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> + return %4 : tensor<10x10xf64> +} + +// CHECK: func.func @main2(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> { +// CHECK-NEXT: %cst = stablehlo.constant dense<8.000000e+00> : tensor<10x10xf64> +// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [1] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> +// CHECK-NEXT: %1 = stablehlo.multiply %cst, %0 : tensor<10x10xf64> +// CHECK-NEXT: return %1 : tensor<10x10xf64> +// CHECK-NEXT: } + +func.func @main3(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> { + %0 = stablehlo.constant dense<4.0> : tensor<10x10xf64> + %1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64> + %2 = stablehlo.constant dense<2.0> : tensor<10x10xf64> + %3 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<10x10xf64>) -> tensor<10x10xf64> + %4 = stablehlo.multiply %2, %3 : tensor<10x10xf64> + %5 = stablehlo.dot_general %1, %4, contracting_dims = [1] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> + return %5 : tensor<10x10xf64> +} + +//CHECK: func.func @main3(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> { +//CHECK-NEXT: %cst = stablehlo.constant dense<8.000000e+00> : tensor<10x10xf64> +//CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [1] x [1] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> +//CHECK-NEXT: %1 = stablehlo.multiply %cst, %0 : tensor<10x10xf64> +//CHECK-NEXT: return %1 : tensor<10x10xf64> +//CHECK-NEXT: } + +func.func @main4(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> { + %0 = stablehlo.constant dense<4.0> : tensor<10x10xf64> + %1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64> + %2 = stablehlo.dot_general %1, %arg0, contracting_dims = [0] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> + %3 = stablehlo.add %2, %1 : tensor<10x10xf64> + return %3 : tensor<10x10xf64> +} + +// CHECK: func.func @main4(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> { +// CHECK-NEXT: %cst = stablehlo.constant dense<4.000000e+00> : tensor<10x10xf64> +// CHECK-NEXT: %0 = stablehlo.multiply %cst, %arg0 : tensor<10x10xf64> +// CHECK-NEXT: %1 = stablehlo.dot_general %0, %arg0, contracting_dims = [0] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> +// CHECK-NEXT: %2 = stablehlo.add %1, %0 : tensor<10x10xf64> +// CHECK-NEXT: return %2 : tensor<10x10xf64> +// CHECK-NEXT: } + +func.func @main5(%arg0: tensor<10x3xf64>, %arg1: tensor<3x10xf64>) -> tensor<10x10xf64> { + %0 = stablehlo.constant dense<4.0> : tensor<10x3xf64> + %1 = stablehlo.multiply %0, %arg0 : tensor<10x3xf64> + %2 = stablehlo.constant dense<2.0> : tensor<3x10xf64> + %3 = stablehlo.multiply %arg1, %2 : tensor<3x10xf64> + %4 = stablehlo.dot_general %1, %3, contracting_dims = [1] x [0] : (tensor<10x3xf64>, tensor<3x10xf64>) -> tensor<10x10xf64> + return %4 : tensor<10x10xf64> +} + +// CHECK: func.func @main5(%arg0: tensor<10x3xf64>, %arg1: tensor<3x10xf64>) -> tensor<10x10xf64> { +// CHECK-NEXT: %cst = stablehlo.constant dense<8.000000e+00> : tensor<10x10xf64> +// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<10x3xf64>, tensor<3x10xf64>) -> tensor<10x10xf64> +// CHECK-NEXT: %1 = stablehlo.multiply %cst, %0 : tensor<10x10xf64> +// CHECK-NEXT: return %1 : tensor<10x10xf64> +// CHECK-NEXT: }