MLIR  19.0.0git
LoopFusion.cpp
Go to the documentation of this file.
1 //===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements affine fusion.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/Builders.h"
26 #include "mlir/Transforms/Passes.h"
27 #include "llvm/ADT/DenseMap.h"
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include <iomanip>
34 #include <optional>
35 #include <sstream>
36 
37 namespace mlir {
38 namespace affine {
39 #define GEN_PASS_DEF_AFFINELOOPFUSION
40 #include "mlir/Dialect/Affine/Passes.h.inc"
41 } // namespace affine
42 } // namespace mlir
43 
44 #define DEBUG_TYPE "affine-loop-fusion"
45 
46 using namespace mlir;
47 using namespace mlir::affine;
48 
49 namespace {
50 /// Loop fusion pass. This pass currently supports a greedy fusion policy,
51 /// which fuses loop nests with single-writer/single-reader memref dependences
52 /// with the goal of improving locality.
53 // TODO: Support fusion of source loop nests which write to multiple
54 // memrefs, where each memref can have multiple users (if profitable).
55 struct LoopFusion : public affine::impl::AffineLoopFusionBase<LoopFusion> {
56  LoopFusion() = default;
57  LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes,
58  bool maximalFusion, enum FusionMode affineFusionMode) {
59  this->fastMemorySpace = fastMemorySpace;
60  this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024;
61  this->maximalFusion = maximalFusion;
62  this->affineFusionMode = affineFusionMode;
63  }
64 
65  void runOnBlock(Block *block);
66  void runOnOperation() override;
67 };
68 
69 } // namespace
70 
71 /// Returns true if node 'srcId' can be removed after fusing it with node
72 /// 'dstId'. The node can be removed if any of the following conditions are met:
73 /// 1. 'srcId' has no output dependences after fusion and no escaping memrefs.
74 /// 2. 'srcId' has no output dependences after fusion, has escaping memrefs
75 /// and the fusion slice is maximal.
76 /// 3. 'srcId' has output dependences after fusion, the fusion slice is
77 /// maximal and the fusion insertion point dominates all the dependences.
79  unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice,
80  Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs,
81  MemRefDependenceGraph *mdg) {
82 
83  Operation *dstNodeOp = mdg->getNode(dstId)->op;
84  bool hasOutDepsAfterFusion = false;
85 
86  for (auto &outEdge : mdg->outEdges[srcId]) {
87  Operation *depNodeOp = mdg->getNode(outEdge.id)->op;
88  // Skip dependence with dstOp since it will be removed after fusion.
89  if (depNodeOp == dstNodeOp)
90  continue;
91 
92  // Only fusion within the same block is supported. Use domination analysis
93  // when needed.
94  if (depNodeOp->getBlock() != dstNodeOp->getBlock())
95  return false;
96 
97  // Check if the insertion point of the fused loop dominates the dependence.
98  // Otherwise, the src loop can't be removed.
99  if (fusedLoopInsPoint != depNodeOp &&
100  !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) {
101  LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't "
102  "dominate dependence\n");
103  return false;
104  }
105 
106  hasOutDepsAfterFusion = true;
107  }
108 
109  // If src loop has dependences after fusion or it writes to an live-out or
110  // escaping memref, we can only remove it if the fusion slice is maximal so
111  // that all the dependences are preserved.
112  if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
113  std::optional<bool> isMaximal = fusionSlice.isMaximal();
114  if (!isMaximal) {
115  LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine "
116  "if fusion is maximal\n");
117  return false;
118  }
119 
120  if (!*isMaximal) {
121  LLVM_DEBUG(llvm::dbgs()
122  << "Src loop can't be removed: fusion is not maximal\n");
123  return false;
124  }
125  }
126 
127  return true;
128 }
129 
130 /// Returns in 'srcIdCandidates' the producer fusion candidates for consumer
131 /// 'dstId'. Candidates are sorted by node id order. This order corresponds to
132 /// the program order when the 'mdg' is created. However, program order is not
133 /// guaranteed and must not be required by the client. Program order won't be
134 /// held if the 'mdg' is reused from a previous fusion step or if the node
135 /// creation order changes in the future to support more advance cases.
136 // TODO: Move this to a loop fusion utility once 'mdg' is also moved.
137 static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg,
138  SmallVectorImpl<unsigned> &srcIdCandidates) {
139  // Skip if no input edges along which to fuse.
140  if (mdg->inEdges.count(dstId) == 0)
141  return;
142 
143  // Gather memrefs from loads in 'dstId'.
144  auto *dstNode = mdg->getNode(dstId);
145  DenseSet<Value> consumedMemrefs;
146  for (Operation *load : dstNode->loads)
147  consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
148 
149  // Traverse 'dstId' incoming edges and gather the nodes that contain a store
150  // to one of the consumed memrefs.
151  for (auto &srcEdge : mdg->inEdges[dstId]) {
152  auto *srcNode = mdg->getNode(srcEdge.id);
153  // Skip if 'srcNode' is not a loop nest.
154  if (!isa<AffineForOp>(srcNode->op))
155  continue;
156 
157  if (any_of(srcNode->stores, [&](Operation *op) {
158  auto storeOp = cast<AffineWriteOpInterface>(op);
159  return consumedMemrefs.count(storeOp.getMemRef()) > 0;
160  }))
161  srcIdCandidates.push_back(srcNode->id);
162  }
163 
164  llvm::sort(srcIdCandidates);
165  srcIdCandidates.erase(
166  std::unique(srcIdCandidates.begin(), srcIdCandidates.end()),
167  srcIdCandidates.end());
168 }
169 
170 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a
171 /// producer-consumer dependence between 'srcId' and 'dstId'.
172 static void
173 gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
175  DenseSet<Value> &producerConsumerMemrefs) {
176  auto *dstNode = mdg->getNode(dstId);
177  auto *srcNode = mdg->getNode(srcId);
178  gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads,
179  producerConsumerMemrefs);
180 }
181 
182 /// A memref escapes in the context of the fusion pass if either:
183 /// 1. it (or its alias) is a block argument, or
184 /// 2. created by an op not known to guarantee alias freedom,
185 /// 3. it (or its alias) are used by ops other than affine dereferencing ops
186 /// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops,
187 /// terminator ops, etc.); such ops do not deference the memref in an affine
188 /// way.
189 static bool isEscapingMemref(Value memref, Block *block) {
190  Operation *defOp = memref.getDefiningOp();
191  // Check if 'memref' is a block argument.
192  if (!defOp)
193  return true;
194 
195  // Check if this is defined to be an alias of another memref.
196  if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
197  if (isEscapingMemref(viewOp.getViewSource(), block))
198  return true;
199 
200  // Any op besides allocating ops wouldn't guarantee alias freedom
201  if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(defOp, memref))
202  return true;
203 
204  // Check if 'memref' is used by a non-deferencing op (including unknown ones)
205  // (e.g., call ops, alias creating ops, etc.).
206  return llvm::any_of(memref.getUsers(), [&](Operation *user) {
207  // Ignore users outside of `block`.
208  Operation *ancestorOp = block->getParent()->findAncestorOpInRegion(*user);
209  if (!ancestorOp)
210  return true;
211  if (ancestorOp->getBlock() != block)
212  return false;
213  return !isa<AffineMapAccessInterface>(*user);
214  });
215 }
216 
217 /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
218 /// that escape the block or are accessed in a non-affine way.
219 static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
220  DenseSet<Value> &escapingMemRefs) {
221  auto *node = mdg->getNode(id);
222  for (Operation *storeOp : node->stores) {
223  auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
224  if (escapingMemRefs.count(memref))
225  continue;
226  if (isEscapingMemref(memref, &mdg->block))
227  escapingMemRefs.insert(memref);
228  }
229 }
230 
231 // Sinks all sequential loops to the innermost levels (while preserving
232 // relative order among them) and moves all parallel loops to the
233 // outermost (while again preserving relative order among them).
234 // This can increase the loop depth at which we can fuse a slice, since we are
235 // pushing loop carried dependence to a greater depth in the loop nest.
237  assert(isa<AffineForOp>(node->op));
238  AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op));
239  node->op = newRootForOp;
240 }
241 
242 // Creates and returns a private (single-user) memref for fused loop rooted
243 // at 'forOp', with (potentially reduced) memref size based on the
244 // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
245 // TODO: consider refactoring the common code from generateDma and
246 // this one.
247 static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
248  unsigned dstLoopDepth,
249  std::optional<unsigned> fastMemorySpace,
250  uint64_t localBufSizeThreshold) {
251  Operation *forInst = forOp.getOperation();
252 
253  // Create builder to insert alloc op just before 'forOp'.
254  OpBuilder b(forInst);
255  // Builder to create constants at the top level.
256  OpBuilder top(forInst->getParentRegion());
257  // Create new memref type based on slice bounds.
258  auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
259  auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
260  unsigned rank = oldMemRefType.getRank();
261 
262  // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
263  MemRefRegion region(srcStoreOpInst->getLoc());
264  bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth));
265  (void)validRegion;
266  assert(validRegion && "unexpected memref region failure");
267  SmallVector<int64_t, 4> newShape;
268  std::vector<SmallVector<int64_t, 4>> lbs;
269  SmallVector<int64_t, 8> lbDivisors;
270  lbs.reserve(rank);
271  // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
272  // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
273  std::optional<int64_t> numElements =
274  region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
275  assert(numElements && "non-constant number of elts in local buffer");
276 
277  const FlatAffineValueConstraints *cst = region.getConstraints();
278  // 'outerIVs' holds the values that this memory region is symbolic/parametric
279  // on; this would correspond to loop IVs surrounding the level at which the
280  // slice is being materialized.
281  SmallVector<Value, 8> outerIVs;
282  cst->getValues(rank, cst->getNumVars(), &outerIVs);
283 
284  // Build 'rank' AffineExprs from MemRefRegion 'lbs'
286  offsets.reserve(rank);
287  for (unsigned d = 0; d < rank; ++d) {
288  assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
289 
290  AffineExpr offset = top.getAffineConstantExpr(0);
291  for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
292  offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
293  }
294  assert(lbDivisors[d] > 0);
295  offset =
296  (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
297  offsets.push_back(offset);
298  }
299 
300  // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
301  // by 'srcStoreOpInst'.
302  auto eltSize = getMemRefIntOrFloatEltSizeInBytes(oldMemRefType);
303  assert(eltSize && "memrefs with size elt types expected");
304  uint64_t bufSize = *eltSize * *numElements;
305  unsigned newMemSpace;
306  if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
307  newMemSpace = *fastMemorySpace;
308  } else {
309  newMemSpace = oldMemRefType.getMemorySpaceAsInt();
310  }
311  auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
312  {}, newMemSpace);
313 
314  // Create new private memref for fused loop 'forOp'. 'newShape' is always
315  // a constant shape.
316  // TODO: Create/move alloc ops for private memrefs closer to their
317  // consumer loop nests to reduce their live range. Currently they are added
318  // at the beginning of the block, because loop nests can be reordered
319  // during the fusion pass.
320  Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType);
321 
322  // Build an AffineMap to remap access functions based on lower bound offsets.
323  SmallVector<AffineExpr, 4> remapExprs;
324  remapExprs.reserve(rank);
325  for (unsigned i = 0; i < rank; i++) {
326  auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
327 
328  auto remapExpr =
329  simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
330  remapExprs.push_back(remapExpr);
331  }
332 
333  auto indexRemap =
334  AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
335 
336  // Replace all users of 'oldMemRef' with 'newMemRef'.
337  LogicalResult res =
338  replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
339  /*extraOperands=*/outerIVs,
340  /*symbolOperands=*/{},
341  /*domOpFilter=*/&*forOp.getBody()->begin());
342  assert(succeeded(res) &&
343  "replaceAllMemrefUsesWith should always succeed here");
344  (void)res;
345  return newMemRef;
346 }
347 
348 /// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and
349 /// 'dstId'), if there is any non-affine operation accessing 'memref', return
350 /// true. Otherwise, return false.
351 static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
352  Value memref,
353  MemRefDependenceGraph *mdg) {
354  auto *srcNode = mdg->getNode(srcId);
355  auto *dstNode = mdg->getNode(dstId);
356  Value::user_range users = memref.getUsers();
357  // For each MemRefDependenceGraph's node that is between 'srcNode' and
358  // 'dstNode' (exclusive of 'srcNodes' and 'dstNode'), check whether any
359  // non-affine operation in the node accesses the 'memref'.
360  for (auto &idAndNode : mdg->nodes) {
361  Operation *op = idAndNode.second.op;
362  // Take care of operations between 'srcNode' and 'dstNode'.
363  if (srcNode->op->isBeforeInBlock(op) && op->isBeforeInBlock(dstNode->op)) {
364  // Walk inside the operation to find any use of the memref.
365  // Interrupt the walk if found.
366  auto walkResult = op->walk([&](Operation *user) {
367  // Skip affine ops.
368  if (isa<AffineMapAccessInterface>(*user))
369  return WalkResult::advance();
370  // Find a non-affine op that uses the memref.
371  if (llvm::is_contained(users, user))
372  return WalkResult::interrupt();
373  return WalkResult::advance();
374  });
375  if (walkResult.wasInterrupted())
376  return true;
377  }
378  }
379  return false;
380 }
381 
382 /// Check whether a memref value in node 'srcId' has a non-affine that
383 /// is between node 'srcId' and node 'dstId' (exclusive of 'srcNode' and
384 /// 'dstNode').
385 static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
386  MemRefDependenceGraph *mdg) {
387  // Collect memref values in node 'srcId'.
388  auto *srcNode = mdg->getNode(srcId);
389  llvm::SmallDenseSet<Value, 2> memRefValues;
390  srcNode->op->walk([&](Operation *op) {
391  // Skip affine ops.
392  if (isa<AffineForOp>(op))
393  return WalkResult::advance();
394  for (Value v : op->getOperands())
395  // Collect memref values only.
396  if (isa<MemRefType>(v.getType()))
397  memRefValues.insert(v);
398  return WalkResult::advance();
399  });
400  // Looking for users between node 'srcId' and node 'dstId'.
401  return llvm::any_of(memRefValues, [&](Value memref) {
402  return hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg);
403  });
404 }
405 
406 // Checks the profitability of fusing a backwards slice of the loop nest
407 // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
408 // The argument 'srcStoreOpInst' is used to calculate the storage reduction on
409 // the memref being produced and consumed, which is an input to the cost model.
410 // For producer-consumer fusion, 'srcStoreOpInst' will be the same as
411 // 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse
412 // fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the
413 // same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the
414 // unique store op in the src node, which will be used to check that the write
415 // region is the same after input-reuse fusion. Computation slices are provided
416 // in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which
417 // fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is
418 // profitable to fuse the candidate loop nests. Returns false otherwise.
419 // `dstLoopDepth` is set to the most profitable depth at which to materialize
420 // the source loop nest slice.
421 // The profitability model executes the following steps:
422 // *) Computes the backward computation slice at 'srcOpInst'. This
423 // computation slice of the loop nest surrounding 'srcOpInst' is
424 // represented by modified src loop bounds in 'sliceState', which are
425 // functions of loop IVs in the loop nest surrounding 'srcOpInst'.
426 // *) Computes the cost of unfused src/dst loop nests (currently the cost of a
427 // loop nest is the total number of dynamic operation instances in the loop
428 // nest).
429 // *) Computes the cost of fusing a slice of the src loop nest into the dst
430 // loop nest at various values of dst loop depth, attempting to fuse
431 // the largest computation slice at the maximal dst loop depth (closest to
432 // the load) to minimize reuse distance and potentially enable subsequent
433 // load/store forwarding.
434 // NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
435 // nest, at which the src computation slice is inserted/fused.
436 // NOTE: We attempt to maximize the dst loop depth, but there are cases
437 // where a particular setting for 'dstLoopNest' might fuse an unsliced
438 // loop (within the src computation slice) at a depth which results in
439 // excessive recomputation (see unit tests for examples).
440 // *) Compares the total cost of the unfused loop nests to the min cost fused
441 // loop nest computed in the previous step, and returns true if the latter
442 // is lower.
443 // TODO: Extend profitability analysis to support scenarios with multiple
444 // stores.
445 static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
446  AffineForOp dstForOp,
447  ArrayRef<ComputationSliceState> depthSliceUnions,
448  unsigned maxLegalFusionDepth,
449  unsigned *dstLoopDepth,
450  double computeToleranceThreshold) {
451  LLVM_DEBUG({
452  llvm::dbgs() << "Checking whether fusion is profitable between src op:\n";
453  llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n";
454  llvm::dbgs() << dstForOp << "\n";
455  });
456 
457  if (maxLegalFusionDepth == 0) {
458  LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth is 0\n");
459  return false;
460  }
461 
462  // Compute cost of sliced and unsliced src loop nest.
463  SmallVector<AffineForOp, 4> srcLoopIVs;
464  getAffineForIVs(*srcOpInst, &srcLoopIVs);
465 
466  // Walk src loop nest and collect stats.
467  LoopNestStats srcLoopNestStats;
468  if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats))
469  return false;
470 
471  // Compute cost of dst loop nest.
472  LoopNestStats dstLoopNestStats;
473  if (!getLoopNestStats(dstForOp, &dstLoopNestStats))
474  return false;
475 
476  // Search for min cost value for 'dstLoopDepth'. At each value of
477  // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice
478  // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
479  // of these bounds). Next the union slice bounds are used to calculate
480  // the cost of the slice and the cost of the slice inserted into the dst
481  // loop nest at 'dstLoopDepth'.
482  uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
483  double maxStorageReduction = 0.0;
484  std::optional<uint64_t> sliceMemEstimate;
485 
486  // The best loop depth at which to materialize the slice.
487  std::optional<unsigned> bestDstLoopDepth;
488 
489  // Compute op instance count for the src loop nest without iteration slicing.
490  uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats);
491 
492  // Compute src loop nest write region size.
493  MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
494  if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) {
495  LLVM_DEBUG(llvm::dbgs()
496  << "Unable to compute MemRefRegion for source operation\n");
497  return false;
498  }
499 
500  std::optional<int64_t> maybeSrcWriteRegionSizeBytes =
501  srcWriteRegion.getRegionSize();
502  if (!maybeSrcWriteRegionSizeBytes.has_value())
503  return false;
504  int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;
505 
506  // Compute op instance count for the src loop nest.
507  uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
508 
509  // Evaluate all depth choices for materializing the slice in the destination
510  // loop nest.
511  for (unsigned i = maxLegalFusionDepth; i >= 1; --i) {
512  const ComputationSliceState &slice = depthSliceUnions[i - 1];
513  // Skip slice union if it wasn't computed for this depth.
514  if (slice.isEmpty())
515  continue;
516 
517  int64_t fusedLoopNestComputeCost;
518  if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp,
519  dstLoopNestStats, slice,
520  &fusedLoopNestComputeCost)) {
521  LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
522  continue;
523  }
524 
525  double additionalComputeFraction =
526  fusedLoopNestComputeCost /
527  (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
528  1;
529 
530  // Determine what the slice write MemRefRegion would be, if the src loop
531  // nest slice 'slice' were to be inserted into the dst loop nest at loop
532  // depth 'i'.
533  MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
534  if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
535  &slice))) {
536  LLVM_DEBUG(llvm::dbgs()
537  << "Failed to compute slice write region at loopDepth: " << i
538  << "\n");
539  continue;
540  }
541 
542  std::optional<int64_t> maybeSliceWriteRegionSizeBytes =
543  sliceWriteRegion.getRegionSize();
544  if (!maybeSliceWriteRegionSizeBytes.has_value() ||
545  *maybeSliceWriteRegionSizeBytes == 0) {
546  LLVM_DEBUG(llvm::dbgs()
547  << "Failed to get slice write region size at loopDepth: " << i
548  << "\n");
549  continue;
550  }
551  int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
552 
553  // If we are fusing for reuse, check that write regions remain the same.
554  // TODO: Write region check should check sizes and offsets in
555  // each dimension, so that we are sure they are covering the same memref
556  // region. Also, move this out to a isMemRefRegionSuperSet helper function.
557  if (srcOpInst != srcStoreOpInst &&
558  sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
559  continue;
560 
561  double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
562  static_cast<double>(sliceWriteRegionSizeBytes);
563 
564  LLVM_DEBUG({
565  std::stringstream msg;
566  msg << " evaluating fusion profitability at depth : " << i << "\n"
567  << std::fixed << std::setprecision(2)
568  << " additional compute fraction: "
569  << 100.0 * additionalComputeFraction << "%\n"
570  << " storage reduction factor: " << storageReduction << "x\n"
571  << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
572  << " src write region size: " << srcWriteRegionSizeBytes << "\n"
573  << " slice write region size: " << sliceWriteRegionSizeBytes
574  << "\n";
575  llvm::dbgs() << msg.str();
576  });
577 
578  // TODO: This is a placeholder cost model.
579  // Among all choices that add an acceptable amount of redundant computation
580  // (as per computeToleranceThreshold), we will simply pick the one that
581  // reduces the intermediary size the most.
582  if ((storageReduction > maxStorageReduction) &&
583  (additionalComputeFraction < computeToleranceThreshold)) {
584  maxStorageReduction = storageReduction;
585  bestDstLoopDepth = i;
586  minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
587  sliceMemEstimate = sliceWriteRegionSizeBytes;
588  }
589  }
590 
591  // A simple cost model: fuse if it reduces the memory footprint.
592 
593  if (!bestDstLoopDepth) {
594  LLVM_DEBUG(
595  llvm::dbgs()
596  << "All fusion choices involve more than the threshold amount of "
597  "redundant computation; NOT fusing.\n");
598  return false;
599  }
600 
601  if (!bestDstLoopDepth) {
602  LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n");
603  return false;
604  }
605 
606  // Set dstLoopDepth based on best values from search.
607  *dstLoopDepth = *bestDstLoopDepth;
608 
609  LLVM_DEBUG(
610  llvm::dbgs() << " LoopFusion fusion stats:"
611  << "\n best loop depth: " << bestDstLoopDepth
612  << "\n src loop nest compute cost: " << srcLoopNestCost
613  << "\n dst loop nest compute cost: " << dstLoopNestCost
614  << "\n fused loop nest compute cost: "
615  << minFusedLoopNestComputeCost << "\n");
616 
617  auto dstMemSize = getMemoryFootprintBytes(dstForOp);
618  auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
619 
620  std::optional<double> storageReduction;
621 
622  if (!dstMemSize || !srcMemSize) {
623  LLVM_DEBUG(llvm::dbgs()
624  << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
625  return false;
626  }
627 
628  auto srcMemSizeVal = *srcMemSize;
629  auto dstMemSizeVal = *dstMemSize;
630 
631  assert(sliceMemEstimate && "expected value");
632  auto fusedMem = dstMemSizeVal + *sliceMemEstimate;
633 
634  LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
635  << " dst mem: " << dstMemSizeVal << "\n"
636  << " fused mem: " << fusedMem << "\n"
637  << " slice mem: " << sliceMemEstimate << "\n");
638 
639  if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
640  LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
641  return false;
642  }
643  storageReduction =
644  100.0 *
645  (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
646 
647  double additionalComputeFraction =
648  100.0 * (minFusedLoopNestComputeCost /
649  (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
650  1);
651  (void)additionalComputeFraction;
652  LLVM_DEBUG({
653  std::stringstream msg;
654  msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
655  << std::setprecision(2) << additionalComputeFraction
656  << "% redundant computation and a ";
657  msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>");
658  msg << "% storage reduction.\n";
659  llvm::dbgs() << msg.str();
660  });
661 
662  return true;
663 }
664 
665 namespace {
666 
667 // GreedyFusion greedily fuses loop nests which have a producer/consumer or
668 // input-reuse relationship on a memref, with the goal of improving locality.
669 //
670 // The steps of the producer-consumer fusion algorithm are as follows:
671 //
672 // *) A worklist is initialized with node ids from the dependence graph.
673 // *) For each node id in the worklist:
674 // *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a
675 // candidate destination AffineForOp into which fusion will be attempted.
676 // *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
677 // *) For each LoadOp in 'dstLoadOps' do:
678 // *) Look up dependent loop nests which have a single store op to the same
679 // memref.
680 // *) Check if dependences would be violated by the fusion.
681 // *) Get a computation slice of 'srcLoopNest', which adjusts its loop
682 // bounds to be functions of 'dstLoopNest' IVs and symbols.
683 // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
684 // at a loop depth determined by the cost model in 'isFusionProfitable'.
685 // *) Add the newly fused load/store operations to the state,
686 // and also add newly fused load ops to 'dstLoopOps' to be considered
687 // as fusion dst load ops in another iteration.
688 // *) Remove old src loop nest and its associated state.
689 //
690 // The steps of the input-reuse fusion algorithm are as follows:
691 //
692 // *) Initialize 'worklist' with node ids from the dependence graph.
693 // *) For each 'dstNode' in the worklist:
694 // *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which
695 // loads from the same memref, but which has no dependence paths to/from.
696 // *) Get a computation slice of 'sibLoopNest', which adjusts its loop
697 // bounds to be functions of 'dstLoopNest' IVs and symbols.
698 // *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest',
699 // at a loop depth determined by the cost model in 'isFusionProfitable'.
700 // This function also checks that the memref write region of 'sibLoopNest',
701 // is preserved in the fused loop nest.
702 // *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'.
703 //
704 // Given a graph where top-level operations are vertices in the set 'V' and
705 // edges in the set 'E' are dependences between vertices, this algorithm
706 // takes O(V) time for initialization, and has runtime O(V + E).
707 //
708 // This greedy algorithm is not 'maximal' due to the current restriction of
709 // fusing along single producer consumer edges, but there is a TODO: to fix
710 // this.
711 //
712 // TODO: Experiment with other fusion policies.
713 struct GreedyFusion {
714 public:
715  // The data dependence graph to traverse during fusion.
717  // Worklist of graph nodes visited during the fusion pass.
718  SmallVector<unsigned, 8> worklist;
719  // Parameter for local buffer size threshold.
720  unsigned localBufSizeThreshold;
721  // Parameter for fast memory space.
722  std::optional<unsigned> fastMemorySpace;
723  // If true, ignore any additional (redundant) computation tolerance threshold
724  // that would have prevented fusion.
725  bool maximalFusion;
726  // The amount of additional computation that is tolerated while fusing
727  // pair-wise as a fraction of the total computation.
728  double computeToleranceThreshold;
729 
731 
732  GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold,
733  std::optional<unsigned> fastMemorySpace, bool maximalFusion,
734  double computeToleranceThreshold)
735  : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
736  fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
737  computeToleranceThreshold(computeToleranceThreshold) {}
738 
739  /// Initializes 'worklist' with nodes from 'mdg'.
740  void init() {
741  // TODO: Add a priority queue for prioritizing nodes by different
742  // metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
743  worklist.clear();
744  for (auto &idAndNode : mdg->nodes) {
745  const Node &node = idAndNode.second;
746  worklist.push_back(node.id);
747  }
748  }
749  /// Run only sibling fusion on the `mdg`.
750  void runSiblingFusionOnly() {
751  fuseSiblingNodes();
752  eraseUnusedMemRefAllocations();
753  }
754 
755  /// Run only producer/consumer fusion on the `mdg`.
756  void runProducerConsumerFusionOnly() {
757  fuseProducerConsumerNodes(
758  /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
759  eraseUnusedMemRefAllocations();
760  }
761 
762  // Run the GreedyFusion pass.
763  // *) First pass through the nodes fuses single-use producer nodes into their
764  // unique consumer.
765  // *) Second pass fuses sibling nodes which share no dependence edges.
766  // *) Third pass fuses any remaining producer nodes into their users.
767  void runGreedyFusion() {
768  // TODO: Run this repeatedly until a fixed-point is reached.
769  fuseProducerConsumerNodes(/*maxSrcUserCount=*/1);
770  fuseSiblingNodes();
771  fuseProducerConsumerNodes(
772  /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
773  eraseUnusedMemRefAllocations();
774  }
775 
776  /// Returns true if a private memref can be created for `memref` given
777  /// the fusion scenario reflected by the other arguments.
778  bool canCreatePrivateMemRef(Value memref,
779  const DenseSet<Value> &srcEscapingMemRefs,
780  unsigned producerId, unsigned consumerId,
781  bool removeSrcNode) {
782  const Node *consumerNode = mdg->getNode(consumerId);
783  // If `memref` is an escaping one, do not create a private memref
784  // for the below scenarios, since doing so will leave the escaping
785  // memref unmodified as all the writes originally meant for the
786  // escaping memref would be performed on the private memref:
787  // 1. The source is to be removed after fusion,
788  // OR
789  // 2. The destination writes to `memref`.
790  if (srcEscapingMemRefs.count(memref) > 0 &&
791  (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
792  return false;
793 
794  // Don't create a private memref if 'srcNode' has in edges on
795  // 'memref' or 'dstNode' has out edges on 'memref'.
796  if (mdg->getIncomingMemRefAccesses(producerId, memref) > 0 ||
797  mdg->getOutEdgeCount(consumerId, memref) > 0)
798  return false;
799 
800  // If 'srcNode' will be removed but it has out edges on 'memref' to
801  // nodes other than 'dstNode', we have to preserve dependences and
802  // cannot create a private memref.
803  if (removeSrcNode &&
804  any_of(mdg->outEdges[producerId], [&](const auto &edge) {
805  return edge.value == memref && edge.id != consumerId;
806  }))
807  return false;
808 
809  return true;
810  }
811 
812  /// Perform fusions with node `dstId` as the destination of fusion, with
813  /// No fusion is performed when producers with a user count greater than
814  /// `maxSrcUserCount` for any of the memrefs involved.
815  void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) {
816  LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
817  // Skip if this node was removed (fused into another node).
818  if (mdg->nodes.count(dstId) == 0)
819  return;
820  // Get 'dstNode' into which to attempt fusion.
821  auto *dstNode = mdg->getNode(dstId);
822  // Skip if 'dstNode' is not a loop nest.
823  if (!isa<AffineForOp>(dstNode->op))
824  return;
825  // Skip if 'dstNode' is a loop nest returning values.
826  // TODO: support loop nests that return values.
827  if (dstNode->op->getNumResults() > 0)
828  return;
829 
830  LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
831 
832  // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
833  // while preserving relative order. This can increase the maximum loop
834  // depth at which we can fuse a slice of a producer loop nest into a
835  // consumer loop nest.
836  sinkSequentialLoops(dstNode);
837  auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
838 
839  // Try to fuse 'dstNode' with candidate producer loops until a fixed point
840  // is reached. Fusing two loops may expose new fusion opportunities.
841  bool dstNodeChanged;
842  do {
843  // Gather src loop candidates for 'dstNode' and visit them in "quasi"
844  // reverse program order to minimize the number of iterations needed to
845  // reach the fixed point. Note that this is a best effort approach since
846  // 'getProducerCandidates' does not always guarantee that program order
847  // in 'srcIdCandidates'.
848  dstNodeChanged = false;
849  SmallVector<unsigned, 16> srcIdCandidates;
850  getProducerCandidates(dstId, mdg, srcIdCandidates);
851 
852  for (unsigned srcId : llvm::reverse(srcIdCandidates)) {
853  // Get 'srcNode' from which to attempt fusion into 'dstNode'.
854  auto *srcNode = mdg->getNode(srcId);
855  auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
856  LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId
857  << " for dst loop " << dstId << "\n");
858 
859  // Skip if 'srcNode' is a loop nest returning values.
860  // TODO: support loop nests that return values.
861  if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
862  continue;
863 
864  DenseSet<Value> producerConsumerMemrefs;
865  gatherProducerConsumerMemrefs(srcId, dstId, mdg,
866  producerConsumerMemrefs);
867 
868  // Skip if 'srcNode' out edge count on any memref is greater than
869  // 'maxSrcUserCount'.
870  if (any_of(producerConsumerMemrefs, [&](Value memref) {
871  return mdg->getOutEdgeCount(srcNode->id, memref) >
872  maxSrcUserCount;
873  }))
874  continue;
875 
876  // Gather memrefs in 'srcNode' that are written and escape out of the
877  // block (e.g., memref block arguments, returned memrefs,
878  // memrefs passed to function calls, etc.).
879  DenseSet<Value> srcEscapingMemRefs;
880  gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
881 
882  // Skip if there are non-affine operations in between the 'srcNode'
883  // and 'dstNode' using their memrefs. If so, we wouldn't be able to
884  // compute a legal insertion point for now. 'srcNode' and 'dstNode'
885  // memrefs with non-affine operation users would be considered
886  // escaping memrefs so we can limit this check to only scenarios with
887  // escaping memrefs.
888  if (!srcEscapingMemRefs.empty() &&
889  hasNonAffineUsersOnThePath(srcId, dstId, mdg)) {
890  LLVM_DEBUG(llvm::dbgs()
891  << "Can't fuse: non-affine users in between the loops\n");
892  continue;
893  }
894 
895  // Compute an operation list insertion point for the fused loop
896  // nest which preserves dependences.
897  Operation *fusedLoopInsPoint =
898  mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
899  if (fusedLoopInsPoint == nullptr)
900  continue;
901 
902  // It's possible this fusion is at an inner depth (i.e., there are
903  // common surrounding affine loops for the source and destination for
904  // ops). We need to get this number because the call to canFuseLoops
905  // needs to be passed the absolute depth. The max legal depth and the
906  // depths we try below are however *relative* and as such don't include
907  // the common depth.
908  SmallVector<AffineForOp, 4> surroundingLoops;
909  getAffineForIVs(*dstAffineForOp, &surroundingLoops);
910  unsigned numSurroundingLoops = surroundingLoops.size();
911 
912  // Compute the innermost common loop depth for dstNode
913  // producer-consumer loads/stores.
914  SmallVector<Operation *, 2> dstMemrefOps;
915  for (Operation *op : dstNode->loads)
916  if (producerConsumerMemrefs.count(
917  cast<AffineReadOpInterface>(op).getMemRef()) > 0)
918  dstMemrefOps.push_back(op);
919  for (Operation *op : dstNode->stores)
920  if (producerConsumerMemrefs.count(
921  cast<AffineWriteOpInterface>(op).getMemRef()))
922  dstMemrefOps.push_back(op);
923  unsigned dstLoopDepthTest =
924  getInnermostCommonLoopDepth(dstMemrefOps) - numSurroundingLoops;
925 
926  // Check the feasibility of fusing src loop nest into dst loop nest
927  // at loop depths in range [1, dstLoopDepthTest].
928  unsigned maxLegalFusionDepth = 0;
929  SmallVector<ComputationSliceState, 8> depthSliceUnions;
930  depthSliceUnions.resize(dstLoopDepthTest);
932  for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
933  FusionResult result =
934  affine::canFuseLoops(srcAffineForOp, dstAffineForOp,
935  /*dstLoopDepth=*/i + numSurroundingLoops,
936  &depthSliceUnions[i - 1], strategy);
937 
938  if (result.value == FusionResult::Success)
939  maxLegalFusionDepth = i;
940  }
941 
942  if (maxLegalFusionDepth == 0) {
943  LLVM_DEBUG(llvm::dbgs()
944  << "Can't fuse: fusion is not legal at any depth\n");
945  continue;
946  }
947 
948  // Check if fusion would be profitable. We skip profitability analysis
949  // for maximal fusion since we already know the maximal legal depth to
950  // fuse.
951  unsigned bestDstLoopDepth = maxLegalFusionDepth;
952  if (!maximalFusion) {
953  // Retrieve producer stores from the src loop.
954  SmallVector<Operation *, 2> producerStores;
955  for (Operation *op : srcNode->stores)
956  if (producerConsumerMemrefs.count(
957  cast<AffineWriteOpInterface>(op).getMemRef()))
958  producerStores.push_back(op);
959 
960  // TODO: Suppport multiple producer stores in profitability
961  // analysis. We limit profitability analysis to only scenarios with
962  // a single producer store for now. Note that some multi-store
963  // producer scenarios will still go through profitability analysis
964  // if only one of the stores is involved the producer-consumer
965  // relationship of the candidate loops.
966  assert(!producerStores.empty() && "Expected producer store");
967  if (producerStores.size() > 1)
968  LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not "
969  "supported for this case\n");
970  else if (!isFusionProfitable(producerStores[0], producerStores[0],
971  dstAffineForOp, depthSliceUnions,
972  maxLegalFusionDepth, &bestDstLoopDepth,
973  computeToleranceThreshold))
974  continue;
975  }
976 
977  assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
978  ComputationSliceState &bestSlice =
979  depthSliceUnions[bestDstLoopDepth - 1];
980  assert(!bestSlice.isEmpty() && "Missing slice union for depth");
981 
982  // Determine if 'srcId' can be removed after fusion, taking into
983  // account remaining dependences, escaping memrefs and the fusion
984  // insertion point.
985  bool removeSrcNode = canRemoveSrcNodeAfterFusion(
986  srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
987  mdg);
988 
989  DenseSet<Value> privateMemrefs;
990  for (Value memref : producerConsumerMemrefs) {
991  if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
992  removeSrcNode)) {
993  // Create a private version of this memref.
994  LLVM_DEBUG(llvm::dbgs()
995  << "Creating private memref for " << memref << '\n');
996  // Create a private version of this memref.
997  privateMemrefs.insert(memref);
998  }
999  }
1000 
1001  // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
1002  fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
1003  dstNodeChanged = true;
1004 
1005  LLVM_DEBUG(llvm::dbgs()
1006  << "Fused src loop " << srcId << " into dst loop " << dstId
1007  << " at depth " << bestDstLoopDepth << ":\n"
1008  << dstAffineForOp << "\n");
1009 
1010  // Move 'dstAffineForOp' before 'insertPointInst' if needed.
1011  if (fusedLoopInsPoint != dstAffineForOp)
1012  dstAffineForOp->moveBefore(fusedLoopInsPoint);
1013 
1014  // Update edges between 'srcNode' and 'dstNode'.
1015  mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
1016  removeSrcNode);
1017 
1018  // Create private memrefs.
1019  if (!privateMemrefs.empty()) {
1020  // Gather stores for all the private-to-be memrefs.
1021  DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
1022  dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
1023  Value storeMemRef = storeOp.getMemRef();
1024  if (privateMemrefs.count(storeMemRef) > 0)
1025  privateMemRefToStores[storeMemRef].push_back(storeOp);
1026  });
1027 
1028  // Replace original memrefs with private memrefs. Note that all the
1029  // loads and stores on these memrefs will be replaced with a new
1030  // loads and stores. Any reference to the original ones becomes
1031  // invalid after this point.
1032  for (auto &memrefToStoresPair : privateMemRefToStores) {
1033  // TODO: Use union of memref write regions to compute
1034  // private memref footprint.
1035  SmallVector<Operation *, 4> &storesForMemref =
1036  memrefToStoresPair.second;
1037  Value newMemRef = createPrivateMemRef(
1038  dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
1039  fastMemorySpace, localBufSizeThreshold);
1040  // Create new node in dependence graph for 'newMemRef' alloc op.
1041  unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
1042  // Add edge from 'newMemRef' node to dstNode.
1043  mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
1044  }
1045  // One or more entries for 'newMemRef' alloc op are inserted into
1046  // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to
1047  // reallocate, update dstNode.
1048  dstNode = mdg->getNode(dstId);
1049  }
1050 
1051  // Collect dst loop stats after memref privatization transformation.
1052  LoopNestStateCollector dstLoopCollector;
1053  dstLoopCollector.collect(dstAffineForOp);
1054 
1055  // Clear and add back loads and stores.
1056  mdg->clearNodeLoadAndStores(dstNode->id);
1057  mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
1058  dstLoopCollector.storeOpInsts);
1059 
1060  if (removeSrcNode) {
1061  LLVM_DEBUG(llvm::dbgs()
1062  << "Removing src loop " << srcId << " after fusion\n");
1063  // srcNode is no longer valid after it is removed from mdg.
1064  srcAffineForOp.erase();
1065  mdg->removeNode(srcId);
1066  srcNode = nullptr;
1067  }
1068  }
1069  } while (dstNodeChanged);
1070  }
1071 
1072  /// Visit each node in the graph, and for each node, attempt to fuse it with
1073  /// producer-consumer candidates. No fusion is performed when producers with a
1074  /// user count greater than `maxSrcUserCount` for any of the memrefs involved
1075  /// are encountered.
1076  void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
1077  LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
1078  init();
1079  while (!worklist.empty()) {
1080  unsigned dstId = worklist.back();
1081  worklist.pop_back();
1082  performFusionsIntoDest(dstId, maxSrcUserCount);
1083  }
1084  }
1085 
1086  // Visits each node in the graph, and for each node, attempts to fuse it with
1087  // its sibling nodes (nodes which share a parent, but no dependence edges).
1088  void fuseSiblingNodes() {
1089  LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n");
1090  init();
1091  while (!worklist.empty()) {
1092  unsigned dstId = worklist.back();
1093  worklist.pop_back();
1094 
1095  // Skip if this node was removed (fused into another node).
1096  if (mdg->nodes.count(dstId) == 0)
1097  continue;
1098  // Get 'dstNode' into which to attempt fusion.
1099  auto *dstNode = mdg->getNode(dstId);
1100  // Skip if 'dstNode' is not a loop nest.
1101  if (!isa<AffineForOp>(dstNode->op))
1102  continue;
1103  // Attempt to fuse 'dstNode' with its sibling nodes in the graph.
1104  fuseWithSiblingNodes(dstNode);
1105  }
1106  }
1107 
1108  // Attempt to fuse 'dstNode' with sibling nodes in the graph.
1109  void fuseWithSiblingNodes(Node *dstNode) {
1110  DenseSet<unsigned> visitedSibNodeIds;
1111  std::pair<unsigned, Value> idAndMemref;
1112  auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
1113 
1114  while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
1115  unsigned sibId = idAndMemref.first;
1116  Value memref = idAndMemref.second;
1117  // TODO: Check that 'sibStoreOpInst' post-dominates all other
1118  // stores to the same memref in 'sibNode' loop nest.
1119  auto *sibNode = mdg->getNode(sibId);
1120  // Compute an operation list insertion point for the fused loop
1121  // nest which preserves dependences.
1122  assert(sibNode->op->getBlock() == dstNode->op->getBlock());
1123  Operation *insertPointInst =
1124  sibNode->op->isBeforeInBlock(dstNode->op)
1125  ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id)
1126  : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id);
1127  if (insertPointInst == nullptr)
1128  continue;
1129 
1130  // Check if fusion would be profitable and at what depth.
1131 
1132  // Get unique 'sibNode' load op to 'memref'.
1133  SmallVector<Operation *, 2> sibLoadOpInsts;
1134  sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
1135  // Currently findSiblingNodeToFuse searches for siblings with one load.
1136  assert(sibLoadOpInsts.size() == 1);
1137  Operation *sibLoadOpInst = sibLoadOpInsts[0];
1138 
1139  // Gather 'dstNode' load ops to 'memref'.
1140  SmallVector<Operation *, 2> dstLoadOpInsts;
1141  dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
1142 
1143  // It's possible this fusion is at an inner depth (i.e., there are common
1144  // surrounding affine loops for the source and destination for ops). We
1145  // need to get this number because the call to canFuseLoops needs to be
1146  // passed the absolute depth. The max legal depth and the depths we try
1147  // below are however *relative* and as such don't include the common
1148  // depth.
1149  SmallVector<AffineForOp, 4> surroundingLoops;
1150  getAffineForIVs(*dstAffineForOp, &surroundingLoops);
1151  unsigned numSurroundingLoops = surroundingLoops.size();
1152  SmallVector<AffineForOp, 4> dstLoopIVs;
1153  getAffineForIVs(*dstLoadOpInsts[0], &dstLoopIVs);
1154  unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
1155  auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
1156 
1157  // Compute loop depth and slice union for fusion.
1158  SmallVector<ComputationSliceState, 8> depthSliceUnions;
1159  depthSliceUnions.resize(dstLoopDepthTest);
1160  unsigned maxLegalFusionDepth = 0;
1161  FusionStrategy strategy(memref);
1162  for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1163  FusionResult result =
1164  affine::canFuseLoops(sibAffineForOp, dstAffineForOp,
1165  /*dstLoopDepth=*/i + numSurroundingLoops,
1166  &depthSliceUnions[i - 1], strategy);
1167 
1168  if (result.value == FusionResult::Success)
1169  maxLegalFusionDepth = i;
1170  }
1171 
1172  LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
1173  << maxLegalFusionDepth << '\n');
1174 
1175  // Skip if fusion is not feasible at any loop depths.
1176  if (maxLegalFusionDepth == 0)
1177  continue;
1178 
1179  unsigned bestDstLoopDepth = maxLegalFusionDepth;
1180  if (!maximalFusion) {
1181  // Check if fusion would be profitable. For sibling fusion, the sibling
1182  // load op is treated as the src "store" op for fusion profitability
1183  // purposes. The footprint of the load in the slice relative to the
1184  // unfused source's determines reuse.
1185  if (!isFusionProfitable(sibLoadOpInst, sibLoadOpInst, dstAffineForOp,
1186  depthSliceUnions, maxLegalFusionDepth,
1187  &bestDstLoopDepth, computeToleranceThreshold))
1188  continue;
1189  }
1190 
1191  assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
1192  assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
1193  "Fusion depth has no computed slice union");
1194  // Check if source loop is being inserted in the innermost
1195  // destination loop. Based on this, the fused loop may be optimized
1196  // further inside `fuseLoops`.
1197  bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
1198  // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
1199  affine::fuseLoops(sibAffineForOp, dstAffineForOp,
1200  depthSliceUnions[bestDstLoopDepth - 1],
1201  isInnermostInsertion);
1202 
1203  auto dstForInst = cast<AffineForOp>(dstNode->op);
1204  // Update operation position of fused loop nest (if needed).
1205  if (insertPointInst != dstForInst) {
1206  dstForInst->moveBefore(insertPointInst);
1207  }
1208  // Update data dependence graph state post fusion.
1209  updateStateAfterSiblingFusion(sibNode, dstNode);
1210  }
1211  }
1212 
1213  // Searches block argument uses and the graph from 'dstNode' looking for a
1214  // fusion candidate sibling node which shares no dependences with 'dstNode'
1215  // but which loads from the same memref. Returns true and sets
1216  // 'idAndMemrefToFuse' on success. Returns false otherwise.
1217  bool findSiblingNodeToFuse(Node *dstNode,
1218  DenseSet<unsigned> *visitedSibNodeIds,
1219  std::pair<unsigned, Value> *idAndMemrefToFuse) {
1220  // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse
1221  // on 'memref'.
1222  auto canFuseWithSibNode = [&](Node *sibNode, Value memref) {
1223  // Skip if 'outEdge' is not a read-after-write dependence.
1224  // TODO: Remove restrict to single load op restriction.
1225  if (sibNode->getLoadOpCount(memref) != 1)
1226  return false;
1227  // Skip if there exists a path of dependent edges between
1228  // 'sibNode' and 'dstNode'.
1229  if (mdg->hasDependencePath(sibNode->id, dstNode->id) ||
1230  mdg->hasDependencePath(dstNode->id, sibNode->id))
1231  return false;
1232  // Skip sib node if it loads to (and stores from) the same memref on
1233  // which it also has an input dependence edge.
1234  DenseSet<Value> loadAndStoreMemrefSet;
1235  sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
1236  if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) {
1237  return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0;
1238  }))
1239  return false;
1240 
1241  // Check that all stores are to the same memref if any.
1242  DenseSet<Value> storeMemrefs;
1243  for (auto *storeOpInst : sibNode->stores) {
1244  storeMemrefs.insert(
1245  cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
1246  }
1247  if (storeMemrefs.size() > 1)
1248  return false;
1249 
1250  // Skip if a memref value in one node is used by a non-affine memref
1251  // access that lies between 'dstNode' and 'sibNode'.
1252  if (hasNonAffineUsersOnThePath(dstNode->id, sibNode->id, mdg) ||
1253  hasNonAffineUsersOnThePath(sibNode->id, dstNode->id, mdg))
1254  return false;
1255  return true;
1256  };
1257 
1258  // Search for siblings which load the same memref block argument.
1259  Block *block = dstNode->op->getBlock();
1260  for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) {
1261  for (Operation *user : block->getArgument(i).getUsers()) {
1262  auto loadOp = dyn_cast<AffineReadOpInterface>(user);
1263  if (!loadOp)
1264  continue;
1265  // Gather loops surrounding 'use'.
1267  getAffineForIVs(*user, &loops);
1268  // Skip 'use' if it is not within a loop nest.
1269  // Find the surrounding affine.for nested immediately within the
1270  // block.
1271  auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
1272  return loop->getBlock() == &mdg->block;
1273  });
1274  // Skip 'use' if it is not within a loop nest in `block`.
1275  if (it == loops.end())
1276  continue;
1277  Node *sibNode = mdg->getForOpNode(*it);
1278  assert(sibNode != nullptr);
1279  // Skip 'use' if it not a sibling to 'dstNode'.
1280  if (sibNode->id == dstNode->id)
1281  continue;
1282  // Skip 'use' if it has been visited.
1283  if (visitedSibNodeIds->count(sibNode->id) > 0)
1284  continue;
1285  // Skip 'use' if it does not load from the same memref as 'dstNode'.
1286  auto memref = loadOp.getMemRef();
1287  if (dstNode->getLoadOpCount(memref) == 0)
1288  continue;
1289  // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
1290  if (canFuseWithSibNode(sibNode, memref)) {
1291  visitedSibNodeIds->insert(sibNode->id);
1292  idAndMemrefToFuse->first = sibNode->id;
1293  idAndMemrefToFuse->second = memref;
1294  return true;
1295  }
1296  }
1297  }
1298 
1299  // Search for siblings by following edges through an intermediate src node.
1300  // Collect candidate 'dstNode' input edges in 'inEdges'.
1303  dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
1304  // Add 'inEdge' if it is a read-after-write dependence.
1305  if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
1306  mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
1307  inEdges.push_back(inEdge);
1308  });
1309 
1310  // Search for sibling nodes to fuse by visiting output edges from each input
1311  // edge in 'inEdges'.
1312  for (auto &inEdge : inEdges) {
1313  // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'.
1316  inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) {
1317  unsigned sibNodeId = outEdge.id;
1318  if (visitedSibNodeIds->count(sibNodeId) > 0)
1319  return;
1320  // Skip output edge if not a sibling using the same memref.
1321  if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
1322  return;
1323  auto *sibNode = mdg->getNode(sibNodeId);
1324  if (!isa<AffineForOp>(sibNode->op))
1325  return;
1326  // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
1327  if (canFuseWithSibNode(sibNode, outEdge.value)) {
1328  // Add candidate 'outEdge' to sibling node.
1329  outEdges.push_back(outEdge);
1330  }
1331  });
1332 
1333  // Add first candidate if any were returned.
1334  if (!outEdges.empty()) {
1335  visitedSibNodeIds->insert(outEdges[0].id);
1336  idAndMemrefToFuse->first = outEdges[0].id;
1337  idAndMemrefToFuse->second = outEdges[0].value;
1338  return true;
1339  }
1340  }
1341  return false;
1342  }
1343 
1344  /// Update data dependence graph state to reflect sibling fusion of 'sibNode'
1345  /// into 'dstNode'.
1346  void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) {
1347  // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion.
1348  mdg->updateEdges(sibNode->id, dstNode->id);
1349 
1350  // Collect dst loop stats after memref privatization transformation.
1351  auto dstForInst = cast<AffineForOp>(dstNode->op);
1352  LoopNestStateCollector dstLoopCollector;
1353  dstLoopCollector.collect(dstForInst);
1354  // Clear and add back loads and stores
1355  mdg->clearNodeLoadAndStores(dstNode->id);
1356  mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
1357  dstLoopCollector.storeOpInsts);
1358  // Remove old sibling loop nest if it no longer has outgoing dependence
1359  // edges, and it does not write to a memref which escapes the block.
1360  if (mdg->getOutEdgeCount(sibNode->id) == 0) {
1361  Operation *op = sibNode->op;
1362  mdg->removeNode(sibNode->id);
1363  op->erase();
1364  }
1365  }
1366 
1367  // Clean up any allocs with no users.
1368  void eraseUnusedMemRefAllocations() {
1369  for (auto &pair : mdg->memrefEdgeCount) {
1370  if (pair.second > 0)
1371  continue;
1372  auto memref = pair.first;
1373  // Skip if there exist other uses (return operation or function calls).
1374  if (!memref.use_empty())
1375  continue;
1376  // Use list expected to match the dep graph info.
1377  auto *op = memref.getDefiningOp();
1378  if (isa_and_nonnull<memref::AllocOp>(op))
1379  op->erase();
1380  }
1381  }
1382 };
1383 
1384 } // namespace
1385 
1386 /// Run fusion on `block`.
1387 void LoopFusion::runOnBlock(Block *block) {
1388  MemRefDependenceGraph g(*block);
1389  if (!g.init()) {
1390  LLVM_DEBUG(llvm::dbgs() << "MDG init failed\n");
1391  return;
1392  }
1393 
1394  std::optional<unsigned> fastMemorySpaceOpt;
1395  if (fastMemorySpace.hasValue())
1396  fastMemorySpaceOpt = fastMemorySpace;
1397  unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
1398  GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
1399  maximalFusion, computeToleranceThreshold);
1400 
1401  if (affineFusionMode == FusionMode::ProducerConsumer)
1402  fusion.runProducerConsumerFusionOnly();
1403  else if (affineFusionMode == FusionMode::Sibling)
1404  fusion.runSiblingFusionOnly();
1405  else
1406  fusion.runGreedyFusion();
1407 }
1408 
1409 void LoopFusion::runOnOperation() {
1410  // Call fusion on every op that has at least two affine.for nests (in post
1411  // order).
1412  getOperation()->walk([&](Operation *op) {
1413  for (Region &region : op->getRegions()) {
1414  for (Block &block : region.getBlocks()) {
1415  auto affineFors = block.getOps<AffineForOp>();
1416  if (!affineFors.empty() && !llvm::hasSingleElement(affineFors))
1417  runOnBlock(&block);
1418  }
1419  }
1420  });
1421 }
1422 
1424  unsigned fastMemorySpace, uint64_t localBufSizeThreshold,
1425  bool maximalFusion, enum FusionMode affineFusionMode) {
1426  return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
1427  maximalFusion, affineFusionMode);
1428 }
MemRefDependenceGraph::Node Node
Definition: Utils.cpp:38
static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, AffineForOp dstForOp, ArrayRef< ComputationSliceState > depthSliceUnions, unsigned maxLegalFusionDepth, unsigned *dstLoopDepth, double computeToleranceThreshold)
Definition: LoopFusion.cpp:445
static bool canRemoveSrcNodeAfterFusion(unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, Operation *fusedLoopInsPoint, const DenseSet< Value > &escapingMemRefs, MemRefDependenceGraph *mdg)
Returns true if node 'srcId' can be removed after fusing it with node 'dstId'.
Definition: LoopFusion.cpp:78
static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, Value memref, MemRefDependenceGraph *mdg)
Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and 'dstId'), if there is any non-aff...
Definition: LoopFusion.cpp:351
static bool isEscapingMemref(Value memref, Block *block)
A memref escapes in the context of the fusion pass if either:
Definition: LoopFusion.cpp:189
static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, DenseSet< Value > &escapingMemRefs)
Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' that escape the block or ...
Definition: LoopFusion.cpp:219
static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, unsigned dstLoopDepth, std::optional< unsigned > fastMemorySpace, uint64_t localBufSizeThreshold)
Definition: LoopFusion.cpp:247
static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg, SmallVectorImpl< unsigned > &srcIdCandidates)
Returns in 'srcIdCandidates' the producer fusion candidates for consumer 'dstId'.
Definition: LoopFusion.cpp:137
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
Definition: AffineExpr.h:69
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
unsigned getNumArguments()
Definition: Block.h:125
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:379
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:371
void getValues(unsigned start, unsigned end, SmallVectorImpl< Value > *values) const
Returns the Values associated with variables in range [start, end).
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
user_range getUsers() const
Definition: Value.h:228
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:52
static WalkResult interrupt()
Definition: Visitors.h:51
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
Describes the fusion strategy to be used in the Affine loop fusion utilities.
unsigned getNumCols() const
Returns the number of columns in the constraint system.
bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, AffineForOp dstForOp, LoopNestStats &dstStats, const ComputationSliceState &slice, int64_t *computeCost)
Computes and returns in 'computeCost', the total compute cost of fusing the 'slice' of the loop nest ...
void gatherProducerConsumerMemrefs(ArrayRef< Operation * > srcOps, ArrayRef< Operation * > dstOps, DenseSet< Value > &producerConsumerMemrefs)
Returns in 'producerConsumerMemrefs' the memrefs involved in a producer-consumer dependence between w...
int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats)
Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
void fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, const ComputationSliceState &srcSlice, bool isInnermostSiblingInsertionFusion=false)
Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point and source slice loop bo...
void getAffineForIVs(Operation &op, SmallVectorImpl< AffineForOp > *loops)
Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered from the outermost 'affine....
Definition: Utils.cpp:633
std::optional< int64_t > getMemoryFootprintBytes(AffineForOp forOp, int memorySpace=-1)
Gets the memory footprint of all data touched in the specified memory space in bytes; if the memory s...
Definition: Utils.cpp:1953
std::unique_ptr< Pass > createLoopFusionPass(unsigned fastMemorySpace=0, uint64_t localBufSizeThreshold=0, bool maximalFusion=false, enum FusionMode fusionMode=FusionMode::Greedy)
Creates a loop fusion pass which fuses affine loop nests at the top-level of the operation the pass i...
FusionMode
Fusion mode to attempt.
Definition: Passes.h:31
@ ProducerConsumer
Definition: Passes.h:31
unsigned getInnermostCommonLoopDepth(ArrayRef< Operation * > ops, SmallVectorImpl< AffineForOp > *surroundingLoops=nullptr)
Returns the innermost common loop depth for the set of operations in 'ops'.
Definition: Utils.cpp:1385
bool getLoopNestStats(AffineForOp forOp, LoopNestStats *stats)
Collect loop nest statistics (eg.
AffineForOp sinkSequentialLoops(AffineForOp forOp)
Definition: LoopUtils.cpp:1493
FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, ComputationSliceState *srcSlice, FusionStrategy fusionStrategy=FusionStrategy::Generic)
Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the loop nest rooted at 'dst...
LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, ArrayRef< Value > extraIndices={}, AffineMap indexRemap=AffineMap(), ArrayRef< Value > extraOperands={}, ArrayRef< Value > symbolOperands={}, Operation *domOpFilter=nullptr, Operation *postDomOpFilter=nullptr, bool allowNonDereferencingOps=false, bool replaceInDeallocOp=false)
Replaces all "dereferencing" uses of oldMemRef with newMemRef while optionally remapping the old memr...
Definition: Utils.cpp:1267
std::optional< int64_t > getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType)
Returns the memref's element type's size in bytes where the elemental type is an int or float or a ve...
Definition: Utils.cpp:1196
Include the generated interface declarations.
int64_t floorDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's floordiv operation on constants.
Definition: MathExtras.h:33
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their associated operands for a ...
Definition: Utils.h:259
bool isEmpty() const
Returns true if the computation slice is empty.
Definition: Utils.h:291
std::optional< bool > isMaximal() const
Returns true if the computation slice encloses all the iterations of the sliced loop nest.
Definition: Utils.cpp:888
enum mlir::affine::FusionResult::ResultEnum value
SmallVector< Operation *, 4 > loadOpInsts
Definition: Utils.h:40
void collect(Operation *opToWalk)
Definition: Utils.cpp:43
SmallVector< Operation *, 4 > storeOpInsts
Definition: Utils.h:41
LoopNestStats aggregates various per-loop statistics (eg.
DenseMap< unsigned, SmallVector< Edge, 2 > > outEdges
Definition: Utils.h:112
Block & block
The block for which this graph is created to perform fusion.
Definition: Utils.h:211
unsigned addNode(Operation *op)
Definition: Utils.cpp:251
unsigned getIncomingMemRefAccesses(unsigned id, Value memref)
Definition: Utils.cpp:388
void addEdge(unsigned srcId, unsigned dstId, Value value)
Definition: Utils.cpp:316
Node * getForOpNode(AffineForOp forOp)
Definition: Utils.cpp:243
DenseMap< unsigned, Node > nodes
Definition: Utils.h:108
bool hasDependencePath(unsigned srcId, unsigned dstId)
Definition: Utils.cpp:354
void clearNodeLoadAndStores(unsigned id)
Definition: Utils.cpp:576
void updateEdges(unsigned srcId, unsigned dstId, const DenseSet< Value > &privateMemRefs, bool removeSrcId)
Definition: Utils.cpp:504
DenseMap< unsigned, SmallVector< Edge, 2 > > inEdges
Definition: Utils.h:110
void addToNode(unsigned id, const SmallVectorImpl< Operation * > &loads, const SmallVectorImpl< Operation * > &stores)
Definition: Utils.cpp:568
void forEachMemRefInputEdge(unsigned id, const std::function< void(Edge)> &callback)
Definition: Utils.cpp:584
unsigned getOutEdgeCount(unsigned id, Value memref=nullptr)
Definition: Utils.cpp:404
void removeNode(unsigned id)
Definition: Utils.cpp:258
Operation * getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId)
Definition: Utils.cpp:428
void forEachMemRefOutputEdge(unsigned id, const std::function< void(Edge)> &callback)
Definition: Utils.cpp:592
Node * getNode(unsigned id)
Definition: Utils.cpp:236
DenseMap< Value, unsigned > memrefEdgeCount
Definition: Utils.h:115
A region of a memref's data space; this is typically constructed by analyzing load/store op's on this...
Definition: Utils.h:430
FlatAffineValueConstraints * getConstraints()
Definition: Utils.h:469
std::optional< int64_t > getConstantBoundingSizeAndShape(SmallVectorImpl< int64_t > *shape=nullptr, std::vector< SmallVector< int64_t, 4 >> *lbs=nullptr, SmallVectorImpl< int64_t > *lbDivisors=nullptr) const
Returns a constant upper bound on the number of elements in this region if bounded by a known constan...
Definition: Utils.cpp:942
std::optional< int64_t > getRegionSize()
Returns the size of this MemRefRegion in bytes.
Definition: Utils.cpp:1215
LogicalResult compute(Operation *op, unsigned loopDepth, const ComputationSliceState *sliceState=nullptr, bool addMemRefDimBounds=true)
Computes the memory region accessed by this memref with the region represented as constraints symboli...
Definition: Utils.cpp:1044
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.