Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 144 additions & 3 deletions lib/Transforms/HandshakeOptimizeBitwidths.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1230,6 +1230,142 @@ struct ArithSelect : public OpRewritePattern<handshake::SelectOp> {
bool forward;
};

/// Optimizes unsigned right-shifts with a constant as a forward pass.
struct ArithShrUIFW : OpRewritePattern<handshake::ShRUIOp> {
ArithShrUIFW(Pass::Statistic &bitwidthReduced, MLIRContext *ctx,
NameAnalysis &namer)
: OpRewritePattern(ctx), bitwidthReduced(bitwidthReduced), namer(namer) {}

LogicalResult matchAndRewrite(handshake::ShRUIOp op,
PatternRewriter &rewriter) const override {
auto [lhs, lhsExt] = getMinimalValueWithExtType(op.getLhs());
unsigned inputBitwidth = lhs.getType().getDataBitWidth();
unsigned currentBitwidth = op.getType().getDataBitWidth();
if (inputBitwidth >= currentBitwidth)
return failure();

assert(lhsExt != ExtType::NONE && "expected an extension");
APInt numOfShiftPositions;
Value constantControl;
{
auto constantOp = op.getRhs().getDefiningOp<handshake::ConstantOp>();
if (!constantOp)
return failure();
numOfShiftPositions = cast<IntegerAttr>(constantOp.getValue()).getValue();
constantControl = constantOp.getCtrl();
}

// Other pattern (such as canonicalization pattern) should fold this case
// to a useful constant instead
if (numOfShiftPositions.uge(currentBitwidth))
return failure();

// The following optimizations can be performed here:
// * ZEXT: Shift amount is larger than the input bitwidth -> replace with 0.
// * SEXT: Shift amount 'c' is larger than the input bitwidth -> replace
// with 'c' many 0s leading 0s and copies of the sign bit otherwise.
// * Otherwise: We can perform the shift at the (lower) input bitwidth
// enabling other ops to be optimized in the forward pass.
if (lhsExt == ExtType::ZEXT) {
if (numOfShiftPositions.uge(inputBitwidth)) {
// The entire input is shifted away and only 0 bits from the extension
// remain.
auto constant = rewriter.replaceOpWithNewOp<handshake::ConstantOp>(
op, op.getType(), rewriter.getZeroAttr(op.getType().getDataType()),
constantControl);
inheritBB(op, constant);
return success();
}

// Otherwise, we can perform the entire shift at the input bitwidth and
// zero-extend back to the original bitwidth.
modArithOp(op, {lhs, lhsExt}, {op.getRhs(), ExtType::NONE}, inputBitwidth,
ExtType::ZEXT, rewriter, namer);
++bitwidthReduced;
return success();
}

// SEXT case.
// The thing all the logic has in common here is that the top 'c' bits are
// always zero (due to being an unsigned shift) and that any bits between
// the top 'c' and (inputBitwidth - c) bits are copies of the sign-bit of
// the input value.
//
// | 0...0 | s ... s | s X ... Y |
// -------------------------------
// c many 0s in the front.
// (inputBitwidth - c) many sign bits (s).
// Original input (lead by sign-bit s) shifted by c on
// the right.
//
// If c is greater than input bitwidth than there are only c many 0s and
// copies of the sign-bit in the remaining bits.
// In all cases we can perform shifts at the input bitwidth or less and use
// extensions to restore the original output.
// These extension operations can be folded into other operations if
// redundant or leveraged by other patterns.
ChannelVal result;
if (numOfShiftPositions.ult(inputBitwidth)) {
// c is less than the input bitwidth, meaning other bits from the input
// besides the sign-bit are preserved in the output.

// Perform the shift on the bitwidth of lhs.
Value newRhs =
modBitWidth({op.getRhs(), ExtType::NONE}, inputBitwidth, rewriter);
result = rewriter.create<handshake::ShRUIOp>(op.getLoc(), lhs, newRhs);

// Now truncate the result to make the sign-bit after shifting the
// top-bit again.
// Note that this even works when 'c' is greater than the difference
// between the input and current bit width.
result = rewriter.create<handshake::TruncIOp>(
op.getLoc(),
result.getType().withDataType(rewriter.getIntegerType(
inputBitwidth - numOfShiftPositions.getZExtValue())),
result);
} else {
// Our shift amount is larger than the input bitwidth but the input
// bitwidth is sign-extended. The only thing that remains from the input
// is the sign-bit that was copied to the bits [inputWidth:currentBitWidth
// - c].
Value inputBWM1 = rewriter.create<handshake::ConstantOp>(
op.getLoc(),
rewriter.getIntegerAttr(lhs.getType().getDataType(),
inputBitwidth - 1),
constantControl);
// Shift away all values of lhs other than the sign-bit.
ChannelVal signBit =
rewriter.create<handshake::ShRUIOp>(op.getLoc(), lhs, inputBWM1);
// Truncate down to just the sign-bit.
// Note that this even works when 'c' is greater than the difference
// between the input and current bit width.
result = rewriter.create<handshake::TruncIOp>(
op.getLoc(), signBit.getType().withDataType(rewriter.getI1Type()),
signBit);
}
// Result pattern is now | s X ... Y |.

// Fill with the sign-bit up until excluding the top 'c' bits.
// Result now follows the | s ... s | s X ... Y | pattern.
result = rewriter.create<handshake::ExtSIOp>(
op.getLoc(),
op.getType().withDataType(rewriter.getIntegerType(
currentBitwidth - numOfShiftPositions.getZExtValue())),
result);

// Fill the top 'c' bits with zero to turn the result into the desired
// | 0...0 | s ... s | s X ... Y | pattern.
rewriter.replaceOpWithNewOp<handshake::ExtUIOp>(op, op.getType(), result);
++bitwidthReduced;
return success();
}

private:
Pass::Statistic &bitwidthReduced;
/// A reference to the pass's name analysis.
NameAnalysis &namer;
};

