23 #include "llvm/ADT/DepthFirstIterator.h"
24 #include "llvm/ADT/PostOrderIterator.h"
25 #include "llvm/ADT/STLExtras.h"
34 for (
auto &use : llvm::make_early_inc_range(orig.
getUses())) {
35 if (region.
isAncestor(use.getOwner()->getParentRegion()))
43 "expected isolation limit to be an ancestor of the given region");
50 properAncestors.insert(reg);
56 if (properAncestors.count(operand.get().getParentRegion()))
63 for (
Region ®ion : regions)
70 values.insert(operand->
get());
76 for (
Region ®ion : regions)
92 std::deque<Value> worklist(initialCapturedValues.begin(),
93 initialCapturedValues.end());
99 while (!worklist.empty()) {
100 Value currValue = worklist.front();
101 worklist.pop_front();
102 if (visited.count(currValue))
104 visited.insert(currValue);
107 if (!definingOp || visitedOps.count(definingOp)) {
108 finalCapturedValues.insert(currValue);
111 visitedOps.insert(definingOp);
113 if (!cloneOperationIntoRegion(definingOp)) {
116 finalCapturedValues.insert(currValue);
123 if (visited.count(operand))
125 worklist.push_back(operand);
127 clonedOperations.push_back(definingOp);
144 for (
auto value : finalCapturedValues) {
145 newArgTypes.push_back(value.getType());
146 newArgLocs.push_back(value.getLoc());
150 Block *newEntryBlock =
157 return use.getOwner()->getBlock()->getParent() == ®ion;
159 for (
auto [arg, capturedVal] :
160 llvm::zip(newEntryBlockArgs.take_back(finalCapturedValues.size()),
161 finalCapturedValues)) {
162 map.
map(capturedVal, arg);
166 for (
auto *clonedOp : clonedOperations) {
171 entryBlock, newEntryBlock,
173 return llvm::to_vector(finalCapturedValues);
186 llvm::df_iterator_default_set<Block *, 16> reachable;
188 bool erasedDeadBlocks =
false;
191 worklist.reserve(regions.size());
192 for (
Region ®ion : regions)
193 worklist.push_back(®ion);
194 while (!worklist.empty()) {
195 Region *region = worklist.pop_back_val();
202 for (
Region ®ion : op.getRegions())
203 worklist.push_back(®ion);
209 for (
Block *block : depth_first_ext(®ion->
front(), reachable))
214 for (
Block &block : llvm::make_early_inc_range(*region)) {
215 if (!reachable.count(&block)) {
216 block.dropAllDefinedValueUses();
218 erasedDeadBlocks =
true;
224 for (
Region ®ion : op.getRegions())
225 worklist.push_back(®ion);
229 return success(erasedDeadBlocks);
248 bool wasProvenLive(
Value value) {
251 if (
OpResult result = dyn_cast<OpResult>(value))
252 return wasProvenLive(result.getOwner());
253 return wasProvenLive(cast<BlockArgument>(value));
255 bool wasProvenLive(
BlockArgument arg) {
return liveValues.count(arg); }
256 void setProvedLive(
Value value) {
259 if (
OpResult result = dyn_cast<OpResult>(value))
260 return setProvedLive(result.getOwner());
261 setProvedLive(cast<BlockArgument>(value));
264 changed |= liveValues.insert(arg).second;
268 bool wasProvenLive(
Operation *op) {
return liveOps.count(op); }
272 void resetChanged() {
changed =
false; }
273 bool hasChanged() {
return changed; }
297 if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner))
298 if (
auto arg = branchInterface.getSuccessorBlockArgument(operandIndex))
299 return !liveMap.wasProvenLive(*arg);
307 if (isUseSpeciallyKnownDead(use, liveMap))
309 return liveMap.wasProvenLive(use.getOwner());
312 liveMap.setProvedLive(value);
319 liveMap.setProvedLive(op);
322 BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op);
323 if (!branchInterface) {
326 liveMap.setProvedLive(arg);
334 branchInterface.getSuccessorOperands(i);
351 if (liveMap.wasProvenLive(op))
356 return liveMap.setProvedLive(op);
367 for (
Block *block : llvm::post_order(®ion.
front())) {
370 for (
Operation &op : llvm::reverse(block->getOperations()))
377 if (block->isEntryBlock())
380 for (
Value value : block->getArguments()) {
381 if (!liveMap.wasProvenLive(value))
389 BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator);
394 succI < succE; succI++) {
399 unsigned succ = succE - succI - 1;
403 for (
unsigned argI = 0, argE = succOperands.
size(); argI < argE; ++argI) {
406 unsigned arg = argE - argI - 1;
407 if (!liveMap.wasProvenLive(successor->
getArgument(arg)))
408 succOperands.
erase(arg);
416 bool erasedAnything =
false;
417 for (
Region ®ion : regions) {
427 for (
Block *block : llvm::post_order(®ion.
front())) {
431 llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
432 if (!liveMap.wasProvenLive(&childOp)) {
433 erasedAnything =
true;
434 childOp.dropAllUses();
437 erasedAnything |= succeeded(
446 block.eraseArguments(
447 [&](
BlockArgument arg) {
return !liveMap.wasProvenLive(arg); });
450 return success(erasedAnything);
474 liveMap.resetChanged();
476 for (
Region ®ion : regions)
478 }
while (liveMap.hasChanged());
508 struct BlockEquivalenceData {
509 BlockEquivalenceData(
Block *block);
513 unsigned getOrderOf(
Value value)
const;
518 llvm::hash_code hash;
526 BlockEquivalenceData::BlockEquivalenceData(
Block *block)
527 : block(block), hash(0) {
530 if (
unsigned numResults = op.getNumResults()) {
531 opOrderIndex.try_emplace(&op, orderIt);
532 orderIt += numResults;
538 hash = llvm::hash_combine(hash, opHash);
542 unsigned BlockEquivalenceData::getOrderOf(
Value value)
const {
543 assert(value.
getParentBlock() == block &&
"expected value of this block");
550 OpResult result = cast<OpResult>(value);
552 assert(opOrderIt != opOrderIndex.end() &&
"expected op to have an order");
562 class BlockMergeCluster {
564 BlockMergeCluster(BlockEquivalenceData &&leaderData)
565 : leaderData(std::move(leaderData)) {}
569 LogicalResult addToCluster(BlockEquivalenceData &blockData);
576 BlockEquivalenceData leaderData;
579 llvm::SmallSetVector<Block *, 1> blocksToMerge;
583 std::set<std::pair<int, int>> operandsToMerge;
587 LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
588 if (leaderData.hash != blockData.hash)
590 Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block;
596 auto lhsIt = leaderBlock->
begin(), lhsE = leaderBlock->
end();
597 auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end();
598 for (
int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
600 if (!OperationEquivalence::isEquivalentTo(
601 &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence,
603 OperationEquivalence::Flags::IgnoreLocations))
608 auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands();
609 for (
int operand : llvm::seq<int>(0, lhsIt->getNumOperands())) {
610 Value lhsOperand = lhsOperands[operand];
611 Value rhsOperand = rhsOperands[operand];
612 if (lhsOperand == rhsOperand)
621 if (lhsIsInBlock != rhsIsInBlock)
631 auto isValidSuccessorArg = [](
Block *block,
Value operand) {
632 if (operand.getDefiningOp() !=
633 operand.getParentBlock()->getTerminator())
636 operand.getParentBlock());
639 if (!isValidSuccessorArg(leaderBlock, lhsOperand) ||
640 !isValidSuccessorArg(mergeBlock, rhsOperand))
643 mismatchedOperands.emplace_back(opI, operand);
649 if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand))
659 if (rhsIt->isUsedOutsideOfBlock(mergeBlock) ||
660 lhsIt->isUsedOutsideOfBlock(leaderBlock)) {
665 if (lhsIt != lhsE || rhsIt != rhsE)
669 operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end());
670 blocksToMerge.insert(blockData.block);
678 if (!isa<BranchOpInterface>((*it)->getTerminator()))
697 if (newArguments.empty())
701 unsigned numLists = newArguments.size();
702 unsigned numArgs = newArguments[0].size();
712 for (
unsigned j = 0;
j < numArgs; ++
j) {
713 Value newArg = newArguments[0][
j];
714 firstValueToIdx.try_emplace(newArg,
j);
718 for (
unsigned j = 0;
j < numArgs; ++
j) {
730 unsigned k = firstValueToIdx[newArguments[0][
j]];
734 bool shouldReplaceJ =
true;
735 unsigned replacement = k;
739 for (
unsigned i = 1; i < numLists; ++i)
741 shouldReplaceJ && (newArguments[i][k] == newArguments[i][
j]);
744 idxToReplacement[
j] = replacement;
748 for (
unsigned i = 0; i < numLists; ++i)
749 for (
unsigned j = 0;
j < numArgs; ++
j)
750 if (!idxToReplacement.contains(
j))
751 newArgumentsPruned[i].push_back(newArguments[i][
j]);
756 if (idxToReplacement.contains(idx)) {
759 block->
getArgument(numOldArguments + idxToReplacement[idx]);
761 toErase.push_back(numOldArguments + idx);
766 for (
unsigned idxToErase : llvm::reverse(toErase))
768 return newArgumentsPruned;
771 LogicalResult BlockMergeCluster::merge(
RewriterBase &rewriter) {
773 if (blocksToMerge.empty())
776 Block *leaderBlock = leaderData.block;
777 if (!operandsToMerge.empty()) {
790 blockIterators.reserve(blocksToMerge.size() + 1);
791 blockIterators.push_back(leaderBlock->
begin());
792 for (
Block *mergeBlock : blocksToMerge)
793 blockIterators.push_back(mergeBlock->begin());
797 1 + blocksToMerge.size(),
799 unsigned curOpIndex = 0;
802 unsigned nextOpOffset = it.value().first - curOpIndex;
803 curOpIndex = it.value().first;
806 for (
unsigned i = 0, e = blockIterators.size(); i != e; ++i) {
808 std::advance(blockIter, nextOpOffset);
809 auto &operand = blockIter->getOpOperand(it.value().second);
810 newArguments[i][it.index()] = operand.get();
814 Value operandVal = operand.get();
823 numOldArguments, leaderBlock);
826 auto updatePredecessors = [&](
Block *block,
unsigned clusterIndex) {
828 predIt != predE; ++predIt) {
829 auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
830 unsigned succIndex = predIt.getSuccessorIndex();
831 branch.getSuccessorOperands(succIndex).append(
832 newArguments[clusterIndex]);
835 updatePredecessors(leaderBlock, 0);
836 for (
unsigned i = 0, e = blocksToMerge.size(); i != e; ++i)
837 updatePredecessors(blocksToMerge[i], i + 1);
841 for (
Block *block : blocksToMerge) {
860 for (
Block &block : llvm::drop_begin(region, 1))
863 bool mergedAnyBlocks =
false;
865 if (blocks.size() == 1)
869 for (
Block *block : blocks) {
870 BlockEquivalenceData data(block);
874 bool hasNonEmptyRegion = llvm::any_of(*block, [](
Operation &op) {
876 [](
Region ®ion) { return !region.empty(); });
878 if (hasNonEmptyRegion)
883 bool argHasExternalUsers = llvm::any_of(
885 return arg.isUsedOutsideOfBlock(block);
887 if (argHasExternalUsers)
891 bool addedToCluster =
false;
892 for (
auto &cluster : clusters)
893 if ((addedToCluster = succeeded(cluster.addToCluster(data))))
896 clusters.emplace_back(std::move(data));
898 for (
auto &cluster : clusters)
899 mergedAnyBlocks |= succeeded(cluster.merge(rewriter));
902 return success(mergedAnyBlocks);
909 llvm::SmallSetVector<Region *, 1> worklist;
910 for (
auto ®ion : regions)
911 worklist.insert(®ion);
912 bool anyChanged =
false;
913 while (!worklist.empty()) {
914 Region *region = worklist.pop_back_val();
916 worklist.insert(region);
921 for (
Block &block : *region)
922 for (
auto &op : block)
923 for (
auto &nestedRegion : op.getRegions())
924 worklist.insert(&nestedRegion);
927 return success(anyChanged);
945 predIt != predE; ++predIt) {
946 auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
951 unsigned succIndex = predIt.getSuccessorIndex();
955 commonValue = branchOperands[argIdx];
958 if (branchOperands[argIdx] != commonValue) {
965 if (commonValue && sameArg) {
966 argsToErase.push_back(argIdx);
974 for (
size_t argIdx : llvm::reverse(argsToErase)) {
979 predIt != predE; ++predIt) {
980 auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
981 unsigned succIndex = predIt.getSuccessorIndex();
983 succOperands.
erase(argIdx);
986 return success(!argsToErase.empty());
1016 llvm::SmallSetVector<Region *, 1> worklist;
1017 for (
Region ®ion : regions)
1018 worklist.insert(®ion);
1019 bool anyChanged =
false;
1020 while (!worklist.empty()) {
1021 Region *region = worklist.pop_back_val();
1024 for (
Block &block : *region) {
1029 for (
Region &nestedRegion : op.getRegions())
1030 worklist.insert(&nestedRegion);
1033 return success(anyChanged);
1048 bool eliminatedOpsOrArgs = succeeded(
runRegionDCE(rewriter, regions));
1049 bool mergedIdenticalBlocks =
false;
1050 bool droppedRedundantArguments =
false;
1053 droppedRedundantArguments =
1056 return success(eliminatedBlocks || eliminatedOpsOrArgs ||
1057 mergedIdenticalBlocks || droppedRedundantArguments);
1072 op,
"unsupported case where operation and insertion point are not in "
1073 "the same basic block");
1078 "insertion point does not dominate op");
1086 options.omitUsesFromAbove =
false;
1089 options.omitBlockArguments =
true;
1095 assert(result.succeeded() &&
"expected a backward slice");
1099 if (slice.contains(insertionPoint)) {
1102 "cannot move dependencies before operation in backward slice of op");
1126 for (
auto value : values) {
1131 if (isa<BlockArgument>(value)) {
1134 "unsupported case of moving block argument before insertion point");
1141 "unsupported case of moving definition of value before an insertion "
1142 "point in a different basic block");
1144 prunedValues.push_back(value);
1152 options.omitUsesFromAbove =
false;
1155 options.omitBlockArguments =
true;
1160 for (
auto value : prunedValues) {
1162 assert(result.succeeded() &&
"expected a backward slice");
1167 if (slice.contains(insertionPoint)) {
1170 "cannot move dependencies before operation in backward slice of op");
static llvm::ManagedStatic< PassManagerOptions > options
static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, Region ®ion)
Identify identical blocks within the given region and merge them, inserting new block arguments as ne...
static void propagateLiveness(Region ®ion, LiveMap &liveMap)
static void processValue(Value value, LiveMap &liveMap)
static bool ableToUpdatePredOperands(Block *block)
Returns true if the predecessor terminators of the given block can not have their operands updated.
static void eraseTerminatorSuccessorOperands(Operation *terminator, LiveMap &liveMap)
static LogicalResult dropRedundantArguments(RewriterBase &rewriter, Block &block)
If a block's argument is always the same across different invocations, then drop the argument and use...
static SmallVector< SmallVector< Value, 8 >, 2 > pruneRedundantArguments(const SmallVector< SmallVector< Value, 8 >, 2 > &newArguments, RewriterBase &rewriter, unsigned numOldArguments, Block *block)
Prunes the redundant list of new arguments.
static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap)
static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap)
static LogicalResult deleteDeadness(RewriterBase &rewriter, MutableArrayRef< Region > regions, LiveMap &liveMap)
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
pred_iterator pred_begin()
SuccessorRange getSuccessors()
iterator_range< pred_iterator > getPredecessors()
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
IRValueT get() const
Return the current value being used by this operand.
RAII guard to reset the insertion point of the builder when destroyed.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Block * getSuccessor(unsigned index)
unsigned getNumSuccessors()
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
SuccessorRange getSuccessors()
result_range getResults()
Implement a predecessor iterator for blocks.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
BlockListType & getBlocks()
bool hasOneBlock()
Return true if this region has exactly one block.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void replaceOpUsesWithIf(Operation *from, ValueRange to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
This class models how operands are forwarded to block arguments in control flow.
void erase(unsigned subStart, unsigned subLen=1)
Erase operands forwarded to the successor.
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
unsigned size() const
Returns the amount of operands passed to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion)
Replace all uses of orig within the given region with replacement.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
bool computeTopologicalSorting(MutableArrayRef< Operation * > ops, function_ref< bool(Value, Operation *)> isOperandReady=nullptr)
Compute a topological ordering of the given ops.
LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op, Operation *insertionPoint, DominanceInfo &dominance)
Move SSA values used within an operation before an insertion point, so that the operation itself (or ...
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter, MutableArrayRef< Region > regions)
Erase the unreachable blocks within the provided regions.
SmallVector< Value > makeRegionIsolatedFromAbove(RewriterBase &rewriter, Region ®ion, llvm::function_ref< bool(Operation *)> cloneOperationIntoRegion=[](Operation *) { return false;})
Make a region isolated from above.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
LogicalResult runRegionDCE(RewriterBase &rewriter, MutableArrayRef< Region > regions)
This function returns success if any operations or arguments were deleted, failure otherwise.
LogicalResult simplifyRegions(RewriterBase &rewriter, MutableArrayRef< Region > regions, bool mergeBlocks=true)
Run a set of structural simplifications over the given regions.
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values before an insertion point.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.