MLIR  18.0.0git
ParallelLoopFusion.cpp
Go to the documentation of this file.
1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/IRMapping.h"
20 #include "mlir/IR/OpDefinition.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
25 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::scf;
30 
31 /// Verify there are no nested ParallelOps.
32 static bool hasNestedParallelOp(ParallelOp ploop) {
33  auto walkResult =
34  ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
35  return walkResult.wasInterrupted();
36 }
37 
38 /// Verify equal iteration spaces.
39 static bool equalIterationSpaces(ParallelOp firstPloop,
40  ParallelOp secondPloop) {
41  if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
42  return false;
43 
44  auto matchOperands = [&](const OperandRange &lhs,
45  const OperandRange &rhs) -> bool {
46  // TODO: Extend this to support aliases and equal constants.
47  return std::equal(lhs.begin(), lhs.end(), rhs.begin());
48  };
49  return matchOperands(firstPloop.getLowerBound(),
50  secondPloop.getLowerBound()) &&
51  matchOperands(firstPloop.getUpperBound(),
52  secondPloop.getUpperBound()) &&
53  matchOperands(firstPloop.getStep(), secondPloop.getStep());
54 }
55 
56 /// Checks if the parallel loops have mixed access to the same buffers. Returns
57 /// `true` if the first parallel loop writes to the same indices that the second
58 /// loop reads.
60  ParallelOp firstPloop, ParallelOp secondPloop,
61  const IRMapping &firstToSecondPloopIndices) {
63  firstPloop.getBody()->walk([&](memref::StoreOp store) {
64  bufferStores[store.getMemRef()].push_back(store.getIndices());
65  });
66  auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
67  // Stop if the memref is defined in secondPloop body. Careful alias analysis
68  // is needed.
69  auto *memrefDef = load.getMemRef().getDefiningOp();
70  if (memrefDef && memrefDef->getBlock() == load->getBlock())
71  return WalkResult::interrupt();
72 
73  auto write = bufferStores.find(load.getMemRef());
74  if (write == bufferStores.end())
75  return WalkResult::advance();
76 
77  // Allow only single write access per buffer.
78  if (write->second.size() != 1)
79  return WalkResult::interrupt();
80 
81  // Check that the load indices of secondPloop coincide with store indices of
82  // firstPloop for the same memrefs.
83  auto storeIndices = write->second.front();
84  auto loadIndices = load.getIndices();
85  if (storeIndices.size() != loadIndices.size())
86  return WalkResult::interrupt();
87  for (int i = 0, e = storeIndices.size(); i < e; ++i) {
88  if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
89  loadIndices[i])
90  return WalkResult::interrupt();
91  }
92  return WalkResult::advance();
93  });
94  return !walkResult.wasInterrupted();
95 }
96 
97 /// Analyzes dependencies in the most primitive way by checking simple read and
98 /// write patterns.
99 static LogicalResult
100 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
101  const IRMapping &firstToSecondPloopIndices) {
102  if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop,
103  firstToSecondPloopIndices))
104  return failure();
105 
106  IRMapping secondToFirstPloopIndices;
107  secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
108  firstPloop.getBody()->getArguments());
110  secondPloop, firstPloop, secondToFirstPloopIndices));
111 }
112 
113 static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
114  const IRMapping &firstToSecondPloopIndices) {
115  return !hasNestedParallelOp(firstPloop) &&
116  !hasNestedParallelOp(secondPloop) &&
117  equalIterationSpaces(firstPloop, secondPloop) &&
118  succeeded(verifyDependencies(firstPloop, secondPloop,
119  firstToSecondPloopIndices));
120 }
121 
122 /// Prepends operations of firstPloop's body into secondPloop's body.
123 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
124  OpBuilder b) {
125  IRMapping firstToSecondPloopIndices;
126  firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
127  secondPloop.getBody()->getArguments());
128 
129  if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
130  return;
131 
132  b.setInsertionPointToStart(secondPloop.getBody());
133  for (auto &op : firstPloop.getBody()->without_terminator())
134  b.clone(op, firstToSecondPloopIndices);
135  firstPloop.erase();
136 }
137 
139  OpBuilder b(region);
140  // Consider every single block and attempt to fuse adjacent loops.
141  for (auto &block : region) {
142  SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
143  // Not using `walk()` to traverse only top-level parallel loops and also
144  // make sure that there are no side-effecting ops between the parallel
145  // loops.
146  bool noSideEffects = true;
147  for (auto &op : block) {
148  if (auto ploop = dyn_cast<ParallelOp>(op)) {
149  if (noSideEffects) {
150  ploopChains.back().push_back(ploop);
151  } else {
152  ploopChains.push_back({ploop});
153  noSideEffects = true;
154  }
155  continue;
156  }
157  // TODO: Handle region side effects properly.
158  noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
159  }
160  for (ArrayRef<ParallelOp> ploops : ploopChains) {
161  for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
162  fuseIfLegal(ploops[i], ploops[i + 1], b);
163  }
164  }
165 }
166 
167 namespace {
168 struct ParallelLoopFusion
169  : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
170  void runOnOperation() override {
171  getOperation()->walk([&](Operation *child) {
172  for (Region &region : child->getRegions())
173  naivelyFuseParallelOps(region);
174  });
175  }
176 };
177 } // namespace
178 
179 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
180  return std::make_unique<ParallelLoopFusion>();
181 }
static bool equalIterationSpaces(ParallelOp firstPloop, ParallelOp secondPloop)
Verify equal iteration spaces.
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices)
static LogicalResult verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices)
Analyzes dependencies in the most primitive way by checking simple read and write patterns.
static bool haveNoReadsAfterWriteExceptSameIndex(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices)
Checks if the parallel loops have mixed access to the same buffers.
static bool hasNestedParallelOp(ParallelOp ploop)
Verify there are no nested ParallelOps.
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop, OpBuilder b)
Prepends operations of firstPloop's body into secondPloop's body.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class helps build Operations.
Definition: Builders.h:206
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:528
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:652
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:655
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
static WalkResult advance()
Definition: Visitors.h:52
static WalkResult interrupt()
Definition: Visitors.h:51
void naivelyFuseParallelOps(Region &region)
Fuses all adjacent scf.parallel operations with identical bounds and step into one scf....
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::unique_ptr< Pass > createParallelLoopFusionPass()
Creates a loop fusion pass which fuses parallel loops.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26