/// Optimizes the bitwidth of shift-type operations. The first template
/// parameter is meant to be either handshake::ShLIOp, handshake::ShRSIOp, or
/// handshake::ShRUIOp. In both modes (forward and backward), the matched
Expand Down Expand Up @@ -1736,9 +1872,14 @@ void HandshakeOptimizeBitwidthsPass::addArithPatterns(
// is dangerous if the shift is used as multiplication.
// Therefore, removing "ArithShift<handshake::ShLIOp>" from the patterns for
// now
patterns.add<ArithShift<handshake::ShRSIOp>, ArithShift<handshake::ShRUIOp>,
ArithSelect>(bitwidthReduced, forward, ctx,
getAnalysis<NameAnalysis>());
patterns.add<ArithShift<handshake::ShRSIOp>, ArithSelect>(
bitwidthReduced, forward, ctx, getAnalysis<NameAnalysis>());
if (!forward)
patterns.add<ArithShift<handshake::ShRUIOp>>(bitwidthReduced, forward, ctx,
getAnalysis<NameAnalysis>());
else
patterns.add<ArithShrUIFW>(bitwidthReduced, ctx,
getAnalysis<NameAnalysis>());

patterns.add<ArithExtToTruncOpt>(bitwidthReduced, ctx,
getAnalysis<NameAnalysis>());
Expand Down
89 changes: 84 additions & 5 deletions test/Transforms/HandshakeOptimizeBitwidths/arith-forward.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,13 @@ handshake.func @shrsiFW(%arg0: !handshake.channel<i16>, %start: !handshake.contr
// CHECK-LABEL: handshake.func @shruiFW(
// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel<i16>,
// CHECK-SAME: %[[VAL_1:.*]]: !handshake.control<>, ...) -> !handshake.channel<i32> attributes {argNames = ["arg0", "start"], resNames = ["out0"]} {
// CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 4 : i16} : <>, <i16>
// CHECK: %[[VAL_3:.*]] = shrui %[[VAL_0]], %[[VAL_2]] : <i16>
// CHECK: %[[VAL_4:.*]] = trunci %[[VAL_3]] : <i16> to <i12>
// CHECK: %[[VAL_5:.*]] = extsi %[[VAL_4]] : <i12> to <i32>
// CHECK: end %[[VAL_5]] : <i32>
// CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 4 : i32} : <>, <i32>
// CHECK: %[[VAL_3:.*]] = trunci %[[VAL_2]] : <i32> to <i16>
// CHECK: %[[VAL_4:.*]] = shrui %[[VAL_0]], %[[VAL_3]] : <i16>
// CHECK: %[[VAL_5:.*]] = trunci %[[VAL_4]] : <i16> to <i12>
// CHECK: %[[VAL_6:.*]] = extsi %[[VAL_5]] : <i12> to <i28>
// CHECK: %[[VAL_7:.*]] = extui %[[VAL_6]] : <i28> to <i32>
// CHECK: end %[[VAL_7]] : <i32>
// CHECK: }
handshake.func @shruiFW(%arg0: !handshake.channel<i16>, %start: !handshake.control<>) -> !handshake.channel<i32> {
%cst = handshake.constant %start {value = 4 : i4} : <>, <i4>
Expand All @@ -185,6 +187,83 @@ handshake.func @shruiFW(%arg0: !handshake.channel<i16>, %start: !handshake.contr
end %res : <i32>
}

// -----

// CHECK-LABEL: handshake.func @shrui_edge_case(
// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel<i29>,
// CHECK-SAME: %[[VAL_1:.*]]: !handshake.control<>, ...) -> !handshake.channel<i32> attributes {argNames = ["arg0", "start"], resNames = ["out0"]} {
// CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 4 : i32} : <>, <i32>
// CHECK: %[[VAL_3:.*]] = trunci %[[VAL_2]] : <i32> to <i29>
// CHECK: %[[VAL_4:.*]] = shrui %[[VAL_0]], %[[VAL_3]] : <i29>
// CHECK: %[[VAL_5:.*]] = trunci %[[VAL_4]] : <i29> to <i25>
// CHECK: %[[VAL_6:.*]] = extsi %[[VAL_5]] : <i25> to <i28>
// CHECK: %[[VAL_7:.*]] = extui %[[VAL_6]] : <i28> to <i32>
// CHECK: end %[[VAL_7]] : <i32>
// CHECK: }
handshake.func @shrui_edge_case(%arg0: !handshake.channel<i29>, %start: !handshake.control<>) -> !handshake.channel<i32> {
%cst = handshake.constant %start {value = 4 : i4} : <>, <i4>
%ext0 = extsi %arg0 : <i29> to <i32>
%extCst = extsi %cst : <i4> to <i32>
%res = shrui %ext0, %extCst : <i32>
end %res : <i32>
}

// -----

// CHECK-LABEL: handshake.func @shrui_si_overflow(
// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel<i16>,
// CHECK-SAME: %[[VAL_1:.*]]: !handshake.control<>, ...) -> !handshake.channel<i32> attributes {argNames = ["arg0", "start"], resNames = ["out0"]} {
// CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 15 : i16} : <>, <i16>
// CHECK: %[[VAL_3:.*]] = shrui %[[VAL_0]], %[[VAL_2]] : <i16>
// CHECK: %[[VAL_4:.*]] = trunci %[[VAL_3]] : <i16> to <i1>
// CHECK: %[[VAL_5:.*]] = extsi %[[VAL_4]] : <i1> to <i14>
// CHECK: %[[VAL_6:.*]] = extui %[[VAL_5]] : <i14> to <i32>
// CHECK: end %[[VAL_6]] : <i32>
// CHECK: }
handshake.func @shrui_si_overflow(%arg0: !handshake.channel<i16>, %start: !handshake.control<>) -> !handshake.channel<i32> {
%cst = handshake.constant %start {value = 18 : i8} : <>, <i8>
%ext0 = extsi %arg0 : <i16> to <i32>
%extCst = extsi %cst : <i8> to <i32>
%res = shrui %ext0, %extCst : <i32>
end %res : <i32>
}

// -----

// CHECK-LABEL: handshake.func @shrui_ui_FW(
// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel<i16>,
// CHECK-SAME: %[[VAL_1:.*]]: !handshake.control<>, ...) -> !handshake.channel<i32> attributes {argNames = ["arg0", "start"], resNames = ["out0"]} {
// CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 4 : i32} : <>, <i32>
// CHECK: %[[VAL_3:.*]] = trunci %[[VAL_2]] : <i32> to <i16>
// CHECK: %[[VAL_4:.*]] = shrui %[[VAL_0]], %[[VAL_3]] : <i16>
// CHECK: %[[VAL_5:.*]] = extui %[[VAL_4]] : <i16> to <i32>
// CHECK: end %[[VAL_5]] : <i32>
// CHECK: }
handshake.func @shrui_ui_FW(%arg0: !handshake.channel<i16>, %start: !handshake.control<>) -> !handshake.channel<i32> {
%cst = handshake.constant %start {value = 4 : i4} : <>, <i4>
%ext0 = extui %arg0 : <i16> to <i32>
%extCst = extui %cst : <i4> to <i32>
%res = shrui %ext0, %extCst : <i32>
end %res : <i32>
}

// -----

// CHECK-LABEL: handshake.func @shrui_ui_overflowFW(
// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel<i16>,
// CHECK-SAME: %[[VAL_1:.*]]: !handshake.control<>, ...) -> !handshake.channel<i32> attributes {argNames = ["arg0", "start"], resNames = ["out0"]} {
// CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 0 : i32} : <>, <i32>
// CHECK: end %[[VAL_2]] : <i32>
// CHECK: }
handshake.func @shrui_ui_overflowFW(%arg0: !handshake.channel<i16>, %start: !handshake.control<>) -> !handshake.channel<i32> {
%cst = handshake.constant %start {value = 17 : i16} : <>, <i16>
%ext0 = extui %arg0 : <i16> to <i32>
%extCst = extui %cst : <i16> to <i32>
%res = shrui %ext0, %extCst : <i32>
end %res : <i32>
}


// -----

// CHECK-LABEL: handshake.func @cmpiFW(
Expand Down
22 changes: 22 additions & 0 deletions test/Transforms/HandshakeOptimizeBitwidths/gh792.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: dynamatic-opt --handshake-optimize-bitwidths --remove-operation-names %s --split-input-file | FileCheck %s

// CHECK-LABEL: handshake.func @test(
// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel<i2>,
// CHECK-SAME: %[[VAL_1:.*]]: !handshake.channel<i32>,
// CHECK-SAME: %[[VAL_2:.*]]: !handshake.control<>, ...) -> (!handshake.channel<i32>, !handshake.control<>) attributes {argNames = ["var0", "var2", "start"], resNames = ["out0", "end"]} {
// CHECK: %[[VAL_3:.*]] = source : <>
// CHECK: %[[VAL_4:.*]] = constant %[[VAL_3]] {value = 1 : i2} : <>, <i2>
// CHECK: %[[VAL_5:.*]] = shrui %[[VAL_0]], %[[VAL_4]] : <i2>
// CHECK: %[[VAL_6:.*]] = trunci %[[VAL_5]] : <i2> to <i1>
// CHECK: %[[VAL_7:.*]] = extsi %[[VAL_6]] : <i1> to <i11>
// CHECK: %[[VAL_8:.*]] = extui %[[VAL_7]] : <i11> to <i32>
// CHECK: end %[[VAL_8]], %[[VAL_2]] : <i32>, <>
// CHECK: }
handshake.func @test(%arg0: !handshake.channel<i2>, %arg1: !handshake.channel<i32>, %arg2: !handshake.control<>, ...) -> (!handshake.channel<i32>, !handshake.control<>) attributes {argNames = ["var0", "var2", "start"], resNames = ["out0", "end"]} {
%11 = extsi %arg0 : <i2> to <i16>
%13 = source : <>
%14 = constant %13 {value = 5 : i16} : <>, <i16>
%16 = shrui %11, %14 : <i16>
%17 = extui %16 : <i16> to <i32>
end %17, %arg2 : <i32>, <>
}
Loading