26 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
27 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
37 return walkResult.wasInterrupted();
42 ParallelOp secondPloop) {
43 if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
49 return std::equal(lhs.begin(), lhs.end(), rhs.begin());
51 return matchOperands(firstPloop.getLowerBound(),
52 secondPloop.getLowerBound()) &&
53 matchOperands(firstPloop.getUpperBound(),
54 secondPloop.getUpperBound()) &&
55 matchOperands(firstPloop.getStep(), secondPloop.getStep());
62 ParallelOp firstPloop, ParallelOp secondPloop,
63 const IRMapping &firstToSecondPloopIndices,
67 firstPloop.getBody()->walk([&](memref::StoreOp store) {
68 bufferStores[store.getMemRef()].push_back(store.getIndices());
69 bufferStoresVec.emplace_back(store.getMemRef());
71 auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
72 Value loadMem = load.getMemRef();
76 if (memrefDef && memrefDef->getBlock() == load->getBlock())
79 for (
Value store : bufferStoresVec)
80 if (store != loadMem &&
mayAlias(store, loadMem))
83 auto write = bufferStores.find(loadMem);
84 if (write == bufferStores.end())
88 if (write->second.empty())
91 auto storeIndices = write->second.front();
94 for (
const auto &othStoreIndices : write->second) {
95 if (othStoreIndices != storeIndices)
96 return WalkResult::interrupt();
101 auto loadIndices = load.getIndices();
102 if (storeIndices.size() != loadIndices.size())
104 for (
int i = 0, e = storeIndices.size(); i < e; ++i) {
105 if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
107 auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
108 auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
109 if (storeIndexDefOp && loadIndexDefOp) {
110 if (!isMemoryEffectFree(storeIndexDefOp))
111 return WalkResult::interrupt();
112 if (!isMemoryEffectFree(loadIndexDefOp))
113 return WalkResult::interrupt();
114 if (!OperationEquivalence::isEquivalentTo(
115 storeIndexDefOp, loadIndexDefOp,
116 [&](Value storeIndex, Value loadIndex) {
117 if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
118 firstToSecondPloopIndices.lookupOrDefault(loadIndex))
124 OperationEquivalence::Flags::IgnoreLocations)) {
125 return WalkResult::interrupt();
128 return WalkResult::interrupt();
134 return !walkResult.wasInterrupted();
141 const IRMapping &firstToSecondPloopIndices,
144 firstPloop, secondPloop, firstToSecondPloopIndices,
mayAlias))
148 secondToFirstPloopIndices.
map(secondPloop.getBody()->getArguments(),
149 firstPloop.getBody()->getArguments());
151 secondPloop, firstPloop, secondToFirstPloopIndices,
mayAlias));
155 const IRMapping &firstToSecondPloopIndices,
161 firstToSecondPloopIndices,
mayAlias));
166 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
169 Block *block1 = firstPloop.getBody();
170 Block *block2 = secondPloop.getBody();
174 if (!
isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
186 ValueRange inits2 = secondPloop.getInitVals();
189 newInitVars.append(inits2.begin(), inits2.end());
193 auto newSecondPloop = b.
create<ParallelOp>(
194 secondPloop.getLoc(), secondPloop.getLowerBound(),
195 secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
197 Block *newBlock = newSecondPloop.getBody();
202 newBlock->getArguments());
204 newBlock->getArguments());
206 ValueRange results = newSecondPloop.getResults();
207 if (!results.empty()) {
213 newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
215 auto newReduceOp = b.
create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
218 term1.getReductions(), term2.getReductions()))) {
220 Block &newRedBlock = newReduceOp.getReductions()[i].
front();
225 firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
226 secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
232 secondPloop = newSecondPloop;
240 for (
auto &block : region) {
242 ploopChains.push_back({});
247 bool noSideEffects =
true;
248 for (
auto &op : block) {
249 if (
auto ploop = dyn_cast<ParallelOp>(op)) {
251 ploopChains.back().push_back(ploop);
253 ploopChains.push_back({ploop});
254 noSideEffects =
true;
262 for (
int i = 0, e = ploops.size(); i + 1 < e; ++i)
269 struct ParallelLoopFusion
270 :
public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
271 void runOnOperation()
override {
272 auto &AA = getAnalysis<AliasAnalysis>();
275 return !AA.alias(val1, val2).isNo();
278 getOperation()->walk([&](
Operation *child) {
287 return std::make_unique<ParallelLoopFusion>();
static bool mayAlias(Value first, Value second)
Returns true if two values may be referencing aliasing memory.
static bool haveNoReadsAfterWriteExceptSameIndex(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref< bool(Value, Value)> mayAlias)
Checks if the parallel loops have mixed access to the same buffers.
static bool equalIterationSpaces(ParallelOp firstPloop, ParallelOp secondPloop)
Verify equal iteration spaces.
static LogicalResult verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref< bool(Value, Value)> mayAlias)
Analyzes dependencies in the most primitive way by checking simple read and write patterns.
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref< bool(Value, Value)> mayAlias)
static bool hasNestedParallelOp(ParallelOp ploop)
Verify there are no nested ParallelOps.
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, OpBuilder builder, llvm::function_ref< bool(Value, Value)> mayAlias)
Prepends operations of firstPloop's body into secondPloop's body.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
user_range getUsers()
Returns a range of all users.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void naivelyFuseParallelOps(Region ®ion, llvm::function_ref< bool(Value, Value)> mayAlias)
Fuses all adjacent scf.parallel operations with identical bounds and step into one scf....
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::unique_ptr< Pass > createParallelLoopFusionPass()
Creates a loop fusion pass which fuses parallel loops.