MLIR  20.0.0git
ParallelLoopFusion.cpp
Go to the documentation of this file.
1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/OpDefinition.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
27 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 using namespace mlir::scf;
32 
33 /// Verify there are no nested ParallelOps.
34 static bool hasNestedParallelOp(ParallelOp ploop) {
35  auto walkResult =
36  ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
37  return walkResult.wasInterrupted();
38 }
39 
40 /// Verify equal iteration spaces.
41 static bool equalIterationSpaces(ParallelOp firstPloop,
42  ParallelOp secondPloop) {
43  if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
44  return false;
45 
46  auto matchOperands = [&](const OperandRange &lhs,
47  const OperandRange &rhs) -> bool {
48  // TODO: Extend this to support aliases and equal constants.
49  return std::equal(lhs.begin(), lhs.end(), rhs.begin());
50  };
51  return matchOperands(firstPloop.getLowerBound(),
52  secondPloop.getLowerBound()) &&
53  matchOperands(firstPloop.getUpperBound(),
54  secondPloop.getUpperBound()) &&
55  matchOperands(firstPloop.getStep(), secondPloop.getStep());
56 }
57 
58 /// Checks if the parallel loops have mixed access to the same buffers. Returns
59 /// `true` if the first parallel loop writes to the same indices that the second
60 /// loop reads.
62  ParallelOp firstPloop, ParallelOp secondPloop,
63  const IRMapping &firstToSecondPloopIndices,
66  SmallVector<Value> bufferStoresVec;
67  firstPloop.getBody()->walk([&](memref::StoreOp store) {
68  bufferStores[store.getMemRef()].push_back(store.getIndices());
69  bufferStoresVec.emplace_back(store.getMemRef());
70  });
71  auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
72  Value loadMem = load.getMemRef();
73  // Stop if the memref is defined in secondPloop body. Careful alias analysis
74  // is needed.
75  auto *memrefDef = loadMem.getDefiningOp();
76  if (memrefDef && memrefDef->getBlock() == load->getBlock())
77  return WalkResult::interrupt();
78 
79  for (Value store : bufferStoresVec)
80  if (store != loadMem && mayAlias(store, loadMem))
81  return WalkResult::interrupt();
82 
83  auto write = bufferStores.find(loadMem);
84  if (write == bufferStores.end())
85  return WalkResult::advance();
86 
87  // Check that at last one store was retrieved
88  if (write->second.empty())
89  return WalkResult::interrupt();
90 
91  auto storeIndices = write->second.front();
92 
93  // Multiple writes to the same memref are allowed only on the same indices
94  for (const auto &othStoreIndices : write->second) {
95  if (othStoreIndices != storeIndices)
96  return WalkResult::interrupt();
97  }
98 
99  // Check that the load indices of secondPloop coincide with store indices of
100  // firstPloop for the same memrefs.
101  auto loadIndices = load.getIndices();
102  if (storeIndices.size() != loadIndices.size())
103  return WalkResult::interrupt();
104  for (int i = 0, e = storeIndices.size(); i < e; ++i) {
105  if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
106  loadIndices[i]) {
107  auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
108  auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
109  if (storeIndexDefOp && loadIndexDefOp) {
110  if (!isMemoryEffectFree(storeIndexDefOp))
111  return WalkResult::interrupt();
112  if (!isMemoryEffectFree(loadIndexDefOp))
113  return WalkResult::interrupt();
114  if (!OperationEquivalence::isEquivalentTo(
115  storeIndexDefOp, loadIndexDefOp,
116  [&](Value storeIndex, Value loadIndex) {
117  if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
118  firstToSecondPloopIndices.lookupOrDefault(loadIndex))
119  return failure();
120  else
121  return success();
122  },
123  /*markEquivalent=*/nullptr,
124  OperationEquivalence::Flags::IgnoreLocations)) {
125  return WalkResult::interrupt();
126  }
127  } else
128  return WalkResult::interrupt();
129  }
130  }
131  return WalkResult::advance();
132  });
133  return !walkResult.wasInterrupted();
134 }
135 
136 /// Analyzes dependencies in the most primitive way by checking simple read and
137 /// write patterns.
138 static LogicalResult
139 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
140  const IRMapping &firstToSecondPloopIndices,
143  firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
144  return failure();
145 
146  IRMapping secondToFirstPloopIndices;
147  secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
148  firstPloop.getBody()->getArguments());
150  secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
151 }
152 
153 static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
154  const IRMapping &firstToSecondPloopIndices,
156  return !hasNestedParallelOp(firstPloop) &&
157  !hasNestedParallelOp(secondPloop) &&
158  equalIterationSpaces(firstPloop, secondPloop) &&
159  succeeded(verifyDependencies(firstPloop, secondPloop,
160  firstToSecondPloopIndices, mayAlias));
161 }
162 
163 /// Prepends operations of firstPloop's body into secondPloop's body.
164 /// Updates secondPloop with new loop.
165 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
166  OpBuilder builder,
168  Block *block1 = firstPloop.getBody();
169  Block *block2 = secondPloop.getBody();
170  IRMapping firstToSecondPloopIndices;
171  firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
172 
173  if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
174  mayAlias))
175  return;
176 
177  DominanceInfo dom;
178  // We are fusing first loop into second, make sure there are no users of the
179  // first loop results between loops.
180  for (Operation *user : firstPloop->getUsers())
181  if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
182  return;
183 
184  ValueRange inits1 = firstPloop.getInitVals();
185  ValueRange inits2 = secondPloop.getInitVals();
186 
187  SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
188  newInitVars.append(inits2.begin(), inits2.end());
189 
190  IRRewriter b(builder);
191  b.setInsertionPoint(secondPloop);
192  auto newSecondPloop = b.create<ParallelOp>(
193  secondPloop.getLoc(), secondPloop.getLowerBound(),
194  secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
195 
196  Block *newBlock = newSecondPloop.getBody();
197  auto term1 = cast<ReduceOp>(block1->getTerminator());
198  auto term2 = cast<ReduceOp>(block2->getTerminator());
199 
200  b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
201  newBlock->getArguments());
202  b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
203  newBlock->getArguments());
204 
205  ValueRange results = newSecondPloop.getResults();
206  if (!results.empty()) {
207  b.setInsertionPointToEnd(newBlock);
208 
209  ValueRange reduceArgs1 = term1.getOperands();
210  ValueRange reduceArgs2 = term2.getOperands();
211  SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
212  newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
213 
214  auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
215 
216  for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
217  term1.getReductions(), term2.getReductions()))) {
218  Block &oldRedBlock = reg.front();
219  Block &newRedBlock = newReduceOp.getReductions()[i].front();
220  b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
221  newRedBlock.getArguments());
222  }
223 
224  firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
225  secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
226  }
227  term1->erase();
228  term2->erase();
229  firstPloop.erase();
230  secondPloop.erase();
231  secondPloop = newSecondPloop;
232 }
233 
235  Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
236  OpBuilder b(region);
237  // Consider every single block and attempt to fuse adjacent loops.
238  SmallVector<SmallVector<ParallelOp>, 1> ploopChains;
239  for (auto &block : region) {
240  ploopChains.clear();
241  ploopChains.push_back({});
242 
243  // Not using `walk()` to traverse only top-level parallel loops and also
244  // make sure that there are no side-effecting ops between the parallel
245  // loops.
246  bool noSideEffects = true;
247  for (auto &op : block) {
248  if (auto ploop = dyn_cast<ParallelOp>(op)) {
249  if (noSideEffects) {
250  ploopChains.back().push_back(ploop);
251  } else {
252  ploopChains.push_back({ploop});
253  noSideEffects = true;
254  }
255  continue;
256  }
257  // TODO: Handle region side effects properly.
258  noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
259  }
260  for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
261  for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
262  fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
263  }
264  }
265 }
266 
267 namespace {
268 struct ParallelLoopFusion
269  : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
270  void runOnOperation() override {
271  auto &AA = getAnalysis<AliasAnalysis>();
272 
273  auto mayAlias = [&](Value val1, Value val2) -> bool {
274  return !AA.alias(val1, val2).isNo();
275  };
276 
277  getOperation()->walk([&](Operation *child) {
278  for (Region &region : child->getRegions())
280  });
281  }
282 };
283 } // namespace
284 
285 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
286  return std::make_unique<ParallelLoopFusion>();
287 }
static bool mayAlias(Value first, Value second)
Returns true if two values may be referencing aliasing memory.
static bool haveNoReadsAfterWriteExceptSameIndex(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref< bool(Value, Value)> mayAlias)
Checks if the parallel loops have mixed access to the same buffers.
static bool equalIterationSpaces(ParallelOp firstPloop, ParallelOp secondPloop)
Verify equal iteration spaces.
static LogicalResult verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref< bool(Value, Value)> mayAlias)
Analyzes dependencies in the most primitive way by checking simple read and write patterns.
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref< bool(Value, Value)> mayAlias)
static bool hasNestedParallelOp(ParallelOp ploop)
Verify there are no nested ParallelOps.
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, OpBuilder builder, llvm::function_ref< bool(Value, Value)> mayAlias)
Prepends operations of firstPloop's body into secondPloop's body.
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.h:153
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:772
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:445
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:874
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void naivelyFuseParallelOps(Region &region, llvm::function_ref< bool(Value, Value)> mayAlias)
Fuses all adjacent scf.parallel operations with identical bounds and step into one scf....
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::unique_ptr< Pass > createParallelLoopFusionPass()
Creates a loop fusion pass which fuses parallel loops.