Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21740,6 +21740,149 @@ struct GatherElementwise
}
};

// (a1 * A) × (b1 * B) → (a1 * b1) * (A × B)
struct FactorScalarsInDotGeneral final
: public CheckedOpRewritePattern<stablehlo::DotGeneralOp,
FactorScalarsInDotGeneral> {
using CheckedOpRewritePattern<
stablehlo::DotGeneralOp,
FactorScalarsInDotGeneral>::CheckedOpRewritePattern;

LogicalResult matchAndRewriteImpl(stablehlo::DotGeneralOp op,
PatternRewriter &rewriter) const {
auto lhs = op.getLhs();
auto rhs = op.getRhs();

// first check if atleast one of lhs or rhs has a
// mutliplication with a scalar
if (!canRewriteOperation(op, lhs, rhs))
return failure();

// From v, extract scalar * tensor, and return true if the operation is not
// used elsewhere.
auto extractMul = [&](Value v, Value &scalar, Value &z) -> bool {
auto mulOp = v.getDefiningOp<stablehlo::MulOp>();
if (!mulOp) { // set default scalar to 1
scalar = nullptr;
z = v;
return true;
}
if (!isOnlyUsedInOperation(mulOp, op)) {
return false;
}

Value mLhs = mulOp.getLhs();
Value mRhs = mulOp.getRhs();

SplatElementsAttr splatAttr;
if (matchPattern(mLhs, m_Constant(&splatAttr))) {
auto mLhsType = cast<RankedTensorType>(mLhs.getType());
auto scalarType = RankedTensorType::get({}, mLhsType.getElementType());
scalar =
stablehlo::ConstantOp::create(rewriter, op.getLoc(), scalarType,
splatAttr.resizeSplat(scalarType));
z = mRhs;
} else if (matchPattern(mRhs, m_Constant(&splatAttr))) {
auto mRhsType = cast<RankedTensorType>(mRhs.getType());
auto scalarType = RankedTensorType::get({}, mRhsType.getElementType());
scalar =
stablehlo::ConstantOp::create(rewriter, op.getLoc(), scalarType,
splatAttr.resizeSplat(scalarType));
z = mLhs;
} else if (auto lhsBcast =
mLhs.getDefiningOp<stablehlo::BroadcastInDimOp>()) {
if (cast<RankedTensorType>(lhsBcast.getOperand().getType()).getRank() ==
0) {
scalar = lhsBcast.getOperand();
z = mRhs;
} else {
scalar = nullptr;
return false;
}
} else if (auto rhsBcast =
mRhs.getDefiningOp<stablehlo::BroadcastInDimOp>()) {
if (cast<RankedTensorType>(rhsBcast.getOperand().getType()).getRank() ==
0) {
scalar = rhsBcast.getOperand();
z = mLhs;
} else {
scalar = nullptr;
return false;
}
} else { // If both are non-scalar, treat whole v as Z, no scalar
scalar = nullptr;
z = v;
}
return true;
};

Value lhsScalar, lhsZ;
Value rhsScalar, rhsZ;

auto lhsExtracted = extractMul(lhs, lhsScalar, lhsZ);
auto rhsExtracted = extractMul(rhs, rhsScalar, rhsZ);

assert(lhsExtracted && rhsExtracted);

auto newDot = stablehlo::DotGeneralOp::create(
rewriter, op.getLoc(), op.getType(), lhsZ, rhsZ,
op.getDotDimensionNumbers(), op.getPrecisionConfigAttr(),
op.getAlgorithmAttr());

Value combinedScalar;
if (lhsScalar && rhsScalar) {
combinedScalar =
stablehlo::MulOp::create(rewriter, op.getLoc(), lhsScalar, rhsScalar);
} else {
combinedScalar = lhsScalar ? lhsScalar : rhsScalar;
}

auto bcastedScalar = stablehlo::BroadcastInDimOp::create(
rewriter, op.getLoc(), newDot.getType(), combinedScalar,
rewriter.getDenseI64ArrayAttr({}));
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, bcastedScalar, newDot);
return success();
}

