-
Couldn't load subscription status.
- Fork 15k
[InstCombine] Fold shifts + selects with -1 to scmp(X, 0) #164129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-llvm-transforms Author: AZero13 (AZero13) ChangesThis is because the sign function with 0 tends to be folded to ashr and other things. Full diff: https://github.com/llvm/llvm-project/pull/164129.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index a8eb9b9cf6a84..71aa94045f6ec 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -4680,5 +4680,30 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
Align(MaskedLoadAlignment->getZExtValue()),
CondVal, FalseVal));
+ // Canonicalize sign function ashr pattern: select (icmp slt X, 1), ashr X,
+ // bitwidth-1, 1 -> scmp(X, 0)
+ // Also handles: select (icmp sgt X, 0), 1, ashr X, bitwidth-1 -> scmp(X, 0)
+ Value *X;
+ unsigned BitWidth = SI.getType()->getScalarSizeInBits();
+ CmpPredicate Pred;
+ Value *CmpLHS, *CmpRHS;
+
+ // Canonicalize sign function ashr patterns:
+ // select (icmp slt X, 1), ashr X, bitwidth-1, 1 -> scmp(X, 0)
+ // select (icmp sgt X, 0), 1, ashr X, bitwidth-1 -> scmp(X, 0)
+ if (match(&SI, m_Select(m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)),
+ m_Value(TrueVal), m_Value(FalseVal))) &&
+ ((Pred == ICmpInst::ICMP_SLT && match(CmpLHS, m_Value(X)) && match(CmpRHS, m_One()) &&
+ match(TrueVal, m_AShr(m_Deferred(X), m_SpecificInt(BitWidth - 1))) &&
+ match(FalseVal, m_One())) ||
+ (Pred == ICmpInst::ICMP_SGT && match(CmpLHS, m_Value(X)) && match(CmpRHS, m_Zero()) &&
+ match(TrueVal, m_One()) &&
+ match(FalseVal, m_AShr(m_Deferred(X), m_SpecificInt(BitWidth - 1)))))) {
+
+ Function *Scmp = Intrinsic::getOrInsertDeclaration(
+ SI.getModule(), Intrinsic::scmp, {SI.getType(), SI.getType()});
+ return CallInst::Create(Scmp, {X, ConstantInt::get(SI.getType(), 0)});
+ }
+
return nullptr;
}
diff --git a/llvm/test/Transforms/InstCombine/scmp.ll b/llvm/test/Transforms/InstCombine/scmp.ll
index c0be5b986b7fd..cada97aeadbad 100644
--- a/llvm/test/Transforms/InstCombine/scmp.ll
+++ b/llvm/test/Transforms/InstCombine/scmp.ll
@@ -519,9 +519,7 @@ define <3 x i2> @scmp_unary_shuffle_ops(<3 x i8> %x, <3 x i8> %y) {
define i32 @scmp_sgt_slt(i32 %a) {
; CHECK-LABEL: define i32 @scmp_sgt_slt(
; CHECK-SAME: i32 [[A:%.*]]) {
-; CHECK-NEXT: [[A_LOBIT:%.*]] = ashr i32 [[A]], 31
-; CHECK-NEXT: [[CMP_INV:%.*]] = icmp slt i32 [[A]], 1
-; CHECK-NEXT: [[RETVAL_0:%.*]] = select i1 [[CMP_INV]], i32 [[A_LOBIT]], i32 1
+; CHECK-NEXT: [[RETVAL_0:%.*]] = call i32 @llvm.scmp.i32.i32(i32 [[A]], i32 0)
; CHECK-NEXT: ret i32 [[RETVAL_0]]
;
%cmp = icmp sgt i32 %a, 0
@@ -747,3 +745,41 @@ define i8 @scmp_from_select_eq_and_gt_neg3(i32 %x, i32 %y) {
%r = select i1 %eq, i8 0, i8 %sel1
ret i8 %r
}
+
+define i32 @scmp_ashr(i32 %a) {
+; CHECK-LABEL: define i32 @scmp_ashr(
+; CHECK-SAME: i32 [[A:%.*]]) {
+; CHECK-NEXT: [[RETVAL_0:%.*]] = call i32 @llvm.scmp.i32.i32(i32 [[A]], i32 0)
+; CHECK-NEXT: ret i32 [[RETVAL_0]]
+;
+ %a.lobit = ashr i32 %a, 31
+ %cmp.inv = icmp slt i32 %a, 1
+ %retval.0 = select i1 %cmp.inv, i32 %a.lobit, i32 1
+ ret i32 %retval.0
+}
+
+; Test the new SGT pattern: select (icmp sgt X, 0), 1, ashr X, bitwidth-1 -> scmp(X, 0)
+define i8 @scmp_ashr_sgt_pattern(i8 %a) {
+; CHECK-LABEL: define i8 @scmp_ashr_sgt_pattern(
+; CHECK-SAME: i8 [[A:%.*]]) {
+; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.scmp.i8.i8(i8 [[A]], i8 0)
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %a.lobit = ashr i8 %a, 7
+ %cmp = icmp sgt i8 %a, 0
+ %retval = select i1 %cmp, i8 1, i8 %a.lobit
+ ret i8 %retval
+}
+
+; Test the SLT pattern: select (icmp slt X, 1), ashr X, bitwidth-1, 1 -> scmp(X, 0)
+define i8 @scmp_ashr_slt_pattern(i8 %a) {
+; CHECK-LABEL: define i8 @scmp_ashr_slt_pattern(
+; CHECK-SAME: i8 [[A:%.*]]) {
+; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.scmp.i8.i8(i8 [[A]], i8 0)
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %a.lobit = ashr i8 %a, 7
+ %cmp = icmp slt i8 %a, 1
+ %retval = select i1 %cmp, i8 %a.lobit, i8 1
+ ret i8 %retval
+}
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
b22f41e to
5b652c2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add some negative tests and provide the alive2 proof.
@dtcxzyw Done! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks.
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/65/builds/24492 Here is the relevant piece of the build log for the reference |
This is because the sign function with 0 tends to be folded to ashr and other things.
Alive2: https://alive2.llvm.org/ce/z/Q59KvH