21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/DebugLog.h"
25 #define DEBUG_TYPE "licm"
43 for (
OpOperand &operand : child->getOpOperands()) {
45 if (op->
isAncestor(operand.get().getParentRegion()->getParentOp()))
47 if (!condition(operand))
52 return !op->
walk(walkFn).wasInterrupted();
58 op, [&](
OpOperand &operand) {
return definedOutside(operand.
get()); });
68 for (
Region *region : regions) {
69 LDBG() <<
"Original loop:\n"
73 std::queue<Operation *> worklist;
78 auto definedOutside = [&](
Value value) {
79 return isDefinedOutsideRegion(value, region);
82 while (!worklist.empty()) {
89 LDBG() <<
"Checking op: "
91 if (!shouldMoveOutOfRegion(op, region) ||
95 LDBG() <<
"Moving loop-invariant op: "
97 moveOutOfRegion(op, region);
103 if (user->getParentRegion() == region)
113 loopLike.getLoopRegions(),
115 return loopLike.isDefinedOutsideOfLoop(value);
123 class MatchingSubsets {
126 void insert(SubsetOpInterface op,
bool collectHoistableOps =
true) {
127 allSubsetOps.push_back(op);
128 if (!collectHoistableOps)
130 if (
auto extractionOp =
131 dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
132 insertExtractionOp(extractionOp);
133 if (
auto insertionOp =
134 dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
135 insertInsertionOp(insertionOp);
142 auto getHoistableSubsetOps() {
143 return llvm::make_filter_range(
144 llvm::zip(extractions, insertions), [&](
auto pair) {
145 auto [extractionOp, insertionOp] = pair;
147 if (extractionOp && insertionOp &&
148 extractionOp->getResult(0).getType() !=
149 insertionOp.getSourceOperand().get().getType())
152 return allDisjoint(extractionOp, insertionOp);
161 LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
163 bool collectHoistableOps =
true);
169 static bool isEquivalent(
Value v1,
Value v2) {
return true; }
174 bool allDisjoint(SubsetExtractionOpInterface extractionOp,
175 SubsetInsertionOpInterface insertionOp)
const {
176 for (SubsetOpInterface other : allSubsetOps) {
177 if (other == extractionOp || other == insertionOp)
180 !other.operatesOnDisjointSubset(extractionOp, isEquivalent))
183 !other.operatesOnDisjointSubset(insertionOp, isEquivalent))
192 void insertExtractionOp(SubsetExtractionOpInterface extractionOp) {
196 auto other = cast<SubsetOpInterface>(it.value().getOperation());
197 if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) {
198 extractions[it.index()] = extractionOp;
203 extractions.push_back(extractionOp);
204 insertions.push_back({});
210 void insertInsertionOp(SubsetInsertionOpInterface insertionOp) {
214 auto other = cast<SubsetOpInterface>(it.value().getOperation());
215 if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) {
216 insertions[it.index()] = insertionOp;
221 extractions.push_back({});
222 insertions.push_back(insertionOp);
243 MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
245 bool collectHoistableOps) {
247 Value value = iterArg;
257 Value nextValue = {};
260 if (
auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
265 auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use);
271 if (
failed(populateSubsetOpsAtIterArg(nestedLoop, nestedIterArg,
274 nextValue = nestedLoop.getTiedLoopResult(&use);
278 auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
283 if (
auto insertionOp =
284 dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
288 auto dstOp = dyn_cast<DestinationStyleOpInterface>(use.getOwner());
289 if (!dstOp || !dstOp.hasPureTensorSemantics())
294 if (&use != &insertionOp.getDestinationOperand())
300 nextValue = insertionOp.getUpdatedDestination();
314 if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
325 LoopLikeOpInterface loopLike,
328 BlockArgument *it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
329 int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
330 MatchingSubsets subsets;
331 if (
failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
335 for (
auto it : subsets.getHoistableSubsetOps()) {
336 auto extractionOp = std::get<0>(it);
337 auto insertionOp = std::get<1>(it);
342 return loopLike.isDefinedOutsideOfLoop(operand.
get()) ||
343 &operand == &extractionOp.getSourceOperand();
349 return loopLike.isDefinedOutsideOfLoop(operand.
get()) ||
350 &operand == &insertionOp.getSourceOperand() ||
351 &operand == &insertionOp.getDestinationOperand();
359 if (extractionOp && insertionOp) {
364 return {insertionOp.getSourceOperand().get()};
366 FailureOr<LoopLikeOpInterface> newLoop =
367 loopLike.replaceWithAdditionalYields(
368 rewriter, extractionOp.getResult(),
369 true, newYieldValuesFn);
375 iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
376 OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
377 OpResult newLoopResult = loopLike.getLoopResults()->back();
381 insertionOp.getDestinationOperand().get());
382 extractionOp.getSourceOperand().set(
383 loopLike.getTiedLoopInit(iterArg)->get());
385 insertionOp.getUpdatedDestination());
386 insertionOp.getSourceOperand().set(newLoopResult);
387 insertionOp.getDestinationOperand().set(loopResult);
396 LoopLikeOpInterface loopLike) {
401 i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) {
403 loopLike.getRegionIterArgs()[i]);
static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter, LoopLikeOpInterface loopLike, BlockArgument iterArg)
Hoist all subset ops that operate on the idx-th region iter_arg of the given loop-like op and index i...
static OpOperand * getSingleTerminatorUse(Value value)
If the given value has a single use by an op that is a terminator, return that use.
static bool canBeHoisted(Operation *op, function_ref< bool(OpOperand &)> condition)
Checks whether the given op can be hoisted by checking that.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be terminators.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
bool hasOneUse() const
Returns true if this value has exactly one use.
static WalkResult advance()
static WalkResult interrupt()
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LoopLikeOpInterface hoistLoopInvariantSubsets(RewriterBase &rewriter, LoopLikeOpInterface loopLike)
Hoist loop-invariant tensor subsets (subset extraction and subset insertion ops) from loop-like ops.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.