private:
bool canRewriteOperation(stablehlo::DotGeneralOp op, Value lhs,
Value rhs) const {
auto lhsMulOp = lhs.getDefiningOp<stablehlo::MulOp>();
auto rhsMulOp = rhs.getDefiningOp<stablehlo::MulOp>();
if (!lhsMulOp && !rhsMulOp)
return false; // nothing to do

if ((lhsMulOp && !isOnlyUsedInOperation(lhsMulOp, op)) ||
(rhsMulOp && !isOnlyUsedInOperation(rhsMulOp, op)))
return false; // better to not do anything

auto isScalar = [&](Value v) -> bool {
SplatElementsAttr splatAttr;
if (matchPattern(v, m_Constant(&splatAttr)))
return true;

auto bcastOp = v.getDefiningOp<stablehlo::BroadcastInDimOp>();
if (bcastOp &&
cast<RankedTensorType>(bcastOp.getOperand().getType()).getRank() == 0)
return true;

return false;
};

bool lhsHasScalar = false;
if (lhsMulOp) {
lhsHasScalar = isScalar(lhsMulOp.getLhs()) || isScalar(lhsMulOp.getRhs());
}

bool rhsHasScalar = false;
if (rhsMulOp) {
rhsHasScalar = isScalar(rhsMulOp.getLhs()) || isScalar(rhsMulOp.getRhs());
}

return lhsHasScalar || rhsHasScalar;
}
};

struct ChainedMultiplyToPower final
: public CheckedOpRewritePattern<stablehlo::MulOp, ChainedMultiplyToPower> {
using CheckedOpRewritePattern<
Expand Down Expand Up @@ -26278,6 +26421,7 @@ struct EnzymeHLOOptPass
(no_nan || all_finite), context);

patterns.add<TransposeSymmetricSimplify>(context);
patterns.add<FactorScalarsInDotGeneral>(context);

// clang-format off
patterns.add<
Expand Down
5 changes: 5 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,11 @@ def ApplyTransposeSymmetricSimplify : EnzymeHLOPatternOp<
let patterns = ["TransposeSymmetricSimplify"];
}

def ApplyFactorScalarsInDotGeneral : EnzymeHLOPatternOp<
"factor_scalars_in_dot_general"> {
let patterns = ["FactorScalarsInDotGeneral"];
}

