MLIR  22.0.0git
Hoisting.cpp
Go to the documentation of this file.
1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with hoisting invariant operations
10 // in the context of Linalg transformations.
11 //
12 //===----------------------------------------------------------------------===//
13 
24 #include "mlir/IR/Dominance.h"
26 #include "llvm/Support/Debug.h"
27 
28 using llvm::dbgs;
29 
30 #define DEBUG_TYPE "linalg-hoisting"
31 
32 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
33 
34 using namespace mlir;
35 using namespace mlir::linalg;
36 
37 /// Replace `loop` with a new loop that has a different init operand at
38 /// position `index`. The body of this loop is moved over to the new loop.
39 ///
40 /// `newInitOperands` specifies the replacement "init" operands.
41 /// `newYieldValue` is the replacement yield value of the loop at position
42 /// `index`.
43 static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter,
44  scf::ForOp loop,
45  Value newInitOperand,
46  unsigned index,
47  Value newYieldValue) {
48  OpBuilder::InsertionGuard g(rewriter);
49  rewriter.setInsertionPoint(loop.getOperation());
50  auto inits = llvm::to_vector(loop.getInits());
51 
52  // Replace the init value with the new operand.
53  assert(index < inits.size());
54  inits[index] = newInitOperand;
55 
56  scf::ForOp newLoop = scf::ForOp::create(
57  rewriter, loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(),
58  loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {},
59  loop.getUnsignedCmp());
60 
61  // Generate the new yield with the replaced operand.
62  auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
63  yieldOp.setOperand(index, newYieldValue);
64 
65  // Move the loop body to the new op.
66  rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
67  newLoop.getBody()->getArguments());
68 
69  // Replace the old loop.
70  rewriter.replaceOp(loop.getOperation(), newLoop->getResults());
71  return newLoop;
72 }
73 
74 // Hoist out a pair of corresponding vector.extract+vector.broadcast
75 // operations. This function transforms a loop like this:
76 // %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) {
77 // %e = vector.extract %iarg : t1 to t2
78 // %u = "some_use"(%e) : (t2) -> t2
79 // %b = vector.broadcast %u : t2 to t1
80 // scf.yield %b : t1
81 // }
82 // into the following:
83 // %e = vector.extract %v: t1 to t2
84 // %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) {
85 // %u' = "some_use"(%iarg) : (t2) -> t2
86 // scf.yield %u' : t2
87 // }
88 // %res = vector.broadcast %res' : t2 to t1
90  Operation *root) {
91  bool changed = true;
92  while (changed) {
93  changed = false;
94  // First move loop invariant ops outside of their loop. This needs to be
95  // done before as we cannot move ops without interrupting the function walk.
96  root->walk(
97  [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
98 
99  root->walk([&](vector::ExtractOp extractOp) {
100  LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
101  << *extractOp.getOperation() << "\n");
102 
103  auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp());
104  if (!loop)
105  return WalkResult::advance();
106 
107  // Check that the vector to extract from is a BlockArgument.
108  auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
109  if (!blockArg)
110  return WalkResult::advance();
111 
112  // Check that the blockArg is an iter_arg of the loop.
113  OpOperand *initArg = loop.getTiedLoopInit(blockArg);
114  if (!initArg)
115  return WalkResult::advance();
116 
117  // If the iter_arg does not have only one use, it won't be possible to
118  // hoist the extractOp out.
119  if (!blockArg.hasOneUse())
120  return WalkResult::advance();
121 
122  unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars();
123 
124  // Check that the loop yields a broadcast that has just one use.
125  Operation *yieldedVal =
126  loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
127  auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
128  if (!broadcast || !broadcast.getResult().hasOneUse())
129  return WalkResult::advance();
130 
131  LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
132 
133  Type broadcastInputType = broadcast.getSourceType();
134  if (broadcastInputType != extractOp.getType())
135  return WalkResult::advance();
136 
137  // The position of the extract must be defined outside of the loop if
138  // it is dynamic.
139  for (auto operand : extractOp.getDynamicPosition())
140  if (!loop.isDefinedOutsideOfLoop(operand))
141  return WalkResult::advance();
142 
143  rewriter.modifyOpInPlace(broadcast, [&] {
144  extractOp.getVectorMutable().assign(initArg->get());
145  });
146  loop.moveOutOfLoop(extractOp);
147  rewriter.moveOpAfter(broadcast, loop);
148 
149  scf::ForOp newLoop = replaceWithDifferentYield(
150  rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
151 
152  LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
153 
154  rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast);
155  rewriter.modifyOpInPlace(
156  broadcast, [&] { broadcast.setOperand(newLoop.getResult(index)); });
157 
158  changed = true;
159  return WalkResult::interrupt();
160  });
161  }
162 }
163 
164 static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
165  LoopLikeOpInterface loop) {
166  Value source = transferRead.getBase();
167 
168  // Skip view-like Ops and retrive the actual soruce Operation
169  while (auto viewLike = source.getDefiningOp<ViewLikeOpInterface>()) {
170  if (viewLike.getViewDest() != source) {
171  break;
172  }
173  source = viewLike.getViewSource();
174  }
175 
176  llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
177  source.getUsers().end());
178  llvm::SmallDenseSet<Operation *, 32> processed;
179  while (!users.empty()) {
180  Operation *user = users.pop_back_val();
181  // If the user has already been processed skip.
182  if (!processed.insert(user).second)
183  continue;
184  if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
185  Value viewDest = viewLike.getViewDest();
186  users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
187  continue;
188  }
189  if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
190  continue;
191  if (!loop->isAncestor(user))
192  continue;
193  return false;
194  }
195  return true;
196 }
197 
199  bool verifyNonZeroTrip) {
200  bool changed = true;
201  while (changed) {
202  changed = false;
203  // First move loop invariant ops outside of their loop. This needs to be
204  // done before as we cannot move ops without interrupting the function walk.
205  root->walk(
206  [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
207 
208  // Find all loops that are certain to have non zero trip count. Any loops
209  // that are not part of this set cannot be hoisted from, since hoisting from
210  // a potentially zero trip count loop may cause a vector transfer to be
211  // executed when it shouldn't be.
212  llvm::DenseSet<LoopLikeOpInterface> definiteNonZeroTripCountLoops;
213  if (verifyNonZeroTrip) {
214  root->walk([&](LoopLikeOpInterface loopLike) {
215  std::optional<SmallVector<OpFoldResult>> lbs =
216  loopLike.getLoopLowerBounds();
217  std::optional<SmallVector<OpFoldResult>> ubs =
218  loopLike.getLoopUpperBounds();
219  // If loop bounds cannot be found, assume possibly zero trip count.
220  if (!lbs || !ubs)
221  return;
222 
223  // Otherwise, use ValueBounds to find the maximum lower bound and
224  // minimum upper bound. If the bounds are found, and maxLb is less
225  // than the minUb, then the loop will not have zero trip count.
226  for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) {
227  FailureOr<int64_t> maxLb =
230  /*stopCondition=*/nullptr, /*closedUB=*/true);
231  if (failed(maxLb))
232  return;
233  FailureOr<int64_t> minUb =
236  if (failed(minUb))
237  return;
238  if (minUb.value() <= maxLb.value())
239  return;
240  definiteNonZeroTripCountLoops.insert(loopLike);
241  }
242  });
243  }
244 
245  root->walk([&](vector::TransferReadOp transferRead) {
246  if (!isa<MemRefType>(transferRead.getShapedType()))
247  return WalkResult::advance();
248 
249  LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
250  << *transferRead.getOperation() << "\n");
251  auto loop = dyn_cast<LoopLikeOpInterface>(transferRead->getParentOp());
252  LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
253  << "\n");
254  if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
255  return WalkResult::advance();
256 
257  if (verifyNonZeroTrip && !definiteNonZeroTripCountLoops.contains(loop)) {
258  LLVM_DEBUG(DBGS() << "Loop may have zero trip count: " << *loop
259  << "\n");
260  return WalkResult::advance();
261  }
262 
263  LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
264  << "\n");
265 
266  SetVector<Operation *> forwardSlice;
267  getForwardSlice(transferRead.getOperation(), &forwardSlice);
268 
269  // Look for the last TransferWriteOp in the forwardSlice of
270  // `transferRead` that operates on the same memref.
271  vector::TransferWriteOp transferWrite;
272  for (auto *sliceOp : llvm::reverse(forwardSlice)) {
273  auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
274  if (!candidateWrite ||
275  candidateWrite.getBase() != transferRead.getBase())
276  continue;
277  transferWrite = candidateWrite;
278  }
279 
280  // All operands of the TransferRead must be defined outside of the loop.
281  for (auto operand : transferRead.getOperands())
282  if (!loop.isDefinedOutsideOfLoop(operand))
283  return WalkResult::advance();
284 
285  // Only hoist transfer_read / transfer_write pairs and singleton
286  // transfer_reads for now.
287  if (!transferWrite) {
288  // Make sure there are no other accesses to the memref before
289  // hoisting transfer_read.
290  if (noAliasingUseInLoop(transferRead, loop))
291  loop.moveOutOfLoop(transferRead);
292  return WalkResult::advance();
293  }
294 
295  LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
296  << "\n");
297 
298  // Approximate aliasing by checking that:
299  // 1. indices, vector type and permutation map are the same (i.e., the
300  // transfer_read/transfer_write ops are matching),
301  // 2. source operands for transfer.{read|write} do not originate from
302  // nor have users that are Ops implementing ViewLikeOpInterface.
303  // 3. no other operations in the loop access the same memref except
304  // for transfer_read/transfer_write accessing statically disjoint
305  // slices.
306 
307  // Check 1.
308  if (transferRead.getIndices() != transferWrite.getIndices() ||
309  transferRead.getVectorType() != transferWrite.getVectorType() ||
310  transferRead.getPermutationMap() != transferWrite.getPermutationMap())
311  return WalkResult::advance();
312 
313  // Check 2. Note, since both xfer Ops share the source, we only need to
314  // look at one of them.
315  auto base = transferRead.getBase();
316  auto *source = base.getDefiningOp();
317  if (source) {
318  // NOTE: We treat `memref.assume_alignment` as a special case.
319  //
320  // The idea is that it is safe to look past AssumeAlignmemtOp (i.e.
321  // MemRef _before_ alignment) iff:
322  // 1. It has exactly two uses (these have to be the xfer Ops
323  // being looked at).
324  // 2. The original MemRef has only one use (i.e.
325  // AssumeAlignmentOp).
326  //
327  // Relaxing these conditions will most likely require proper alias
328  // analysis.
329  if (auto assume = dyn_cast<memref::AssumeAlignmentOp>(source)) {
330  Value memPreAlignment = assume.getMemref();
331  auto numInLoopUses =
332  llvm::count_if(base.getUses(), [&loop](OpOperand &use) {
333  return loop->isAncestor(use.getOwner());
334  });
335 
336  if (numInLoopUses && memPreAlignment.hasOneUse())
337  source = memPreAlignment.getDefiningOp();
338  }
339  if (isa_and_nonnull<ViewLikeOpInterface>(source))
340  return WalkResult::advance();
341  }
342 
343  if (llvm::any_of(base.getUsers(), llvm::IsaPred<ViewLikeOpInterface>))
344  return WalkResult::advance();
345 
346  // Check 3.
347  // TODO: may want to memoize this information for performance but it
348  // likely gets invalidated often.
349  DominanceInfo dom(loop);
350  if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
351  return WalkResult::advance();
352  for (auto &use : transferRead.getBase().getUses()) {
353  if (!loop->isAncestor(use.getOwner()))
354  continue;
355  if (use.getOwner() == transferRead.getOperation() ||
356  use.getOwner() == transferWrite.getOperation())
357  continue;
358  if (auto transferWriteUse =
359  dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
361  cast<VectorTransferOpInterface>(*transferWrite),
362  cast<VectorTransferOpInterface>(*transferWriteUse),
363  /*testDynamicValueUsingBounds=*/true))
364  return WalkResult::advance();
365  } else if (auto transferReadUse =
366  dyn_cast<vector::TransferReadOp>(use.getOwner())) {
368  cast<VectorTransferOpInterface>(*transferWrite),
369  cast<VectorTransferOpInterface>(*transferReadUse),
370  /*testDynamicValueUsingBounds=*/true))
371  return WalkResult::advance();
372  } else {
373  // Unknown use, we cannot prove that it doesn't alias with the
374  // transferRead/transferWrite operations.
375  return WalkResult::advance();
376  }
377  }
378 
379  // Hoist read before.
380  loop.moveOutOfLoop(transferRead);
381 
382  // Hoist write after.
383  transferWrite->moveAfter(loop);
384 
385  // Rewrite `loop` with new yields by cloning and erase the original
386  // loop.
387  IRRewriter rewriter(transferRead.getContext());
388  NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
389  ArrayRef<BlockArgument> newBBArgs) {
390  return SmallVector<Value>{transferWrite.getVector()};
391  };
392 
393  auto maybeNewLoop = loop.replaceWithAdditionalYields(
394  rewriter, transferRead.getVector(),
395  /*replaceInitOperandUsesInLoop=*/true, yieldFn);
396  if (failed(maybeNewLoop))
397  return WalkResult::interrupt();
398 
399  transferWrite.getValueToStoreMutable().assign(
400  maybeNewLoop->getOperation()->getResults().back());
401  changed = true;
402  // Need to interrupt and restart because erasing the loop messes up
403  // the walk.
404  return WalkResult::interrupt();
405  });
406  }
407 }
static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, LoopLikeOpInterface loop)
Definition: Hoisting.cpp:164
static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop, Value newInitOperand, unsigned index, Value newYieldValue)
Replace loop with a new loop that has a different init operand at position index.
Definition: Hoisting.cpp:43
#define DBGS()
Definition: Hoisting.cpp:32
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.cpp:323
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:764
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:636
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, StopConditionFn stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given variable.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
user_range getUsers() const
Definition: Value.h:218
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
Definition: Hoisting.cpp:89
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition: Hoisting.cpp:198
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Definition: VectorOps.cpp:314
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.