diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineCompares.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCompares.cpp | 390 |
1 files changed, 321 insertions, 69 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index 8c0ad52598..02e8bf1013 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -24,6 +24,8 @@ using namespace llvm; using namespace PatternMatch; +#define DEBUG_TYPE "instcombine" + static ConstantInt *getOne(Constant *C) { return ConstantInt::get(cast<IntegerType>(C->getType()), 1); } @@ -218,15 +220,15 @@ Instruction *InstCombiner:: FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, CmpInst &ICI, ConstantInt *AndCst) { // We need TD information to know the pointer size unless this is inbounds. - if (!GEP->isInBounds() && DL == 0) - return 0; + if (!GEP->isInBounds() && !DL) + return nullptr; Constant *Init = GV->getInitializer(); if (!isa<ConstantArray>(Init) && !isa<ConstantDataArray>(Init)) - return 0; + return nullptr; uint64_t ArrayElementCount = Init->getType()->getArrayNumElements(); - if (ArrayElementCount > 1024) return 0; // Don't blow up on huge arrays. + if (ArrayElementCount > 1024) return nullptr; // Don't blow up on huge arrays. // There are many forms of this optimization we can handle, for now, just do // the simple index into a single-dimensional array. @@ -236,7 +238,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, !isa<ConstantInt>(GEP->getOperand(1)) || !cast<ConstantInt>(GEP->getOperand(1))->isZero() || isa<Constant>(GEP->getOperand(2))) - return 0; + return nullptr; // Check that indices after the variable are constants and in-range for the // type they index. Collect the indices. This is typically for arrays of @@ -246,18 +248,18 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, Type *EltTy = Init->getType()->getArrayElementType(); for (unsigned i = 3, e = GEP->getNumOperands(); i != e; ++i) { ConstantInt *Idx = dyn_cast<ConstantInt>(GEP->getOperand(i)); - if (Idx == 0) return 0; // Variable index. + if (!Idx) return nullptr; // Variable index. uint64_t IdxVal = Idx->getZExtValue(); - if ((unsigned)IdxVal != IdxVal) return 0; // Too large array index. + if ((unsigned)IdxVal != IdxVal) return nullptr; // Too large array index. if (StructType *STy = dyn_cast<StructType>(EltTy)) EltTy = STy->getElementType(IdxVal); else if (ArrayType *ATy = dyn_cast<ArrayType>(EltTy)) { - if (IdxVal >= ATy->getNumElements()) return 0; + if (IdxVal >= ATy->getNumElements()) return nullptr; EltTy = ATy->getElementType(); } else { - return 0; // Unknown type. + return nullptr; // Unknown type. } LaterIndices.push_back(IdxVal); @@ -296,7 +298,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, Constant *CompareRHS = cast<Constant>(ICI.getOperand(1)); for (unsigned i = 0, e = ArrayElementCount; i != e; ++i) { Constant *Elt = Init->getAggregateElement(i); - if (Elt == 0) return 0; + if (!Elt) return nullptr; // If this is indexing an array of structures, get the structure element. if (!LaterIndices.empty()) @@ -321,7 +323,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // If we can't compute the result for any of the elements, we have to give // up evaluating the entire conditional. - if (!isa<ConstantInt>(C)) return 0; + if (!isa<ConstantInt>(C)) return nullptr; // Otherwise, we know if the comparison is true or false for this element, // update our state machines. @@ -375,7 +377,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, if ((i & 8) == 0 && i >= 64 && SecondTrueElement == Overdefined && SecondFalseElement == Overdefined && TrueRangeEnd == Overdefined && FalseRangeEnd == Overdefined) - return 0; + return nullptr; } // Now that we've scanned the entire array, emit our new comparison(s). We @@ -467,7 +469,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // of this load, replace it with computation that does: // ((magic_cst >> i) & 1) != 0 { - Type *Ty = 0; + Type *Ty = nullptr; // Look for an appropriate type: // - The type of Idx if the magic fits @@ -480,7 +482,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, else if (ArrayElementCount <= 32) Ty = Type::getInt32Ty(Init->getContext()); - if (Ty != 0) { + if (Ty) { Value *V = Builder->CreateIntCast(Idx, Ty, false); V = Builder->CreateLShr(ConstantInt::get(Ty, MagicBitvector), V); V = Builder->CreateAnd(ConstantInt::get(Ty, 1), V); @@ -488,7 +490,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, } } - return 0; + return nullptr; } @@ -533,7 +535,7 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) { // If there are no variable indices, we must have a constant offset, just // evaluate it the general way. - if (i == e) return 0; + if (i == e) return nullptr; Value *VariableIdx = GEP->getOperand(i); // Determine the scale factor of the variable element. For example, this is @@ -543,7 +545,7 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) { // Verify that there are no other variable indices. If so, emit the hard way. for (++i, ++GTI; i != e; ++i, ++GTI) { ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i)); - if (!CI) return 0; + if (!CI) return nullptr; // Compute the aggregate offset of constant indices. if (CI->isZero()) continue; @@ -587,7 +589,7 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) { // multiple of the variable scale. int64_t NewOffs = Offset / (int64_t)VariableScale; if (Offset != NewOffs*(int64_t)VariableScale) - return 0; + return nullptr; // Okay, we can do this evaluation. Start by converting the index to intptr. if (VariableIdx->getType() != IntPtrTy) @@ -608,7 +610,7 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // e.g. "&foo[0] <s &foo[1]" can't be folded to "true" because "foo" could be // the maximum signed value for the pointer type. if (ICmpInst::isSigned(Cond)) - return 0; + return nullptr; // Look through bitcasts. if (BitCastInst *BCI = dyn_cast<BitCastInst>(RHS)) @@ -623,7 +625,7 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, Value *Offset = EvaluateGEPOffsetExpression(GEPLHS, *this); // If not, synthesize the offset the hard way. - if (Offset == 0) + if (!Offset) Offset = EmitGEPOffset(GEPLHS); return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Offset, Constant::getNullValue(Offset->getType())); @@ -661,7 +663,7 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // Otherwise, the base pointers are different and the indices are // different, bail out. - return 0; + return nullptr; } // If one of the GEPs has all zero indices, recurse. @@ -729,7 +731,7 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, return new ICmpInst(ICmpInst::getSignedPredicate(Cond), L, R); } } - return 0; + return nullptr; } /// FoldICmpAddOpCst - Fold "icmp pred (X+CI), X". @@ -812,11 +814,11 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, // if it finds it. bool DivIsSigned = DivI->getOpcode() == Instruction::SDiv; if (!ICI.isEquality() && DivIsSigned != ICI.isSigned()) - return 0; + return nullptr; if (DivRHS->isZero()) - return 0; // The ProdOV computation fails on divide by zero. + return nullptr; // The ProdOV computation fails on divide by zero. if (DivIsSigned && DivRHS->isAllOnesValue()) - return 0; // The overflow computation also screws up here + return nullptr; // The overflow computation also screws up here if (DivRHS->isOne()) { // This eliminates some funny cases with INT_MIN. ICI.setOperand(0, DivI->getOperand(0)); // X/1 == X. @@ -850,7 +852,7 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, // overflow variable is set to 0 if it's corresponding bound variable is valid // -1 if overflowed off the bottom end, or +1 if overflowed off the top end. int LoOverflow = 0, HiOverflow = 0; - Constant *LoBound = 0, *HiBound = 0; + Constant *LoBound = nullptr, *HiBound = nullptr; if (!DivIsSigned) { // udiv // e.g. X/5 op 3 --> [15, 20) @@ -890,7 +892,7 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, HiBound = cast<ConstantInt>(ConstantExpr::getNeg(RangeSize)); if (HiBound == DivRHS) { // -INTMIN = INTMIN HiOverflow = 1; // [INTMIN+1, overflow) - HiBound = 0; // e.g. X/INTMIN = 0 --> X > INTMIN + HiBound = nullptr; // e.g. X/INTMIN = 0 --> X > INTMIN } } else if (CmpRHSV.isStrictlyPositive()) { // (X / neg) op pos // e.g. X/-5 op 3 --> [-19, -14) @@ -964,20 +966,20 @@ Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, uint32_t TypeBits = CmpRHSV.getBitWidth(); uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); if (ShAmtVal >= TypeBits || ShAmtVal == 0) - return 0; + return nullptr; if (!ICI.isEquality()) { // If we have an unsigned comparison and an ashr, we can't simplify this. // Similarly for signed comparisons with lshr. if (ICI.isSigned() != (Shr->getOpcode() == Instruction::AShr)) - return 0; + return nullptr; // Otherwise, all lshr and most exact ashr's are equivalent to a udiv/sdiv // by a power of 2. Since we already have logic to simplify these, // transform to div and then simplify the resultant comparison. if (Shr->getOpcode() == Instruction::AShr && (!Shr->isExact() || ShAmtVal == TypeBits - 1)) - return 0; + return nullptr; // Revisit the shift (to delete it). Worklist.Add(Shr); @@ -994,7 +996,7 @@ Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, // If the builder folded the binop, just return it. BinaryOperator *TheDiv = dyn_cast<BinaryOperator>(Tmp); - if (TheDiv == 0) + if (!TheDiv) return &ICI; // Otherwise, fold this div/compare. @@ -1037,7 +1039,7 @@ Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, Mask, Shr->getName()+".mask"); return new ICmpInst(ICI.getPredicate(), And, ShiftedCmpRHS); } - return 0; + return nullptr; } @@ -1056,7 +1058,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, unsigned DstBits = LHSI->getType()->getPrimitiveSizeInBits(), SrcBits = LHSI->getOperand(0)->getType()->getPrimitiveSizeInBits(); APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0); - ComputeMaskedBits(LHSI->getOperand(0), KnownZero, KnownOne); + computeKnownBits(LHSI->getOperand(0), KnownZero, KnownOne); // If all the high bits are known, we can do this xform. if ((KnownZero|KnownOne).countLeadingOnes() >= SrcBits-DstBits) { @@ -1181,10 +1183,10 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, // access. BinaryOperator *Shift = dyn_cast<BinaryOperator>(LHSI->getOperand(0)); if (Shift && !Shift->isShift()) - Shift = 0; + Shift = nullptr; ConstantInt *ShAmt; - ShAmt = Shift ? dyn_cast<ConstantInt>(Shift->getOperand(1)) : 0; + ShAmt = Shift ? dyn_cast<ConstantInt>(Shift->getOperand(1)) : nullptr; // This seemingly simple opportunity to fold away a shift turns out to // be rather complicated. See PR17827 @@ -1777,7 +1779,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, } } } - return 0; + return nullptr; } /// visitICmpInstWithCastAndCast - Handle icmp (cast x to y), (cast/cst). @@ -1794,7 +1796,7 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { // integer type is the same size as the pointer type. if (DL && LHSCI->getOpcode() == Instruction::PtrToInt && DL->getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth()) { - Value *RHSOp = 0; + Value *RHSOp = nullptr; if (Constant *RHSC = dyn_cast<Constant>(ICI.getOperand(1))) { RHSOp = ConstantExpr::getIntToPtr(RHSC, SrcTy); } else if (PtrToIntInst *RHSC = dyn_cast<PtrToIntInst>(ICI.getOperand(1))) { @@ -1812,7 +1814,7 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { // Enforce this. if (LHSCI->getOpcode() != Instruction::ZExt && LHSCI->getOpcode() != Instruction::SExt) - return 0; + return nullptr; bool isSignedExt = LHSCI->getOpcode() == Instruction::SExt; bool isSignedCmp = ICI.isSigned(); @@ -1821,12 +1823,12 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { // Not an extension from the same type? RHSCIOp = CI->getOperand(0); if (RHSCIOp->getType() != LHSCIOp->getType()) - return 0; + return nullptr; // If the signedness of the two casts doesn't agree (i.e. one is a sext // and the other is a zext), then we can't handle this. if (CI->getOpcode() != LHSCI->getOpcode()) - return 0; + return nullptr; // Deal with equality cases early. if (ICI.isEquality()) @@ -1844,7 +1846,7 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { // If we aren't dealing with a constant on the RHS, exit early ConstantInt *CI = dyn_cast<ConstantInt>(ICI.getOperand(1)); if (!CI) - return 0; + return nullptr; // Compute the constant that would happen if we truncated to SrcTy then // reextended to DestTy. @@ -1873,7 +1875,7 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { // by SimplifyICmpInst, so only deal with the tricky case. if (isSignedCmp || !isSignedExt) - return 0; + return nullptr; // Evaluate the comparison for LT (we invert for GT below). LE and GE cases // should have been folded away previously and not enter in here. @@ -1909,12 +1911,12 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // In order to eliminate the add-with-constant, the compare can be its only // use. Instruction *AddWithCst = cast<Instruction>(I.getOperand(0)); - if (!AddWithCst->hasOneUse()) return 0; + if (!AddWithCst->hasOneUse()) return nullptr; // If CI2 is 2^7, 2^15, 2^31, then it might be an sadd.with.overflow. - if (!CI2->getValue().isPowerOf2()) return 0; + if (!CI2->getValue().isPowerOf2()) return nullptr; unsigned NewWidth = CI2->getValue().countTrailingZeros(); - if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31) return 0; + if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31) return nullptr; // The width of the new add formed is 1 more than the bias. ++NewWidth; @@ -1922,7 +1924,7 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // Check to see that CI1 is an all-ones value with NewWidth bits. if (CI1->getBitWidth() == NewWidth || CI1->getValue() != APInt::getLowBitsSet(CI1->getBitWidth(), NewWidth)) - return 0; + return nullptr; // This is only really a signed overflow check if the inputs have been // sign-extended; check for that condition. For example, if CI2 is 2^31 and @@ -1930,7 +1932,7 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1; if (IC.ComputeNumSignBits(A) < NeededSignBits || IC.ComputeNumSignBits(B) < NeededSignBits) - return 0; + return nullptr; // In order to replace the original add with a narrower // llvm.sadd.with.overflow, the only uses allowed are the add-with-constant @@ -1946,8 +1948,8 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // original add had another add which was then immediately truncated, we // could still do the transformation. TruncInst *TI = dyn_cast<TruncInst>(U); - if (TI == 0 || - TI->getType()->getPrimitiveSizeInBits() > NewWidth) return 0; + if (!TI || TI->getType()->getPrimitiveSizeInBits() > NewWidth) + return nullptr; } // If the pattern matches, truncate the inputs to the narrower type and @@ -1983,11 +1985,11 @@ static Instruction *ProcessUAddIdiom(Instruction &I, Value *OrigAddV, InstCombiner &IC) { // Don't bother doing this transformation for pointers, don't do it for // vectors. - if (!isa<IntegerType>(OrigAddV->getType())) return 0; + if (!isa<IntegerType>(OrigAddV->getType())) return nullptr; // If the add is a constant expr, then we don't bother transforming it. Instruction *OrigAdd = dyn_cast<Instruction>(OrigAddV); - if (OrigAdd == 0) return 0; + if (!OrigAdd) return nullptr; Value *LHS = OrigAdd->getOperand(0), *RHS = OrigAdd->getOperand(1); @@ -2008,6 +2010,236 @@ static Instruction *ProcessUAddIdiom(Instruction &I, Value *OrigAddV, return ExtractValueInst::Create(Call, 1, "uadd.overflow"); } +/// \brief Recognize and process idiom involving test for multiplication +/// overflow. +/// +/// The caller has matched a pattern of the form: +/// I = cmp u (mul(zext A, zext B), V +/// The function checks if this is a test for overflow and if so replaces +/// multiplication with call to 'mul.with.overflow' intrinsic. +/// +/// \param I Compare instruction. +/// \param MulVal Result of 'mult' instruction. It is one of the arguments of +/// the compare instruction. Must be of integer type. +/// \param OtherVal The other argument of compare instruction. +/// \returns Instruction which must replace the compare instruction, NULL if no +/// replacement required. +static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal, + Value *OtherVal, InstCombiner &IC) { + assert(I.getOperand(0) == MulVal || I.getOperand(1) == MulVal); + assert(I.getOperand(0) == OtherVal || I.getOperand(1) == OtherVal); + assert(isa<IntegerType>(MulVal->getType())); + Instruction *MulInstr = cast<Instruction>(MulVal); + assert(MulInstr->getOpcode() == Instruction::Mul); + + Instruction *LHS = cast<Instruction>(MulInstr->getOperand(0)), + *RHS = cast<Instruction>(MulInstr->getOperand(1)); + assert(LHS->getOpcode() == Instruction::ZExt); + assert(RHS->getOpcode() == Instruction::ZExt); + Value *A = LHS->getOperand(0), *B = RHS->getOperand(0); + + // Calculate type and width of the result produced by mul.with.overflow. + Type *TyA = A->getType(), *TyB = B->getType(); + unsigned WidthA = TyA->getPrimitiveSizeInBits(), + WidthB = TyB->getPrimitiveSizeInBits(); + unsigned MulWidth; + Type *MulType; + if (WidthB > WidthA) { + MulWidth = WidthB; + MulType = TyB; + } else { + MulWidth = WidthA; + MulType = TyA; + } + + // In order to replace the original mul with a narrower mul.with.overflow, + // all uses must ignore upper bits of the product. The number of used low + // bits must be not greater than the width of mul.with.overflow. + if (MulVal->hasNUsesOrMore(2)) + for (User *U : MulVal->users()) { + if (U == &I) + continue; + if (TruncInst *TI = dyn_cast<TruncInst>(U)) { + // Check if truncation ignores bits above MulWidth. + unsigned TruncWidth = TI->getType()->getPrimitiveSizeInBits(); + if (TruncWidth > MulWidth) + return nullptr; + } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) { + // Check if AND ignores bits above MulWidth. + if (BO->getOpcode() != Instruction::And) + return nullptr; + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) { + const APInt &CVal = CI->getValue(); + if (CVal.getBitWidth() - CVal.countLeadingZeros() > MulWidth) + return nullptr; + } + } else { + // Other uses prohibit this transformation. + return nullptr; + } + } + + // Recognize patterns + switch (I.getPredicate()) { + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_NE: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp eq/neq mulval, zext trunc mulval + if (ZExtInst *Zext = dyn_cast<ZExtInst>(OtherVal)) + if (Zext->hasOneUse()) { + Value *ZextArg = Zext->getOperand(0); + if (TruncInst *Trunc = dyn_cast<TruncInst>(ZextArg)) + if (Trunc->getType()->getPrimitiveSizeInBits() == MulWidth) + break; //Recognized + } + + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp eq/neq mulval, and(mulval, mask), mask selects low MulWidth bits. + ConstantInt *CI; + Value *ValToMask; + if (match(OtherVal, m_And(m_Value(ValToMask), m_ConstantInt(CI)))) { + if (ValToMask != MulVal) + return nullptr; + const APInt &CVal = CI->getValue() + 1; + if (CVal.isPowerOf2()) { + unsigned MaskWidth = CVal.logBase2(); + if (MaskWidth == MulWidth) + break; // Recognized + } + } + return nullptr; + + case ICmpInst::ICMP_UGT: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp ugt mulval, max + if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) { + APInt MaxVal = APInt::getMaxValue(MulWidth); + MaxVal = MaxVal.zext(CI->getBitWidth()); + if (MaxVal.eq(CI->getValue())) + break; // Recognized + } + return nullptr; + + case ICmpInst::ICMP_UGE: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp uge mulval, max+1 + if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) { + APInt MaxVal = APInt::getOneBitSet(CI->getBitWidth(), MulWidth); + if (MaxVal.eq(CI->getValue())) + break; // Recognized + } + return nullptr; + + case ICmpInst::ICMP_ULE: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp ule mulval, max + if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) { + APInt MaxVal = APInt::getMaxValue(MulWidth); + MaxVal = MaxVal.zext(CI->getBitWidth()); + if (MaxVal.eq(CI->getValue())) + break; // Recognized + } + return nullptr; + + case ICmpInst::ICMP_ULT: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp ule mulval, max + 1 + if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) { + APInt MaxVal = APInt::getOneBitSet(CI->getBitWidth(), MulWidth); + if (MaxVal.eq(CI->getValue())) + break; // Recognized + } + return nullptr; + + default: + return nullptr; + } + + InstCombiner::BuilderTy *Builder = IC.Builder; + Builder->SetInsertPoint(MulInstr); + Module *M = I.getParent()->getParent()->getParent(); + + // Replace: mul(zext A, zext B) --> mul.with.overflow(A, B) + Value *MulA = A, *MulB = B; + if (WidthA < MulWidth) + MulA = Builder->CreateZExt(A, MulType); + if (WidthB < MulWidth) + MulB = Builder->CreateZExt(B, MulType); + Value *F = + Intrinsic::getDeclaration(M, Intrinsic::umul_with_overflow, MulType); + CallInst *Call = Builder->CreateCall2(F, MulA, MulB, "umul"); + IC.Worklist.Add(MulInstr); + + // If there are uses of mul result other than the comparison, we know that + // they are truncation or binary AND. Change them to use result of + // mul.with.overflow and adjust properly mask/size. + if (MulVal->hasNUsesOrMore(2)) { + Value *Mul = Builder->CreateExtractValue(Call, 0, "umul.value"); + for (User *U : MulVal->users()) { + if (U == &I || U == OtherVal) + continue; + if (TruncInst *TI = dyn_cast<TruncInst>(U)) { + if (TI->getType()->getPrimitiveSizeInBits() == MulWidth) + IC.ReplaceInstUsesWith(*TI, Mul); + else + TI->setOperand(0, Mul); + } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) { + assert(BO->getOpcode() == Instruction::And); + // Replace (mul & mask) --> zext (mul.with.overflow & short_mask) + ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1)); + APInt ShortMask = CI->getValue().trunc(MulWidth); + Value *ShortAnd = Builder->CreateAnd(Mul, ShortMask); + Instruction *Zext = + cast<Instruction>(Builder->CreateZExt(ShortAnd, BO->getType())); + IC.Worklist.Add(Zext); + IC.ReplaceInstUsesWith(*BO, Zext); + } else { + llvm_unreachable("Unexpected Binary operation"); + } + IC.Worklist.Add(cast<Instruction>(U)); + } + } + if (isa<Instruction>(OtherVal)) + IC.Worklist.Add(cast<Instruction>(OtherVal)); + + // The original icmp gets replaced with the overflow value, maybe inverted + // depending on predicate. + bool Inverse = false; + switch (I.getPredicate()) { + case ICmpInst::ICMP_NE: + break; + case ICmpInst::ICMP_EQ: + Inverse = true; + break; + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + if (I.getOperand(0) == MulVal) + break; + Inverse = true; + break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + if (I.getOperand(1) == MulVal) + break; + Inverse = true; + break; + default: + llvm_unreachable("Unexpected predicate"); + } + if (Inverse) { + Value *Res = Builder->CreateExtractValue(Call, 1); + return BinaryOperator::CreateNot(Res); + } + + return ExtractValueInst::Create(Call, 1); +} + // DemandedBitsLHSMask - When performing a comparison against a constant, // it is possible that not all the bits in the LHS are demanded. This helper // method computes the mask that IS demanded. @@ -2178,7 +2410,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // See if we are doing a comparison with a constant. if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - Value *A = 0, *B = 0; + Value *A = nullptr, *B = nullptr; // Match the following pattern, which is a common idiom when writing // overflow-safe integer arithmetic function. The source performs an @@ -2293,15 +2525,15 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { APInt Op0KnownZeroInverted = ~Op0KnownZero; if (~Op1KnownZero == 0 && Op0KnownZeroInverted.isPowerOf2()) { // If the LHS is an AND with the same constant, look through it. - Value *LHS = 0; - ConstantInt *LHSC = 0; + Value *LHS = nullptr; + ConstantInt *LHSC = nullptr; if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) || LHSC->getValue() != Op0KnownZeroInverted) LHS = Op0; // If the LHS is 1 << x, and we know the result is a power of 2 like 8, // then turn "((1 << x)&8) == 0" into "x != 3". - Value *X = 0; + Value *X = nullptr; if (match(LHS, m_Shl(m_One(), m_Value(X)))) { unsigned CmpVal = Op0KnownZeroInverted.countTrailingZeros(); return new ICmpInst(ICmpInst::ICMP_NE, X, @@ -2330,15 +2562,15 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { APInt Op0KnownZeroInverted = ~Op0KnownZero; if (~Op1KnownZero == 0 && Op0KnownZeroInverted.isPowerOf2()) { // If the LHS is an AND with the same constant, look through it. - Value *LHS = 0; - ConstantInt *LHSC = 0; + Value *LHS = nullptr; + ConstantInt *LHSC = nullptr; if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) || LHSC->getValue() != Op0KnownZeroInverted) LHS = Op0; // If the LHS is 1 << x, and we know the result is a power of 2 like 8, // then turn "((1 << x)&8) != 0" into "x == 3". - Value *X = 0; + Value *X = nullptr; if (match(LHS, m_Shl(m_One(), m_Value(X)))) { unsigned CmpVal = Op0KnownZeroInverted.countTrailingZeros(); return new ICmpInst(ICmpInst::ICMP_EQ, X, @@ -2470,7 +2702,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (SelectInst *SI = dyn_cast<SelectInst>(*I.user_begin())) if ((SI->getOperand(1) == Op0 && SI->getOperand(2) == Op1) || (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1)) - return 0; + return nullptr; // See if we are doing a comparison between a constant and an instruction that // can be folded into the comparison. @@ -2506,7 +2738,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // If either operand of the select is a constant, we can fold the // comparison into the select arms, which will cause one to be // constant folded and the select turned into a bitwise or. - Value *Op1 = 0, *Op2 = 0; + Value *Op1 = nullptr, *Op2 = nullptr; if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(1))) Op1 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(2))) @@ -2618,7 +2850,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // Analyze the case when either Op0 or Op1 is an add instruction. // Op0 = A + B (or A and B are null); Op1 = C + D (or C and D are null). - Value *A = 0, *B = 0, *C = 0, *D = 0; + Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; if (BO0 && BO0->getOpcode() == Instruction::Add) A = BO0->getOperand(0), B = BO0->getOperand(1); if (BO1 && BO1->getOpcode() == Instruction::Add) @@ -2713,7 +2945,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // Analyze the case when either Op0 or Op1 is a sub instruction. // Op0 = A - B (or A and B are null); Op1 = C - D (or C and D are null). - A = 0; B = 0; C = 0; D = 0; + A = nullptr; B = nullptr; C = nullptr; D = nullptr; if (BO0 && BO0->getOpcode() == Instruction::Sub) A = BO0->getOperand(0), B = BO0->getOperand(1); if (BO1 && BO1->getOpcode() == Instruction::Sub) @@ -2739,7 +2971,17 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { BO0->hasOneUse() && BO1->hasOneUse()) return new ICmpInst(Pred, D, B); - BinaryOperator *SRem = NULL; + // icmp (0-X) < cst --> x > -cst + if (NoOp0WrapProblem && ICmpInst::isSigned(Pred)) { + Value *X; + if (match(BO0, m_Neg(m_Value(X)))) + if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) + if (!RHSC->isMinValue(/*isSigned=*/true)) + return new ICmpInst(I.getSwappedPredicate(), X, + ConstantExpr::getNeg(RHSC)); + } + + BinaryOperator *SRem = nullptr; // icmp (srem X, Y), Y if (BO0 && BO0->getOpcode() == Instruction::SRem && Op1 == BO0->getOperand(1)) @@ -2877,6 +3119,16 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { (Op0 == A || Op0 == B)) if (Instruction *R = ProcessUAddIdiom(I, Op1, *this)) return R; + + // (zext a) * (zext b) --> llvm.umul.with.overflow. + if (match(Op0, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { + if (Instruction *R = ProcessUMulZExtIdiom(I, Op0, Op1, *this)) + return R; + } + if (match(Op1, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { + if (Instruction *R = ProcessUMulZExtIdiom(I, Op1, Op0, *this)) + return R; + } } if (I.isEquality()) { @@ -2918,7 +3170,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // (X&Z) == (Y&Z) -> (X^Y) & Z == 0 if (match(Op0, m_OneUse(m_And(m_Value(A), m_Value(B)))) && match(Op1, m_OneUse(m_And(m_Value(C), m_Value(D))))) { - Value *X = 0, *Y = 0, *Z = 0; + Value *X = nullptr, *Y = nullptr, *Z = nullptr; if (A == C) { X = B; Y = D; Z = A; @@ -3009,7 +3261,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (match(Op1, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op0 == X) return FoldICmpAddOpCst(I, X, Cst, I.getSwappedPredicate()); } - return Changed ? &I : 0; + return Changed ? &I : nullptr; } /// FoldFCmp_IntToFP_Cst - Fold fcmp ([us]itofp x, cst) if possible. @@ -3017,13 +3269,13 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, Instruction *LHSI, Constant *RHSC) { - if (!isa<ConstantFP>(RHSC)) return 0; + if (!isa<ConstantFP>(RHSC)) return nullptr; const APFloat &RHS = cast<ConstantFP>(RHSC)->getValueAPF(); // Get the width of the mantissa. We don't want to hack on conversions that // might lose information from the integer, e.g. "i64 -> float" int MantissaWidth = LHSI->getType()->getFPMantissaWidth(); - if (MantissaWidth == -1) return 0; // Unknown. + if (MantissaWidth == -1) return nullptr; // Unknown. // Check to see that the input is converted from an integer type that is small // enough that preserves all bits. TODO: check here for "known" sign bits. @@ -3037,7 +3289,7 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, // If the conversion would lose info, don't hack on this. if ((int)InputSize > MantissaWidth) - return 0; + return nullptr; // Otherwise, we can potentially simplify the comparison. We know that it // will always come through as an integer value and we know the constant is @@ -3383,5 +3635,5 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { return new FCmpInst(I.getPredicate(), LHSExt->getOperand(0), RHSExt->getOperand(0)); - return Changed ? &I : 0; + return Changed ? &I : nullptr; } |