diff --git a/lib/Transforms/HandshakeOptimizeBitwidths.cpp b/lib/Transforms/HandshakeOptimizeBitwidths.cpp index ba3be7ec50..4eec5c8c2c 100644 --- a/lib/Transforms/HandshakeOptimizeBitwidths.cpp +++ b/lib/Transforms/HandshakeOptimizeBitwidths.cpp @@ -1230,6 +1230,142 @@ struct ArithSelect : public OpRewritePattern { bool forward; }; +/// Optimizes unsigned right-shifts with a constant as a forward pass. +struct ArithShrUIFW : OpRewritePattern { + ArithShrUIFW(Pass::Statistic &bitwidthReduced, MLIRContext *ctx, + NameAnalysis &namer) + : OpRewritePattern(ctx), bitwidthReduced(bitwidthReduced), namer(namer) {} + + LogicalResult matchAndRewrite(handshake::ShRUIOp op, + PatternRewriter &rewriter) const override { + auto [lhs, lhsExt] = getMinimalValueWithExtType(op.getLhs()); + unsigned inputBitwidth = lhs.getType().getDataBitWidth(); + unsigned currentBitwidth = op.getType().getDataBitWidth(); + if (inputBitwidth >= currentBitwidth) + return failure(); + + assert(lhsExt != ExtType::NONE && "expected an extension"); + APInt numOfShiftPositions; + Value constantControl; + { + auto constantOp = op.getRhs().getDefiningOp(); + if (!constantOp) + return failure(); + numOfShiftPositions = cast(constantOp.getValue()).getValue(); + constantControl = constantOp.getCtrl(); + } + + // Other pattern (such as canonicalization pattern) should fold this case + // to a useful constant instead + if (numOfShiftPositions.uge(currentBitwidth)) + return failure(); + + // The following optimizations can be performed here: + // * ZEXT: Shift amount is larger than the input bitwidth -> replace with 0. + // * SEXT: Shift amount 'c' is larger than the input bitwidth -> replace + // with 'c' many 0s leading 0s and copies of the sign bit otherwise. + // * Otherwise: We can perform the shift at the (lower) input bitwidth + // enabling other ops to be optimized in the forward pass. + if (lhsExt == ExtType::ZEXT) { + if (numOfShiftPositions.uge(inputBitwidth)) { + // The entire input is shifted away and only 0 bits from the extension + // remain. + auto constant = rewriter.replaceOpWithNewOp( + op, op.getType(), rewriter.getZeroAttr(op.getType().getDataType()), + constantControl); + inheritBB(op, constant); + return success(); + } + + // Otherwise, we can perform the entire shift at the input bitwidth and + // zero-extend back to the original bitwidth. + modArithOp(op, {lhs, lhsExt}, {op.getRhs(), ExtType::NONE}, inputBitwidth, + ExtType::ZEXT, rewriter, namer); + ++bitwidthReduced; + return success(); + } + + // SEXT case. + // The thing all the logic has in common here is that the top 'c' bits are + // always zero (due to being an unsigned shift) and that any bits between + // the top 'c' and (inputBitwidth - c) bits are copies of the sign-bit of + // the input value. + // + // | 0...0 | s ... s | s X ... Y | + // ------------------------------- + // c many 0s in the front. + // (inputBitwidth - c) many sign bits (s). + // Original input (lead by sign-bit s) shifted by c on + // the right. + // + // If c is greater than input bitwidth than there are only c many 0s and + // copies of the sign-bit in the remaining bits. + // In all cases we can perform shifts at the input bitwidth or less and use + // extensions to restore the original output. + // These extension operations can be folded into other operations if + // redundant or leveraged by other patterns. + ChannelVal result; + if (numOfShiftPositions.ult(inputBitwidth)) { + // c is less than the input bitwidth, meaning other bits from the input + // besides the sign-bit are preserved in the output. + + // Perform the shift on the bitwidth of lhs. + Value newRhs = + modBitWidth({op.getRhs(), ExtType::NONE}, inputBitwidth, rewriter); + result = rewriter.create(op.getLoc(), lhs, newRhs); + + // Now truncate the result to make the sign-bit after shifting the + // top-bit again. + // Note that this even works when 'c' is greater than the difference + // between the input and current bit width. + result = rewriter.create( + op.getLoc(), + result.getType().withDataType(rewriter.getIntegerType( + inputBitwidth - numOfShiftPositions.getZExtValue())), + result); + } else { + // Our shift amount is larger than the input bitwidth but the input + // bitwidth is sign-extended. The only thing that remains from the input + // is the sign-bit that was copied to the bits [inputWidth:currentBitWidth + // - c]. + Value inputBWM1 = rewriter.create( + op.getLoc(), + rewriter.getIntegerAttr(lhs.getType().getDataType(), + inputBitwidth - 1), + constantControl); + // Shift away all values of lhs other than the sign-bit. + ChannelVal signBit = + rewriter.create(op.getLoc(), lhs, inputBWM1); + // Truncate down to just the sign-bit. + // Note that this even works when 'c' is greater than the difference + // between the input and current bit width. + result = rewriter.create( + op.getLoc(), signBit.getType().withDataType(rewriter.getI1Type()), + signBit); + } + // Result pattern is now | s X ... Y |. + + // Fill with the sign-bit up until excluding the top 'c' bits. + // Result now follows the | s ... s | s X ... Y | pattern. + result = rewriter.create( + op.getLoc(), + op.getType().withDataType(rewriter.getIntegerType( + currentBitwidth - numOfShiftPositions.getZExtValue())), + result); + + // Fill the top 'c' bits with zero to turn the result into the desired + // | 0...0 | s ... s | s X ... Y | pattern. + rewriter.replaceOpWithNewOp(op, op.getType(), result); + ++bitwidthReduced; + return success(); + } + +private: + Pass::Statistic &bitwidthReduced; + /// A reference to the pass's name analysis. + NameAnalysis &namer; +}; + /// Optimizes the bitwidth of shift-type operations. The first template /// parameter is meant to be either handshake::ShLIOp, handshake::ShRSIOp, or /// handshake::ShRUIOp. In both modes (forward and backward), the matched @@ -1736,9 +1872,14 @@ void HandshakeOptimizeBitwidthsPass::addArithPatterns( // is dangerous if the shift is used as multiplication. // Therefore, removing "ArithShift" from the patterns for // now - patterns.add, ArithShift, - ArithSelect>(bitwidthReduced, forward, ctx, - getAnalysis()); + patterns.add, ArithSelect>( + bitwidthReduced, forward, ctx, getAnalysis()); + if (!forward) + patterns.add>(bitwidthReduced, forward, ctx, + getAnalysis()); + else + patterns.add(bitwidthReduced, ctx, + getAnalysis()); patterns.add(bitwidthReduced, ctx, getAnalysis()); diff --git a/test/Transforms/HandshakeOptimizeBitwidths/arith-forward.mlir b/test/Transforms/HandshakeOptimizeBitwidths/arith-forward.mlir index 7b8226a3ed..94d72cf12a 100644 --- a/test/Transforms/HandshakeOptimizeBitwidths/arith-forward.mlir +++ b/test/Transforms/HandshakeOptimizeBitwidths/arith-forward.mlir @@ -171,11 +171,13 @@ handshake.func @shrsiFW(%arg0: !handshake.channel, %start: !handshake.contr // CHECK-LABEL: handshake.func @shruiFW( // CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, // CHECK-SAME: %[[VAL_1:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["arg0", "start"], resNames = ["out0"]} { -// CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 4 : i16} : <>, -// CHECK: %[[VAL_3:.*]] = shrui %[[VAL_0]], %[[VAL_2]] : -// CHECK: %[[VAL_4:.*]] = trunci %[[VAL_3]] : to -// CHECK: %[[VAL_5:.*]] = extsi %[[VAL_4]] : to -// CHECK: end %[[VAL_5]] : +// CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 4 : i32} : <>, +// CHECK: %[[VAL_3:.*]] = trunci %[[VAL_2]] : to +// CHECK: %[[VAL_4:.*]] = shrui %[[VAL_0]], %[[VAL_3]] : +// CHECK: %[[VAL_5:.*]] = trunci %[[VAL_4]] : to +// CHECK: %[[VAL_6:.*]] = extsi %[[VAL_5]] : to +// CHECK: %[[VAL_7:.*]] = extui %[[VAL_6]] : to +// CHECK: end %[[VAL_7]] : // CHECK: } handshake.func @shruiFW(%arg0: !handshake.channel, %start: !handshake.control<>) -> !handshake.channel { %cst = handshake.constant %start {value = 4 : i4} : <>, @@ -185,6 +187,83 @@ handshake.func @shruiFW(%arg0: !handshake.channel, %start: !handshake.contr end %res : } +// ----- + +// CHECK-LABEL: handshake.func @shrui_edge_case( +// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, +// CHECK-SAME: %[[VAL_1:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["arg0", "start"], resNames = ["out0"]} { +// CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 4 : i32} : <>, +// CHECK: %[[VAL_3:.*]] = trunci %[[VAL_2]] : to +// CHECK: %[[VAL_4:.*]] = shrui %[[VAL_0]], %[[VAL_3]] : +// CHECK: %[[VAL_5:.*]] = trunci %[[VAL_4]] : to +// CHECK: %[[VAL_6:.*]] = extsi %[[VAL_5]] : to +// CHECK: %[[VAL_7:.*]] = extui %[[VAL_6]] : to +// CHECK: end %[[VAL_7]] : +// CHECK: } +handshake.func @shrui_edge_case(%arg0: !handshake.channel, %start: !handshake.control<>) -> !handshake.channel { + %cst = handshake.constant %start {value = 4 : i4} : <>, + %ext0 = extsi %arg0 : to + %extCst = extsi %cst : to + %res = shrui %ext0, %extCst : + end %res : +} + +// ----- + +// CHECK-LABEL: handshake.func @shrui_si_overflow( +// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, +// CHECK-SAME: %[[VAL_1:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["arg0", "start"], resNames = ["out0"]} { +// CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 15 : i16} : <>, +// CHECK: %[[VAL_3:.*]] = shrui %[[VAL_0]], %[[VAL_2]] : +// CHECK: %[[VAL_4:.*]] = trunci %[[VAL_3]] : to +// CHECK: %[[VAL_5:.*]] = extsi %[[VAL_4]] : to +// CHECK: %[[VAL_6:.*]] = extui %[[VAL_5]] : to +// CHECK: end %[[VAL_6]] : +// CHECK: } +handshake.func @shrui_si_overflow(%arg0: !handshake.channel, %start: !handshake.control<>) -> !handshake.channel { + %cst = handshake.constant %start {value = 18 : i8} : <>, + %ext0 = extsi %arg0 : to + %extCst = extsi %cst : to + %res = shrui %ext0, %extCst : + end %res : +} + +// ----- + +// CHECK-LABEL: handshake.func @shrui_ui_FW( +// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, +// CHECK-SAME: %[[VAL_1:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["arg0", "start"], resNames = ["out0"]} { +// CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 4 : i32} : <>, +// CHECK: %[[VAL_3:.*]] = trunci %[[VAL_2]] : to +// CHECK: %[[VAL_4:.*]] = shrui %[[VAL_0]], %[[VAL_3]] : +// CHECK: %[[VAL_5:.*]] = extui %[[VAL_4]] : to +// CHECK: end %[[VAL_5]] : +// CHECK: } +handshake.func @shrui_ui_FW(%arg0: !handshake.channel, %start: !handshake.control<>) -> !handshake.channel { + %cst = handshake.constant %start {value = 4 : i4} : <>, + %ext0 = extui %arg0 : to + %extCst = extui %cst : to + %res = shrui %ext0, %extCst : + end %res : +} + +// ----- + +// CHECK-LABEL: handshake.func @shrui_ui_overflowFW( +// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, +// CHECK-SAME: %[[VAL_1:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["arg0", "start"], resNames = ["out0"]} { +// CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 0 : i32} : <>, +// CHECK: end %[[VAL_2]] : +// CHECK: } +handshake.func @shrui_ui_overflowFW(%arg0: !handshake.channel, %start: !handshake.control<>) -> !handshake.channel { + %cst = handshake.constant %start {value = 17 : i16} : <>, + %ext0 = extui %arg0 : to + %extCst = extui %cst : to + %res = shrui %ext0, %extCst : + end %res : +} + + // ----- // CHECK-LABEL: handshake.func @cmpiFW( diff --git a/test/Transforms/HandshakeOptimizeBitwidths/gh792.mlir b/test/Transforms/HandshakeOptimizeBitwidths/gh792.mlir new file mode 100644 index 0000000000..ba611cd0f4 --- /dev/null +++ b/test/Transforms/HandshakeOptimizeBitwidths/gh792.mlir @@ -0,0 +1,22 @@ +// RUN: dynamatic-opt --handshake-optimize-bitwidths --remove-operation-names %s --split-input-file | FileCheck %s + +// CHECK-LABEL: handshake.func @test( +// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, +// CHECK-SAME: %[[VAL_1:.*]]: !handshake.channel, +// CHECK-SAME: %[[VAL_2:.*]]: !handshake.control<>, ...) -> (!handshake.channel, !handshake.control<>) attributes {argNames = ["var0", "var2", "start"], resNames = ["out0", "end"]} { +// CHECK: %[[VAL_3:.*]] = source : <> +// CHECK: %[[VAL_4:.*]] = constant %[[VAL_3]] {value = 1 : i2} : <>, +// CHECK: %[[VAL_5:.*]] = shrui %[[VAL_0]], %[[VAL_4]] : +// CHECK: %[[VAL_6:.*]] = trunci %[[VAL_5]] : to +// CHECK: %[[VAL_7:.*]] = extsi %[[VAL_6]] : to +// CHECK: %[[VAL_8:.*]] = extui %[[VAL_7]] : to +// CHECK: end %[[VAL_8]], %[[VAL_2]] : , <> +// CHECK: } +handshake.func @test(%arg0: !handshake.channel, %arg1: !handshake.channel, %arg2: !handshake.control<>, ...) -> (!handshake.channel, !handshake.control<>) attributes {argNames = ["var0", "var2", "start"], resNames = ["out0", "end"]} { + %11 = extsi %arg0 : to + %13 = source : <> + %14 = constant %13 {value = 5 : i16} : <>, + %16 = shrui %11, %14 : + %17 = extui %16 : to + end %17, %arg2 : , <> +}