MLIR  21.0.0git
LoopFusion.cpp
Go to the documentation of this file.
1 //===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements affine fusion.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/Transforms/Passes.h"
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/ADT/DenseSet.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include <iomanip>
33 #include <optional>
34 #include <sstream>
35 
36 namespace mlir {
37 namespace affine {
38 #define GEN_PASS_DEF_AFFINELOOPFUSION
39 #include "mlir/Dialect/Affine/Passes.h.inc"
40 } // namespace affine
41 } // namespace mlir
42 
43 #define DEBUG_TYPE "affine-fusion"
44 
45 using namespace mlir;
46 using namespace mlir::affine;
47 
48 namespace {
49 /// Loop fusion pass. This pass currently supports a greedy fusion policy,
50 /// which fuses loop nests with single-writer/single-reader memref dependences
51 /// with the goal of improving locality.
52 // TODO: Support fusion of source loop nests which write to multiple
53 // memrefs, where each memref can have multiple users (if profitable).
54 struct LoopFusion : public affine::impl::AffineLoopFusionBase<LoopFusion> {
55  LoopFusion() = default;
56  LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes,
57  bool maximalFusion, enum FusionMode affineFusionMode) {
58  this->fastMemorySpace = fastMemorySpace;
59  this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024;
60  this->maximalFusion = maximalFusion;
61  this->affineFusionMode = affineFusionMode;
62  }
63 
64  void runOnBlock(Block *block);
65  void runOnOperation() override;
66 };
67 
68 } // namespace
69 
70 /// Returns true if node 'srcId' can be removed after fusing it with node
71 /// 'dstId'. The node can be removed if any of the following conditions are met:
72 /// 1. 'srcId' has no output dependences after fusion and no escaping memrefs.
73 /// 2. 'srcId' has no output dependences after fusion, has escaping memrefs
74 /// and the fusion slice is maximal.
75 /// 3. 'srcId' has output dependences after fusion, the fusion slice is
76 /// maximal and the fusion insertion point dominates all the dependences.
78  unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice,
79  Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs,
80  const MemRefDependenceGraph &mdg) {
81 
82  Operation *dstNodeOp = mdg.getNode(dstId)->op;
83  bool hasOutDepsAfterFusion = false;
84 
85  for (auto &outEdge : mdg.outEdges.lookup(srcId)) {
86  Operation *depNodeOp = mdg.getNode(outEdge.id)->op;
87  // Skip dependence with dstOp since it will be removed after fusion.
88  if (depNodeOp == dstNodeOp)
89  continue;
90 
91  // Only fusion within the same block is supported. Use domination analysis
92  // when needed.
93  if (depNodeOp->getBlock() != dstNodeOp->getBlock())
94  return false;
95 
96  // Check if the insertion point of the fused loop dominates the dependence.
97  // Otherwise, the src loop can't be removed.
98  if (fusedLoopInsPoint != depNodeOp &&
99  !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) {
100  LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't "
101  "dominate dependence\n");
102  return false;
103  }
104 
105  hasOutDepsAfterFusion = true;
106  }
107 
108  // If src loop has dependences after fusion or it writes to an live-out or
109  // escaping memref, we can only remove it if the fusion slice is maximal so
110  // that all the dependences are preserved.
111  if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
112  std::optional<bool> isMaximal = fusionSlice.isMaximal();
113  if (!isMaximal) {
114  LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine "
115  "if fusion is maximal\n");
116  return false;
117  }
118 
119  if (!*isMaximal) {
120  LLVM_DEBUG(llvm::dbgs()
121  << "Src loop can't be removed: fusion is not maximal\n");
122  return false;
123  }
124  }
125 
126  return true;
127 }
128 
129 /// Returns in 'srcIdCandidates' the producer fusion candidates for consumer
130 /// 'dstId'. Candidates are sorted by node id order. This order corresponds to
131 /// the program order when the 'mdg' is created. However, program order is not
132 /// guaranteed and must not be required by the client. Program order won't be
133 /// held if the 'mdg' is reused from a previous fusion step or if the node
134 /// creation order changes in the future to support more advance cases.
135 // TODO: Move this to a loop fusion utility once 'mdg' is also moved.
136 static void getProducerCandidates(unsigned dstId,
137  const MemRefDependenceGraph &mdg,
138  SmallVectorImpl<unsigned> &srcIdCandidates) {
139  // Skip if no input edges along which to fuse.
140  if (mdg.inEdges.count(dstId) == 0)
141  return;
142 
143  // Gather memrefs from loads in 'dstId'.
144  auto *dstNode = mdg.getNode(dstId);
145  DenseSet<Value> consumedMemrefs;
146  for (Operation *load : dstNode->loads)
147  consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
148 
149  // Traverse 'dstId' incoming edges and gather the nodes that contain a store
150  // to one of the consumed memrefs.
151  for (const auto &srcEdge : mdg.inEdges.lookup(dstId)) {
152  const auto *srcNode = mdg.getNode(srcEdge.id);
153  // Skip if 'srcNode' is not a loop nest.
154  if (!isa<AffineForOp>(srcNode->op))
155  continue;
156 
157  if (any_of(srcNode->stores, [&](Operation *op) {
158  auto storeOp = cast<AffineWriteOpInterface>(op);
159  return consumedMemrefs.count(storeOp.getMemRef()) > 0;
160  }))
161  srcIdCandidates.push_back(srcNode->id);
162  }
163 
164  llvm::sort(srcIdCandidates);
165  srcIdCandidates.erase(llvm::unique(srcIdCandidates), srcIdCandidates.end());
166 }
167 
168 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a
169 /// producer-consumer dependence between 'srcId' and 'dstId'.
170 static void
171 gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
172  const MemRefDependenceGraph &mdg,
173  DenseSet<Value> &producerConsumerMemrefs) {
174  auto *dstNode = mdg.getNode(dstId);
175  auto *srcNode = mdg.getNode(srcId);
176  gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads,
177  producerConsumerMemrefs);
178 }
179 
180 /// A memref escapes in the context of the fusion pass if either:
181 /// 1. it (or its alias) is a block argument, or
182 /// 2. created by an op not known to guarantee alias freedom,
183 /// 3. it (or its alias) are used by ops other than affine dereferencing ops
184 /// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops,
185 /// terminator ops, etc.); such ops do not deference the memref in an affine
186 /// way.
187 static bool isEscapingMemref(Value memref, Block *block) {
188  Operation *defOp = memref.getDefiningOp();
189  // Check if 'memref' is a block argument.
190  if (!defOp)
191  return true;
192 
193  // Check if this is defined to be an alias of another memref.
194  if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
195  if (isEscapingMemref(viewOp.getViewSource(), block))
196  return true;
197 
198  // Any op besides allocating ops wouldn't guarantee alias freedom
199  if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(defOp, memref))
200  return true;
201 
202  // Check if 'memref' is used by a non-deferencing op (including unknown ones)
203  // (e.g., call ops, alias creating ops, etc.).
204  return llvm::any_of(memref.getUsers(), [&](Operation *user) {
205  // Ignore users outside of `block`.
206  Operation *ancestorOp = block->getParent()->findAncestorOpInRegion(*user);
207  if (!ancestorOp)
208  return true;
209  if (ancestorOp->getBlock() != block)
210  return false;
211  return !isa<AffineMapAccessInterface>(*user);
212  });
213 }
214 
215 /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
216 /// that escape the block or are accessed in a non-affine way.
217 static void gatherEscapingMemrefs(unsigned id, const MemRefDependenceGraph &mdg,
218  DenseSet<Value> &escapingMemRefs) {
219  auto *node = mdg.getNode(id);
220  for (Operation *storeOp : node->stores) {
221  auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
222  if (escapingMemRefs.count(memref))
223  continue;
224  if (isEscapingMemref(memref, &mdg.block))
225  escapingMemRefs.insert(memref);
226  }
227 }
228 
229 // Sinks all sequential loops to the innermost levels (while preserving
230 // relative order among them) and moves all parallel loops to the
231 // outermost (while again preserving relative order among them).
232 // This can increase the loop depth at which we can fuse a slice, since we are
233 // pushing loop carried dependence to a greater depth in the loop nest.
235  assert(isa<AffineForOp>(node->op));
236  AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op));
237  node->op = newRootForOp;
238 }
239 
240 /// Get the operation that should act as a dominance filter while replacing
241 /// memref uses with a private memref for which `producerStores` and
242 /// `sliceInsertionBlock` are provided. This effectively determines in what
243 /// part of the IR we should be performing the replacement.
244 static Operation *
246  ArrayRef<Operation *> producerStores) {
247  assert(!producerStores.empty() && "expected producer store");
248 
249  // We first find the common block that contains the producer stores and
250  // the slice computation. The first ancestor among the ancestors of the
251  // producer stores in that common block is the dominance filter to use for
252  // replacement.
253  Block *commonBlock = nullptr;
254  // Find the common block of all relevant operations.
255  for (Operation *store : producerStores) {
256  Operation *otherOp =
257  !commonBlock ? &*sliceInsertionBlock->begin() : &*commonBlock->begin();
258  commonBlock = findInnermostCommonBlockInScope(store, otherOp);
259  }
260  assert(commonBlock &&
261  "common block of producer stores and slice should exist");
262 
263  // Find the first ancestor among the ancestors of `producerStores` in
264  // `commonBlock`.
265  Operation *firstAncestor = nullptr;
266  for (Operation *store : producerStores) {
267  Operation *ancestor = commonBlock->findAncestorOpInBlock(*store);
268  assert(ancestor && "producer store should be contained in common block");
269  firstAncestor = !firstAncestor || ancestor->isBeforeInBlock(firstAncestor)
270  ? ancestor
271  : firstAncestor;
272  }
273  return firstAncestor;
274 }
275 
276 /// Returns the amount of additional (redundant) computation that will be done
277 /// as a fraction of the total computation if `srcForOp` is fused into
278 /// `dstForOp` at depth `depth`. The method returns the compute cost of the
279 /// slice and the fused nest's compute cost in the trailing output arguments.
280 static std::optional<double> getAdditionalComputeFraction(
281  AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth,
282  ArrayRef<ComputationSliceState> depthSliceUnions, int64_t &sliceCost,
283  int64_t &fusedLoopNestComputeCost) {
284  LLVM_DEBUG(llvm::dbgs() << "Determining additional compute fraction...\n";);
285  // Compute cost of sliced and unsliced src loop nest.
286  // Walk src loop nest and collect stats.
287  LoopNestStats srcLoopNestStats;
288  if (!getLoopNestStats(srcForOp, &srcLoopNestStats)) {
289  LLVM_DEBUG(llvm::dbgs() << "Failed to get source loop nest stats.\n");
290  return std::nullopt;
291  }
292 
293  // Compute cost of dst loop nest.
294  LoopNestStats dstLoopNestStats;
295  if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) {
296  LLVM_DEBUG(llvm::dbgs() << "Failed to get destination loop nest stats.\n");
297  return std::nullopt;
298  }
299 
300  // Compute op instance count for the src loop nest without iteration slicing.
301  uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
302 
303  // Compute op cost for the dst loop nest.
304  uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
305 
306  const ComputationSliceState &slice = depthSliceUnions[depth - 1];
307  // Skip slice union if it wasn't computed for this depth.
308  if (slice.isEmpty()) {
309  LLVM_DEBUG(llvm::dbgs() << "Slice wasn't computed.\n");
310  return std::nullopt;
311  }
312 
313  if (!getFusionComputeCost(srcForOp, srcLoopNestStats, dstForOp,
314  dstLoopNestStats, slice,
315  &fusedLoopNestComputeCost)) {
316  LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
317  return std::nullopt;
318  }
319 
320  double additionalComputeFraction =
321  fusedLoopNestComputeCost /
322  (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
323  1;
324 
325  return additionalComputeFraction;
326 }
327 
328 // Creates and returns a private (single-user) memref for fused loop rooted at
329 // 'forOp', with (potentially reduced) memref size based on the memref region
330 // written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock'
331 // specifies the block in which the slice was/will be inserted. The method
332 // expects that all stores ops to the memref have the same access function.
333 // Returns nullptr if the creation failed.
334 static Value createPrivateMemRef(AffineForOp forOp,
335  ArrayRef<Operation *> storeOps,
336  unsigned dstLoopDepth,
337  std::optional<unsigned> fastMemorySpace,
338  Block *sliceInsertionBlock,
339  uint64_t localBufSizeThreshold) {
340  assert(!storeOps.empty() && "no source stores supplied");
341 
342  // Check if all stores have the same access function; we only support this
343  // case.
344  // TODO: Use union of memref write regions to compute private memref footprint
345  // for store ops with different access functions.
346  if (storeOps.size() > 1 &&
347  !std::equal(std::next(storeOps.begin()), storeOps.end(), storeOps.begin(),
348  [](Operation *a, Operation *b) {
349  MemRefAccess aM(cast<AffineWriteOpInterface>(a));
350  MemRefAccess bM(cast<AffineWriteOpInterface>(b));
351  return aM == bM;
352  })) {
353  LLVM_DEBUG(llvm::dbgs()
354  << "Private memref creation unsupported for multiple producer "
355  "stores with different access functions.\n");
356  return nullptr;
357  }
358 
359  Operation *srcStoreOp = storeOps[0];
360 
361  // Create builder to insert alloc op just before 'forOp'.
362  OpBuilder b(forOp);
363  // Builder to create constants at the top level.
364  OpBuilder top(forOp->getParentRegion());
365  // Create new memref type based on slice bounds.
366  auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOp).getMemRef();
367  auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
368  unsigned rank = oldMemRefType.getRank();
369 
370  // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
371  MemRefRegion region(srcStoreOp->getLoc());
372  bool validRegion = succeeded(
373  region.compute(srcStoreOp, dstLoopDepth, /*sliceState=*/nullptr,
374  /*addMemRefDimBounds=*/true, /*dropLocalVars=*/false));
375 
376  (void)validRegion;
377  assert(validRegion && "unexpected memref region failure");
378  SmallVector<int64_t, 4> newShape;
380  lbs.reserve(rank);
381  // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
382  // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
383  std::optional<int64_t> numElements =
384  region.getConstantBoundingSizeAndShape(&newShape, &lbs);
385  assert(numElements && "non-constant number of elts in local buffer");
386 
387  const FlatAffineValueConstraints *cst = region.getConstraints();
388  // 'outerIVs' holds the values that this memory region is symbolic/parametric
389  // on; this would correspond to loop IVs surrounding the level at which the
390  // slice is being materialized.
391  SmallVector<Value, 8> outerIVs;
392  cst->getValues(rank, cst->getNumDimAndSymbolVars(), &outerIVs);
393 
394  // Build 'rank' AffineExprs from MemRefRegion 'lbs'
396  offsets.reserve(rank);
397 
398  // Outer IVs are considered symbols during memref region computation. Replace
399  // them uniformly with dims so that valid IR is guaranteed.
400  SmallVector<AffineExpr> replacements;
401  for (unsigned j = 0, e = lbs[0].getNumSymbols(); j < e; ++j)
402  replacements.push_back(mlir::getAffineDimExpr(j, forOp.getContext()));
403  for (unsigned d = 0; d < rank; ++d) {
404  assert(lbs[d].getNumResults() == 1 &&
405  "invalid private memref bound calculation");
406  offsets.push_back(lbs[d].getResult(0).replaceSymbols(replacements));
407  }
408 
409  // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
410  // by 'srcStoreOpInst'.
411  auto eltSize = getMemRefIntOrFloatEltSizeInBytes(oldMemRefType);
412  assert(eltSize && "memrefs with size elt types expected");
413  uint64_t bufSize = *eltSize * *numElements;
414  Attribute newMemSpace;
415  if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
416  newMemSpace = b.getI64IntegerAttr(*fastMemorySpace);
417  } else {
418  newMemSpace = oldMemRefType.getMemorySpace();
419  }
420  auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
421  /*map=*/AffineMap(), newMemSpace);
422 
423  // Create new private memref for fused loop 'forOp'. 'newShape' is always
424  // a constant shape.
425  // TODO: Create/move alloc ops for private memrefs closer to their
426  // consumer loop nests to reduce their live range. Currently they are added
427  // at the beginning of the block, because loop nests can be reordered
428  // during the fusion pass.
429  Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType);
430 
431  // Build an AffineMap to remap access functions based on lower bound offsets.
432  SmallVector<AffineExpr, 4> remapExprs;
433  remapExprs.reserve(rank);
434  for (unsigned i = 0; i < rank; i++) {
435  auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
436 
437  auto remapExpr =
438  simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
439  remapExprs.push_back(remapExpr);
440  }
441 
442  auto indexRemap =
443  AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
444 
445  // Replace all users of 'oldMemRef' with 'newMemRef'.
446  Operation *domFilter =
447  getDominanceFilterForPrivateMemRefRepl(sliceInsertionBlock, storeOps);
448  LogicalResult res = replaceAllMemRefUsesWith(
449  oldMemRef, newMemRef, /*extraIndices=*/{}, indexRemap,
450  /*extraOperands=*/outerIVs,
451  /*symbolOperands=*/{}, domFilter);
452  assert(succeeded(res) &&
453  "replaceAllMemrefUsesWith should always succeed here");
454  (void)res;
455  LLVM_DEBUG(llvm::dbgs() << "Created private memref of type: " << newMemRefType
456  << '\n');
457  return newMemRef;
458 }
459 
460 // Checks the profitability of fusing a backwards slice of the loop nest
461 // `srcForOp` into the loop nest surrounding 'dstLoadOpInsts'. The argument
462 // 'srcStoreOpInst' is used to calculate the storage reduction on the memref
463 // being produced and consumed, which is an input to the cost model. For
464 // producer-consumer fusion, 'srcStoreOpInst' will be the same as 'srcOpInst',
465 // as we are slicing w.r.t to that producer. For input-reuse fusion, 'srcOpInst'
466 // will be the src loop nest LoadOp which reads from the same memref as dst loop
467 // nest load ops, and 'srcStoreOpInst' will be the unique store op in the src
468 // node, which will be used to check that the write region is the same after
469 // input-reuse fusion. Computation slices are provided in 'depthSliceUnions' for
470 // each legal fusion depth. The maximal depth at which fusion is legal is
471 // provided in 'maxLegalFusionDepth'. Returns true if it is profitable to fuse
472 // the candidate loop nests. Returns false otherwise. `dstLoopDepth` is set to
473 // the most profitable depth at which to materialize the source loop nest slice.
474 // The profitability model executes the following steps:
475 // *) Computes the backward computation slice at 'srcOpInst'. This
476 // computation slice of the loop nest surrounding 'srcOpInst' is
477 // represented by modified src loop bounds in 'sliceState', which are
478 // functions of loop IVs in the loop nest surrounding 'srcOpInst'.
479 // *) Computes the cost of unfused src/dst loop nests (currently the cost of a
480 // loop nest is the total number of dynamic operation instances in the loop
481 // nest).
482 // *) Computes the cost of fusing a slice of the src loop nest into the dst
483 // loop nest at various values of dst loop depth, attempting to fuse
484 // the largest computation slice at the maximal dst loop depth (closest to
485 // the load) to minimize reuse distance and potentially enable subsequent
486 // load/store forwarding.
487 // NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
488 // nest, at which the src computation slice is inserted/fused.
489 // NOTE: We attempt to maximize the dst loop depth, but there are cases
490 // where a particular setting for 'dstLoopNest' might fuse an unsliced
491 // loop (within the src computation slice) at a depth which results in
492 // excessive recomputation (see unit tests for examples).
493 // *) Compares the total cost of the unfused loop nests to the min cost fused
494 // loop nest computed in the previous step, and returns true if the latter
495 // is lower.
496 // TODO: Extend profitability analysis to support scenarios with multiple
497 // stores.
498 static bool isFusionProfitable(AffineForOp srcForOp,
499  ArrayRef<Operation *> producerStores,
500  AffineForOp dstForOp,
501  ArrayRef<ComputationSliceState> depthSliceUnions,
502  unsigned maxLegalFusionDepth,
503  unsigned *dstLoopDepth,
504  double computeToleranceThreshold) {
505  LLVM_DEBUG({
506  llvm::dbgs()
507  << "Checking whether fusion is profitable between source nest:\n";
508  llvm::dbgs() << ' ' << srcForOp << " and destination nest:\n";
509  llvm::dbgs() << dstForOp << "\n";
510  });
511 
512  if (maxLegalFusionDepth == 0) {
513  LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth is 0\n");
514  return false;
515  }
516 
517  // Compute cost of sliced and unsliced src loop nest.
518 
519  // Walk src loop nest and collect stats.
520  LoopNestStats srcLoopNestStats;
521  if (!getLoopNestStats(srcForOp, &srcLoopNestStats))
522  return false;
523 
524  // Compute cost of dst loop nest.
525  LoopNestStats dstLoopNestStats;
526  if (!getLoopNestStats(dstForOp, &dstLoopNestStats))
527  return false;
528 
529  // We limit profitability analysis to only scenarios with
530  // a single producer store for now. Note that some multi-store
531  // producer scenarios will still go through profitability analysis
532  // if only one of the stores is involved in the producer-consumer
533  // relationship of the candidate loops.
534  // TODO: Suppport multiple producer stores in profitability
535  // analysis.
536  if (producerStores.size() > 1) {
537  LLVM_DEBUG(llvm::dbgs() << "Limited profitability analysis. Not "
538  "supported for multiple producer store case.\n");
539  int64_t sliceCost;
540  int64_t fusedLoopNestComputeCost;
541  // We will still fuse if fusion obeys the specified compute
542  // tolerance at the max legal depth.
543  auto fraction = getAdditionalComputeFraction(
544  srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost,
545  fusedLoopNestComputeCost);
546  if (!fraction || fraction > computeToleranceThreshold) {
547  LLVM_DEBUG(llvm::dbgs() << "Additional computation exceeds "
548  "compute tolerance. Not fusing.\n");
549  return false;
550  }
551  LLVM_DEBUG(llvm::dbgs()
552  << "Considering fusion profitable at max legal depth.\n");
553  return true;
554  }
555 
556  Operation *srcStoreOp = producerStores.front();
557 
558  // Search for min cost value for 'dstLoopDepth'. At each value of
559  // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice
560  // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
561  // of these bounds). Next the union slice bounds are used to calculate
562  // the cost of the slice and the cost of the slice inserted into the dst
563  // loop nest at 'dstLoopDepth'.
564  uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
565  double maxStorageReduction = 0.0;
566  std::optional<uint64_t> sliceMemEstimate;
567 
568  // The best loop depth at which to materialize the slice.
569  std::optional<unsigned> bestDstLoopDepth;
570 
571  // Compute src loop nest write region size.
572  MemRefRegion srcWriteRegion(srcStoreOp->getLoc());
573  if (failed(srcWriteRegion.compute(srcStoreOp, /*loopDepth=*/0))) {
574  LLVM_DEBUG(llvm::dbgs()
575  << "Unable to compute MemRefRegion for source operation\n");
576  return false;
577  }
578 
579  std::optional<int64_t> maybeSrcWriteRegionSizeBytes =
580  srcWriteRegion.getRegionSize();
581  if (!maybeSrcWriteRegionSizeBytes.has_value())
582  return false;
583  int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;
584 
585  // Compute op instance count for the src loop nest without iteration slicing.
586  uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
587 
588  // Compute op instance count for the destination loop nest.
589  uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
590 
591  // Evaluate all depth choices for materializing the slice in the destination
592  // loop nest.
593  for (unsigned i = maxLegalFusionDepth; i >= 1; --i) {
594  const ComputationSliceState &slice = depthSliceUnions[i - 1];
595  // Skip slice union if it wasn't computed for this depth.
596  if (slice.isEmpty())
597  continue;
598 
599  // Compute cost of the slice separately, i.e, the compute cost of the slice
600  // if all outer trip counts are one.
601  int64_t sliceCost;
602 
603  int64_t fusedLoopNestComputeCost;
604 
605  auto mayAdditionalComputeFraction =
606  getAdditionalComputeFraction(srcForOp, dstForOp, i, depthSliceUnions,
607  sliceCost, fusedLoopNestComputeCost);
608  if (!mayAdditionalComputeFraction) {
609  LLVM_DEBUG(llvm::dbgs()
610  << "Can't determine additional compute fraction.\n");
611  continue;
612  }
613  double additionalComputeFraction = *mayAdditionalComputeFraction;
614 
615  // Determine what the slice write MemRefRegion would be, if the src loop
616  // nest slice 'slice' were to be inserted into the dst loop nest at loop
617  // depth 'i'.
618  MemRefRegion sliceWriteRegion(srcStoreOp->getLoc());
619  if (failed(sliceWriteRegion.compute(srcStoreOp, /*loopDepth=*/0, &slice))) {
620  LLVM_DEBUG(llvm::dbgs()
621  << "Failed to compute slice write region at loopDepth: " << i
622  << "\n");
623  continue;
624  }
625 
626  std::optional<int64_t> maybeSliceWriteRegionSizeBytes =
627  sliceWriteRegion.getRegionSize();
628  if (!maybeSliceWriteRegionSizeBytes.has_value() ||
629  *maybeSliceWriteRegionSizeBytes == 0) {
630  LLVM_DEBUG(llvm::dbgs()
631  << "Failed to get slice write region size at loopDepth: " << i
632  << "\n");
633  continue;
634  }
635  int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
636 
637  double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
638  static_cast<double>(sliceWriteRegionSizeBytes);
639 
640  LLVM_DEBUG({
641  std::stringstream msg;
642  msg << " evaluating fusion profitability at depth : " << i << "\n"
643  << std::fixed << std::setprecision(2)
644  << " additional compute fraction: "
645  << 100.0 * additionalComputeFraction << "%\n"
646  << " storage reduction factor: " << storageReduction << "x\n"
647  << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
648  << " src write region size: " << srcWriteRegionSizeBytes << "\n"
649  << " slice write region size: " << sliceWriteRegionSizeBytes
650  << "\n";
651  llvm::dbgs() << msg.str();
652  });
653 
654  // TODO: This is a placeholder cost model.
655  // Among all choices that add an acceptable amount of redundant computation
656  // (as per computeToleranceThreshold), we will simply pick the one that
657  // reduces the intermediary size the most.
658  if ((storageReduction > maxStorageReduction) &&
659  (additionalComputeFraction <= computeToleranceThreshold)) {
660  maxStorageReduction = storageReduction;
661  bestDstLoopDepth = i;
662  minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
663  sliceMemEstimate = sliceWriteRegionSizeBytes;
664  }
665  }
666 
667  // A simple cost model: fuse if it reduces the memory footprint.
668 
669  if (!bestDstLoopDepth) {
670  LLVM_DEBUG(
671  llvm::dbgs()
672  << "All fusion choices involve more than the threshold amount of "
673  "redundant computation; NOT fusing.\n");
674  return false;
675  }
676 
677  if (!bestDstLoopDepth) {
678  LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n");
679  return false;
680  }
681 
682  // Set dstLoopDepth based on best values from search.
683  *dstLoopDepth = *bestDstLoopDepth;
684 
685  LLVM_DEBUG(
686  llvm::dbgs() << " LoopFusion fusion stats:"
687  << "\n best loop depth: " << bestDstLoopDepth
688  << "\n src loop nest compute cost: " << srcLoopNestCost
689  << "\n dst loop nest compute cost: " << dstLoopNestCost
690  << "\n fused loop nest compute cost: "
691  << minFusedLoopNestComputeCost << "\n");
692 
693  auto dstMemSize = getMemoryFootprintBytes(dstForOp);
694  auto srcMemSize = getMemoryFootprintBytes(srcForOp);
695 
696  std::optional<double> storageReduction;
697 
698  if (!dstMemSize || !srcMemSize) {
699  LLVM_DEBUG(llvm::dbgs()
700  << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
701  return false;
702  }
703 
704  auto srcMemSizeVal = *srcMemSize;
705  auto dstMemSizeVal = *dstMemSize;
706 
707  assert(sliceMemEstimate && "expected value");
708  auto fusedMem = dstMemSizeVal + *sliceMemEstimate;
709 
710  LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
711  << " dst mem: " << dstMemSizeVal << "\n"
712  << " fused mem: " << fusedMem << "\n"
713  << " slice mem: " << sliceMemEstimate << "\n");
714 
715  if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
716  LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
717  return false;
718  }
719  storageReduction =
720  100.0 *
721  (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
722 
723  double additionalComputeFraction =
724  100.0 * (minFusedLoopNestComputeCost /
725  (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
726  1);
727  (void)additionalComputeFraction;
728  LLVM_DEBUG({
729  std::stringstream msg;
730  msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
731  << std::setprecision(2) << additionalComputeFraction
732  << "% redundant computation and a ";
733  msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>");
734  msg << "% storage reduction.\n";
735  llvm::dbgs() << msg.str();
736  });
737 
738  return true;
739 }
740 
741 namespace {
742 
743 // GreedyFusion greedily fuses loop nests which have a producer/consumer or
744 // input-reuse relationship on a memref, with the goal of improving locality.
745 //
746 // The steps of the producer-consumer fusion algorithm are as follows:
747 //
748 // *) A worklist is initialized with node ids from the dependence graph.
749 // *) For each node id in the worklist:
750 // *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a
751 // candidate destination AffineForOp into which fusion will be attempted.
752 // *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
753 // *) For each LoadOp in 'dstLoadOps' do:
754 // *) Look up dependent loop nests which have a single store op to the same
755 // memref.
756 // *) Check if dependences would be violated by the fusion.
757 // *) Get a computation slice of 'srcLoopNest', which adjusts its loop
758 // bounds to be functions of 'dstLoopNest' IVs and symbols.
759 // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
760 // at a loop depth determined by the cost model in 'isFusionProfitable'.
761 // *) Add the newly fused load/store operations to the state,
762 // and also add newly fused load ops to 'dstLoopOps' to be considered
763 // as fusion dst load ops in another iteration.
764 // *) Remove old src loop nest and its associated state.
765 //
766 // The steps of the input-reuse fusion algorithm are as follows:
767 //
768 // *) Initialize 'worklist' with node ids from the dependence graph.
769 // *) For each 'dstNode' in the worklist:
770 // *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which
771 // loads from the same memref, but which has no dependence paths to/from.
772 // *) Get a computation slice of 'sibLoopNest', which adjusts its loop
773 // bounds to be functions of 'dstLoopNest' IVs and symbols.
774 // *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest',
775 // at a loop depth determined by the cost model in 'isFusionProfitable'.
776 // This function also checks that the memref write region of 'sibLoopNest',
777 // is preserved in the fused loop nest.
778 // *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'.
779 //
780 // Given a graph where top-level operations are vertices in the set 'V' and
781 // edges in the set 'E' are dependences between vertices, this algorithm
782 // takes O(V) time for initialization, and has runtime O(V + E).
783 //
784 // This greedy algorithm is not 'maximal' due to the current restriction of
785 // fusing along single producer consumer edges, but there is a TODO: to fix
786 // this.
787 //
788 // TODO: Experiment with other fusion policies.
789 struct GreedyFusion {
790 public:
791  // The data dependence graph to traverse during fusion.
793  // Worklist of graph nodes visited during the fusion pass.
794  SmallVector<unsigned, 8> worklist;
795  // Parameter for local buffer size threshold.
796  unsigned localBufSizeThreshold;
797  // Parameter for fast memory space.
798  std::optional<unsigned> fastMemorySpace;
799  // If true, ignore any additional (redundant) computation tolerance threshold
800  // that would have prevented fusion.
801  bool maximalFusion;
802  // The amount of additional computation that is tolerated while fusing
803  // pair-wise as a fraction of the total computation.
804  double computeToleranceThreshold;
805 
807 
808  GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold,
809  std::optional<unsigned> fastMemorySpace, bool maximalFusion,
810  double computeToleranceThreshold)
811  : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
812  fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
813  computeToleranceThreshold(computeToleranceThreshold) {}
814 
815  /// Initializes 'worklist' with nodes from 'mdg'.
816  void init() {
817  // TODO: Add a priority queue for prioritizing nodes by different
818  // metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
819  worklist.clear();
820  for (auto &idAndNode : mdg->nodes) {
821  const Node &node = idAndNode.second;
822  worklist.push_back(node.id);
823  }
824  }
825  /// Run only sibling fusion on the `mdg`.
826  void runSiblingFusionOnly() {
827  fuseSiblingNodes();
828  eraseUnusedMemRefAllocations();
829  }
830 
831  /// Run only producer/consumer fusion on the `mdg`.
832  void runProducerConsumerFusionOnly() {
833  fuseProducerConsumerNodes(
834  /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
835  eraseUnusedMemRefAllocations();
836  }
837 
838  // Run the GreedyFusion pass.
839  // *) First pass through the nodes fuses single-use producer nodes into their
840  // unique consumer.
841  // *) Second pass fuses sibling nodes which share no dependence edges.
842  // *) Third pass fuses any remaining producer nodes into their users.
843  void runGreedyFusion() {
844  // TODO: Run this repeatedly until a fixed-point is reached.
845  fuseProducerConsumerNodes(/*maxSrcUserCount=*/1);
846  fuseSiblingNodes();
847  fuseProducerConsumerNodes(
848  /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
849  eraseUnusedMemRefAllocations();
850  }
851 
852  /// Returns true if a private memref can be created for `memref` given
853  /// the fusion scenario reflected by the other arguments.
854  bool canCreatePrivateMemRef(Value memref,
855  const DenseSet<Value> &srcEscapingMemRefs,
856  unsigned producerId, unsigned consumerId,
857  bool removeSrcNode) {
858  // We can't generate private memrefs if their size can't be computed.
859  if (!getMemRefIntOrFloatEltSizeInBytes(cast<MemRefType>(memref.getType())))
860  return false;
861  const Node *consumerNode = mdg->getNode(consumerId);
862  // If `memref` is an escaping one, do not create a private memref
863  // for the below scenarios, since doing so will leave the escaping
864  // memref unmodified as all the writes originally meant for the
865  // escaping memref would be performed on the private memref:
866  // 1. The source is to be removed after fusion,
867  // OR
868  // 2. The destination writes to `memref`.
869  if (srcEscapingMemRefs.count(memref) > 0 &&
870  (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
871  return false;
872 
873  // Don't create a private memref if 'srcNode' has in edges on
874  // 'memref' or 'dstNode' has out edges on 'memref'.
875  if (mdg->getIncomingMemRefAccesses(producerId, memref) > 0 ||
876  mdg->getOutEdgeCount(consumerId, memref) > 0)
877  return false;
878 
879  // If 'srcNode' will be removed but it has out edges on 'memref' to
880  // nodes other than 'dstNode', we have to preserve dependences and
881  // cannot create a private memref.
882  if (removeSrcNode &&
883  any_of(mdg->outEdges[producerId], [&](const auto &edge) {
884  return edge.value == memref && edge.id != consumerId;
885  }))
886  return false;
887 
888  return true;
889  }
890 
891  /// Perform fusions with node `dstId` as the destination of fusion, with
892  /// No fusion is performed when producers with a user count greater than
893  /// `maxSrcUserCount` for any of the memrefs involved.
894  void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) {
895  LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
896  // Skip if this node was removed (fused into another node).
897  if (mdg->nodes.count(dstId) == 0)
898  return;
899  // Get 'dstNode' into which to attempt fusion.
900  auto *dstNode = mdg->getNode(dstId);
901  // Skip if 'dstNode' is not a loop nest.
902  if (!isa<AffineForOp>(dstNode->op))
903  return;
904  // Skip if 'dstNode' is a loop nest returning values.
905  // TODO: support loop nests that return values.
906  if (dstNode->op->getNumResults() > 0)
907  return;
908 
909  LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
910 
911  // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
912  // while preserving relative order. This can increase the maximum loop
913  // depth at which we can fuse a slice of a producer loop nest into a
914  // consumer loop nest.
915  sinkSequentialLoops(dstNode);
916  auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
917 
918  // Try to fuse 'dstNode' with candidate producer loops until a fixed point
919  // is reached. Fusing two loops may expose new fusion opportunities.
920  bool dstNodeChanged;
921  do {
922  // Gather src loop candidates for 'dstNode' and visit them in "quasi"
923  // reverse program order to minimize the number of iterations needed to
924  // reach the fixed point. Note that this is a best effort approach since
925  // 'getProducerCandidates' does not always guarantee that program order
926  // in 'srcIdCandidates'.
927  dstNodeChanged = false;
928  SmallVector<unsigned, 16> srcIdCandidates;
929  getProducerCandidates(dstId, *mdg, srcIdCandidates);
930 
931  for (unsigned srcId : llvm::reverse(srcIdCandidates)) {
932  // Get 'srcNode' from which to attempt fusion into 'dstNode'.
933  auto *srcNode = mdg->getNode(srcId);
934  auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
935 
936  LLVM_DEBUG(llvm::dbgs()
937  << "Trying to fuse producer loop nest " << srcId
938  << " with consumer loop nest " << dstId << "\n");
939  LLVM_DEBUG(llvm::dbgs() << "Compute tolerance threshold: "
940  << computeToleranceThreshold << '\n');
941  LLVM_DEBUG(llvm::dbgs()
942  << "Producer loop nest:\n"
943  << *srcNode->op << "\n and consumer loop nest:\n"
944  << *dstNode->op << '\n');
945 
946  LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId
947  << " for dst loop " << dstId << "\n");
948 
949  // Skip if 'srcNode' is a loop nest returning values.
950  // TODO: support loop nests that return values.
951  if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
952  continue;
953 
954  DenseSet<Value> producerConsumerMemrefs;
955  gatherProducerConsumerMemrefs(srcId, dstId, *mdg,
956  producerConsumerMemrefs);
957 
958  // Skip if 'srcNode' out edge count on any memref is greater than
959  // 'maxSrcUserCount'.
960  if (any_of(producerConsumerMemrefs, [&](Value memref) {
961  return mdg->getOutEdgeCount(srcNode->id, memref) >
962  maxSrcUserCount;
963  }))
964  continue;
965 
966  // Gather memrefs in 'srcNode' that are written and escape out of the
967  // block (e.g., memref block arguments, returned memrefs,
968  // memrefs passed to function calls, etc.).
969  DenseSet<Value> srcEscapingMemRefs;
970  gatherEscapingMemrefs(srcNode->id, *mdg, srcEscapingMemRefs);
971 
972  // Compute an operation list insertion point for the fused loop
973  // nest which preserves dependences.
974  Operation *fusedLoopInsPoint =
975  mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
976  if (fusedLoopInsPoint == nullptr)
977  continue;
978 
979  // It's possible this fusion is at an inner depth (i.e., there are
980  // common surrounding affine loops for the source and destination for
981  // ops). We need to get this number because the call to canFuseLoops
982  // needs to be passed the absolute depth. The max legal depth and the
983  // depths we try below are however *relative* and as such don't include
984  // the common depth.
985  SmallVector<AffineForOp, 4> surroundingLoops;
986  getAffineForIVs(*dstAffineForOp, &surroundingLoops);
987  unsigned numSurroundingLoops = surroundingLoops.size();
988 
989  // Compute the innermost common loop depth for dstNode
990  // producer-consumer loads/stores.
991  SmallVector<Operation *, 2> dstMemrefOps;
992  for (Operation *op : dstNode->loads)
993  if (producerConsumerMemrefs.count(
994  cast<AffineReadOpInterface>(op).getMemRef()) > 0)
995  dstMemrefOps.push_back(op);
996  for (Operation *op : dstNode->stores)
997  if (producerConsumerMemrefs.count(
998  cast<AffineWriteOpInterface>(op).getMemRef()))
999  dstMemrefOps.push_back(op);
1000  unsigned dstLoopDepthTest =
1001  getInnermostCommonLoopDepth(dstMemrefOps) - numSurroundingLoops;
1002 
1003  // Check the feasibility of fusing src loop nest into dst loop nest
1004  // at loop depths in range [1, dstLoopDepthTest].
1005  unsigned maxLegalFusionDepth = 0;
1006  SmallVector<ComputationSliceState, 8> depthSliceUnions;
1007  depthSliceUnions.resize(dstLoopDepthTest);
1009  for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1010  FusionResult result =
1011  affine::canFuseLoops(srcAffineForOp, dstAffineForOp,
1012  /*dstLoopDepth=*/i + numSurroundingLoops,
1013  &depthSliceUnions[i - 1], strategy);
1014  if (result.value == FusionResult::Success) {
1015  maxLegalFusionDepth = i;
1016  LLVM_DEBUG(llvm::dbgs()
1017  << "Found valid slice for depth: " << i << '\n');
1018  }
1019  }
1020 
1021  if (maxLegalFusionDepth == 0) {
1022  LLVM_DEBUG(llvm::dbgs()
1023  << "Can't fuse: fusion is not legal at any depth\n");
1024  continue;
1025  }
1026 
1027  LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
1028  << maxLegalFusionDepth << '\n');
1029 
1030  double computeToleranceThresholdToUse = computeToleranceThreshold;
1031 
1032  // Cyclic dependences in the source nest may be violated when performing
1033  // slicing-based fusion. They aren't actually violated in cases where no
1034  // redundant execution of the source happens (1:1 pointwise dep on the
1035  // producer-consumer memref access for example). Check this and allow
1036  // fusion accordingly.
1037  if (hasCyclicDependence(srcAffineForOp)) {
1038  LLVM_DEBUG(llvm::dbgs() << "Source nest has a cyclic dependence.\n");
1039  // Maximal fusion does not check for compute tolerance threshold; so
1040  // perform the maximal fusion only when the redundanation computation
1041  // is zero.
1042  if (maximalFusion) {
1043  auto srcForOp = cast<AffineForOp>(srcNode->op);
1044  auto dstForOp = cast<AffineForOp>(dstNode->op);
1045  int64_t sliceCost;
1046  int64_t fusedLoopNestComputeCost;
1047  auto fraction = getAdditionalComputeFraction(
1048  srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
1049  sliceCost, fusedLoopNestComputeCost);
1050  if (!fraction || fraction > 0) {
1051  LLVM_DEBUG(
1052  llvm::dbgs()
1053  << "Can't perform maximal fusion with a cyclic dependence "
1054  "and non-zero additional compute.\n");
1055  return;
1056  }
1057  } else {
1058  // Set redundant computation tolerance to zero regardless of what
1059  // the user specified. Without this, fusion would be invalid.
1060  LLVM_DEBUG(llvm::dbgs()
1061  << "Setting compute tolerance to zero since "
1062  "source has a cylic dependence.\n");
1063  computeToleranceThresholdToUse = 0;
1064  }
1065  }
1066 
1067  // Check if fusion would be profitable. We skip profitability analysis
1068  // for maximal fusion since we already know the maximal legal depth to
1069  // fuse.
1070  unsigned bestDstLoopDepth = maxLegalFusionDepth;
1071  if (!maximalFusion) {
1072  // Retrieve producer stores from the src loop.
1073  SmallVector<Operation *, 2> producerStores;
1074  for (Operation *op : srcNode->stores)
1075  if (producerConsumerMemrefs.count(
1076  cast<AffineWriteOpInterface>(op).getMemRef()))
1077  producerStores.push_back(op);
1078 
1079  assert(!producerStores.empty() && "Expected producer store");
1080  if (!isFusionProfitable(srcAffineForOp, producerStores,
1081  dstAffineForOp, depthSliceUnions,
1082  maxLegalFusionDepth, &bestDstLoopDepth,
1083  computeToleranceThresholdToUse)) {
1084  continue;
1085  }
1086  }
1087 
1088  assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
1089  ComputationSliceState &bestSlice =
1090  depthSliceUnions[bestDstLoopDepth - 1];
1091  assert(!bestSlice.isEmpty() && "Missing slice union for depth");
1092 
1093  // Determine if 'srcId' can be removed after fusion, taking into
1094  // account remaining dependences, escaping memrefs and the fusion
1095  // insertion point.
1096  bool removeSrcNode = canRemoveSrcNodeAfterFusion(
1097  srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
1098  *mdg);
1099 
1100  DenseSet<Value> privateMemrefs;
1101  for (Value memref : producerConsumerMemrefs) {
1102  if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
1103  removeSrcNode)) {
1104  // Create a private version of this memref.
1105  LLVM_DEBUG(llvm::dbgs()
1106  << "Creating private memref for " << memref << '\n');
1107  // Create a private version of this memref.
1108  privateMemrefs.insert(memref);
1109  }
1110  }
1111 
1112  // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
1113  fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
1114  dstNodeChanged = true;
1115 
1116  LLVM_DEBUG(llvm::dbgs()
1117  << "Fused src loop " << srcId << " into dst loop " << dstId
1118  << " at depth " << bestDstLoopDepth << ":\n"
1119  << dstAffineForOp << "\n");
1120 
1121  // Move 'dstAffineForOp' before 'insertPointInst' if needed.
1122  if (fusedLoopInsPoint != dstAffineForOp)
1123  dstAffineForOp->moveBefore(fusedLoopInsPoint);
1124 
1125  // Update edges between 'srcNode' and 'dstNode'.
1126  mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
1127  removeSrcNode);
1128 
1129  // Create private memrefs.
1130  if (!privateMemrefs.empty()) {
1131  // Note the block into which fusion was performed. This can be used to
1132  // place `alloc`s that create private memrefs.
1133  Block *sliceInsertionBlock = bestSlice.insertPoint->getBlock();
1134 
1135  // Gather stores for all the private-to-be memrefs.
1136  DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
1137  dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
1138  Value storeMemRef = storeOp.getMemRef();
1139  if (privateMemrefs.count(storeMemRef) > 0)
1140  privateMemRefToStores[storeMemRef].push_back(storeOp);
1141  });
1142 
1143  // Replace original memrefs with private memrefs. Note that all the
1144  // loads and stores on these memrefs will be replaced with a new
1145  // loads and stores. Any reference to the original ones becomes
1146  // invalid after this point.
1147  for (auto &memrefToStoresPair : privateMemRefToStores) {
1148  ArrayRef<Operation *> storesForMemref = memrefToStoresPair.second;
1149  Value newMemRef = createPrivateMemRef(
1150  dstAffineForOp, storesForMemref, bestDstLoopDepth,
1151  fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
1152  if (!newMemRef)
1153  continue;
1154  // Create new node in dependence graph for 'newMemRef' alloc op.
1155  unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
1156  // Add edge from 'newMemRef' node to dstNode.
1157  mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
1158  }
1159  // One or more entries for 'newMemRef' alloc op are inserted into
1160  // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to
1161  // reallocate, update dstNode.
1162  dstNode = mdg->getNode(dstId);
1163  }
1164 
1165  // Collect dst loop stats after memref privatization transformation.
1166  LoopNestStateCollector dstLoopCollector;
1167  dstLoopCollector.collect(dstAffineForOp);
1168 
1169  // Clear and add back loads and stores.
1170  mdg->clearNodeLoadAndStores(dstNode->id);
1171  mdg->addToNode(
1172  dstId, dstLoopCollector.loadOpInsts, dstLoopCollector.storeOpInsts,
1173  dstLoopCollector.memrefLoads, dstLoopCollector.memrefStores,
1174  dstLoopCollector.memrefFrees);
1175 
1176  if (removeSrcNode) {
1177  LLVM_DEBUG(llvm::dbgs()
1178  << "Removing src loop " << srcId << " after fusion\n");
1179  // srcNode is no longer valid after it is removed from mdg.
1180  srcAffineForOp.erase();
1181  mdg->removeNode(srcId);
1182  srcNode = nullptr;
1183  }
1184  }
1185  } while (dstNodeChanged);
1186  }
1187 
1188  /// Visit each node in the graph, and for each node, attempt to fuse it with
1189  /// producer-consumer candidates. No fusion is performed when producers with a
1190  /// user count greater than `maxSrcUserCount` for any of the memrefs involved
1191  /// are encountered.
1192  void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
1193  LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
1194  init();
1195  while (!worklist.empty()) {
1196  unsigned dstId = worklist.back();
1197  worklist.pop_back();
1198  performFusionsIntoDest(dstId, maxSrcUserCount);
1199  }
1200  }
1201 
1202  // Visits each node in the graph, and for each node, attempts to fuse it with
1203  // its sibling nodes (nodes which share a parent, but no dependence edges).
1204  void fuseSiblingNodes() {
1205  LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n");
1206  init();
1207  while (!worklist.empty()) {
1208  unsigned dstId = worklist.back();
1209  worklist.pop_back();
1210 
1211  // Skip if this node was removed (fused into another node).
1212  if (mdg->nodes.count(dstId) == 0)
1213  continue;
1214  // Get 'dstNode' into which to attempt fusion.
1215  auto *dstNode = mdg->getNode(dstId);
1216  // Skip if 'dstNode' is not a loop nest.
1217  if (!isa<AffineForOp>(dstNode->op))
1218  continue;
1219  // Attempt to fuse 'dstNode' with its sibling nodes in the graph.
1220  fuseWithSiblingNodes(dstNode);
1221  }
1222  }
1223 
1224  // Attempt to fuse 'dstNode' with sibling nodes in the graph.
1225  void fuseWithSiblingNodes(Node *dstNode) {
1226  DenseSet<unsigned> visitedSibNodeIds;
1227  std::pair<unsigned, Value> idAndMemref;
1228  auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
1229 
1230  while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
1231  unsigned sibId = idAndMemref.first;
1232  Value memref = idAndMemref.second;
1233  // TODO: Check that 'sibStoreOpInst' post-dominates all other
1234  // stores to the same memref in 'sibNode' loop nest.
1235  auto *sibNode = mdg->getNode(sibId);
1236  // Compute an operation list insertion point for the fused loop
1237  // nest which preserves dependences.
1238  assert(sibNode->op->getBlock() == dstNode->op->getBlock());
1239  Operation *insertPointInst =
1240  sibNode->op->isBeforeInBlock(dstNode->op)
1241  ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id)
1242  : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id);
1243  if (insertPointInst == nullptr)
1244  continue;
1245 
1246  // Check if fusion would be profitable and at what depth.
1247 
1248  // Get unique 'sibNode' load op to 'memref'.
1249  SmallVector<Operation *, 2> sibLoadOpInsts;
1250  sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
1251  // Currently findSiblingNodeToFuse searches for siblings with one load.
1252  assert(sibLoadOpInsts.size() == 1);
1253  Operation *sibLoadOpInst = sibLoadOpInsts[0];
1254 
1255  // Gather 'dstNode' load ops to 'memref'.
1256  SmallVector<Operation *, 2> dstLoadOpInsts;
1257  dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
1258 
1259  // It's possible this fusion is at an inner depth (i.e., there are common
1260  // surrounding affine loops for the source and destination for ops). We
1261  // need to get this number because the call to canFuseLoops needs to be
1262  // passed the absolute depth. The max legal depth and the depths we try
1263  // below are however *relative* and as such don't include the common
1264  // depth.
1265  SmallVector<AffineForOp, 4> surroundingLoops;
1266  getAffineForIVs(*dstAffineForOp, &surroundingLoops);
1267  unsigned numSurroundingLoops = surroundingLoops.size();
1268  SmallVector<AffineForOp, 4> dstLoopIVs;
1269  getAffineForIVs(*dstLoadOpInsts[0], &dstLoopIVs);
1270  unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
1271  auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
1272 
1273  // Compute loop depth and slice union for fusion.
1274  SmallVector<ComputationSliceState, 8> depthSliceUnions;
1275  depthSliceUnions.resize(dstLoopDepthTest);
1276  unsigned maxLegalFusionDepth = 0;
1277  FusionStrategy strategy(memref);
1278  for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1279  FusionResult result =
1280  affine::canFuseLoops(sibAffineForOp, dstAffineForOp,
1281  /*dstLoopDepth=*/i + numSurroundingLoops,
1282  &depthSliceUnions[i - 1], strategy);
1283 
1284  if (result.value == FusionResult::Success)
1285  maxLegalFusionDepth = i;
1286  }
1287 
1288  LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
1289  << maxLegalFusionDepth << '\n');
1290 
1291  // Skip if fusion is not feasible at any loop depths.
1292  if (maxLegalFusionDepth == 0)
1293  continue;
1294 
1295  double computeToleranceThresholdToUse = computeToleranceThreshold;
1296 
1297  // Cyclic dependences in the source nest may be violated when performing
1298  // slicing-based fusion. They aren't actually violated in cases where no
1299  // redundant execution of the source happens (1:1 pointwise dep on the
1300  // producer-consumer memref access for example). Check this and allow
1301  // fusion accordingly.
1302  if (hasCyclicDependence(sibAffineForOp)) {
1303  LLVM_DEBUG(llvm::dbgs() << "Source nest has a cyclic dependence.\n");
1304  // Maximal fusion does not check for compute tolerance threshold; so
1305  // perform the maximal fusion only when the redundanation computation is
1306  // zero.
1307  if (maximalFusion) {
1308  auto dstForOp = cast<AffineForOp>(dstNode->op);
1309  int64_t sliceCost;
1310  int64_t fusedLoopNestComputeCost;
1311  auto fraction = getAdditionalComputeFraction(
1312  sibAffineForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
1313  sliceCost, fusedLoopNestComputeCost);
1314  if (!fraction || fraction > 0) {
1315  LLVM_DEBUG(
1316  llvm::dbgs()
1317  << "Can't perform maximal fusion with a cyclic dependence "
1318  "and non-zero additional compute.\n");
1319  return;
1320  }
1321  } else {
1322  // Set redundant computation tolerance to zero regardless of what the
1323  // user specified. Without this, fusion would be invalid.
1324  LLVM_DEBUG(llvm::dbgs() << "Setting compute tolerance to zero since "
1325  "source has a cyclic dependence.\n");
1326  computeToleranceThresholdToUse = 0.0;
1327  }
1328  }
1329 
1330  unsigned bestDstLoopDepth = maxLegalFusionDepth;
1331  if (!maximalFusion) {
1332  // Check if fusion would be profitable. For sibling fusion, the sibling
1333  // load op is treated as the src "store" op for fusion profitability
1334  // purposes. The footprint of the load in the slice relative to the
1335  // unfused source's determines reuse.
1336  if (!isFusionProfitable(sibAffineForOp, sibLoadOpInst, dstAffineForOp,
1337  depthSliceUnions, maxLegalFusionDepth,
1338  &bestDstLoopDepth,
1339  computeToleranceThresholdToUse))
1340  continue;
1341  }
1342 
1343  assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
1344 
1345  const ComputationSliceState &bestSlice =
1346  depthSliceUnions[bestDstLoopDepth - 1];
1347  assert(!bestSlice.isEmpty() &&
1348  "Fusion depth has no computed slice union");
1349 
1350  // Do not perform sibling fusion if it isn't maximal. We always remove the
1351  // sibling node and as such fusion shouldn't be performed if a part of the
1352  // slice is used in the destination.
1353  auto isMaximal = bestSlice.isMaximal();
1354  if (!isMaximal.value_or(false)) {
1355  LLVM_DEBUG(llvm::dbgs()
1356  << "Slice isn't maximal; not performing sibling fusion.\n");
1357  continue;
1358  }
1359 
1360  // Check if source loop is being inserted in the innermost
1361  // destination loop. Based on this, the fused loop may be optimized
1362  // further inside `fuseLoops`.
1363  bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
1364  // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
1365  affine::fuseLoops(sibAffineForOp, dstAffineForOp, bestSlice,
1366  isInnermostInsertion);
1367 
1368  auto dstForInst = cast<AffineForOp>(dstNode->op);
1369  // Update operation position of fused loop nest (if needed).
1370  if (insertPointInst != dstForInst)
1371  dstForInst->moveBefore(insertPointInst);
1372 
1373  LLVM_DEBUG(llvm::dbgs()
1374  << "Fused sibling nest " << sibId << " into destination nest "
1375  << dstNode->id << " at depth " << bestDstLoopDepth << ":\n"
1376  << dstAffineForOp << "\n");
1377 
1378  // Update data dependence graph state post fusion.
1379  updateStateAfterSiblingFusion(sibNode, dstNode);
1380 
1381  // Remove old sibling loop nest.
1382  // Get op before we invalidate the MDG node.
1383  Operation *op = sibNode->op;
1384  mdg->removeNode(sibNode->id);
1385  op->erase();
1386  }
1387  }
1388 
1389  // Searches block argument uses and the graph from 'dstNode' looking for a
1390  // fusion candidate sibling node which shares no dependences with 'dstNode'
1391  // but which loads from the same memref. Returns true and sets
1392  // 'idAndMemrefToFuse' on success. Returns false otherwise.
1393  bool findSiblingNodeToFuse(Node *dstNode,
1394  DenseSet<unsigned> *visitedSibNodeIds,
1395  std::pair<unsigned, Value> *idAndMemrefToFuse) {
1396  // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse
1397  // on 'memref'.
1398  auto canFuseWithSibNode = [&](Node *sibNode, Value memref) {
1399  // Skip if 'outEdge' is not a read-after-write dependence.
1400  // TODO: Remove restrict to single load op restriction.
1401  if (sibNode->getLoadOpCount(memref) != 1)
1402  return false;
1403  // Skip if there exists a path of dependent edges between
1404  // 'sibNode' and 'dstNode'.
1405  if (mdg->hasDependencePath(sibNode->id, dstNode->id) ||
1406  mdg->hasDependencePath(dstNode->id, sibNode->id))
1407  return false;
1408  // Skip sib node if it loads to (and stores from) the same memref on
1409  // which it also has an input dependence edge.
1410  DenseSet<Value> loadAndStoreMemrefSet;
1411  sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
1412  if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) {
1413  return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0;
1414  }))
1415  return false;
1416 
1417  // Check that all stores are to the same memref if any.
1418  DenseSet<Value> storeMemrefs;
1419  for (auto *storeOpInst : sibNode->stores) {
1420  storeMemrefs.insert(
1421  cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
1422  }
1423  return storeMemrefs.size() <= 1;
1424  };
1425 
1426  // Search for siblings which load the same memref block argument.
1427  Block *block = dstNode->op->getBlock();
1428  for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) {
1429  for (Operation *user : block->getArgument(i).getUsers()) {
1430  auto loadOp = dyn_cast<AffineReadOpInterface>(user);
1431  if (!loadOp)
1432  continue;
1433  // Gather loops surrounding 'use'.
1435  getAffineForIVs(*user, &loops);
1436  // Skip 'use' if it is not within a loop nest.
1437  // Find the surrounding affine.for nested immediately within the
1438  // block.
1439  auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
1440  return loop->getBlock() == &mdg->block;
1441  });
1442  // Skip 'use' if it is not within a loop nest in `block`.
1443  if (it == loops.end())
1444  continue;
1445  Node *sibNode = mdg->getForOpNode(*it);
1446  assert(sibNode != nullptr);
1447  // Skip 'use' if it not a sibling to 'dstNode'.
1448  if (sibNode->id == dstNode->id)
1449  continue;
1450  // Skip 'use' if it has been visited.
1451  if (visitedSibNodeIds->count(sibNode->id) > 0)
1452  continue;
1453  // Skip 'use' if it does not load from the same memref as 'dstNode'.
1454  auto memref = loadOp.getMemRef();
1455  if (dstNode->getLoadOpCount(memref) == 0)
1456  continue;
1457  // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
1458  if (canFuseWithSibNode(sibNode, memref)) {
1459  visitedSibNodeIds->insert(sibNode->id);
1460  idAndMemrefToFuse->first = sibNode->id;
1461  idAndMemrefToFuse->second = memref;
1462  return true;
1463  }
1464  }
1465  }
1466 
1467  // Search for siblings by following edges through an intermediate src node.
1468  // Collect candidate 'dstNode' input edges in 'inEdges'.
1471  dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
1472  // Add 'inEdge' if it is a read-after-write dependence.
1473  if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
1474  mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
1475  inEdges.push_back(inEdge);
1476  });
1477 
1478  // Search for sibling nodes to fuse by visiting output edges from each input
1479  // edge in 'inEdges'.
1480  for (auto &inEdge : inEdges) {
1481  // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'.
1484  inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) {
1485  unsigned sibNodeId = outEdge.id;
1486  if (visitedSibNodeIds->count(sibNodeId) > 0)
1487  return;
1488  // Skip output edge if not a sibling using the same memref.
1489  if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
1490  return;
1491  auto *sibNode = mdg->getNode(sibNodeId);
1492  if (!isa<AffineForOp>(sibNode->op))
1493  return;
1494  // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
1495  if (canFuseWithSibNode(sibNode, outEdge.value)) {
1496  // Add candidate 'outEdge' to sibling node.
1497  outEdges.push_back(outEdge);
1498  }
1499  });
1500 
1501  // Add first candidate if any were returned.
1502  if (!outEdges.empty()) {
1503  visitedSibNodeIds->insert(outEdges[0].id);
1504  idAndMemrefToFuse->first = outEdges[0].id;
1505  idAndMemrefToFuse->second = outEdges[0].value;
1506  return true;
1507  }
1508  }
1509  return false;
1510  }
1511 
1512  /// Update data dependence graph state to reflect sibling fusion of 'sibNode'
1513  /// into 'dstNode'.
1514  void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) {
1515  // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion.
1516  mdg->updateEdges(sibNode->id, dstNode->id);
1517 
1518  // Collect dst loop stats after memref privatization transformation.
1519  auto dstForInst = cast<AffineForOp>(dstNode->op);
1520  LoopNestStateCollector dstLoopCollector;
1521  dstLoopCollector.collect(dstForInst);
1522  // Clear and add back loads and stores
1523  mdg->clearNodeLoadAndStores(dstNode->id);
1524  mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
1525  dstLoopCollector.storeOpInsts, dstLoopCollector.memrefLoads,
1526  dstLoopCollector.memrefStores, dstLoopCollector.memrefFrees);
1527  }
1528 
1529  // Clean up any allocs with no users.
1530  void eraseUnusedMemRefAllocations() {
1531  for (auto &pair : mdg->memrefEdgeCount) {
1532  if (pair.second > 0)
1533  continue;
1534  auto memref = pair.first;
1535  // Skip if there exist other uses (return operation or function calls).
1536  if (!memref.use_empty())
1537  continue;
1538  // Use list expected to match the dep graph info.
1539  auto *op = memref.getDefiningOp();
1540  if (isa_and_nonnull<memref::AllocOp>(op))
1541  op->erase();
1542  }
1543  }
1544 };
1545 
1546 } // namespace
1547 
1548 /// Run fusion on `block`.
1549 void LoopFusion::runOnBlock(Block *block) {
1550  MemRefDependenceGraph g(*block);
1551  if (!g.init()) {
1552  LLVM_DEBUG(llvm::dbgs() << "MDG init failed\n");
1553  return;
1554  }
1555 
1556  std::optional<unsigned> fastMemorySpaceOpt;
1557  if (fastMemorySpace.hasValue())
1558  fastMemorySpaceOpt = fastMemorySpace;
1559  unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
1560  GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
1561  maximalFusion, computeToleranceThreshold);
1562 
1563  if (affineFusionMode == FusionMode::ProducerConsumer)
1564  fusion.runProducerConsumerFusionOnly();
1565  else if (affineFusionMode == FusionMode::Sibling)
1566  fusion.runSiblingFusionOnly();
1567  else
1568  fusion.runGreedyFusion();
1569 }
1570 
1571 void LoopFusion::runOnOperation() {
1572  // Call fusion on every op that has at least two affine.for nests (in post
1573  // order).
1574  getOperation()->walk([&](Operation *op) {
1575  for (Region &region : op->getRegions()) {
1576  for (Block &block : region.getBlocks()) {
1577  auto affineFors = block.getOps<AffineForOp>();
1578  if (!affineFors.empty() && !llvm::hasSingleElement(affineFors))
1579  runOnBlock(&block);
1580  }
1581  }
1582  });
1583 }
1584 
1586  unsigned fastMemorySpace, uint64_t localBufSizeThreshold,
1587  bool maximalFusion, enum FusionMode affineFusionMode) {
1588  return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
1589  maximalFusion, affineFusionMode);
1590 }
MemRefDependenceGraph::Node Node
Definition: Utils.cpp:39
static Operation * getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock, ArrayRef< Operation * > producerStores)
Get the operation that should act as a dominance filter while replacing memref uses with a private me...
Definition: LoopFusion.cpp:245
static void getProducerCandidates(unsigned dstId, const MemRefDependenceGraph &mdg, SmallVectorImpl< unsigned > &srcIdCandidates)
Returns in 'srcIdCandidates' the producer fusion candidates for consumer 'dstId'.
Definition: LoopFusion.cpp:136
static bool isFusionProfitable(AffineForOp srcForOp, ArrayRef< Operation * > producerStores, AffineForOp dstForOp, ArrayRef< ComputationSliceState > depthSliceUnions, unsigned maxLegalFusionDepth, unsigned *dstLoopDepth, double computeToleranceThreshold)
Definition: LoopFusion.cpp:498
static bool isEscapingMemref(Value memref, Block *block)
A memref escapes in the context of the fusion pass if either:
Definition: LoopFusion.cpp:187
static bool canRemoveSrcNodeAfterFusion(unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, Operation *fusedLoopInsPoint, const DenseSet< Value > &escapingMemRefs, const MemRefDependenceGraph &mdg)
Returns true if node 'srcId' can be removed after fusing it with node 'dstId'.
Definition: LoopFusion.cpp:77
static Value createPrivateMemRef(AffineForOp forOp, ArrayRef< Operation * > storeOps, unsigned dstLoopDepth, std::optional< unsigned > fastMemorySpace, Block *sliceInsertionBlock, uint64_t localBufSizeThreshold)
Definition: LoopFusion.cpp:334
static void gatherEscapingMemrefs(unsigned id, const MemRefDependenceGraph &mdg, DenseSet< Value > &escapingMemRefs)
Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' that escape the block or ...
Definition: LoopFusion.cpp:217
static std::optional< double > getAdditionalComputeFraction(AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth, ArrayRef< ComputationSliceState > depthSliceUnions, int64_t &sliceCost, int64_t &fusedLoopNestComputeCost)
Returns the amount of additional (redundant) computation that will be done as a fraction of the total...
Definition: LoopFusion.cpp:280
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition: Block.cpp:76
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
iterator begin()
Definition: Block.h:143
void getValues(unsigned start, unsigned end, SmallVectorImpl< Value > *values) const
Returns the Values associated with variables in range [start, end).
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
Describes the fusion strategy to be used in the Affine loop fusion utilities.
bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, AffineForOp dstForOp, LoopNestStats &dstStats, const ComputationSliceState &slice, int64_t *computeCost)
Computes and returns in 'computeCost', the total compute cost of fusing the 'slice' of the loop nest ...
void gatherProducerConsumerMemrefs(ArrayRef< Operation * > srcOps, ArrayRef< Operation * > dstOps, DenseSet< Value > &producerConsumerMemrefs)
Returns in 'producerConsumerMemrefs' the memrefs involved in a producer-consumer dependence between w...
int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats)
Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
void fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, const ComputationSliceState &srcSlice, bool isInnermostSiblingInsertionFusion=false)
Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point and source slice loop bo...
void getAffineForIVs(Operation &op, SmallVectorImpl< AffineForOp > *loops)
Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered from the outermost 'affine....
Definition: Utils.cpp:762
std::optional< int64_t > getMemoryFootprintBytes(AffineForOp forOp, int memorySpace=-1)
Gets the memory footprint of all data touched in the specified memory space in bytes; if the memory s...
Definition: Utils.cpp:2090
std::unique_ptr< Pass > createLoopFusionPass(unsigned fastMemorySpace=0, uint64_t localBufSizeThreshold=0, bool maximalFusion=false, enum FusionMode fusionMode=FusionMode::Greedy)
Creates a loop fusion pass which fuses affine loop nests at the top-level of the operation the pass i...
FusionMode
Fusion mode to attempt.
Definition: Passes.h:32
@ ProducerConsumer
Definition: Passes.h:32
unsigned getInnermostCommonLoopDepth(ArrayRef< Operation * > ops, SmallVectorImpl< AffineForOp > *surroundingLoops=nullptr)
Returns the innermost common loop depth for the set of operations in 'ops'.
Definition: Utils.cpp:1521
bool getLoopNestStats(AffineForOp forOp, LoopNestStats *stats)
Collect loop nest statistics (eg.
AffineForOp sinkSequentialLoops(AffineForOp forOp)
Definition: LoopUtils.cpp:1460
bool hasCyclicDependence(AffineForOp root)
Returns true if the affine nest rooted at root has a cyclic dependence among its affine memory access...
mlir::Block * findInnermostCommonBlockInScope(mlir::Operation *a, mlir::Operation *b)
Find the innermost common Block of a and b in the affine scope that a and b are part of.
Definition: Utils.cpp:2308
FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, ComputationSliceState *srcSlice, FusionStrategy fusionStrategy=FusionStrategy::Generic)
Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the loop nest rooted at 'dst...
LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, ArrayRef< Value > extraIndices={}, AffineMap indexRemap=AffineMap(), ArrayRef< Value > extraOperands={}, ArrayRef< Value > symbolOperands={}, Operation *domOpFilter=nullptr, Operation *postDomOpFilter=nullptr, bool allowNonDereferencingOps=false, bool replaceInDeallocOp=false)
Replaces all "dereferencing" uses of oldMemRef with newMemRef while optionally remapping the old memr...
Definition: Utils.cpp:1370
std::optional< int64_t > getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType)
Returns the memref's element type's size in bytes where the elemental type is an int or float or a ve...
Definition: Utils.cpp:1332
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their associated operands for a ...
Definition: Utils.h:310
bool isEmpty() const
Returns true if the computation slice is empty.
Definition: Utils.h:342
std::optional< bool > isMaximal() const
Returns true if the computation slice encloses all the iterations of the sliced loop nest.
Definition: Utils.cpp:1017
enum mlir::affine::FusionResult::ResultEnum value
SmallVector< Operation *, 4 > memrefFrees
Definition: Utils.h:49
SmallVector< Operation *, 4 > loadOpInsts
Definition: Utils.h:41
SmallVector< Operation *, 4 > memrefStores
Definition: Utils.h:47
void collect(Operation *opToWalk)
Definition: Utils.cpp:44
SmallVector< Operation *, 4 > memrefLoads
Definition: Utils.h:45
SmallVector< Operation *, 4 > storeOpInsts
Definition: Utils.h:43
LoopNestStats aggregates various per-loop statistics (eg.
DenseMap< unsigned, SmallVector< Edge, 2 > > outEdges
Definition: Utils.h:147
Block & block
The block for which this graph is created to perform fusion.
Definition: Utils.h:262
unsigned addNode(Operation *op)
Definition: Utils.cpp:374
void addEdge(unsigned srcId, unsigned dstId, Value value)
Definition: Utils.cpp:439
DenseMap< unsigned, Node > nodes
Definition: Utils.h:141
bool hasDependencePath(unsigned srcId, unsigned dstId) const
Definition: Utils.cpp:477
void clearNodeLoadAndStores(unsigned id)
Definition: Utils.cpp:705
const Node * getForOpNode(AffineForOp forOp) const
Definition: Utils.cpp:366
Operation * getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) const
Definition: Utils.cpp:552
void updateEdges(unsigned srcId, unsigned dstId, const DenseSet< Value > &privateMemRefs, bool removeSrcId)
Definition: Utils.cpp:628
DenseMap< unsigned, SmallVector< Edge, 2 > > inEdges
Definition: Utils.h:144
void forEachMemRefInputEdge(unsigned id, const std::function< void(Edge)> &callback)
Definition: Utils.cpp:713
unsigned getOutEdgeCount(unsigned id, Value memref=nullptr) const
Definition: Utils.cpp:528
const Node * getNode(unsigned id) const
Definition: Utils.cpp:359
void removeNode(unsigned id)
Definition: Utils.cpp:381
void forEachMemRefOutputEdge(unsigned id, const std::function< void(Edge)> &callback)
Definition: Utils.cpp:721
void addToNode(unsigned id, ArrayRef< Operation * > loads, ArrayRef< Operation * > stores, ArrayRef< Operation * > memrefLoads, ArrayRef< Operation * > memrefStores, ArrayRef< Operation * > memrefFrees)
Definition: Utils.cpp:692
unsigned getIncomingMemRefAccesses(unsigned id, Value memref) const
Definition: Utils.cpp:512
DenseMap< Value, unsigned > memrefEdgeCount
Definition: Utils.h:150
A region of a memref's data space; this is typically constructed by analyzing load/store op's on this...
Definition: Utils.h:481
std::optional< int64_t > getConstantBoundingSizeAndShape(SmallVectorImpl< int64_t > *shape=nullptr, SmallVectorImpl< AffineMap > *lbs=nullptr) const
Returns a constant upper bound on the number of elements in this region if bounded by a known constan...
Definition: Utils.cpp:1071
FlatAffineValueConstraints * getConstraints()
Definition: Utils.h:527
LogicalResult compute(Operation *op, unsigned loopDepth, const ComputationSliceState *sliceState=nullptr, bool addMemRefDimBounds=true, bool dropLocalVars=true, bool dropOuterIVs=true)
Computes the memory region accessed by this memref with the region represented as constraints symboli...
Definition: Utils.cpp:1168
std::optional< int64_t > getRegionSize()
Returns the size of this MemRefRegion in bytes.
Definition: Utils.cpp:1351
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.