MLIR  21.0.0git
ParallelLoopFusion.cpp
Go to the documentation of this file.
1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/OpDefinition.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
27 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 using namespace mlir::scf;
32 
33 /// Verify there are no nested ParallelOps.
34 static bool hasNestedParallelOp(ParallelOp ploop) {
35  auto walkResult =
36  ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
37  return walkResult.wasInterrupted();
38 }
39 
40 /// Verify equal iteration spaces.
41 static bool equalIterationSpaces(ParallelOp firstPloop,
42  ParallelOp secondPloop) {
43  if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
44  return false;
45 
46  auto matchOperands = [&](const OperandRange &lhs,
47  const OperandRange &rhs) -> bool {
48  // TODO: Extend this to support aliases and equal constants.
49  return std::equal(lhs.begin(), lhs.end(), rhs.begin());
50  };
51  return matchOperands(firstPloop.getLowerBound(),
52  secondPloop.getLowerBound()) &&
53  matchOperands(firstPloop.getUpperBound(),
54  secondPloop.getUpperBound()) &&
55  matchOperands(firstPloop.getStep(), secondPloop.getStep());
56 }
57 
58 /// Checks if the parallel loops have mixed access to the same buffers. Returns
59 /// `true` if the first parallel loop writes to the same indices that the second
60 /// loop reads.
62  ParallelOp firstPloop, ParallelOp secondPloop,
63  const IRMapping &firstToSecondPloopIndices,
66  SmallVector<Value> bufferStoresVec;
67  firstPloop.getBody()->walk([&](memref::StoreOp store) {
68  bufferStores[store.getMemRef()].push_back(store.getIndices());
69  bufferStoresVec.emplace_back(store.getMemRef());
70  });
71  auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
72  Value loadMem = load.getMemRef();
73  // Stop if the memref is defined in secondPloop body. Careful alias analysis
74  // is needed.
75  auto *memrefDef = loadMem.getDefiningOp();
76  if (memrefDef && memrefDef->getBlock() == load->getBlock())
77  return WalkResult::interrupt();
78 
79  for (Value store : bufferStoresVec)
80  if (store != loadMem && mayAlias(store, loadMem))
81  return WalkResult::interrupt();
82 
83  auto write = bufferStores.find(loadMem);
84  if (write == bufferStores.end())
85  return WalkResult::advance();
86 
87  // Check that at last one store was retrieved
88  if (write->second.empty())
89  return WalkResult::interrupt();
90 
91  auto storeIndices = write->second.front();
92 
93  // Multiple writes to the same memref are allowed only on the same indices
94  for (const auto &othStoreIndices : write->second) {
95  if (othStoreIndices != storeIndices)
96  return WalkResult::interrupt();
97  }
98 
99  // Check that the load indices of secondPloop coincide with store indices of
100  // firstPloop for the same memrefs.
101  auto loadIndices = load.getIndices();
102  if (storeIndices.size() != loadIndices.size())
103  return WalkResult::interrupt();
104  for (int i = 0, e = storeIndices.size(); i < e; ++i) {
105  if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
106  loadIndices[i]) {
107  auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
108  auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
109  if (storeIndexDefOp && loadIndexDefOp) {
110  if (!isMemoryEffectFree(storeIndexDefOp))
111  return WalkResult::interrupt();
112  if (!isMemoryEffectFree(loadIndexDefOp))
113  return WalkResult::interrupt();
114  if (!OperationEquivalence::isEquivalentTo(
115  storeIndexDefOp, loadIndexDefOp,
116  [&](Value storeIndex, Value loadIndex) {
117  if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
118  firstToSecondPloopIndices.lookupOrDefault(loadIndex))
119  return failure();
120  else
121  return success();
122  },
123  /*markEquivalent=*/nullptr,
124  OperationEquivalence::Flags::IgnoreLocations)) {
125  return WalkResult::interrupt();
126  }
127  } else {
128  return WalkResult::interrupt();
129  }
130  }
131  }
132  return WalkResult::advance();
133  });
134  return !walkResult.wasInterrupted();
135 }
136 
137 /// Analyzes dependencies in the most primitive way by checking simple read and
138 /// write patterns.
139 static LogicalResult
140 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
141  const IRMapping &firstToSecondPloopIndices,
144  firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
145  return failure();
146 
147  IRMapping secondToFirstPloopIndices;
148  secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
149  firstPloop.getBody()->getArguments());
151  secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
152 }
153 
154 static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
155  const IRMapping &firstToSecondPloopIndices,
157  return !hasNestedParallelOp(firstPloop) &&
158  !hasNestedParallelOp(secondPloop) &&
159  equalIterationSpaces(firstPloop, secondPloop) &&
160  succeeded(verifyDependencies(firstPloop, secondPloop,
161  firstToSecondPloopIndices, mayAlias));
162 }
163 
164 /// Prepends operations of firstPloop's body into secondPloop's body.
165 /// Updates secondPloop with new loop.
166 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
167  OpBuilder builder,
169  Block *block1 = firstPloop.getBody();
170  Block *block2 = secondPloop.getBody();
171  IRMapping firstToSecondPloopIndices;
172  firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
173 
174  if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
175  mayAlias))
176  return;
177 
178  DominanceInfo dom;
179  // We are fusing first loop into second, make sure there are no users of the
180  // first loop results between loops.
181  for (Operation *user : firstPloop->getUsers())
182  if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
183  return;
184 
185  ValueRange inits1 = firstPloop.getInitVals();
186  ValueRange inits2 = secondPloop.getInitVals();
187 
188  SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
189  newInitVars.append(inits2.begin(), inits2.end());
190 
191  IRRewriter b(builder);
192  b.setInsertionPoint(secondPloop);
193  auto newSecondPloop = b.create<ParallelOp>(
194  secondPloop.getLoc(), secondPloop.getLowerBound(),
195  secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
196 
197  Block *newBlock = newSecondPloop.getBody();
198  auto term1 = cast<ReduceOp>(block1->getTerminator());
199  auto term2 = cast<ReduceOp>(block2->getTerminator());
200 
201  b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
202  newBlock->getArguments());
203  b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
204  newBlock->getArguments());
205 
206  ValueRange results = newSecondPloop.getResults();
207  if (!results.empty()) {
208  b.setInsertionPointToEnd(newBlock);
209 
210  ValueRange reduceArgs1 = term1.getOperands();
211  ValueRange reduceArgs2 = term2.getOperands();
212  SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
213  newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
214 
215  auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
216 
217  for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
218  term1.getReductions(), term2.getReductions()))) {
219  Block &oldRedBlock = reg.front();
220  Block &newRedBlock = newReduceOp.getReductions()[i].front();
221  b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
222  newRedBlock.getArguments());
223  }
224 
225  firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
226  secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
227  }
228  term1->erase();
229  term2->erase();
230  firstPloop.erase();
231  secondPloop.erase();
232  secondPloop = newSecondPloop;
233 }
234 
236  Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
237  OpBuilder b(region);
238  // Consider every single block and attempt to fuse adjacent loops.
239  SmallVector<SmallVector<ParallelOp>, 1> ploopChains;
240  for (auto &block : region) {
241  ploopChains.clear();
242  ploopChains.push_back({});
243 
244  // Not using `walk()` to traverse only top-level parallel loops and also
245  // make sure that there are no side-effecting ops between the parallel
246  // loops.
247  bool noSideEffects = true;
248  for (auto &op : block) {
249  if (auto ploop = dyn_cast<ParallelOp>(op)) {
250  if (noSideEffects) {
251  ploopChains.back().push_back(ploop);
252  } else {
253  ploopChains.push_back({ploop});
254  noSideEffects = true;
255  }
256  continue;
257  }
258  // TODO: Handle region side effects properly.
259  noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
260  }
261  for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
262  for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
263  fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
264  }
265  }
266 }
267 
268 namespace {
269 struct ParallelLoopFusion
270  : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
271  void runOnOperation() override {
272  auto &AA = getAnalysis<AliasAnalysis>();
273 
274  auto mayAlias = [&](Value val1, Value val2) -> bool {
275  return !AA.alias(val1, val2).isNo();
276  };
277 
278  getOperation()->walk([&](Operation *child) {
279  for (Region &region : child->getRegions())
281  });
282  }
283 };
284 } // namespace
285 
286 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
287  return std::make_unique<ParallelLoopFusion>();
288 }
static bool mayAlias(Value first, Value second)
Returns true if two values may be referencing aliasing memory.
static bool haveNoReadsAfterWriteExceptSameIndex(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref< bool(Value, Value)> mayAlias)
Checks if the parallel loops have mixed access to the same buffers.
static bool equalIterationSpaces(ParallelOp firstPloop, ParallelOp secondPloop)
Verify equal iteration spaces.
static LogicalResult verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref< bool(Value, Value)> mayAlias)
Analyzes dependencies in the most primitive way by checking simple read and write patterns.
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref< bool(Value, Value)> mayAlias)
static bool hasNestedParallelOp(ParallelOp ploop)
Verify there are no nested ParallelOps.
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, OpBuilder builder, llvm::function_ref< bool(Value, Value)> mayAlias)
Prepends operations of firstPloop's body into secondPloop's body.
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.cpp:324
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:730
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void naivelyFuseParallelOps(Region &region, llvm::function_ref< bool(Value, Value)> mayAlias)
Fuses all adjacent scf.parallel operations with identical bounds and step into one scf....
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::unique_ptr< Pass > createParallelLoopFusionPass()
Creates a loop fusion pass which fuses parallel loops.