24 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
25 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
35 return walkResult.wasInterrupted();
40 ParallelOp secondPloop) {
41 if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
47 return std::equal(lhs.begin(), lhs.end(), rhs.begin());
49 return matchOperands(firstPloop.getLowerBound(),
50 secondPloop.getLowerBound()) &&
51 matchOperands(firstPloop.getUpperBound(),
52 secondPloop.getUpperBound()) &&
53 matchOperands(firstPloop.getStep(), secondPloop.getStep());
60 ParallelOp firstPloop, ParallelOp secondPloop,
61 const IRMapping &firstToSecondPloopIndices) {
63 firstPloop.getBody()->walk([&](memref::StoreOp store) {
64 bufferStores[store.getMemRef()].push_back(store.getIndices());
66 auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
69 auto *memrefDef = load.getMemRef().getDefiningOp();
70 if (memrefDef && memrefDef->getBlock() == load->getBlock())
73 auto write = bufferStores.find(load.getMemRef());
74 if (write == bufferStores.end())
78 if (write->second.size() != 1)
83 auto storeIndices = write->second.front();
84 auto loadIndices = load.getIndices();
85 if (storeIndices.size() != loadIndices.size())
87 for (
int i = 0, e = storeIndices.size(); i < e; ++i) {
88 if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
90 return WalkResult::interrupt();
94 return !walkResult.wasInterrupted();
101 const IRMapping &firstToSecondPloopIndices) {
103 firstToSecondPloopIndices))
107 secondToFirstPloopIndices.
map(secondPloop.getBody()->getArguments(),
108 firstPloop.getBody()->getArguments());
110 secondPloop, firstPloop, secondToFirstPloopIndices));
114 const IRMapping &firstToSecondPloopIndices) {
119 firstToSecondPloopIndices));
123 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
126 firstToSecondPloopIndices.
map(firstPloop.getBody()->getArguments(),
127 secondPloop.getBody()->getArguments());
129 if (!
isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
133 for (
auto &op : firstPloop.getBody()->without_terminator())
134 b.
clone(op, firstToSecondPloopIndices);
141 for (
auto &block : region) {
146 bool noSideEffects =
true;
147 for (
auto &op : block) {
148 if (
auto ploop = dyn_cast<ParallelOp>(op)) {
150 ploopChains.back().push_back(ploop);
152 ploopChains.push_back({ploop});
153 noSideEffects =
true;
161 for (
int i = 0, e = ploops.size(); i + 1 < e; ++i)
168 struct ParallelLoopFusion
169 :
public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
170 void runOnOperation()
override {
171 getOperation()->walk([&](
Operation *child) {
180 return std::make_unique<ParallelLoopFusion>();
static bool equalIterationSpaces(ParallelOp firstPloop, ParallelOp secondPloop)
Verify equal iteration spaces.
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices)
static LogicalResult verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices)
Analyzes dependencies in the most primitive way by checking simple read and write patterns.
static bool haveNoReadsAfterWriteExceptSameIndex(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices)
Checks if the parallel loops have mixed access to the same buffers.
static bool hasNestedParallelOp(ParallelOp ploop)
Verify there are no nested ParallelOps.
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop, OpBuilder b)
Prepends operations of firstPloop's body into secondPloop's body.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
static WalkResult advance()
static WalkResult interrupt()
void naivelyFuseParallelOps(Region ®ion)
Fuses all adjacent scf.parallel operations with identical bounds and step into one scf....
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
std::unique_ptr< Pass > createParallelLoopFusionPass()
Creates a loop fusion pass which fuses parallel loops.
This class represents an efficient way to signal success or failure.