Skip to content

Commit 9bfcbae

Browse files
committed
add factor scalars
1 parent aded9a2 commit 9bfcbae

File tree

3 files changed

+162
-0
lines changed

3 files changed

+162
-0
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21701,6 +21701,94 @@ struct GatherElementwise
2170121701
}
2170221702
};
2170321703

21704+
struct FactorScalarsInDotGeneral final
21705+
: public CheckedOpRewritePattern<stablehlo::DotGeneralOp,
21706+
FactorScalarsInDotGeneral> {
21707+
using CheckedOpRewritePattern<
21708+
stablehlo::DotGeneralOp,
21709+
FactorScalarsInDotGeneral>::CheckedOpRewritePattern;
21710+
21711+
LogicalResult matchAndRewriteImpl(stablehlo::DotGeneralOp op,
21712+
PatternRewriter &rewriter) const {
21713+
auto lhs = op.getLhs();
21714+
auto rhs = op.getRhs();
21715+
21716+
// From v, extract scalar * tensor, and return true if the operation is not
21717+
// used elsewhere.
21718+
auto extractMul = [&](Value v, Value &scalar, Value &z) -> bool {
21719+
auto mulOp = v.getDefiningOp<stablehlo::MulOp>();
21720+
if (!mulOp) { // set default scalar to 1
21721+
scalar = nullptr;
21722+
z = v;
21723+
return true;
21724+
}
21725+
if (!isOnlyUsedInOperation(mulOp, op)) {
21726+
return false;
21727+
}
21728+
21729+
Value mLhs = mulOp.getLhs();
21730+
Value mRhs = mulOp.getRhs();
21731+
21732+
SplatElementsAttr splatAttr;
21733+
auto mLhsIsSplat = matchPattern(mLhs, m_Constant(&splatAttr));
21734+
auto mRhsIsSplat = matchPattern(mRhs, m_Constant(&splatAttr));
21735+
21736+
if (mLhsIsSplat) {
21737+
scalar = mLhs;
21738+
z = mRhs;
21739+
} else if (mRhsIsSplat) {
21740+
scalar = mRhs;
21741+
z = mLhs;
21742+
} else {
21743+
// If both are non-scalar, treat whole v as Z, no scalar
21744+
scalar = nullptr;
21745+
z = v;
21746+
}
21747+
return true;
21748+
};
21749+
21750+
Value lhsScalar, lhsZ;
21751+
Value rhsScalar, rhsZ;
21752+
21753+
if (!extractMul(lhs, lhsScalar, lhsZ) || !extractMul(rhs, rhsScalar, rhsZ))
21754+
return failure();
21755+
21756+
if (!lhsScalar && !rhsScalar) { // nothing to do
21757+
return failure();
21758+
}
21759+
21760+
auto rhsZT = rhsZ.getDefiningOp<stablehlo::TransposeOp>();
21761+
auto lhsZT = lhsZ.getDefiningOp<stablehlo::TransposeOp>();
21762+
if (lhsZ == rhsZ || rhsZT && rhsZT.getOperand() == lhsZ ||
21763+
lhsZT && lhsZT.getOperand() == rhsZ) {
21764+
auto precision =
21765+
op.getPrecisionConfig().value_or(stablehlo::PrecisionConfigAttr());
21766+
auto algorithm =
21767+
op.getAlgorithm().value_or(stablehlo::DotAlgorithmAttr());
21768+
21769+
auto newDot = rewriter.create<stablehlo::DotGeneralOp>(
21770+
op.getLoc(), op.getType(), lhsZ, rhsZ, op.getDotDimensionNumbers(),
21771+
precision, algorithm);
21772+
21773+
Value combinedScalar;
21774+
if (lhsScalar && rhsScalar) {
21775+
combinedScalar = rewriter.create<stablehlo::MulOp>(
21776+
op.getLoc(), lhsScalar, rhsScalar);
21777+
} else {
21778+
combinedScalar = lhsScalar ? lhsScalar : rhsScalar;
21779+
}
21780+
21781+
// Multiply (a*b) * dot_general(Z, Z)
21782+
Value result = rewriter.create<stablehlo::MulOp>(
21783+
op.getLoc(), combinedScalar, newDot.getResult());
21784+
rewriter.replaceOp(op, result);
21785+
return success();
21786+
}
21787+
21788+
return failure();
21789+
}
21790+
};
21791+
2170421792
struct ChainedMultiplyToPower final
2170521793
: public CheckedOpRewritePattern<stablehlo::MulOp, ChainedMultiplyToPower> {
2170621794
using CheckedOpRewritePattern<
@@ -26164,6 +26252,7 @@ struct EnzymeHLOOptPass
2616426252
(no_nan || all_finite), context);
2616526253

2616626254
patterns.add<TransposeSymmetricSimplify>(context);
26255+
patterns.add<FactorScalarsInDotGeneral>(context);
2616726256

