21#include "llvm/Support/Debug.h"
22#include "llvm/Support/DebugLog.h"
25#define DEBUG_TYPE "licm"
43 for (
OpOperand &operand : child->getOpOperands()) {
45 if (op->
isAncestor(operand.get().getParentRegion()->getParentOp()))
47 if (!condition(operand))
52 return !op->
walk(walkFn).wasInterrupted();
58 op, [&](
OpOperand &operand) {
return definedOutside(operand.
get()); });
68 for (
Region *region : regions) {
69 LDBG() <<
"Original loop:\n"
73 std::queue<Operation *> worklist;
78 auto definedOutside = [&](
Value value) {
79 return isDefinedOutsideRegion(value, region);
82 while (!worklist.empty()) {
89 LDBG() <<
"Checking op: "
91 if (!shouldMoveOutOfRegion(op, region) ||
95 LDBG() <<
"Moving loop-invariant op: "
97 moveOutOfRegion(op, region);
103 if (user->getParentRegion() == region)
113 loopLike.getLoopRegions(),
115 return loopLike.isDefinedOutsideOfLoop(value);
123class MatchingSubsets {
126 void insert(SubsetOpInterface op,
bool collectHoistableOps =
true) {
127 allSubsetOps.push_back(op);
128 if (!collectHoistableOps)
130 if (
auto extractionOp =
131 dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
132 insertExtractionOp(extractionOp);
133 if (
auto insertionOp =
134 dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
135 insertInsertionOp(insertionOp);
142 auto getHoistableSubsetOps() {
143 return llvm::make_filter_range(
144 llvm::zip(extractions, insertions), [&](
auto pair) {
145 auto [extractionOp, insertionOp] = pair;
147 if (extractionOp && insertionOp &&
148 extractionOp->getResult(0).getType() !=
149 insertionOp.getSourceOperand().get().getType())
152 return allDisjoint(extractionOp, insertionOp);
161 LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
162 BlockArgument iterArg,
163 bool collectHoistableOps =
true);
169 static bool isEquivalent(Value v1, Value v2) {
return true; }
174 bool allDisjoint(SubsetExtractionOpInterface extractionOp,
175 SubsetInsertionOpInterface insertionOp)
const {
176 for (SubsetOpInterface other : allSubsetOps) {
177 if (other == extractionOp || other == insertionOp)
180 !other.operatesOnDisjointSubset(extractionOp, isEquivalent))
183 !other.operatesOnDisjointSubset(insertionOp, isEquivalent))
192 void insertExtractionOp(SubsetExtractionOpInterface extractionOp) {
193 for (
auto it : llvm::enumerate(insertions)) {
196 auto other = cast<SubsetOpInterface>(it.value().getOperation());
197 if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) {
198 extractions[it.index()] = extractionOp;
203 extractions.push_back(extractionOp);
204 insertions.push_back({});
210 void insertInsertionOp(SubsetInsertionOpInterface insertionOp) {
211 for (
auto it : llvm::enumerate(extractions)) {
214 auto other = cast<SubsetOpInterface>(it.value().getOperation());
215 if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) {
216 insertions[it.index()] = insertionOp;
221 extractions.push_back({});
222 insertions.push_back(insertionOp);
225 SmallVector<SubsetExtractionOpInterface> extractions;
226 SmallVector<SubsetInsertionOpInterface> insertions;
227 SmallVector<SubsetOpInterface> allSubsetOps;
243MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
245 bool collectHoistableOps) {
247 Value value = iterArg;
253 OpOperand *yieldedOperand =
nullptr;
257 Value nextValue = {};
259 for (OpOperand &use : value.
getUses()) {
260 if (
auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
265 auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use);
271 if (
failed(populateSubsetOpsAtIterArg(nestedLoop, nestedIterArg,
278 nextValue = nestedLoop.getTiedLoopResult(&use);
282 auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
285 insert(subsetOp, collectHoistableOps);
287 if (
auto insertionOp =
288 dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
292 auto dstOp = dyn_cast<DestinationStyleOpInterface>(use.getOwner());
293 if (!dstOp || !dstOp.hasPureTensorSemantics())
298 if (&use != &insertionOp.getDestinationOperand())
304 nextValue = insertionOp.getUpdatedDestination();
318 if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
329 LoopLikeOpInterface loopLike,
332 BlockArgument *it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
333 int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
334 MatchingSubsets subsets;
335 if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
339 for (
auto it : subsets.getHoistableSubsetOps()) {
340 auto extractionOp = std::get<0>(it);
341 auto insertionOp = std::get<1>(it);
346 return loopLike.isDefinedOutsideOfLoop(operand.
get()) ||
347 &operand == &extractionOp.getSourceOperand();
353 return loopLike.isDefinedOutsideOfLoop(operand.
get()) ||
354 &operand == &insertionOp.getSourceOperand() ||
355 &operand == &insertionOp.getDestinationOperand();
363 if (extractionOp && insertionOp) {
368 return {insertionOp.getSourceOperand().get()};
370 FailureOr<LoopLikeOpInterface> newLoop =
371 loopLike.replaceWithAdditionalYields(
372 rewriter, extractionOp.getResult(),
373 true, newYieldValuesFn);
379 iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
380 OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
381 OpResult newLoopResult = loopLike.getLoopResults()->back();
385 insertionOp.getDestinationOperand().get());
386 extractionOp.getSourceOperand().set(
387 loopLike.getTiedLoopInit(iterArg)->get());
389 insertionOp.getUpdatedDestination());
390 insertionOp.getSourceOperand().set(newLoopResult);
391 insertionOp.getDestinationOperand().set(loopResult);
400 LoopLikeOpInterface loopLike) {
405 i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) {
407 loopLike.getRegionIterArgs()[i]);
static OpOperand * getSingleTerminatorUse(Value value)
If the given value has a single use by an op that is a terminator, return that use.
static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter, LoopLikeOpInterface loopLike, BlockArgument iterArg)
Hoist all subset ops that operate on the idx-th region iter_arg of the given loop-like op and index i...
static bool canBeHoisted(Operation *op, function_ref< bool(OpOperand &)> condition)
Checks whether the given op can be hoisted by checking that.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be terminators.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
bool hasOneUse() const
Returns true if this value has exactly one use.
static WalkResult advance()
static WalkResult interrupt()
Operation * getOwner() const
Return the owner of this operand.
Include the generated interface declarations.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
LoopLikeOpInterface hoistLoopInvariantSubsets(RewriterBase &rewriter, LoopLikeOpInterface loopLike)
Hoist loop-invariant tensor subsets (subset extraction and subset insertion ops) from loop-like ops.
llvm::function_ref< Fn > function_ref
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.