diff options
Diffstat (limited to 'lib/Transforms/Utils/LowerInvoke.cpp')
-rw-r--r-- | lib/Transforms/Utils/LowerInvoke.cpp | 441 |
1 files changed, 301 insertions, 140 deletions
diff --git a/lib/Transforms/Utils/LowerInvoke.cpp b/lib/Transforms/Utils/LowerInvoke.cpp index b0d8fb8..54724b5 100644 --- a/lib/Transforms/Utils/LowerInvoke.cpp +++ b/lib/Transforms/Utils/LowerInvoke.cpp @@ -41,13 +41,17 @@ #include "llvm/Module.h" #include "llvm/Pass.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/ADT/Statistic.h" #include "llvm/Support/CommandLine.h" #include <csetjmp> using namespace llvm; namespace { - Statistic<> NumLowered("lowerinvoke", "Number of invoke & unwinds replaced"); + Statistic<> NumInvokes("lowerinvoke", "Number of invokes replaced"); + Statistic<> NumUnwinds("lowerinvoke", "Number of unwinds replaced"); + Statistic<> NumSpilled("lowerinvoke", + "Number of registers live across unwind edges"); cl::opt<bool> ExpensiveEHSupport("enable-correct-eh-support", cl::desc("Make the -lowerinvoke pass insert expensive, but correct, EH code")); @@ -65,10 +69,14 @@ namespace { public: bool doInitialization(Module &M); bool runOnFunction(Function &F); + private: void createAbortMessage(); void writeAbortMessage(Instruction *IB); bool insertCheapEHSupport(Function &F); + void splitLiveRangesLiveAcrossInvokes(std::vector<InvokeInst*> &Invokes); + void rewriteExpensiveInvoke(InvokeInst *II, unsigned InvokeNo, + AllocaInst *InvokeNum, SwitchInst *CatchSwitch); bool insertExpensiveEHSupport(Function &F); }; @@ -97,9 +105,9 @@ bool LowerInvoke::doInitialization(Module &M) { { // The type is recursive, so use a type holder. std::vector<const Type*> Elements; + Elements.push_back(JmpBufTy); OpaqueType *OT = OpaqueType::get(); Elements.push_back(PointerType::get(OT)); - Elements.push_back(JmpBufTy); PATypeHolder JBLType(StructType::get(Elements)); OT->refineAbstractTypeTo(JBLType.get()); // Complete the cycle. JBLinkTy = JBLType.get(); @@ -220,7 +228,7 @@ bool LowerInvoke::insertCheapEHSupport(Function &F) { // Remove the invoke instruction now. BB->getInstList().erase(II); - ++NumLowered; Changed = true; + ++NumInvokes; Changed = true; } else if (UnwindInst *UI = dyn_cast<UnwindInst>(BB->getTerminator())) { // Insert a new call to write(2, AbortMessage, AbortMessageLength); writeAbortMessage(UI); @@ -236,163 +244,316 @@ bool LowerInvoke::insertCheapEHSupport(Function &F) { // Remove the unwind instruction now. BB->getInstList().erase(UI); - ++NumLowered; Changed = true; + ++NumUnwinds; Changed = true; } return Changed; } -bool LowerInvoke::insertExpensiveEHSupport(Function &F) { - bool Changed = false; +/// rewriteExpensiveInvoke - Insert code and hack the function to replace the +/// specified invoke instruction with a call. +void LowerInvoke::rewriteExpensiveInvoke(InvokeInst *II, unsigned InvokeNo, + AllocaInst *InvokeNum, + SwitchInst *CatchSwitch) { + ConstantUInt *InvokeNoC = ConstantUInt::get(Type::UIntTy, InvokeNo); + + // Insert a store of the invoke num before the invoke and store zero into the + // location afterward. + new StoreInst(InvokeNoC, InvokeNum, true, II); // volatile + new StoreInst(Constant::getNullValue(Type::UIntTy), InvokeNum, false, + II->getNormalDest()->begin()); // nonvolatile. + + // Add a switch case to our unwind block. + CatchSwitch->addCase(InvokeNoC, II->getUnwindDest()); + + // Insert a normal call instruction. + std::string Name = II->getName(); II->setName(""); + CallInst *NewCall = new CallInst(II->getCalledValue(), + std::vector<Value*>(II->op_begin()+3, + II->op_end()), Name, + II); + NewCall->setCallingConv(II->getCallingConv()); + II->replaceAllUsesWith(NewCall); + + // Replace the invoke with an uncond branch. + new BranchInst(II->getNormalDest(), NewCall->getParent()); + II->eraseFromParent(); +} - // If a function uses invoke, we have an alloca for the jump buffer. - AllocaInst *JmpBuf = 0; +/// MarkBlocksLiveIn - Insert BB and all of its predescessors into LiveBBs until +/// we reach blocks we've already seen. +static void MarkBlocksLiveIn(BasicBlock *BB, std::set<BasicBlock*> &LiveBBs) { + if (!LiveBBs.insert(BB).second) return; // already been here. + + for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) + MarkBlocksLiveIn(*PI, LiveBBs); +} - // If this function contains an unwind instruction, two blocks get added: one - // to actually perform the longjmp, and one to terminate the program if there - // is no handler. - BasicBlock *UnwindBlock = 0, *TermBlock = 0; - std::vector<LoadInst*> JBPtrs; +// First thing we need to do is scan the whole function for values that are +// live across unwind edges. Each value that is live across an unwind edge +// we spill into a stack location, guaranteeing that there is nothing live +// across the unwind edge. This process also splits all critical edges +// coming out of invoke's. +void LowerInvoke:: +splitLiveRangesLiveAcrossInvokes(std::vector<InvokeInst*> &Invokes) { + // First step, split all critical edges from invoke instructions. + for (unsigned i = 0, e = Invokes.size(); i != e; ++i) { + InvokeInst *II = Invokes[i]; + SplitCriticalEdge(II, 0, this); + SplitCriticalEdge(II, 1, this); + assert(!isa<PHINode>(II->getNormalDest()) && + !isa<PHINode>(II->getUnwindDest()) && + "critical edge splitting left single entry phi nodes?"); + } - for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) - if (InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator())) { - if (JmpBuf == 0) - JmpBuf = new AllocaInst(JBLinkTy, 0, "jblink", F.begin()->begin()); - - // On the entry to the invoke, we must install our JmpBuf as the top of - // the stack. - LoadInst *OldEntry = new LoadInst(JBListHead, "oldehlist", II); - - // Store this old value as our 'next' field, and store our alloca as the - // current jblist. - std::vector<Value*> Idx; - Idx.push_back(Constant::getNullValue(Type::IntTy)); - Idx.push_back(ConstantUInt::get(Type::UIntTy, 0)); - Value *NextFieldPtr = new GetElementPtrInst(JmpBuf, Idx, "NextField", II); - new StoreInst(OldEntry, NextFieldPtr, II); - new StoreInst(JmpBuf, JBListHead, II); - - // Call setjmp, passing in the address of the jmpbuffer. - Idx[1] = ConstantUInt::get(Type::UIntTy, 1); - Value *JmpBufPtr = new GetElementPtrInst(JmpBuf, Idx, "TheJmpBuf", II); - Value *SJRet = new CallInst(SetJmpFn, JmpBufPtr, "sjret", II); - - // Compare the return value to zero. - Value *IsNormal = BinaryOperator::create(Instruction::SetEQ, SJRet, - Constant::getNullValue(SJRet->getType()), - "notunwind", II); - // Create the receiver block if there is a critical edge to the normal - // destination. - SplitCriticalEdge(II, 0, this); + Function *F = Invokes.back()->getParent()->getParent(); + + // To avoid having to handle incoming arguments specially, we lower each arg + // to a copy instruction in the entry block. This ensure that the argument + // value itself cannot be live across the entry block. + BasicBlock::iterator AfterAllocaInsertPt = F->begin()->begin(); + while (isa<AllocaInst>(AfterAllocaInsertPt) && + isa<ConstantInt>(cast<AllocaInst>(AfterAllocaInsertPt)->getArraySize())) + ++AfterAllocaInsertPt; + for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end(); + AI != E; ++AI) { + CastInst *NC = new CastInst(AI, AI->getType(), AI->getName()+".tmp", + AfterAllocaInsertPt); + AI->replaceAllUsesWith(NC); + NC->setOperand(0, AI); + } + + // Finally, scan the code looking for instructions with bad live ranges. + for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) + for (BasicBlock::iterator II = BB->begin(), E = BB->end(); II != E; ++II) { + // Ignore obvious cases we don't have to handle. In particular, most + // instructions either have no uses or only have a single use inside the + // current block. Ignore them quickly. + Instruction *Inst = II; + if (Inst->use_empty()) continue; + if (Inst->hasOneUse() && + cast<Instruction>(Inst->use_back())->getParent() == BB && + !isa<PHINode>(Inst->use_back())) continue; - // There should not be any PHI nodes in II->getNormalDest() now. It has - // a single predecessor, so any PHI nodes are unneeded. Remove them now - // by replacing them with their single input value. - assert(II->getNormalDest()->getSinglePredecessor() && - "Split crit edge doesn't have a single predecessor!"); - - BasicBlock::iterator InsertLoc = II->getNormalDest()->begin(); - while (PHINode *PN = dyn_cast<PHINode>(InsertLoc)) { - PN->replaceAllUsesWith(PN->getIncomingValue(0)); - PN->eraseFromParent(); - InsertLoc = II->getNormalDest()->begin(); + // Avoid iterator invalidation by copying users to a temporary vector. + std::vector<Instruction*> Users; + for (Value::use_iterator UI = Inst->use_begin(), E = Inst->use_end(); + UI != E; ++UI) { + Instruction *User = cast<Instruction>(*UI); + if (User->getParent() != BB || isa<PHINode>(User)) + Users.push_back(User); } - - // Insert a normal call instruction on the normal execution path. - std::string Name = II->getName(); II->setName(""); - CallInst *NewCall = new CallInst(II->getCalledValue(), - std::vector<Value*>(II->op_begin()+3, - II->op_end()), Name, - InsertLoc); - NewCall->setCallingConv(II->getCallingConv()); - II->replaceAllUsesWith(NewCall); - - // If we got this far, then no exception was thrown and we can pop our - // jmpbuf entry off. - new StoreInst(OldEntry, JBListHead, InsertLoc); - - // Now we change the invoke into a branch instruction. - new BranchInst(II->getNormalDest(), II->getUnwindDest(), IsNormal, II); - - // Remove the InvokeInst now. - BB->getInstList().erase(II); - ++NumLowered; Changed = true; - } else if (UnwindInst *UI = dyn_cast<UnwindInst>(BB->getTerminator())) { - if (UnwindBlock == 0) { - // Create two new blocks, the unwind block and the terminate block. Add - // them at the end of the function because they are not hot. - UnwindBlock = new BasicBlock("unwind", &F); - TermBlock = new BasicBlock("unwinderror", &F); - - // Insert return instructions. These really should be "barrier"s, as - // they are unreachable. - new ReturnInst(F.getReturnType() == Type::VoidTy ? 0 : - Constant::getNullValue(F.getReturnType()), UnwindBlock); - new ReturnInst(F.getReturnType() == Type::VoidTy ? 0 : - Constant::getNullValue(F.getReturnType()), TermBlock); + // Scan all of the uses and see if the live range is live across an unwind + // edge. If we find a use live across an invoke edge, create an alloca + // and spill the value. + AllocaInst *SpillLoc = 0; + std::set<InvokeInst*> InvokesWithStoreInserted; + + // Find all of the blocks that this value is live in. + std::set<BasicBlock*> LiveBBs; + LiveBBs.insert(Inst->getParent()); + while (!Users.empty()) { + Instruction *U = Users.back(); + Users.pop_back(); + + BasicBlock *UseBlock; + if (!isa<PHINode>(U)) { + MarkBlocksLiveIn(U->getParent(), LiveBBs); + } else { + // Uses for a PHI node occur in their predecessor block. + PHINode *PN = cast<PHINode>(U); + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingValue(i) == Inst) + MarkBlocksLiveIn(PN->getIncomingBlock(i), LiveBBs); + } + } + + // Now that we know all of the blocks that this thing is live in, see if + // it includes any of the unwind locations. + bool NeedsSpill = false; + for (unsigned i = 0, e = Invokes.size(); i != e; ++i) { + BasicBlock *UnwindBlock = Invokes[i]->getUnwindDest(); + if (UnwindBlock != BB && LiveBBs.count(UnwindBlock)) { + NeedsSpill = true; + } } - // Load the JBList, if it's null, then there was no catch! - LoadInst *Ptr = new LoadInst(JBListHead, "ehlist", UI); - Value *NotNull = BinaryOperator::create(Instruction::SetNE, Ptr, - Constant::getNullValue(Ptr->getType()), - "notnull", UI); - new BranchInst(UnwindBlock, TermBlock, NotNull, UI); - - // Remember the loaded value so we can insert the PHI node as needed. - JBPtrs.push_back(Ptr); - - // Remove the UnwindInst now. - BB->getInstList().erase(UI); - ++NumLowered; Changed = true; + // If we decided we need a spill, do it. + if (NeedsSpill) { + ++NumSpilled; + DemoteRegToStack(*Inst, true); + } } +} + +bool LowerInvoke::insertExpensiveEHSupport(Function &F) { + std::vector<ReturnInst*> Returns; + std::vector<UnwindInst*> Unwinds; + std::vector<InvokeInst*> Invokes; - // If an unwind instruction was inserted, we need to set up the Unwind and - // term blocks. - if (UnwindBlock) { - // In the unwind block, we know that the pointer coming in on the JBPtrs - // list are non-null. - Instruction *RI = UnwindBlock->getTerminator(); - - Value *RecPtr; - if (JBPtrs.size() == 1) - RecPtr = JBPtrs[0]; - else { - // If there is more than one unwind in this function, make a PHI node to - // merge in all of the loaded values. - PHINode *PN = new PHINode(JBPtrs[0]->getType(), "jbptrs", RI); - for (unsigned i = 0, e = JBPtrs.size(); i != e; ++i) - PN->addIncoming(JBPtrs[i], JBPtrs[i]->getParent()); - RecPtr = PN; + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) + if (ReturnInst *RI = dyn_cast<ReturnInst>(BB->getTerminator())) { + // Remember all return instructions in case we insert an invoke into this + // function. + Returns.push_back(RI); + } else if (InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator())) { + Invokes.push_back(II); + } else if (UnwindInst *UI = dyn_cast<UnwindInst>(BB->getTerminator())) { + Unwinds.push_back(UI); } - // Now that we have a pointer to the whole record, remove the entry from the - // JBList. + if (Unwinds.empty() && Invokes.empty()) return false; + + NumInvokes += Invokes.size(); + NumUnwinds += Unwinds.size(); + + // If we have an invoke instruction, insert a setjmp that dominates all + // invokes. After the setjmp, use a cond branch that goes to the original + // code path on zero, and to a designated 'catch' block of nonzero. + Value *OldJmpBufPtr = 0; + if (!Invokes.empty()) { + // First thing we need to do is scan the whole function for values that are + // live across unwind edges. Each value that is live across an unwind edge + // we spill into a stack location, guaranteeing that there is nothing live + // across the unwind edge. This process also splits all critical edges + // coming out of invoke's. + splitLiveRangesLiveAcrossInvokes(Invokes); + + BasicBlock *EntryBB = F.begin(); + + // Create an alloca for the incoming jump buffer ptr and the new jump buffer + // that needs to be restored on all exits from the function. This is an + // alloca because the value needs to be live across invokes. + AllocaInst *JmpBuf = + new AllocaInst(JBLinkTy, 0, "jblink", F.begin()->begin()); + std::vector<Value*> Idx; Idx.push_back(Constant::getNullValue(Type::IntTy)); - Idx.push_back(ConstantUInt::get(Type::UIntTy, 0)); - Value *NextFieldPtr = new GetElementPtrInst(RecPtr, Idx, "NextField", RI); - Value *NextRec = new LoadInst(NextFieldPtr, "NextRecord", RI); - new StoreInst(NextRec, JBListHead, RI); - - // Now that we popped the top of the JBList, get a pointer to the jmpbuf and - // longjmp. - Idx[1] = ConstantUInt::get(Type::UIntTy, 1); - Idx[0] = new GetElementPtrInst(RecPtr, Idx, "JmpBuf", RI); - Idx[1] = ConstantInt::get(Type::IntTy, 1); - new CallInst(LongJmpFn, Idx, "", RI); - - // Now we set up the terminate block. - RI = TermBlock->getTerminator(); - - // Insert a new call to write(2, AbortMessage, AbortMessageLength); - writeAbortMessage(RI); - - // Insert a call to abort() - (new CallInst(AbortFn, std::vector<Value*>(), "", RI))->setTailCall(); + Idx.push_back(ConstantUInt::get(Type::UIntTy, 1)); + OldJmpBufPtr = new GetElementPtrInst(JmpBuf, Idx, "OldBuf", + EntryBB->getTerminator()); + + // Copy the JBListHead to the alloca. + Value *OldBuf = new LoadInst(JBListHead, "oldjmpbufptr", true, + EntryBB->getTerminator()); + new StoreInst(OldBuf, OldJmpBufPtr, true, EntryBB->getTerminator()); + + // Add the new jumpbuf to the list. + new StoreInst(JmpBuf, JBListHead, true, EntryBB->getTerminator()); + + // Create the catch block. The catch block is basically a big switch + // statement that goes to all of the invoke catch blocks. + BasicBlock *CatchBB = new BasicBlock("setjmp.catch", &F); + + // Create an alloca which keeps track of which invoke is currently + // executing. For normal calls it contains zero. + AllocaInst *InvokeNum = new AllocaInst(Type::UIntTy, 0, "invokenum", + EntryBB->begin()); + new StoreInst(ConstantInt::get(Type::UIntTy, 0), InvokeNum, true, + EntryBB->getTerminator()); + + // Insert a load in the Catch block, and a switch on its value. By default, + // we go to a block that just does an unwind (which is the correct action + // for a standard call). + BasicBlock *UnwindBB = new BasicBlock("unwindbb", &F); + Unwinds.push_back(new UnwindInst(UnwindBB)); + + Value *CatchLoad = new LoadInst(InvokeNum, "invoke.num", true, CatchBB); + SwitchInst *CatchSwitch = + new SwitchInst(CatchLoad, UnwindBB, Invokes.size(), CatchBB); + + // Now that things are set up, insert the setjmp call itself. + + // Split the entry block to insert the conditional branch for the setjmp. + BasicBlock *ContBlock = EntryBB->splitBasicBlock(EntryBB->getTerminator(), + "setjmp.cont"); + + Idx[1] = ConstantUInt::get(Type::UIntTy, 0); + Value *JmpBufPtr = new GetElementPtrInst(JmpBuf, Idx, "TheJmpBuf", + EntryBB->getTerminator()); + Value *SJRet = new CallInst(SetJmpFn, JmpBufPtr, "sjret", + EntryBB->getTerminator()); + + // Compare the return value to zero. + Value *IsNormal = BinaryOperator::createSetEQ(SJRet, + Constant::getNullValue(SJRet->getType()), + "notunwind", EntryBB->getTerminator()); + // Nuke the uncond branch. + EntryBB->getTerminator()->eraseFromParent(); + + // Put in a new condbranch in its place. + new BranchInst(ContBlock, CatchBB, IsNormal, EntryBB); + + // At this point, we are all set up, rewrite each invoke instruction. + for (unsigned i = 0, e = Invokes.size(); i != e; ++i) + rewriteExpensiveInvoke(Invokes[i], i+1, InvokeNum, CatchSwitch); } - return Changed; + // We know that there is at least one unwind. + + // Create three new blocks, the block to load the jmpbuf ptr and compare + // against null, the block to do the longjmp, and the error block for if it + // is null. Add them at the end of the function because they are not hot. + BasicBlock *UnwindHandler = new BasicBlock("dounwind", &F); + BasicBlock *UnwindBlock = new BasicBlock("unwind", &F); + BasicBlock *TermBlock = new BasicBlock("unwinderror", &F); + + // If this function contains an invoke, restore the old jumpbuf ptr. + Value *BufPtr; + if (OldJmpBufPtr) { + // Before the return, insert a copy from the saved value to the new value. + BufPtr = new LoadInst(OldJmpBufPtr, "oldjmpbufptr", UnwindHandler); + new StoreInst(BufPtr, JBListHead, UnwindHandler); + } else { + BufPtr = new LoadInst(JBListHead, "ehlist", UnwindHandler); + } + + // Load the JBList, if it's null, then there was no catch! + Value *NotNull = BinaryOperator::createSetNE(BufPtr, + Constant::getNullValue(BufPtr->getType()), + "notnull", UnwindHandler); + new BranchInst(UnwindBlock, TermBlock, NotNull, UnwindHandler); + + // Create the block to do the longjmp. + // Get a pointer to the jmpbuf and longjmp. + std::vector<Value*> Idx; + Idx.push_back(Constant::getNullValue(Type::IntTy)); + Idx.push_back(ConstantUInt::get(Type::UIntTy, 0)); + Idx[0] = new GetElementPtrInst(BufPtr, Idx, "JmpBuf", UnwindBlock); + Idx[1] = ConstantInt::get(Type::IntTy, 1); + new CallInst(LongJmpFn, Idx, "", UnwindBlock); + new UnreachableInst(UnwindBlock); + + // Set up the term block ("throw without a catch"). + new UnreachableInst(TermBlock); + + // Insert a new call to write(2, AbortMessage, AbortMessageLength); + writeAbortMessage(TermBlock->getTerminator()); + + // Insert a call to abort() + (new CallInst(AbortFn, std::vector<Value*>(), "", + TermBlock->getTerminator()))->setTailCall(); + + + // Replace all unwinds with a branch to the unwind handler. + for (unsigned i = 0, e = Unwinds.size(); i != e; ++i) { + new BranchInst(UnwindHandler, Unwinds[i]); + Unwinds[i]->eraseFromParent(); + } + + // Finally, for any returns from this function, if this function contains an + // invoke, restore the old jmpbuf pointer to its input value. + if (OldJmpBufPtr) { + for (unsigned i = 0, e = Returns.size(); i != e; ++i) { + ReturnInst *R = Returns[i]; + + // Before the return, insert a copy from the saved value to the new value. + Value *OldBuf = new LoadInst(OldJmpBufPtr, "oldjmpbufptr", true, R); + new StoreInst(OldBuf, JBListHead, true, R); + } + } + + return true; } bool LowerInvoke::runOnFunction(Function &F) { |