26#include "llvm/Support/Debug.h" 
   30#define DEBUG_TYPE "linalg-hoisting" 
   32#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") 
   47                                            Value newYieldValue) {
 
   50  auto inits = llvm::to_vector(loop.getInits());
 
   53  assert(
index < inits.size());
 
   54  inits[
index] = newInitOperand;
 
   56  scf::ForOp newLoop = scf::ForOp::create(
 
   57      rewriter, loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(),
 
   59      loop.getUnsignedCmp());
 
   62  auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
 
   63  yieldOp.setOperand(
index, newYieldValue);
 
   66  rewriter.
mergeBlocks(loop.getBody(), newLoop.getBody(),
 
   67                       newLoop.getBody()->getArguments());
 
   70  rewriter.
replaceOp(loop.getOperation(), newLoop->getResults());
 
 
   99    root->
walk([&](vector::ExtractOp extractOp) {
 
  100      LLVM_DEBUG(
DBGS() << 
"Candidate for hoisting: " 
  101                        << *extractOp.getOperation() << 
"\n");
 
  103      auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp());
 
  108      auto blockArg = dyn_cast<BlockArgument>(extractOp.getSource());
 
  113      OpOperand *initArg = loop.getTiedLoopInit(blockArg);
 
  119      if (!blockArg.hasOneUse())
 
  122      unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars();
 
  126          loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
 
  127      auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
 
  131      LLVM_DEBUG(
DBGS() << 
"Candidate broadcast: " << 
broadcast << 
"\n");
 
  134      if (broadcastInputType != extractOp.getType())
 
  139      for (
auto operand : extractOp.getDynamicPosition())
 
  140        if (!loop.isDefinedOutsideOfLoop(operand))
 
  144        extractOp.getSourceMutable().assign(initArg->
get());
 
  146      loop.moveOutOfLoop(extractOp);
 
  150          rewriter, loop, extractOp.getResult(), 
index, 
broadcast.getSource());
 
  152      LLVM_DEBUG(
DBGS() << 
"New loop: " << newLoop << 
"\n");
 
 
  165                                LoopLikeOpInterface loop) {
 
  166  Value source = transferRead.getBase();
 
  169  while (
auto viewLike = source.
getDefiningOp<ViewLikeOpInterface>()) {
 
  170    if (viewLike.getViewDest() != source) {
 
  173    source = viewLike.getViewSource();
 
  178  llvm::SmallDenseSet<Operation *, 32> processed;
 
  179  while (!users.empty()) {
 
  182    if (!processed.insert(user).second)
 
  184    if (
auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
 
  185      Value viewDest = viewLike.getViewDest();
 
  191    if (!loop->isAncestor(user))
 
 
  199                                                 bool verifyNonZeroTrip) {
 
  213    if (verifyNonZeroTrip) {
 
  214      root->
walk([&](LoopLikeOpInterface loopLike) {
 
  215        std::optional<SmallVector<OpFoldResult>> lbs =
 
  216            loopLike.getLoopLowerBounds();
 
  217        std::optional<SmallVector<OpFoldResult>> ubs =
 
  218            loopLike.getLoopUpperBounds();
 
  226        for (
auto [lb, 
ub] : llvm::zip_equal(lbs.value(), ubs.value())) {
 
  227          FailureOr<int64_t> maxLb =
 
  233          FailureOr<int64_t> minUb =
 
  238          if (minUb.value() <= maxLb.value())
 
  240          definiteNonZeroTripCountLoops.insert(loopLike);
 
  245    root->
walk([&](vector::TransferReadOp transferRead) {
 
  246      if (!isa<MemRefType>(transferRead.getShapedType()))
 
  249      LLVM_DEBUG(
DBGS() << 
"Candidate for hoisting: " 
  250                        << *transferRead.getOperation() << 
"\n");
 
  251      auto loop = dyn_cast<LoopLikeOpInterface>(transferRead->getParentOp());
 
  252      LLVM_DEBUG(
DBGS() << 
"Parent op: " << *transferRead->getParentOp()
 
  254      if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
 
  257      if (verifyNonZeroTrip && !definiteNonZeroTripCountLoops.contains(loop)) {
 
  258        LLVM_DEBUG(
DBGS() << 
"Loop may have zero trip count: " << *loop
 
  263      LLVM_DEBUG(
DBGS() << 
"Candidate read: " << *transferRead.getOperation()
 
  271      vector::TransferWriteOp transferWrite;
 
  272      for (
auto *sliceOp : llvm::reverse(forwardSlice)) {
 
  273        auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
 
  274        if (!candidateWrite ||
 
  275            candidateWrite.getBase() != transferRead.getBase())
 
  277        transferWrite = candidateWrite;
 
  281      for (
auto operand : transferRead.getOperands())
 
  282        if (!loop.isDefinedOutsideOfLoop(operand))
 
  287      if (!transferWrite) {
 
  291          loop.moveOutOfLoop(transferRead);
 
  295      LLVM_DEBUG(
DBGS() << 
"Candidate: " << *transferWrite.getOperation()
 
  308      if (transferRead.getIndices() != transferWrite.getIndices() ||
 
  309          transferRead.getVectorType() != transferWrite.getVectorType() ||
 
  310          transferRead.getPermutationMap() != transferWrite.getPermutationMap())
 
  315      auto base = transferRead.getBase();
 
  316      auto *source = base.getDefiningOp();
 
  329        if (
auto assume = dyn_cast<memref::AssumeAlignmentOp>(source)) {
 
  330          Value memPreAlignment = assume.getMemref();
 
  332              llvm::count_if(base.getUses(), [&loop](
OpOperand &use) {
 
  333                return loop->isAncestor(use.getOwner());
 
  336          if (numInLoopUses && memPreAlignment.
hasOneUse())
 
  339        if (isa_and_nonnull<ViewLikeOpInterface>(source))
 
  343      if (llvm::any_of(base.getUsers(), llvm::IsaPred<ViewLikeOpInterface>))
 
  352      for (
auto &use : transferRead.getBase().getUses()) {
 
  353        if (!loop->isAncestor(use.getOwner()))
 
  355        if (use.getOwner() == transferRead.getOperation() ||
 
  356            use.getOwner() == transferWrite.getOperation())
 
  358        if (
auto transferWriteUse =
 
  359                dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
 
  361                  cast<VectorTransferOpInterface>(*transferWrite),
 
  362                  cast<VectorTransferOpInterface>(*transferWriteUse),
 
  365        } 
else if (
auto transferReadUse =
 
  366                       dyn_cast<vector::TransferReadOp>(use.getOwner())) {
 
  368                  cast<VectorTransferOpInterface>(*transferWrite),
 
  369                  cast<VectorTransferOpInterface>(*transferReadUse),
 
  380      loop.moveOutOfLoop(transferRead);
 
  383      transferWrite->moveAfter(loop);
 
  387      IRRewriter rewriter(transferRead.getContext());
 
  393      auto maybeNewLoop = loop.replaceWithAdditionalYields(
 
  394          rewriter, transferRead.getVector(),
 
  396      if (failed(maybeNewLoop))
 
  399      transferWrite.getValueToStoreMutable().assign(
 
  400          maybeNewLoop->getOperation()->getResults().back());
 
 
static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, LoopLikeOpInterface loop)
 
static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop, Value newInitOperand, unsigned index, Value newYieldValue)
Replace loop with a new loop that has a different init operand at position index.
 
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
 
A class for computing basic dominance information.
 
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
 
IRValueT get() const
Return the current value being used by this operand.
 
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
 
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
 
RAII guard to reset the insertion point of the builder when destroyed.
 
This class helps build Operations.
 
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
 
This class represents an operand of an operation.
 
Operation is the basic unit of execution within MLIR.
 
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
 
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
 
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
 
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
 
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
 
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
 
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
 
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
 
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, const StopConditionFn &stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given variable.
 
This class provides an abstraction over the different types of ranges over Values.
 
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
 
user_range getUsers() const
 
bool hasOneUse() const
Returns true if this value has exactly one use.
 
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
 
static WalkResult advance()
 
static WalkResult interrupt()
 
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
 
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
 
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
 
Include the generated interface declarations.
 
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
 
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
 
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
 
llvm::SetVector< T, Vector, Set, N > SetVector
 
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
 
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.