MLIR  19.0.0git
Hoisting.cpp
Go to the documentation of this file.
1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with hoisting invariant operations
10 // in the context of Linalg transformations.
11 //
12 //===----------------------------------------------------------------------===//
13 
29 #include "mlir/IR/BuiltinOps.h"
30 #include "mlir/IR/Dominance.h"
33 #include "llvm/ADT/StringRef.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 
37 using llvm::dbgs;
38 
39 #define DEBUG_TYPE "linalg-hoisting"
40 
41 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
42 
43 using namespace mlir;
44 using namespace mlir::linalg;
45 
46 /// Replace `loop` with a new loop that has a different init operand at
47 /// position `index`. The body of this loop is moved over to the new loop.
48 ///
49 /// `newInitOperands` specifies the replacement "init" operands.
50 /// `newYieldValue` is the replacement yield value of the loop at position
51 /// `index`.
52 static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter,
53  scf::ForOp loop,
54  Value newInitOperand,
55  unsigned index,
56  Value newYieldValue) {
57  OpBuilder::InsertionGuard g(rewriter);
58  rewriter.setInsertionPoint(loop.getOperation());
59  auto inits = llvm::to_vector(loop.getInits());
60 
61  // Replace the init value with the new operand.
62  assert(index < inits.size());
63  inits[index] = newInitOperand;
64 
65  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
66  loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
67  inits, [](OpBuilder &, Location, Value, ValueRange) {});
68 
69  // Generate the new yield with the replaced operand.
70  auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
71  yieldOp.setOperand(index, newYieldValue);
72 
73  // Move the loop body to the new op.
74  rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
75  newLoop.getBody()->getArguments());
76 
77  // Replace the old loop.
78  rewriter.replaceOp(loop.getOperation(), newLoop->getResults());
79  return newLoop;
80 }
81 
82 // Hoist out a pair of corresponding vector.extract+vector.broadcast
83 // operations. This function transforms a loop like this:
84 // %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) {
85 // %e = vector.extract %iarg : t1 to t2
86 // %u = "some_use"(%e) : (t2) -> t2
87 // %b = vector.broadcast %u : t2 to t1
88 // scf.yield %b : t1
89 // }
90 // into the following:
91 // %e = vector.extract %v: t1 to t2
92 // %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) {
93 // %u' = "some_use"(%iarg) : (t2) -> t2
94 // scf.yield %u' : t2
95 // }
96 // %res = vector.broadcast %res' : t2 to t1
98  Operation *root) {
99  bool changed = true;
100  while (changed) {
101  changed = false;
102  // First move loop invariant ops outside of their loop. This needs to be
103  // done before as we cannot move ops without interrupting the function walk.
104  root->walk(
105  [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
106 
107  root->walk([&](vector::ExtractOp extractOp) {
108  LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
109  << *extractOp.getOperation() << "\n");
110 
111  auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp());
112  if (!loop)
113  return WalkResult::advance();
114 
115  // Check that the vector to extract from is a BlockArgument.
116  auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
117  if (!blockArg)
118  return WalkResult::advance();
119 
120  // Check that the blockArg is an iter_arg of the loop.
121  OpOperand *initArg = loop.getTiedLoopInit(blockArg);
122  if (!initArg)
123  return WalkResult::advance();
124 
125  // If the iter_arg does not have only one use, it won't be possible to
126  // hoist the extractOp out.
127  if (!blockArg.hasOneUse())
128  return WalkResult::advance();
129 
130  unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars();
131 
132  // Check that the loop yields a broadcast that has just one use.
133  Operation *yieldedVal =
134  loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
135  auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
136  if (!broadcast || !broadcast.getResult().hasOneUse())
137  return WalkResult::advance();
138 
139  LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
140 
141  Type broadcastInputType = broadcast.getSourceType();
142  if (broadcastInputType != extractOp.getType())
143  return WalkResult::advance();
144 
145  // The position of the extract must be defined outside of the loop if
146  // it is dynamic.
147  for (auto operand : extractOp.getDynamicPosition())
148  if (!loop.isDefinedOutsideOfLoop(operand))
149  return WalkResult::advance();
150 
151  rewriter.modifyOpInPlace(broadcast, [&] {
152  extractOp.getVectorMutable().assign(initArg->get());
153  });
154  loop.moveOutOfLoop(extractOp);
155  rewriter.moveOpAfter(broadcast, loop);
156 
157  scf::ForOp newLoop = replaceWithDifferentYield(
158  rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
159 
160  LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
161 
162  rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast);
163  rewriter.modifyOpInPlace(
164  broadcast, [&] { broadcast.setOperand(newLoop.getResult(index)); });
165 
166  changed = true;
167  return WalkResult::interrupt();
168  });
169  }
170 }
171 
172 static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
173  LoopLikeOpInterface loop) {
174  Value source = transferRead.getSource();
175 
176  // Skip view-like Ops and retrive the actual soruce Operation
177  while (auto srcOp =
178  dyn_cast_or_null<ViewLikeOpInterface>(source.getDefiningOp()))
179  source = srcOp.getViewSource();
180 
181  llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
182  source.getUsers().end());
183  llvm::SmallDenseSet<Operation *, 32> processed;
184  while (!users.empty()) {
185  Operation *user = users.pop_back_val();
186  // If the user has already been processed skip.
187  if (!processed.insert(user).second)
188  continue;
189  if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
190  users.append(viewLike->getUsers().begin(), viewLike->getUsers().end());
191  continue;
192  }
193  if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
194  continue;
195  if (!loop->isAncestor(user))
196  continue;
197  return false;
198  }
199  return true;
200 }
201 
203  bool changed = true;
204  while (changed) {
205  changed = false;
206  // First move loop invariant ops outside of their loop. This needs to be
207  // done before as we cannot move ops without interrupting the function walk.
208  root->walk(
209  [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
210 
211  root->walk([&](vector::TransferReadOp transferRead) {
212  if (!isa<MemRefType>(transferRead.getShapedType()))
213  return WalkResult::advance();
214 
215  LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
216  << *transferRead.getOperation() << "\n");
217  auto loop = dyn_cast<LoopLikeOpInterface>(transferRead->getParentOp());
218  LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
219  << "\n");
220  if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
221  return WalkResult::advance();
222 
223  LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
224  << "\n");
225 
226  SetVector<Operation *> forwardSlice;
227  getForwardSlice(transferRead.getOperation(), &forwardSlice);
228 
229  // Look for the last TransferWriteOp in the forwardSlice of
230  // `transferRead` that operates on the same memref.
231  vector::TransferWriteOp transferWrite;
232  for (auto *sliceOp : llvm::reverse(forwardSlice)) {
233  auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
234  if (!candidateWrite ||
235  candidateWrite.getSource() != transferRead.getSource())
236  continue;
237  transferWrite = candidateWrite;
238  }
239 
240  // All operands of the TransferRead must be defined outside of the loop.
241  for (auto operand : transferRead.getOperands())
242  if (!loop.isDefinedOutsideOfLoop(operand))
243  return WalkResult::advance();
244 
245  // Only hoist transfer_read / transfer_write pairs and singleton
246  // transfer_reads for now.
247  if (!transferWrite) {
248  // Make sure there are no other accesses to the memref before
249  // hoisting transfer_read.
250  if (noAliasingUseInLoop(transferRead, loop))
251  loop.moveOutOfLoop(transferRead);
252  return WalkResult::advance();
253  }
254 
255  LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
256  << "\n");
257 
258  // Approximate aliasing by checking that:
259  // 1. indices, vector type and permutation map are the same (i.e., the
260  // transfer_read/transfer_write ops are matching),
261  // 2. source operands for transfer.{read|write} do not originate from
262  // Ops implementing ViewLikeOpInterface.
263  // 3. no other operations in the loop access the same memref except
264  // for transfer_read/transfer_write accessing statically disjoint
265  // slices.
266  if (transferRead.getIndices() != transferWrite.getIndices() ||
267  transferRead.getVectorType() != transferWrite.getVectorType() ||
268  transferRead.getPermutationMap() != transferWrite.getPermutationMap())
269  return WalkResult::advance();
270 
271  auto *source = transferRead.getSource().getDefiningOp();
272  if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
273  return WalkResult::advance();
274 
275  source = transferWrite.getSource().getDefiningOp();
276  if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
277  return WalkResult::advance();
278 
279  // TODO: may want to memoize this information for performance but it
280  // likely gets invalidated often.
281  DominanceInfo dom(loop);
282  if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
283  return WalkResult::advance();
284  for (auto &use : transferRead.getSource().getUses()) {
285  if (!loop->isAncestor(use.getOwner()))
286  continue;
287  if (use.getOwner() == transferRead.getOperation() ||
288  use.getOwner() == transferWrite.getOperation())
289  continue;
290  if (auto transferWriteUse =
291  dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
293  cast<VectorTransferOpInterface>(*transferWrite),
294  cast<VectorTransferOpInterface>(*transferWriteUse),
295  /*testDynamicValueUsingBounds=*/true))
296  return WalkResult::advance();
297  } else if (auto transferReadUse =
298  dyn_cast<vector::TransferReadOp>(use.getOwner())) {
300  cast<VectorTransferOpInterface>(*transferWrite),
301  cast<VectorTransferOpInterface>(*transferReadUse),
302  /*testDynamicValueUsingBounds=*/true))
303  return WalkResult::advance();
304  } else {
305  // Unknown use, we cannot prove that it doesn't alias with the
306  // transferRead/transferWrite operations.
307  return WalkResult::advance();
308  }
309  }
310 
311  // Hoist read before.
312  loop.moveOutOfLoop(transferRead);
313 
314  // Hoist write after.
315  transferWrite->moveAfter(loop);
316 
317  // Rewrite `loop` with new yields by cloning and erase the original loop.
318  IRRewriter rewriter(transferRead.getContext());
319  NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
320  ArrayRef<BlockArgument> newBBArgs) {
321  return SmallVector<Value>{transferWrite.getVector()};
322  };
323 
324  auto maybeNewLoop = loop.replaceWithAdditionalYields(
325  rewriter, transferRead.getVector(),
326  /*replaceInitOperandUsesInLoop=*/true, yieldFn);
327  if (failed(maybeNewLoop))
328  return WalkResult::interrupt();
329 
330  transferWrite.getVectorMutable().assign(
331  maybeNewLoop->getOperation()->getResults().back());
332  changed = true;
333  // Need to interrupt and restart because erasing the loop messes up
334  // the walk.
335  return WalkResult::interrupt();
336  });
337  }
338 }
static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, LoopLikeOpInterface loop)
Definition: Hoisting.cpp:172
static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop, Value newInitOperand, unsigned index, Value newYieldValue)
Replace loop with a new loop that has a different init operand at position index.
Definition: Hoisting.cpp:52
#define DBGS()
Definition: Hoisting.cpp:41
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
A class for computing basic dominance information.
Definition: Dominance.h:136
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.h:149
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
user_range getUsers() const
Definition: Value.h:228
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:52
static WalkResult interrupt()
Definition: Visitors.h:51
void hoistRedundantVectorTransfers(Operation *root)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition: Hoisting.cpp:202
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
Definition: Hoisting.cpp:97
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Definition: VectorOps.cpp:253
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.