MLIR  20.0.0git
LoopFusion.cpp
Go to the documentation of this file.
1 //===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements affine fusion.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/Builders.h"
26 #include "mlir/Transforms/Passes.h"
27 #include "llvm/ADT/DenseMap.h"
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include <iomanip>
34 #include <optional>
35 #include <sstream>
36 
37 namespace mlir {
38 namespace affine {
39 #define GEN_PASS_DEF_AFFINELOOPFUSION
40 #include "mlir/Dialect/Affine/Passes.h.inc"
41 } // namespace affine
42 } // namespace mlir
43 
44 #define DEBUG_TYPE "affine-loop-fusion"
45 
46 using namespace mlir;
47 using namespace mlir::affine;
48 
49 namespace {
50 /// Loop fusion pass. This pass currently supports a greedy fusion policy,
51 /// which fuses loop nests with single-writer/single-reader memref dependences
52 /// with the goal of improving locality.
53 // TODO: Support fusion of source loop nests which write to multiple
54 // memrefs, where each memref can have multiple users (if profitable).
55 struct LoopFusion : public affine::impl::AffineLoopFusionBase<LoopFusion> {
56  LoopFusion() = default;
57  LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes,
58  bool maximalFusion, enum FusionMode affineFusionMode) {
59  this->fastMemorySpace = fastMemorySpace;
60  this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024;
61  this->maximalFusion = maximalFusion;
62  this->affineFusionMode = affineFusionMode;
63  }
64 
65  void runOnBlock(Block *block);
66  void runOnOperation() override;
67 };
68 
69 } // namespace
70 
71 /// Returns true if node 'srcId' can be removed after fusing it with node
72 /// 'dstId'. The node can be removed if any of the following conditions are met:
73 /// 1. 'srcId' has no output dependences after fusion and no escaping memrefs.
74 /// 2. 'srcId' has no output dependences after fusion, has escaping memrefs
75 /// and the fusion slice is maximal.
76 /// 3. 'srcId' has output dependences after fusion, the fusion slice is
77 /// maximal and the fusion insertion point dominates all the dependences.
79  unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice,
80  Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs,
81  MemRefDependenceGraph *mdg) {
82 
83  Operation *dstNodeOp = mdg->getNode(dstId)->op;
84  bool hasOutDepsAfterFusion = false;
85 
86  for (auto &outEdge : mdg->outEdges[srcId]) {
87  Operation *depNodeOp = mdg->getNode(outEdge.id)->op;
88  // Skip dependence with dstOp since it will be removed after fusion.
89  if (depNodeOp == dstNodeOp)
90  continue;
91 
92  // Only fusion within the same block is supported. Use domination analysis
93  // when needed.
94  if (depNodeOp->getBlock() != dstNodeOp->getBlock())
95  return false;
96 
97  // Check if the insertion point of the fused loop dominates the dependence.
98  // Otherwise, the src loop can't be removed.
99  if (fusedLoopInsPoint != depNodeOp &&
100  !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) {
101  LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't "
102  "dominate dependence\n");
103  return false;
104  }
105 
106  hasOutDepsAfterFusion = true;
107  }
108 
109  // If src loop has dependences after fusion or it writes to an live-out or
110  // escaping memref, we can only remove it if the fusion slice is maximal so
111  // that all the dependences are preserved.
112  if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
113  std::optional<bool> isMaximal = fusionSlice.isMaximal();
114  if (!isMaximal) {
115  LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine "
116  "if fusion is maximal\n");
117  return false;
118  }
119 
120  if (!*isMaximal) {
121  LLVM_DEBUG(llvm::dbgs()
122  << "Src loop can't be removed: fusion is not maximal\n");
123  return false;
124  }
125  }
126 
127  return true;
128 }
129 
130 /// Returns in 'srcIdCandidates' the producer fusion candidates for consumer
131 /// 'dstId'. Candidates are sorted by node id order. This order corresponds to
132 /// the program order when the 'mdg' is created. However, program order is not
133 /// guaranteed and must not be required by the client. Program order won't be
134 /// held if the 'mdg' is reused from a previous fusion step or if the node
135 /// creation order changes in the future to support more advance cases.
136 // TODO: Move this to a loop fusion utility once 'mdg' is also moved.
137 static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg,
138  SmallVectorImpl<unsigned> &srcIdCandidates) {
139  // Skip if no input edges along which to fuse.
140  if (mdg->inEdges.count(dstId) == 0)
141  return;
142 
143  // Gather memrefs from loads in 'dstId'.
144  auto *dstNode = mdg->getNode(dstId);
145  DenseSet<Value> consumedMemrefs;
146  for (Operation *load : dstNode->loads)
147  consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
148 
149  // Traverse 'dstId' incoming edges and gather the nodes that contain a store
150  // to one of the consumed memrefs.
151  for (auto &srcEdge : mdg->inEdges[dstId]) {
152  auto *srcNode = mdg->getNode(srcEdge.id);
153  // Skip if 'srcNode' is not a loop nest.
154  if (!isa<AffineForOp>(srcNode->op))
155  continue;
156 
157  if (any_of(srcNode->stores, [&](Operation *op) {
158  auto storeOp = cast<AffineWriteOpInterface>(op);
159  return consumedMemrefs.count(storeOp.getMemRef()) > 0;
160  }))
161  srcIdCandidates.push_back(srcNode->id);
162  }
163 
164  llvm::sort(srcIdCandidates);
165  srcIdCandidates.erase(llvm::unique(srcIdCandidates), srcIdCandidates.end());
166 }
167 
168 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a
169 /// producer-consumer dependence between 'srcId' and 'dstId'.
170 static void
171 gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
173  DenseSet<Value> &producerConsumerMemrefs) {
174  auto *dstNode = mdg->getNode(dstId);
175  auto *srcNode = mdg->getNode(srcId);
176  gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads,
177  producerConsumerMemrefs);
178 }
179 
180 /// A memref escapes in the context of the fusion pass if either:
181 /// 1. it (or its alias) is a block argument, or
182 /// 2. created by an op not known to guarantee alias freedom,
183 /// 3. it (or its alias) are used by ops other than affine dereferencing ops
184 /// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops,
185 /// terminator ops, etc.); such ops do not deference the memref in an affine
186 /// way.
187 static bool isEscapingMemref(Value memref, Block *block) {
188  Operation *defOp = memref.getDefiningOp();
189  // Check if 'memref' is a block argument.
190  if (!defOp)
191  return true;
192 
193  // Check if this is defined to be an alias of another memref.
194  if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
195  if (isEscapingMemref(viewOp.getViewSource(), block))
196  return true;
197 
198  // Any op besides allocating ops wouldn't guarantee alias freedom
199  if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(defOp, memref))
200  return true;
201 
202  // Check if 'memref' is used by a non-deferencing op (including unknown ones)
203  // (e.g., call ops, alias creating ops, etc.).
204  return llvm::any_of(memref.getUsers(), [&](Operation *user) {
205  // Ignore users outside of `block`.
206  Operation *ancestorOp = block->getParent()->findAncestorOpInRegion(*user);
207  if (!ancestorOp)
208  return true;
209  if (ancestorOp->getBlock() != block)
210  return false;
211  return !isa<AffineMapAccessInterface>(*user);
212  });
213 }
214 
215 /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
216 /// that escape the block or are accessed in a non-affine way.
217 static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
218  DenseSet<Value> &escapingMemRefs) {
219  auto *node = mdg->getNode(id);
220  for (Operation *storeOp : node->stores) {
221  auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
222  if (escapingMemRefs.count(memref))
223  continue;
224  if (isEscapingMemref(memref, &mdg->block))
225  escapingMemRefs.insert(memref);
226  }
227 }
228 
229 // Sinks all sequential loops to the innermost levels (while preserving
230 // relative order among them) and moves all parallel loops to the
231 // outermost (while again preserving relative order among them).
232 // This can increase the loop depth at which we can fuse a slice, since we are
233 // pushing loop carried dependence to a greater depth in the loop nest.
235  assert(isa<AffineForOp>(node->op));
236  AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op));
237  node->op = newRootForOp;
238 }
239 
240 // Creates and returns a private (single-user) memref for fused loop rooted
241 // at 'forOp', with (potentially reduced) memref size based on the
242 // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
243 // TODO: consider refactoring the common code from generateDma and
244 // this one.
245 static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
246  unsigned dstLoopDepth,
247  std::optional<unsigned> fastMemorySpace,
248  uint64_t localBufSizeThreshold) {
249  Operation *forInst = forOp.getOperation();
250 
251  // Create builder to insert alloc op just before 'forOp'.
252  OpBuilder b(forInst);
253  // Builder to create constants at the top level.
254  OpBuilder top(forInst->getParentRegion());
255  // Create new memref type based on slice bounds.
256  auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
257  auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
258  unsigned rank = oldMemRefType.getRank();
259 
260  // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
261  MemRefRegion region(srcStoreOpInst->getLoc());
262  bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth));
263  (void)validRegion;
264  assert(validRegion && "unexpected memref region failure");
265  SmallVector<int64_t, 4> newShape;
266  std::vector<SmallVector<int64_t, 4>> lbs;
267  SmallVector<int64_t, 8> lbDivisors;
268  lbs.reserve(rank);
269  // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
270  // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
271  std::optional<int64_t> numElements =
272  region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
273  assert(numElements && "non-constant number of elts in local buffer");
274 
275  const FlatAffineValueConstraints *cst = region.getConstraints();
276  // 'outerIVs' holds the values that this memory region is symbolic/parametric
277  // on; this would correspond to loop IVs surrounding the level at which the
278  // slice is being materialized.
279  SmallVector<Value, 8> outerIVs;
280  cst->getValues(rank, cst->getNumVars(), &outerIVs);
281 
282  // Build 'rank' AffineExprs from MemRefRegion 'lbs'
284  offsets.reserve(rank);
285  for (unsigned d = 0; d < rank; ++d) {
286  assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
287 
288  AffineExpr offset = top.getAffineConstantExpr(0);
289  for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
290  offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
291  }
292  assert(lbDivisors[d] > 0);
293  offset =
294  (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
295  offsets.push_back(offset);
296  }
297 
298  // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
299  // by 'srcStoreOpInst'.
300  auto eltSize = getMemRefIntOrFloatEltSizeInBytes(oldMemRefType);
301  assert(eltSize && "memrefs with size elt types expected");
302  uint64_t bufSize = *eltSize * *numElements;
303  unsigned newMemSpace;
304  if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
305  newMemSpace = *fastMemorySpace;
306  } else {
307  newMemSpace = oldMemRefType.getMemorySpaceAsInt();
308  }
309  auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
310  {}, newMemSpace);
311 
312  // Create new private memref for fused loop 'forOp'. 'newShape' is always
313  // a constant shape.
314  // TODO: Create/move alloc ops for private memrefs closer to their
315  // consumer loop nests to reduce their live range. Currently they are added
316  // at the beginning of the block, because loop nests can be reordered
317  // during the fusion pass.
318  Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType);
319 
320  // Build an AffineMap to remap access functions based on lower bound offsets.
321  SmallVector<AffineExpr, 4> remapExprs;
322  remapExprs.reserve(rank);
323  for (unsigned i = 0; i < rank; i++) {
324  auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
325 
326  auto remapExpr =
327  simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
328  remapExprs.push_back(remapExpr);
329  }
330 
331  auto indexRemap =
332  AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
333 
334  // Replace all users of 'oldMemRef' with 'newMemRef'.
335  LogicalResult res =
336  replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
337  /*extraOperands=*/outerIVs,
338  /*symbolOperands=*/{},
339  /*domOpFilter=*/&*forOp.getBody()->begin());
340  assert(succeeded(res) &&
341  "replaceAllMemrefUsesWith should always succeed here");
342  (void)res;
343  return newMemRef;
344 }
345 
346 /// Returns true if there are any non-affine uses of `memref` in any of
347 /// the operations between `start` and `end` (both exclusive). Any other
348 /// than affine read/write are treated as non-affine uses of `memref`.
349 static bool hasNonAffineUsersOnPath(Operation *start, Operation *end,
350  Value memref) {
351  assert(start->getBlock() == end->getBlock());
352  assert(start->isBeforeInBlock(end) && "start expected to be before end");
353  Block *block = start->getBlock();
354  // Check if there is a non-affine memref user in any op between `start` and
355  // `end`.
356  return llvm::any_of(memref.getUsers(), [&](Operation *user) {
357  if (isa<AffineReadOpInterface, AffineWriteOpInterface>(user))
358  return false;
359  Operation *ancestor = block->findAncestorOpInBlock(*user);
360  return ancestor && start->isBeforeInBlock(ancestor) &&
361  ancestor->isBeforeInBlock(end);
362  });
363 }
364 
365 /// Check whether a memref value used in any operation of 'src' has a
366 /// non-affine operation that is between `src` and `end` (exclusive of `src`
367 /// and `end`) where `src` and `end` are expected to be in the same Block.
368 /// Any other than affine read/write are treated as non-affine uses of memref.
369 static bool hasNonAffineUsersOnPath(Operation *src, Operation *end) {
370  assert(src->getBlock() == end->getBlock() && "same block expected");
371 
372  // Trivial case. `src` and `end` are exclusive.
373  if (src == end || end->isBeforeInBlock(src))
374  return false;
375 
376  // Collect relevant memref values.
377  llvm::SmallDenseSet<Value, 2> memRefValues;
378  src->walk([&](Operation *op) {
379  for (Value v : op->getOperands())
380  // Collect memref values only.
381  if (isa<MemRefType>(v.getType()))
382  memRefValues.insert(v);
383  return WalkResult::advance();
384  });
385  // Look for non-affine users between `src` and `end`.
386  return llvm::any_of(memRefValues, [&](Value memref) {
387  return hasNonAffineUsersOnPath(src, end, memref);
388  });
389 }
390 
391 // Checks the profitability of fusing a backwards slice of the loop nest
392 // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
393 // The argument 'srcStoreOpInst' is used to calculate the storage reduction on
394 // the memref being produced and consumed, which is an input to the cost model.
395 // For producer-consumer fusion, 'srcStoreOpInst' will be the same as
396 // 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse
397 // fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the
398 // same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the
399 // unique store op in the src node, which will be used to check that the write
400 // region is the same after input-reuse fusion. Computation slices are provided
401 // in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which
402 // fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is
403 // profitable to fuse the candidate loop nests. Returns false otherwise.
404 // `dstLoopDepth` is set to the most profitable depth at which to materialize
405 // the source loop nest slice.
406 // The profitability model executes the following steps:
407 // *) Computes the backward computation slice at 'srcOpInst'. This
408 // computation slice of the loop nest surrounding 'srcOpInst' is
409 // represented by modified src loop bounds in 'sliceState', which are
410 // functions of loop IVs in the loop nest surrounding 'srcOpInst'.
411 // *) Computes the cost of unfused src/dst loop nests (currently the cost of a
412 // loop nest is the total number of dynamic operation instances in the loop
413 // nest).
414 // *) Computes the cost of fusing a slice of the src loop nest into the dst
415 // loop nest at various values of dst loop depth, attempting to fuse
416 // the largest computation slice at the maximal dst loop depth (closest to
417 // the load) to minimize reuse distance and potentially enable subsequent
418 // load/store forwarding.
419 // NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
420 // nest, at which the src computation slice is inserted/fused.
421 // NOTE: We attempt to maximize the dst loop depth, but there are cases
422 // where a particular setting for 'dstLoopNest' might fuse an unsliced
423 // loop (within the src computation slice) at a depth which results in
424 // excessive recomputation (see unit tests for examples).
425 // *) Compares the total cost of the unfused loop nests to the min cost fused
426 // loop nest computed in the previous step, and returns true if the latter
427 // is lower.
428 // TODO: Extend profitability analysis to support scenarios with multiple
429 // stores.
430 static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
431  AffineForOp dstForOp,
432  ArrayRef<ComputationSliceState> depthSliceUnions,
433  unsigned maxLegalFusionDepth,
434  unsigned *dstLoopDepth,
435  double computeToleranceThreshold) {
436  LLVM_DEBUG({
437  llvm::dbgs() << "Checking whether fusion is profitable between src op:\n";
438  llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n";
439  llvm::dbgs() << dstForOp << "\n";
440  });
441 
442  if (maxLegalFusionDepth == 0) {
443  LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth is 0\n");
444  return false;
445  }
446 
447  // Compute cost of sliced and unsliced src loop nest.
448  SmallVector<AffineForOp, 4> srcLoopIVs;
449  getAffineForIVs(*srcOpInst, &srcLoopIVs);
450 
451  // Walk src loop nest and collect stats.
452  LoopNestStats srcLoopNestStats;
453  if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats))
454  return false;
455 
456  // Compute cost of dst loop nest.
457  LoopNestStats dstLoopNestStats;
458  if (!getLoopNestStats(dstForOp, &dstLoopNestStats))
459  return false;
460 
461  // Search for min cost value for 'dstLoopDepth'. At each value of
462  // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice
463  // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
464  // of these bounds). Next the union slice bounds are used to calculate
465  // the cost of the slice and the cost of the slice inserted into the dst
466  // loop nest at 'dstLoopDepth'.
467  uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
468  double maxStorageReduction = 0.0;
469  std::optional<uint64_t> sliceMemEstimate;
470 
471  // The best loop depth at which to materialize the slice.
472  std::optional<unsigned> bestDstLoopDepth;
473 
474  // Compute op instance count for the src loop nest without iteration slicing.
475  uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats);
476 
477  // Compute src loop nest write region size.
478  MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
479  if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) {
480  LLVM_DEBUG(llvm::dbgs()
481  << "Unable to compute MemRefRegion for source operation\n");
482  return false;
483  }
484 
485  std::optional<int64_t> maybeSrcWriteRegionSizeBytes =
486  srcWriteRegion.getRegionSize();
487  if (!maybeSrcWriteRegionSizeBytes.has_value())
488  return false;
489  int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;
490 
491  // Compute op instance count for the src loop nest.
492  uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
493 
494  // Evaluate all depth choices for materializing the slice in the destination
495  // loop nest.
496  for (unsigned i = maxLegalFusionDepth; i >= 1; --i) {
497  const ComputationSliceState &slice = depthSliceUnions[i - 1];
498  // Skip slice union if it wasn't computed for this depth.
499  if (slice.isEmpty())
500  continue;
501 
502  int64_t fusedLoopNestComputeCost;
503  if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp,
504  dstLoopNestStats, slice,
505  &fusedLoopNestComputeCost)) {
506  LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
507  continue;
508  }
509 
510  double additionalComputeFraction =
511  fusedLoopNestComputeCost /
512  (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
513  1;
514 
515  // Determine what the slice write MemRefRegion would be, if the src loop
516  // nest slice 'slice' were to be inserted into the dst loop nest at loop
517  // depth 'i'.
518  MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
519  if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
520  &slice))) {
521  LLVM_DEBUG(llvm::dbgs()
522  << "Failed to compute slice write region at loopDepth: " << i
523  << "\n");
524  continue;
525  }
526 
527  std::optional<int64_t> maybeSliceWriteRegionSizeBytes =
528  sliceWriteRegion.getRegionSize();
529  if (!maybeSliceWriteRegionSizeBytes.has_value() ||
530  *maybeSliceWriteRegionSizeBytes == 0) {
531  LLVM_DEBUG(llvm::dbgs()
532  << "Failed to get slice write region size at loopDepth: " << i
533  << "\n");
534  continue;
535  }
536  int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
537 
538  // If we are fusing for reuse, check that write regions remain the same.
539  // TODO: Write region check should check sizes and offsets in
540  // each dimension, so that we are sure they are covering the same memref
541  // region. Also, move this out to a isMemRefRegionSuperSet helper function.
542  if (srcOpInst != srcStoreOpInst &&
543  sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
544  continue;
545 
546  double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
547  static_cast<double>(sliceWriteRegionSizeBytes);
548 
549  LLVM_DEBUG({
550  std::stringstream msg;
551  msg << " evaluating fusion profitability at depth : " << i << "\n"
552  << std::fixed << std::setprecision(2)
553  << " additional compute fraction: "
554  << 100.0 * additionalComputeFraction << "%\n"
555  << " storage reduction factor: " << storageReduction << "x\n"
556  << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
557  << " src write region size: " << srcWriteRegionSizeBytes << "\n"
558  << " slice write region size: " << sliceWriteRegionSizeBytes
559  << "\n";
560  llvm::dbgs() << msg.str();
561  });
562 
563  // TODO: This is a placeholder cost model.
564  // Among all choices that add an acceptable amount of redundant computation
565  // (as per computeToleranceThreshold), we will simply pick the one that
566  // reduces the intermediary size the most.
567  if ((storageReduction > maxStorageReduction) &&
568  (additionalComputeFraction < computeToleranceThreshold)) {
569  maxStorageReduction = storageReduction;
570  bestDstLoopDepth = i;
571  minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
572  sliceMemEstimate = sliceWriteRegionSizeBytes;
573  }
574  }
575 
576  // A simple cost model: fuse if it reduces the memory footprint.
577 
578  if (!bestDstLoopDepth) {
579  LLVM_DEBUG(
580  llvm::dbgs()
581  << "All fusion choices involve more than the threshold amount of "
582  "redundant computation; NOT fusing.\n");
583  return false;
584  }
585 
586  if (!bestDstLoopDepth) {
587  LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n");
588  return false;
589  }
590 
591  // Set dstLoopDepth based on best values from search.
592  *dstLoopDepth = *bestDstLoopDepth;
593 
594  LLVM_DEBUG(
595  llvm::dbgs() << " LoopFusion fusion stats:"
596  << "\n best loop depth: " << bestDstLoopDepth
597  << "\n src loop nest compute cost: " << srcLoopNestCost
598  << "\n dst loop nest compute cost: " << dstLoopNestCost
599  << "\n fused loop nest compute cost: "
600  << minFusedLoopNestComputeCost << "\n");
601 
602  auto dstMemSize = getMemoryFootprintBytes(dstForOp);
603  auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
604 
605  std::optional<double> storageReduction;
606 
607  if (!dstMemSize || !srcMemSize) {
608  LLVM_DEBUG(llvm::dbgs()
609  << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
610  return false;
611  }
612 
613  auto srcMemSizeVal = *srcMemSize;
614  auto dstMemSizeVal = *dstMemSize;
615 
616  assert(sliceMemEstimate && "expected value");
617  auto fusedMem = dstMemSizeVal + *sliceMemEstimate;
618 
619  LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
620  << " dst mem: " << dstMemSizeVal << "\n"
621  << " fused mem: " << fusedMem << "\n"
622  << " slice mem: " << sliceMemEstimate << "\n");
623 
624  if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
625  LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
626  return false;
627  }
628  storageReduction =
629  100.0 *
630  (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
631 
632  double additionalComputeFraction =
633  100.0 * (minFusedLoopNestComputeCost /
634  (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
635  1);
636  (void)additionalComputeFraction;
637  LLVM_DEBUG({
638  std::stringstream msg;
639  msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
640  << std::setprecision(2) << additionalComputeFraction
641  << "% redundant computation and a ";
642  msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>");
643  msg << "% storage reduction.\n";
644  llvm::dbgs() << msg.str();
645  });
646 
647  return true;
648 }
649 
650 namespace {
651 
652 // GreedyFusion greedily fuses loop nests which have a producer/consumer or
653 // input-reuse relationship on a memref, with the goal of improving locality.
654 //
655 // The steps of the producer-consumer fusion algorithm are as follows:
656 //
657 // *) A worklist is initialized with node ids from the dependence graph.
658 // *) For each node id in the worklist:
659 // *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a
660 // candidate destination AffineForOp into which fusion will be attempted.
661 // *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
662 // *) For each LoadOp in 'dstLoadOps' do:
663 // *) Look up dependent loop nests which have a single store op to the same
664 // memref.
665 // *) Check if dependences would be violated by the fusion.
666 // *) Get a computation slice of 'srcLoopNest', which adjusts its loop
667 // bounds to be functions of 'dstLoopNest' IVs and symbols.
668 // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
669 // at a loop depth determined by the cost model in 'isFusionProfitable'.
670 // *) Add the newly fused load/store operations to the state,
671 // and also add newly fused load ops to 'dstLoopOps' to be considered
672 // as fusion dst load ops in another iteration.
673 // *) Remove old src loop nest and its associated state.
674 //
675 // The steps of the input-reuse fusion algorithm are as follows:
676 //
677 // *) Initialize 'worklist' with node ids from the dependence graph.
678 // *) For each 'dstNode' in the worklist:
679 // *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which
680 // loads from the same memref, but which has no dependence paths to/from.
681 // *) Get a computation slice of 'sibLoopNest', which adjusts its loop
682 // bounds to be functions of 'dstLoopNest' IVs and symbols.
683 // *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest',
684 // at a loop depth determined by the cost model in 'isFusionProfitable'.
685 // This function also checks that the memref write region of 'sibLoopNest',
686 // is preserved in the fused loop nest.
687 // *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'.
688 //
689 // Given a graph where top-level operations are vertices in the set 'V' and
690 // edges in the set 'E' are dependences between vertices, this algorithm
691 // takes O(V) time for initialization, and has runtime O(V + E).
692 //
693 // This greedy algorithm is not 'maximal' due to the current restriction of
694 // fusing along single producer consumer edges, but there is a TODO: to fix
695 // this.
696 //
697 // TODO: Experiment with other fusion policies.
698 struct GreedyFusion {
699 public:
700  // The data dependence graph to traverse during fusion.
702  // Worklist of graph nodes visited during the fusion pass.
703  SmallVector<unsigned, 8> worklist;
704  // Parameter for local buffer size threshold.
705  unsigned localBufSizeThreshold;
706  // Parameter for fast memory space.
707  std::optional<unsigned> fastMemorySpace;
708  // If true, ignore any additional (redundant) computation tolerance threshold
709  // that would have prevented fusion.
710  bool maximalFusion;
711  // The amount of additional computation that is tolerated while fusing
712  // pair-wise as a fraction of the total computation.
713  double computeToleranceThreshold;
714 
716 
717  GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold,
718  std::optional<unsigned> fastMemorySpace, bool maximalFusion,
719  double computeToleranceThreshold)
720  : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
721  fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
722  computeToleranceThreshold(computeToleranceThreshold) {}
723 
724  /// Initializes 'worklist' with nodes from 'mdg'.
725  void init() {
726  // TODO: Add a priority queue for prioritizing nodes by different
727  // metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
728  worklist.clear();
729  for (auto &idAndNode : mdg->nodes) {
730  const Node &node = idAndNode.second;
731  worklist.push_back(node.id);
732  }
733  }
734  /// Run only sibling fusion on the `mdg`.
735  void runSiblingFusionOnly() {
736  fuseSiblingNodes();
737  eraseUnusedMemRefAllocations();
738  }
739 
740  /// Run only producer/consumer fusion on the `mdg`.
741  void runProducerConsumerFusionOnly() {
742  fuseProducerConsumerNodes(
743  /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
744  eraseUnusedMemRefAllocations();
745  }
746 
747  // Run the GreedyFusion pass.
748  // *) First pass through the nodes fuses single-use producer nodes into their
749  // unique consumer.
750  // *) Second pass fuses sibling nodes which share no dependence edges.
751  // *) Third pass fuses any remaining producer nodes into their users.
752  void runGreedyFusion() {
753  // TODO: Run this repeatedly until a fixed-point is reached.
754  fuseProducerConsumerNodes(/*maxSrcUserCount=*/1);
755  fuseSiblingNodes();
756  fuseProducerConsumerNodes(
757  /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
758  eraseUnusedMemRefAllocations();
759  }
760 
761  /// Returns true if a private memref can be created for `memref` given
762  /// the fusion scenario reflected by the other arguments.
763  bool canCreatePrivateMemRef(Value memref,
764  const DenseSet<Value> &srcEscapingMemRefs,
765  unsigned producerId, unsigned consumerId,
766  bool removeSrcNode) {
767  const Node *consumerNode = mdg->getNode(consumerId);
768  // If `memref` is an escaping one, do not create a private memref
769  // for the below scenarios, since doing so will leave the escaping
770  // memref unmodified as all the writes originally meant for the
771  // escaping memref would be performed on the private memref:
772  // 1. The source is to be removed after fusion,
773  // OR
774  // 2. The destination writes to `memref`.
775  if (srcEscapingMemRefs.count(memref) > 0 &&
776  (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
777  return false;
778 
779  // Don't create a private memref if 'srcNode' has in edges on
780  // 'memref' or 'dstNode' has out edges on 'memref'.
781  if (mdg->getIncomingMemRefAccesses(producerId, memref) > 0 ||
782  mdg->getOutEdgeCount(consumerId, memref) > 0)
783  return false;
784 
785  // If 'srcNode' will be removed but it has out edges on 'memref' to
786  // nodes other than 'dstNode', we have to preserve dependences and
787  // cannot create a private memref.
788  if (removeSrcNode &&
789  any_of(mdg->outEdges[producerId], [&](const auto &edge) {
790  return edge.value == memref && edge.id != consumerId;
791  }))
792  return false;
793 
794  return true;
795  }
796 
797  /// Perform fusions with node `dstId` as the destination of fusion, with
798  /// No fusion is performed when producers with a user count greater than
799  /// `maxSrcUserCount` for any of the memrefs involved.
800  void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) {
801  LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
802  // Skip if this node was removed (fused into another node).
803  if (mdg->nodes.count(dstId) == 0)
804  return;
805  // Get 'dstNode' into which to attempt fusion.
806  auto *dstNode = mdg->getNode(dstId);
807  // Skip if 'dstNode' is not a loop nest.
808  if (!isa<AffineForOp>(dstNode->op))
809  return;
810  // Skip if 'dstNode' is a loop nest returning values.
811  // TODO: support loop nests that return values.
812  if (dstNode->op->getNumResults() > 0)
813  return;
814 
815  LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
816 
817  // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
818  // while preserving relative order. This can increase the maximum loop
819  // depth at which we can fuse a slice of a producer loop nest into a
820  // consumer loop nest.
821  sinkSequentialLoops(dstNode);
822  auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
823 
824  // Try to fuse 'dstNode' with candidate producer loops until a fixed point
825  // is reached. Fusing two loops may expose new fusion opportunities.
826  bool dstNodeChanged;
827  do {
828  // Gather src loop candidates for 'dstNode' and visit them in "quasi"
829  // reverse program order to minimize the number of iterations needed to
830  // reach the fixed point. Note that this is a best effort approach since
831  // 'getProducerCandidates' does not always guarantee that program order
832  // in 'srcIdCandidates'.
833  dstNodeChanged = false;
834  SmallVector<unsigned, 16> srcIdCandidates;
835  getProducerCandidates(dstId, mdg, srcIdCandidates);
836 
837  for (unsigned srcId : llvm::reverse(srcIdCandidates)) {
838  // Get 'srcNode' from which to attempt fusion into 'dstNode'.
839  auto *srcNode = mdg->getNode(srcId);
840  auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
841  LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId
842  << " for dst loop " << dstId << "\n");
843 
844  // Skip if 'srcNode' is a loop nest returning values.
845  // TODO: support loop nests that return values.
846  if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
847  continue;
848 
849  DenseSet<Value> producerConsumerMemrefs;
850  gatherProducerConsumerMemrefs(srcId, dstId, mdg,
851  producerConsumerMemrefs);
852 
853  // Skip if 'srcNode' out edge count on any memref is greater than
854  // 'maxSrcUserCount'.
855  if (any_of(producerConsumerMemrefs, [&](Value memref) {
856  return mdg->getOutEdgeCount(srcNode->id, memref) >
857  maxSrcUserCount;
858  }))
859  continue;
860 
861  // Gather memrefs in 'srcNode' that are written and escape out of the
862  // block (e.g., memref block arguments, returned memrefs,
863  // memrefs passed to function calls, etc.).
864  DenseSet<Value> srcEscapingMemRefs;
865  gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
866 
867  // Skip if there are non-affine operations in between the 'srcNode'
868  // and 'dstNode' using their memrefs. If so, we wouldn't be able to
869  // compute a legal insertion point for now. 'srcNode' and 'dstNode'
870  // memrefs with non-affine operation users would be considered
871  // escaping memrefs so we can limit this check to only scenarios with
872  // escaping memrefs.
873  if (!srcEscapingMemRefs.empty() &&
874  hasNonAffineUsersOnPath(srcNode->op, dstNode->op)) {
875  LLVM_DEBUG(llvm::dbgs()
876  << "Can't fuse: non-affine users in between the loops\n");
877  continue;
878  }
879 
880  // Compute an operation list insertion point for the fused loop
881  // nest which preserves dependences.
882  Operation *fusedLoopInsPoint =
883  mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
884  if (fusedLoopInsPoint == nullptr)
885  continue;
886 
887  // It's possible this fusion is at an inner depth (i.e., there are
888  // common surrounding affine loops for the source and destination for
889  // ops). We need to get this number because the call to canFuseLoops
890  // needs to be passed the absolute depth. The max legal depth and the
891  // depths we try below are however *relative* and as such don't include
892  // the common depth.
893  SmallVector<AffineForOp, 4> surroundingLoops;
894  getAffineForIVs(*dstAffineForOp, &surroundingLoops);
895  unsigned numSurroundingLoops = surroundingLoops.size();
896 
897  // Compute the innermost common loop depth for dstNode
898  // producer-consumer loads/stores.
899  SmallVector<Operation *, 2> dstMemrefOps;
900  for (Operation *op : dstNode->loads)
901  if (producerConsumerMemrefs.count(
902  cast<AffineReadOpInterface>(op).getMemRef()) > 0)
903  dstMemrefOps.push_back(op);
904  for (Operation *op : dstNode->stores)
905  if (producerConsumerMemrefs.count(
906  cast<AffineWriteOpInterface>(op).getMemRef()))
907  dstMemrefOps.push_back(op);
908  unsigned dstLoopDepthTest =
909  getInnermostCommonLoopDepth(dstMemrefOps) - numSurroundingLoops;
910 
911  // Check the feasibility of fusing src loop nest into dst loop nest
912  // at loop depths in range [1, dstLoopDepthTest].
913  unsigned maxLegalFusionDepth = 0;
914  SmallVector<ComputationSliceState, 8> depthSliceUnions;
915  depthSliceUnions.resize(dstLoopDepthTest);
917  for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
918  FusionResult result =
919  affine::canFuseLoops(srcAffineForOp, dstAffineForOp,
920  /*dstLoopDepth=*/i + numSurroundingLoops,
921  &depthSliceUnions[i - 1], strategy);
922 
923  if (result.value == FusionResult::Success)
924  maxLegalFusionDepth = i;
925  }
926 
927  if (maxLegalFusionDepth == 0) {
928  LLVM_DEBUG(llvm::dbgs()
929  << "Can't fuse: fusion is not legal at any depth\n");
930  continue;
931  }
932 
933  // Check if fusion would be profitable. We skip profitability analysis
934  // for maximal fusion since we already know the maximal legal depth to
935  // fuse.
936  unsigned bestDstLoopDepth = maxLegalFusionDepth;
937  if (!maximalFusion) {
938  // Retrieve producer stores from the src loop.
939  SmallVector<Operation *, 2> producerStores;
940  for (Operation *op : srcNode->stores)
941  if (producerConsumerMemrefs.count(
942  cast<AffineWriteOpInterface>(op).getMemRef()))
943  producerStores.push_back(op);
944 
945  // TODO: Suppport multiple producer stores in profitability
946  // analysis. We limit profitability analysis to only scenarios with
947  // a single producer store for now. Note that some multi-store
948  // producer scenarios will still go through profitability analysis
949  // if only one of the stores is involved the producer-consumer
950  // relationship of the candidate loops.
951  assert(!producerStores.empty() && "Expected producer store");
952  if (producerStores.size() > 1)
953  LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not "
954  "supported for this case\n");
955  else if (!isFusionProfitable(producerStores[0], producerStores[0],
956  dstAffineForOp, depthSliceUnions,
957  maxLegalFusionDepth, &bestDstLoopDepth,
958  computeToleranceThreshold))
959  continue;
960  }
961 
962  assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
963  ComputationSliceState &bestSlice =
964  depthSliceUnions[bestDstLoopDepth - 1];
965  assert(!bestSlice.isEmpty() && "Missing slice union for depth");
966 
967  // Determine if 'srcId' can be removed after fusion, taking into
968  // account remaining dependences, escaping memrefs and the fusion
969  // insertion point.
970  bool removeSrcNode = canRemoveSrcNodeAfterFusion(
971  srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
972  mdg);
973 
974  DenseSet<Value> privateMemrefs;
975  for (Value memref : producerConsumerMemrefs) {
976  if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
977  removeSrcNode)) {
978  // Create a private version of this memref.
979  LLVM_DEBUG(llvm::dbgs()
980  << "Creating private memref for " << memref << '\n');
981  // Create a private version of this memref.
982  privateMemrefs.insert(memref);
983  }
984  }
985 
986  // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
987  fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
988  dstNodeChanged = true;
989 
990  LLVM_DEBUG(llvm::dbgs()
991  << "Fused src loop " << srcId << " into dst loop " << dstId
992  << " at depth " << bestDstLoopDepth << ":\n"
993  << dstAffineForOp << "\n");
994 
995  // Move 'dstAffineForOp' before 'insertPointInst' if needed.
996  if (fusedLoopInsPoint != dstAffineForOp)
997  dstAffineForOp->moveBefore(fusedLoopInsPoint);
998 
999  // Update edges between 'srcNode' and 'dstNode'.
1000  mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
1001  removeSrcNode);
1002 
1003  // Create private memrefs.
1004  if (!privateMemrefs.empty()) {
1005  // Gather stores for all the private-to-be memrefs.
1006  DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
1007  dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
1008  Value storeMemRef = storeOp.getMemRef();
1009  if (privateMemrefs.count(storeMemRef) > 0)
1010  privateMemRefToStores[storeMemRef].push_back(storeOp);
1011  });
1012 
1013  // Replace original memrefs with private memrefs. Note that all the
1014  // loads and stores on these memrefs will be replaced with a new
1015  // loads and stores. Any reference to the original ones becomes
1016  // invalid after this point.
1017  for (auto &memrefToStoresPair : privateMemRefToStores) {
1018  // TODO: Use union of memref write regions to compute
1019  // private memref footprint.
1020  SmallVector<Operation *, 4> &storesForMemref =
1021  memrefToStoresPair.second;
1022  Value newMemRef = createPrivateMemRef(
1023  dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
1024  fastMemorySpace, localBufSizeThreshold);
1025  // Create new node in dependence graph for 'newMemRef' alloc op.
1026  unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
1027  // Add edge from 'newMemRef' node to dstNode.
1028  mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
1029  }
1030  // One or more entries for 'newMemRef' alloc op are inserted into
1031  // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to
1032  // reallocate, update dstNode.
1033  dstNode = mdg->getNode(dstId);
1034  }
1035 
1036  // Collect dst loop stats after memref privatization transformation.
1037  LoopNestStateCollector dstLoopCollector;
1038  dstLoopCollector.collect(dstAffineForOp);
1039 
1040  // Clear and add back loads and stores.
1041  mdg->clearNodeLoadAndStores(dstNode->id);
1042  mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
1043  dstLoopCollector.storeOpInsts);
1044 
1045  if (removeSrcNode) {
1046  LLVM_DEBUG(llvm::dbgs()
1047  << "Removing src loop " << srcId << " after fusion\n");
1048  // srcNode is no longer valid after it is removed from mdg.
1049  srcAffineForOp.erase();
1050  mdg->removeNode(srcId);
1051  srcNode = nullptr;
1052  }
1053  }
1054  } while (dstNodeChanged);
1055  }
1056 
1057  /// Visit each node in the graph, and for each node, attempt to fuse it with
1058  /// producer-consumer candidates. No fusion is performed when producers with a
1059  /// user count greater than `maxSrcUserCount` for any of the memrefs involved
1060  /// are encountered.
1061  void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
1062  LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
1063  init();
1064  while (!worklist.empty()) {
1065  unsigned dstId = worklist.back();
1066  worklist.pop_back();
1067  performFusionsIntoDest(dstId, maxSrcUserCount);
1068  }
1069  }
1070 
1071  // Visits each node in the graph, and for each node, attempts to fuse it with
1072  // its sibling nodes (nodes which share a parent, but no dependence edges).
1073  void fuseSiblingNodes() {
1074  LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n");
1075  init();
1076  while (!worklist.empty()) {
1077  unsigned dstId = worklist.back();
1078  worklist.pop_back();
1079 
1080  // Skip if this node was removed (fused into another node).
1081  if (mdg->nodes.count(dstId) == 0)
1082  continue;
1083  // Get 'dstNode' into which to attempt fusion.
1084  auto *dstNode = mdg->getNode(dstId);
1085  // Skip if 'dstNode' is not a loop nest.
1086  if (!isa<AffineForOp>(dstNode->op))
1087  continue;
1088  // Attempt to fuse 'dstNode' with its sibling nodes in the graph.
1089  fuseWithSiblingNodes(dstNode);
1090  }
1091  }
1092 
1093  // Attempt to fuse 'dstNode' with sibling nodes in the graph.
1094  void fuseWithSiblingNodes(Node *dstNode) {
1095  DenseSet<unsigned> visitedSibNodeIds;
1096  std::pair<unsigned, Value> idAndMemref;
1097  auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
1098 
1099  while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
1100  unsigned sibId = idAndMemref.first;
1101  Value memref = idAndMemref.second;
1102  // TODO: Check that 'sibStoreOpInst' post-dominates all other
1103  // stores to the same memref in 'sibNode' loop nest.
1104  auto *sibNode = mdg->getNode(sibId);
1105  // Compute an operation list insertion point for the fused loop
1106  // nest which preserves dependences.
1107  assert(sibNode->op->getBlock() == dstNode->op->getBlock());
1108  Operation *insertPointInst =
1109  sibNode->op->isBeforeInBlock(dstNode->op)
1110  ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id)
1111  : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id);
1112  if (insertPointInst == nullptr)
1113  continue;
1114 
1115  // Check if fusion would be profitable and at what depth.
1116 
1117  // Get unique 'sibNode' load op to 'memref'.
1118  SmallVector<Operation *, 2> sibLoadOpInsts;
1119  sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
1120  // Currently findSiblingNodeToFuse searches for siblings with one load.
1121  assert(sibLoadOpInsts.size() == 1);
1122  Operation *sibLoadOpInst = sibLoadOpInsts[0];
1123 
1124  // Gather 'dstNode' load ops to 'memref'.
1125  SmallVector<Operation *, 2> dstLoadOpInsts;
1126  dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
1127 
1128  // It's possible this fusion is at an inner depth (i.e., there are common
1129  // surrounding affine loops for the source and destination for ops). We
1130  // need to get this number because the call to canFuseLoops needs to be
1131  // passed the absolute depth. The max legal depth and the depths we try
1132  // below are however *relative* and as such don't include the common
1133  // depth.
1134  SmallVector<AffineForOp, 4> surroundingLoops;
1135  getAffineForIVs(*dstAffineForOp, &surroundingLoops);
1136  unsigned numSurroundingLoops = surroundingLoops.size();
1137  SmallVector<AffineForOp, 4> dstLoopIVs;
1138  getAffineForIVs(*dstLoadOpInsts[0], &dstLoopIVs);
1139  unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
1140  auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
1141 
1142  // Compute loop depth and slice union for fusion.
1143  SmallVector<ComputationSliceState, 8> depthSliceUnions;
1144  depthSliceUnions.resize(dstLoopDepthTest);
1145  unsigned maxLegalFusionDepth = 0;
1146  FusionStrategy strategy(memref);
1147  for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1148  FusionResult result =
1149  affine::canFuseLoops(sibAffineForOp, dstAffineForOp,
1150  /*dstLoopDepth=*/i + numSurroundingLoops,
1151  &depthSliceUnions[i - 1], strategy);
1152 
1153  if (result.value == FusionResult::Success)
1154  maxLegalFusionDepth = i;
1155  }
1156 
1157  LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
1158  << maxLegalFusionDepth << '\n');
1159 
1160  // Skip if fusion is not feasible at any loop depths.
1161  if (maxLegalFusionDepth == 0)
1162  continue;
1163 
1164  unsigned bestDstLoopDepth = maxLegalFusionDepth;
1165  if (!maximalFusion) {
1166  // Check if fusion would be profitable. For sibling fusion, the sibling
1167  // load op is treated as the src "store" op for fusion profitability
1168  // purposes. The footprint of the load in the slice relative to the
1169  // unfused source's determines reuse.
1170  if (!isFusionProfitable(sibLoadOpInst, sibLoadOpInst, dstAffineForOp,
1171  depthSliceUnions, maxLegalFusionDepth,
1172  &bestDstLoopDepth, computeToleranceThreshold))
1173  continue;
1174  }
1175 
1176  assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
1177  assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
1178  "Fusion depth has no computed slice union");
1179  // Check if source loop is being inserted in the innermost
1180  // destination loop. Based on this, the fused loop may be optimized
1181  // further inside `fuseLoops`.
1182  bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
1183  // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
1184  affine::fuseLoops(sibAffineForOp, dstAffineForOp,
1185  depthSliceUnions[bestDstLoopDepth - 1],
1186  isInnermostInsertion);
1187 
1188  auto dstForInst = cast<AffineForOp>(dstNode->op);
1189  // Update operation position of fused loop nest (if needed).
1190  if (insertPointInst != dstForInst) {
1191  dstForInst->moveBefore(insertPointInst);
1192  }
1193  // Update data dependence graph state post fusion.
1194  updateStateAfterSiblingFusion(sibNode, dstNode);
1195  }
1196  }
1197 
1198  // Searches block argument uses and the graph from 'dstNode' looking for a
1199  // fusion candidate sibling node which shares no dependences with 'dstNode'
1200  // but which loads from the same memref. Returns true and sets
1201  // 'idAndMemrefToFuse' on success. Returns false otherwise.
1202  bool findSiblingNodeToFuse(Node *dstNode,
1203  DenseSet<unsigned> *visitedSibNodeIds,
1204  std::pair<unsigned, Value> *idAndMemrefToFuse) {
1205  // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse
1206  // on 'memref'.
1207  auto canFuseWithSibNode = [&](Node *sibNode, Value memref) {
1208  // Skip if 'outEdge' is not a read-after-write dependence.
1209  // TODO: Remove restrict to single load op restriction.
1210  if (sibNode->getLoadOpCount(memref) != 1)
1211  return false;
1212  // Skip if there exists a path of dependent edges between
1213  // 'sibNode' and 'dstNode'.
1214  if (mdg->hasDependencePath(sibNode->id, dstNode->id) ||
1215  mdg->hasDependencePath(dstNode->id, sibNode->id))
1216  return false;
1217  // Skip sib node if it loads to (and stores from) the same memref on
1218  // which it also has an input dependence edge.
1219  DenseSet<Value> loadAndStoreMemrefSet;
1220  sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
1221  if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) {
1222  return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0;
1223  }))
1224  return false;
1225 
1226  // Check that all stores are to the same memref if any.
1227  DenseSet<Value> storeMemrefs;
1228  for (auto *storeOpInst : sibNode->stores) {
1229  storeMemrefs.insert(
1230  cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
1231  }
1232  if (storeMemrefs.size() > 1)
1233  return false;
1234 
1235  // Skip if a memref value in one node is used by a non-affine memref
1236  // access that lies between 'dstNode' and 'sibNode'.
1237  if (hasNonAffineUsersOnPath(dstNode->op, sibNode->op) ||
1238  hasNonAffineUsersOnPath(sibNode->op, dstNode->op))
1239  return false;
1240  return true;
1241  };
1242 
1243  // Search for siblings which load the same memref block argument.
1244  Block *block = dstNode->op->getBlock();
1245  for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) {
1246  for (Operation *user : block->getArgument(i).getUsers()) {
1247  auto loadOp = dyn_cast<AffineReadOpInterface>(user);
1248  if (!loadOp)
1249  continue;
1250  // Gather loops surrounding 'use'.
1252  getAffineForIVs(*user, &loops);
1253  // Skip 'use' if it is not within a loop nest.
1254  // Find the surrounding affine.for nested immediately within the
1255  // block.
1256  auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
1257  return loop->getBlock() == &mdg->block;
1258  });
1259  // Skip 'use' if it is not within a loop nest in `block`.
1260  if (it == loops.end())
1261  continue;
1262  Node *sibNode = mdg->getForOpNode(*it);
1263  assert(sibNode != nullptr);
1264  // Skip 'use' if it not a sibling to 'dstNode'.
1265  if (sibNode->id == dstNode->id)
1266  continue;
1267  // Skip 'use' if it has been visited.
1268  if (visitedSibNodeIds->count(sibNode->id) > 0)
1269  continue;
1270  // Skip 'use' if it does not load from the same memref as 'dstNode'.
1271  auto memref = loadOp.getMemRef();
1272  if (dstNode->getLoadOpCount(memref) == 0)
1273  continue;
1274  // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
1275  if (canFuseWithSibNode(sibNode, memref)) {
1276  visitedSibNodeIds->insert(sibNode->id);
1277  idAndMemrefToFuse->first = sibNode->id;
1278  idAndMemrefToFuse->second = memref;
1279  return true;
1280  }
1281  }
1282  }
1283 
1284  // Search for siblings by following edges through an intermediate src node.
1285  // Collect candidate 'dstNode' input edges in 'inEdges'.
1288  dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
1289  // Add 'inEdge' if it is a read-after-write dependence.
1290  if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
1291  mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
1292  inEdges.push_back(inEdge);
1293  });
1294 
1295  // Search for sibling nodes to fuse by visiting output edges from each input
1296  // edge in 'inEdges'.
1297  for (auto &inEdge : inEdges) {
1298  // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'.
1301  inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) {
1302  unsigned sibNodeId = outEdge.id;
1303  if (visitedSibNodeIds->count(sibNodeId) > 0)
1304  return;
1305  // Skip output edge if not a sibling using the same memref.
1306  if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
1307  return;
1308  auto *sibNode = mdg->getNode(sibNodeId);
1309  if (!isa<AffineForOp>(sibNode->op))
1310  return;
1311  // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
1312  if (canFuseWithSibNode(sibNode, outEdge.value)) {
1313  // Add candidate 'outEdge' to sibling node.
1314  outEdges.push_back(outEdge);
1315  }
1316  });
1317 
1318  // Add first candidate if any were returned.
1319  if (!outEdges.empty()) {
1320  visitedSibNodeIds->insert(outEdges[0].id);
1321  idAndMemrefToFuse->first = outEdges[0].id;
1322  idAndMemrefToFuse->second = outEdges[0].value;
1323  return true;
1324  }
1325  }
1326  return false;
1327  }
1328 
1329  /// Update data dependence graph state to reflect sibling fusion of 'sibNode'
1330  /// into 'dstNode'.
1331  void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) {
1332  // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion.
1333  mdg->updateEdges(sibNode->id, dstNode->id);
1334 
1335  // Collect dst loop stats after memref privatization transformation.
1336  auto dstForInst = cast<AffineForOp>(dstNode->op);
1337  LoopNestStateCollector dstLoopCollector;
1338  dstLoopCollector.collect(dstForInst);
1339  // Clear and add back loads and stores
1340  mdg->clearNodeLoadAndStores(dstNode->id);
1341  mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
1342  dstLoopCollector.storeOpInsts);
1343  // Remove old sibling loop nest if it no longer has outgoing dependence
1344  // edges, and it does not write to a memref which escapes the block.
1345  if (mdg->getOutEdgeCount(sibNode->id) == 0) {
1346  Operation *op = sibNode->op;
1347  mdg->removeNode(sibNode->id);
1348  op->erase();
1349  }
1350  }
1351 
1352  // Clean up any allocs with no users.
1353  void eraseUnusedMemRefAllocations() {
1354  for (auto &pair : mdg->memrefEdgeCount) {
1355  if (pair.second > 0)
1356  continue;
1357  auto memref = pair.first;
1358  // Skip if there exist other uses (return operation or function calls).
1359  if (!memref.use_empty())
1360  continue;
1361  // Use list expected to match the dep graph info.
1362  auto *op = memref.getDefiningOp();
1363  if (isa_and_nonnull<memref::AllocOp>(op))
1364  op->erase();
1365  }
1366  }
1367 };
1368 
1369 } // namespace
1370 
1371 /// Run fusion on `block`.
1372 void LoopFusion::runOnBlock(Block *block) {
1373  MemRefDependenceGraph g(*block);
1374  if (!g.init()) {
1375  LLVM_DEBUG(llvm::dbgs() << "MDG init failed\n");
1376  return;
1377  }
1378 
1379  std::optional<unsigned> fastMemorySpaceOpt;
1380  if (fastMemorySpace.hasValue())
1381  fastMemorySpaceOpt = fastMemorySpace;
1382  unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
1383  GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
1384  maximalFusion, computeToleranceThreshold);
1385 
1386  if (affineFusionMode == FusionMode::ProducerConsumer)
1387  fusion.runProducerConsumerFusionOnly();
1388  else if (affineFusionMode == FusionMode::Sibling)
1389  fusion.runSiblingFusionOnly();
1390  else
1391  fusion.runGreedyFusion();
1392 }
1393 
1394 void LoopFusion::runOnOperation() {
1395  // Call fusion on every op that has at least two affine.for nests (in post
1396  // order).
1397  getOperation()->walk([&](Operation *op) {
1398  for (Region &region : op->getRegions()) {
1399  for (Block &block : region.getBlocks()) {
1400  auto affineFors = block.getOps<AffineForOp>();
1401  if (!affineFors.empty() && !llvm::hasSingleElement(affineFors))
1402  runOnBlock(&block);
1403  }
1404  }
1405  });
1406 }
1407 
1409  unsigned fastMemorySpace, uint64_t localBufSizeThreshold,
1410  bool maximalFusion, enum FusionMode affineFusionMode) {
1411  return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
1412  maximalFusion, affineFusionMode);
1413 }
MemRefDependenceGraph::Node Node
Definition: Utils.cpp:38
static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, AffineForOp dstForOp, ArrayRef< ComputationSliceState > depthSliceUnions, unsigned maxLegalFusionDepth, unsigned *dstLoopDepth, double computeToleranceThreshold)
Definition: LoopFusion.cpp:430
static bool hasNonAffineUsersOnPath(Operation *start, Operation *end, Value memref)
Returns true if there are any non-affine uses of memref in any of the operations between start and en...
Definition: LoopFusion.cpp:349
static bool canRemoveSrcNodeAfterFusion(unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, Operation *fusedLoopInsPoint, const DenseSet< Value > &escapingMemRefs, MemRefDependenceGraph *mdg)
Returns true if node 'srcId' can be removed after fusing it with node 'dstId'.
Definition: LoopFusion.cpp:78
static bool isEscapingMemref(Value memref, Block *block)
A memref escapes in the context of the fusion pass if either:
Definition: LoopFusion.cpp:187
static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, DenseSet< Value > &escapingMemRefs)
Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' that escape the block or ...
Definition: LoopFusion.cpp:217
static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, unsigned dstLoopDepth, std::optional< unsigned > fastMemorySpace, uint64_t localBufSizeThreshold)
Definition: LoopFusion.cpp:245
static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg, SmallVectorImpl< unsigned > &srcIdCandidates)
Returns in 'srcIdCandidates' the producer fusion candidates for consumer 'dstId'.
Definition: LoopFusion.cpp:137
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
Definition: AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:412
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:404
void getValues(unsigned start, unsigned end, SmallVectorImpl< Value > *values) const
Returns the Values associated with variables in range [start, end).
This class helps build Operations.
Definition: Builders.h:216
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
user_range getUsers() const
Definition: Value.h:228
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:51
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
Describes the fusion strategy to be used in the Affine loop fusion utilities.
unsigned getNumCols() const
Returns the number of columns in the constraint system.
bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, AffineForOp dstForOp, LoopNestStats &dstStats, const ComputationSliceState &slice, int64_t *computeCost)
Computes and returns in 'computeCost', the total compute cost of fusing the 'slice' of the loop nest ...
void gatherProducerConsumerMemrefs(ArrayRef< Operation * > srcOps, ArrayRef< Operation * > dstOps, DenseSet< Value > &producerConsumerMemrefs)
Returns in 'producerConsumerMemrefs' the memrefs involved in a producer-consumer dependence between w...
int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats)
Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
void fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, const ComputationSliceState &srcSlice, bool isInnermostSiblingInsertionFusion=false)
Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point and source slice loop bo...
void getAffineForIVs(Operation &op, SmallVectorImpl< AffineForOp > *loops)
Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered from the outermost 'affine....
Definition: Utils.cpp:633
std::optional< int64_t > getMemoryFootprintBytes(AffineForOp forOp, int memorySpace=-1)
Gets the memory footprint of all data touched in the specified memory space in bytes; if the memory s...
Definition: Utils.cpp:1953
std::unique_ptr< Pass > createLoopFusionPass(unsigned fastMemorySpace=0, uint64_t localBufSizeThreshold=0, bool maximalFusion=false, enum FusionMode fusionMode=FusionMode::Greedy)
Creates a loop fusion pass which fuses affine loop nests at the top-level of the operation the pass i...
FusionMode
Fusion mode to attempt.
Definition: Passes.h:31
@ ProducerConsumer
Definition: Passes.h:31
unsigned getInnermostCommonLoopDepth(ArrayRef< Operation * > ops, SmallVectorImpl< AffineForOp > *surroundingLoops=nullptr)
Returns the innermost common loop depth for the set of operations in 'ops'.
Definition: Utils.cpp:1385
bool getLoopNestStats(AffineForOp forOp, LoopNestStats *stats)
Collect loop nest statistics (eg.
AffineForOp sinkSequentialLoops(AffineForOp forOp)
Definition: LoopUtils.cpp:1457
FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, ComputationSliceState *srcSlice, FusionStrategy fusionStrategy=FusionStrategy::Generic)
Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the loop nest rooted at 'dst...
LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, ArrayRef< Value > extraIndices={}, AffineMap indexRemap=AffineMap(), ArrayRef< Value > extraOperands={}, ArrayRef< Value > symbolOperands={}, Operation *domOpFilter=nullptr, Operation *postDomOpFilter=nullptr, bool allowNonDereferencingOps=false, bool replaceInDeallocOp=false)
Replaces all "dereferencing" uses of oldMemRef with newMemRef while optionally remapping the old memr...
Definition: Utils.cpp:1366
std::optional< int64_t > getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType)
Returns the memref's element type's size in bytes where the elemental type is an int or float or a ve...
Definition: Utils.cpp:1196
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their associated operands for a ...
Definition: Utils.h:259
bool isEmpty() const
Returns true if the computation slice is empty.
Definition: Utils.h:291
std::optional< bool > isMaximal() const
Returns true if the computation slice encloses all the iterations of the sliced loop nest.
Definition: Utils.cpp:888
enum mlir::affine::FusionResult::ResultEnum value
SmallVector< Operation *, 4 > loadOpInsts
Definition: Utils.h:40
void collect(Operation *opToWalk)
Definition: Utils.cpp:43
SmallVector< Operation *, 4 > storeOpInsts
Definition: Utils.h:41
LoopNestStats aggregates various per-loop statistics (eg.
DenseMap< unsigned, SmallVector< Edge, 2 > > outEdges
Definition: Utils.h:112
Block & block
The block for which this graph is created to perform fusion.
Definition: Utils.h:211
unsigned addNode(Operation *op)
Definition: Utils.cpp:251
unsigned getIncomingMemRefAccesses(unsigned id, Value memref)
Definition: Utils.cpp:388
void addEdge(unsigned srcId, unsigned dstId, Value value)
Definition: Utils.cpp:316
Node * getForOpNode(AffineForOp forOp)
Definition: Utils.cpp:243
DenseMap< unsigned, Node > nodes
Definition: Utils.h:108
bool hasDependencePath(unsigned srcId, unsigned dstId)
Definition: Utils.cpp:354
void clearNodeLoadAndStores(unsigned id)
Definition: Utils.cpp:576
void updateEdges(unsigned srcId, unsigned dstId, const DenseSet< Value > &privateMemRefs, bool removeSrcId)
Definition: Utils.cpp:504
DenseMap< unsigned, SmallVector< Edge, 2 > > inEdges
Definition: Utils.h:110
void addToNode(unsigned id, const SmallVectorImpl< Operation * > &loads, const SmallVectorImpl< Operation * > &stores)
Definition: Utils.cpp:568
void forEachMemRefInputEdge(unsigned id, const std::function< void(Edge)> &callback)
Definition: Utils.cpp:584
unsigned getOutEdgeCount(unsigned id, Value memref=nullptr)
Definition: Utils.cpp:404
void removeNode(unsigned id)
Definition: Utils.cpp:258
Operation * getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId)
Definition: Utils.cpp:428
void forEachMemRefOutputEdge(unsigned id, const std::function< void(Edge)> &callback)
Definition: Utils.cpp:592
Node * getNode(unsigned id)
Definition: Utils.cpp:236
DenseMap< Value, unsigned > memrefEdgeCount
Definition: Utils.h:115
A region of a memref's data space; this is typically constructed by analyzing load/store op's on this...
Definition: Utils.h:430
FlatAffineValueConstraints * getConstraints()
Definition: Utils.h:469
std::optional< int64_t > getConstantBoundingSizeAndShape(SmallVectorImpl< int64_t > *shape=nullptr, std::vector< SmallVector< int64_t, 4 >> *lbs=nullptr, SmallVectorImpl< int64_t > *lbDivisors=nullptr) const
Returns a constant upper bound on the number of elements in this region if bounded by a known constan...
Definition: Utils.cpp:942
std::optional< int64_t > getRegionSize()
Returns the size of this MemRefRegion in bytes.
Definition: Utils.cpp:1215
LogicalResult compute(Operation *op, unsigned loopDepth, const ComputationSliceState *sliceState=nullptr, bool addMemRefDimBounds=true)
Computes the memory region accessed by this memref with the region represented as constraints symboli...
Definition: Utils.cpp:1044
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.