2616826257
// clang-format off
2616926258
patterns.add<

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,11 @@ def ApplyTransposeSymmetricSimplify : EnzymeHLOPatternOp<
604604
let patterns = ["TransposeSymmetricSimplify"];
605605
}
606606

607+
def ApplyFactorScalarsInDotGeneral : EnzymeHLOPatternOp<
608+
"factor_scalars_in_dot_general"> {
609+
let patterns = ["FactorScalarsInDotGeneral"];
610+
}
611+
607612
def ApplyTransposeElementwisePatterns : EnzymeHLOParameterizedPatternOp<
608613
"transpose_elementwise"> {
609614
let arguments = (ins OptionalAttr<I64Attr>:$benefit, BoolAttr:$parameter);
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=factor_scalars_in_dot_general" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s
2+
3+
func.func @pass1(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} {
4+
%0 = stablehlo.constant dense<4.0> : tensor<10x10xf64>
5+
%1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64>
6+
%2 = stablehlo.dot_general %1, %arg0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64>
7+
return %2 : tensor<10x10xf64>
8+
}
9+
10+
// CHECK: func.func @pass1(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} {
11+
// CHECK-NEXT: %cst = stablehlo.constant dense<4.000000e+00> : tensor<10x10xf64>
12+
// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64>
13+
// CHECK-NEXT: %1 = stablehlo.multiply %cst, %0 : tensor<10x10xf64>
14+
// CHECK-NEXT: return %1 : tensor<10x10xf64>
15+
// CHECK-NEXT: }
16+
17+
func.func @pass2(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} {
18+
%0 = stablehlo.constant dense<4.0> : tensor<10x10xf64>
19+
%1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64>
20+
%2 = stablehlo.constant dense<2.0> : tensor<10x10xf64>
21+
%3 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64>
22+
%4 = stablehlo.dot_general %1, %3, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64>
23+
return %4 : tensor<10x10xf64>
24+
}
25+
26+
// CHECK: func.func @pass2(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} {
27+
// CHECK-NEXT: %cst = stablehlo.constant dense<4.000000e+00> : tensor<10x10xf64>
28+
// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64>
29+
// CHECK-NEXT: %1 = stablehlo.multiply %cst, %cst : tensor<10x10xf64>
30+
// CHECK-NEXT: %2 = stablehlo.multiply %1, %0 : tensor<10x10xf64>
31+
// CHECK-NEXT: return %2 : tensor<10x10xf64>
32+
// CHECK-NEXT: }
33+
34+
func.func @pass3(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} {
35+
%0 = stablehlo.constant dense<4.0> : tensor<10x10xf64>
36+
%1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64>
37+
%2 = stablehlo.constant dense<2.0> : tensor<10x10xf64>
38+
%3 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<10x10xf64>) -> tensor<10x10xf64>
39+
%4 = stablehlo.multiply %2, %3 : tensor<10x10xf64>
40+
%5 = stablehlo.dot_general %1, %4, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64>
41+
return %5 : tensor<10x10xf64>
42+
}
43+
44+
//CHECK: func.func @pass3(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} {
45+
//CHECK-NEXT: %cst = stablehlo.constant dense<2.000000e+00> : tensor<10x10xf64>
46+
//CHECK-NEXT: %cst_0 = stablehlo.constant dense<4.000000e+00> : tensor<10x10xf64>
47+
//CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<10x10xf64>) -> tensor<10x10xf64>
48+
//CHECK-NEXT: %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64>
49+
//CHECK-NEXT: %2 = stablehlo.multiply %cst_0, %cst : tensor<10x10xf64>
50+
//CHECK-NEXT: %3 = stablehlo.multiply %2, %1 : tensor<10x10xf64>
51+
//CHECK-NEXT: return %3 : tensor<10x10xf64>
52+
//CHECK-NEXT: }
53+
54+
func.func @fail1(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} {
55+
%0 = stablehlo.constant dense<4.0> : tensor<10x10xf64>
56+
%1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64>
57+
%2 = stablehlo.dot_general %1, %arg0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64>
58+
%3 = stablehlo.add %2, %1 : tensor<10x10xf64>
59+
return %3 : tensor<10x10xf64>
60+
}
61+
62+
// CHECK: func.func @fail1(%arg0: tensor<10x10xf64> {enzymexla.memory_effects = []}) -> tensor<10x10xf64> attributes {enzymexla.memory_effects = []} {
63+
// CHECK-NEXT: %cst = stablehlo.constant dense<4.000000e+00> : tensor<10x10xf64>
64+
// CHECK-NEXT: %0 = stablehlo.multiply %cst, %arg0 : tensor<10x10xf64>
65+
// CHECK-NEXT: %1 = stablehlo.dot_general %0, %arg0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64>
66+
// CHECK-NEXT: %2 = stablehlo.add %1, %0 : tensor<10x10xf64>
67+
// CHECK-NEXT: return %2 : tensor<10x10xf64>
68+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)