MLIR  20.0.0git
Hoisting.cpp
Go to the documentation of this file.
1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with hoisting invariant operations
10 // in the context of Linalg transformations.
11 //
12 //===----------------------------------------------------------------------===//
13 
29 #include "mlir/IR/BuiltinOps.h"
30 #include "mlir/IR/Dominance.h"
33 #include "llvm/ADT/StringRef.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 
37 using llvm::dbgs;
38 
39 #define DEBUG_TYPE "linalg-hoisting"
40 
41 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
42 
43 using namespace mlir;
44 using namespace mlir::linalg;
45 
46 /// Replace `loop` with a new loop that has a different init operand at
47 /// position `index`. The body of this loop is moved over to the new loop.
48 ///
49 /// `newInitOperands` specifies the replacement "init" operands.
50 /// `newYieldValue` is the replacement yield value of the loop at position
51 /// `index`.
52 static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter,
53  scf::ForOp loop,
54  Value newInitOperand,
55  unsigned index,
56  Value newYieldValue) {
57  OpBuilder::InsertionGuard g(rewriter);
58  rewriter.setInsertionPoint(loop.getOperation());
59  auto inits = llvm::to_vector(loop.getInits());
60 
61  // Replace the init value with the new operand.
62  assert(index < inits.size());
63  inits[index] = newInitOperand;
64 
65  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
66  loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
67  inits, [](OpBuilder &, Location, Value, ValueRange) {});
68 
69  // Generate the new yield with the replaced operand.
70  auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
71  yieldOp.setOperand(index, newYieldValue);
72 
73  // Move the loop body to the new op.
74  rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
75  newLoop.getBody()->getArguments());
76 
77  // Replace the old loop.
78  rewriter.replaceOp(loop.getOperation(), newLoop->getResults());
79  return newLoop;
80 }
81 
82 // Hoist out a pair of corresponding vector.extract+vector.broadcast
83 // operations. This function transforms a loop like this:
84 // %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) {
85 // %e = vector.extract %iarg : t1 to t2
86 // %u = "some_use"(%e) : (t2) -> t2
87 // %b = vector.broadcast %u : t2 to t1
88 // scf.yield %b : t1
89 // }
90 // into the following:
91 // %e = vector.extract %v: t1 to t2
92 // %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) {
93 // %u' = "some_use"(%iarg) : (t2) -> t2
94 // scf.yield %u' : t2
95 // }
96 // %res = vector.broadcast %res' : t2 to t1
98  Operation *root) {
99  bool changed = true;
100  while (changed) {
101  changed = false;
102  // First move loop invariant ops outside of their loop. This needs to be
103  // done before as we cannot move ops without interrupting the function walk.
104  root->walk(
105  [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
106 
107  root->walk([&](vector::ExtractOp extractOp) {
108  LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
109  << *extractOp.getOperation() << "\n");
110 
111  auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp());
112  if (!loop)
113  return WalkResult::advance();
114 
115  // Check that the vector to extract from is a BlockArgument.
116  auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
117  if (!blockArg)
118  return WalkResult::advance();
119 
120  // Check that the blockArg is an iter_arg of the loop.
121  OpOperand *initArg = loop.getTiedLoopInit(blockArg);
122  if (!initArg)
123  return WalkResult::advance();
124 
125  // If the iter_arg does not have only one use, it won't be possible to
126  // hoist the extractOp out.
127  if (!blockArg.hasOneUse())
128  return WalkResult::advance();
129 
130  unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars();
131 
132  // Check that the loop yields a broadcast that has just one use.
133  Operation *yieldedVal =
134  loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
135  auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
136  if (!broadcast || !broadcast.getResult().hasOneUse())
137  return WalkResult::advance();
138 
139  LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
140 
141  Type broadcastInputType = broadcast.getSourceType();
142  if (broadcastInputType != extractOp.getType())
143  return WalkResult::advance();
144 
145  // The position of the extract must be defined outside of the loop if
146  // it is dynamic.
147  for (auto operand : extractOp.getDynamicPosition())
148  if (!loop.isDefinedOutsideOfLoop(operand))
149  return WalkResult::advance();
150 
151  rewriter.modifyOpInPlace(broadcast, [&] {
152  extractOp.getVectorMutable().assign(initArg->get());
153  });
154  loop.moveOutOfLoop(extractOp);
155  rewriter.moveOpAfter(broadcast, loop);
156 
157  scf::ForOp newLoop = replaceWithDifferentYield(
158  rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
159 
160  LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
161 
162  rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast);
163  rewriter.modifyOpInPlace(
164  broadcast, [&] { broadcast.setOperand(newLoop.getResult(index)); });
165 
166  changed = true;
167  return WalkResult::interrupt();
168  });
169  }
170 }
171 
172 static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
173  LoopLikeOpInterface loop) {
174  Value source = transferRead.getSource();
175 
176  // Skip view-like Ops and retrive the actual soruce Operation
177  while (auto srcOp =
178  dyn_cast_or_null<ViewLikeOpInterface>(source.getDefiningOp()))
179  source = srcOp.getViewSource();
180 
181  llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
182  source.getUsers().end());
183  llvm::SmallDenseSet<Operation *, 32> processed;
184  while (!users.empty()) {
185  Operation *user = users.pop_back_val();
186  // If the user has already been processed skip.
187  if (!processed.insert(user).second)
188  continue;
189  if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
190  users.append(viewLike->getUsers().begin(), viewLike->getUsers().end());
191  continue;
192  }
193  if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
194  continue;
195  if (!loop->isAncestor(user))
196  continue;
197  return false;
198  }
199  return true;
200 }
201 
203  bool verifyNonZeroTrip) {
204  bool changed = true;
205  while (changed) {
206  changed = false;
207  // First move loop invariant ops outside of their loop. This needs to be
208  // done before as we cannot move ops without interrupting the function walk.
209  root->walk(
210  [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
211 
212  // Find all loops that are certain to have non zero trip count. Any loops
213  // that are not part of this set cannot be hoisted from, since hoisting from
214  // a potentially zero trip count loop may cause a vector transfer to be
215  // executed when it shouldn't be.
216  llvm::DenseSet<LoopLikeOpInterface> definiteNonZeroTripCountLoops;
217  if (verifyNonZeroTrip) {
218  root->walk([&](LoopLikeOpInterface loopLike) {
219  std::optional<SmallVector<OpFoldResult>> lbs =
220  loopLike.getLoopLowerBounds();
221  std::optional<SmallVector<OpFoldResult>> ubs =
222  loopLike.getLoopUpperBounds();
223  // If loop bounds cannot be found, assume possibly zero trip count.
224  if (!lbs || !ubs)
225  return;
226 
227  // Otherwise, use ValueBounds to find the maximum lower bound and
228  // minimum upper bound. If the bounds are found, and maxLb is less
229  // than the minUb, then the loop will not have zero trip count.
230  for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) {
231  FailureOr<int64_t> maxLb =
234  /*stopCondition=*/nullptr, /*closedUB=*/true);
235  if (failed(maxLb))
236  return;
237  FailureOr<int64_t> minUb =
240  if (failed(minUb))
241  return;
242  if (minUb.value() <= maxLb.value())
243  return;
244  definiteNonZeroTripCountLoops.insert(loopLike);
245  }
246  });
247  }
248 
249  root->walk([&](vector::TransferReadOp transferRead) {
250  if (!isa<MemRefType>(transferRead.getShapedType()))
251  return WalkResult::advance();
252 
253  LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
254  << *transferRead.getOperation() << "\n");
255  auto loop = dyn_cast<LoopLikeOpInterface>(transferRead->getParentOp());
256  LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
257  << "\n");
258  if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
259  return WalkResult::advance();
260 
261  if (verifyNonZeroTrip && !definiteNonZeroTripCountLoops.contains(loop)) {
262  LLVM_DEBUG(DBGS() << "Loop may have zero trip count: " << *loop
263  << "\n");
264  return WalkResult::advance();
265  }
266 
267  LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
268  << "\n");
269 
270  SetVector<Operation *> forwardSlice;
271  getForwardSlice(transferRead.getOperation(), &forwardSlice);
272 
273  // Look for the last TransferWriteOp in the forwardSlice of
274  // `transferRead` that operates on the same memref.
275  vector::TransferWriteOp transferWrite;
276  for (auto *sliceOp : llvm::reverse(forwardSlice)) {
277  auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
278  if (!candidateWrite ||
279  candidateWrite.getSource() != transferRead.getSource())
280  continue;
281  transferWrite = candidateWrite;
282  }
283 
284  // All operands of the TransferRead must be defined outside of the loop.
285  for (auto operand : transferRead.getOperands())
286  if (!loop.isDefinedOutsideOfLoop(operand))
287  return WalkResult::advance();
288 
289  // Only hoist transfer_read / transfer_write pairs and singleton
290  // transfer_reads for now.
291  if (!transferWrite) {
292  // Make sure there are no other accesses to the memref before
293  // hoisting transfer_read.
294  if (noAliasingUseInLoop(transferRead, loop))
295  loop.moveOutOfLoop(transferRead);
296  return WalkResult::advance();
297  }
298 
299  LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
300  << "\n");
301 
302  // Approximate aliasing by checking that:
303  // 1. indices, vector type and permutation map are the same (i.e., the
304  // transfer_read/transfer_write ops are matching),
305  // 2. source operands for transfer.{read|write} do not originate from
306  // Ops implementing ViewLikeOpInterface.
307  // 3. no other operations in the loop access the same memref except
308  // for transfer_read/transfer_write accessing statically disjoint
309  // slices.
310  if (transferRead.getIndices() != transferWrite.getIndices() ||
311  transferRead.getVectorType() != transferWrite.getVectorType() ||
312  transferRead.getPermutationMap() != transferWrite.getPermutationMap())
313  return WalkResult::advance();
314 
315  auto *source = transferRead.getSource().getDefiningOp();
316  if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
317  return WalkResult::advance();
318 
319  source = transferWrite.getSource().getDefiningOp();
320  if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
321  return WalkResult::advance();
322 
323  // TODO: may want to memoize this information for performance but it
324  // likely gets invalidated often.
325  DominanceInfo dom(loop);
326  if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
327  return WalkResult::advance();
328  for (auto &use : transferRead.getSource().getUses()) {
329  if (!loop->isAncestor(use.getOwner()))
330  continue;
331  if (use.getOwner() == transferRead.getOperation() ||
332  use.getOwner() == transferWrite.getOperation())
333  continue;
334  if (auto transferWriteUse =
335  dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
337  cast<VectorTransferOpInterface>(*transferWrite),
338  cast<VectorTransferOpInterface>(*transferWriteUse),
339  /*testDynamicValueUsingBounds=*/true))
340  return WalkResult::advance();
341  } else if (auto transferReadUse =
342  dyn_cast<vector::TransferReadOp>(use.getOwner())) {
344  cast<VectorTransferOpInterface>(*transferWrite),
345  cast<VectorTransferOpInterface>(*transferReadUse),
346  /*testDynamicValueUsingBounds=*/true))
347  return WalkResult::advance();
348  } else {
349  // Unknown use, we cannot prove that it doesn't alias with the
350  // transferRead/transferWrite operations.
351  return WalkResult::advance();
352  }
353  }
354 
355  // Hoist read before.
356  loop.moveOutOfLoop(transferRead);
357 
358  // Hoist write after.
359  transferWrite->moveAfter(loop);
360 
361  // Rewrite `loop` with new yields by cloning and erase the original loop.
362  IRRewriter rewriter(transferRead.getContext());
363  NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
364  ArrayRef<BlockArgument> newBBArgs) {
365  return SmallVector<Value>{transferWrite.getVector()};
366  };
367 
368  auto maybeNewLoop = loop.replaceWithAdditionalYields(
369  rewriter, transferRead.getVector(),
370  /*replaceInitOperandUsesInLoop=*/true, yieldFn);
371  if (failed(maybeNewLoop))
372  return WalkResult::interrupt();
373 
374  transferWrite.getVectorMutable().assign(
375  maybeNewLoop->getOperation()->getResults().back());
376  changed = true;
377  // Need to interrupt and restart because erasing the loop messes up
378  // the walk.
379  return WalkResult::interrupt();
380  });
381  }
382 }
static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, LoopLikeOpInterface loop)
Definition: Hoisting.cpp:172
static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop, Value newInitOperand, unsigned index, Value newYieldValue)
Replace loop with a new loop that has a different init operand at position index.
Definition: Hoisting.cpp:52
#define DBGS()
Definition: Hoisting.cpp:41
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.h:153
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:772
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, StopConditionFn stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given variable.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
user_range getUsers() const
Definition: Value.h:228
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
Definition: Hoisting.cpp:97
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition: Hoisting.cpp:202
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Definition: VectorOps.cpp:283
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.