def ApplyTransposeElementwisePatterns : EnzymeHLOParameterizedPatternOp<
"transpose_elementwise"> {
let arguments = (ins OptionalAttr<I64Attr>:$benefit, BoolAttr:$parameter);
Expand Down
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def optimization_passes(
"dot_general_only_diagonal_access",
"divide_negated_operands_simplify",
"multiply_negated_operands_simplify",
"factor_scalars_in_dot_general",
]

# constant propagation patterns
Expand Down
80 changes: 80 additions & 0 deletions test/lit_tests/factor_scalars.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=factor_scalars_in_dot_general" --transform-interpreter --enzyme-hlo-remove-transform --enzyme-hlo-opt %s | FileCheck %s

func.func @main1(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> {
%0 = stablehlo.constant dense<4.0> : tensor<10x10xf64>
%1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64>
%2 = stablehlo.dot_general %1, %arg0, contracting_dims = [0] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64>
return %2 : tensor<10x10xf64>
}

// CHECK: func.func @main1(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> {
// CHECK-NEXT: %cst = stablehlo.constant dense<4.000000e+00> : tensor<10x10xf64>
// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [0] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64>
// CHECK-NEXT: %1 = stablehlo.multiply %cst, %0 : tensor<10x10xf64>
// CHECK-NEXT: return %1 : tensor<10x10xf64>
// CHECK-NEXT: }

func.func @main2(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> {
%0 = stablehlo.constant dense<4.0> : tensor<10x10xf64>
%1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64>
%2 = stablehlo.constant dense<2.0> : tensor<10x10xf64>
%3 = stablehlo.multiply %2, %arg0 : tensor<10x10xf64>
%4 = stablehlo.dot_general %1, %3, contracting_dims = [1] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64>
return %4 : tensor<10x10xf64>
}

// CHECK: func.func @main2(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> {
// CHECK-NEXT: %cst = stablehlo.constant dense<8.000000e+00> : tensor<10x10xf64>
// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [1] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64>
// CHECK-NEXT: %1 = stablehlo.multiply %cst, %0 : tensor<10x10xf64>
// CHECK-NEXT: return %1 : tensor<10x10xf64>
// CHECK-NEXT: }

func.func @main3(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> {
%0 = stablehlo.constant dense<4.0> : tensor<10x10xf64>
%1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64>
%2 = stablehlo.constant dense<2.0> : tensor<10x10xf64>
%3 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<10x10xf64>) -> tensor<10x10xf64>
%4 = stablehlo.multiply %2, %3 : tensor<10x10xf64>
%5 = stablehlo.dot_general %1, %4, contracting_dims = [1] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64>
return %5 : tensor<10x10xf64>
}

//CHECK: func.func @main3(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> {
//CHECK-NEXT: %cst = stablehlo.constant dense<8.000000e+00> : tensor<10x10xf64>
//CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [1] x [1] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64>
//CHECK-NEXT: %1 = stablehlo.multiply %cst, %0 : tensor<10x10xf64>
//CHECK-NEXT: return %1 : tensor<10x10xf64>
//CHECK-NEXT: }

func.func @main4(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> {
%0 = stablehlo.constant dense<4.0> : tensor<10x10xf64>
%1 = stablehlo.multiply %0, %arg0 : tensor<10x10xf64>
%2 = stablehlo.dot_general %1, %arg0, contracting_dims = [0] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64>
%3 = stablehlo.add %2, %1 : tensor<10x10xf64>
return %3 : tensor<10x10xf64>
}

// CHECK: func.func @main4(%arg0: tensor<10x10xf64>) -> tensor<10x10xf64> {
// CHECK-NEXT: %cst = stablehlo.constant dense<4.000000e+00> : tensor<10x10xf64>
// CHECK-NEXT: %0 = stablehlo.multiply %cst, %arg0 : tensor<10x10xf64>
// CHECK-NEXT: %1 = stablehlo.dot_general %0, %arg0, contracting_dims = [0] x [0] : (tensor<10x10xf64>, tensor<10x10xf64>) -> tensor<10x10xf64>
// CHECK-NEXT: %2 = stablehlo.add %1, %0 : tensor<10x10xf64>
// CHECK-NEXT: return %2 : tensor<10x10xf64>
// CHECK-NEXT: }

func.func @main5(%arg0: tensor<10x3xf64>, %arg1: tensor<3x10xf64>) -> tensor<10x10xf64> {
%0 = stablehlo.constant dense<4.0> : tensor<10x3xf64>
%1 = stablehlo.multiply %0, %arg0 : tensor<10x3xf64>
%2 = stablehlo.constant dense<2.0> : tensor<3x10xf64>
%3 = stablehlo.multiply %arg1, %2 : tensor<3x10xf64>
%4 = stablehlo.dot_general %1, %3, contracting_dims = [1] x [0] : (tensor<10x3xf64>, tensor<3x10xf64>) -> tensor<10x10xf64>
return %4 : tensor<10x10xf64>
}

// CHECK: func.func @main5(%arg0: tensor<10x3xf64>, %arg1: tensor<3x10xf64>) -> tensor<10x10xf64> {
// CHECK-NEXT: %cst = stablehlo.constant dense<8.000000e+00> : tensor<10x10xf64>
// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<10x3xf64>, tensor<3x10xf64>) -> tensor<10x10xf64>
// CHECK-NEXT: %1 = stablehlo.multiply %cst, %0 : tensor<10x10xf64>
// CHECK-NEXT: return %1 : tensor<10x10xf64>
// CHECK-NEXT: }
Loading