MLIR  22.0.0git
LoopFusion.cpp
Go to the documentation of this file.
1 //===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements affine fusion.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/Builders.h"
25 #include "llvm/ADT/DenseMap.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/DebugLog.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <iomanip>
32 #include <optional>
33 #include <sstream>
34 
35 namespace mlir {
36 namespace affine {
37 #define GEN_PASS_DEF_AFFINELOOPFUSION
38 #include "mlir/Dialect/Affine/Passes.h.inc"
39 } // namespace affine
40 } // namespace mlir
41 
42 #define DEBUG_TYPE "affine-fusion"
43 
44 using namespace mlir;
45 using namespace mlir::affine;
46 
47 namespace {
48 /// Loop fusion pass. This pass currently supports a greedy fusion policy,
49 /// which fuses loop nests with single-writer/single-reader memref dependences
50 /// with the goal of improving locality.
51 // TODO: Support fusion of source loop nests which write to multiple
52 // memrefs, where each memref can have multiple users (if profitable).
53 struct LoopFusion : public affine::impl::AffineLoopFusionBase<LoopFusion> {
54  LoopFusion() = default;
55  LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes,
56  bool maximalFusion, enum FusionMode affineFusionMode) {
57  this->fastMemorySpace = fastMemorySpace;
58  this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024;
59  this->maximalFusion = maximalFusion;
60  this->affineFusionMode = affineFusionMode;
61  }
62 
63  void runOnBlock(Block *block);
64  void runOnOperation() override;
65 };
66 
67 } // namespace
68 
69 /// Returns true if node 'srcId' can be removed after fusing it with node
70 /// 'dstId'. The node can be removed if any of the following conditions are met:
71 /// 1. 'srcId' has no output dependences after fusion and no escaping memrefs.
72 /// 2. 'srcId' has no output dependences after fusion, has escaping memrefs
73 /// and the fusion slice is maximal.
74 /// 3. 'srcId' has output dependences after fusion, the fusion slice is
75 /// maximal and the fusion insertion point dominates all the dependences.
77  unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice,
78  Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs,
79  const MemRefDependenceGraph &mdg) {
80 
81  Operation *dstNodeOp = mdg.getNode(dstId)->op;
82  bool hasOutDepsAfterFusion = false;
83 
84  for (auto &outEdge : mdg.outEdges.lookup(srcId)) {
85  Operation *depNodeOp = mdg.getNode(outEdge.id)->op;
86  // Skip dependence with dstOp since it will be removed after fusion.
87  if (depNodeOp == dstNodeOp)
88  continue;
89 
90  // Only fusion within the same block is supported. Use domination analysis
91  // when needed.
92  if (depNodeOp->getBlock() != dstNodeOp->getBlock())
93  return false;
94 
95  // Check if the insertion point of the fused loop dominates the dependence.
96  // Otherwise, the src loop can't be removed.
97  if (fusedLoopInsPoint != depNodeOp &&
98  !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) {
99  LDBG() << "Src loop can't be removed: dst loop doesn't "
100  << "dominate dependence";
101  return false;
102  }
103 
104  hasOutDepsAfterFusion = true;
105  }
106 
107  // If src loop has dependences after fusion or it writes to an live-out or
108  // escaping memref, we can only remove it if the fusion slice is maximal so
109  // that all the dependences are preserved.
110  if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
111  std::optional<bool> isMaximal = fusionSlice.isMaximal();
112  if (!isMaximal) {
113  LDBG() << "Src loop can't be removed: can't determine "
114  << "if fusion is maximal";
115  return false;
116  }
117 
118  if (!*isMaximal) {
119  LDBG() << "Src loop can't be removed: fusion is not maximal";
120  return false;
121  }
122  }
123 
124  return true;
125 }
126 
127 /// Returns in 'srcIdCandidates' the producer fusion candidates for consumer
128 /// 'dstId'. Candidates are sorted by node id order. This order corresponds to
129 /// the program order when the 'mdg' is created. However, program order is not
130 /// guaranteed and must not be required by the client. Program order won't be
131 /// held if the 'mdg' is reused from a previous fusion step or if the node
132 /// creation order changes in the future to support more advance cases.
133 // TODO: Move this to a loop fusion utility once 'mdg' is also moved.
134 static void getProducerCandidates(unsigned dstId,
135  const MemRefDependenceGraph &mdg,
136  SmallVectorImpl<unsigned> &srcIdCandidates) {
137  // Skip if no input edges along which to fuse.
138  if (mdg.inEdges.count(dstId) == 0)
139  return;
140 
141  // Gather memrefs from loads in 'dstId'.
142  auto *dstNode = mdg.getNode(dstId);
143  DenseSet<Value> consumedMemrefs;
144  for (Operation *load : dstNode->loads)
145  consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
146 
147  // Traverse 'dstId' incoming edges and gather the nodes that contain a store
148  // to one of the consumed memrefs.
149  for (const auto &srcEdge : mdg.inEdges.lookup(dstId)) {
150  const auto *srcNode = mdg.getNode(srcEdge.id);
151  // Skip if 'srcNode' is not a loop nest.
152  if (!isa<AffineForOp>(srcNode->op))
153  continue;
154 
155  if (any_of(srcNode->stores, [&](Operation *op) {
156  auto storeOp = cast<AffineWriteOpInterface>(op);
157  return consumedMemrefs.count(storeOp.getMemRef()) > 0;
158  }))
159  srcIdCandidates.push_back(srcNode->id);
160  }
161 
162  llvm::sort(srcIdCandidates);
163  srcIdCandidates.erase(llvm::unique(srcIdCandidates), srcIdCandidates.end());
164 }
165 
166 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a
167 /// producer-consumer dependence between 'srcId' and 'dstId'.
168 static void
169 gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
170  const MemRefDependenceGraph &mdg,
171  DenseSet<Value> &producerConsumerMemrefs) {
172  auto *dstNode = mdg.getNode(dstId);
173  auto *srcNode = mdg.getNode(srcId);
174  gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads,
175  producerConsumerMemrefs);
176 }
177 
178 /// A memref escapes in the context of the fusion pass if either:
179 /// 1. it (or its alias) is a block argument, or
180 /// 2. created by an op not known to guarantee alias freedom,
181 /// 3. it (or its alias) are used by ops other than affine dereferencing ops
182 /// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops,
183 /// terminator ops, etc.); such ops do not deference the memref in an affine
184 /// way.
185 static bool isEscapingMemref(Value memref, Block *block) {
186  Operation *defOp = memref.getDefiningOp();
187  // Check if 'memref' is a block argument.
188  if (!defOp)
189  return true;
190 
191  // Check if this is defined to be an alias of another memref.
192  if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
193  if (memref == viewOp.getViewDest() &&
194  isEscapingMemref(viewOp.getViewSource(), block))
195  return true;
196 
197  // Any op besides allocating ops wouldn't guarantee alias freedom
198  if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(defOp, memref))
199  return true;
200 
201  // Check if 'memref' is used by a non-deferencing op (including unknown ones)
202  // (e.g., call ops, alias creating ops, etc.).
203  return llvm::any_of(memref.getUsers(), [&](Operation *user) {
204  // Ignore users outside of `block`.
205  Operation *ancestorOp = block->getParent()->findAncestorOpInRegion(*user);
206  if (!ancestorOp)
207  return true;
208  if (ancestorOp->getBlock() != block)
209  return false;
210  return !isa<AffineMapAccessInterface>(*user);
211  });
212 }
213 
214 /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
215 /// that escape the block or are accessed in a non-affine way.
216 static void gatherEscapingMemrefs(unsigned id, const MemRefDependenceGraph &mdg,
217  DenseSet<Value> &escapingMemRefs) {
218  auto *node = mdg.getNode(id);
219  for (Operation *storeOp : node->stores) {
220  auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
221  if (escapingMemRefs.count(memref))
222  continue;
223  if (isEscapingMemref(memref, &mdg.block))
224  escapingMemRefs.insert(memref);
225  }
226 }
227 
228 // Sinks all sequential loops to the innermost levels (while preserving
229 // relative order among them) and moves all parallel loops to the
230 // outermost (while again preserving relative order among them).
231 // This can increase the loop depth at which we can fuse a slice, since we are
232 // pushing loop carried dependence to a greater depth in the loop nest.
234  assert(isa<AffineForOp>(node->op));
235  AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op));
236  node->op = newRootForOp;
237 }
238 
239 /// Get the operation that should act as a dominance filter while replacing
240 /// memref uses with a private memref for which `producerStores` and
241 /// `sliceInsertionBlock` are provided. This effectively determines in what
242 /// part of the IR we should be performing the replacement.
243 static Operation *
245  ArrayRef<Operation *> producerStores) {
246  assert(!producerStores.empty() && "expected producer store");
247 
248  // We first find the common block that contains the producer stores and
249  // the slice computation. The first ancestor among the ancestors of the
250  // producer stores in that common block is the dominance filter to use for
251  // replacement.
252  Block *commonBlock = nullptr;
253  // Find the common block of all relevant operations.
254  for (Operation *store : producerStores) {
255  Operation *otherOp =
256  !commonBlock ? &*sliceInsertionBlock->begin() : &*commonBlock->begin();
257  commonBlock = findInnermostCommonBlockInScope(store, otherOp);
258  }
259  assert(commonBlock &&
260  "common block of producer stores and slice should exist");
261 
262  // Find the first ancestor among the ancestors of `producerStores` in
263  // `commonBlock`.
264  Operation *firstAncestor = nullptr;
265  for (Operation *store : producerStores) {
266  Operation *ancestor = commonBlock->findAncestorOpInBlock(*store);
267  assert(ancestor && "producer store should be contained in common block");
268  firstAncestor = !firstAncestor || ancestor->isBeforeInBlock(firstAncestor)
269  ? ancestor
270  : firstAncestor;
271  }
272  return firstAncestor;
273 }
274 
275 /// Returns the amount of additional (redundant) computation that will be done
276 /// as a fraction of the total computation if `srcForOp` is fused into
277 /// `dstForOp` at depth `depth`. The method returns the compute cost of the
278 /// slice and the fused nest's compute cost in the trailing output arguments.
279 static std::optional<double> getAdditionalComputeFraction(
280  AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth,
281  ArrayRef<ComputationSliceState> depthSliceUnions, int64_t &sliceCost,
282  int64_t &fusedLoopNestComputeCost) {
283  LDBG() << "Determining additional compute fraction...";
284  // Compute cost of sliced and unsliced src loop nest.
285  // Walk src loop nest and collect stats.
286  LoopNestStats srcLoopNestStats;
287  if (!getLoopNestStats(srcForOp, &srcLoopNestStats)) {
288  LDBG() << "Failed to get source loop nest stats.";
289  return std::nullopt;
290  }
291 
292  // Compute cost of dst loop nest.
293  LoopNestStats dstLoopNestStats;
294  if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) {
295  LDBG() << "Failed to get destination loop nest stats.";
296  return std::nullopt;
297  }
298 
299  // Compute op instance count for the src loop nest without iteration slicing.
300  uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
301 
302  // Compute op cost for the dst loop nest.
303  uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
304 
305  const ComputationSliceState &slice = depthSliceUnions[depth - 1];
306  // Skip slice union if it wasn't computed for this depth.
307  if (slice.isEmpty()) {
308  LDBG() << "Slice wasn't computed.";
309  return std::nullopt;
310  }
311 
312  if (!getFusionComputeCost(srcForOp, srcLoopNestStats, dstForOp,
313  dstLoopNestStats, slice,
314  &fusedLoopNestComputeCost)) {
315  LDBG() << "Unable to compute fusion compute cost";
316  return std::nullopt;
317  }
318 
319  double additionalComputeFraction =
320  fusedLoopNestComputeCost /
321  (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
322  1;
323 
324  return additionalComputeFraction;
325 }
326 
327 // Creates and returns a private (single-user) memref for fused loop rooted at
328 // 'forOp', with (potentially reduced) memref size based on the memref region
329 // written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock'
330 // specifies the block in which the slice was/will be inserted. The method
331 // expects that all stores ops to the memref have the same access function.
332 // Returns nullptr if the creation failed.
333 static Value createPrivateMemRef(AffineForOp forOp,
334  ArrayRef<Operation *> storeOps,
335  unsigned dstLoopDepth,
336  std::optional<unsigned> fastMemorySpace,
337  Block *sliceInsertionBlock,
338  uint64_t localBufSizeThreshold) {
339  assert(!storeOps.empty() && "no source stores supplied");
340 
341  // Check if all stores have the same access function; we only support this
342  // case.
343  // TODO: Use union of memref write regions to compute private memref footprint
344  // for store ops with different access functions.
345  if (storeOps.size() > 1 &&
346  !std::equal(std::next(storeOps.begin()), storeOps.end(), storeOps.begin(),
347  [](Operation *a, Operation *b) {
348  MemRefAccess aM(cast<AffineWriteOpInterface>(a));
349  MemRefAccess bM(cast<AffineWriteOpInterface>(b));
350  return aM == bM;
351  })) {
352  LDBG() << "Private memref creation unsupported for multiple producer "
353  << "stores with different access functions.";
354  return nullptr;
355  }
356 
357  Operation *srcStoreOp = storeOps[0];
358 
359  // Create builder to insert alloc op just before 'forOp'.
360  OpBuilder b(forOp);
361  // Builder to create constants at the top level.
362  OpBuilder top(forOp->getParentRegion());
363  // Create new memref type based on slice bounds.
364  auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOp).getMemRef();
365  auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
366  unsigned rank = oldMemRefType.getRank();
367 
368  // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
369  MemRefRegion region(srcStoreOp->getLoc());
370  bool validRegion = succeeded(
371  region.compute(srcStoreOp, dstLoopDepth, /*sliceState=*/nullptr,
372  /*addMemRefDimBounds=*/true, /*dropLocalVars=*/false));
373 
374  (void)validRegion;
375  assert(validRegion && "unexpected memref region failure");
376  SmallVector<int64_t, 4> newShape;
378  lbs.reserve(rank);
379  // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
380  // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
381  std::optional<int64_t> numElements =
382  region.getConstantBoundingSizeAndShape(&newShape, &lbs);
383  assert(numElements && "non-constant number of elts in local buffer");
384 
385  const FlatAffineValueConstraints *cst = region.getConstraints();
386  // 'outerIVs' holds the values that this memory region is symbolic/parametric
387  // on; this would correspond to loop IVs surrounding the level at which the
388  // slice is being materialized.
389  SmallVector<Value, 8> outerIVs;
390  cst->getValues(rank, cst->getNumDimAndSymbolVars(), &outerIVs);
391 
392  // Build 'rank' AffineExprs from MemRefRegion 'lbs'
394  offsets.reserve(rank);
395 
396  // Outer IVs are considered symbols during memref region computation. Replace
397  // them uniformly with dims so that valid IR is guaranteed.
398  SmallVector<AffineExpr> replacements;
399  for (unsigned j = 0, e = lbs[0].getNumSymbols(); j < e; ++j)
400  replacements.push_back(mlir::getAffineDimExpr(j, forOp.getContext()));
401  for (unsigned d = 0; d < rank; ++d) {
402  assert(lbs[d].getNumResults() == 1 &&
403  "invalid private memref bound calculation");
404  offsets.push_back(lbs[d].getResult(0).replaceSymbols(replacements));
405  }
406 
407  // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
408  // by 'srcStoreOpInst'.
409  auto eltSize = getMemRefIntOrFloatEltSizeInBytes(oldMemRefType);
410  assert(eltSize && "memrefs with size elt types expected");
411  uint64_t bufSize = *eltSize * *numElements;
412  Attribute newMemSpace;
413  if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
414  newMemSpace = b.getI64IntegerAttr(*fastMemorySpace);
415  } else {
416  newMemSpace = oldMemRefType.getMemorySpace();
417  }
418  auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
419  /*map=*/AffineMap(), newMemSpace);
420 
421  // Create new private memref for fused loop 'forOp'. 'newShape' is always
422  // a constant shape.
423  // TODO: Create/move alloc ops for private memrefs closer to their
424  // consumer loop nests to reduce their live range. Currently they are added
425  // at the beginning of the block, because loop nests can be reordered
426  // during the fusion pass.
427  Value newMemRef = memref::AllocOp::create(top, forOp.getLoc(), newMemRefType);
428 
429  // Build an AffineMap to remap access functions based on lower bound offsets.
430  SmallVector<AffineExpr, 4> remapExprs;
431  remapExprs.reserve(rank);
432  for (unsigned i = 0; i < rank; i++) {
433  auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
434 
435  auto remapExpr =
436  simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
437  remapExprs.push_back(remapExpr);
438  }
439 
440  auto indexRemap =
441  AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
442 
443  // Replace all users of 'oldMemRef' with 'newMemRef'.
444  Operation *domFilter =
445  getDominanceFilterForPrivateMemRefRepl(sliceInsertionBlock, storeOps);
446  auto userFilterFn = [&](Operation *user) {
447  auto domInfo = std::make_unique<DominanceInfo>(
448  domFilter->getParentOfType<FunctionOpInterface>());
449  return domInfo->dominates(domFilter, user);
450  };
451  LogicalResult res = replaceAllMemRefUsesWith(
452  oldMemRef, newMemRef, /*extraIndices=*/{}, indexRemap,
453  /*extraOperands=*/outerIVs,
454  /*symbolOperands=*/{}, userFilterFn);
455  assert(succeeded(res) &&
456  "replaceAllMemrefUsesWith should always succeed here");
457  (void)res;
458  LDBG() << "Created private memref of type: " << newMemRefType;
459  return newMemRef;
460 }
461 
462 // Checks the profitability of fusing a backwards slice of the loop nest
463 // `srcForOp` into the loop nest surrounding 'dstLoadOpInsts'. The argument
464 // 'srcStoreOpInst' is used to calculate the storage reduction on the memref
465 // being produced and consumed, which is an input to the cost model. For
466 // producer-consumer fusion, 'srcStoreOpInst' will be the same as 'srcOpInst',
467 // as we are slicing w.r.t to that producer. For input-reuse fusion, 'srcOpInst'
468 // will be the src loop nest LoadOp which reads from the same memref as dst loop
469 // nest load ops, and 'srcStoreOpInst' will be the unique store op in the src
470 // node, which will be used to check that the write region is the same after
471 // input-reuse fusion. Computation slices are provided in 'depthSliceUnions' for
472 // each legal fusion depth. The maximal depth at which fusion is legal is
473 // provided in 'maxLegalFusionDepth'. Returns true if it is profitable to fuse
474 // the candidate loop nests. Returns false otherwise. `dstLoopDepth` is set to
475 // the most profitable depth at which to materialize the source loop nest slice.
476 // The profitability model executes the following steps:
477 // *) Computes the backward computation slice at 'srcOpInst'. This
478 // computation slice of the loop nest surrounding 'srcOpInst' is
479 // represented by modified src loop bounds in 'sliceState', which are
480 // functions of loop IVs in the loop nest surrounding 'srcOpInst'.
481 // *) Computes the cost of unfused src/dst loop nests (currently the cost of a
482 // loop nest is the total number of dynamic operation instances in the loop
483 // nest).
484 // *) Computes the cost of fusing a slice of the src loop nest into the dst
485 // loop nest at various values of dst loop depth, attempting to fuse
486 // the largest computation slice at the maximal dst loop depth (closest to
487 // the load) to minimize reuse distance and potentially enable subsequent
488 // load/store forwarding.
489 // NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
490 // nest, at which the src computation slice is inserted/fused.
491 // NOTE: We attempt to maximize the dst loop depth, but there are cases
492 // where a particular setting for 'dstLoopNest' might fuse an unsliced
493 // loop (within the src computation slice) at a depth which results in
494 // excessive recomputation (see unit tests for examples).
495 // *) Compares the total cost of the unfused loop nests to the min cost fused
496 // loop nest computed in the previous step, and returns true if the latter
497 // is lower.
498 // TODO: Extend profitability analysis to support scenarios with multiple
499 // stores.
500 static bool isFusionProfitable(AffineForOp srcForOp,
501  ArrayRef<Operation *> producerStores,
502  AffineForOp dstForOp,
503  ArrayRef<ComputationSliceState> depthSliceUnions,
504  unsigned maxLegalFusionDepth,
505  unsigned *dstLoopDepth,
506  double computeToleranceThreshold) {
507  LDBG() << "Checking whether fusion is profitable between source nest:";
508  LDBG() << ' ' << srcForOp << " and destination nest:";
509  LDBG() << dstForOp;
510 
511  if (maxLegalFusionDepth == 0) {
512  LDBG() << "Can't fuse: maxLegalFusionDepth is 0";
513  return false;
514  }
515 
516  // Compute cost of sliced and unsliced src loop nest.
517 
518  // Walk src loop nest and collect stats.
519  LoopNestStats srcLoopNestStats;
520  if (!getLoopNestStats(srcForOp, &srcLoopNestStats))
521  return false;
522 
523  // Compute cost of dst loop nest.
524  LoopNestStats dstLoopNestStats;
525  if (!getLoopNestStats(dstForOp, &dstLoopNestStats))
526  return false;
527 
528  // We limit profitability analysis to only scenarios with
529  // a single producer store for now. Note that some multi-store
530  // producer scenarios will still go through profitability analysis
531  // if only one of the stores is involved in the producer-consumer
532  // relationship of the candidate loops.
533  // TODO: Suppport multiple producer stores in profitability
534  // analysis.
535  if (producerStores.size() > 1) {
536  LDBG() << "Limited profitability analysis. Not "
537  << "supported for multiple producer store case.";
538  int64_t sliceCost;
539  int64_t fusedLoopNestComputeCost;
540  // We will still fuse if fusion obeys the specified compute
541  // tolerance at the max legal depth.
542  auto fraction = getAdditionalComputeFraction(
543  srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost,
544  fusedLoopNestComputeCost);
545  if (!fraction || fraction > computeToleranceThreshold) {
546  LDBG() << "Additional computation exceeds "
547  << "compute tolerance. Not fusing.";
548  return false;
549  }
550  LDBG() << "Considering fusion profitable at max legal depth.";
551  return true;
552  }
553 
554  Operation *srcStoreOp = producerStores.front();
555 
556  // Search for min cost value for 'dstLoopDepth'. At each value of
557  // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice
558  // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
559  // of these bounds). Next the union slice bounds are used to calculate
560  // the cost of the slice and the cost of the slice inserted into the dst
561  // loop nest at 'dstLoopDepth'.
562  uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
563  double maxStorageReduction = 0.0;
564  std::optional<uint64_t> sliceMemEstimate;
565 
566  // The best loop depth at which to materialize the slice.
567  std::optional<unsigned> bestDstLoopDepth;
568 
569  // Compute src loop nest write region size.
570  MemRefRegion srcWriteRegion(srcStoreOp->getLoc());
571  if (failed(srcWriteRegion.compute(srcStoreOp, /*loopDepth=*/0))) {
572  LDBG() << "Unable to compute MemRefRegion for source operation";
573  return false;
574  }
575 
576  std::optional<int64_t> maybeSrcWriteRegionSizeBytes =
577  srcWriteRegion.getRegionSize();
578  if (!maybeSrcWriteRegionSizeBytes.has_value())
579  return false;
580  int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;
581 
582  // Compute op instance count for the src loop nest without iteration slicing.
583  uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
584 
585  // Compute op instance count for the destination loop nest.
586  uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
587 
588  // Evaluate all depth choices for materializing the slice in the destination
589  // loop nest.
590  for (unsigned i = maxLegalFusionDepth; i >= 1; --i) {
591  const ComputationSliceState &slice = depthSliceUnions[i - 1];
592  // Skip slice union if it wasn't computed for this depth.
593  if (slice.isEmpty())
594  continue;
595 
596  // Compute cost of the slice separately, i.e, the compute cost of the slice
597  // if all outer trip counts are one.
598  int64_t sliceCost;
599 
600  int64_t fusedLoopNestComputeCost;
601 
602  auto mayAdditionalComputeFraction =
603  getAdditionalComputeFraction(srcForOp, dstForOp, i, depthSliceUnions,
604  sliceCost, fusedLoopNestComputeCost);
605  if (!mayAdditionalComputeFraction) {
606  LDBG() << "Can't determine additional compute fraction.";
607  continue;
608  }
609  double additionalComputeFraction = *mayAdditionalComputeFraction;
610 
611  // Determine what the slice write MemRefRegion would be, if the src loop
612  // nest slice 'slice' were to be inserted into the dst loop nest at loop
613  // depth 'i'.
614  MemRefRegion sliceWriteRegion(srcStoreOp->getLoc());
615  if (failed(sliceWriteRegion.compute(srcStoreOp, /*loopDepth=*/0, &slice))) {
616  LDBG() << "Failed to compute slice write region at loopDepth: " << i;
617  continue;
618  }
619 
620  std::optional<int64_t> maybeSliceWriteRegionSizeBytes =
621  sliceWriteRegion.getRegionSize();
622  if (!maybeSliceWriteRegionSizeBytes.has_value() ||
623  *maybeSliceWriteRegionSizeBytes == 0) {
624  LDBG() << "Failed to get slice write region size at loopDepth: " << i;
625  continue;
626  }
627  int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
628 
629  double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
630  static_cast<double>(sliceWriteRegionSizeBytes);
631 
632  LLVM_DEBUG({
633  std::stringstream msg;
634  msg << " evaluating fusion profitability at depth : " << i << "\n"
635  << std::fixed << std::setprecision(2)
636  << " additional compute fraction: "
637  << 100.0 * additionalComputeFraction << "%\n"
638  << " storage reduction factor: " << storageReduction << "x\n"
639  << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
640  << " src write region size: " << srcWriteRegionSizeBytes << "\n"
641  << " slice write region size: " << sliceWriteRegionSizeBytes;
642  LDBG() << msg.str();
643  });
644 
645  // TODO: This is a placeholder cost model.
646  // Among all choices that add an acceptable amount of redundant computation
647  // (as per computeToleranceThreshold), we will simply pick the one that
648  // reduces the intermediary size the most.
649  if ((storageReduction > maxStorageReduction) &&
650  (additionalComputeFraction <= computeToleranceThreshold)) {
651  maxStorageReduction = storageReduction;
652  bestDstLoopDepth = i;
653  minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
654  sliceMemEstimate = sliceWriteRegionSizeBytes;
655  }
656  }
657 
658  // A simple cost model: fuse if it reduces the memory footprint.
659 
660  if (!bestDstLoopDepth) {
661  LDBG() << "All fusion choices involve more than the threshold amount of "
662  << "redundant computation; NOT fusing.";
663  return false;
664  }
665 
666  if (!bestDstLoopDepth) {
667  LDBG() << "no fusion depth could be evaluated.";
668  return false;
669  }
670 
671  // Set dstLoopDepth based on best values from search.
672  *dstLoopDepth = *bestDstLoopDepth;
673 
674  LDBG() << " LoopFusion fusion stats:";
675  LDBG() << " best loop depth: " << bestDstLoopDepth;
676  LDBG() << " src loop nest compute cost: " << srcLoopNestCost;
677  LDBG() << " dst loop nest compute cost: " << dstLoopNestCost;
678  LDBG() << " fused loop nest compute cost: " << minFusedLoopNestComputeCost;
679 
680  auto dstMemSize = getMemoryFootprintBytes(dstForOp);
681  auto srcMemSize = getMemoryFootprintBytes(srcForOp);
682 
683  std::optional<double> storageReduction;
684 
685  if (!dstMemSize || !srcMemSize) {
686  LDBG() << " fusion memory benefit cannot be evaluated; NOT fusing.";
687  return false;
688  }
689 
690  auto srcMemSizeVal = *srcMemSize;
691  auto dstMemSizeVal = *dstMemSize;
692 
693  assert(sliceMemEstimate && "expected value");
694  auto fusedMem = dstMemSizeVal + *sliceMemEstimate;
695 
696  LDBG() << " src mem: " << srcMemSizeVal;
697  LDBG() << " dst mem: " << dstMemSizeVal;
698  LDBG() << " fused mem: " << fusedMem;
699  LDBG() << " slice mem: " << sliceMemEstimate;
700 
701  if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
702  LDBG() << "Fusion is not profitable; NOT fusing.";
703  return false;
704  }
705  storageReduction =
706  100.0 *
707  (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
708 
709  double additionalComputeFraction =
710  100.0 * (minFusedLoopNestComputeCost /
711  (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
712  1);
713  (void)additionalComputeFraction;
714  LLVM_DEBUG({
715  std::stringstream msg;
716  msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
717  << std::setprecision(2) << additionalComputeFraction
718  << "% redundant computation and a ";
719  msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>");
720  msg << "% storage reduction.";
721  LDBG() << msg.str();
722  });
723 
724  return true;
725 }
726 
727 namespace {
728 
729 // GreedyFusion greedily fuses loop nests which have a producer/consumer or
730 // input-reuse relationship on a memref, with the goal of improving locality.
731 //
732 // The steps of the producer-consumer fusion algorithm are as follows:
733 //
734 // *) A worklist is initialized with node ids from the dependence graph.
735 // *) For each node id in the worklist:
736 // *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a
737 // candidate destination AffineForOp into which fusion will be attempted.
738 // *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
739 // *) For each LoadOp in 'dstLoadOps' do:
740 // *) Look up dependent loop nests which have a single store op to the same
741 // memref.
742 // *) Check if dependences would be violated by the fusion.
743 // *) Get a computation slice of 'srcLoopNest', which adjusts its loop
744 // bounds to be functions of 'dstLoopNest' IVs and symbols.
745 // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
746 // at a loop depth determined by the cost model in 'isFusionProfitable'.
747 // *) Add the newly fused load/store operations to the state,
748 // and also add newly fused load ops to 'dstLoopOps' to be considered
749 // as fusion dst load ops in another iteration.
750 // *) Remove old src loop nest and its associated state.
751 //
752 // The steps of the input-reuse fusion algorithm are as follows:
753 //
754 // *) Initialize 'worklist' with node ids from the dependence graph.
755 // *) For each 'dstNode' in the worklist:
756 // *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which
757 // loads from the same memref, but which has no dependence paths to/from.
758 // *) Get a computation slice of 'sibLoopNest', which adjusts its loop
759 // bounds to be functions of 'dstLoopNest' IVs and symbols.
760 // *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest',
761 // at a loop depth determined by the cost model in 'isFusionProfitable'.
762 // This function also checks that the memref write region of 'sibLoopNest',
763 // is preserved in the fused loop nest.
764 // *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'.
765 //
766 // Given a graph where top-level operations are vertices in the set 'V' and
767 // edges in the set 'E' are dependences between vertices, this algorithm
768 // takes O(V) time for initialization, and has runtime O(V + E).
769 //
770 // This greedy algorithm is not 'maximal' due to the current restriction of
771 // fusing along single producer consumer edges, but there is a TODO: to fix
772 // this.
773 //
774 // TODO: Experiment with other fusion policies.
775 struct GreedyFusion {
776 public:
777  // The data dependence graph to traverse during fusion.
779  // Worklist of graph nodes visited during the fusion pass.
780  SmallVector<unsigned, 8> worklist;
781  // Parameter for local buffer size threshold.
782  unsigned localBufSizeThreshold;
783  // Parameter for fast memory space.
784  std::optional<unsigned> fastMemorySpace;
785  // If true, ignore any additional (redundant) computation tolerance threshold
786  // that would have prevented fusion.
787  bool maximalFusion;
788  // The amount of additional computation that is tolerated while fusing
789  // pair-wise as a fraction of the total computation.
790  double computeToleranceThreshold;
791 
793 
794  GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold,
795  std::optional<unsigned> fastMemorySpace, bool maximalFusion,
796  double computeToleranceThreshold)
797  : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
798  fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
799  computeToleranceThreshold(computeToleranceThreshold) {}
800 
801  /// Initializes 'worklist' with nodes from 'mdg'.
802  void init() {
803  // TODO: Add a priority queue for prioritizing nodes by different
804  // metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
805  worklist.clear();
806  for (auto &idAndNode : mdg->nodes) {
807  const Node &node = idAndNode.second;
808  worklist.push_back(node.id);
809  }
810  }
811  /// Run only sibling fusion on the `mdg`.
812  void runSiblingFusionOnly() {
813  fuseSiblingNodes();
814  eraseUnusedMemRefAllocations();
815  }
816 
817  /// Run only producer/consumer fusion on the `mdg`.
818  void runProducerConsumerFusionOnly() {
819  fuseProducerConsumerNodes(
820  /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
821  eraseUnusedMemRefAllocations();
822  }
823 
824  // Run the GreedyFusion pass.
825  // *) First pass through the nodes fuses single-use producer nodes into their
826  // unique consumer.
827  // *) Second pass fuses sibling nodes which share no dependence edges.
828  // *) Third pass fuses any remaining producer nodes into their users.
829  void runGreedyFusion() {
830  // TODO: Run this repeatedly until a fixed-point is reached.
831  fuseProducerConsumerNodes(/*maxSrcUserCount=*/1);
832  fuseSiblingNodes();
833  fuseProducerConsumerNodes(
834  /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
835  eraseUnusedMemRefAllocations();
836  }
837 
838  /// Returns true if a private memref can be created for `memref` given
839  /// the fusion scenario reflected by the other arguments.
840  bool canCreatePrivateMemRef(Value memref,
841  const DenseSet<Value> &srcEscapingMemRefs,
842  unsigned producerId, unsigned consumerId,
843  bool removeSrcNode) {
844  // We can't generate private memrefs if their size can't be computed.
845  if (!getMemRefIntOrFloatEltSizeInBytes(cast<MemRefType>(memref.getType())))
846  return false;
847  const Node *consumerNode = mdg->getNode(consumerId);
848  // If `memref` is an escaping one, do not create a private memref
849  // for the below scenarios, since doing so will leave the escaping
850  // memref unmodified as all the writes originally meant for the
851  // escaping memref would be performed on the private memref:
852  // 1. The source is to be removed after fusion,
853  // OR
854  // 2. The destination writes to `memref`.
855  if (srcEscapingMemRefs.count(memref) > 0 &&
856  (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
857  return false;
858 
859  // Don't create a private memref if 'srcNode' has in edges on
860  // 'memref' or 'dstNode' has out edges on 'memref'.
861  if (mdg->getIncomingMemRefAccesses(producerId, memref) > 0 ||
862  mdg->getOutEdgeCount(consumerId, memref) > 0)
863  return false;
864 
865  // If 'srcNode' will be removed but it has out edges on 'memref' to
866  // nodes other than 'dstNode', we have to preserve dependences and
867  // cannot create a private memref.
868  if (removeSrcNode &&
869  any_of(mdg->outEdges[producerId], [&](const auto &edge) {
870  return edge.value == memref && edge.id != consumerId;
871  }))
872  return false;
873 
874  return true;
875  }
876 
877  /// Perform fusions with node `dstId` as the destination of fusion, with
878  /// No fusion is performed when producers with a user count greater than
879  /// `maxSrcUserCount` for any of the memrefs involved.
880  void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) {
881  LDBG() << "Evaluating dst loop " << dstId;
882  // Skip if this node was removed (fused into another node).
883  if (mdg->nodes.count(dstId) == 0)
884  return;
885  // Get 'dstNode' into which to attempt fusion.
886  auto *dstNode = mdg->getNode(dstId);
887  // Skip if 'dstNode' is not a loop nest.
888  if (!isa<AffineForOp>(dstNode->op))
889  return;
890  // Skip if 'dstNode' is a loop nest returning values.
891  // TODO: support loop nests that return values.
892  if (dstNode->op->getNumResults() > 0)
893  return;
894 
895  LDBG() << "Evaluating dst loop " << dstId;
896 
897  // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
898  // while preserving relative order. This can increase the maximum loop
899  // depth at which we can fuse a slice of a producer loop nest into a
900  // consumer loop nest.
901  sinkSequentialLoops(dstNode);
902  auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
903 
904  // Try to fuse 'dstNode' with candidate producer loops until a fixed point
905  // is reached. Fusing two loops may expose new fusion opportunities.
906  bool dstNodeChanged;
907  do {
908  // Gather src loop candidates for 'dstNode' and visit them in "quasi"
909  // reverse program order to minimize the number of iterations needed to
910  // reach the fixed point. Note that this is a best effort approach since
911  // 'getProducerCandidates' does not always guarantee that program order
912  // in 'srcIdCandidates'.
913  dstNodeChanged = false;
914  SmallVector<unsigned, 16> srcIdCandidates;
915  getProducerCandidates(dstId, *mdg, srcIdCandidates);
916 
917  for (unsigned srcId : llvm::reverse(srcIdCandidates)) {
918  // Get 'srcNode' from which to attempt fusion into 'dstNode'.
919  auto *srcNode = mdg->getNode(srcId);
920  auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
921 
922  LDBG() << "Trying to fuse producer loop nest " << srcId
923  << " with consumer loop nest " << dstId;
924  LDBG() << "Compute tolerance threshold: " << computeToleranceThreshold;
925  LDBG() << "Producer loop nest:";
926  LDBG() << *srcNode->op << " and consumer loop nest:";
927  LDBG() << *dstNode->op;
928 
929  LDBG() << "Evaluating src loop " << srcId << " for dst loop " << dstId;
930 
931  // Skip if 'srcNode' is a loop nest returning values.
932  // TODO: support loop nests that return values.
933  if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
934  continue;
935 
936  DenseSet<Value> producerConsumerMemrefs;
937  gatherProducerConsumerMemrefs(srcId, dstId, *mdg,
938  producerConsumerMemrefs);
939 
940  // Skip if 'srcNode' out edge count on any memref is greater than
941  // 'maxSrcUserCount'.
942  if (any_of(producerConsumerMemrefs, [&](Value memref) {
943  return mdg->getOutEdgeCount(srcNode->id, memref) >
944  maxSrcUserCount;
945  }))
946  continue;
947 
948  // Gather memrefs in 'srcNode' that are written and escape out of the
949  // block (e.g., memref block arguments, returned memrefs,
950  // memrefs passed to function calls, etc.).
951  DenseSet<Value> srcEscapingMemRefs;
952  gatherEscapingMemrefs(srcNode->id, *mdg, srcEscapingMemRefs);
953 
954  // Compute an operation list insertion point for the fused loop
955  // nest which preserves dependences.
956  Operation *fusedLoopInsPoint =
957  mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
958  if (fusedLoopInsPoint == nullptr)
959  continue;
960 
961  // It's possible this fusion is at an inner depth (i.e., there are
962  // common surrounding affine loops for the source and destination for
963  // ops). We need to get this number because the call to canFuseLoops
964  // needs to be passed the absolute depth. The max legal depth and the
965  // depths we try below are however *relative* and as such don't include
966  // the common depth.
967  SmallVector<AffineForOp, 4> surroundingLoops;
968  getAffineForIVs(*dstAffineForOp, &surroundingLoops);
969  unsigned numSurroundingLoops = surroundingLoops.size();
970 
971  // Compute the innermost common loop depth for dstNode
972  // producer-consumer loads/stores.
973  SmallVector<Operation *, 2> dstMemrefOps;
974  for (Operation *op : dstNode->loads)
975  if (producerConsumerMemrefs.count(
976  cast<AffineReadOpInterface>(op).getMemRef()) > 0)
977  dstMemrefOps.push_back(op);
978  for (Operation *op : dstNode->stores)
979  if (producerConsumerMemrefs.count(
980  cast<AffineWriteOpInterface>(op).getMemRef()))
981  dstMemrefOps.push_back(op);
982  if (dstMemrefOps.empty())
983  continue;
984  unsigned dstLoopDepthTest =
985  getInnermostCommonLoopDepth(dstMemrefOps) - numSurroundingLoops;
986 
987  // Check the feasibility of fusing src loop nest into dst loop nest
988  // at loop depths in range [1, dstLoopDepthTest].
989  unsigned maxLegalFusionDepth = 0;
990  SmallVector<ComputationSliceState, 8> depthSliceUnions;
991  depthSliceUnions.resize(dstLoopDepthTest);
993  for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
994  FusionResult result =
995  affine::canFuseLoops(srcAffineForOp, dstAffineForOp,
996  /*dstLoopDepth=*/i + numSurroundingLoops,
997  &depthSliceUnions[i - 1], strategy);
998  if (result.value == FusionResult::Success) {
999  maxLegalFusionDepth = i;
1000  LDBG() << "Found valid slice for depth: " << i;
1001  }
1002  }
1003 
1004  if (maxLegalFusionDepth == 0) {
1005  LDBG() << "Can't fuse: fusion is not legal at any depth";
1006  continue;
1007  }
1008 
1009  LDBG() << "Max legal depth for fusion: " << maxLegalFusionDepth;
1010 
1011  double computeToleranceThresholdToUse = computeToleranceThreshold;
1012 
1013  // Cyclic dependences in the source nest may be violated when performing
1014  // slicing-based fusion. They aren't actually violated in cases where no
1015  // redundant execution of the source happens (1:1 pointwise dep on the
1016  // producer-consumer memref access for example). Check this and allow
1017  // fusion accordingly.
1018  if (hasCyclicDependence(srcAffineForOp)) {
1019  LDBG() << "Source nest has a cyclic dependence.";
1020  // Maximal fusion does not check for compute tolerance threshold; so
1021  // perform the maximal fusion only when the redundanation computation
1022  // is zero.
1023  if (maximalFusion) {
1024  auto srcForOp = cast<AffineForOp>(srcNode->op);
1025  auto dstForOp = cast<AffineForOp>(dstNode->op);
1026  int64_t sliceCost;
1027  int64_t fusedLoopNestComputeCost;
1028  auto fraction = getAdditionalComputeFraction(
1029  srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
1030  sliceCost, fusedLoopNestComputeCost);
1031  if (!fraction || fraction > 0) {
1032  LDBG() << "Can't perform maximal fusion with a cyclic dependence "
1033  << "and non-zero additional compute.";
1034  return;
1035  }
1036  } else {
1037  // Set redundant computation tolerance to zero regardless of what
1038  // the user specified. Without this, fusion would be invalid.
1039  LDBG() << "Setting compute tolerance to zero since "
1040  << "source has a cylic dependence.";
1041  computeToleranceThresholdToUse = 0;
1042  }
1043  }
1044 
1045  // Check if fusion would be profitable. We skip profitability analysis
1046  // for maximal fusion since we already know the maximal legal depth to
1047  // fuse.
1048  unsigned bestDstLoopDepth = maxLegalFusionDepth;
1049  if (!maximalFusion) {
1050  // Retrieve producer stores from the src loop.
1051  SmallVector<Operation *, 2> producerStores;
1052  for (Operation *op : srcNode->stores)
1053  if (producerConsumerMemrefs.count(
1054  cast<AffineWriteOpInterface>(op).getMemRef()))
1055  producerStores.push_back(op);
1056 
1057  assert(!producerStores.empty() && "Expected producer store");
1058  if (!isFusionProfitable(srcAffineForOp, producerStores,
1059  dstAffineForOp, depthSliceUnions,
1060  maxLegalFusionDepth, &bestDstLoopDepth,
1061  computeToleranceThresholdToUse)) {
1062  continue;
1063  }
1064  }
1065 
1066  assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
1067  ComputationSliceState &bestSlice =
1068  depthSliceUnions[bestDstLoopDepth - 1];
1069  assert(!bestSlice.isEmpty() && "Missing slice union for depth");
1070 
1071  // Determine if 'srcId' can be removed after fusion, taking into
1072  // account remaining dependences, escaping memrefs and the fusion
1073  // insertion point.
1074  bool removeSrcNode = canRemoveSrcNodeAfterFusion(
1075  srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
1076  *mdg);
1077 
1078  DenseSet<Value> privateMemrefs;
1079  for (Value memref : producerConsumerMemrefs) {
1080  if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
1081  removeSrcNode)) {
1082  // Create a private version of this memref.
1083  LDBG() << "Creating private memref for " << memref;
1084  // Create a private version of this memref.
1085  privateMemrefs.insert(memref);
1086  }
1087  }
1088 
1089  // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
1090  fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
1091  dstNodeChanged = true;
1092 
1093  LDBG() << "Fused src loop " << srcId << " into dst loop " << dstId
1094  << " at depth " << bestDstLoopDepth << ":";
1095  LDBG() << dstAffineForOp;
1096 
1097  // Move 'dstAffineForOp' before 'insertPointInst' if needed.
1098  if (fusedLoopInsPoint != dstAffineForOp)
1099  dstAffineForOp->moveBefore(fusedLoopInsPoint);
1100 
1101  // Update edges between 'srcNode' and 'dstNode'.
1102  mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
1103  removeSrcNode);
1104 
1105  // Create private memrefs.
1106  if (!privateMemrefs.empty()) {
1107  // Note the block into which fusion was performed. This can be used to
1108  // place `alloc`s that create private memrefs.
1109  Block *sliceInsertionBlock = bestSlice.insertPoint->getBlock();
1110 
1111  // Gather stores for all the private-to-be memrefs.
1112  DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
1113  dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
1114  Value storeMemRef = storeOp.getMemRef();
1115  if (privateMemrefs.count(storeMemRef) > 0)
1116  privateMemRefToStores[storeMemRef].push_back(storeOp);
1117  });
1118 
1119  // Replace original memrefs with private memrefs. Note that all the
1120  // loads and stores on these memrefs will be replaced with a new
1121  // loads and stores. Any reference to the original ones becomes
1122  // invalid after this point.
1123  for (auto &memrefToStoresPair : privateMemRefToStores) {
1124  ArrayRef<Operation *> storesForMemref = memrefToStoresPair.second;
1125  Value newMemRef = createPrivateMemRef(
1126  dstAffineForOp, storesForMemref, bestDstLoopDepth,
1127  fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
1128  if (!newMemRef)
1129  continue;
1130  // Create new node in dependence graph for 'newMemRef' alloc op.
1131  unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
1132  // Add edge from 'newMemRef' node to dstNode.
1133  mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
1134  }
1135  // One or more entries for 'newMemRef' alloc op are inserted into
1136  // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to
1137  // reallocate, update dstNode.
1138  dstNode = mdg->getNode(dstId);
1139  }
1140 
1141  // Collect dst loop stats after memref privatization transformation.
1142  LoopNestStateCollector dstLoopCollector;
1143  dstLoopCollector.collect(dstAffineForOp);
1144 
1145  // Clear and add back loads and stores.
1146  mdg->clearNodeLoadAndStores(dstNode->id);
1147  mdg->addToNode(
1148  dstId, dstLoopCollector.loadOpInsts, dstLoopCollector.storeOpInsts,
1149  dstLoopCollector.memrefLoads, dstLoopCollector.memrefStores,
1150  dstLoopCollector.memrefFrees);
1151 
1152  if (removeSrcNode) {
1153  LDBG() << "Removing src loop " << srcId << " after fusion";
1154  // srcNode is no longer valid after it is removed from mdg.
1155  srcAffineForOp.erase();
1156  mdg->removeNode(srcId);
1157  srcNode = nullptr;
1158  }
1159  }
1160  } while (dstNodeChanged);
1161  }
1162 
1163  /// Visit each node in the graph, and for each node, attempt to fuse it with
1164  /// producer-consumer candidates. No fusion is performed when producers with a
1165  /// user count greater than `maxSrcUserCount` for any of the memrefs involved
1166  /// are encountered.
1167  void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
1168  LDBG() << "--- Producer/Consumer Fusion ---";
1169  init();
1170  while (!worklist.empty()) {
1171  unsigned dstId = worklist.back();
1172  worklist.pop_back();
1173  performFusionsIntoDest(dstId, maxSrcUserCount);
1174  }
1175  }
1176 
1177  // Visits each node in the graph, and for each node, attempts to fuse it with
1178  // its sibling nodes (nodes which share a parent, but no dependence edges).
1179  void fuseSiblingNodes() {
1180  LDBG() << "--- Sibling Fusion ---";
1181  init();
1182  while (!worklist.empty()) {
1183  unsigned dstId = worklist.back();
1184  worklist.pop_back();
1185 
1186  // Skip if this node was removed (fused into another node).
1187  if (mdg->nodes.count(dstId) == 0)
1188  continue;
1189  // Get 'dstNode' into which to attempt fusion.
1190  auto *dstNode = mdg->getNode(dstId);
1191  // Skip if 'dstNode' is not a loop nest.
1192  if (!isa<AffineForOp>(dstNode->op))
1193  continue;
1194  // Attempt to fuse 'dstNode' with its sibling nodes in the graph.
1195  fuseWithSiblingNodes(dstNode);
1196  }
1197  }
1198 
1199  // Attempt to fuse 'dstNode' with sibling nodes in the graph.
1200  void fuseWithSiblingNodes(Node *dstNode) {
1201  DenseSet<unsigned> visitedSibNodeIds;
1202  std::pair<unsigned, Value> idAndMemref;
1203  auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
1204 
1205  while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
1206  unsigned sibId = idAndMemref.first;
1207  Value memref = idAndMemref.second;
1208  // TODO: Check that 'sibStoreOpInst' post-dominates all other
1209  // stores to the same memref in 'sibNode' loop nest.
1210  auto *sibNode = mdg->getNode(sibId);
1211  // Compute an operation list insertion point for the fused loop
1212  // nest which preserves dependences.
1213  assert(sibNode->op->getBlock() == dstNode->op->getBlock());
1214  Operation *insertPointInst =
1215  sibNode->op->isBeforeInBlock(dstNode->op)
1216  ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id)
1217  : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id);
1218  if (insertPointInst == nullptr)
1219  continue;
1220 
1221  // Check if fusion would be profitable and at what depth.
1222 
1223  // Get unique 'sibNode' load op to 'memref'.
1224  SmallVector<Operation *, 2> sibLoadOpInsts;
1225  sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
1226  // Currently findSiblingNodeToFuse searches for siblings with one load.
1227  Operation *sibLoadOpInst = llvm::getSingleElement(sibLoadOpInsts);
1228 
1229  // Gather 'dstNode' load ops to 'memref'.
1230  SmallVector<Operation *, 2> dstLoadOpInsts;
1231  dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
1232 
1233  // It's possible this fusion is at an inner depth (i.e., there are common
1234  // surrounding affine loops for the source and destination for ops). We
1235  // need to get this number because the call to canFuseLoops needs to be
1236  // passed the absolute depth. The max legal depth and the depths we try
1237  // below are however *relative* and as such don't include the common
1238  // depth.
1239  SmallVector<AffineForOp, 4> surroundingLoops;
1240  getAffineForIVs(*dstAffineForOp, &surroundingLoops);
1241  unsigned numSurroundingLoops = surroundingLoops.size();
1242  SmallVector<AffineForOp, 4> dstLoopIVs;
1243  getAffineForIVs(*dstLoadOpInsts[0], &dstLoopIVs);
1244  unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
1245  auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
1246 
1247  // Compute loop depth and slice union for fusion.
1248  SmallVector<ComputationSliceState, 8> depthSliceUnions;
1249  depthSliceUnions.resize(dstLoopDepthTest);
1250  unsigned maxLegalFusionDepth = 0;
1251  FusionStrategy strategy(memref);
1252  for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1253  FusionResult result =
1254  affine::canFuseLoops(sibAffineForOp, dstAffineForOp,
1255  /*dstLoopDepth=*/i + numSurroundingLoops,
1256  &depthSliceUnions[i - 1], strategy);
1257 
1258  if (result.value == FusionResult::Success)
1259  maxLegalFusionDepth = i;
1260  }
1261 
1262  LDBG() << "Max legal depth for fusion: " << maxLegalFusionDepth;
1263 
1264  // Skip if fusion is not feasible at any loop depths.
1265  if (maxLegalFusionDepth == 0)
1266  continue;
1267 
1268  double computeToleranceThresholdToUse = computeToleranceThreshold;
1269 
1270  // Cyclic dependences in the source nest may be violated when performing
1271  // slicing-based fusion. They aren't actually violated in cases where no
1272  // redundant execution of the source happens (1:1 pointwise dep on the
1273  // producer-consumer memref access for example). Check this and allow
1274  // fusion accordingly.
1275  if (hasCyclicDependence(sibAffineForOp)) {
1276  LDBG() << "Source nest has a cyclic dependence.";
1277  // Maximal fusion does not check for compute tolerance threshold; so
1278  // perform the maximal fusion only when the redundanation computation is
1279  // zero.
1280  if (maximalFusion) {
1281  auto dstForOp = cast<AffineForOp>(dstNode->op);
1282  int64_t sliceCost;
1283  int64_t fusedLoopNestComputeCost;
1284  auto fraction = getAdditionalComputeFraction(
1285  sibAffineForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
1286  sliceCost, fusedLoopNestComputeCost);
1287  if (!fraction || fraction > 0) {
1288  LDBG() << "Can't perform maximal fusion with a cyclic dependence "
1289  << "and non-zero additional compute.";
1290  return;
1291  }
1292  } else {
1293  // Set redundant computation tolerance to zero regardless of what the
1294  // user specified. Without this, fusion would be invalid.
1295  LDBG() << "Setting compute tolerance to zero since "
1296  << "source has a cyclic dependence.";
1297  computeToleranceThresholdToUse = 0.0;
1298  }
1299  }
1300 
1301  unsigned bestDstLoopDepth = maxLegalFusionDepth;
1302  if (!maximalFusion) {
1303  // Check if fusion would be profitable. For sibling fusion, the sibling
1304  // load op is treated as the src "store" op for fusion profitability
1305  // purposes. The footprint of the load in the slice relative to the
1306  // unfused source's determines reuse.
1307  if (!isFusionProfitable(sibAffineForOp, sibLoadOpInst, dstAffineForOp,
1308  depthSliceUnions, maxLegalFusionDepth,
1309  &bestDstLoopDepth,
1310  computeToleranceThresholdToUse))
1311  continue;
1312  }
1313 
1314  assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
1315 
1316  const ComputationSliceState &bestSlice =
1317  depthSliceUnions[bestDstLoopDepth - 1];
1318  assert(!bestSlice.isEmpty() &&
1319  "Fusion depth has no computed slice union");
1320 
1321  // Do not perform sibling fusion if it isn't maximal. We always remove the
1322  // sibling node and as such fusion shouldn't be performed if a part of the
1323  // slice is used in the destination.
1324  auto isMaximal = bestSlice.isMaximal();
1325  if (!isMaximal.value_or(false)) {
1326  LDBG() << "Slice isn't maximal; not performing sibling fusion.";
1327  continue;
1328  }
1329 
1330  // Check if source loop is being inserted in the innermost
1331  // destination loop. Based on this, the fused loop may be optimized
1332  // further inside `fuseLoops`.
1333  bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
1334  // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
1335  affine::fuseLoops(sibAffineForOp, dstAffineForOp, bestSlice,
1336  isInnermostInsertion);
1337 
1338  auto dstForInst = cast<AffineForOp>(dstNode->op);
1339  // Update operation position of fused loop nest (if needed).
1340  if (insertPointInst != dstForInst)
1341  dstForInst->moveBefore(insertPointInst);
1342 
1343  LDBG() << "Fused sibling nest " << sibId << " into destination nest "
1344  << dstNode->id << " at depth " << bestDstLoopDepth << ":";
1345  LDBG() << dstAffineForOp;
1346 
1347  // Update data dependence graph state post fusion.
1348  updateStateAfterSiblingFusion(sibNode, dstNode);
1349 
1350  // Remove old sibling loop nest.
1351  // Get op before we invalidate the MDG node.
1352  Operation *op = sibNode->op;
1353  mdg->removeNode(sibNode->id);
1354  op->erase();
1355  }
1356  }
1357 
1358  // Searches block argument uses and the graph from 'dstNode' looking for a
1359  // fusion candidate sibling node which shares no dependences with 'dstNode'
1360  // but which loads from the same memref. Returns true and sets
1361  // 'idAndMemrefToFuse' on success. Returns false otherwise.
1362  bool findSiblingNodeToFuse(Node *dstNode,
1363  DenseSet<unsigned> *visitedSibNodeIds,
1364  std::pair<unsigned, Value> *idAndMemrefToFuse) {
1365  // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse
1366  // on 'memref'.
1367  auto canFuseWithSibNode = [&](Node *sibNode, Value memref) {
1368  // Skip if 'outEdge' is not a read-after-write dependence.
1369  // TODO: Remove restrict to single load op restriction.
1370  if (sibNode->getLoadOpCount(memref) != 1)
1371  return false;
1372  // Skip if there exists a path of dependent edges between
1373  // 'sibNode' and 'dstNode'.
1374  if (mdg->hasDependencePath(sibNode->id, dstNode->id) ||
1375  mdg->hasDependencePath(dstNode->id, sibNode->id))
1376  return false;
1377  // Skip sib node if it loads to (and stores from) the same memref on
1378  // which it also has an input dependence edge.
1379  DenseSet<Value> loadAndStoreMemrefSet;
1380  sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
1381  if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) {
1382  return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0;
1383  }))
1384  return false;
1385 
1386  // Check that all stores are to the same memref if any.
1387  DenseSet<Value> storeMemrefs;
1388  for (auto *storeOpInst : sibNode->stores) {
1389  storeMemrefs.insert(
1390  cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
1391  }
1392  return storeMemrefs.size() <= 1;
1393  };
1394 
1395  // Search for siblings which load the same memref block argument.
1396  Block *block = dstNode->op->getBlock();
1397  for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) {
1398  for (Operation *user : block->getArgument(i).getUsers()) {
1399  auto loadOp = dyn_cast<AffineReadOpInterface>(user);
1400  if (!loadOp)
1401  continue;
1402  // Gather loops surrounding 'use'.
1404  getAffineForIVs(*user, &loops);
1405  // Skip 'use' if it is not within a loop nest.
1406  // Find the surrounding affine.for nested immediately within the
1407  // block.
1408  auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
1409  return loop->getBlock() == &mdg->block;
1410  });
1411  // Skip 'use' if it is not within a loop nest in `block`.
1412  if (it == loops.end())
1413  continue;
1414  Node *sibNode = mdg->getForOpNode(*it);
1415  assert(sibNode != nullptr);
1416  // Skip 'use' if it not a sibling to 'dstNode'.
1417  if (sibNode->id == dstNode->id)
1418  continue;
1419  // Skip 'use' if it has been visited.
1420  if (visitedSibNodeIds->count(sibNode->id) > 0)
1421  continue;
1422  // Skip 'use' if it does not load from the same memref as 'dstNode'.
1423  auto memref = loadOp.getMemRef();
1424  if (dstNode->getLoadOpCount(memref) == 0)
1425  continue;
1426  // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
1427  if (canFuseWithSibNode(sibNode, memref)) {
1428  visitedSibNodeIds->insert(sibNode->id);
1429  idAndMemrefToFuse->first = sibNode->id;
1430  idAndMemrefToFuse->second = memref;
1431  return true;
1432  }
1433  }
1434  }
1435 
1436  // Search for siblings by following edges through an intermediate src node.
1437  // Collect candidate 'dstNode' input edges in 'inEdges'.
1440  dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
1441  // Add 'inEdge' if it is a read-after-write dependence or an edge
1442  // from a memref defining op (e.g. view-like op or alloc op).
1443  if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
1444  (mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0 ||
1445  inEdge.value.getDefiningOp() == mdg->getNode(inEdge.id)->op))
1446  inEdges.push_back(inEdge);
1447  });
1448 
1449  // Search for sibling nodes to fuse by visiting output edges from each input
1450  // edge in 'inEdges'.
1451  for (auto &inEdge : inEdges) {
1452  // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'.
1455  inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) {
1456  unsigned sibNodeId = outEdge.id;
1457  if (visitedSibNodeIds->count(sibNodeId) > 0)
1458  return;
1459  // Skip output edge if not a sibling using the same memref.
1460  if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
1461  return;
1462  auto *sibNode = mdg->getNode(sibNodeId);
1463  if (!isa<AffineForOp>(sibNode->op))
1464  return;
1465  // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
1466  if (canFuseWithSibNode(sibNode, outEdge.value)) {
1467  // Add candidate 'outEdge' to sibling node.
1468  outEdges.push_back(outEdge);
1469  }
1470  });
1471 
1472  // Add first candidate if any were returned.
1473  if (!outEdges.empty()) {
1474  visitedSibNodeIds->insert(outEdges[0].id);
1475  idAndMemrefToFuse->first = outEdges[0].id;
1476  idAndMemrefToFuse->second = outEdges[0].value;
1477  return true;
1478  }
1479  }
1480  return false;
1481  }
1482 
1483  /// Update data dependence graph state to reflect sibling fusion of 'sibNode'
1484  /// into 'dstNode'.
1485  void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) {
1486  // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion.
1487  mdg->updateEdges(sibNode->id, dstNode->id);
1488 
1489  // Collect dst loop stats after memref privatization transformation.
1490  auto dstForInst = cast<AffineForOp>(dstNode->op);
1491  LoopNestStateCollector dstLoopCollector;
1492  dstLoopCollector.collect(dstForInst);
1493  // Clear and add back loads and stores
1494  mdg->clearNodeLoadAndStores(dstNode->id);
1495  mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
1496  dstLoopCollector.storeOpInsts, dstLoopCollector.memrefLoads,
1497  dstLoopCollector.memrefStores, dstLoopCollector.memrefFrees);
1498  }
1499 
1500  // Clean up any allocs with no users.
1501  void eraseUnusedMemRefAllocations() {
1502  for (auto &pair : mdg->memrefEdgeCount) {
1503  if (pair.second > 0)
1504  continue;
1505  auto memref = pair.first;
1506  // Skip if there exist other uses (return operation or function calls).
1507  if (!memref.use_empty())
1508  continue;
1509  // Use list expected to match the dep graph info.
1510  auto *op = memref.getDefiningOp();
1511  if (isa_and_nonnull<memref::AllocOp>(op))
1512  op->erase();
1513  }
1514  }
1515 };
1516 
1517 } // namespace
1518 
1519 /// Run fusion on `block`.
1520 void LoopFusion::runOnBlock(Block *block) {
1521  MemRefDependenceGraph g(*block);
1522  if (!g.init()) {
1523  LDBG() << "MDG init failed";
1524  return;
1525  }
1526 
1527  std::optional<unsigned> fastMemorySpaceOpt;
1528  if (fastMemorySpace.hasValue())
1529  fastMemorySpaceOpt = fastMemorySpace;
1530  unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
1531  GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
1532  maximalFusion, computeToleranceThreshold);
1533 
1534  if (affineFusionMode == FusionMode::ProducerConsumer)
1535  fusion.runProducerConsumerFusionOnly();
1536  else if (affineFusionMode == FusionMode::Sibling)
1537  fusion.runSiblingFusionOnly();
1538  else
1539  fusion.runGreedyFusion();
1540 }
1541 
1542 void LoopFusion::runOnOperation() {
1543  // Call fusion on every op that has at least two affine.for nests (in post
1544  // order).
1545  getOperation()->walk([&](Operation *op) {
1546  for (Region &region : op->getRegions()) {
1547  for (Block &block : region.getBlocks()) {
1548  auto affineFors = block.getOps<AffineForOp>();
1549  if (!affineFors.empty() && !llvm::hasSingleElement(affineFors))
1550  runOnBlock(&block);
1551  }
1552  }
1553  });
1554 }
1555 
1557  unsigned fastMemorySpace, uint64_t localBufSizeThreshold,
1558  bool maximalFusion, enum FusionMode affineFusionMode) {
1559  return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
1560  maximalFusion, affineFusionMode);
1561 }
MemRefDependenceGraph::Node Node
Definition: Utils.cpp:37
static Operation * getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock, ArrayRef< Operation * > producerStores)
Get the operation that should act as a dominance filter while replacing memref uses with a private me...
Definition: LoopFusion.cpp:244
static void getProducerCandidates(unsigned dstId, const MemRefDependenceGraph &mdg, SmallVectorImpl< unsigned > &srcIdCandidates)
Returns in 'srcIdCandidates' the producer fusion candidates for consumer 'dstId'.
Definition: LoopFusion.cpp:134
static bool isFusionProfitable(AffineForOp srcForOp, ArrayRef< Operation * > producerStores, AffineForOp dstForOp, ArrayRef< ComputationSliceState > depthSliceUnions, unsigned maxLegalFusionDepth, unsigned *dstLoopDepth, double computeToleranceThreshold)
Definition: LoopFusion.cpp:500
static bool isEscapingMemref(Value memref, Block *block)
A memref escapes in the context of the fusion pass if either:
Definition: LoopFusion.cpp:185
static bool canRemoveSrcNodeAfterFusion(unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, Operation *fusedLoopInsPoint, const DenseSet< Value > &escapingMemRefs, const MemRefDependenceGraph &mdg)
Returns true if node 'srcId' can be removed after fusing it with node 'dstId'.
Definition: LoopFusion.cpp:76
static Value createPrivateMemRef(AffineForOp forOp, ArrayRef< Operation * > storeOps, unsigned dstLoopDepth, std::optional< unsigned > fastMemorySpace, Block *sliceInsertionBlock, uint64_t localBufSizeThreshold)
Definition: LoopFusion.cpp:333
static void gatherEscapingMemrefs(unsigned id, const MemRefDependenceGraph &mdg, DenseSet< Value > &escapingMemRefs)
Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' that escape the block or ...
Definition: LoopFusion.cpp:216
static std::optional< double > getAdditionalComputeFraction(AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth, ArrayRef< ComputationSliceState > depthSliceUnions, int64_t &sliceCost, int64_t &fusedLoopNestComputeCost)
Returns the amount of additional (redundant) computation that will be done as a fraction of the total...
Definition: LoopFusion.cpp:279
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition: Block.cpp:74
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
iterator begin()
Definition: Block.h:143
void getValues(unsigned start, unsigned end, SmallVectorImpl< Value > *values) const
Returns the Values associated with variables in range [start, end).
This class helps build Operations.
Definition: Builders.h:205
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:208
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
Describes the fusion strategy to be used in the Affine loop fusion utilities.
bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, AffineForOp dstForOp, LoopNestStats &dstStats, const ComputationSliceState &slice, int64_t *computeCost)
Computes and returns in 'computeCost', the total compute cost of fusing the 'slice' of the loop nest ...
void gatherProducerConsumerMemrefs(ArrayRef< Operation * > srcOps, ArrayRef< Operation * > dstOps, DenseSet< Value > &producerConsumerMemrefs)
Returns in 'producerConsumerMemrefs' the memrefs involved in a producer-consumer dependence between w...
int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats)
Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
void fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, const ComputationSliceState &srcSlice, bool isInnermostSiblingInsertionFusion=false)
Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point and source slice loop bo...
void getAffineForIVs(Operation &op, SmallVectorImpl< AffineForOp > *loops)
Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered from the outermost 'affine....
Definition: Utils.cpp:755
std::optional< int64_t > getMemoryFootprintBytes(AffineForOp forOp, int memorySpace=-1)
Gets the memory footprint of all data touched in the specified memory space in bytes; if the memory s...
Definition: Utils.cpp:2077
std::unique_ptr< Pass > createLoopFusionPass(unsigned fastMemorySpace=0, uint64_t localBufSizeThreshold=0, bool maximalFusion=false, enum FusionMode fusionMode=FusionMode::Greedy)
Creates a loop fusion pass which fuses affine loop nests at the top-level of the operation the pass i...
FusionMode
Fusion mode to attempt.
Definition: Passes.h:35
@ ProducerConsumer
Definition: Passes.h:35
unsigned getInnermostCommonLoopDepth(ArrayRef< Operation * > ops, SmallVectorImpl< AffineForOp > *surroundingLoops=nullptr)
Returns the innermost common loop depth for the set of operations in 'ops'.
Definition: Utils.cpp:1505
bool getLoopNestStats(AffineForOp forOp, LoopNestStats *stats)
Collect loop nest statistics (eg.
AffineForOp sinkSequentialLoops(AffineForOp forOp)
Definition: LoopUtils.cpp:1453
bool hasCyclicDependence(AffineForOp root)
Returns true if the affine nest rooted at root has a cyclic dependence among its affine memory access...
mlir::Block * findInnermostCommonBlockInScope(mlir::Operation *a, mlir::Operation *b)
Find the innermost common Block of a and b in the affine scope that a and b are part of.
Definition: Utils.cpp:2295
FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, ComputationSliceState *srcSlice, FusionStrategy fusionStrategy=FusionStrategy::Generic)
Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the loop nest rooted at 'dst...
LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, ArrayRef< Value > extraIndices={}, AffineMap indexRemap=AffineMap(), ArrayRef< Value > extraOperands={}, ArrayRef< Value > symbolOperands={}, llvm::function_ref< bool(Operation *)> userFilterFn=nullptr, bool allowNonDereferencingOps=false, bool replaceInDeallocOp=false)
Replaces all "dereferencing" uses of oldMemRef with newMemRef while optionally remapping the old memr...
Definition: Utils.cpp:1306
std::optional< int64_t > getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType)
Returns the memref's element type's size in bytes where the elemental type is an int or float or a ve...
Definition: Utils.cpp:1323
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their associated operands for a ...
Definition: Utils.h:310
bool isEmpty() const
Returns true if the computation slice is empty.
Definition: Utils.h:342
std::optional< bool > isMaximal() const
Returns true if the computation slice encloses all the iterations of the sliced loop nest.
Definition: Utils.cpp:1010
enum mlir::affine::FusionResult::ResultEnum value
SmallVector< Operation *, 4 > memrefFrees
Definition: Utils.h:49
SmallVector< Operation *, 4 > loadOpInsts
Definition: Utils.h:41
SmallVector< Operation *, 4 > memrefStores
Definition: Utils.h:47
void collect(Operation *opToWalk)
Definition: Utils.cpp:42
SmallVector< Operation *, 4 > memrefLoads
Definition: Utils.h:45
SmallVector< Operation *, 4 > storeOpInsts
Definition: Utils.h:43
LoopNestStats aggregates various per-loop statistics (eg.
DenseMap< unsigned, SmallVector< Edge, 2 > > outEdges
Definition: Utils.h:147
Block & block
The block for which this graph is created to perform fusion.
Definition: Utils.h:262
unsigned addNode(Operation *op)
Definition: Utils.cpp:371
void addEdge(unsigned srcId, unsigned dstId, Value value)
Definition: Utils.cpp:436
DenseMap< unsigned, Node > nodes
Definition: Utils.h:141
bool hasDependencePath(unsigned srcId, unsigned dstId) const
Definition: Utils.cpp:474
void clearNodeLoadAndStores(unsigned id)
Definition: Utils.cpp:701
const Node * getForOpNode(AffineForOp forOp) const
Definition: Utils.cpp:363
Operation * getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) const
Definition: Utils.cpp:549
void updateEdges(unsigned srcId, unsigned dstId, const DenseSet< Value > &privateMemRefs, bool removeSrcId)
Definition: Utils.cpp:624
DenseMap< unsigned, SmallVector< Edge, 2 > > inEdges
Definition: Utils.h:144
void forEachMemRefInputEdge(unsigned id, const std::function< void(Edge)> &callback)
Definition: Utils.cpp:709
unsigned getOutEdgeCount(unsigned id, Value memref=nullptr) const
Definition: Utils.cpp:525
const Node * getNode(unsigned id) const
Definition: Utils.cpp:356
void removeNode(unsigned id)
Definition: Utils.cpp:378
void forEachMemRefOutputEdge(unsigned id, const std::function< void(Edge)> &callback)
Definition: Utils.cpp:717
void addToNode(unsigned id, ArrayRef< Operation * > loads, ArrayRef< Operation * > stores, ArrayRef< Operation * > memrefLoads, ArrayRef< Operation * > memrefStores, ArrayRef< Operation * > memrefFrees)
Definition: Utils.cpp:688
unsigned getIncomingMemRefAccesses(unsigned id, Value memref) const
Definition: Utils.cpp:509
DenseMap< Value, unsigned > memrefEdgeCount
Definition: Utils.h:150
A region of a memref's data space; this is typically constructed by analyzing load/store op's on this...
Definition: Utils.h:481
std::optional< int64_t > getConstantBoundingSizeAndShape(SmallVectorImpl< int64_t > *shape=nullptr, SmallVectorImpl< AffineMap > *lbs=nullptr) const
Returns a constant upper bound on the number of elements in this region if bounded by a known constan...
Definition: Utils.cpp:1064
FlatAffineValueConstraints * getConstraints()
Definition: Utils.h:527
LogicalResult compute(Operation *op, unsigned loopDepth, const ComputationSliceState *sliceState=nullptr, bool addMemRefDimBounds=true, bool dropLocalVars=true, bool dropOuterIVs=true)
Computes the memory region accessed by this memref with the region represented as constraints symboli...
Definition: Utils.cpp:1161
std::optional< int64_t > getRegionSize()
Returns the size of this MemRefRegion in bytes.
Definition: Utils.cpp:1342
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.