MLIR  14.0.0git
ParallelLoopFusion.cpp
Go to the documentation of this file.
1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
16 #include "mlir/Dialect/SCF/SCF.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/OpDefinition.h"
22 
23 using namespace mlir;
24 using namespace mlir::scf;
25 
26 /// Verify there are no nested ParallelOps.
27 static bool hasNestedParallelOp(ParallelOp ploop) {
28  auto walkResult =
29  ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
30  return walkResult.wasInterrupted();
31 }
32 
33 /// Verify equal iteration spaces.
34 static bool equalIterationSpaces(ParallelOp firstPloop,
35  ParallelOp secondPloop) {
36  if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
37  return false;
38 
39  auto matchOperands = [&](const OperandRange &lhs,
40  const OperandRange &rhs) -> bool {
41  // TODO: Extend this to support aliases and equal constants.
42  return std::equal(lhs.begin(), lhs.end(), rhs.begin());
43  };
44  return matchOperands(firstPloop.getLowerBound(),
45  secondPloop.getLowerBound()) &&
46  matchOperands(firstPloop.getUpperBound(),
47  secondPloop.getUpperBound()) &&
48  matchOperands(firstPloop.getStep(), secondPloop.getStep());
49 }
50 
51 /// Checks if the parallel loops have mixed access to the same buffers. Returns
52 /// `true` if the first parallel loop writes to the same indices that the second
53 /// loop reads.
55  ParallelOp firstPloop, ParallelOp secondPloop,
56  const BlockAndValueMapping &firstToSecondPloopIndices) {
58  firstPloop.getBody()->walk([&](memref::StoreOp store) {
59  bufferStores[store.getMemRef()].push_back(store.indices());
60  });
61  auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
62  // Stop if the memref is defined in secondPloop body. Careful alias analysis
63  // is needed.
64  auto *memrefDef = load.getMemRef().getDefiningOp();
65  if (memrefDef && memrefDef->getBlock() == load->getBlock())
66  return WalkResult::interrupt();
67 
68  auto write = bufferStores.find(load.getMemRef());
69  if (write == bufferStores.end())
70  return WalkResult::advance();
71 
72  // Allow only single write access per buffer.
73  if (write->second.size() != 1)
74  return WalkResult::interrupt();
75 
76  // Check that the load indices of secondPloop coincide with store indices of
77  // firstPloop for the same memrefs.
78  auto storeIndices = write->second.front();
79  auto loadIndices = load.indices();
80  if (storeIndices.size() != loadIndices.size())
81  return WalkResult::interrupt();
82  for (int i = 0, e = storeIndices.size(); i < e; ++i) {
83  if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
84  loadIndices[i])
85  return WalkResult::interrupt();
86  }
87  return WalkResult::advance();
88  });
89  return !walkResult.wasInterrupted();
90 }
91 
92 /// Analyzes dependencies in the most primitive way by checking simple read and
93 /// write patterns.
94 static LogicalResult
95 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
96  const BlockAndValueMapping &firstToSecondPloopIndices) {
97  if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop,
98  firstToSecondPloopIndices))
99  return failure();
100 
101  BlockAndValueMapping secondToFirstPloopIndices;
102  secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
103  firstPloop.getBody()->getArguments());
105  secondPloop, firstPloop, secondToFirstPloopIndices));
106 }
107 
108 static bool
109 isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
110  const BlockAndValueMapping &firstToSecondPloopIndices) {
111  return !hasNestedParallelOp(firstPloop) &&
112  !hasNestedParallelOp(secondPloop) &&
113  equalIterationSpaces(firstPloop, secondPloop) &&
114  succeeded(verifyDependencies(firstPloop, secondPloop,
115  firstToSecondPloopIndices));
116 }
117 
118 /// Prepends operations of firstPloop's body into secondPloop's body.
119 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
120  OpBuilder b) {
121  BlockAndValueMapping firstToSecondPloopIndices;
122  firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
123  secondPloop.getBody()->getArguments());
124 
125  if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
126  return;
127 
128  b.setInsertionPointToStart(secondPloop.getBody());
129  for (auto &op : firstPloop.getBody()->without_terminator())
130  b.clone(op, firstToSecondPloopIndices);
131  firstPloop.erase();
132 }
133 
135  OpBuilder b(region);
136  // Consider every single block and attempt to fuse adjacent loops.
137  for (auto &block : region) {
138  SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
139  // Not using `walk()` to traverse only top-level parallel loops and also
140  // make sure that there are no side-effecting ops between the parallel
141  // loops.
142  bool noSideEffects = true;
143  for (auto &op : block) {
144  if (auto ploop = dyn_cast<ParallelOp>(op)) {
145  if (noSideEffects) {
146  ploopChains.back().push_back(ploop);
147  } else {
148  ploopChains.push_back({ploop});
149  noSideEffects = true;
150  }
151  continue;
152  }
153  // TODO: Handle region side effects properly.
154  noSideEffects &=
155  MemoryEffectOpInterface::hasNoEffect(&op) && op.getNumRegions() == 0;
156  }
157  for (ArrayRef<ParallelOp> ploops : ploopChains) {
158  for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
159  fuseIfLegal(ploops[i], ploops[i + 1], b);
160  }
161  }
162 }
163 
164 namespace {
165 struct ParallelLoopFusion
166  : public SCFParallelLoopFusionBase<ParallelLoopFusion> {
167  void runOnOperation() override {
168  getOperation()->walk([&](Operation *child) {
169  for (Region &region : child->getRegions())
170  naivelyFuseParallelOps(region);
171  });
172  }
173 };
174 } // namespace
175 
176 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
177  return std::make_unique<ParallelLoopFusion>();
178 }
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:423
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:457
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
static bool equalIterationSpaces(ParallelOp firstPloop, ParallelOp secondPloop)
Verify equal iteration spaces.
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:424
void map(Block *from, Block *to)
Inserts a new mapping for &#39;from&#39; to &#39;to&#39;.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static LogicalResult verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, const BlockAndValueMapping &firstToSecondPloopIndices)
Analyzes dependencies in the most primitive way by checking simple read and write patterns...
static bool hasNestedParallelOp(ParallelOp ploop)
Verify there are no nested ParallelOps.
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, const BlockAndValueMapping &firstToSecondPloopIndices)
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop, OpBuilder b)
Prepends operations of firstPloop&#39;s body into secondPloop&#39;s body.
static bool haveNoReadsAfterWriteExceptSameIndex(ParallelOp firstPloop, ParallelOp secondPloop, const BlockAndValueMapping &firstToSecondPloopIndices)
Checks if the parallel loops have mixed access to the same buffers.
void naivelyFuseParallelOps(Region &region)
Fuses all adjacent scf.parallel operations with identical bounds and step into one scf...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:362
Block * lookupOrDefault(Block *from) const
Lookup a mapped value within the map.
This class implements the operand iterators for the Operation class.
This class helps build Operations.
Definition: Builders.h:177
std::unique_ptr< Pass > createParallelLoopFusionPass()
Creates a loop fusion pass which fuses parallel loops.