26 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
27 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
37 return walkResult.wasInterrupted();
42 ParallelOp secondPloop) {
43 if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
49 return std::equal(lhs.begin(), lhs.end(), rhs.begin());
51 return matchOperands(firstPloop.getLowerBound(),
52 secondPloop.getLowerBound()) &&
53 matchOperands(firstPloop.getUpperBound(),
54 secondPloop.getUpperBound()) &&
55 matchOperands(firstPloop.getStep(), secondPloop.getStep());
62 ParallelOp firstPloop, ParallelOp secondPloop,
63 const IRMapping &firstToSecondPloopIndices,
67 firstPloop.getBody()->walk([&](memref::StoreOp store) {
68 bufferStores[store.getMemRef()].push_back(store.getIndices());
69 bufferStoresVec.emplace_back(store.getMemRef());
71 auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
72 Value loadMem = load.getMemRef();
76 if (memrefDef && memrefDef->getBlock() == load->getBlock())
79 for (
Value store : bufferStoresVec)
80 if (store != loadMem &&
mayAlias(store, loadMem))
83 auto write = bufferStores.find(loadMem);
84 if (write == bufferStores.end())
88 if (!write->second.size())
91 auto storeIndices = write->second.front();
94 for (
const auto &othStoreIndices : write->second) {
95 if (othStoreIndices != storeIndices)
96 return WalkResult::interrupt();
101 auto loadIndices = load.getIndices();
102 if (storeIndices.size() != loadIndices.size())
104 for (
int i = 0, e = storeIndices.size(); i < e; ++i) {
105 if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
107 auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
108 auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
109 if (storeIndexDefOp && loadIndexDefOp) {
110 if (!isMemoryEffectFree(storeIndexDefOp))
111 return WalkResult::interrupt();
112 if (!isMemoryEffectFree(loadIndexDefOp))
113 return WalkResult::interrupt();
114 if (!OperationEquivalence::isEquivalentTo(
115 storeIndexDefOp, loadIndexDefOp,
116 [&](Value storeIndex, Value loadIndex) {
117 if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
118 firstToSecondPloopIndices.lookupOrDefault(loadIndex))
124 OperationEquivalence::Flags::IgnoreLocations)) {
125 return WalkResult::interrupt();
128 return WalkResult::interrupt();
133 return !walkResult.wasInterrupted();
140 const IRMapping &firstToSecondPloopIndices,
143 firstPloop, secondPloop, firstToSecondPloopIndices,
mayAlias))
147 secondToFirstPloopIndices.
map(secondPloop.getBody()->getArguments(),
148 firstPloop.getBody()->getArguments());
150 secondPloop, firstPloop, secondToFirstPloopIndices,
mayAlias));
154 const IRMapping &firstToSecondPloopIndices,
160 firstToSecondPloopIndices,
mayAlias));
165 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
168 Block *block1 = firstPloop.getBody();
169 Block *block2 = secondPloop.getBody();
173 if (!
isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
185 ValueRange inits2 = secondPloop.getInitVals();
188 newInitVars.append(inits2.begin(), inits2.end());
192 auto newSecondPloop = b.
create<ParallelOp>(
193 secondPloop.getLoc(), secondPloop.getLowerBound(),
194 secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
196 Block *newBlock = newSecondPloop.getBody();
201 newBlock->getArguments());
203 newBlock->getArguments());
205 ValueRange results = newSecondPloop.getResults();
206 if (!results.empty()) {
212 newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
214 auto newReduceOp = b.
create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
217 term1.getReductions(), term2.getReductions()))) {
219 Block &newRedBlock = newReduceOp.getReductions()[i].
front();
224 firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
225 secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
231 secondPloop = newSecondPloop;
239 for (
auto &block : region) {
241 ploopChains.push_back({});
246 bool noSideEffects =
true;
247 for (
auto &op : block) {
248 if (
auto ploop = dyn_cast<ParallelOp>(op)) {
250 ploopChains.back().push_back(ploop);
252 ploopChains.push_back({ploop});
253 noSideEffects =
true;
261 for (
int i = 0, e = ploops.size(); i + 1 < e; ++i)
268 struct ParallelLoopFusion
269 :
public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
270 void runOnOperation()
override {
271 auto &AA = getAnalysis<AliasAnalysis>();
274 return !AA.alias(val1, val2).isNo();
277 getOperation()->walk([&](
Operation *child) {
286 return std::make_unique<ParallelLoopFusion>();
static bool mayAlias(Value first, Value second)
Returns true if two values may be referencing aliasing memory.
static bool haveNoReadsAfterWriteExceptSameIndex(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref< bool(Value, Value)> mayAlias)
Checks if the parallel loops have mixed access to the same buffers.
static bool equalIterationSpaces(ParallelOp firstPloop, ParallelOp secondPloop)
Verify equal iteration spaces.
static LogicalResult verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref< bool(Value, Value)> mayAlias)
Analyzes dependencies in the most primitive way by checking simple read and write patterns.
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref< bool(Value, Value)> mayAlias)
static bool hasNestedParallelOp(ParallelOp ploop)
Verify there are no nested ParallelOps.
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, OpBuilder builder, llvm::function_ref< bool(Value, Value)> mayAlias)
Prepends operations of firstPloop's body into secondPloop's body.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
user_range getUsers()
Returns a range of all users.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void naivelyFuseParallelOps(Region ®ion, llvm::function_ref< bool(Value, Value)> mayAlias)
Fuses all adjacent scf.parallel operations with identical bounds and step into one scf....
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::unique_ptr< Pass > createParallelLoopFusionPass()
Creates a loop fusion pass which fuses parallel loops.