25 #include "llvm/ADT/DepthFirstIterator.h"
26 #include "llvm/ADT/PostOrderIterator.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallSet.h"
37 for (
auto &use : llvm::make_early_inc_range(orig.
getUses())) {
38 if (region.
isAncestor(use.getOwner()->getParentRegion()))
46 "expected isolation limit to be an ancestor of the given region");
53 properAncestors.insert(reg);
59 if (properAncestors.count(operand.get().getParentRegion()))
66 for (
Region ®ion : regions)
73 values.insert(operand->
get());
79 for (
Region ®ion : regions)
95 std::deque<Value> worklist(initialCapturedValues.begin(),
96 initialCapturedValues.end());
102 while (!worklist.empty()) {
103 Value currValue = worklist.front();
104 worklist.pop_front();
105 if (visited.count(currValue))
107 visited.insert(currValue);
110 if (!definingOp || visitedOps.count(definingOp)) {
111 finalCapturedValues.insert(currValue);
114 visitedOps.insert(definingOp);
116 if (!cloneOperationIntoRegion(definingOp)) {
119 finalCapturedValues.insert(currValue);
126 if (visited.count(operand))
128 worklist.push_back(operand);
130 clonedOperations.push_back(definingOp);
147 for (
auto value : finalCapturedValues) {
148 newArgTypes.push_back(value.getType());
149 newArgLocs.push_back(value.getLoc());
153 Block *newEntryBlock =
160 return use.getOwner()->getBlock()->getParent() == ®ion;
162 for (
auto [arg, capturedVal] :
163 llvm::zip(newEntryBlockArgs.take_back(finalCapturedValues.size()),
164 finalCapturedValues)) {
165 map.
map(capturedVal, arg);
169 for (
auto *clonedOp : clonedOperations) {
174 entryBlock, newEntryBlock,
176 return llvm::to_vector(finalCapturedValues);
189 llvm::df_iterator_default_set<Block *, 16> reachable;
191 bool erasedDeadBlocks =
false;
194 worklist.reserve(regions.size());
195 for (
Region ®ion : regions)
196 worklist.push_back(®ion);
197 while (!worklist.empty()) {
198 Region *region = worklist.pop_back_val();
203 if (std::next(region->
begin()) == region->
end()) {
205 for (
Region ®ion : op.getRegions())
206 worklist.push_back(®ion);
212 for (
Block *block : depth_first_ext(®ion->
front(), reachable))
217 for (
Block &block : llvm::make_early_inc_range(*region)) {
218 if (!reachable.count(&block)) {
219 block.dropAllDefinedValueUses();
221 erasedDeadBlocks =
true;
227 for (
Region ®ion : op.getRegions())
228 worklist.push_back(®ion);
232 return success(erasedDeadBlocks);
251 bool wasProvenLive(
Value value) {
254 if (
OpResult result = dyn_cast<OpResult>(value))
255 return wasProvenLive(result.getOwner());
256 return wasProvenLive(cast<BlockArgument>(value));
258 bool wasProvenLive(
BlockArgument arg) {
return liveValues.count(arg); }
259 void setProvedLive(
Value value) {
262 if (
OpResult result = dyn_cast<OpResult>(value))
263 return setProvedLive(result.getOwner());
264 setProvedLive(cast<BlockArgument>(value));
267 changed |= liveValues.insert(arg).second;
271 bool wasProvenLive(
Operation *op) {
return liveOps.count(op); }
275 void resetChanged() {
changed =
false; }
276 bool hasChanged() {
return changed; }
300 if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner))
301 if (
auto arg = branchInterface.getSuccessorBlockArgument(operandIndex))
302 return !liveMap.wasProvenLive(*arg);
310 if (isUseSpeciallyKnownDead(use, liveMap))
312 return liveMap.wasProvenLive(use.getOwner());
315 liveMap.setProvedLive(value);
322 liveMap.setProvedLive(op);
325 BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op);
326 if (!branchInterface) {
329 liveMap.setProvedLive(arg);
337 branchInterface.getSuccessorOperands(i);
354 if (liveMap.wasProvenLive(op))
359 return liveMap.setProvedLive(op);
370 for (
Block *block : llvm::post_order(®ion.
front())) {
373 for (
Operation &op : llvm::reverse(block->getOperations()))
380 if (block->isEntryBlock())
383 for (
Value value : block->getArguments()) {
384 if (!liveMap.wasProvenLive(value))
392 BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator);
397 succI < succE; succI++) {
402 unsigned succ = succE - succI - 1;
406 for (
unsigned argI = 0, argE = succOperands.
size(); argI < argE; ++argI) {
409 unsigned arg = argE - argI - 1;
410 if (!liveMap.wasProvenLive(successor->
getArgument(arg)))
411 succOperands.
erase(arg);
419 bool erasedAnything =
false;
420 for (
Region ®ion : regions) {
423 bool hasSingleBlock = llvm::hasSingleElement(region);
430 for (
Block *block : llvm::post_order(®ion.
front())) {
434 llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
435 if (!liveMap.wasProvenLive(&childOp)) {
436 erasedAnything =
true;
437 childOp.dropAllUses();
440 erasedAnything |= succeeded(
449 block.eraseArguments(
450 [&](
BlockArgument arg) {
return !liveMap.wasProvenLive(arg); });
453 return success(erasedAnything);
477 liveMap.resetChanged();
479 for (
Region ®ion : regions)
481 }
while (liveMap.hasChanged());
510 struct BlockEquivalenceData {
511 BlockEquivalenceData(
Block *block);
515 unsigned getOrderOf(
Value value)
const;
520 llvm::hash_code hash;
528 BlockEquivalenceData::BlockEquivalenceData(
Block *block)
529 : block(block), hash(0) {
532 if (
unsigned numResults = op.getNumResults()) {
533 opOrderIndex.try_emplace(&op, orderIt);
534 orderIt += numResults;
540 hash = llvm::hash_combine(hash, opHash);
544 unsigned BlockEquivalenceData::getOrderOf(
Value value)
const {
545 assert(value.
getParentBlock() == block &&
"expected value of this block");
552 OpResult result = cast<OpResult>(value);
554 assert(opOrderIt != opOrderIndex.end() &&
"expected op to have an order");
563 class BlockMergeCluster {
565 BlockMergeCluster(BlockEquivalenceData &&leaderData)
566 : leaderData(std::move(leaderData)) {}
570 LogicalResult addToCluster(BlockEquivalenceData &blockData);
577 BlockEquivalenceData leaderData;
580 llvm::SmallSetVector<Block *, 1> blocksToMerge;
584 std::set<std::pair<int, int>> operandsToMerge;
588 LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
589 if (leaderData.hash != blockData.hash)
591 Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block;
597 auto lhsIt = leaderBlock->
begin(), lhsE = leaderBlock->
end();
598 auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end();
599 for (
int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
601 if (!OperationEquivalence::isEquivalentTo(
602 &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence,
604 OperationEquivalence::Flags::IgnoreLocations))
609 auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands();
610 for (
int operand : llvm::seq<int>(0, lhsIt->getNumOperands())) {
611 Value lhsOperand = lhsOperands[operand];
612 Value rhsOperand = rhsOperands[operand];
613 if (lhsOperand == rhsOperand)
622 if (lhsIsInBlock != rhsIsInBlock)
632 auto isValidSuccessorArg = [](
Block *block,
Value operand) {
633 if (operand.getDefiningOp() !=
634 operand.getParentBlock()->getTerminator())
637 operand.getParentBlock());
640 if (!isValidSuccessorArg(leaderBlock, lhsOperand) ||
641 !isValidSuccessorArg(mergeBlock, rhsOperand))
644 mismatchedOperands.emplace_back(opI, operand);
650 if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand))
660 if (rhsIt->isUsedOutsideOfBlock(mergeBlock) ||
661 lhsIt->isUsedOutsideOfBlock(leaderBlock)) {
666 if (lhsIt != lhsE || rhsIt != rhsE)
670 operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end());
671 blocksToMerge.insert(blockData.block);
679 if (!isa<BranchOpInterface>((*it)->getTerminator()))
698 if (newArguments.empty())
702 unsigned numLists = newArguments.size();
703 unsigned numArgs = newArguments[0].size();
713 for (
unsigned j = 0;
j < numArgs; ++
j) {
714 Value newArg = newArguments[0][
j];
715 firstValueToIdx.try_emplace(newArg,
j);
719 for (
unsigned j = 0;
j < numArgs; ++
j) {
731 unsigned k = firstValueToIdx[newArguments[0][
j]];
735 bool shouldReplaceJ =
true;
736 unsigned replacement = k;
740 for (
unsigned i = 1; i < numLists; ++i)
742 shouldReplaceJ && (newArguments[i][k] == newArguments[i][
j]);
745 idxToReplacement[
j] = replacement;
749 for (
unsigned i = 0; i < numLists; ++i)
750 for (
unsigned j = 0;
j < numArgs; ++
j)
751 if (!idxToReplacement.contains(
j))
752 newArgumentsPruned[i].push_back(newArguments[i][
j]);
757 if (idxToReplacement.contains(idx)) {
760 block->
getArgument(numOldArguments + idxToReplacement[idx]);
762 toErase.push_back(numOldArguments + idx);
767 for (
unsigned idxToErase : llvm::reverse(toErase))
769 return newArgumentsPruned;
772 LogicalResult BlockMergeCluster::merge(
RewriterBase &rewriter) {
774 if (blocksToMerge.empty())
777 Block *leaderBlock = leaderData.block;
778 if (!operandsToMerge.empty()) {
791 blockIterators.reserve(blocksToMerge.size() + 1);
792 blockIterators.push_back(leaderBlock->
begin());
793 for (
Block *mergeBlock : blocksToMerge)
794 blockIterators.push_back(mergeBlock->begin());
798 1 + blocksToMerge.size(),
800 unsigned curOpIndex = 0;
803 unsigned nextOpOffset = it.value().first - curOpIndex;
804 curOpIndex = it.value().first;
807 for (
unsigned i = 0, e = blockIterators.size(); i != e; ++i) {
809 std::advance(blockIter, nextOpOffset);
810 auto &operand = blockIter->getOpOperand(it.value().second);
811 newArguments[i][it.index()] = operand.get();
815 Value operandVal = operand.get();
824 numOldArguments, leaderBlock);
827 auto updatePredecessors = [&](
Block *block,
unsigned clusterIndex) {
829 predIt != predE; ++predIt) {
830 auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
831 unsigned succIndex = predIt.getSuccessorIndex();
832 branch.getSuccessorOperands(succIndex).append(
833 newArguments[clusterIndex]);
836 updatePredecessors(leaderBlock, 0);
837 for (
unsigned i = 0, e = blocksToMerge.size(); i != e; ++i)
838 updatePredecessors(blocksToMerge[i], i + 1);
842 for (
Block *block : blocksToMerge) {
854 if (region.
empty() || llvm::hasSingleElement(region))
861 for (
Block &block : llvm::drop_begin(region, 1))
864 bool mergedAnyBlocks =
false;
866 if (blocks.size() == 1)
870 for (
Block *block : blocks) {
871 BlockEquivalenceData data(block);
875 bool hasNonEmptyRegion = llvm::any_of(*block, [](
Operation &op) {
877 [](
Region ®ion) { return !region.empty(); });
879 if (hasNonEmptyRegion)
884 bool argHasExternalUsers = llvm::any_of(
886 return arg.isUsedOutsideOfBlock(block);
888 if (argHasExternalUsers)
892 bool addedToCluster =
false;
893 for (
auto &cluster : clusters)
894 if ((addedToCluster = succeeded(cluster.addToCluster(data))))
897 clusters.emplace_back(std::move(data));
899 for (
auto &cluster : clusters)
900 mergedAnyBlocks |= succeeded(cluster.merge(rewriter));
903 return success(mergedAnyBlocks);
910 llvm::SmallSetVector<Region *, 1> worklist;
911 for (
auto ®ion : regions)
912 worklist.insert(®ion);
913 bool anyChanged =
false;
914 while (!worklist.empty()) {
915 Region *region = worklist.pop_back_val();
917 worklist.insert(region);
922 for (
Block &block : *region)
923 for (
auto &op : block)
924 for (
auto &nestedRegion : op.getRegions())
925 worklist.insert(&nestedRegion);
928 return success(anyChanged);
946 predIt != predE; ++predIt) {
947 auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
952 unsigned succIndex = predIt.getSuccessorIndex();
956 commonValue = branchOperands[argIdx];
959 if (branchOperands[argIdx] != commonValue) {
966 if (commonValue && sameArg) {
967 argsToErase.push_back(argIdx);
975 for (
size_t argIdx : llvm::reverse(argsToErase)) {
980 predIt != predE; ++predIt) {
981 auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
982 unsigned succIndex = predIt.getSuccessorIndex();
984 succOperands.
erase(argIdx);
987 return success(!argsToErase.empty());
1017 llvm::SmallSetVector<Region *, 1> worklist;
1018 for (
Region ®ion : regions)
1019 worklist.insert(®ion);
1020 bool anyChanged =
false;
1021 while (!worklist.empty()) {
1022 Region *region = worklist.pop_back_val();
1025 for (
Block &block : *region) {
1030 for (
Region &nestedRegion : op.getRegions())
1031 worklist.insert(&nestedRegion);
1034 return success(anyChanged);
1049 bool eliminatedOpsOrArgs = succeeded(
runRegionDCE(rewriter, regions));
1050 bool mergedIdenticalBlocks =
false;
1051 bool droppedRedundantArguments =
false;
1054 droppedRedundantArguments =
1057 return success(eliminatedBlocks || eliminatedOpsOrArgs ||
1058 mergedIdenticalBlocks || droppedRedundantArguments);
1073 op,
"unsupported case where operation and insertion point are not in "
1074 "the same basic block");
1079 "insertion point does not dominate op");
1087 options.omitUsesFromAbove =
false;
1090 options.omitBlockArguments =
true;
1098 if (slice.contains(insertionPoint)) {
1101 "cannot move dependencies before operation in backward slice of op");
1125 for (
auto value : values) {
1130 if (isa<BlockArgument>(value)) {
1133 "unsupported case of moving block argument before insertion point");
1140 "unsupported case of moving definition of value before an insertion "
1141 "point in a different basic block");
1143 prunedValues.push_back(value);
1151 options.omitUsesFromAbove =
false;
1154 options.omitBlockArguments =
true;
1159 for (
auto value : prunedValues) {
1164 if (slice.contains(insertionPoint)) {
1167 "cannot move dependencies before operation in backward slice of op");
static llvm::ManagedStatic< PassManagerOptions > options
static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, Region ®ion)
Identify identical blocks within the given region and merge them, inserting new block arguments as ne...
static void propagateLiveness(Region ®ion, LiveMap &liveMap)
static void processValue(Value value, LiveMap &liveMap)
static bool ableToUpdatePredOperands(Block *block)
Returns true if the predecessor terminators of the given block can not have their operands updated.
static void eraseTerminatorSuccessorOperands(Operation *terminator, LiveMap &liveMap)
static LogicalResult dropRedundantArguments(RewriterBase &rewriter, Block &block)
If a block's argument is always the same across different invocations, then drop the argument and use...
static SmallVector< SmallVector< Value, 8 >, 2 > pruneRedundantArguments(const SmallVector< SmallVector< Value, 8 >, 2 > &newArguments, RewriterBase &rewriter, unsigned numOldArguments, Block *block)
Prunes the redundant list of new arguments.
static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap)
static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap)
static LogicalResult deleteDeadness(RewriterBase &rewriter, MutableArrayRef< Region > regions, LiveMap &liveMap)
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
pred_iterator pred_begin()
SuccessorRange getSuccessors()
iterator_range< pred_iterator > getPredecessors()
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
IRValueT get() const
Return the current value being used by this operand.
RAII guard to reset the insertion point of the builder when destroyed.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Block * getSuccessor(unsigned index)
unsigned getNumSuccessors()
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
SuccessorRange getSuccessors()
result_range getResults()
Implement a predecessor iterator for blocks.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
BlockListType & getBlocks()
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void replaceOpUsesWithIf(Operation *from, ValueRange to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
This class models how operands are forwarded to block arguments in control flow.
void erase(unsigned subStart, unsigned subLen=1)
Erase operands forwarded to the successor.
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
unsigned size() const
Returns the amount of operands passed to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion)
Replace all uses of orig within the given region with replacement.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
bool computeTopologicalSorting(MutableArrayRef< Operation * > ops, function_ref< bool(Value, Operation *)> isOperandReady=nullptr)
Compute a topological ordering of the given ops.
LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op, Operation *insertionPoint, DominanceInfo &dominance)
Move SSA values used within an operation before an insertion point, so that the operation itself (or ...
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter, MutableArrayRef< Region > regions)
Erase the unreachable blocks within the provided regions.
SmallVector< Value > makeRegionIsolatedFromAbove(RewriterBase &rewriter, Region ®ion, llvm::function_ref< bool(Operation *)> cloneOperationIntoRegion=[](Operation *) { return false;})
Make a region isolated from above.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
LogicalResult runRegionDCE(RewriterBase &rewriter, MutableArrayRef< Region > regions)
This function returns success if any operations or arguments were deleted, failure otherwise.
LogicalResult simplifyRegions(RewriterBase &rewriter, MutableArrayRef< Region > regions, bool mergeBlocks=true)
Run a set of structural simplifications over the given regions.
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values before an insertion point.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
static llvm::hash_code ignoreHashValue(Value)
Helper that can be used with computeHash above to ignore operation operands/result mapping.
static llvm::hash_code computeHash(Operation *op, function_ref< llvm::hash_code(Value)> hashOperands=[](Value v) { return hash_value(v);}, function_ref< llvm::hash_code(Value)> hashResults=[](Value v) { return hash_value(v);}, Flags flags=Flags::None)
Compute a hash for the given operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.