MLIR  16.0.0git
ParallelLoopFusion.cpp
Go to the documentation of this file.
1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/OpDefinition.h"
21 
22 using namespace mlir;
23 using namespace mlir::scf;
24 
25 /// Verify there are no nested ParallelOps.
26 static bool hasNestedParallelOp(ParallelOp ploop) {
27  auto walkResult =
28  ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
29  return walkResult.wasInterrupted();
30 }
31 
32 /// Verify equal iteration spaces.
33 static bool equalIterationSpaces(ParallelOp firstPloop,
34  ParallelOp secondPloop) {
35  if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
36  return false;
37 
38  auto matchOperands = [&](const OperandRange &lhs,
39  const OperandRange &rhs) -> bool {
40  // TODO: Extend this to support aliases and equal constants.
41  return std::equal(lhs.begin(), lhs.end(), rhs.begin());
42  };
43  return matchOperands(firstPloop.getLowerBound(),
44  secondPloop.getLowerBound()) &&
45  matchOperands(firstPloop.getUpperBound(),
46  secondPloop.getUpperBound()) &&
47  matchOperands(firstPloop.getStep(), secondPloop.getStep());
48 }
49 
50 /// Checks if the parallel loops have mixed access to the same buffers. Returns
51 /// `true` if the first parallel loop writes to the same indices that the second
52 /// loop reads.
54  ParallelOp firstPloop, ParallelOp secondPloop,
55  const BlockAndValueMapping &firstToSecondPloopIndices) {
57  firstPloop.getBody()->walk([&](memref::StoreOp store) {
58  bufferStores[store.getMemRef()].push_back(store.getIndices());
59  });
60  auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
61  // Stop if the memref is defined in secondPloop body. Careful alias analysis
62  // is needed.
63  auto *memrefDef = load.getMemRef().getDefiningOp();
64  if (memrefDef && memrefDef->getBlock() == load->getBlock())
65  return WalkResult::interrupt();
66 
67  auto write = bufferStores.find(load.getMemRef());
68  if (write == bufferStores.end())
69  return WalkResult::advance();
70 
71  // Allow only single write access per buffer.
72  if (write->second.size() != 1)
73  return WalkResult::interrupt();
74 
75  // Check that the load indices of secondPloop coincide with store indices of
76  // firstPloop for the same memrefs.
77  auto storeIndices = write->second.front();
78  auto loadIndices = load.getIndices();
79  if (storeIndices.size() != loadIndices.size())
80  return WalkResult::interrupt();
81  for (int i = 0, e = storeIndices.size(); i < e; ++i) {
82  if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
83  loadIndices[i])
84  return WalkResult::interrupt();
85  }
86  return WalkResult::advance();
87  });
88  return !walkResult.wasInterrupted();
89 }
90 
91 /// Analyzes dependencies in the most primitive way by checking simple read and
92 /// write patterns.
93 static LogicalResult
94 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
95  const BlockAndValueMapping &firstToSecondPloopIndices) {
96  if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop,
97  firstToSecondPloopIndices))
98  return failure();
99 
100  BlockAndValueMapping secondToFirstPloopIndices;
101  secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
102  firstPloop.getBody()->getArguments());
104  secondPloop, firstPloop, secondToFirstPloopIndices));
105 }
106 
107 static bool
108 isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
109  const BlockAndValueMapping &firstToSecondPloopIndices) {
110  return !hasNestedParallelOp(firstPloop) &&
111  !hasNestedParallelOp(secondPloop) &&
112  equalIterationSpaces(firstPloop, secondPloop) &&
113  succeeded(verifyDependencies(firstPloop, secondPloop,
114  firstToSecondPloopIndices));
115 }
116 
117 /// Prepends operations of firstPloop's body into secondPloop's body.
118 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
119  OpBuilder b) {
120  BlockAndValueMapping firstToSecondPloopIndices;
121  firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
122  secondPloop.getBody()->getArguments());
123 
124  if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
125  return;
126 
127  b.setInsertionPointToStart(secondPloop.getBody());
128  for (auto &op : firstPloop.getBody()->without_terminator())
129  b.clone(op, firstToSecondPloopIndices);
130  firstPloop.erase();
131 }
132 
134  OpBuilder b(region);
135  // Consider every single block and attempt to fuse adjacent loops.
136  for (auto &block : region) {
137  SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
138  // Not using `walk()` to traverse only top-level parallel loops and also
139  // make sure that there are no side-effecting ops between the parallel
140  // loops.
141  bool noSideEffects = true;
142  for (auto &op : block) {
143  if (auto ploop = dyn_cast<ParallelOp>(op)) {
144  if (noSideEffects) {
145  ploopChains.back().push_back(ploop);
146  } else {
147  ploopChains.push_back({ploop});
148  noSideEffects = true;
149  }
150  continue;
151  }
152  // TODO: Handle region side effects properly.
153  noSideEffects &=
154  MemoryEffectOpInterface::hasNoEffect(&op) && op.getNumRegions() == 0;
155  }
156  for (ArrayRef<ParallelOp> ploops : ploopChains) {
157  for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
158  fuseIfLegal(ploops[i], ploops[i + 1], b);
159  }
160  }
161 }
162 
163 namespace {
164 struct ParallelLoopFusion
165  : public SCFParallelLoopFusionBase<ParallelLoopFusion> {
166  void runOnOperation() override {
167  getOperation()->walk([&](Operation *child) {
168  for (Region &region : child->getRegions())
169  naivelyFuseParallelOps(region);
170  });
171  }
172 };
173 } // namespace
174 
175 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
176  return std::make_unique<ParallelLoopFusion>();
177 }
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:492
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
static bool equalIterationSpaces(ParallelOp firstPloop, ParallelOp secondPloop)
Verify equal iteration spaces.
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:414
void map(Block *from, Block *to)
Inserts a new mapping for &#39;from&#39; to &#39;to&#39;.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static LogicalResult verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, const BlockAndValueMapping &firstToSecondPloopIndices)
Analyzes dependencies in the most primitive way by checking simple read and write patterns...
static bool hasNestedParallelOp(ParallelOp ploop)
Verify there are no nested ParallelOps.
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, const BlockAndValueMapping &firstToSecondPloopIndices)
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop, OpBuilder b)
Prepends operations of firstPloop&#39;s body into secondPloop&#39;s body.
static bool haveNoReadsAfterWriteExceptSameIndex(ParallelOp firstPloop, ParallelOp secondPloop, const BlockAndValueMapping &firstToSecondPloopIndices)
Checks if the parallel loops have mixed access to the same buffers.
void naivelyFuseParallelOps(Region &region)
Fuses all adjacent scf.parallel operations with identical bounds and step into one scf...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:377
Block * lookupOrDefault(Block *from) const
Lookup a mapped value within the map.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:40
This class helps build Operations.
Definition: Builders.h:192
std::unique_ptr< Pass > createParallelLoopFusionPass()
Creates a loop fusion pass which fuses parallel loops.