From a02c86d32d6566ab20763bbff022aed017da891f Mon Sep 17 00:00:00 2001 From: snonk Date: Sat, 22 Nov 2025 12:49:23 -0600 Subject: [PATCH 1/3] add factor scalars --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 89 +++++++++++++++++++ .../jax/TransformOps/TransformOps.td | 5 ++ .../structured_tensors/factor_scalars.mlir | 68 ++++++++++++++ 3 files changed, 162 insertions(+) create mode 100644 test/lit_tests/structured_tensors/factor_scalars.mlir diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 6bda073b4..c7e8b3ee4 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -21740,6 +21740,94 @@ struct GatherElementwise } }; +struct FactorScalarsInDotGeneral final + : public CheckedOpRewritePattern { + using CheckedOpRewritePattern< + stablehlo::DotGeneralOp, + FactorScalarsInDotGeneral>::CheckedOpRewritePattern; + + LogicalResult matchAndRewriteImpl(stablehlo::DotGeneralOp op, + PatternRewriter &rewriter) const { + auto lhs = op.getLhs(); + auto rhs = op.getRhs(); + + // From v, extract scalar * tensor, and return true if the operation is not + // used elsewhere. + auto extractMul = [&](Value v, Value &scalar, Value &z) -> bool { + auto mulOp = v.getDefiningOp(); + if (!mulOp) { // set default scalar to 1 + scalar = nullptr; + z = v; + return true; + } + if (!isOnlyUsedInOperation(mulOp, op)) { + return false; + } + + Value mLhs = mulOp.getLhs(); + Value mRhs = mulOp.getRhs(); + + SplatElementsAttr splatAttr; + auto mLhsIsSplat = matchPattern(mLhs, m_Constant(&splatAttr)); + auto mRhsIsSplat = matchPattern(mRhs, m_Constant(&splatAttr)); + + if (mLhsIsSplat) { + scalar = mLhs; + z = mRhs; + } else if (mRhsIsSplat) { + scalar = mRhs; + z = mLhs; + } else { + // If both are non-scalar, treat whole v as Z, no scalar + scalar = nullptr; + z = v; + } + return true; + }; + + Value lhsScalar, lhsZ; + Value rhsScalar, rhsZ; + + if (!extractMul(lhs, lhsScalar, lhsZ) || !extractMul(rhs, rhsScalar, rhsZ)) + return failure(); + + if (!lhsScalar && !rhsScalar) { // nothing to do + return failure(); + } + + auto rhsZT = rhsZ.getDefiningOp(); + auto lhsZT = lhsZ.getDefiningOp(); + if (lhsZ == rhsZ || rhsZT && rhsZT.getOperand() == lhsZ || + lhsZT && lhsZT.getOperand() == rhsZ) { + auto precision = + op.getPrecisionConfig().value_or(stablehlo::PrecisionConfigAttr()); + auto algorithm = + op.getAlgorithm().value_or(stablehlo::DotAlgorithmAttr()); + + auto newDot = rewriter.create( + op.getLoc(), op.getType(), lhsZ, rhsZ, op.getDotDimensionNumbers(), + precision, algorithm); + + Value combinedScalar; + if (lhsScalar && rhsScalar) { + combinedScalar = rewriter.create( + op.getLoc(), lhsScalar, rhsScalar); + } else { + combinedScalar = lhsScalar ? lhsScalar : rhsScalar; + } + + // Multiply (a*b) * dot_general(Z, Z) + Value result = rewriter.create( + op.getLoc(), combinedScalar, newDot.getResult()); + rewriter.replaceOp(op, result); + return success(); + } + + return failure(); + } +}; + struct ChainedMultiplyToPower final : public CheckedOpRewritePattern { using CheckedOpRewritePattern< @@ -26278,6 +26366,7 @@ struct EnzymeHLOOptPass (no_nan || all_finite), context); patterns.add(context); + patterns.add(context); // clang-format off patterns.add< diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.td b/src/enzyme_ad/jax/TransformOps/TransformOps.td index dc0d6f760..327dcc639 100644 --- a/src/enzyme_ad/jax/TransformOps/TransformOps.td +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.td @@ -608,6 +608,11 @@ def ApplyTransposeSymmetricSimplify : EnzymeHLOPatternOp< let patterns = ["TransposeSymmetricSimplify"]; } +def ApplyFactorScalarsInDotGeneral : EnzymeHLOPatternOp< + "factor_scalars_in_dot_general"> { + let patterns = ["FactorScalarsInDotGeneral"]; +} + def ApplyTransposeElementwisePatterns : EnzymeHLOParameterizedPatternOp< "transpose_elementwise"> { let arguments = (ins OptionalAttr:$benefit, BoolAttr:$parameter); diff --git a/test/lit_tests/structured_tensors/factor_scalars.mlir b/test/lit_tests/structured_tensors/factor_scalars.mlir new file mode 100644 index 000000000..dd7bec9d9 --- /dev/null +++ b/test/lit_tests/structured_tensors/factor_scalars.mlir @@ -0,0 +1,68 @@ +// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=factor_scalars_in_dot_general" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s + +func.func @pass1(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { + %0 = stablehlo.constant dense<4.0> : tensor<10x10xf64> + %1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64> + %2 = stablehlo.dot_general %1, %arg0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> + return %2 : tensor<10x10xf64> +} + +// CHECK: func.func @pass1(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { +// CHECK-NEXT: %cst = stablehlo.constant dense<4.000000e+00> : tensor<10x10xf64> +// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> +// CHECK-NEXT: %1 = stablehlo.multiply %cst, %0 : tensor<10x10xf64> +// CHECK-NEXT: return %1 : tensor<10x10xf64> +// CHECK-NEXT: } + +func.func @pass2(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { + %0 = stablehlo.constant dense<4.0> : tensor<10x10xf64> + %1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64> + %2 = stablehlo.constant dense<2.0> : tensor<10x10xf64> + %3 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64> + %4 = stablehlo.dot_general %1, %3, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> + return %4 : tensor<10x10xf64> +} + +// CHECK: func.func @pass2(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { +// CHECK-NEXT: %cst = stablehlo.constant dense<4.000000e+00> : tensor<10x10xf64> +// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> +// CHECK-NEXT: %1 = stablehlo.multiply %cst, %cst : tensor<10x10xf64> +// CHECK-NEXT: %2 = stablehlo.multiply %1, %0 : tensor<10x10xf64> +// CHECK-NEXT: return %2 : tensor<10x10xf64> +// CHECK-NEXT: } + +func.func @pass3(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { + %0 = stablehlo.constant dense<4.0> : tensor<10x10xf64> + %1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64> + %2 = stablehlo.constant dense<2.0> : tensor<10x10xf64> + %3 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<10x10xf64>) -> tensor<10x10xf64> + %4 = stablehlo.multiply %2, %3 : tensor<10x10xf64> + %5 = stablehlo.dot_general %1, %4, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> + return %5 : tensor<10x10xf64> +} + +//CHECK: func.func @pass3(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { +//CHECK-NEXT: %cst = stablehlo.constant dense<2.000000e+00> : tensor<10x10xf64> +//CHECK-NEXT: %cst_0 = stablehlo.constant dense<4.000000e+00> : tensor<10x10xf64> +//CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<10x10xf64>) -> tensor<10x10xf64> +//CHECK-NEXT: %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> +//CHECK-NEXT: %2 = stablehlo.multiply %cst_0, %cst : tensor<10x10xf64> +//CHECK-NEXT: %3 = stablehlo.multiply %2, %1 : tensor<10x10xf64> +//CHECK-NEXT: return %3 : tensor<10x10xf64> +//CHECK-NEXT: } + +func.func @fail1(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { + %0 = stablehlo.constant dense<4.0> : tensor<10x10xf64> + %1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64> + %2 = stablehlo.dot_general %1, %arg0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> + %3 = stablehlo.add %2, %1 : tensor<10x10xf64> + return %3 : tensor<10x10xf64> +} + +// CHECK: func.func @fail1(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { +// CHECK-NEXT: %cst = stablehlo.constant dense<4.000000e+00> : tensor<10x10xf64> +// CHECK-NEXT: %0 = stablehlo.multiply %cst, %arg0 : tensor<10x10xf64> +// CHECK-NEXT: %1 = stablehlo.dot_general %0, %arg0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> +// CHECK-NEXT: %2 = stablehlo.add %1, %0 : tensor<10x10xf64> +// CHECK-NEXT: return %2 : tensor<10x10xf64> +// CHECK-NEXT: } From 1bebb6fa9bfd686939bd5a0426cb1909e949b731 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 24 Nov 2025 19:00:26 -0600 Subject: [PATCH 2/3] chore: minor cleanups --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 14 ++----- src/enzyme_ad/jax/primitives.py | 1 + .../factor_scalars.mlir | 39 ++++++++++--------- 3 files changed, 24 insertions(+), 30 deletions(-) rename test/lit_tests/{structured_tensors => }/factor_scalars.mlir (68%) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index c7e8b3ee4..490370175 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -21769,13 +21769,10 @@ struct FactorScalarsInDotGeneral final Value mRhs = mulOp.getRhs(); SplatElementsAttr splatAttr; - auto mLhsIsSplat = matchPattern(mLhs, m_Constant(&splatAttr)); - auto mRhsIsSplat = matchPattern(mRhs, m_Constant(&splatAttr)); - - if (mLhsIsSplat) { + if (matchPattern(mLhs, m_Constant(&splatAttr))) { scalar = mLhs; z = mRhs; - } else if (mRhsIsSplat) { + } else if (matchPattern(mRhs, m_Constant(&splatAttr))) { scalar = mRhs; z = mLhs; } else { @@ -21800,14 +21797,9 @@ struct FactorScalarsInDotGeneral final auto lhsZT = lhsZ.getDefiningOp(); if (lhsZ == rhsZ || rhsZT && rhsZT.getOperand() == lhsZ || lhsZT && lhsZT.getOperand() == rhsZ) { - auto precision = - op.getPrecisionConfig().value_or(stablehlo::PrecisionConfigAttr()); - auto algorithm = - op.getAlgorithm().value_or(stablehlo::DotAlgorithmAttr()); - auto newDot = rewriter.create( op.getLoc(), op.getType(), lhsZ, rhsZ, op.getDotDimensionNumbers(), - precision, algorithm); + op.getPrecisionConfigAttr(), op.getAlgorithmAttr()); Value combinedScalar; if (lhsScalar && rhsScalar) { diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 8d107478d..d5d401038 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -314,6 +314,7 @@ def optimization_passes( "dot_general_only_diagonal_access", "divide_negated_operands_simplify", "multiply_negated_operands_simplify", + "factor_scalars_in_dot_general", ] # constant propagation patterns diff --git a/test/lit_tests/structured_tensors/factor_scalars.mlir b/test/lit_tests/factor_scalars.mlir similarity index 68% rename from test/lit_tests/structured_tensors/factor_scalars.mlir rename to test/lit_tests/factor_scalars.mlir index dd7bec9d9..0a513fa43 100644 --- a/test/lit_tests/structured_tensors/factor_scalars.mlir +++ b/test/lit_tests/factor_scalars.mlir @@ -1,68 +1,69 @@ // RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=factor_scalars_in_dot_general" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s -func.func @pass1(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { +func.func @main1(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { %0 = stablehlo.constant dense<4.0> : tensor<10x10xf64> %1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64> - %2 = stablehlo.dot_general %1, %arg0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> + %2 = stablehlo.dot_general %1, %arg0, contracting_dims = [0] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> return %2 : tensor<10x10xf64> } -// CHECK: func.func @pass1(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { +// CHECK: func.func @main1(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { // CHECK-NEXT: %cst = stablehlo.constant dense<4.000000e+00> : tensor<10x10xf64> -// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> +// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [0] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> // CHECK-NEXT: %1 = stablehlo.multiply %cst, %0 : tensor<10x10xf64> // CHECK-NEXT: return %1 : tensor<10x10xf64> // CHECK-NEXT: } -func.func @pass2(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { +func.func @main2(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { %0 = stablehlo.constant dense<4.0> : tensor<10x10xf64> %1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64> %2 = stablehlo.constant dense<2.0> : tensor<10x10xf64> - %3 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64> - %4 = stablehlo.dot_general %1, %3, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> + %3 = stablehlo.multiply %2, %arg0 : tensor<10x10xf64> + %4 = stablehlo.dot_general %1, %3, contracting_dims = [1] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> return %4 : tensor<10x10xf64> } -// CHECK: func.func @pass2(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { -// CHECK-NEXT: %cst = stablehlo.constant dense<4.000000e+00> : tensor<10x10xf64> -// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> -// CHECK-NEXT: %1 = stablehlo.multiply %cst, %cst : tensor<10x10xf64> +// CHECK: func.func @main2(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { +// CHECK-NEXT: %cst = stablehlo.constant dense<2.000000e+00> : tensor<10x10xf64> +// CHECK-NEXT: %cst_0 = stablehlo.constant dense<4.000000e+00> : tensor<10x10xf64> +// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [1] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> +// CHECK-NEXT: %1 = stablehlo.multiply %cst_0, %cst : tensor<10x10xf64> // CHECK-NEXT: %2 = stablehlo.multiply %1, %0 : tensor<10x10xf64> // CHECK-NEXT: return %2 : tensor<10x10xf64> // CHECK-NEXT: } -func.func @pass3(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { +func.func @main3(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { %0 = stablehlo.constant dense<4.0> : tensor<10x10xf64> %1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64> %2 = stablehlo.constant dense<2.0> : tensor<10x10xf64> %3 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<10x10xf64>) -> tensor<10x10xf64> %4 = stablehlo.multiply %2, %3 : tensor<10x10xf64> - %5 = stablehlo.dot_general %1, %4, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> + %5 = stablehlo.dot_general %1, %4, contracting_dims = [1] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> return %5 : tensor<10x10xf64> } -//CHECK: func.func @pass3(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { +//CHECK: func.func @main3(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { //CHECK-NEXT: %cst = stablehlo.constant dense<2.000000e+00> : tensor<10x10xf64> //CHECK-NEXT: %cst_0 = stablehlo.constant dense<4.000000e+00> : tensor<10x10xf64> //CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<10x10xf64>) -> tensor<10x10xf64> -//CHECK-NEXT: %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> +//CHECK-NEXT: %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> //CHECK-NEXT: %2 = stablehlo.multiply %cst_0, %cst : tensor<10x10xf64> //CHECK-NEXT: %3 = stablehlo.multiply %2, %1 : tensor<10x10xf64> //CHECK-NEXT: return %3 : tensor<10x10xf64> //CHECK-NEXT: } -func.func @fail1(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { +func.func @main4(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { %0 = stablehlo.constant dense<4.0> : tensor<10x10xf64> %1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64> - %2 = stablehlo.dot_general %1, %arg0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> + %2 = stablehlo.dot_general %1, %arg0, contracting_dims = [0] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> %3 = stablehlo.add %2, %1 : tensor<10x10xf64> return %3 : tensor<10x10xf64> } -// CHECK: func.func @fail1(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { +// CHECK: func.func @main4(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { // CHECK-NEXT: %cst = stablehlo.constant dense<4.000000e+00> : tensor<10x10xf64> // CHECK-NEXT: %0 = stablehlo.multiply %cst, %arg0 : tensor<10x10xf64> -// CHECK-NEXT: %1 = stablehlo.dot_general %0, %arg0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> +// CHECK-NEXT: %1 = stablehlo.dot_general %0, %arg0, contracting_dims = [0] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> // CHECK-NEXT: %2 = stablehlo.add %1, %0 : tensor<10x10xf64> // CHECK-NEXT: return %2 : tensor<10x10xf64> // CHECK-NEXT: } From 203e88fbdc4accc5dd4b8de555f2f13679936f69 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 24 Nov 2025 22:12:49 -0600 Subject: [PATCH 3/3] fix: generalize to non-square cases --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 119 +++++++++++++++++----- test/lit_tests/factor_scalars.mlir | 53 ++++++---- 2 files changed, 123 insertions(+), 49 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 490370175..e24993dde 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -21740,6 +21740,7 @@ struct GatherElementwise } }; +// (a1 * A) × (b1 * B) → (a1 * b1) * (A × B) struct FactorScalarsInDotGeneral final : public CheckedOpRewritePattern { @@ -21752,6 +21753,11 @@ struct FactorScalarsInDotGeneral final auto lhs = op.getLhs(); auto rhs = op.getRhs(); + // first check if atleast one of lhs or rhs has a + // mutliplication with a scalar + if (!canRewriteOperation(op, lhs, rhs)) + return failure(); + // From v, extract scalar * tensor, and return true if the operation is not // used elsewhere. auto extractMul = [&](Value v, Value &scalar, Value &z) -> bool { @@ -21770,13 +21776,40 @@ struct FactorScalarsInDotGeneral final SplatElementsAttr splatAttr; if (matchPattern(mLhs, m_Constant(&splatAttr))) { - scalar = mLhs; + auto mLhsType = cast(mLhs.getType()); + auto scalarType = RankedTensorType::get({}, mLhsType.getElementType()); + scalar = + stablehlo::ConstantOp::create(rewriter, op.getLoc(), scalarType, + splatAttr.resizeSplat(scalarType)); z = mRhs; } else if (matchPattern(mRhs, m_Constant(&splatAttr))) { - scalar = mRhs; + auto mRhsType = cast(mRhs.getType()); + auto scalarType = RankedTensorType::get({}, mRhsType.getElementType()); + scalar = + stablehlo::ConstantOp::create(rewriter, op.getLoc(), scalarType, + splatAttr.resizeSplat(scalarType)); z = mLhs; - } else { - // If both are non-scalar, treat whole v as Z, no scalar + } else if (auto lhsBcast = + mLhs.getDefiningOp()) { + if (cast(lhsBcast.getOperand().getType()).getRank() == + 0) { + scalar = lhsBcast.getOperand(); + z = mRhs; + } else { + scalar = nullptr; + return false; + } + } else if (auto rhsBcast = + mRhs.getDefiningOp()) { + if (cast(rhsBcast.getOperand().getType()).getRank() == + 0) { + scalar = rhsBcast.getOperand(); + z = mLhs; + } else { + scalar = nullptr; + return false; + } + } else { // If both are non-scalar, treat whole v as Z, no scalar scalar = nullptr; z = v; } @@ -21786,37 +21819,67 @@ struct FactorScalarsInDotGeneral final Value lhsScalar, lhsZ; Value rhsScalar, rhsZ; - if (!extractMul(lhs, lhsScalar, lhsZ) || !extractMul(rhs, rhsScalar, rhsZ)) - return failure(); + auto lhsExtracted = extractMul(lhs, lhsScalar, lhsZ); + auto rhsExtracted = extractMul(rhs, rhsScalar, rhsZ); - if (!lhsScalar && !rhsScalar) { // nothing to do - return failure(); + assert(lhsExtracted && rhsExtracted); + + auto newDot = stablehlo::DotGeneralOp::create( + rewriter, op.getLoc(), op.getType(), lhsZ, rhsZ, + op.getDotDimensionNumbers(), op.getPrecisionConfigAttr(), + op.getAlgorithmAttr()); + + Value combinedScalar; + if (lhsScalar && rhsScalar) { + combinedScalar = + stablehlo::MulOp::create(rewriter, op.getLoc(), lhsScalar, rhsScalar); + } else { + combinedScalar = lhsScalar ? lhsScalar : rhsScalar; } - auto rhsZT = rhsZ.getDefiningOp(); - auto lhsZT = lhsZ.getDefiningOp(); - if (lhsZ == rhsZ || rhsZT && rhsZT.getOperand() == lhsZ || - lhsZT && lhsZT.getOperand() == rhsZ) { - auto newDot = rewriter.create( - op.getLoc(), op.getType(), lhsZ, rhsZ, op.getDotDimensionNumbers(), - op.getPrecisionConfigAttr(), op.getAlgorithmAttr()); + auto bcastedScalar = stablehlo::BroadcastInDimOp::create( + rewriter, op.getLoc(), newDot.getType(), combinedScalar, + rewriter.getDenseI64ArrayAttr({})); + rewriter.replaceOpWithNewOp(op, bcastedScalar, newDot); + return success(); + } + +private: + bool canRewriteOperation(stablehlo::DotGeneralOp op, Value lhs, + Value rhs) const { + auto lhsMulOp = lhs.getDefiningOp(); + auto rhsMulOp = rhs.getDefiningOp(); + if (!lhsMulOp && !rhsMulOp) + return false; // nothing to do + + if ((lhsMulOp && !isOnlyUsedInOperation(lhsMulOp, op)) || + (rhsMulOp && !isOnlyUsedInOperation(rhsMulOp, op))) + return false; // better to not do anything + + auto isScalar = [&](Value v) -> bool { + SplatElementsAttr splatAttr; + if (matchPattern(v, m_Constant(&splatAttr))) + return true; - Value combinedScalar; - if (lhsScalar && rhsScalar) { - combinedScalar = rewriter.create( - op.getLoc(), lhsScalar, rhsScalar); - } else { - combinedScalar = lhsScalar ? lhsScalar : rhsScalar; - } + auto bcastOp = v.getDefiningOp(); + if (bcastOp && + cast(bcastOp.getOperand().getType()).getRank() == 0) + return true; - // Multiply (a*b) * dot_general(Z, Z) - Value result = rewriter.create( - op.getLoc(), combinedScalar, newDot.getResult()); - rewriter.replaceOp(op, result); - return success(); + return false; + }; + + bool lhsHasScalar = false; + if (lhsMulOp) { + lhsHasScalar = isScalar(lhsMulOp.getLhs()) || isScalar(lhsMulOp.getRhs()); } - return failure(); + bool rhsHasScalar = false; + if (rhsMulOp) { + rhsHasScalar = isScalar(rhsMulOp.getLhs()) || isScalar(rhsMulOp.getRhs()); + } + + return lhsHasScalar || rhsHasScalar; } }; diff --git a/test/lit_tests/factor_scalars.mlir b/test/lit_tests/factor_scalars.mlir index 0a513fa43..3db93bb85 100644 --- a/test/lit_tests/factor_scalars.mlir +++ b/test/lit_tests/factor_scalars.mlir @@ -1,20 +1,20 @@ -// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=factor_scalars_in_dot_general" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s +// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=factor_scalars_in_dot_general" --transform-interpreter --enzyme-hlo-remove-transform --enzyme-hlo-opt %s | FileCheck %s -func.func @main1(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { +func.func @main1(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> { %0 = stablehlo.constant dense<4.0> : tensor<10x10xf64> %1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64> %2 = stablehlo.dot_general %1, %arg0, contracting_dims = [0] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> return %2 : tensor<10x10xf64> } -// CHECK: func.func @main1(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { +// CHECK: func.func @main1(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> { // CHECK-NEXT: %cst = stablehlo.constant dense<4.000000e+00> : tensor<10x10xf64> // CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [0] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> // CHECK-NEXT: %1 = stablehlo.multiply %cst, %0 : tensor<10x10xf64> // CHECK-NEXT: return %1 : tensor<10x10xf64> // CHECK-NEXT: } -func.func @main2(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { +func.func @main2(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> { %0 = stablehlo.constant dense<4.0> : tensor<10x10xf64> %1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64> %2 = stablehlo.constant dense<2.0> : tensor<10x10xf64> @@ -23,16 +23,14 @@ func.func @main2(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> te return %4 : tensor<10x10xf64> } -// CHECK: func.func @main2(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { -// CHECK-NEXT: %cst = stablehlo.constant dense<2.000000e+00> : tensor<10x10xf64> -// CHECK-NEXT: %cst_0 = stablehlo.constant dense<4.000000e+00> : tensor<10x10xf64> +// CHECK: func.func @main2(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> { +// CHECK-NEXT: %cst = stablehlo.constant dense<8.000000e+00> : tensor<10x10xf64> // CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [1] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> -// CHECK-NEXT: %1 = stablehlo.multiply %cst_0, %cst : tensor<10x10xf64> -// CHECK-NEXT: %2 = stablehlo.multiply %1, %0 : tensor<10x10xf64> -// CHECK-NEXT: return %2 : tensor<10x10xf64> +// CHECK-NEXT: %1 = stablehlo.multiply %cst, %0 : tensor<10x10xf64> +// CHECK-NEXT: return %1 : tensor<10x10xf64> // CHECK-NEXT: } -func.func @main3(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { +func.func @main3(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> { %0 = stablehlo.constant dense<4.0> : tensor<10x10xf64> %1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64> %2 = stablehlo.constant dense<2.0> : tensor<10x10xf64> @@ -42,17 +40,14 @@ func.func @main3(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> te return %5 : tensor<10x10xf64> } -//CHECK: func.func @main3(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { -//CHECK-NEXT: %cst = stablehlo.constant dense<2.000000e+00> : tensor<10x10xf64> -//CHECK-NEXT: %cst_0 = stablehlo.constant dense<4.000000e+00> : tensor<10x10xf64> -//CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<10x10xf64>) -> tensor<10x10xf64> -//CHECK-NEXT: %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> -//CHECK-NEXT: %2 = stablehlo.multiply %cst_0, %cst : tensor<10x10xf64> -//CHECK-NEXT: %3 = stablehlo.multiply %2, %1 : tensor<10x10xf64> -//CHECK-NEXT: return %3 : tensor<10x10xf64> +//CHECK: func.func @main3(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> { +//CHECK-NEXT: %cst = stablehlo.constant dense<8.000000e+00> : tensor<10x10xf64> +//CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [1] x [1] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> +//CHECK-NEXT: %1 = stablehlo.multiply %cst, %0 : tensor<10x10xf64> +//CHECK-NEXT: return %1 : tensor<10x10xf64> //CHECK-NEXT: } -func.func @main4(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { +func.func @main4(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> { %0 = stablehlo.constant dense<4.0> : tensor<10x10xf64> %1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64> %2 = stablehlo.dot_general %1, %arg0, contracting_dims = [0] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> @@ -60,10 +55,26 @@ func.func @main4(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> te return %3 : tensor<10x10xf64> } -// CHECK: func.func @main4(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} { +// CHECK: func.func @main4(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> { // CHECK-NEXT: %cst = stablehlo.constant dense<4.000000e+00> : tensor<10x10xf64> // CHECK-NEXT: %0 = stablehlo.multiply %cst, %arg0 : tensor<10x10xf64> // CHECK-NEXT: %1 = stablehlo.dot_general %0, %arg0, contracting_dims = [0] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64> // CHECK-NEXT: %2 = stablehlo.add %1, %0 : tensor<10x10xf64> // CHECK-NEXT: return %2 : tensor<10x10xf64> // CHECK-NEXT: } + +func.func @main5(%arg0: tensor<10x3xf64>, %arg1: tensor<3x10xf64>) -> tensor<10x10xf64> { + %0 = stablehlo.constant dense<4.0> : tensor<10x3xf64> + %1 = stablehlo.multiply %0, %arg0 : tensor<10x3xf64> + %2 = stablehlo.constant dense<2.0> : tensor<3x10xf64> + %3 = stablehlo.multiply %arg1, %2 : tensor<3x10xf64> + %4 = stablehlo.dot_general %1, %3, contracting_dims = [1] x [0] : (tensor<10x3xf64>, tensor<3x10xf64>) -> tensor<10x10xf64> + return %4 : tensor<10x10xf64> +} + +// CHECK: func.func @main5(%arg0: tensor<10x3xf64>, %arg1: tensor<3x10xf64>) -> tensor<10x10xf64> { +// CHECK-NEXT: %cst = stablehlo.constant dense<8.000000e+00> : tensor<10x10xf64> +// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<10x3xf64>, tensor<3x10xf64>) -> tensor<10x10xf64> +// CHECK-NEXT: %1 = stablehlo.multiply %cst, %0 : tensor<10x10xf64> +// CHECK-NEXT: return %1 : tensor<10x10xf64> +// CHECK-NEXT: }