MLIR  18.0.0git
LoopFusion.cpp
Go to the documentation of this file.
1 //===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements affine fusion.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/Builders.h"
26 #include "mlir/Transforms/Passes.h"
27 #include "llvm/ADT/DenseMap.h"
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include <iomanip>
34 #include <optional>
35 #include <sstream>
36 
37 namespace mlir {
38 namespace affine {
39 #define GEN_PASS_DEF_AFFINELOOPFUSION
40 #include "mlir/Dialect/Affine/Passes.h.inc"
41 } // namespace affine
42 } // namespace mlir
43 
44 #define DEBUG_TYPE "affine-loop-fusion"
45 
46 using namespace mlir;
47 using namespace mlir::affine;
48 
49 namespace {
50 /// Loop fusion pass. This pass currently supports a greedy fusion policy,
51 /// which fuses loop nests with single-writer/single-reader memref dependences
52 /// with the goal of improving locality.
53 // TODO: Support fusion of source loop nests which write to multiple
54 // memrefs, where each memref can have multiple users (if profitable).
55 struct LoopFusion : public affine::impl::AffineLoopFusionBase<LoopFusion> {
56  LoopFusion() = default;
57  LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes,
58  bool maximalFusion, enum FusionMode affineFusionMode) {
59  this->fastMemorySpace = fastMemorySpace;
60  this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024;
61  this->maximalFusion = maximalFusion;
62  this->affineFusionMode = affineFusionMode;
63  }
64 
65  void runOnBlock(Block *block);
66  void runOnOperation() override;
67 };
68 
69 } // namespace
70 
71 /// Returns true if node 'srcId' can be removed after fusing it with node
72 /// 'dstId'. The node can be removed if any of the following conditions are met:
73 /// 1. 'srcId' has no output dependences after fusion and no escaping memrefs.
74 /// 2. 'srcId' has no output dependences after fusion, has escaping memrefs
75 /// and the fusion slice is maximal.
76 /// 3. 'srcId' has output dependences after fusion, the fusion slice is
77 /// maximal and the fusion insertion point dominates all the dependences.
79  unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice,
80  Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs,
81  MemRefDependenceGraph *mdg) {
82 
83  Operation *dstNodeOp = mdg->getNode(dstId)->op;
84  bool hasOutDepsAfterFusion = false;
85 
86  for (auto &outEdge : mdg->outEdges[srcId]) {
87  Operation *depNodeOp = mdg->getNode(outEdge.id)->op;
88  // Skip dependence with dstOp since it will be removed after fusion.
89  if (depNodeOp == dstNodeOp)
90  continue;
91 
92  // Only fusion within the same block is supported. Use domination analysis
93  // when needed.
94  if (depNodeOp->getBlock() != dstNodeOp->getBlock())
95  return false;
96 
97  // Check if the insertion point of the fused loop dominates the dependence.
98  // Otherwise, the src loop can't be removed.
99  if (fusedLoopInsPoint != depNodeOp &&
100  !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) {
101  LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't "
102  "dominate dependence\n");
103  return false;
104  }
105 
106  hasOutDepsAfterFusion = true;
107  }
108 
109  // If src loop has dependences after fusion or it writes to an live-out or
110  // escaping memref, we can only remove it if the fusion slice is maximal so
111  // that all the dependences are preserved.
112  if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
113  std::optional<bool> isMaximal = fusionSlice.isMaximal();
114  if (!isMaximal) {
115  LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine "
116  "if fusion is maximal\n");
117  return false;
118  }
119 
120  if (!*isMaximal) {
121  LLVM_DEBUG(llvm::dbgs()
122  << "Src loop can't be removed: fusion is not maximal\n");
123  return false;
124  }
125  }
126 
127  return true;
128 }
129 
130 /// Returns in 'srcIdCandidates' the producer fusion candidates for consumer
131 /// 'dstId'. Candidates are sorted by node id order. This order corresponds to
132 /// the program order when the 'mdg' is created. However, program order is not
133 /// guaranteed and must not be required by the client. Program order won't be
134 /// held if the 'mdg' is reused from a previous fusion step or if the node
135 /// creation order changes in the future to support more advance cases.
136 // TODO: Move this to a loop fusion utility once 'mdg' is also moved.
137 static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg,
138  SmallVectorImpl<unsigned> &srcIdCandidates) {
139  // Skip if no input edges along which to fuse.
140  if (mdg->inEdges.count(dstId) == 0)
141  return;
142 
143  // Gather memrefs from loads in 'dstId'.
144  auto *dstNode = mdg->getNode(dstId);
145  DenseSet<Value> consumedMemrefs;
146  for (Operation *load : dstNode->loads)
147  consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
148 
149  // Traverse 'dstId' incoming edges and gather the nodes that contain a store
150  // to one of the consumed memrefs.
151  for (auto &srcEdge : mdg->inEdges[dstId]) {
152  auto *srcNode = mdg->getNode(srcEdge.id);
153  // Skip if 'srcNode' is not a loop nest.
154  if (!isa<AffineForOp>(srcNode->op))
155  continue;
156 
157  if (any_of(srcNode->stores, [&](Operation *op) {
158  auto storeOp = cast<AffineWriteOpInterface>(op);
159  return consumedMemrefs.count(storeOp.getMemRef()) > 0;
160  }))
161  srcIdCandidates.push_back(srcNode->id);
162  }
163 
164  llvm::sort(srcIdCandidates);
165  srcIdCandidates.erase(
166  std::unique(srcIdCandidates.begin(), srcIdCandidates.end()),
167  srcIdCandidates.end());
168 }
169 
170 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a
171 /// producer-consumer dependence between 'srcId' and 'dstId'.
172 static void
173 gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
175  DenseSet<Value> &producerConsumerMemrefs) {
176  auto *dstNode = mdg->getNode(dstId);
177  auto *srcNode = mdg->getNode(srcId);
178  gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads,
179  producerConsumerMemrefs);
180 }
181 
182 /// A memref escapes in the context of the fusion pass if either:
183 /// 1. it (or its alias) is a block argument, or
184 /// 2. created by an op not known to guarantee alias freedom,
185 /// 3. it (or its alias) are used by ops other than affine dereferencing ops
186 /// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops,
187 /// terminator ops, etc.); such ops do not deference the memref in an affine
188 /// way.
189 static bool isEscapingMemref(Value memref, Block *block) {
190  Operation *defOp = memref.getDefiningOp();
191  // Check if 'memref' is a block argument.
192  if (!defOp)
193  return true;
194 
195  // Check if this is defined to be an alias of another memref.
196  if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
197  if (isEscapingMemref(viewOp.getViewSource(), block))
198  return true;
199 
200  // Any op besides allocating ops wouldn't guarantee alias freedom
201  if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(defOp, memref))
202  return true;
203 
204  // Check if 'memref' is used by a non-deferencing op (including unknown ones)
205  // (e.g., call ops, alias creating ops, etc.).
206  return llvm::any_of(memref.getUsers(), [&](Operation *user) {
207  // Ignore users outside of `block`.
208  if (block->getParent()->findAncestorOpInRegion(*user)->getBlock() != block)
209  return false;
210  return !isa<AffineMapAccessInterface>(*user);
211  });
212 }
213 
214 /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
215 /// that escape the block or are accessed in a non-affine way.
216 static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
217  DenseSet<Value> &escapingMemRefs) {
218  auto *node = mdg->getNode(id);
219  for (Operation *storeOp : node->stores) {
220  auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
221  if (escapingMemRefs.count(memref))
222  continue;
223  if (isEscapingMemref(memref, &mdg->block))
224  escapingMemRefs.insert(memref);
225  }
226 }
227 
228 // Sinks all sequential loops to the innermost levels (while preserving
229 // relative order among them) and moves all parallel loops to the
230 // outermost (while again preserving relative order among them).
231 // This can increase the loop depth at which we can fuse a slice, since we are
232 // pushing loop carried dependence to a greater depth in the loop nest.
234  assert(isa<AffineForOp>(node->op));
235  AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op));
236  node->op = newRootForOp;
237 }
238 
239 // Creates and returns a private (single-user) memref for fused loop rooted
240 // at 'forOp', with (potentially reduced) memref size based on the
241 // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
242 // TODO: consider refactoring the common code from generateDma and
243 // this one.
244 static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
245  unsigned dstLoopDepth,
246  std::optional<unsigned> fastMemorySpace,
247  uint64_t localBufSizeThreshold) {
248  Operation *forInst = forOp.getOperation();
249 
250  // Create builder to insert alloc op just before 'forOp'.
251  OpBuilder b(forInst);
252  // Builder to create constants at the top level.
253  OpBuilder top(forInst->getParentRegion());
254  // Create new memref type based on slice bounds.
255  auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
256  auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
257  unsigned rank = oldMemRefType.getRank();
258 
259  // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
260  MemRefRegion region(srcStoreOpInst->getLoc());
261  bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth));
262  (void)validRegion;
263  assert(validRegion && "unexpected memref region failure");
264  SmallVector<int64_t, 4> newShape;
265  std::vector<SmallVector<int64_t, 4>> lbs;
266  SmallVector<int64_t, 8> lbDivisors;
267  lbs.reserve(rank);
268  // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
269  // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
270  std::optional<int64_t> numElements =
271  region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
272  assert(numElements && "non-constant number of elts in local buffer");
273 
274  const FlatAffineValueConstraints *cst = region.getConstraints();
275  // 'outerIVs' holds the values that this memory region is symbolic/parametric
276  // on; this would correspond to loop IVs surrounding the level at which the
277  // slice is being materialized.
278  SmallVector<Value, 8> outerIVs;
279  cst->getValues(rank, cst->getNumVars(), &outerIVs);
280 
281  // Build 'rank' AffineExprs from MemRefRegion 'lbs'
283  offsets.reserve(rank);
284  for (unsigned d = 0; d < rank; ++d) {
285  assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
286 
287  AffineExpr offset = top.getAffineConstantExpr(0);
288  for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
289  offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
290  }
291  assert(lbDivisors[d] > 0);
292  offset =
293  (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
294  offsets.push_back(offset);
295  }
296 
297  // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
298  // by 'srcStoreOpInst'.
299  auto eltSize = getMemRefIntOrFloatEltSizeInBytes(oldMemRefType);
300  assert(eltSize && "memrefs with size elt types expected");
301  uint64_t bufSize = *eltSize * *numElements;
302  unsigned newMemSpace;
303  if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
304  newMemSpace = *fastMemorySpace;
305  } else {
306  newMemSpace = oldMemRefType.getMemorySpaceAsInt();
307  }
308  auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
309  {}, newMemSpace);
310 
311  // Create new private memref for fused loop 'forOp'. 'newShape' is always
312  // a constant shape.
313  // TODO: Create/move alloc ops for private memrefs closer to their
314  // consumer loop nests to reduce their live range. Currently they are added
315  // at the beginning of the block, because loop nests can be reordered
316  // during the fusion pass.
317  Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType);
318 
319  // Build an AffineMap to remap access functions based on lower bound offsets.
320  SmallVector<AffineExpr, 4> remapExprs;
321  remapExprs.reserve(rank);
322  for (unsigned i = 0; i < rank; i++) {
323  auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
324 
325  auto remapExpr =
326  simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
327  remapExprs.push_back(remapExpr);
328  }
329 
330  auto indexRemap =
331  AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
332 
333  // Replace all users of 'oldMemRef' with 'newMemRef'.
334  LogicalResult res =
335  replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
336  /*extraOperands=*/outerIVs,
337  /*symbolOperands=*/{},
338  /*domOpFilter=*/&*forOp.getBody()->begin());
339  assert(succeeded(res) &&
340  "replaceAllMemrefUsesWith should always succeed here");
341  (void)res;
342  return newMemRef;
343 }
344 
345 /// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and
346 /// 'dstId'), if there is any non-affine operation accessing 'memref', return
347 /// true. Otherwise, return false.
348 static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
349  Value memref,
350  MemRefDependenceGraph *mdg) {
351  auto *srcNode = mdg->getNode(srcId);
352  auto *dstNode = mdg->getNode(dstId);
353  Value::user_range users = memref.getUsers();
354  // For each MemRefDependenceGraph's node that is between 'srcNode' and
355  // 'dstNode' (exclusive of 'srcNodes' and 'dstNode'), check whether any
356  // non-affine operation in the node accesses the 'memref'.
357  for (auto &idAndNode : mdg->nodes) {
358  Operation *op = idAndNode.second.op;
359  // Take care of operations between 'srcNode' and 'dstNode'.
360  if (srcNode->op->isBeforeInBlock(op) && op->isBeforeInBlock(dstNode->op)) {
361  // Walk inside the operation to find any use of the memref.
362  // Interrupt the walk if found.
363  auto walkResult = op->walk([&](Operation *user) {
364  // Skip affine ops.
365  if (isa<AffineMapAccessInterface>(*user))
366  return WalkResult::advance();
367  // Find a non-affine op that uses the memref.
368  if (llvm::is_contained(users, user))
369  return WalkResult::interrupt();
370  return WalkResult::advance();
371  });
372  if (walkResult.wasInterrupted())
373  return true;
374  }
375  }
376  return false;
377 }
378 
379 /// Check whether a memref value in node 'srcId' has a non-affine that
380 /// is between node 'srcId' and node 'dstId' (exclusive of 'srcNode' and
381 /// 'dstNode').
382 static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
383  MemRefDependenceGraph *mdg) {
384  // Collect memref values in node 'srcId'.
385  auto *srcNode = mdg->getNode(srcId);
386  llvm::SmallDenseSet<Value, 2> memRefValues;
387  srcNode->op->walk([&](Operation *op) {
388  // Skip affine ops.
389  if (isa<AffineForOp>(op))
390  return WalkResult::advance();
391  for (Value v : op->getOperands())
392  // Collect memref values only.
393  if (isa<MemRefType>(v.getType()))
394  memRefValues.insert(v);
395  return WalkResult::advance();
396  });
397  // Looking for users between node 'srcId' and node 'dstId'.
398  return llvm::any_of(memRefValues, [&](Value memref) {
399  return hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg);
400  });
401 }
402 
403 // Checks the profitability of fusing a backwards slice of the loop nest
404 // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
405 // The argument 'srcStoreOpInst' is used to calculate the storage reduction on
406 // the memref being produced and consumed, which is an input to the cost model.
407 // For producer-consumer fusion, 'srcStoreOpInst' will be the same as
408 // 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse
409 // fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the
410 // same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the
411 // unique store op in the src node, which will be used to check that the write
412 // region is the same after input-reuse fusion. Computation slices are provided
413 // in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which
414 // fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is
415 // profitable to fuse the candidate loop nests. Returns false otherwise.
416 // `dstLoopDepth` is set to the most profitable depth at which to materialize
417 // the source loop nest slice.
418 // The profitability model executes the following steps:
419 // *) Computes the backward computation slice at 'srcOpInst'. This
420 // computation slice of the loop nest surrounding 'srcOpInst' is
421 // represented by modified src loop bounds in 'sliceState', which are
422 // functions of loop IVs in the loop nest surrounding 'srcOpInst'.
423 // *) Computes the cost of unfused src/dst loop nests (currently the cost of a
424 // loop nest is the total number of dynamic operation instances in the loop
425 // nest).
426 // *) Computes the cost of fusing a slice of the src loop nest into the dst
427 // loop nest at various values of dst loop depth, attempting to fuse
428 // the largest computation slice at the maximal dst loop depth (closest to
429 // the load) to minimize reuse distance and potentially enable subsequent
430 // load/store forwarding.
431 // NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
432 // nest, at which the src computation slice is inserted/fused.
433 // NOTE: We attempt to maximize the dst loop depth, but there are cases
434 // where a particular setting for 'dstLoopNest' might fuse an unsliced
435 // loop (within the src computation slice) at a depth which results in
436 // excessive recomputation (see unit tests for examples).
437 // *) Compares the total cost of the unfused loop nests to the min cost fused
438 // loop nest computed in the previous step, and returns true if the latter
439 // is lower.
440 // TODO: Extend profitability analysis to support scenarios with multiple
441 // stores.
442 static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
443  AffineForOp dstForOp,
444  ArrayRef<ComputationSliceState> depthSliceUnions,
445  unsigned maxLegalFusionDepth,
446  unsigned *dstLoopDepth,
447  double computeToleranceThreshold) {
448  LLVM_DEBUG({
449  llvm::dbgs() << "Checking whether fusion is profitable between src op:\n";
450  llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n";
451  llvm::dbgs() << dstForOp << "\n";
452  });
453 
454  if (maxLegalFusionDepth == 0) {
455  LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth is 0\n");
456  return false;
457  }
458 
459  // Compute cost of sliced and unsliced src loop nest.
460  SmallVector<AffineForOp, 4> srcLoopIVs;
461  getAffineForIVs(*srcOpInst, &srcLoopIVs);
462 
463  // Walk src loop nest and collect stats.
464  LoopNestStats srcLoopNestStats;
465  if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats))
466  return false;
467 
468  // Compute cost of dst loop nest.
469  LoopNestStats dstLoopNestStats;
470  if (!getLoopNestStats(dstForOp, &dstLoopNestStats))
471  return false;
472 
473  // Search for min cost value for 'dstLoopDepth'. At each value of
474  // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice
475  // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
476  // of these bounds). Next the union slice bounds are used to calculate
477  // the cost of the slice and the cost of the slice inserted into the dst
478  // loop nest at 'dstLoopDepth'.
479  uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
480  double maxStorageReduction = 0.0;
481  std::optional<uint64_t> sliceMemEstimate;
482 
483  // The best loop depth at which to materialize the slice.
484  std::optional<unsigned> bestDstLoopDepth;
485 
486  // Compute op instance count for the src loop nest without iteration slicing.
487  uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats);
488 
489  // Compute src loop nest write region size.
490  MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
491  if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) {
492  LLVM_DEBUG(llvm::dbgs()
493  << "Unable to compute MemRefRegion for source operation\n");
494  return false;
495  }
496 
497  std::optional<int64_t> maybeSrcWriteRegionSizeBytes =
498  srcWriteRegion.getRegionSize();
499  if (!maybeSrcWriteRegionSizeBytes.has_value())
500  return false;
501  int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;
502 
503  // Compute op instance count for the src loop nest.
504  uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
505 
506  // Evaluate all depth choices for materializing the slice in the destination
507  // loop nest.
508  for (unsigned i = maxLegalFusionDepth; i >= 1; --i) {
509  const ComputationSliceState &slice = depthSliceUnions[i - 1];
510  // Skip slice union if it wasn't computed for this depth.
511  if (slice.isEmpty())
512  continue;
513 
514  int64_t fusedLoopNestComputeCost;
515  if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp,
516  dstLoopNestStats, slice,
517  &fusedLoopNestComputeCost)) {
518  LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
519  continue;
520  }
521 
522  double additionalComputeFraction =
523  fusedLoopNestComputeCost /
524  (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
525  1;
526 
527  // Determine what the slice write MemRefRegion would be, if the src loop
528  // nest slice 'slice' were to be inserted into the dst loop nest at loop
529  // depth 'i'.
530  MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
531  if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
532  &slice))) {
533  LLVM_DEBUG(llvm::dbgs()
534  << "Failed to compute slice write region at loopDepth: " << i
535  << "\n");
536  continue;
537  }
538 
539  std::optional<int64_t> maybeSliceWriteRegionSizeBytes =
540  sliceWriteRegion.getRegionSize();
541  if (!maybeSliceWriteRegionSizeBytes.has_value() ||
542  *maybeSliceWriteRegionSizeBytes == 0) {
543  LLVM_DEBUG(llvm::dbgs()
544  << "Failed to get slice write region size at loopDepth: " << i
545  << "\n");
546  continue;
547  }
548  int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
549 
550  // If we are fusing for reuse, check that write regions remain the same.
551  // TODO: Write region check should check sizes and offsets in
552  // each dimension, so that we are sure they are covering the same memref
553  // region. Also, move this out to a isMemRefRegionSuperSet helper function.
554  if (srcOpInst != srcStoreOpInst &&
555  sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
556  continue;
557 
558  double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
559  static_cast<double>(sliceWriteRegionSizeBytes);
560 
561  LLVM_DEBUG({
562  std::stringstream msg;
563  msg << " evaluating fusion profitability at depth : " << i << "\n"
564  << std::fixed << std::setprecision(2)
565  << " additional compute fraction: "
566  << 100.0 * additionalComputeFraction << "%\n"
567  << " storage reduction factor: " << storageReduction << "x\n"
568  << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
569  << " src write region size: " << srcWriteRegionSizeBytes << "\n"
570  << " slice write region size: " << sliceWriteRegionSizeBytes
571  << "\n";
572  llvm::dbgs() << msg.str();
573  });
574 
575  // TODO: This is a placeholder cost model.
576  // Among all choices that add an acceptable amount of redundant computation
577  // (as per computeToleranceThreshold), we will simply pick the one that
578  // reduces the intermediary size the most.
579  if ((storageReduction > maxStorageReduction) &&
580  (additionalComputeFraction < computeToleranceThreshold)) {
581  maxStorageReduction = storageReduction;
582  bestDstLoopDepth = i;
583  minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
584  sliceMemEstimate = sliceWriteRegionSizeBytes;
585  }
586  }
587 
588  // A simple cost model: fuse if it reduces the memory footprint.
589 
590  if (!bestDstLoopDepth) {
591  LLVM_DEBUG(
592  llvm::dbgs()
593  << "All fusion choices involve more than the threshold amount of "
594  "redundant computation; NOT fusing.\n");
595  return false;
596  }
597 
598  if (!bestDstLoopDepth) {
599  LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n");
600  return false;
601  }
602 
603  // Set dstLoopDepth based on best values from search.
604  *dstLoopDepth = *bestDstLoopDepth;
605 
606  LLVM_DEBUG(
607  llvm::dbgs() << " LoopFusion fusion stats:"
608  << "\n best loop depth: " << bestDstLoopDepth
609  << "\n src loop nest compute cost: " << srcLoopNestCost
610  << "\n dst loop nest compute cost: " << dstLoopNestCost
611  << "\n fused loop nest compute cost: "
612  << minFusedLoopNestComputeCost << "\n");
613 
614  auto dstMemSize = getMemoryFootprintBytes(dstForOp);
615  auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
616 
617  std::optional<double> storageReduction;
618 
619  if (!dstMemSize || !srcMemSize) {
620  LLVM_DEBUG(llvm::dbgs()
621  << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
622  return false;
623  }
624 
625  auto srcMemSizeVal = *srcMemSize;
626  auto dstMemSizeVal = *dstMemSize;
627 
628  assert(sliceMemEstimate && "expected value");
629  auto fusedMem = dstMemSizeVal + *sliceMemEstimate;
630 
631  LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
632  << " dst mem: " << dstMemSizeVal << "\n"
633  << " fused mem: " << fusedMem << "\n"
634  << " slice mem: " << sliceMemEstimate << "\n");
635 
636  if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
637  LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
638  return false;
639  }
640  storageReduction =
641  100.0 *
642  (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
643 
644  double additionalComputeFraction =
645  100.0 * (minFusedLoopNestComputeCost /
646  (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
647  1);
648  (void)additionalComputeFraction;
649  LLVM_DEBUG({
650  std::stringstream msg;
651  msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
652  << std::setprecision(2) << additionalComputeFraction
653  << "% redundant computation and a ";
654  msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>");
655  msg << "% storage reduction.\n";
656  llvm::dbgs() << msg.str();
657  });
658 
659  return true;
660 }
661 
662 namespace {
663 
664 // GreedyFusion greedily fuses loop nests which have a producer/consumer or
665 // input-reuse relationship on a memref, with the goal of improving locality.
666 //
667 // The steps of the producer-consumer fusion algorithm are as follows:
668 //
669 // *) A worklist is initialized with node ids from the dependence graph.
670 // *) For each node id in the worklist:
671 // *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a
672 // candidate destination AffineForOp into which fusion will be attempted.
673 // *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
674 // *) For each LoadOp in 'dstLoadOps' do:
675 // *) Look up dependent loop nests which have a single store op to the same
676 // memref.
677 // *) Check if dependences would be violated by the fusion.
678 // *) Get a computation slice of 'srcLoopNest', which adjusts its loop
679 // bounds to be functions of 'dstLoopNest' IVs and symbols.
680 // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
681 // at a loop depth determined by the cost model in 'isFusionProfitable'.
682 // *) Add the newly fused load/store operations to the state,
683 // and also add newly fused load ops to 'dstLoopOps' to be considered
684 // as fusion dst load ops in another iteration.
685 // *) Remove old src loop nest and its associated state.
686 //
687 // The steps of the input-reuse fusion algorithm are as follows:
688 //
689 // *) Initialize 'worklist' with node ids from the dependence graph.
690 // *) For each 'dstNode' in the worklist:
691 // *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which
692 // loads from the same memref, but which has no dependence paths to/from.
693 // *) Get a computation slice of 'sibLoopNest', which adjusts its loop
694 // bounds to be functions of 'dstLoopNest' IVs and symbols.
695 // *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest',
696 // at a loop depth determined by the cost model in 'isFusionProfitable'.
697 // This function also checks that the memref write region of 'sibLoopNest',
698 // is preserved in the fused loop nest.
699 // *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'.
700 //
701 // Given a graph where top-level operations are vertices in the set 'V' and
702 // edges in the set 'E' are dependences between vertices, this algorithm
703 // takes O(V) time for initialization, and has runtime O(V + E).
704 //
705 // This greedy algorithm is not 'maximal' due to the current restriction of
706 // fusing along single producer consumer edges, but there is a TODO: to fix
707 // this.
708 //
709 // TODO: Experiment with other fusion policies.
710 struct GreedyFusion {
711 public:
712  // The data dependence graph to traverse during fusion.
714  // Worklist of graph nodes visited during the fusion pass.
715  SmallVector<unsigned, 8> worklist;
716  // Parameter for local buffer size threshold.
717  unsigned localBufSizeThreshold;
718  // Parameter for fast memory space.
719  std::optional<unsigned> fastMemorySpace;
720  // If true, ignore any additional (redundant) computation tolerance threshold
721  // that would have prevented fusion.
722  bool maximalFusion;
723  // The amount of additional computation that is tolerated while fusing
724  // pair-wise as a fraction of the total computation.
725  double computeToleranceThreshold;
726 
728 
729  GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold,
730  std::optional<unsigned> fastMemorySpace, bool maximalFusion,
731  double computeToleranceThreshold)
732  : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
733  fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
734  computeToleranceThreshold(computeToleranceThreshold) {}
735 
736  /// Initializes 'worklist' with nodes from 'mdg'.
737  void init() {
738  // TODO: Add a priority queue for prioritizing nodes by different
739  // metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
740  worklist.clear();
741  for (auto &idAndNode : mdg->nodes) {
742  const Node &node = idAndNode.second;
743  worklist.push_back(node.id);
744  }
745  }
746  /// Run only sibling fusion on the `mdg`.
747  void runSiblingFusionOnly() {
748  fuseSiblingNodes();
749  eraseUnusedMemRefAllocations();
750  }
751 
752  /// Run only producer/consumer fusion on the `mdg`.
753  void runProducerConsumerFusionOnly() {
754  fuseProducerConsumerNodes(
755  /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
756  eraseUnusedMemRefAllocations();
757  }
758 
759  // Run the GreedyFusion pass.
760  // *) First pass through the nodes fuses single-use producer nodes into their
761  // unique consumer.
762  // *) Second pass fuses sibling nodes which share no dependence edges.
763  // *) Third pass fuses any remaining producer nodes into their users.
764  void runGreedyFusion() {
765  // TODO: Run this repeatedly until a fixed-point is reached.
766  fuseProducerConsumerNodes(/*maxSrcUserCount=*/1);
767  fuseSiblingNodes();
768  fuseProducerConsumerNodes(
769  /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
770  eraseUnusedMemRefAllocations();
771  }
772 
773  /// Returns true if a private memref can be created for `memref` given
774  /// the fusion scenario reflected by the other arguments.
775  bool canCreatePrivateMemRef(Value memref,
776  const DenseSet<Value> &srcEscapingMemRefs,
777  unsigned producerId, unsigned consumerId,
778  bool removeSrcNode) {
779  const Node *consumerNode = mdg->getNode(consumerId);
780  // If `memref` is an escaping one, do not create a private memref
781  // for the below scenarios, since doing so will leave the escaping
782  // memref unmodified as all the writes originally meant for the
783  // escaping memref would be performed on the private memref:
784  // 1. The source is to be removed after fusion,
785  // OR
786  // 2. The destination writes to `memref`.
787  if (srcEscapingMemRefs.count(memref) > 0 &&
788  (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
789  return false;
790 
791  // Don't create a private memref if 'srcNode' has in edges on
792  // 'memref' or 'dstNode' has out edges on 'memref'.
793  if (mdg->getIncomingMemRefAccesses(producerId, memref) > 0 ||
794  mdg->getOutEdgeCount(consumerId, memref) > 0)
795  return false;
796 
797  // If 'srcNode' will be removed but it has out edges on 'memref' to
798  // nodes other than 'dstNode', we have to preserve dependences and
799  // cannot create a private memref.
800  if (removeSrcNode &&
801  any_of(mdg->outEdges[producerId], [&](const auto &edge) {
802  return edge.value == memref && edge.id != consumerId;
803  }))
804  return false;
805 
806  return true;
807  }
808 
809  /// Perform fusions with node `dstId` as the destination of fusion, with
810  /// No fusion is performed when producers with a user count greater than
811  /// `maxSrcUserCount` for any of the memrefs involved.
812  void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) {
813  LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
814  // Skip if this node was removed (fused into another node).
815  if (mdg->nodes.count(dstId) == 0)
816  return;
817  // Get 'dstNode' into which to attempt fusion.
818  auto *dstNode = mdg->getNode(dstId);
819  // Skip if 'dstNode' is not a loop nest.
820  if (!isa<AffineForOp>(dstNode->op))
821  return;
822  // Skip if 'dstNode' is a loop nest returning values.
823  // TODO: support loop nests that return values.
824  if (dstNode->op->getNumResults() > 0)
825  return;
826 
827  LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
828 
829  // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
830  // while preserving relative order. This can increase the maximum loop
831  // depth at which we can fuse a slice of a producer loop nest into a
832  // consumer loop nest.
833  sinkSequentialLoops(dstNode);
834  auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
835 
836  // Try to fuse 'dstNode' with candidate producer loops until a fixed point
837  // is reached. Fusing two loops may expose new fusion opportunities.
838  bool dstNodeChanged;
839  do {
840  // Gather src loop candidates for 'dstNode' and visit them in "quasi"
841  // reverse program order to minimize the number of iterations needed to
842  // reach the fixed point. Note that this is a best effort approach since
843  // 'getProducerCandidates' does not always guarantee that program order
844  // in 'srcIdCandidates'.
845  dstNodeChanged = false;
846  SmallVector<unsigned, 16> srcIdCandidates;
847  getProducerCandidates(dstId, mdg, srcIdCandidates);
848 
849  for (unsigned srcId : llvm::reverse(srcIdCandidates)) {
850  // Get 'srcNode' from which to attempt fusion into 'dstNode'.
851  auto *srcNode = mdg->getNode(srcId);
852  auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
853  LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId
854  << " for dst loop " << dstId << "\n");
855 
856  // Skip if 'srcNode' is a loop nest returning values.
857  // TODO: support loop nests that return values.
858  if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
859  continue;
860 
861  DenseSet<Value> producerConsumerMemrefs;
862  gatherProducerConsumerMemrefs(srcId, dstId, mdg,
863  producerConsumerMemrefs);
864 
865  // Skip if 'srcNode' out edge count on any memref is greater than
866  // 'maxSrcUserCount'.
867  if (any_of(producerConsumerMemrefs, [&](Value memref) {
868  return mdg->getOutEdgeCount(srcNode->id, memref) >
869  maxSrcUserCount;
870  }))
871  continue;
872 
873  // Gather memrefs in 'srcNode' that are written and escape out of the
874  // block (e.g., memref block arguments, returned memrefs,
875  // memrefs passed to function calls, etc.).
876  DenseSet<Value> srcEscapingMemRefs;
877  gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
878 
879  // Skip if there are non-affine operations in between the 'srcNode'
880  // and 'dstNode' using their memrefs. If so, we wouldn't be able to
881  // compute a legal insertion point for now. 'srcNode' and 'dstNode'
882  // memrefs with non-affine operation users would be considered
883  // escaping memrefs so we can limit this check to only scenarios with
884  // escaping memrefs.
885  if (!srcEscapingMemRefs.empty() &&
886  hasNonAffineUsersOnThePath(srcId, dstId, mdg)) {
887  LLVM_DEBUG(llvm::dbgs()
888  << "Can't fuse: non-affine users in between the loops\n");
889  continue;
890  }
891 
892  // Compute an operation list insertion point for the fused loop
893  // nest which preserves dependences.
894  Operation *fusedLoopInsPoint =
895  mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
896  if (fusedLoopInsPoint == nullptr)
897  continue;
898 
899  // It's possible this fusion is at an inner depth (i.e., there are
900  // common surrounding affine loops for the source and destination for
901  // ops). We need to get this number because the call to canFuseLoops
902  // needs to be passed the absolute depth. The max legal depth and the
903  // depths we try below are however *relative* and as such don't include
904  // the common depth.
905  SmallVector<AffineForOp, 4> surroundingLoops;
906  getAffineForIVs(*dstAffineForOp, &surroundingLoops);
907  unsigned numSurroundingLoops = surroundingLoops.size();
908 
909  // Compute the innermost common loop depth for dstNode
910  // producer-consumer loads/stores.
911  SmallVector<Operation *, 2> dstMemrefOps;
912  for (Operation *op : dstNode->loads)
913  if (producerConsumerMemrefs.count(
914  cast<AffineReadOpInterface>(op).getMemRef()) > 0)
915  dstMemrefOps.push_back(op);
916  for (Operation *op : dstNode->stores)
917  if (producerConsumerMemrefs.count(
918  cast<AffineWriteOpInterface>(op).getMemRef()))
919  dstMemrefOps.push_back(op);
920  unsigned dstLoopDepthTest =
921  getInnermostCommonLoopDepth(dstMemrefOps) - numSurroundingLoops;
922 
923  // Check the feasibility of fusing src loop nest into dst loop nest
924  // at loop depths in range [1, dstLoopDepthTest].
925  unsigned maxLegalFusionDepth = 0;
926  SmallVector<ComputationSliceState, 8> depthSliceUnions;
927  depthSliceUnions.resize(dstLoopDepthTest);
929  for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
930  FusionResult result =
931  affine::canFuseLoops(srcAffineForOp, dstAffineForOp,
932  /*dstLoopDepth=*/i + numSurroundingLoops,
933  &depthSliceUnions[i - 1], strategy);
934 
935  if (result.value == FusionResult::Success)
936  maxLegalFusionDepth = i;
937  }
938 
939  if (maxLegalFusionDepth == 0) {
940  LLVM_DEBUG(llvm::dbgs()
941  << "Can't fuse: fusion is not legal at any depth\n");
942  continue;
943  }
944 
945  // Check if fusion would be profitable. We skip profitability analysis
946  // for maximal fusion since we already know the maximal legal depth to
947  // fuse.
948  unsigned bestDstLoopDepth = maxLegalFusionDepth;
949  if (!maximalFusion) {
950  // Retrieve producer stores from the src loop.
951  SmallVector<Operation *, 2> producerStores;
952  for (Operation *op : srcNode->stores)
953  if (producerConsumerMemrefs.count(
954  cast<AffineWriteOpInterface>(op).getMemRef()))
955  producerStores.push_back(op);
956 
957  // TODO: Suppport multiple producer stores in profitability
958  // analysis. We limit profitability analysis to only scenarios with
959  // a single producer store for now. Note that some multi-store
960  // producer scenarios will still go through profitability analysis
961  // if only one of the stores is involved the producer-consumer
962  // relationship of the candidate loops.
963  assert(!producerStores.empty() && "Expected producer store");
964  if (producerStores.size() > 1)
965  LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not "
966  "supported for this case\n");
967  else if (!isFusionProfitable(producerStores[0], producerStores[0],
968  dstAffineForOp, depthSliceUnions,
969  maxLegalFusionDepth, &bestDstLoopDepth,
970  computeToleranceThreshold))
971  continue;
972  }
973 
974  assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
975  ComputationSliceState &bestSlice =
976  depthSliceUnions[bestDstLoopDepth - 1];
977  assert(!bestSlice.isEmpty() && "Missing slice union for depth");
978 
979  // Determine if 'srcId' can be removed after fusion, taking into
980  // account remaining dependences, escaping memrefs and the fusion
981  // insertion point.
982  bool removeSrcNode = canRemoveSrcNodeAfterFusion(
983  srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
984  mdg);
985 
986  DenseSet<Value> privateMemrefs;
987  for (Value memref : producerConsumerMemrefs) {
988  if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
989  removeSrcNode)) {
990  // Create a private version of this memref.
991  LLVM_DEBUG(llvm::dbgs()
992  << "Creating private memref for " << memref << '\n');
993  // Create a private version of this memref.
994  privateMemrefs.insert(memref);
995  }
996  }
997 
998  // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
999  fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
1000  dstNodeChanged = true;
1001 
1002  LLVM_DEBUG(llvm::dbgs()
1003  << "Fused src loop " << srcId << " into dst loop " << dstId
1004  << " at depth " << bestDstLoopDepth << ":\n"
1005  << dstAffineForOp << "\n");
1006 
1007  // Move 'dstAffineForOp' before 'insertPointInst' if needed.
1008  if (fusedLoopInsPoint != dstAffineForOp)
1009  dstAffineForOp->moveBefore(fusedLoopInsPoint);
1010 
1011  // Update edges between 'srcNode' and 'dstNode'.
1012  mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
1013  removeSrcNode);
1014 
1015  // Create private memrefs.
1016  if (!privateMemrefs.empty()) {
1017  // Gather stores for all the private-to-be memrefs.
1018  DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
1019  dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
1020  Value storeMemRef = storeOp.getMemRef();
1021  if (privateMemrefs.count(storeMemRef) > 0)
1022  privateMemRefToStores[storeMemRef].push_back(storeOp);
1023  });
1024 
1025  // Replace original memrefs with private memrefs. Note that all the
1026  // loads and stores on these memrefs will be replaced with a new
1027  // loads and stores. Any reference to the original ones becomes
1028  // invalid after this point.
1029  for (auto &memrefToStoresPair : privateMemRefToStores) {
1030  // TODO: Use union of memref write regions to compute
1031  // private memref footprint.
1032  SmallVector<Operation *, 4> &storesForMemref =
1033  memrefToStoresPair.second;
1034  Value newMemRef = createPrivateMemRef(
1035  dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
1036  fastMemorySpace, localBufSizeThreshold);
1037  // Create new node in dependence graph for 'newMemRef' alloc op.
1038  unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
1039  // Add edge from 'newMemRef' node to dstNode.
1040  mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
1041  }
1042  // One or more entries for 'newMemRef' alloc op are inserted into
1043  // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to
1044  // reallocate, update dstNode.
1045  dstNode = mdg->getNode(dstId);
1046  }
1047 
1048  // Collect dst loop stats after memref privatization transformation.
1049  LoopNestStateCollector dstLoopCollector;
1050  dstLoopCollector.collect(dstAffineForOp);
1051 
1052  // Clear and add back loads and stores.
1053  mdg->clearNodeLoadAndStores(dstNode->id);
1054  mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
1055  dstLoopCollector.storeOpInsts);
1056 
1057  if (removeSrcNode) {
1058  LLVM_DEBUG(llvm::dbgs()
1059  << "Removing src loop " << srcId << " after fusion\n");
1060  // srcNode is no longer valid after it is removed from mdg.
1061  srcAffineForOp.erase();
1062  mdg->removeNode(srcId);
1063  srcNode = nullptr;
1064  }
1065  }
1066  } while (dstNodeChanged);
1067  }
1068 
1069  /// Visit each node in the graph, and for each node, attempt to fuse it with
1070  /// producer-consumer candidates. No fusion is performed when producers with a
1071  /// user count greater than `maxSrcUserCount` for any of the memrefs involved
1072  /// are encountered.
1073  void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
1074  LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
1075  init();
1076  while (!worklist.empty()) {
1077  unsigned dstId = worklist.back();
1078  worklist.pop_back();
1079  performFusionsIntoDest(dstId, maxSrcUserCount);
1080  }
1081  }
1082 
1083  // Visits each node in the graph, and for each node, attempts to fuse it with
1084  // its sibling nodes (nodes which share a parent, but no dependence edges).
1085  void fuseSiblingNodes() {
1086  LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n");
1087  init();
1088  while (!worklist.empty()) {
1089  unsigned dstId = worklist.back();
1090  worklist.pop_back();
1091 
1092  // Skip if this node was removed (fused into another node).
1093  if (mdg->nodes.count(dstId) == 0)
1094  continue;
1095  // Get 'dstNode' into which to attempt fusion.
1096  auto *dstNode = mdg->getNode(dstId);
1097  // Skip if 'dstNode' is not a loop nest.
1098  if (!isa<AffineForOp>(dstNode->op))
1099  continue;
1100  // Attempt to fuse 'dstNode' with its sibling nodes in the graph.
1101  fuseWithSiblingNodes(dstNode);
1102  }
1103  }
1104 
1105  // Attempt to fuse 'dstNode' with sibling nodes in the graph.
1106  void fuseWithSiblingNodes(Node *dstNode) {
1107  DenseSet<unsigned> visitedSibNodeIds;
1108  std::pair<unsigned, Value> idAndMemref;
1109  auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
1110 
1111  while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
1112  unsigned sibId = idAndMemref.first;
1113  Value memref = idAndMemref.second;
1114  // TODO: Check that 'sibStoreOpInst' post-dominates all other
1115  // stores to the same memref in 'sibNode' loop nest.
1116  auto *sibNode = mdg->getNode(sibId);
1117  // Compute an operation list insertion point for the fused loop
1118  // nest which preserves dependences.
1119  assert(sibNode->op->getBlock() == dstNode->op->getBlock());
1120  Operation *insertPointInst =
1121  sibNode->op->isBeforeInBlock(dstNode->op)
1122  ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id)
1123  : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id);
1124  if (insertPointInst == nullptr)
1125  continue;
1126 
1127  // Check if fusion would be profitable and at what depth.
1128 
1129  // Get unique 'sibNode' load op to 'memref'.
1130  SmallVector<Operation *, 2> sibLoadOpInsts;
1131  sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
1132  // Currently findSiblingNodeToFuse searches for siblings with one load.
1133  assert(sibLoadOpInsts.size() == 1);
1134  Operation *sibLoadOpInst = sibLoadOpInsts[0];
1135 
1136  // Gather 'dstNode' load ops to 'memref'.
1137  SmallVector<Operation *, 2> dstLoadOpInsts;
1138  dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
1139 
1140  // It's possible this fusion is at an inner depth (i.e., there are common
1141  // surrounding affine loops for the source and destination for ops). We
1142  // need to get this number because the call to canFuseLoops needs to be
1143  // passed the absolute depth. The max legal depth and the depths we try
1144  // below are however *relative* and as such don't include the common
1145  // depth.
1146  SmallVector<AffineForOp, 4> surroundingLoops;
1147  getAffineForIVs(*dstAffineForOp, &surroundingLoops);
1148  unsigned numSurroundingLoops = surroundingLoops.size();
1149  SmallVector<AffineForOp, 4> dstLoopIVs;
1150  getAffineForIVs(*dstLoadOpInsts[0], &dstLoopIVs);
1151  unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
1152  auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
1153 
1154  // Compute loop depth and slice union for fusion.
1155  SmallVector<ComputationSliceState, 8> depthSliceUnions;
1156  depthSliceUnions.resize(dstLoopDepthTest);
1157  unsigned maxLegalFusionDepth = 0;
1158  FusionStrategy strategy(memref);
1159  for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1160  FusionResult result =
1161  affine::canFuseLoops(sibAffineForOp, dstAffineForOp,
1162  /*dstLoopDepth=*/i + numSurroundingLoops,
1163  &depthSliceUnions[i - 1], strategy);
1164 
1165  if (result.value == FusionResult::Success)
1166  maxLegalFusionDepth = i;
1167  }
1168 
1169  LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
1170  << maxLegalFusionDepth << '\n');
1171 
1172  // Skip if fusion is not feasible at any loop depths.
1173  if (maxLegalFusionDepth == 0)
1174  continue;
1175 
1176  unsigned bestDstLoopDepth = maxLegalFusionDepth;
1177  if (!maximalFusion) {
1178  // Check if fusion would be profitable. For sibling fusion, the sibling
1179  // load op is treated as the src "store" op for fusion profitability
1180  // purposes. The footprint of the load in the slice relative to the
1181  // unfused source's determines reuse.
1182  if (!isFusionProfitable(sibLoadOpInst, sibLoadOpInst, dstAffineForOp,
1183  depthSliceUnions, maxLegalFusionDepth,
1184  &bestDstLoopDepth, computeToleranceThreshold))
1185  continue;
1186  }
1187 
1188  assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
1189  assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
1190  "Fusion depth has no computed slice union");
1191  // Check if source loop is being inserted in the innermost
1192  // destination loop. Based on this, the fused loop may be optimized
1193  // further inside `fuseLoops`.
1194  bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
1195  // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
1196  affine::fuseLoops(sibAffineForOp, dstAffineForOp,
1197  depthSliceUnions[bestDstLoopDepth - 1],
1198  isInnermostInsertion);
1199 
1200  auto dstForInst = cast<AffineForOp>(dstNode->op);
1201  // Update operation position of fused loop nest (if needed).
1202  if (insertPointInst != dstForInst) {
1203  dstForInst->moveBefore(insertPointInst);
1204  }
1205  // Update data dependence graph state post fusion.
1206  updateStateAfterSiblingFusion(sibNode, dstNode);
1207  }
1208  }
1209 
1210  // Searches block argument uses and the graph from 'dstNode' looking for a
1211  // fusion candidate sibling node which shares no dependences with 'dstNode'
1212  // but which loads from the same memref. Returns true and sets
1213  // 'idAndMemrefToFuse' on success. Returns false otherwise.
1214  bool findSiblingNodeToFuse(Node *dstNode,
1215  DenseSet<unsigned> *visitedSibNodeIds,
1216  std::pair<unsigned, Value> *idAndMemrefToFuse) {
1217  // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse
1218  // on 'memref'.
1219  auto canFuseWithSibNode = [&](Node *sibNode, Value memref) {
1220  // Skip if 'outEdge' is not a read-after-write dependence.
1221  // TODO: Remove restrict to single load op restriction.
1222  if (sibNode->getLoadOpCount(memref) != 1)
1223  return false;
1224  // Skip if there exists a path of dependent edges between
1225  // 'sibNode' and 'dstNode'.
1226  if (mdg->hasDependencePath(sibNode->id, dstNode->id) ||
1227  mdg->hasDependencePath(dstNode->id, sibNode->id))
1228  return false;
1229  // Skip sib node if it loads to (and stores from) the same memref on
1230  // which it also has an input dependence edge.
1231  DenseSet<Value> loadAndStoreMemrefSet;
1232  sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
1233  if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) {
1234  return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0;
1235  }))
1236  return false;
1237 
1238  // Check that all stores are to the same memref if any.
1239  DenseSet<Value> storeMemrefs;
1240  for (auto *storeOpInst : sibNode->stores) {
1241  storeMemrefs.insert(
1242  cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
1243  }
1244  if (storeMemrefs.size() > 1)
1245  return false;
1246 
1247  // Skip if a memref value in one node is used by a non-affine memref
1248  // access that lies between 'dstNode' and 'sibNode'.
1249  if (hasNonAffineUsersOnThePath(dstNode->id, sibNode->id, mdg) ||
1250  hasNonAffineUsersOnThePath(sibNode->id, dstNode->id, mdg))
1251  return false;
1252  return true;
1253  };
1254 
1255  // Search for siblings which load the same memref block argument.
1256  Block *block = dstNode->op->getBlock();
1257  for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) {
1258  for (Operation *user : block->getArgument(i).getUsers()) {
1259  auto loadOp = dyn_cast<AffineReadOpInterface>(user);
1260  if (!loadOp)
1261  continue;
1262  // Gather loops surrounding 'use'.
1264  getAffineForIVs(*user, &loops);
1265  // Skip 'use' if it is not within a loop nest.
1266  // Find the surrounding affine.for nested immediately within the
1267  // block.
1268  auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
1269  return loop->getBlock() == &mdg->block;
1270  });
1271  // Skip 'use' if it is not within a loop nest in `block`.
1272  if (it == loops.end())
1273  continue;
1274  Node *sibNode = mdg->getForOpNode(*it);
1275  assert(sibNode != nullptr);
1276  // Skip 'use' if it not a sibling to 'dstNode'.
1277  if (sibNode->id == dstNode->id)
1278  continue;
1279  // Skip 'use' if it has been visited.
1280  if (visitedSibNodeIds->count(sibNode->id) > 0)
1281  continue;
1282  // Skip 'use' if it does not load from the same memref as 'dstNode'.
1283  auto memref = loadOp.getMemRef();
1284  if (dstNode->getLoadOpCount(memref) == 0)
1285  continue;
1286  // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
1287  if (canFuseWithSibNode(sibNode, memref)) {
1288  visitedSibNodeIds->insert(sibNode->id);
1289  idAndMemrefToFuse->first = sibNode->id;
1290  idAndMemrefToFuse->second = memref;
1291  return true;
1292  }
1293  }
1294  }
1295 
1296  // Search for siblings by following edges through an intermediate src node.
1297  // Collect candidate 'dstNode' input edges in 'inEdges'.
1300  dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
1301  // Add 'inEdge' if it is a read-after-write dependence.
1302  if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
1303  mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
1304  inEdges.push_back(inEdge);
1305  });
1306 
1307  // Search for sibling nodes to fuse by visiting output edges from each input
1308  // edge in 'inEdges'.
1309  for (auto &inEdge : inEdges) {
1310  // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'.
1313  inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) {
1314  unsigned sibNodeId = outEdge.id;
1315  if (visitedSibNodeIds->count(sibNodeId) > 0)
1316  return;
1317  // Skip output edge if not a sibling using the same memref.
1318  if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
1319  return;
1320  auto *sibNode = mdg->getNode(sibNodeId);
1321  if (!isa<AffineForOp>(sibNode->op))
1322  return;
1323  // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
1324  if (canFuseWithSibNode(sibNode, outEdge.value)) {
1325  // Add candidate 'outEdge' to sibling node.
1326  outEdges.push_back(outEdge);
1327  }
1328  });
1329 
1330  // Add first candidate if any were returned.
1331  if (!outEdges.empty()) {
1332  visitedSibNodeIds->insert(outEdges[0].id);
1333  idAndMemrefToFuse->first = outEdges[0].id;
1334  idAndMemrefToFuse->second = outEdges[0].value;
1335  return true;
1336  }
1337  }
1338  return false;
1339  }
1340 
1341  /// Update data dependence graph state to reflect sibling fusion of 'sibNode'
1342  /// into 'dstNode'.
1343  void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) {
1344  // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion.
1345  mdg->updateEdges(sibNode->id, dstNode->id);
1346 
1347  // Collect dst loop stats after memref privatization transformation.
1348  auto dstForInst = cast<AffineForOp>(dstNode->op);
1349  LoopNestStateCollector dstLoopCollector;
1350  dstLoopCollector.collect(dstForInst);
1351  // Clear and add back loads and stores
1352  mdg->clearNodeLoadAndStores(dstNode->id);
1353  mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
1354  dstLoopCollector.storeOpInsts);
1355  // Remove old sibling loop nest if it no longer has outgoing dependence
1356  // edges, and it does not write to a memref which escapes the block.
1357  if (mdg->getOutEdgeCount(sibNode->id) == 0) {
1358  Operation *op = sibNode->op;
1359  mdg->removeNode(sibNode->id);
1360  op->erase();
1361  }
1362  }
1363 
1364  // Clean up any allocs with no users.
1365  void eraseUnusedMemRefAllocations() {
1366  for (auto &pair : mdg->memrefEdgeCount) {
1367  if (pair.second > 0)
1368  continue;
1369  auto memref = pair.first;
1370  // Skip if there exist other uses (return operation or function calls).
1371  if (!memref.use_empty())
1372  continue;
1373  // Use list expected to match the dep graph info.
1374  auto *op = memref.getDefiningOp();
1375  if (isa_and_nonnull<memref::AllocOp>(op))
1376  op->erase();
1377  }
1378  }
1379 };
1380 
1381 } // namespace
1382 
1383 /// Run fusion on `block`.
1384 void LoopFusion::runOnBlock(Block *block) {
1385  MemRefDependenceGraph g(*block);
1386  if (!g.init()) {
1387  LLVM_DEBUG(llvm::dbgs() << "MDG init failed\n");
1388  return;
1389  }
1390 
1391  std::optional<unsigned> fastMemorySpaceOpt;
1392  if (fastMemorySpace.hasValue())
1393  fastMemorySpaceOpt = fastMemorySpace;
1394  unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
1395  GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
1396  maximalFusion, computeToleranceThreshold);
1397 
1398  if (affineFusionMode == FusionMode::ProducerConsumer)
1399  fusion.runProducerConsumerFusionOnly();
1400  else if (affineFusionMode == FusionMode::Sibling)
1401  fusion.runSiblingFusionOnly();
1402  else
1403  fusion.runGreedyFusion();
1404 }
1405 
1406 void LoopFusion::runOnOperation() {
1407  // Call fusion on every op that has at least two affine.for nests (in post
1408  // order).
1409  getOperation()->walk([&](Operation *op) {
1410  for (Region &region : op->getRegions()) {
1411  for (Block &block : region.getBlocks()) {
1412  auto affineFors = block.getOps<AffineForOp>();
1413  if (!affineFors.empty() && !llvm::hasSingleElement(affineFors))
1414  runOnBlock(&block);
1415  }
1416  }
1417  });
1418 }
1419 
1421  unsigned fastMemorySpace, uint64_t localBufSizeThreshold,
1422  bool maximalFusion, enum FusionMode affineFusionMode) {
1423  return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
1424  maximalFusion, affineFusionMode);
1425 }
MemRefDependenceGraph::Node Node
Definition: Utils.cpp:38
static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, AffineForOp dstForOp, ArrayRef< ComputationSliceState > depthSliceUnions, unsigned maxLegalFusionDepth, unsigned *dstLoopDepth, double computeToleranceThreshold)
Definition: LoopFusion.cpp:442
static bool canRemoveSrcNodeAfterFusion(unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, Operation *fusedLoopInsPoint, const DenseSet< Value > &escapingMemRefs, MemRefDependenceGraph *mdg)
Returns true if node 'srcId' can be removed after fusing it with node 'dstId'.
Definition: LoopFusion.cpp:78
static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, Value memref, MemRefDependenceGraph *mdg)
Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and 'dstId'), if there is any non-aff...
Definition: LoopFusion.cpp:348
static bool isEscapingMemref(Value memref, Block *block)
A memref escapes in the context of the fusion pass if either:
Definition: LoopFusion.cpp:189
static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, DenseSet< Value > &escapingMemRefs)
Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' that escape the block or ...
Definition: LoopFusion.cpp:216
static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, unsigned dstLoopDepth, std::optional< unsigned > fastMemorySpace, uint64_t localBufSizeThreshold)
Definition: LoopFusion.cpp:244
static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg, SmallVectorImpl< unsigned > &srcIdCandidates)
Returns in 'srcIdCandidates' the producer fusion candidates for consumer 'dstId'.
Definition: LoopFusion.cpp:137
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
Definition: AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
unsigned getNumArguments()
Definition: Block.h:121
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:361
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:353
void getValues(unsigned start, unsigned end, SmallVectorImpl< Value > *values) const
Returns the Values associated with variables in range [start, end).
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:776
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:655
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:214
user_range getUsers() const
Definition: Value.h:224
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:52
static WalkResult interrupt()
Definition: Visitors.h:51
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
Describes the fusion strategy to be used in the Affine loop fusion utilities.
unsigned getNumCols() const
Returns the number of columns in the constraint system.
bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, AffineForOp dstForOp, LoopNestStats &dstStats, const ComputationSliceState &slice, int64_t *computeCost)
Computes and returns in 'computeCost', the total compute cost of fusing the 'slice' of the loop nest ...
void gatherProducerConsumerMemrefs(ArrayRef< Operation * > srcOps, ArrayRef< Operation * > dstOps, DenseSet< Value > &producerConsumerMemrefs)
Returns in 'producerConsumerMemrefs' the memrefs involved in a producer-consumer dependence between w...
int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats)
Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
void fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, const ComputationSliceState &srcSlice, bool isInnermostSiblingInsertionFusion=false)
Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point and source slice loop bo...
void getAffineForIVs(Operation &op, SmallVectorImpl< AffineForOp > *loops)
Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered from the outermost 'affine....
Definition: Utils.cpp:633
std::optional< int64_t > getMemoryFootprintBytes(AffineForOp forOp, int memorySpace=-1)
Gets the memory footprint of all data touched in the specified memory space in bytes; if the memory s...
Definition: Utils.cpp:1953
std::unique_ptr< Pass > createLoopFusionPass(unsigned fastMemorySpace=0, uint64_t localBufSizeThreshold=0, bool maximalFusion=false, enum FusionMode fusionMode=FusionMode::Greedy)
Creates a loop fusion pass which fuses affine loop nests at the top-level of the operation the pass i...
FusionMode
Fusion mode to attempt.
Definition: Passes.h:31
@ ProducerConsumer
Definition: Passes.h:31
unsigned getInnermostCommonLoopDepth(ArrayRef< Operation * > ops, SmallVectorImpl< AffineForOp > *surroundingLoops=nullptr)
Returns the innermost common loop depth for the set of operations in 'ops'.
Definition: Utils.cpp:1385
bool getLoopNestStats(AffineForOp forOp, LoopNestStats *stats)
Collect loop nest statistics (eg.
AffineForOp sinkSequentialLoops(AffineForOp forOp)
Definition: LoopUtils.cpp:1493
FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, ComputationSliceState *srcSlice, FusionStrategy fusionStrategy=FusionStrategy::Generic)
Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the loop nest rooted at 'dst...
LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, ArrayRef< Value > extraIndices={}, AffineMap indexRemap=AffineMap(), ArrayRef< Value > extraOperands={}, ArrayRef< Value > symbolOperands={}, Operation *domOpFilter=nullptr, Operation *postDomOpFilter=nullptr, bool allowNonDereferencingOps=false, bool replaceInDeallocOp=false)
Replaces all "dereferencing" uses of oldMemRef with newMemRef while optionally remapping the old memr...
Definition: Utils.cpp:1266
std::optional< int64_t > getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType)
Returns the memref's element type's size in bytes where the elemental type is an int or float or a ve...
Definition: Utils.cpp:1196
Include the generated interface declarations.
int64_t floorDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's floordiv operation on constants.
Definition: MathExtras.h:33
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their associated operands for a ...
Definition: Utils.h:260
bool isEmpty() const
Returns true if the computation slice is empty.
Definition: Utils.h:292
std::optional< bool > isMaximal() const
Returns true if the computation slice encloses all the iterations of the sliced loop nest.
Definition: Utils.cpp:888
enum mlir::affine::FusionResult::ResultEnum value
SmallVector< Operation *, 4 > loadOpInsts
Definition: Utils.h:40
void collect(Operation *opToWalk)
Definition: Utils.cpp:43
SmallVector< Operation *, 4 > storeOpInsts
Definition: Utils.h:41
LoopNestStats aggregates various per-loop statistics (eg.
DenseMap< unsigned, SmallVector< Edge, 2 > > outEdges
Definition: Utils.h:113
Block & block
The block for which this graph is created to perform fusion.
Definition: Utils.h:212
unsigned addNode(Operation *op)
Definition: Utils.cpp:251
unsigned getIncomingMemRefAccesses(unsigned id, Value memref)
Definition: Utils.cpp:388
void addEdge(unsigned srcId, unsigned dstId, Value value)
Definition: Utils.cpp:316
Node * getForOpNode(AffineForOp forOp)
Definition: Utils.cpp:243
DenseMap< unsigned, Node > nodes
Definition: Utils.h:109
bool hasDependencePath(unsigned srcId, unsigned dstId)
Definition: Utils.cpp:354
void clearNodeLoadAndStores(unsigned id)
Definition: Utils.cpp:576
void updateEdges(unsigned srcId, unsigned dstId, const DenseSet< Value > &privateMemRefs, bool removeSrcId)
Definition: Utils.cpp:504
DenseMap< unsigned, SmallVector< Edge, 2 > > inEdges
Definition: Utils.h:111
void addToNode(unsigned id, const SmallVectorImpl< Operation * > &loads, const SmallVectorImpl< Operation * > &stores)
Definition: Utils.cpp:568
void forEachMemRefInputEdge(unsigned id, const std::function< void(Edge)> &callback)
Definition: Utils.cpp:584
unsigned getOutEdgeCount(unsigned id, Value memref=nullptr)
Definition: Utils.cpp:404
void removeNode(unsigned id)
Definition: Utils.cpp:258
Operation * getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId)
Definition: Utils.cpp:428
void forEachMemRefOutputEdge(unsigned id, const std::function< void(Edge)> &callback)
Definition: Utils.cpp:592
Node * getNode(unsigned id)
Definition: Utils.cpp:236
DenseMap< Value, unsigned > memrefEdgeCount
Definition: Utils.h:116
A region of a memref's data space; this is typically constructed by analyzing load/store op's on this...
Definition: Utils.h:432
FlatAffineValueConstraints * getConstraints()
Definition: Utils.h:471
std::optional< int64_t > getConstantBoundingSizeAndShape(SmallVectorImpl< int64_t > *shape=nullptr, std::vector< SmallVector< int64_t, 4 >> *lbs=nullptr, SmallVectorImpl< int64_t > *lbDivisors=nullptr) const
Returns a constant upper bound on the number of elements in this region if bounded by a known constan...
Definition: Utils.cpp:942
std::optional< int64_t > getRegionSize()
Returns the size of this MemRefRegion in bytes.
Definition: Utils.cpp:1215
LogicalResult compute(Operation *op, unsigned loopDepth, const ComputationSliceState *sliceState=nullptr, bool addMemRefDimBounds=true)
Computes the memory region accessed by this memref with the region represented as constraints symboli...
Definition: Utils.cpp:1044
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.