MLIR 22.0.0git
LoopFusion.cpp
Go to the documentation of this file.
1//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements affine fusion.
10//
11//===----------------------------------------------------------------------===//
12
14
22#include "mlir/IR/AffineExpr.h"
23#include "mlir/IR/AffineMap.h"
24#include "mlir/IR/Builders.h"
25#include "llvm/ADT/DenseMap.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/Support/CommandLine.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/DebugLog.h"
30#include "llvm/Support/raw_ostream.h"
31#include <iomanip>
32#include <optional>
33#include <sstream>
34
35namespace mlir {
36namespace affine {
37#define GEN_PASS_DEF_AFFINELOOPFUSION
38#include "mlir/Dialect/Affine/Passes.h.inc"
39} // namespace affine
40} // namespace mlir
41
42#define DEBUG_TYPE "affine-fusion"
43
44using namespace mlir;
45using namespace mlir::affine;
46
47namespace {
48/// Loop fusion pass. This pass currently supports a greedy fusion policy,
49/// which fuses loop nests with single-writer/single-reader memref dependences
50/// with the goal of improving locality.
51// TODO: Support fusion of source loop nests which write to multiple
52// memrefs, where each memref can have multiple users (if profitable).
53struct LoopFusion : public affine::impl::AffineLoopFusionBase<LoopFusion> {
54 LoopFusion() = default;
55 LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes,
56 bool maximalFusion, enum FusionMode affineFusionMode) {
57 this->fastMemorySpace = fastMemorySpace;
58 this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024;
59 this->maximalFusion = maximalFusion;
60 this->affineFusionMode = affineFusionMode;
61 }
62
63 void runOnBlock(Block *block);
64 void runOnOperation() override;
65};
66
67} // namespace
68
69/// Returns true if node 'srcId' can be removed after fusing it with node
70/// 'dstId'. The node can be removed if any of the following conditions are met:
71/// 1. 'srcId' has no output dependences after fusion and no escaping memrefs.
72/// 2. 'srcId' has no output dependences after fusion, has escaping memrefs
73/// and the fusion slice is maximal.
74/// 3. 'srcId' has output dependences after fusion, the fusion slice is
75/// maximal and the fusion insertion point dominates all the dependences.
77 unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice,
78 Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs,
79 const MemRefDependenceGraph &mdg) {
80
81 Operation *dstNodeOp = mdg.getNode(dstId)->op;
82 bool hasOutDepsAfterFusion = false;
83
84 for (auto &outEdge : mdg.outEdges.lookup(srcId)) {
85 Operation *depNodeOp = mdg.getNode(outEdge.id)->op;
86 // Skip dependence with dstOp since it will be removed after fusion.
87 if (depNodeOp == dstNodeOp)
88 continue;
89
90 // Only fusion within the same block is supported. Use domination analysis
91 // when needed.
92 if (depNodeOp->getBlock() != dstNodeOp->getBlock())
93 return false;
94
95 // Check if the insertion point of the fused loop dominates the dependence.
96 // Otherwise, the src loop can't be removed.
97 if (fusedLoopInsPoint != depNodeOp &&
98 !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) {
99 LDBG() << "Src loop can't be removed: dst loop doesn't "
100 << "dominate dependence";
101 return false;
102 }
103
104 hasOutDepsAfterFusion = true;
105 }
106
107 // If src loop has dependences after fusion or it writes to an live-out or
108 // escaping memref, we can only remove it if the fusion slice is maximal so
109 // that all the dependences are preserved.
110 if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
111 std::optional<bool> isMaximal = fusionSlice.isMaximal();
112 if (!isMaximal) {
113 LDBG() << "Src loop can't be removed: can't determine "
114 << "if fusion is maximal";
115 return false;
116 }
117
118 if (!*isMaximal) {
119 LDBG() << "Src loop can't be removed: fusion is not maximal";
120 return false;
121 }
122 }
123
124 return true;
125}
126
127/// Returns in 'srcIdCandidates' the producer fusion candidates for consumer
128/// 'dstId'. Candidates are sorted by node id order. This order corresponds to
129/// the program order when the 'mdg' is created. However, program order is not
130/// guaranteed and must not be required by the client. Program order won't be
131/// held if the 'mdg' is reused from a previous fusion step or if the node
132/// creation order changes in the future to support more advance cases.
133// TODO: Move this to a loop fusion utility once 'mdg' is also moved.
134static void getProducerCandidates(unsigned dstId,
135 const MemRefDependenceGraph &mdg,
136 SmallVectorImpl<unsigned> &srcIdCandidates) {
137 // Skip if no input edges along which to fuse.
138 if (mdg.inEdges.count(dstId) == 0)
139 return;
140
141 // Gather memrefs from loads in 'dstId'.
142 auto *dstNode = mdg.getNode(dstId);
143 DenseSet<Value> consumedMemrefs;
144 for (Operation *load : dstNode->loads)
145 consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
146
147 // Traverse 'dstId' incoming edges and gather the nodes that contain a store
148 // to one of the consumed memrefs.
149 for (const auto &srcEdge : mdg.inEdges.lookup(dstId)) {
150 const auto *srcNode = mdg.getNode(srcEdge.id);
151 // Skip if 'srcNode' is not a loop nest.
152 if (!isa<AffineForOp>(srcNode->op))
153 continue;
154
155 if (any_of(srcNode->stores, [&](Operation *op) {
156 auto storeOp = cast<AffineWriteOpInterface>(op);
157 return consumedMemrefs.count(storeOp.getMemRef()) > 0;
158 }))
159 srcIdCandidates.push_back(srcNode->id);
160 }
161
162 llvm::sort(srcIdCandidates);
163 srcIdCandidates.erase(llvm::unique(srcIdCandidates), srcIdCandidates.end());
164}
165
166/// Returns in 'producerConsumerMemrefs' the memrefs involved in a
167/// producer-consumer dependence between 'srcId' and 'dstId'.
168static void
169gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
170 const MemRefDependenceGraph &mdg,
171 DenseSet<Value> &producerConsumerMemrefs) {
172 auto *dstNode = mdg.getNode(dstId);
173 auto *srcNode = mdg.getNode(srcId);
174 gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads,
175 producerConsumerMemrefs);
176}
177
178/// A memref escapes in the context of the fusion pass if either:
179/// 1. it (or its alias) is a block argument, or
180/// 2. created by an op not known to guarantee alias freedom,
181/// 3. it (or its alias) are used by ops other than affine dereferencing ops
182/// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops,
183/// terminator ops, etc.); such ops do not deference the memref in an affine
184/// way.
185static bool isEscapingMemref(Value memref, Block *block) {
186 Operation *defOp = memref.getDefiningOp();
187 // Check if 'memref' is a block argument.
188 if (!defOp)
189 return true;
190
191 // Check if this is defined to be an alias of another memref.
192 if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
193 if (memref == viewOp.getViewDest() &&
194 isEscapingMemref(viewOp.getViewSource(), block))
195 return true;
196
197 // Any op besides allocating ops wouldn't guarantee alias freedom
199 return true;
200
201 // Check if 'memref' is used by a non-deferencing op (including unknown ones)
202 // (e.g., call ops, alias creating ops, etc.).
203 return llvm::any_of(memref.getUsers(), [&](Operation *user) {
204 // Ignore users outside of `block`.
205 Operation *ancestorOp = block->getParent()->findAncestorOpInRegion(*user);
206 if (!ancestorOp)
207 return true;
208 if (ancestorOp->getBlock() != block)
209 return false;
210 return !isa<AffineMapAccessInterface>(*user);
211 });
212}
213
214/// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
215/// that escape the block or are accessed in a non-affine way.
216static void gatherEscapingMemrefs(unsigned id, const MemRefDependenceGraph &mdg,
217 DenseSet<Value> &escapingMemRefs) {
218 auto *node = mdg.getNode(id);
219 for (Operation *storeOp : node->stores) {
220 auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
221 if (escapingMemRefs.count(memref))
222 continue;
223 if (isEscapingMemref(memref, &mdg.block))
224 escapingMemRefs.insert(memref);
225 }
226}
227
228// Sinks all sequential loops to the innermost levels (while preserving
229// relative order among them) and moves all parallel loops to the
230// outermost (while again preserving relative order among them).
231// This can increase the loop depth at which we can fuse a slice, since we are
232// pushing loop carried dependence to a greater depth in the loop nest.
234 assert(isa<AffineForOp>(node->op));
235 AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op));
236 node->op = newRootForOp;
237}
238
239/// Get the operation that should act as a dominance filter while replacing
240/// memref uses with a private memref for which `producerStores` and
241/// `sliceInsertionBlock` are provided. This effectively determines in what
242/// part of the IR we should be performing the replacement.
243static Operation *
245 ArrayRef<Operation *> producerStores) {
246 assert(!producerStores.empty() && "expected producer store");
247
248 // We first find the common block that contains the producer stores and
249 // the slice computation. The first ancestor among the ancestors of the
250 // producer stores in that common block is the dominance filter to use for
251 // replacement.
252 Block *commonBlock = nullptr;
253 // Find the common block of all relevant operations.
254 for (Operation *store : producerStores) {
255 Operation *otherOp =
256 !commonBlock ? &*sliceInsertionBlock->begin() : &*commonBlock->begin();
257 commonBlock = findInnermostCommonBlockInScope(store, otherOp);
259 assert(commonBlock &&
260 "common block of producer stores and slice should exist");
262 // Find the first ancestor among the ancestors of `producerStores` in
263 // `commonBlock`.
264 Operation *firstAncestor = nullptr;
265 for (Operation *store : producerStores) {
266 Operation *ancestor = commonBlock->findAncestorOpInBlock(*store);
267 assert(ancestor && "producer store should be contained in common block");
268 firstAncestor = !firstAncestor || ancestor->isBeforeInBlock(firstAncestor)
269 ? ancestor
270 : firstAncestor;
272 return firstAncestor;
273}
275/// Returns the amount of additional (redundant) computation that will be done
276/// as a fraction of the total computation if `srcForOp` is fused into
277/// `dstForOp` at depth `depth`. The method returns the compute cost of the
278/// slice and the fused nest's compute cost in the trailing output arguments.
279static std::optional<double> getAdditionalComputeFraction(
280 AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth,
281 ArrayRef<ComputationSliceState> depthSliceUnions, int64_t &sliceCost,
282 int64_t &fusedLoopNestComputeCost) {
283 LDBG() << "Determining additional compute fraction...";
284 // Compute cost of sliced and unsliced src loop nest.
285 // Walk src loop nest and collect stats.
286 LoopNestStats srcLoopNestStats;
287 if (!getLoopNestStats(srcForOp, &srcLoopNestStats)) {
288 LDBG() << "Failed to get source loop nest stats.";
289 return std::nullopt;
291
292 // Compute cost of dst loop nest.
293 LoopNestStats dstLoopNestStats;
294 if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) {
295 LDBG() << "Failed to get destination loop nest stats.";
296 return std::nullopt;
297 }
298
299 // Compute op instance count for the src loop nest without iteration slicing.
300 uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
301
302 // Compute op cost for the dst loop nest.
303 uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
304
305 const ComputationSliceState &slice = depthSliceUnions[depth - 1];
306 // Skip slice union if it wasn't computed for this depth.
307 if (slice.isEmpty()) {
308 LDBG() << "Slice wasn't computed.";
309 return std::nullopt;
312 if (!getFusionComputeCost(srcForOp, srcLoopNestStats, dstForOp,
313 dstLoopNestStats, slice,
314 &fusedLoopNestComputeCost)) {
315 LDBG() << "Unable to compute fusion compute cost";
316 return std::nullopt;
317 }
318
319 double additionalComputeFraction =
320 fusedLoopNestComputeCost /
321 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
322 1;
323
324 return additionalComputeFraction;
325}
326
327// Creates and returns a private (single-user) memref for fused loop rooted at
328// 'forOp', with (potentially reduced) memref size based on the memref region
329// written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock'
330// specifies the block in which the slice was/will be inserted. The method
331// expects that all stores ops to the memref have the same access function.
332// Returns nullptr if the creation failed.
333static Value createPrivateMemRef(AffineForOp forOp,
334 ArrayRef<Operation *> storeOps,
335 unsigned dstLoopDepth,
336 std::optional<unsigned> fastMemorySpace,
337 Block *sliceInsertionBlock,
338 uint64_t localBufSizeThreshold) {
339 assert(!storeOps.empty() && "no source stores supplied");
340
341 // Check if all stores have the same access function; we only support this
342 // case.
343 // TODO: Use union of memref write regions to compute private memref footprint
344 // for store ops with different access functions.
345 if (storeOps.size() > 1 &&
346 !std::equal(std::next(storeOps.begin()), storeOps.end(), storeOps.begin(),
347 [](Operation *a, Operation *b) {
348 MemRefAccess aM(cast<AffineWriteOpInterface>(a));
349 MemRefAccess bM(cast<AffineWriteOpInterface>(b));
350 return aM == bM;
351 })) {
352 LDBG() << "Private memref creation unsupported for multiple producer "
353 << "stores with different access functions.";
354 return nullptr;
355 }
356
357 Operation *srcStoreOp = storeOps[0];
358
359 // Create builder to insert alloc op just before 'forOp'.
360 OpBuilder b(forOp);
361 // Builder to create constants at the top level.
362 OpBuilder top(forOp->getParentRegion());
363 // Create new memref type based on slice bounds.
364 auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOp).getMemRef();
365 auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
366 unsigned rank = oldMemRefType.getRank();
367
368 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
369 MemRefRegion region(srcStoreOp->getLoc());
370 bool validRegion = succeeded(
371 region.compute(srcStoreOp, dstLoopDepth, /*sliceState=*/nullptr,
372 /*addMemRefDimBounds=*/true, /*dropLocalVars=*/false));
373
374 (void)validRegion;
375 assert(validRegion && "unexpected memref region failure");
378 lbs.reserve(rank);
379 // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
380 // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
381 std::optional<int64_t> numElements =
382 region.getConstantBoundingSizeAndShape(&newShape, &lbs);
383 assert(numElements && "non-constant number of elts in local buffer");
384
385 const FlatAffineValueConstraints *cst = region.getConstraints();
386 // 'outerIVs' holds the values that this memory region is symbolic/parametric
387 // on; this would correspond to loop IVs surrounding the level at which the
388 // slice is being materialized.
389 SmallVector<Value, 8> outerIVs;
390 cst->getValues(rank, cst->getNumDimAndSymbolVars(), &outerIVs);
391
392 // Build 'rank' AffineExprs from MemRefRegion 'lbs'
394 offsets.reserve(rank);
395
396 // Outer IVs are considered symbols during memref region computation. Replace
397 // them uniformly with dims so that valid IR is guaranteed.
398 SmallVector<AffineExpr> replacements;
399 for (unsigned j = 0, e = lbs[0].getNumSymbols(); j < e; ++j)
400 replacements.push_back(mlir::getAffineDimExpr(j, forOp.getContext()));
401 for (unsigned d = 0; d < rank; ++d) {
402 assert(lbs[d].getNumResults() == 1 &&
403 "invalid private memref bound calculation");
404 offsets.push_back(lbs[d].getResult(0).replaceSymbols(replacements));
405 }
406
407 // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
408 // by 'srcStoreOpInst'.
409 auto eltSize = getMemRefIntOrFloatEltSizeInBytes(oldMemRefType);
410 assert(eltSize && "memrefs with size elt types expected");
411 uint64_t bufSize = *eltSize * *numElements;
412 Attribute newMemSpace;
413 if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
414 newMemSpace = b.getI64IntegerAttr(*fastMemorySpace);
415 } else {
416 newMemSpace = oldMemRefType.getMemorySpace();
417 }
418 auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
419 /*map=*/AffineMap(), newMemSpace);
420
421 // Create new private memref for fused loop 'forOp'. 'newShape' is always
422 // a constant shape.
423 // TODO: Create/move alloc ops for private memrefs closer to their
424 // consumer loop nests to reduce their live range. Currently they are added
425 // at the beginning of the block, because loop nests can be reordered
426 // during the fusion pass.
427 Value newMemRef = memref::AllocOp::create(top, forOp.getLoc(), newMemRefType);
428
429 // Build an AffineMap to remap access functions based on lower bound offsets.
431 remapExprs.reserve(rank);
432 for (unsigned i = 0; i < rank; i++) {
433 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
434
435 auto remapExpr =
436 simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
437 remapExprs.push_back(remapExpr);
438 }
439
440 auto indexRemap =
441 AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
442
443 // Replace all users of 'oldMemRef' with 'newMemRef'.
444 Operation *domFilter =
445 getDominanceFilterForPrivateMemRefRepl(sliceInsertionBlock, storeOps);
446 auto userFilterFn = [&](Operation *user) {
447 auto domInfo = std::make_unique<DominanceInfo>(
448 domFilter->getParentOfType<FunctionOpInterface>());
449 return domInfo->dominates(domFilter, user);
450 };
451 LogicalResult res = replaceAllMemRefUsesWith(
452 oldMemRef, newMemRef, /*extraIndices=*/{}, indexRemap,
453 /*extraOperands=*/outerIVs,
454 /*symbolOperands=*/{}, userFilterFn);
455 assert(succeeded(res) &&
456 "replaceAllMemrefUsesWith should always succeed here");
457 (void)res;
458 LDBG() << "Created private memref of type: " << newMemRefType;
459 return newMemRef;
460}
461
462// Checks the profitability of fusing a backwards slice of the loop nest
463// `srcForOp` into the loop nest surrounding 'dstLoadOpInsts'. The argument
464// 'srcStoreOpInst' is used to calculate the storage reduction on the memref
465// being produced and consumed, which is an input to the cost model. For
466// producer-consumer fusion, 'srcStoreOpInst' will be the same as 'srcOpInst',
467// as we are slicing w.r.t to that producer. For input-reuse fusion, 'srcOpInst'
468// will be the src loop nest LoadOp which reads from the same memref as dst loop
469// nest load ops, and 'srcStoreOpInst' will be the unique store op in the src
470// node, which will be used to check that the write region is the same after
471// input-reuse fusion. Computation slices are provided in 'depthSliceUnions' for
472// each legal fusion depth. The maximal depth at which fusion is legal is
473// provided in 'maxLegalFusionDepth'. Returns true if it is profitable to fuse
474// the candidate loop nests. Returns false otherwise. `dstLoopDepth` is set to
475// the most profitable depth at which to materialize the source loop nest slice.
476// The profitability model executes the following steps:
477// *) Computes the backward computation slice at 'srcOpInst'. This
478// computation slice of the loop nest surrounding 'srcOpInst' is
479// represented by modified src loop bounds in 'sliceState', which are
480// functions of loop IVs in the loop nest surrounding 'srcOpInst'.
481// *) Computes the cost of unfused src/dst loop nests (currently the cost of a
482// loop nest is the total number of dynamic operation instances in the loop
483// nest).
484// *) Computes the cost of fusing a slice of the src loop nest into the dst
485// loop nest at various values of dst loop depth, attempting to fuse
486// the largest computation slice at the maximal dst loop depth (closest to
487// the load) to minimize reuse distance and potentially enable subsequent
488// load/store forwarding.
489// NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
490// nest, at which the src computation slice is inserted/fused.
491// NOTE: We attempt to maximize the dst loop depth, but there are cases
492// where a particular setting for 'dstLoopNest' might fuse an unsliced
493// loop (within the src computation slice) at a depth which results in
494// excessive recomputation (see unit tests for examples).
495// *) Compares the total cost of the unfused loop nests to the min cost fused
496// loop nest computed in the previous step, and returns true if the latter
497// is lower.
498// TODO: Extend profitability analysis to support scenarios with multiple
499// stores.
500static bool isFusionProfitable(AffineForOp srcForOp,
501 ArrayRef<Operation *> producerStores,
502 AffineForOp dstForOp,
503 ArrayRef<ComputationSliceState> depthSliceUnions,
504 unsigned maxLegalFusionDepth,
505 unsigned *dstLoopDepth,
506 double computeToleranceThreshold) {
507 LDBG() << "Checking whether fusion is profitable between source nest:";
508 LDBG() << ' ' << srcForOp << " and destination nest:";
509 LDBG() << dstForOp;
510
511 if (maxLegalFusionDepth == 0) {
512 LDBG() << "Can't fuse: maxLegalFusionDepth is 0";
513 return false;
514 }
515
516 // Compute cost of sliced and unsliced src loop nest.
517
518 // Walk src loop nest and collect stats.
519 LoopNestStats srcLoopNestStats;
520 if (!getLoopNestStats(srcForOp, &srcLoopNestStats))
521 return false;
522
523 // Compute cost of dst loop nest.
524 LoopNestStats dstLoopNestStats;
525 if (!getLoopNestStats(dstForOp, &dstLoopNestStats))
526 return false;
527
528 // We limit profitability analysis to only scenarios with
529 // a single producer store for now. Note that some multi-store
530 // producer scenarios will still go through profitability analysis
531 // if only one of the stores is involved in the producer-consumer
532 // relationship of the candidate loops.
533 // TODO: Suppport multiple producer stores in profitability
534 // analysis.
535 if (producerStores.size() > 1) {
536 LDBG() << "Limited profitability analysis. Not "
537 << "supported for multiple producer store case.";
538 int64_t sliceCost;
539 int64_t fusedLoopNestComputeCost;
540 // We will still fuse if fusion obeys the specified compute
541 // tolerance at the max legal depth.
542 auto fraction = getAdditionalComputeFraction(
543 srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost,
544 fusedLoopNestComputeCost);
545 if (!fraction || fraction > computeToleranceThreshold) {
546 LDBG() << "Additional computation exceeds "
547 << "compute tolerance. Not fusing.";
548 return false;
549 }
550 LDBG() << "Considering fusion profitable at max legal depth.";
551 return true;
552 }
553
554 Operation *srcStoreOp = producerStores.front();
555
556 // Search for min cost value for 'dstLoopDepth'. At each value of
557 // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice
558 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
559 // of these bounds). Next the union slice bounds are used to calculate
560 // the cost of the slice and the cost of the slice inserted into the dst
561 // loop nest at 'dstLoopDepth'.
562 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
563 double maxStorageReduction = 0.0;
564 std::optional<uint64_t> sliceMemEstimate;
565
566 // The best loop depth at which to materialize the slice.
567 std::optional<unsigned> bestDstLoopDepth;
568
569 // Compute src loop nest write region size.
570 MemRefRegion srcWriteRegion(srcStoreOp->getLoc());
571 if (failed(srcWriteRegion.compute(srcStoreOp, /*loopDepth=*/0))) {
572 LDBG() << "Unable to compute MemRefRegion for source operation";
573 return false;
574 }
575
576 std::optional<int64_t> maybeSrcWriteRegionSizeBytes =
577 srcWriteRegion.getRegionSize();
578 if (!maybeSrcWriteRegionSizeBytes.has_value())
579 return false;
580 int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;
581
582 // Compute op instance count for the src loop nest without iteration slicing.
583 uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
584
585 // Compute op instance count for the destination loop nest.
586 uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
587
588 // Evaluate all depth choices for materializing the slice in the destination
589 // loop nest.
590 for (unsigned i = maxLegalFusionDepth; i >= 1; --i) {
591 const ComputationSliceState &slice = depthSliceUnions[i - 1];
592 // Skip slice union if it wasn't computed for this depth.
593 if (slice.isEmpty())
594 continue;
595
596 // Compute cost of the slice separately, i.e, the compute cost of the slice
597 // if all outer trip counts are one.
598 int64_t sliceCost;
599
600 int64_t fusedLoopNestComputeCost;
601
602 auto mayAdditionalComputeFraction =
603 getAdditionalComputeFraction(srcForOp, dstForOp, i, depthSliceUnions,
604 sliceCost, fusedLoopNestComputeCost);
605 if (!mayAdditionalComputeFraction) {
606 LDBG() << "Can't determine additional compute fraction.";
607 continue;
608 }
609 double additionalComputeFraction = *mayAdditionalComputeFraction;
610
611 // Determine what the slice write MemRefRegion would be, if the src loop
612 // nest slice 'slice' were to be inserted into the dst loop nest at loop
613 // depth 'i'.
614 MemRefRegion sliceWriteRegion(srcStoreOp->getLoc());
615 if (failed(sliceWriteRegion.compute(srcStoreOp, /*loopDepth=*/0, &slice))) {
616 LDBG() << "Failed to compute slice write region at loopDepth: " << i;
617 continue;
618 }
619
620 std::optional<int64_t> maybeSliceWriteRegionSizeBytes =
621 sliceWriteRegion.getRegionSize();
622 if (!maybeSliceWriteRegionSizeBytes.has_value() ||
623 *maybeSliceWriteRegionSizeBytes == 0) {
624 LDBG() << "Failed to get slice write region size at loopDepth: " << i;
625 continue;
626 }
627 int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
628
629 double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
630 static_cast<double>(sliceWriteRegionSizeBytes);
631
632 LLVM_DEBUG({
633 std::stringstream msg;
634 msg << " evaluating fusion profitability at depth : " << i << "\n"
635 << std::fixed << std::setprecision(2)
636 << " additional compute fraction: "
637 << 100.0 * additionalComputeFraction << "%\n"
638 << " storage reduction factor: " << storageReduction << "x\n"
639 << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
640 << " src write region size: " << srcWriteRegionSizeBytes << "\n"
641 << " slice write region size: " << sliceWriteRegionSizeBytes;
642 LDBG() << msg.str();
643 });
644
645 // TODO: This is a placeholder cost model.
646 // Among all choices that add an acceptable amount of redundant computation
647 // (as per computeToleranceThreshold), we will simply pick the one that
648 // reduces the intermediary size the most.
649 if ((storageReduction > maxStorageReduction) &&
650 (additionalComputeFraction <= computeToleranceThreshold)) {
651 maxStorageReduction = storageReduction;
652 bestDstLoopDepth = i;
653 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
654 sliceMemEstimate = sliceWriteRegionSizeBytes;
655 }
656 }
657
658 // A simple cost model: fuse if it reduces the memory footprint.
659
660 if (!bestDstLoopDepth) {
661 LDBG() << "All fusion choices involve more than the threshold amount of "
662 << "redundant computation; NOT fusing.";
663 return false;
664 }
665
666 if (!bestDstLoopDepth) {
667 LDBG() << "no fusion depth could be evaluated.";
668 return false;
669 }
670
671 // Set dstLoopDepth based on best values from search.
672 *dstLoopDepth = *bestDstLoopDepth;
673
674 LDBG() << " LoopFusion fusion stats:";
675 LDBG() << " best loop depth: " << bestDstLoopDepth;
676 LDBG() << " src loop nest compute cost: " << srcLoopNestCost;
677 LDBG() << " dst loop nest compute cost: " << dstLoopNestCost;
678 LDBG() << " fused loop nest compute cost: " << minFusedLoopNestComputeCost;
679
680 auto dstMemSize = getMemoryFootprintBytes(dstForOp);
681 auto srcMemSize = getMemoryFootprintBytes(srcForOp);
682
683 std::optional<double> storageReduction;
684
685 if (!dstMemSize || !srcMemSize) {
686 LDBG() << " fusion memory benefit cannot be evaluated; NOT fusing.";
687 return false;
688 }
689
690 auto srcMemSizeVal = *srcMemSize;
691 auto dstMemSizeVal = *dstMemSize;
692
693 assert(sliceMemEstimate && "expected value");
694 auto fusedMem = dstMemSizeVal + *sliceMemEstimate;
695
696 LDBG() << " src mem: " << srcMemSizeVal;
697 LDBG() << " dst mem: " << dstMemSizeVal;
698 LDBG() << " fused mem: " << fusedMem;
699 LDBG() << " slice mem: " << sliceMemEstimate;
700
701 if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
702 LDBG() << "Fusion is not profitable; NOT fusing.";
703 return false;
704 }
705 storageReduction =
706 100.0 *
707 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
708
709 double additionalComputeFraction =
710 100.0 * (minFusedLoopNestComputeCost /
711 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
712 1);
713 (void)additionalComputeFraction;
714 LLVM_DEBUG({
715 std::stringstream msg;
716 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
717 << std::setprecision(2) << additionalComputeFraction
718 << "% redundant computation and a ";
719 msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>");
720 msg << "% storage reduction.";
721 LDBG() << msg.str();
722 });
723
724 return true;
725}
726
727namespace {
728
729// GreedyFusion greedily fuses loop nests which have a producer/consumer or
730// input-reuse relationship on a memref, with the goal of improving locality.
731//
732// The steps of the producer-consumer fusion algorithm are as follows:
733//
734// *) A worklist is initialized with node ids from the dependence graph.
735// *) For each node id in the worklist:
736// *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a
737// candidate destination AffineForOp into which fusion will be attempted.
738// *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
739// *) For each LoadOp in 'dstLoadOps' do:
740// *) Look up dependent loop nests which have a single store op to the same
741// memref.
742// *) Check if dependences would be violated by the fusion.
743// *) Get a computation slice of 'srcLoopNest', which adjusts its loop
744// bounds to be functions of 'dstLoopNest' IVs and symbols.
745// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
746// at a loop depth determined by the cost model in 'isFusionProfitable'.
747// *) Add the newly fused load/store operations to the state,
748// and also add newly fused load ops to 'dstLoopOps' to be considered
749// as fusion dst load ops in another iteration.
750// *) Remove old src loop nest and its associated state.
751//
752// The steps of the input-reuse fusion algorithm are as follows:
753//
754// *) Initialize 'worklist' with node ids from the dependence graph.
755// *) For each 'dstNode' in the worklist:
756// *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which
757// loads from the same memref, but which has no dependence paths to/from.
758// *) Get a computation slice of 'sibLoopNest', which adjusts its loop
759// bounds to be functions of 'dstLoopNest' IVs and symbols.
760// *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest',
761// at a loop depth determined by the cost model in 'isFusionProfitable'.
762// This function also checks that the memref write region of 'sibLoopNest',
763// is preserved in the fused loop nest.
764// *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'.
765//
766// Given a graph where top-level operations are vertices in the set 'V' and
767// edges in the set 'E' are dependences between vertices, this algorithm
768// takes O(V) time for initialization, and has runtime O(V + E).
769//
770// This greedy algorithm is not 'maximal' due to the current restriction of
771// fusing along single producer consumer edges, but there is a TODO: to fix
772// this.
773//
774// TODO: Experiment with other fusion policies.
775struct GreedyFusion {
776public:
777 // The data dependence graph to traverse during fusion.
778 MemRefDependenceGraph *mdg;
779 // Worklist of graph nodes visited during the fusion pass.
780 SmallVector<unsigned, 8> worklist;
781 // Parameter for local buffer size threshold.
782 unsigned localBufSizeThreshold;
783 // Parameter for fast memory space.
784 std::optional<unsigned> fastMemorySpace;
785 // If true, ignore any additional (redundant) computation tolerance threshold
786 // that would have prevented fusion.
787 bool maximalFusion;
788 // The amount of additional computation that is tolerated while fusing
789 // pair-wise as a fraction of the total computation.
790 double computeToleranceThreshold;
791
792 using Node = MemRefDependenceGraph::Node;
793
794 GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold,
795 std::optional<unsigned> fastMemorySpace, bool maximalFusion,
796 double computeToleranceThreshold)
797 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
798 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
799 computeToleranceThreshold(computeToleranceThreshold) {}
800
801 /// Initializes 'worklist' with nodes from 'mdg'.
802 void init() {
803 // TODO: Add a priority queue for prioritizing nodes by different
804 // metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
805 worklist.clear();
806 for (auto &idAndNode : mdg->nodes) {
807 const Node &node = idAndNode.second;
808 worklist.push_back(node.id);
809 }
810 }
811 /// Run only sibling fusion on the `mdg`.
812 void runSiblingFusionOnly() {
813 fuseSiblingNodes();
814 eraseUnusedMemRefAllocations();
815 }
816
817 /// Run only producer/consumer fusion on the `mdg`.
818 void runProducerConsumerFusionOnly() {
819 fuseProducerConsumerNodes(
820 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
821 eraseUnusedMemRefAllocations();
822 }
823
824 // Run the GreedyFusion pass.
825 // *) First pass through the nodes fuses single-use producer nodes into their
826 // unique consumer.
827 // *) Second pass fuses sibling nodes which share no dependence edges.
828 // *) Third pass fuses any remaining producer nodes into their users.
829 void runGreedyFusion() {
830 // TODO: Run this repeatedly until a fixed-point is reached.
831 fuseProducerConsumerNodes(/*maxSrcUserCount=*/1);
832 fuseSiblingNodes();
833 fuseProducerConsumerNodes(
834 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
835 eraseUnusedMemRefAllocations();
836 }
837
838 /// Returns true if a private memref can be created for `memref` given
839 /// the fusion scenario reflected by the other arguments.
840 bool canCreatePrivateMemRef(Value memref,
841 const DenseSet<Value> &srcEscapingMemRefs,
842 unsigned producerId, unsigned consumerId,
843 bool removeSrcNode) {
844 // We can't generate private memrefs if their size can't be computed.
845 if (!getMemRefIntOrFloatEltSizeInBytes(cast<MemRefType>(memref.getType())))
846 return false;
847 const Node *consumerNode = mdg->getNode(consumerId);
848 // If `memref` is an escaping one, do not create a private memref
849 // for the below scenarios, since doing so will leave the escaping
850 // memref unmodified as all the writes originally meant for the
851 // escaping memref would be performed on the private memref:
852 // 1. The source is to be removed after fusion,
853 // OR
854 // 2. The destination writes to `memref`.
855 if (srcEscapingMemRefs.count(memref) > 0 &&
856 (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
857 return false;
858
859 // Don't create a private memref if 'srcNode' has in edges on
860 // 'memref' or 'dstNode' has out edges on 'memref'.
861 if (mdg->getIncomingMemRefAccesses(producerId, memref) > 0 ||
862 mdg->getOutEdgeCount(consumerId, memref) > 0)
863 return false;
864
865 // If 'srcNode' will be removed but it has out edges on 'memref' to
866 // nodes other than 'dstNode', we have to preserve dependences and
867 // cannot create a private memref.
868 if (removeSrcNode &&
869 any_of(mdg->outEdges[producerId], [&](const auto &edge) {
870 return edge.value == memref && edge.id != consumerId;
871 }))
872 return false;
873
874 return true;
875 }
876
877 /// Perform fusions with node `dstId` as the destination of fusion, with
878 /// No fusion is performed when producers with a user count greater than
879 /// `maxSrcUserCount` for any of the memrefs involved.
880 void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) {
881 LDBG() << "Evaluating dst loop " << dstId;
882 // Skip if this node was removed (fused into another node).
883 if (mdg->nodes.count(dstId) == 0)
884 return;
885 // Get 'dstNode' into which to attempt fusion.
886 auto *dstNode = mdg->getNode(dstId);
887 // Skip if 'dstNode' is not a loop nest.
888 if (!isa<AffineForOp>(dstNode->op))
889 return;
890 // Skip if 'dstNode' is a loop nest returning values.
891 // TODO: support loop nests that return values.
892 if (dstNode->op->getNumResults() > 0)
893 return;
894
895 LDBG() << "Evaluating dst loop " << dstId;
896
897 // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
898 // while preserving relative order. This can increase the maximum loop
899 // depth at which we can fuse a slice of a producer loop nest into a
900 // consumer loop nest.
901 sinkSequentialLoops(dstNode);
902 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
903
904 // Try to fuse 'dstNode' with candidate producer loops until a fixed point
905 // is reached. Fusing two loops may expose new fusion opportunities.
906 bool dstNodeChanged;
907 do {
908 // Gather src loop candidates for 'dstNode' and visit them in "quasi"
909 // reverse program order to minimize the number of iterations needed to
910 // reach the fixed point. Note that this is a best effort approach since
911 // 'getProducerCandidates' does not always guarantee that program order
912 // in 'srcIdCandidates'.
913 dstNodeChanged = false;
914 SmallVector<unsigned, 16> srcIdCandidates;
915 getProducerCandidates(dstId, *mdg, srcIdCandidates);
916
917 for (unsigned srcId : llvm::reverse(srcIdCandidates)) {
918 // Get 'srcNode' from which to attempt fusion into 'dstNode'.
919 auto *srcNode = mdg->getNode(srcId);
920 auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
921
922 LDBG() << "Trying to fuse producer loop nest " << srcId
923 << " with consumer loop nest " << dstId;
924 LDBG() << "Compute tolerance threshold: " << computeToleranceThreshold;
925 LDBG() << "Producer loop nest:";
926 LDBG() << *srcNode->op << " and consumer loop nest:";
927 LDBG() << *dstNode->op;
928
929 LDBG() << "Evaluating src loop " << srcId << " for dst loop " << dstId;
930
931 // Skip if 'srcNode' is a loop nest returning values.
932 // TODO: support loop nests that return values.
933 if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
934 continue;
935
936 DenseSet<Value> producerConsumerMemrefs;
937 gatherProducerConsumerMemrefs(srcId, dstId, *mdg,
938 producerConsumerMemrefs);
939
940 // Skip if 'srcNode' out edge count on any memref is greater than
941 // 'maxSrcUserCount'.
942 if (any_of(producerConsumerMemrefs, [&](Value memref) {
943 return mdg->getOutEdgeCount(srcNode->id, memref) >
944 maxSrcUserCount;
945 }))
946 continue;
947
948 // Gather memrefs in 'srcNode' that are written and escape out of the
949 // block (e.g., memref block arguments, returned memrefs,
950 // memrefs passed to function calls, etc.).
951 DenseSet<Value> srcEscapingMemRefs;
952 gatherEscapingMemrefs(srcNode->id, *mdg, srcEscapingMemRefs);
953
954 // Compute an operation list insertion point for the fused loop
955 // nest which preserves dependences.
956 Operation *fusedLoopInsPoint =
957 mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
958 if (fusedLoopInsPoint == nullptr)
959 continue;
960
961 // It's possible this fusion is at an inner depth (i.e., there are
962 // common surrounding affine loops for the source and destination for
963 // ops). We need to get this number because the call to canFuseLoops
964 // needs to be passed the absolute depth. The max legal depth and the
965 // depths we try below are however *relative* and as such don't include
966 // the common depth.
967 SmallVector<AffineForOp, 4> surroundingLoops;
968 getAffineForIVs(*dstAffineForOp, &surroundingLoops);
969 unsigned numSurroundingLoops = surroundingLoops.size();
970
971 // Compute the innermost common loop depth for dstNode
972 // producer-consumer loads/stores.
973 SmallVector<Operation *, 2> dstMemrefOps;
974 for (Operation *op : dstNode->loads)
975 if (producerConsumerMemrefs.count(
976 cast<AffineReadOpInterface>(op).getMemRef()) > 0)
977 dstMemrefOps.push_back(op);
978 for (Operation *op : dstNode->stores)
979 if (producerConsumerMemrefs.count(
980 cast<AffineWriteOpInterface>(op).getMemRef()))
981 dstMemrefOps.push_back(op);
982 if (dstMemrefOps.empty())
983 continue;
984 unsigned dstLoopDepthTest =
985 getInnermostCommonLoopDepth(dstMemrefOps) - numSurroundingLoops;
986
987 // Check the feasibility of fusing src loop nest into dst loop nest
988 // at loop depths in range [1, dstLoopDepthTest].
989 unsigned maxLegalFusionDepth = 0;
990 SmallVector<ComputationSliceState, 8> depthSliceUnions;
991 depthSliceUnions.resize(dstLoopDepthTest);
992 FusionStrategy strategy(FusionStrategy::ProducerConsumer);
993 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
994 FusionResult result =
995 affine::canFuseLoops(srcAffineForOp, dstAffineForOp,
996 /*dstLoopDepth=*/i + numSurroundingLoops,
997 &depthSliceUnions[i - 1], strategy);
998 if (result.value == FusionResult::Success) {
999 maxLegalFusionDepth = i;
1000 LDBG() << "Found valid slice for depth: " << i;
1001 }
1002 }
1003
1004 if (maxLegalFusionDepth == 0) {
1005 LDBG() << "Can't fuse: fusion is not legal at any depth";
1006 continue;
1007 }
1008
1009 LDBG() << "Max legal depth for fusion: " << maxLegalFusionDepth;
1010
1011 double computeToleranceThresholdToUse = computeToleranceThreshold;
1012
1013 // Cyclic dependences in the source nest may be violated when performing
1014 // slicing-based fusion. They aren't actually violated in cases where no
1015 // redundant execution of the source happens (1:1 pointwise dep on the
1016 // producer-consumer memref access for example). Check this and allow
1017 // fusion accordingly.
1018 if (hasCyclicDependence(srcAffineForOp)) {
1019 LDBG() << "Source nest has a cyclic dependence.";
1020 // Maximal fusion does not check for compute tolerance threshold; so
1021 // perform the maximal fusion only when the redundanation computation
1022 // is zero.
1023 if (maximalFusion) {
1024 auto srcForOp = cast<AffineForOp>(srcNode->op);
1025 auto dstForOp = cast<AffineForOp>(dstNode->op);
1026 int64_t sliceCost;
1027 int64_t fusedLoopNestComputeCost;
1028 auto fraction = getAdditionalComputeFraction(
1029 srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
1030 sliceCost, fusedLoopNestComputeCost);
1031 if (!fraction || fraction > 0) {
1032 LDBG() << "Can't perform maximal fusion with a cyclic dependence "
1033 << "and non-zero additional compute.";
1034 return;
1035 }
1036 } else {
1037 // Set redundant computation tolerance to zero regardless of what
1038 // the user specified. Without this, fusion would be invalid.
1039 LDBG() << "Setting compute tolerance to zero since "
1040 << "source has a cylic dependence.";
1041 computeToleranceThresholdToUse = 0;
1042 }
1043 }
1044
1045 // Check if fusion would be profitable. We skip profitability analysis
1046 // for maximal fusion since we already know the maximal legal depth to
1047 // fuse.
1048 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1049 if (!maximalFusion) {
1050 // Retrieve producer stores from the src loop.
1051 SmallVector<Operation *, 2> producerStores;
1052 for (Operation *op : srcNode->stores)
1053 if (producerConsumerMemrefs.count(
1054 cast<AffineWriteOpInterface>(op).getMemRef()))
1055 producerStores.push_back(op);
1056
1057 assert(!producerStores.empty() && "Expected producer store");
1058 if (!isFusionProfitable(srcAffineForOp, producerStores,
1059 dstAffineForOp, depthSliceUnions,
1060 maxLegalFusionDepth, &bestDstLoopDepth,
1061 computeToleranceThresholdToUse)) {
1062 continue;
1063 }
1064 }
1065
1066 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
1067 ComputationSliceState &bestSlice =
1068 depthSliceUnions[bestDstLoopDepth - 1];
1069 assert(!bestSlice.isEmpty() && "Missing slice union for depth");
1070
1071 // Determine if 'srcId' can be removed after fusion, taking into
1072 // account remaining dependences, escaping memrefs and the fusion
1073 // insertion point.
1074 bool removeSrcNode = canRemoveSrcNodeAfterFusion(
1075 srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
1076 *mdg);
1077
1078 DenseSet<Value> privateMemrefs;
1079 for (Value memref : producerConsumerMemrefs) {
1080 if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
1081 removeSrcNode)) {
1082 // Create a private version of this memref.
1083 LDBG() << "Creating private memref for " << memref;
1084 // Create a private version of this memref.
1085 privateMemrefs.insert(memref);
1086 }
1087 }
1088
1089 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
1090 fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
1091 dstNodeChanged = true;
1092
1093 LDBG() << "Fused src loop " << srcId << " into dst loop " << dstId
1094 << " at depth " << bestDstLoopDepth << ":";
1095 LDBG() << dstAffineForOp;
1096
1097 // Move 'dstAffineForOp' before 'insertPointInst' if needed.
1098 if (fusedLoopInsPoint != dstAffineForOp)
1099 dstAffineForOp->moveBefore(fusedLoopInsPoint);
1100
1101 // Update edges between 'srcNode' and 'dstNode'.
1102 mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
1103 removeSrcNode);
1104
1105 // Create private memrefs.
1106 if (!privateMemrefs.empty()) {
1107 // Note the block into which fusion was performed. This can be used to
1108 // place `alloc`s that create private memrefs.
1109 Block *sliceInsertionBlock = bestSlice.insertPoint->getBlock();
1110
1111 // Gather stores for all the private-to-be memrefs.
1112 DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
1113 dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
1114 Value storeMemRef = storeOp.getMemRef();
1115 if (privateMemrefs.count(storeMemRef) > 0)
1116 privateMemRefToStores[storeMemRef].push_back(storeOp);
1117 });
1118
1119 // Replace original memrefs with private memrefs. Note that all the
1120 // loads and stores on these memrefs will be replaced with a new
1121 // loads and stores. Any reference to the original ones becomes
1122 // invalid after this point.
1123 for (auto &memrefToStoresPair : privateMemRefToStores) {
1124 ArrayRef<Operation *> storesForMemref = memrefToStoresPair.second;
1125 Value newMemRef = createPrivateMemRef(
1126 dstAffineForOp, storesForMemref, bestDstLoopDepth,
1127 fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
1128 if (!newMemRef)
1129 continue;
1130 // Create new node in dependence graph for 'newMemRef' alloc op.
1131 unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
1132 // Add edge from 'newMemRef' node to dstNode.
1133 mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
1134 }
1135 // One or more entries for 'newMemRef' alloc op are inserted into
1136 // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to
1137 // reallocate, update dstNode.
1138 dstNode = mdg->getNode(dstId);
1139 }
1140
1141 // Collect dst loop stats after memref privatization transformation.
1142 LoopNestStateCollector dstLoopCollector;
1143 dstLoopCollector.collect(dstAffineForOp);
1144
1145 // Clear and add back loads and stores.
1146 mdg->clearNodeLoadAndStores(dstNode->id);
1147 mdg->addToNode(
1148 dstId, dstLoopCollector.loadOpInsts, dstLoopCollector.storeOpInsts,
1149 dstLoopCollector.memrefLoads, dstLoopCollector.memrefStores,
1150 dstLoopCollector.memrefFrees);
1151
1152 if (removeSrcNode) {
1153 LDBG() << "Removing src loop " << srcId << " after fusion";
1154 // srcNode is no longer valid after it is removed from mdg.
1155 srcAffineForOp.erase();
1156 mdg->removeNode(srcId);
1157 srcNode = nullptr;
1158 }
1159 }
1160 } while (dstNodeChanged);
1161 }
1162
1163 /// Visit each node in the graph, and for each node, attempt to fuse it with
1164 /// producer-consumer candidates. No fusion is performed when producers with a
1165 /// user count greater than `maxSrcUserCount` for any of the memrefs involved
1166 /// are encountered.
1167 void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
1168 LDBG() << "--- Producer/Consumer Fusion ---";
1169 init();
1170 while (!worklist.empty()) {
1171 unsigned dstId = worklist.back();
1172 worklist.pop_back();
1173 performFusionsIntoDest(dstId, maxSrcUserCount);
1174 }
1175 }
1176
1177 // Visits each node in the graph, and for each node, attempts to fuse it with
1178 // its sibling nodes (nodes which share a parent, but no dependence edges).
1179 void fuseSiblingNodes() {
1180 LDBG() << "--- Sibling Fusion ---";
1181 init();
1182 while (!worklist.empty()) {
1183 unsigned dstId = worklist.back();
1184 worklist.pop_back();
1185
1186 // Skip if this node was removed (fused into another node).
1187 if (mdg->nodes.count(dstId) == 0)
1188 continue;
1189 // Get 'dstNode' into which to attempt fusion.
1190 auto *dstNode = mdg->getNode(dstId);
1191 // Skip if 'dstNode' is not a loop nest.
1192 if (!isa<AffineForOp>(dstNode->op))
1193 continue;
1194 // Attempt to fuse 'dstNode' with its sibling nodes in the graph.
1195 fuseWithSiblingNodes(dstNode);
1196 }
1197 }
1198
1199 // Attempt to fuse 'dstNode' with sibling nodes in the graph.
1200 void fuseWithSiblingNodes(Node *dstNode) {
1201 DenseSet<unsigned> visitedSibNodeIds;
1202 std::pair<unsigned, Value> idAndMemref;
1203 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
1204
1205 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
1206 unsigned sibId = idAndMemref.first;
1207 Value memref = idAndMemref.second;
1208 // TODO: Check that 'sibStoreOpInst' post-dominates all other
1209 // stores to the same memref in 'sibNode' loop nest.
1210 auto *sibNode = mdg->getNode(sibId);
1211 // Compute an operation list insertion point for the fused loop
1212 // nest which preserves dependences.
1213 assert(sibNode->op->getBlock() == dstNode->op->getBlock());
1214 Operation *insertPointInst =
1215 sibNode->op->isBeforeInBlock(dstNode->op)
1216 ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id)
1217 : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id);
1218 if (insertPointInst == nullptr)
1219 continue;
1220
1221 // Check if fusion would be profitable and at what depth.
1222
1223 // Get unique 'sibNode' load op to 'memref'.
1224 SmallVector<Operation *, 2> sibLoadOpInsts;
1225 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
1226 // Currently findSiblingNodeToFuse searches for siblings with one load.
1227 Operation *sibLoadOpInst = llvm::getSingleElement(sibLoadOpInsts);
1228
1229 // Gather 'dstNode' load ops to 'memref'.
1230 SmallVector<Operation *, 2> dstLoadOpInsts;
1231 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
1232
1233 // It's possible this fusion is at an inner depth (i.e., there are common
1234 // surrounding affine loops for the source and destination for ops). We
1235 // need to get this number because the call to canFuseLoops needs to be
1236 // passed the absolute depth. The max legal depth and the depths we try
1237 // below are however *relative* and as such don't include the common
1238 // depth.
1239 SmallVector<AffineForOp, 4> surroundingLoops;
1240 getAffineForIVs(*dstAffineForOp, &surroundingLoops);
1241 unsigned numSurroundingLoops = surroundingLoops.size();
1242 SmallVector<AffineForOp, 4> dstLoopIVs;
1243 getAffineForIVs(*dstLoadOpInsts[0], &dstLoopIVs);
1244 unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
1245 auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
1246
1247 // Compute loop depth and slice union for fusion.
1248 SmallVector<ComputationSliceState, 8> depthSliceUnions;
1249 depthSliceUnions.resize(dstLoopDepthTest);
1250 unsigned maxLegalFusionDepth = 0;
1251 FusionStrategy strategy(memref);
1252 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1253 FusionResult result =
1254 affine::canFuseLoops(sibAffineForOp, dstAffineForOp,
1255 /*dstLoopDepth=*/i + numSurroundingLoops,
1256 &depthSliceUnions[i - 1], strategy);
1257
1258 if (result.value == FusionResult::Success)
1259 maxLegalFusionDepth = i;
1260 }
1261
1262 LDBG() << "Max legal depth for fusion: " << maxLegalFusionDepth;
1263
1264 // Skip if fusion is not feasible at any loop depths.
1265 if (maxLegalFusionDepth == 0)
1266 continue;
1267
1268 double computeToleranceThresholdToUse = computeToleranceThreshold;
1269
1270 // Cyclic dependences in the source nest may be violated when performing
1271 // slicing-based fusion. They aren't actually violated in cases where no
1272 // redundant execution of the source happens (1:1 pointwise dep on the
1273 // producer-consumer memref access for example). Check this and allow
1274 // fusion accordingly.
1275 if (hasCyclicDependence(sibAffineForOp)) {
1276 LDBG() << "Source nest has a cyclic dependence.";
1277 // Maximal fusion does not check for compute tolerance threshold; so
1278 // perform the maximal fusion only when the redundanation computation is
1279 // zero.
1280 if (maximalFusion) {
1281 auto dstForOp = cast<AffineForOp>(dstNode->op);
1282 int64_t sliceCost;
1283 int64_t fusedLoopNestComputeCost;
1284 auto fraction = getAdditionalComputeFraction(
1285 sibAffineForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
1286 sliceCost, fusedLoopNestComputeCost);
1287 if (!fraction || fraction > 0) {
1288 LDBG() << "Can't perform maximal fusion with a cyclic dependence "
1289 << "and non-zero additional compute.";
1290 return;
1291 }
1292 } else {
1293 // Set redundant computation tolerance to zero regardless of what the
1294 // user specified. Without this, fusion would be invalid.
1295 LDBG() << "Setting compute tolerance to zero since "
1296 << "source has a cyclic dependence.";
1297 computeToleranceThresholdToUse = 0.0;
1298 }
1299 }
1300
1301 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1302 if (!maximalFusion) {
1303 // Check if fusion would be profitable. For sibling fusion, the sibling
1304 // load op is treated as the src "store" op for fusion profitability
1305 // purposes. The footprint of the load in the slice relative to the
1306 // unfused source's determines reuse.
1307 if (!isFusionProfitable(sibAffineForOp, sibLoadOpInst, dstAffineForOp,
1308 depthSliceUnions, maxLegalFusionDepth,
1309 &bestDstLoopDepth,
1310 computeToleranceThresholdToUse))
1311 continue;
1312 }
1313
1314 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
1315
1316 const ComputationSliceState &bestSlice =
1317 depthSliceUnions[bestDstLoopDepth - 1];
1318 assert(!bestSlice.isEmpty() &&
1319 "Fusion depth has no computed slice union");
1320
1321 // Do not perform sibling fusion if it isn't maximal. We always remove the
1322 // sibling node and as such fusion shouldn't be performed if a part of the
1323 // slice is used in the destination.
1324 auto isMaximal = bestSlice.isMaximal();
1325 if (!isMaximal.value_or(false)) {
1326 LDBG() << "Slice isn't maximal; not performing sibling fusion.";
1327 continue;
1328 }
1329
1330 // Check if source loop is being inserted in the innermost
1331 // destination loop. Based on this, the fused loop may be optimized
1332 // further inside `fuseLoops`.
1333 bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
1334 // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
1335 affine::fuseLoops(sibAffineForOp, dstAffineForOp, bestSlice,
1336 isInnermostInsertion);
1337
1338 auto dstForInst = cast<AffineForOp>(dstNode->op);
1339 // Update operation position of fused loop nest (if needed).
1340 if (insertPointInst != dstForInst)
1341 dstForInst->moveBefore(insertPointInst);
1342
1343 LDBG() << "Fused sibling nest " << sibId << " into destination nest "
1344 << dstNode->id << " at depth " << bestDstLoopDepth << ":";
1345 LDBG() << dstAffineForOp;
1346
1347 // Update data dependence graph state post fusion.
1348 updateStateAfterSiblingFusion(sibNode, dstNode);
1349
1350 // Remove old sibling loop nest.
1351 // Get op before we invalidate the MDG node.
1352 Operation *op = sibNode->op;
1353 mdg->removeNode(sibNode->id);
1354 op->erase();
1355 }
1356 }
1357
1358 // Searches block argument uses and the graph from 'dstNode' looking for a
1359 // fusion candidate sibling node which shares no dependences with 'dstNode'
1360 // but which loads from the same memref. Returns true and sets
1361 // 'idAndMemrefToFuse' on success. Returns false otherwise.
1362 bool findSiblingNodeToFuse(Node *dstNode,
1363 DenseSet<unsigned> *visitedSibNodeIds,
1364 std::pair<unsigned, Value> *idAndMemrefToFuse) {
1365 // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse
1366 // on 'memref'.
1367 auto canFuseWithSibNode = [&](Node *sibNode, Value memref) {
1368 // Skip if 'outEdge' is not a read-after-write dependence.
1369 // TODO: Remove restrict to single load op restriction.
1370 if (sibNode->getLoadOpCount(memref) != 1)
1371 return false;
1372 // Skip if there exists a path of dependent edges between
1373 // 'sibNode' and 'dstNode'.
1374 if (mdg->hasDependencePath(sibNode->id, dstNode->id) ||
1375 mdg->hasDependencePath(dstNode->id, sibNode->id))
1376 return false;
1377 // Skip sib node if it loads to (and stores from) the same memref on
1378 // which it also has an input dependence edge.
1379 DenseSet<Value> loadAndStoreMemrefSet;
1380 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
1381 if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) {
1382 return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0;
1383 }))
1384 return false;
1385
1386 // Check that all stores are to the same memref if any.
1387 DenseSet<Value> storeMemrefs;
1388 for (auto *storeOpInst : sibNode->stores) {
1389 storeMemrefs.insert(
1390 cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
1391 }
1392 return storeMemrefs.size() <= 1;
1393 };
1394
1395 // Search for siblings which load the same memref block argument.
1396 Block *block = dstNode->op->getBlock();
1397 for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) {
1398 for (Operation *user : block->getArgument(i).getUsers()) {
1399 auto loadOp = dyn_cast<AffineReadOpInterface>(user);
1400 if (!loadOp)
1401 continue;
1402 // Gather loops surrounding 'use'.
1403 SmallVector<AffineForOp, 4> loops;
1404 getAffineForIVs(*user, &loops);
1405 // Skip 'use' if it is not within a loop nest.
1406 // Find the surrounding affine.for nested immediately within the
1407 // block.
1408 auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
1409 return loop->getBlock() == &mdg->block;
1410 });
1411 // Skip 'use' if it is not within a loop nest in `block`.
1412 if (it == loops.end())
1413 continue;
1414 Node *sibNode = mdg->getForOpNode(*it);
1415 assert(sibNode != nullptr);
1416 // Skip 'use' if it not a sibling to 'dstNode'.
1417 if (sibNode->id == dstNode->id)
1418 continue;
1419 // Skip 'use' if it has been visited.
1420 if (visitedSibNodeIds->count(sibNode->id) > 0)
1421 continue;
1422 // Skip 'use' if it does not load from the same memref as 'dstNode'.
1423 auto memref = loadOp.getMemRef();
1424 if (dstNode->getLoadOpCount(memref) == 0)
1425 continue;
1426 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
1427 if (canFuseWithSibNode(sibNode, memref)) {
1428 visitedSibNodeIds->insert(sibNode->id);
1429 idAndMemrefToFuse->first = sibNode->id;
1430 idAndMemrefToFuse->second = memref;
1431 return true;
1432 }
1433 }
1434 }
1435
1436 // Search for siblings by following edges through an intermediate src node.
1437 // Collect candidate 'dstNode' input edges in 'inEdges'.
1438 SmallVector<MemRefDependenceGraph::Edge, 2> inEdges;
1439 mdg->forEachMemRefInputEdge(
1440 dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
1441 // Add 'inEdge' if it is a read-after-write dependence or an edge
1442 // from a memref defining op (e.g. view-like op or alloc op).
1443 if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
1444 (mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0 ||
1445 inEdge.value.getDefiningOp() == mdg->getNode(inEdge.id)->op))
1446 inEdges.push_back(inEdge);
1447 });
1448
1449 // Search for sibling nodes to fuse by visiting output edges from each input
1450 // edge in 'inEdges'.
1451 for (auto &inEdge : inEdges) {
1452 // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'.
1453 SmallVector<MemRefDependenceGraph::Edge, 2> outEdges;
1454 mdg->forEachMemRefOutputEdge(
1455 inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) {
1456 unsigned sibNodeId = outEdge.id;
1457 if (visitedSibNodeIds->count(sibNodeId) > 0)
1458 return;
1459 // Skip output edge if not a sibling using the same memref.
1460 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
1461 return;
1462 auto *sibNode = mdg->getNode(sibNodeId);
1463 if (!isa<AffineForOp>(sibNode->op))
1464 return;
1465 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
1466 if (canFuseWithSibNode(sibNode, outEdge.value)) {
1467 // Add candidate 'outEdge' to sibling node.
1468 outEdges.push_back(outEdge);
1469 }
1470 });
1471
1472 // Add first candidate if any were returned.
1473 if (!outEdges.empty()) {
1474 visitedSibNodeIds->insert(outEdges[0].id);
1475 idAndMemrefToFuse->first = outEdges[0].id;
1476 idAndMemrefToFuse->second = outEdges[0].value;
1477 return true;
1478 }
1479 }
1480 return false;
1481 }
1482
1483 /// Update data dependence graph state to reflect sibling fusion of 'sibNode'
1484 /// into 'dstNode'.
1485 void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) {
1486 // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion.
1487 mdg->updateEdges(sibNode->id, dstNode->id);
1488
1489 // Collect dst loop stats after memref privatization transformation.
1490 auto dstForInst = cast<AffineForOp>(dstNode->op);
1491 LoopNestStateCollector dstLoopCollector;
1492 dstLoopCollector.collect(dstForInst);
1493 // Clear and add back loads and stores
1494 mdg->clearNodeLoadAndStores(dstNode->id);
1495 mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
1496 dstLoopCollector.storeOpInsts, dstLoopCollector.memrefLoads,
1497 dstLoopCollector.memrefStores, dstLoopCollector.memrefFrees);
1498 }
1499
1500 // Clean up any allocs with no users.
1501 void eraseUnusedMemRefAllocations() {
1502 for (auto &pair : mdg->memrefEdgeCount) {
1503 if (pair.second > 0)
1504 continue;
1505 auto memref = pair.first;
1506 // Skip if there exist other uses (return operation or function calls).
1507 if (!memref.use_empty())
1508 continue;
1509 // Use list expected to match the dep graph info.
1510 auto *op = memref.getDefiningOp();
1511 if (isa_and_nonnull<memref::AllocOp>(op))
1512 op->erase();
1513 }
1514 }
1515};
1516
1517} // namespace
1518
1519/// Run fusion on `block`.
1520void LoopFusion::runOnBlock(Block *block) {
1521 MemRefDependenceGraph g(*block);
1522 if (!g.init()) {
1523 LDBG() << "MDG init failed";
1524 return;
1525 }
1526
1527 std::optional<unsigned> fastMemorySpaceOpt;
1528 if (fastMemorySpace.hasValue())
1529 fastMemorySpaceOpt = fastMemorySpace;
1530 unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
1531 GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
1532 maximalFusion, computeToleranceThreshold);
1533
1534 if (affineFusionMode == FusionMode::ProducerConsumer)
1535 fusion.runProducerConsumerFusionOnly();
1536 else if (affineFusionMode == FusionMode::Sibling)
1537 fusion.runSiblingFusionOnly();
1538 else
1539 fusion.runGreedyFusion();
1540}
1541
1542void LoopFusion::runOnOperation() {
1543 // Call fusion on every op that has at least two affine.for nests (in post
1544 // order).
1545 getOperation()->walk([&](Operation *op) {
1546 for (Region &region : op->getRegions()) {
1547 for (Block &block : region.getBlocks()) {
1548 auto affineFors = block.getOps<AffineForOp>();
1549 if (!affineFors.empty() && !llvm::hasSingleElement(affineFors))
1550 runOnBlock(&block);
1551 }
1552 }
1553 });
1554}
1555
1557 unsigned fastMemorySpace, uint64_t localBufSizeThreshold,
1558 bool maximalFusion, enum FusionMode affineFusionMode) {
1559 return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
1560 maximalFusion, affineFusionMode);
1561}
static Value getMemRef(Operation *memOp)
Returns the memref being read/written by a memref/affine load/store op.
Definition Utils.cpp:246
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static void getProducerCandidates(unsigned dstId, const MemRefDependenceGraph &mdg, SmallVectorImpl< unsigned > &srcIdCandidates)
Returns in 'srcIdCandidates' the producer fusion candidates for consumer 'dstId'.
static Operation * getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock, ArrayRef< Operation * > producerStores)
Get the operation that should act as a dominance filter while replacing memref uses with a private me...
static bool isFusionProfitable(AffineForOp srcForOp, ArrayRef< Operation * > producerStores, AffineForOp dstForOp, ArrayRef< ComputationSliceState > depthSliceUnions, unsigned maxLegalFusionDepth, unsigned *dstLoopDepth, double computeToleranceThreshold)
static bool isEscapingMemref(Value memref, Block *block)
A memref escapes in the context of the fusion pass if either:
static std::optional< double > getAdditionalComputeFraction(AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth, ArrayRef< ComputationSliceState > depthSliceUnions, int64_t &sliceCost, int64_t &fusedLoopNestComputeCost)
Returns the amount of additional (redundant) computation that will be done as a fraction of the total...
static bool canRemoveSrcNodeAfterFusion(unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, Operation *fusedLoopInsPoint, const DenseSet< Value > &escapingMemRefs, const MemRefDependenceGraph &mdg)
Returns true if node 'srcId' can be removed after fusing it with node 'dstId'.
static Value createPrivateMemRef(AffineForOp forOp, ArrayRef< Operation * > storeOps, unsigned dstLoopDepth, std::optional< unsigned > fastMemorySpace, Block *sliceInsertionBlock, uint64_t localBufSizeThreshold)
static void gatherEscapingMemrefs(unsigned id, const MemRefDependenceGraph &mdg, DenseSet< Value > &escapingMemRefs)
Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' that escape the block or ...
auto load
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition Block.cpp:74
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition Block.h:193
iterator begin()
Definition Block.h:143
void getValues(unsigned start, unsigned end, SmallVectorImpl< Value > *values) const
Returns the Values associated with variables in range [start, end).
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
void erase()
Remove this operation from its parent block and delete it.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, AffineForOp dstForOp, LoopNestStats &dstStats, const ComputationSliceState &slice, int64_t *computeCost)
Computes and returns in 'computeCost', the total compute cost of fusing the 'slice' of the loop nest ...
void gatherProducerConsumerMemrefs(ArrayRef< Operation * > srcOps, ArrayRef< Operation * > dstOps, DenseSet< Value > &producerConsumerMemrefs)
Returns in 'producerConsumerMemrefs' the memrefs involved in a producer-consumer dependence between w...
int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats)
Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
void fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, const ComputationSliceState &srcSlice, bool isInnermostSiblingInsertionFusion=false)
Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point and source slice loop bo...
void getAffineForIVs(Operation &op, SmallVectorImpl< AffineForOp > *loops)
Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered from the outermost 'affine....
Definition Utils.cpp:851
std::optional< int64_t > getMemoryFootprintBytes(AffineForOp forOp, int memorySpace=-1)
Gets the memory footprint of all data touched in the specified memory space in bytes; if the memory s...
Definition Utils.cpp:2173
std::unique_ptr< Pass > createLoopFusionPass(unsigned fastMemorySpace=0, uint64_t localBufSizeThreshold=0, bool maximalFusion=false, enum FusionMode fusionMode=FusionMode::Greedy)
Creates a loop fusion pass which fuses affine loop nests at the top-level of the operation the pass i...
FusionMode
Fusion mode to attempt.
Definition Passes.h:35
unsigned getInnermostCommonLoopDepth(ArrayRef< Operation * > ops, SmallVectorImpl< AffineForOp > *surroundingLoops=nullptr)
Returns the innermost common loop depth for the set of operations in 'ops'.
Definition Utils.cpp:1601
bool getLoopNestStats(AffineForOp forOp, LoopNestStats *stats)
Collect loop nest statistics (eg.
AffineForOp sinkSequentialLoops(AffineForOp forOp)
bool hasCyclicDependence(AffineForOp root)
Returns true if the affine nest rooted at root has a cyclic dependence among its affine memory access...
mlir::Block * findInnermostCommonBlockInScope(mlir::Operation *a, mlir::Operation *b)
Find the innermost common Block of a and b in the affine scope that a and b are part of.
Definition Utils.cpp:2391
FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, ComputationSliceState *srcSlice, FusionStrategy fusionStrategy=FusionStrategy::Generic)
Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the loop nest rooted at 'dst...
std::optional< int64_t > getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType)
Returns the memref's element type's size in bytes where the elemental type is an int or float or a ve...
Definition Utils.cpp:1419
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
bool hasSingleEffect(Operation *op)
Returns "true" if op has only an effect of type EffectTy.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their associated operands for a ...
Definition Utils.h:318
bool isEmpty() const
Returns true if the computation slice is empty.
Definition Utils.h:350
std::optional< bool > isMaximal() const
Returns true if the computation slice encloses all the iterations of the sliced loop nest.
Definition Utils.cpp:1106
SmallVector< Operation *, 4 > memrefFrees
Definition Utils.h:49
SmallVector< Operation *, 4 > loadOpInsts
Definition Utils.h:41
SmallVector< Operation *, 4 > memrefStores
Definition Utils.h:47
void collect(Operation *opToWalk)
Definition Utils.cpp:43
SmallVector< Operation *, 4 > memrefLoads
Definition Utils.h:45
SmallVector< Operation *, 4 > storeOpInsts
Definition Utils.h:43
LoopNestStats aggregates various per-loop statistics (eg.
DenseMap< unsigned, SmallVector< Edge, 2 > > outEdges
Definition Utils.h:147
Block & block
The block for which this graph is created to perform fusion.
Definition Utils.h:270
void clearNodeLoadAndStores(unsigned id)
Definition Utils.cpp:797
void updateEdges(unsigned srcId, unsigned dstId, const DenseSet< Value > &privateMemRefs, bool removeSrcId)
Definition Utils.cpp:720
DenseMap< unsigned, SmallVector< Edge, 2 > > inEdges
Definition Utils.h:144
const Node * getNode(unsigned id) const
Definition Utils.cpp:452
void addToNode(unsigned id, ArrayRef< Operation * > loads, ArrayRef< Operation * > stores, ArrayRef< Operation * > memrefLoads, ArrayRef< Operation * > memrefStores, ArrayRef< Operation * > memrefFrees)
Definition Utils.cpp:784
DenseMap< Value, unsigned > memrefEdgeCount
Definition Utils.h:150
A region of a memref's data space; this is typically constructed by analyzing load/store op's on this...
Definition Utils.h:489
std::optional< int64_t > getConstantBoundingSizeAndShape(SmallVectorImpl< int64_t > *shape=nullptr, SmallVectorImpl< AffineMap > *lbs=nullptr) const
Returns a constant upper bound on the number of elements in this region if bounded by a known constan...
Definition Utils.cpp:1160
LogicalResult compute(Operation *op, unsigned loopDepth, const ComputationSliceState *sliceState=nullptr, bool addMemRefDimBounds=true, bool dropLocalVars=true, bool dropOuterIVs=true)
Computes the memory region accessed by this memref with the region represented as constraints symboli...
Definition Utils.cpp:1257
std::optional< int64_t > getRegionSize()
Returns the size of this MemRefRegion in bytes.
Definition Utils.cpp:1438
FlatAffineValueConstraints * getConstraints()
Definition Utils.h:535
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.