MLIR  19.0.0git
LoopFusion.cpp
Go to the documentation of this file.
1 //===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements affine fusion.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/Builders.h"
26 #include "mlir/Transforms/Passes.h"
27 #include "llvm/ADT/DenseMap.h"
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include <iomanip>
34 #include <optional>
35 #include <sstream>
36 
37 namespace mlir {
38 namespace affine {
39 #define GEN_PASS_DEF_AFFINELOOPFUSION
40 #include "mlir/Dialect/Affine/Passes.h.inc"
41 } // namespace affine
42 } // namespace mlir
43 
44 #define DEBUG_TYPE "affine-loop-fusion"
45 
46 using namespace mlir;
47 using namespace mlir::affine;
48 
49 namespace {
50 /// Loop fusion pass. This pass currently supports a greedy fusion policy,
51 /// which fuses loop nests with single-writer/single-reader memref dependences
52 /// with the goal of improving locality.
53 // TODO: Support fusion of source loop nests which write to multiple
54 // memrefs, where each memref can have multiple users (if profitable).
55 struct LoopFusion : public affine::impl::AffineLoopFusionBase<LoopFusion> {
56  LoopFusion() = default;
57  LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes,
58  bool maximalFusion, enum FusionMode affineFusionMode) {
59  this->fastMemorySpace = fastMemorySpace;
60  this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024;
61  this->maximalFusion = maximalFusion;
62  this->affineFusionMode = affineFusionMode;
63  }
64 
65  void runOnBlock(Block *block);
66  void runOnOperation() override;
67 };
68 
69 } // namespace
70 
71 /// Returns true if node 'srcId' can be removed after fusing it with node
72 /// 'dstId'. The node can be removed if any of the following conditions are met:
73 /// 1. 'srcId' has no output dependences after fusion and no escaping memrefs.
74 /// 2. 'srcId' has no output dependences after fusion, has escaping memrefs
75 /// and the fusion slice is maximal.
76 /// 3. 'srcId' has output dependences after fusion, the fusion slice is
77 /// maximal and the fusion insertion point dominates all the dependences.
79  unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice,
80  Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs,
81  MemRefDependenceGraph *mdg) {
82 
83  Operation *dstNodeOp = mdg->getNode(dstId)->op;
84  bool hasOutDepsAfterFusion = false;
85 
86  for (auto &outEdge : mdg->outEdges[srcId]) {
87  Operation *depNodeOp = mdg->getNode(outEdge.id)->op;
88  // Skip dependence with dstOp since it will be removed after fusion.
89  if (depNodeOp == dstNodeOp)
90  continue;
91 
92  // Only fusion within the same block is supported. Use domination analysis
93  // when needed.
94  if (depNodeOp->getBlock() != dstNodeOp->getBlock())
95  return false;
96 
97  // Check if the insertion point of the fused loop dominates the dependence.
98  // Otherwise, the src loop can't be removed.
99  if (fusedLoopInsPoint != depNodeOp &&
100  !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) {
101  LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't "
102  "dominate dependence\n");
103  return false;
104  }
105 
106  hasOutDepsAfterFusion = true;
107  }
108 
109  // If src loop has dependences after fusion or it writes to an live-out or
110  // escaping memref, we can only remove it if the fusion slice is maximal so
111  // that all the dependences are preserved.
112  if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
113  std::optional<bool> isMaximal = fusionSlice.isMaximal();
114  if (!isMaximal) {
115  LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine "
116  "if fusion is maximal\n");
117  return false;
118  }
119 
120  if (!*isMaximal) {
121  LLVM_DEBUG(llvm::dbgs()
122  << "Src loop can't be removed: fusion is not maximal\n");
123  return false;
124  }
125  }
126 
127  return true;
128 }
129 
130 /// Returns in 'srcIdCandidates' the producer fusion candidates for consumer
131 /// 'dstId'. Candidates are sorted by node id order. This order corresponds to
132 /// the program order when the 'mdg' is created. However, program order is not
133 /// guaranteed and must not be required by the client. Program order won't be
134 /// held if the 'mdg' is reused from a previous fusion step or if the node
135 /// creation order changes in the future to support more advance cases.
136 // TODO: Move this to a loop fusion utility once 'mdg' is also moved.
137 static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg,
138  SmallVectorImpl<unsigned> &srcIdCandidates) {
139  // Skip if no input edges along which to fuse.
140  if (mdg->inEdges.count(dstId) == 0)
141  return;
142 
143  // Gather memrefs from loads in 'dstId'.
144  auto *dstNode = mdg->getNode(dstId);
145  DenseSet<Value> consumedMemrefs;
146  for (Operation *load : dstNode->loads)
147  consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
148 
149  // Traverse 'dstId' incoming edges and gather the nodes that contain a store
150  // to one of the consumed memrefs.
151  for (auto &srcEdge : mdg->inEdges[dstId]) {
152  auto *srcNode = mdg->getNode(srcEdge.id);
153  // Skip if 'srcNode' is not a loop nest.
154  if (!isa<AffineForOp>(srcNode->op))
155  continue;
156 
157  if (any_of(srcNode->stores, [&](Operation *op) {
158  auto storeOp = cast<AffineWriteOpInterface>(op);
159  return consumedMemrefs.count(storeOp.getMemRef()) > 0;
160  }))
161  srcIdCandidates.push_back(srcNode->id);
162  }
163 
164  llvm::sort(srcIdCandidates);
165  srcIdCandidates.erase(llvm::unique(srcIdCandidates), srcIdCandidates.end());
166 }
167 
168 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a
169 /// producer-consumer dependence between 'srcId' and 'dstId'.
170 static void
171 gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
173  DenseSet<Value> &producerConsumerMemrefs) {
174  auto *dstNode = mdg->getNode(dstId);
175  auto *srcNode = mdg->getNode(srcId);
176  gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads,
177  producerConsumerMemrefs);
178 }
179 
180 /// A memref escapes in the context of the fusion pass if either:
181 /// 1. it (or its alias) is a block argument, or
182 /// 2. created by an op not known to guarantee alias freedom,
183 /// 3. it (or its alias) are used by ops other than affine dereferencing ops
184 /// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops,
185 /// terminator ops, etc.); such ops do not deference the memref in an affine
186 /// way.
187 static bool isEscapingMemref(Value memref, Block *block) {
188  Operation *defOp = memref.getDefiningOp();
189  // Check if 'memref' is a block argument.
190  if (!defOp)
191  return true;
192 
193  // Check if this is defined to be an alias of another memref.
194  if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
195  if (isEscapingMemref(viewOp.getViewSource(), block))
196  return true;
197 
198  // Any op besides allocating ops wouldn't guarantee alias freedom
199  if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(defOp, memref))
200  return true;
201 
202  // Check if 'memref' is used by a non-deferencing op (including unknown ones)
203  // (e.g., call ops, alias creating ops, etc.).
204  return llvm::any_of(memref.getUsers(), [&](Operation *user) {
205  // Ignore users outside of `block`.
206  Operation *ancestorOp = block->getParent()->findAncestorOpInRegion(*user);
207  if (!ancestorOp)
208  return true;
209  if (ancestorOp->getBlock() != block)
210  return false;
211  return !isa<AffineMapAccessInterface>(*user);
212  });
213 }
214 
215 /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
216 /// that escape the block or are accessed in a non-affine way.
217 static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
218  DenseSet<Value> &escapingMemRefs) {
219  auto *node = mdg->getNode(id);
220  for (Operation *storeOp : node->stores) {
221  auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
222  if (escapingMemRefs.count(memref))
223  continue;
224  if (isEscapingMemref(memref, &mdg->block))
225  escapingMemRefs.insert(memref);
226  }
227 }
228 
229 // Sinks all sequential loops to the innermost levels (while preserving
230 // relative order among them) and moves all parallel loops to the
231 // outermost (while again preserving relative order among them).
232 // This can increase the loop depth at which we can fuse a slice, since we are
233 // pushing loop carried dependence to a greater depth in the loop nest.
235  assert(isa<AffineForOp>(node->op));
236  AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op));
237  node->op = newRootForOp;
238 }
239 
240 // Creates and returns a private (single-user) memref for fused loop rooted
241 // at 'forOp', with (potentially reduced) memref size based on the
242 // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
243 // TODO: consider refactoring the common code from generateDma and
244 // this one.
245 static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
246  unsigned dstLoopDepth,
247  std::optional<unsigned> fastMemorySpace,
248  uint64_t localBufSizeThreshold) {
249  Operation *forInst = forOp.getOperation();
250 
251  // Create builder to insert alloc op just before 'forOp'.
252  OpBuilder b(forInst);
253  // Builder to create constants at the top level.
254  OpBuilder top(forInst->getParentRegion());
255  // Create new memref type based on slice bounds.
256  auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
257  auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
258  unsigned rank = oldMemRefType.getRank();
259 
260  // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
261  MemRefRegion region(srcStoreOpInst->getLoc());
262  bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth));
263  (void)validRegion;
264  assert(validRegion && "unexpected memref region failure");
265  SmallVector<int64_t, 4> newShape;
266  std::vector<SmallVector<int64_t, 4>> lbs;
267  SmallVector<int64_t, 8> lbDivisors;
268  lbs.reserve(rank);
269  // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
270  // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
271  std::optional<int64_t> numElements =
272  region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
273  assert(numElements && "non-constant number of elts in local buffer");
274 
275  const FlatAffineValueConstraints *cst = region.getConstraints();
276  // 'outerIVs' holds the values that this memory region is symbolic/parametric
277  // on; this would correspond to loop IVs surrounding the level at which the
278  // slice is being materialized.
279  SmallVector<Value, 8> outerIVs;
280  cst->getValues(rank, cst->getNumVars(), &outerIVs);
281 
282  // Build 'rank' AffineExprs from MemRefRegion 'lbs'
284  offsets.reserve(rank);
285  for (unsigned d = 0; d < rank; ++d) {
286  assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
287 
288  AffineExpr offset = top.getAffineConstantExpr(0);
289  for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
290  offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
291  }
292  assert(lbDivisors[d] > 0);
293  offset =
294  (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
295  offsets.push_back(offset);
296  }
297 
298  // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
299  // by 'srcStoreOpInst'.
300  auto eltSize = getMemRefIntOrFloatEltSizeInBytes(oldMemRefType);
301  assert(eltSize && "memrefs with size elt types expected");
302  uint64_t bufSize = *eltSize * *numElements;
303  unsigned newMemSpace;
304  if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
305  newMemSpace = *fastMemorySpace;
306  } else {
307  newMemSpace = oldMemRefType.getMemorySpaceAsInt();
308  }
309  auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
310  {}, newMemSpace);
311 
312  // Create new private memref for fused loop 'forOp'. 'newShape' is always
313  // a constant shape.
314  // TODO: Create/move alloc ops for private memrefs closer to their
315  // consumer loop nests to reduce their live range. Currently they are added
316  // at the beginning of the block, because loop nests can be reordered
317  // during the fusion pass.
318  Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType);
319 
320  // Build an AffineMap to remap access functions based on lower bound offsets.
321  SmallVector<AffineExpr, 4> remapExprs;
322  remapExprs.reserve(rank);
323  for (unsigned i = 0; i < rank; i++) {
324  auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
325 
326  auto remapExpr =
327  simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
328  remapExprs.push_back(remapExpr);
329  }
330 
331  auto indexRemap =
332  AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
333 
334  // Replace all users of 'oldMemRef' with 'newMemRef'.
335  LogicalResult res =
336  replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
337  /*extraOperands=*/outerIVs,
338  /*symbolOperands=*/{},
339  /*domOpFilter=*/&*forOp.getBody()->begin());
340  assert(succeeded(res) &&
341  "replaceAllMemrefUsesWith should always succeed here");
342  (void)res;
343  return newMemRef;
344 }
345 
346 /// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and
347 /// 'dstId'), if there is any non-affine operation accessing 'memref', return
348 /// true. Otherwise, return false.
349 static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
350  Value memref,
351  MemRefDependenceGraph *mdg) {
352  auto *srcNode = mdg->getNode(srcId);
353  auto *dstNode = mdg->getNode(dstId);
354  Value::user_range users = memref.getUsers();
355  // For each MemRefDependenceGraph's node that is between 'srcNode' and
356  // 'dstNode' (exclusive of 'srcNodes' and 'dstNode'), check whether any
357  // non-affine operation in the node accesses the 'memref'.
358  for (auto &idAndNode : mdg->nodes) {
359  Operation *op = idAndNode.second.op;
360  // Take care of operations between 'srcNode' and 'dstNode'.
361  if (srcNode->op->isBeforeInBlock(op) && op->isBeforeInBlock(dstNode->op)) {
362  // Walk inside the operation to find any use of the memref.
363  // Interrupt the walk if found.
364  auto walkResult = op->walk([&](Operation *user) {
365  // Skip affine ops.
366  if (isa<AffineMapAccessInterface>(*user))
367  return WalkResult::advance();
368  // Find a non-affine op that uses the memref.
369  if (llvm::is_contained(users, user))
370  return WalkResult::interrupt();
371  return WalkResult::advance();
372  });
373  if (walkResult.wasInterrupted())
374  return true;
375  }
376  }
377  return false;
378 }
379 
380 /// Check whether a memref value in node 'srcId' has a non-affine that
381 /// is between node 'srcId' and node 'dstId' (exclusive of 'srcNode' and
382 /// 'dstNode').
383 static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
384  MemRefDependenceGraph *mdg) {
385  // Collect memref values in node 'srcId'.
386  auto *srcNode = mdg->getNode(srcId);
387  llvm::SmallDenseSet<Value, 2> memRefValues;
388  srcNode->op->walk([&](Operation *op) {
389  // Skip affine ops.
390  if (isa<AffineForOp>(op))
391  return WalkResult::advance();
392  for (Value v : op->getOperands())
393  // Collect memref values only.
394  if (isa<MemRefType>(v.getType()))
395  memRefValues.insert(v);
396  return WalkResult::advance();
397  });
398  // Looking for users between node 'srcId' and node 'dstId'.
399  return llvm::any_of(memRefValues, [&](Value memref) {
400  return hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg);
401  });
402 }
403 
404 // Checks the profitability of fusing a backwards slice of the loop nest
405 // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
406 // The argument 'srcStoreOpInst' is used to calculate the storage reduction on
407 // the memref being produced and consumed, which is an input to the cost model.
408 // For producer-consumer fusion, 'srcStoreOpInst' will be the same as
409 // 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse
410 // fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the
411 // same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the
412 // unique store op in the src node, which will be used to check that the write
413 // region is the same after input-reuse fusion. Computation slices are provided
414 // in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which
415 // fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is
416 // profitable to fuse the candidate loop nests. Returns false otherwise.
417 // `dstLoopDepth` is set to the most profitable depth at which to materialize
418 // the source loop nest slice.
419 // The profitability model executes the following steps:
420 // *) Computes the backward computation slice at 'srcOpInst'. This
421 // computation slice of the loop nest surrounding 'srcOpInst' is
422 // represented by modified src loop bounds in 'sliceState', which are
423 // functions of loop IVs in the loop nest surrounding 'srcOpInst'.
424 // *) Computes the cost of unfused src/dst loop nests (currently the cost of a
425 // loop nest is the total number of dynamic operation instances in the loop
426 // nest).
427 // *) Computes the cost of fusing a slice of the src loop nest into the dst
428 // loop nest at various values of dst loop depth, attempting to fuse
429 // the largest computation slice at the maximal dst loop depth (closest to
430 // the load) to minimize reuse distance and potentially enable subsequent
431 // load/store forwarding.
432 // NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
433 // nest, at which the src computation slice is inserted/fused.
434 // NOTE: We attempt to maximize the dst loop depth, but there are cases
435 // where a particular setting for 'dstLoopNest' might fuse an unsliced
436 // loop (within the src computation slice) at a depth which results in
437 // excessive recomputation (see unit tests for examples).
438 // *) Compares the total cost of the unfused loop nests to the min cost fused
439 // loop nest computed in the previous step, and returns true if the latter
440 // is lower.
441 // TODO: Extend profitability analysis to support scenarios with multiple
442 // stores.
443 static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
444  AffineForOp dstForOp,
445  ArrayRef<ComputationSliceState> depthSliceUnions,
446  unsigned maxLegalFusionDepth,
447  unsigned *dstLoopDepth,
448  double computeToleranceThreshold) {
449  LLVM_DEBUG({
450  llvm::dbgs() << "Checking whether fusion is profitable between src op:\n";
451  llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n";
452  llvm::dbgs() << dstForOp << "\n";
453  });
454 
455  if (maxLegalFusionDepth == 0) {
456  LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth is 0\n");
457  return false;
458  }
459 
460  // Compute cost of sliced and unsliced src loop nest.
461  SmallVector<AffineForOp, 4> srcLoopIVs;
462  getAffineForIVs(*srcOpInst, &srcLoopIVs);
463 
464  // Walk src loop nest and collect stats.
465  LoopNestStats srcLoopNestStats;
466  if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats))
467  return false;
468 
469  // Compute cost of dst loop nest.
470  LoopNestStats dstLoopNestStats;
471  if (!getLoopNestStats(dstForOp, &dstLoopNestStats))
472  return false;
473 
474  // Search for min cost value for 'dstLoopDepth'. At each value of
475  // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice
476  // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
477  // of these bounds). Next the union slice bounds are used to calculate
478  // the cost of the slice and the cost of the slice inserted into the dst
479  // loop nest at 'dstLoopDepth'.
480  uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
481  double maxStorageReduction = 0.0;
482  std::optional<uint64_t> sliceMemEstimate;
483 
484  // The best loop depth at which to materialize the slice.
485  std::optional<unsigned> bestDstLoopDepth;
486 
487  // Compute op instance count for the src loop nest without iteration slicing.
488  uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats);
489 
490  // Compute src loop nest write region size.
491  MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
492  if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) {
493  LLVM_DEBUG(llvm::dbgs()
494  << "Unable to compute MemRefRegion for source operation\n");
495  return false;
496  }
497 
498  std::optional<int64_t> maybeSrcWriteRegionSizeBytes =
499  srcWriteRegion.getRegionSize();
500  if (!maybeSrcWriteRegionSizeBytes.has_value())
501  return false;
502  int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;
503 
504  // Compute op instance count for the src loop nest.
505  uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
506 
507  // Evaluate all depth choices for materializing the slice in the destination
508  // loop nest.
509  for (unsigned i = maxLegalFusionDepth; i >= 1; --i) {
510  const ComputationSliceState &slice = depthSliceUnions[i - 1];
511  // Skip slice union if it wasn't computed for this depth.
512  if (slice.isEmpty())
513  continue;
514 
515  int64_t fusedLoopNestComputeCost;
516  if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp,
517  dstLoopNestStats, slice,
518  &fusedLoopNestComputeCost)) {
519  LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
520  continue;
521  }
522 
523  double additionalComputeFraction =
524  fusedLoopNestComputeCost /
525  (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
526  1;
527 
528  // Determine what the slice write MemRefRegion would be, if the src loop
529  // nest slice 'slice' were to be inserted into the dst loop nest at loop
530  // depth 'i'.
531  MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
532  if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
533  &slice))) {
534  LLVM_DEBUG(llvm::dbgs()
535  << "Failed to compute slice write region at loopDepth: " << i
536  << "\n");
537  continue;
538  }
539 
540  std::optional<int64_t> maybeSliceWriteRegionSizeBytes =
541  sliceWriteRegion.getRegionSize();
542  if (!maybeSliceWriteRegionSizeBytes.has_value() ||
543  *maybeSliceWriteRegionSizeBytes == 0) {
544  LLVM_DEBUG(llvm::dbgs()
545  << "Failed to get slice write region size at loopDepth: " << i
546  << "\n");
547  continue;
548  }
549  int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
550 
551  // If we are fusing for reuse, check that write regions remain the same.
552  // TODO: Write region check should check sizes and offsets in
553  // each dimension, so that we are sure they are covering the same memref
554  // region. Also, move this out to a isMemRefRegionSuperSet helper function.
555  if (srcOpInst != srcStoreOpInst &&
556  sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
557  continue;
558 
559  double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
560  static_cast<double>(sliceWriteRegionSizeBytes);
561 
562  LLVM_DEBUG({
563  std::stringstream msg;
564  msg << " evaluating fusion profitability at depth : " << i << "\n"
565  << std::fixed << std::setprecision(2)
566  << " additional compute fraction: "
567  << 100.0 * additionalComputeFraction << "%\n"
568  << " storage reduction factor: " << storageReduction << "x\n"
569  << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
570  << " src write region size: " << srcWriteRegionSizeBytes << "\n"
571  << " slice write region size: " << sliceWriteRegionSizeBytes
572  << "\n";
573  llvm::dbgs() << msg.str();
574  });
575 
576  // TODO: This is a placeholder cost model.
577  // Among all choices that add an acceptable amount of redundant computation
578  // (as per computeToleranceThreshold), we will simply pick the one that
579  // reduces the intermediary size the most.
580  if ((storageReduction > maxStorageReduction) &&
581  (additionalComputeFraction < computeToleranceThreshold)) {
582  maxStorageReduction = storageReduction;
583  bestDstLoopDepth = i;
584  minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
585  sliceMemEstimate = sliceWriteRegionSizeBytes;
586  }
587  }
588 
589  // A simple cost model: fuse if it reduces the memory footprint.
590 
591  if (!bestDstLoopDepth) {
592  LLVM_DEBUG(
593  llvm::dbgs()
594  << "All fusion choices involve more than the threshold amount of "
595  "redundant computation; NOT fusing.\n");
596  return false;
597  }
598 
599  if (!bestDstLoopDepth) {
600  LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n");
601  return false;
602  }
603 
604  // Set dstLoopDepth based on best values from search.
605  *dstLoopDepth = *bestDstLoopDepth;
606 
607  LLVM_DEBUG(
608  llvm::dbgs() << " LoopFusion fusion stats:"
609  << "\n best loop depth: " << bestDstLoopDepth
610  << "\n src loop nest compute cost: " << srcLoopNestCost
611  << "\n dst loop nest compute cost: " << dstLoopNestCost
612  << "\n fused loop nest compute cost: "
613  << minFusedLoopNestComputeCost << "\n");
614 
615  auto dstMemSize = getMemoryFootprintBytes(dstForOp);
616  auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
617 
618  std::optional<double> storageReduction;
619 
620  if (!dstMemSize || !srcMemSize) {
621  LLVM_DEBUG(llvm::dbgs()
622  << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
623  return false;
624  }
625 
626  auto srcMemSizeVal = *srcMemSize;
627  auto dstMemSizeVal = *dstMemSize;
628 
629  assert(sliceMemEstimate && "expected value");
630  auto fusedMem = dstMemSizeVal + *sliceMemEstimate;
631 
632  LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
633  << " dst mem: " << dstMemSizeVal << "\n"
634  << " fused mem: " << fusedMem << "\n"
635  << " slice mem: " << sliceMemEstimate << "\n");
636 
637  if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
638  LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
639  return false;
640  }
641  storageReduction =
642  100.0 *
643  (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
644 
645  double additionalComputeFraction =
646  100.0 * (minFusedLoopNestComputeCost /
647  (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
648  1);
649  (void)additionalComputeFraction;
650  LLVM_DEBUG({
651  std::stringstream msg;
652  msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
653  << std::setprecision(2) << additionalComputeFraction
654  << "% redundant computation and a ";
655  msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>");
656  msg << "% storage reduction.\n";
657  llvm::dbgs() << msg.str();
658  });
659 
660  return true;
661 }
662 
663 namespace {
664 
665 // GreedyFusion greedily fuses loop nests which have a producer/consumer or
666 // input-reuse relationship on a memref, with the goal of improving locality.
667 //
668 // The steps of the producer-consumer fusion algorithm are as follows:
669 //
670 // *) A worklist is initialized with node ids from the dependence graph.
671 // *) For each node id in the worklist:
672 // *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a
673 // candidate destination AffineForOp into which fusion will be attempted.
674 // *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
675 // *) For each LoadOp in 'dstLoadOps' do:
676 // *) Look up dependent loop nests which have a single store op to the same
677 // memref.
678 // *) Check if dependences would be violated by the fusion.
679 // *) Get a computation slice of 'srcLoopNest', which adjusts its loop
680 // bounds to be functions of 'dstLoopNest' IVs and symbols.
681 // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
682 // at a loop depth determined by the cost model in 'isFusionProfitable'.
683 // *) Add the newly fused load/store operations to the state,
684 // and also add newly fused load ops to 'dstLoopOps' to be considered
685 // as fusion dst load ops in another iteration.
686 // *) Remove old src loop nest and its associated state.
687 //
688 // The steps of the input-reuse fusion algorithm are as follows:
689 //
690 // *) Initialize 'worklist' with node ids from the dependence graph.
691 // *) For each 'dstNode' in the worklist:
692 // *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which
693 // loads from the same memref, but which has no dependence paths to/from.
694 // *) Get a computation slice of 'sibLoopNest', which adjusts its loop
695 // bounds to be functions of 'dstLoopNest' IVs and symbols.
696 // *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest',
697 // at a loop depth determined by the cost model in 'isFusionProfitable'.
698 // This function also checks that the memref write region of 'sibLoopNest',
699 // is preserved in the fused loop nest.
700 // *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'.
701 //
702 // Given a graph where top-level operations are vertices in the set 'V' and
703 // edges in the set 'E' are dependences between vertices, this algorithm
704 // takes O(V) time for initialization, and has runtime O(V + E).
705 //
706 // This greedy algorithm is not 'maximal' due to the current restriction of
707 // fusing along single producer consumer edges, but there is a TODO: to fix
708 // this.
709 //
710 // TODO: Experiment with other fusion policies.
711 struct GreedyFusion {
712 public:
713  // The data dependence graph to traverse during fusion.
715  // Worklist of graph nodes visited during the fusion pass.
716  SmallVector<unsigned, 8> worklist;
717  // Parameter for local buffer size threshold.
718  unsigned localBufSizeThreshold;
719  // Parameter for fast memory space.
720  std::optional<unsigned> fastMemorySpace;
721  // If true, ignore any additional (redundant) computation tolerance threshold
722  // that would have prevented fusion.
723  bool maximalFusion;
724  // The amount of additional computation that is tolerated while fusing
725  // pair-wise as a fraction of the total computation.
726  double computeToleranceThreshold;
727 
729 
730  GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold,
731  std::optional<unsigned> fastMemorySpace, bool maximalFusion,
732  double computeToleranceThreshold)
733  : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
734  fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
735  computeToleranceThreshold(computeToleranceThreshold) {}
736 
737  /// Initializes 'worklist' with nodes from 'mdg'.
738  void init() {
739  // TODO: Add a priority queue for prioritizing nodes by different
740  // metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
741  worklist.clear();
742  for (auto &idAndNode : mdg->nodes) {
743  const Node &node = idAndNode.second;
744  worklist.push_back(node.id);
745  }
746  }
747  /// Run only sibling fusion on the `mdg`.
748  void runSiblingFusionOnly() {
749  fuseSiblingNodes();
750  eraseUnusedMemRefAllocations();
751  }
752 
753  /// Run only producer/consumer fusion on the `mdg`.
754  void runProducerConsumerFusionOnly() {
755  fuseProducerConsumerNodes(
756  /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
757  eraseUnusedMemRefAllocations();
758  }
759 
760  // Run the GreedyFusion pass.
761  // *) First pass through the nodes fuses single-use producer nodes into their
762  // unique consumer.
763  // *) Second pass fuses sibling nodes which share no dependence edges.
764  // *) Third pass fuses any remaining producer nodes into their users.
765  void runGreedyFusion() {
766  // TODO: Run this repeatedly until a fixed-point is reached.
767  fuseProducerConsumerNodes(/*maxSrcUserCount=*/1);
768  fuseSiblingNodes();
769  fuseProducerConsumerNodes(
770  /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
771  eraseUnusedMemRefAllocations();
772  }
773 
774  /// Returns true if a private memref can be created for `memref` given
775  /// the fusion scenario reflected by the other arguments.
776  bool canCreatePrivateMemRef(Value memref,
777  const DenseSet<Value> &srcEscapingMemRefs,
778  unsigned producerId, unsigned consumerId,
779  bool removeSrcNode) {
780  const Node *consumerNode = mdg->getNode(consumerId);
781  // If `memref` is an escaping one, do not create a private memref
782  // for the below scenarios, since doing so will leave the escaping
783  // memref unmodified as all the writes originally meant for the
784  // escaping memref would be performed on the private memref:
785  // 1. The source is to be removed after fusion,
786  // OR
787  // 2. The destination writes to `memref`.
788  if (srcEscapingMemRefs.count(memref) > 0 &&
789  (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
790  return false;
791 
792  // Don't create a private memref if 'srcNode' has in edges on
793  // 'memref' or 'dstNode' has out edges on 'memref'.
794  if (mdg->getIncomingMemRefAccesses(producerId, memref) > 0 ||
795  mdg->getOutEdgeCount(consumerId, memref) > 0)
796  return false;
797 
798  // If 'srcNode' will be removed but it has out edges on 'memref' to
799  // nodes other than 'dstNode', we have to preserve dependences and
800  // cannot create a private memref.
801  if (removeSrcNode &&
802  any_of(mdg->outEdges[producerId], [&](const auto &edge) {
803  return edge.value == memref && edge.id != consumerId;
804  }))
805  return false;
806 
807  return true;
808  }
809 
810  /// Perform fusions with node `dstId` as the destination of fusion, with
811  /// No fusion is performed when producers with a user count greater than
812  /// `maxSrcUserCount` for any of the memrefs involved.
813  void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) {
814  LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
815  // Skip if this node was removed (fused into another node).
816  if (mdg->nodes.count(dstId) == 0)
817  return;
818  // Get 'dstNode' into which to attempt fusion.
819  auto *dstNode = mdg->getNode(dstId);
820  // Skip if 'dstNode' is not a loop nest.
821  if (!isa<AffineForOp>(dstNode->op))
822  return;
823  // Skip if 'dstNode' is a loop nest returning values.
824  // TODO: support loop nests that return values.
825  if (dstNode->op->getNumResults() > 0)
826  return;
827 
828  LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
829 
830  // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
831  // while preserving relative order. This can increase the maximum loop
832  // depth at which we can fuse a slice of a producer loop nest into a
833  // consumer loop nest.
834  sinkSequentialLoops(dstNode);
835  auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
836 
837  // Try to fuse 'dstNode' with candidate producer loops until a fixed point
838  // is reached. Fusing two loops may expose new fusion opportunities.
839  bool dstNodeChanged;
840  do {
841  // Gather src loop candidates for 'dstNode' and visit them in "quasi"
842  // reverse program order to minimize the number of iterations needed to
843  // reach the fixed point. Note that this is a best effort approach since
844  // 'getProducerCandidates' does not always guarantee that program order
845  // in 'srcIdCandidates'.
846  dstNodeChanged = false;
847  SmallVector<unsigned, 16> srcIdCandidates;
848  getProducerCandidates(dstId, mdg, srcIdCandidates);
849 
850  for (unsigned srcId : llvm::reverse(srcIdCandidates)) {
851  // Get 'srcNode' from which to attempt fusion into 'dstNode'.
852  auto *srcNode = mdg->getNode(srcId);
853  auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
854  LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId
855  << " for dst loop " << dstId << "\n");
856 
857  // Skip if 'srcNode' is a loop nest returning values.
858  // TODO: support loop nests that return values.
859  if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
860  continue;
861 
862  DenseSet<Value> producerConsumerMemrefs;
863  gatherProducerConsumerMemrefs(srcId, dstId, mdg,
864  producerConsumerMemrefs);
865 
866  // Skip if 'srcNode' out edge count on any memref is greater than
867  // 'maxSrcUserCount'.
868  if (any_of(producerConsumerMemrefs, [&](Value memref) {
869  return mdg->getOutEdgeCount(srcNode->id, memref) >
870  maxSrcUserCount;
871  }))
872  continue;
873 
874  // Gather memrefs in 'srcNode' that are written and escape out of the
875  // block (e.g., memref block arguments, returned memrefs,
876  // memrefs passed to function calls, etc.).
877  DenseSet<Value> srcEscapingMemRefs;
878  gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
879 
880  // Skip if there are non-affine operations in between the 'srcNode'
881  // and 'dstNode' using their memrefs. If so, we wouldn't be able to
882  // compute a legal insertion point for now. 'srcNode' and 'dstNode'
883  // memrefs with non-affine operation users would be considered
884  // escaping memrefs so we can limit this check to only scenarios with
885  // escaping memrefs.
886  if (!srcEscapingMemRefs.empty() &&
887  hasNonAffineUsersOnThePath(srcId, dstId, mdg)) {
888  LLVM_DEBUG(llvm::dbgs()
889  << "Can't fuse: non-affine users in between the loops\n");
890  continue;
891  }
892 
893  // Compute an operation list insertion point for the fused loop
894  // nest which preserves dependences.
895  Operation *fusedLoopInsPoint =
896  mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
897  if (fusedLoopInsPoint == nullptr)
898  continue;
899 
900  // It's possible this fusion is at an inner depth (i.e., there are
901  // common surrounding affine loops for the source and destination for
902  // ops). We need to get this number because the call to canFuseLoops
903  // needs to be passed the absolute depth. The max legal depth and the
904  // depths we try below are however *relative* and as such don't include
905  // the common depth.
906  SmallVector<AffineForOp, 4> surroundingLoops;
907  getAffineForIVs(*dstAffineForOp, &surroundingLoops);
908  unsigned numSurroundingLoops = surroundingLoops.size();
909 
910  // Compute the innermost common loop depth for dstNode
911  // producer-consumer loads/stores.
912  SmallVector<Operation *, 2> dstMemrefOps;
913  for (Operation *op : dstNode->loads)
914  if (producerConsumerMemrefs.count(
915  cast<AffineReadOpInterface>(op).getMemRef()) > 0)
916  dstMemrefOps.push_back(op);
917  for (Operation *op : dstNode->stores)
918  if (producerConsumerMemrefs.count(
919  cast<AffineWriteOpInterface>(op).getMemRef()))
920  dstMemrefOps.push_back(op);
921  unsigned dstLoopDepthTest =
922  getInnermostCommonLoopDepth(dstMemrefOps) - numSurroundingLoops;
923 
924  // Check the feasibility of fusing src loop nest into dst loop nest
925  // at loop depths in range [1, dstLoopDepthTest].
926  unsigned maxLegalFusionDepth = 0;
927  SmallVector<ComputationSliceState, 8> depthSliceUnions;
928  depthSliceUnions.resize(dstLoopDepthTest);
930  for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
931  FusionResult result =
932  affine::canFuseLoops(srcAffineForOp, dstAffineForOp,
933  /*dstLoopDepth=*/i + numSurroundingLoops,
934  &depthSliceUnions[i - 1], strategy);
935 
936  if (result.value == FusionResult::Success)
937  maxLegalFusionDepth = i;
938  }
939 
940  if (maxLegalFusionDepth == 0) {
941  LLVM_DEBUG(llvm::dbgs()
942  << "Can't fuse: fusion is not legal at any depth\n");
943  continue;
944  }
945 
946  // Check if fusion would be profitable. We skip profitability analysis
947  // for maximal fusion since we already know the maximal legal depth to
948  // fuse.
949  unsigned bestDstLoopDepth = maxLegalFusionDepth;
950  if (!maximalFusion) {
951  // Retrieve producer stores from the src loop.
952  SmallVector<Operation *, 2> producerStores;
953  for (Operation *op : srcNode->stores)
954  if (producerConsumerMemrefs.count(
955  cast<AffineWriteOpInterface>(op).getMemRef()))
956  producerStores.push_back(op);
957 
958  // TODO: Suppport multiple producer stores in profitability
959  // analysis. We limit profitability analysis to only scenarios with
960  // a single producer store for now. Note that some multi-store
961  // producer scenarios will still go through profitability analysis
962  // if only one of the stores is involved the producer-consumer
963  // relationship of the candidate loops.
964  assert(!producerStores.empty() && "Expected producer store");
965  if (producerStores.size() > 1)
966  LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not "
967  "supported for this case\n");
968  else if (!isFusionProfitable(producerStores[0], producerStores[0],
969  dstAffineForOp, depthSliceUnions,
970  maxLegalFusionDepth, &bestDstLoopDepth,
971  computeToleranceThreshold))
972  continue;
973  }
974 
975  assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
976  ComputationSliceState &bestSlice =
977  depthSliceUnions[bestDstLoopDepth - 1];
978  assert(!bestSlice.isEmpty() && "Missing slice union for depth");
979 
980  // Determine if 'srcId' can be removed after fusion, taking into
981  // account remaining dependences, escaping memrefs and the fusion
982  // insertion point.
983  bool removeSrcNode = canRemoveSrcNodeAfterFusion(
984  srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
985  mdg);
986 
987  DenseSet<Value> privateMemrefs;
988  for (Value memref : producerConsumerMemrefs) {
989  if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
990  removeSrcNode)) {
991  // Create a private version of this memref.
992  LLVM_DEBUG(llvm::dbgs()
993  << "Creating private memref for " << memref << '\n');
994  // Create a private version of this memref.
995  privateMemrefs.insert(memref);
996  }
997  }
998 
999  // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
1000  fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
1001  dstNodeChanged = true;
1002 
1003  LLVM_DEBUG(llvm::dbgs()
1004  << "Fused src loop " << srcId << " into dst loop " << dstId
1005  << " at depth " << bestDstLoopDepth << ":\n"
1006  << dstAffineForOp << "\n");
1007 
1008  // Move 'dstAffineForOp' before 'insertPointInst' if needed.
1009  if (fusedLoopInsPoint != dstAffineForOp)
1010  dstAffineForOp->moveBefore(fusedLoopInsPoint);
1011 
1012  // Update edges between 'srcNode' and 'dstNode'.
1013  mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
1014  removeSrcNode);
1015 
1016  // Create private memrefs.
1017  if (!privateMemrefs.empty()) {
1018  // Gather stores for all the private-to-be memrefs.
1019  DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
1020  dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
1021  Value storeMemRef = storeOp.getMemRef();
1022  if (privateMemrefs.count(storeMemRef) > 0)
1023  privateMemRefToStores[storeMemRef].push_back(storeOp);
1024  });
1025 
1026  // Replace original memrefs with private memrefs. Note that all the
1027  // loads and stores on these memrefs will be replaced with a new
1028  // loads and stores. Any reference to the original ones becomes
1029  // invalid after this point.
1030  for (auto &memrefToStoresPair : privateMemRefToStores) {
1031  // TODO: Use union of memref write regions to compute
1032  // private memref footprint.
1033  SmallVector<Operation *, 4> &storesForMemref =
1034  memrefToStoresPair.second;
1035  Value newMemRef = createPrivateMemRef(
1036  dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
1037  fastMemorySpace, localBufSizeThreshold);
1038  // Create new node in dependence graph for 'newMemRef' alloc op.
1039  unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
1040  // Add edge from 'newMemRef' node to dstNode.
1041  mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
1042  }
1043  // One or more entries for 'newMemRef' alloc op are inserted into
1044  // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to
1045  // reallocate, update dstNode.
1046  dstNode = mdg->getNode(dstId);
1047  }
1048 
1049  // Collect dst loop stats after memref privatization transformation.
1050  LoopNestStateCollector dstLoopCollector;
1051  dstLoopCollector.collect(dstAffineForOp);
1052 
1053  // Clear and add back loads and stores.
1054  mdg->clearNodeLoadAndStores(dstNode->id);
1055  mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
1056  dstLoopCollector.storeOpInsts);
1057 
1058  if (removeSrcNode) {
1059  LLVM_DEBUG(llvm::dbgs()
1060  << "Removing src loop " << srcId << " after fusion\n");
1061  // srcNode is no longer valid after it is removed from mdg.
1062  srcAffineForOp.erase();
1063  mdg->removeNode(srcId);
1064  srcNode = nullptr;
1065  }
1066  }
1067  } while (dstNodeChanged);
1068  }
1069 
1070  /// Visit each node in the graph, and for each node, attempt to fuse it with
1071  /// producer-consumer candidates. No fusion is performed when producers with a
1072  /// user count greater than `maxSrcUserCount` for any of the memrefs involved
1073  /// are encountered.
1074  void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
1075  LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
1076  init();
1077  while (!worklist.empty()) {
1078  unsigned dstId = worklist.back();
1079  worklist.pop_back();
1080  performFusionsIntoDest(dstId, maxSrcUserCount);
1081  }
1082  }
1083 
1084  // Visits each node in the graph, and for each node, attempts to fuse it with
1085  // its sibling nodes (nodes which share a parent, but no dependence edges).
1086  void fuseSiblingNodes() {
1087  LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n");
1088  init();
1089  while (!worklist.empty()) {
1090  unsigned dstId = worklist.back();
1091  worklist.pop_back();
1092 
1093  // Skip if this node was removed (fused into another node).
1094  if (mdg->nodes.count(dstId) == 0)
1095  continue;
1096  // Get 'dstNode' into which to attempt fusion.
1097  auto *dstNode = mdg->getNode(dstId);
1098  // Skip if 'dstNode' is not a loop nest.
1099  if (!isa<AffineForOp>(dstNode->op))
1100  continue;
1101  // Attempt to fuse 'dstNode' with its sibling nodes in the graph.
1102  fuseWithSiblingNodes(dstNode);
1103  }
1104  }
1105 
1106  // Attempt to fuse 'dstNode' with sibling nodes in the graph.
1107  void fuseWithSiblingNodes(Node *dstNode) {
1108  DenseSet<unsigned> visitedSibNodeIds;
1109  std::pair<unsigned, Value> idAndMemref;
1110  auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
1111 
1112  while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
1113  unsigned sibId = idAndMemref.first;
1114  Value memref = idAndMemref.second;
1115  // TODO: Check that 'sibStoreOpInst' post-dominates all other
1116  // stores to the same memref in 'sibNode' loop nest.
1117  auto *sibNode = mdg->getNode(sibId);
1118  // Compute an operation list insertion point for the fused loop
1119  // nest which preserves dependences.
1120  assert(sibNode->op->getBlock() == dstNode->op->getBlock());
1121  Operation *insertPointInst =
1122  sibNode->op->isBeforeInBlock(dstNode->op)
1123  ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id)
1124  : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id);
1125  if (insertPointInst == nullptr)
1126  continue;
1127 
1128  // Check if fusion would be profitable and at what depth.
1129 
1130  // Get unique 'sibNode' load op to 'memref'.
1131  SmallVector<Operation *, 2> sibLoadOpInsts;
1132  sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
1133  // Currently findSiblingNodeToFuse searches for siblings with one load.
1134  assert(sibLoadOpInsts.size() == 1);
1135  Operation *sibLoadOpInst = sibLoadOpInsts[0];
1136 
1137  // Gather 'dstNode' load ops to 'memref'.
1138  SmallVector<Operation *, 2> dstLoadOpInsts;
1139  dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
1140 
1141  // It's possible this fusion is at an inner depth (i.e., there are common
1142  // surrounding affine loops for the source and destination for ops). We
1143  // need to get this number because the call to canFuseLoops needs to be
1144  // passed the absolute depth. The max legal depth and the depths we try
1145  // below are however *relative* and as such don't include the common
1146  // depth.
1147  SmallVector<AffineForOp, 4> surroundingLoops;
1148  getAffineForIVs(*dstAffineForOp, &surroundingLoops);
1149  unsigned numSurroundingLoops = surroundingLoops.size();
1150  SmallVector<AffineForOp, 4> dstLoopIVs;
1151  getAffineForIVs(*dstLoadOpInsts[0], &dstLoopIVs);
1152  unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
1153  auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
1154 
1155  // Compute loop depth and slice union for fusion.
1156  SmallVector<ComputationSliceState, 8> depthSliceUnions;
1157  depthSliceUnions.resize(dstLoopDepthTest);
1158  unsigned maxLegalFusionDepth = 0;
1159  FusionStrategy strategy(memref);
1160  for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1161  FusionResult result =
1162  affine::canFuseLoops(sibAffineForOp, dstAffineForOp,
1163  /*dstLoopDepth=*/i + numSurroundingLoops,
1164  &depthSliceUnions[i - 1], strategy);
1165 
1166  if (result.value == FusionResult::Success)
1167  maxLegalFusionDepth = i;
1168  }
1169 
1170  LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
1171  << maxLegalFusionDepth << '\n');
1172 
1173  // Skip if fusion is not feasible at any loop depths.
1174  if (maxLegalFusionDepth == 0)
1175  continue;
1176 
1177  unsigned bestDstLoopDepth = maxLegalFusionDepth;
1178  if (!maximalFusion) {
1179  // Check if fusion would be profitable. For sibling fusion, the sibling
1180  // load op is treated as the src "store" op for fusion profitability
1181  // purposes. The footprint of the load in the slice relative to the
1182  // unfused source's determines reuse.
1183  if (!isFusionProfitable(sibLoadOpInst, sibLoadOpInst, dstAffineForOp,
1184  depthSliceUnions, maxLegalFusionDepth,
1185  &bestDstLoopDepth, computeToleranceThreshold))
1186  continue;
1187  }
1188 
1189  assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
1190  assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
1191  "Fusion depth has no computed slice union");
1192  // Check if source loop is being inserted in the innermost
1193  // destination loop. Based on this, the fused loop may be optimized
1194  // further inside `fuseLoops`.
1195  bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
1196  // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
1197  affine::fuseLoops(sibAffineForOp, dstAffineForOp,
1198  depthSliceUnions[bestDstLoopDepth - 1],
1199  isInnermostInsertion);
1200 
1201  auto dstForInst = cast<AffineForOp>(dstNode->op);
1202  // Update operation position of fused loop nest (if needed).
1203  if (insertPointInst != dstForInst) {
1204  dstForInst->moveBefore(insertPointInst);
1205  }
1206  // Update data dependence graph state post fusion.
1207  updateStateAfterSiblingFusion(sibNode, dstNode);
1208  }
1209  }
1210 
1211  // Searches block argument uses and the graph from 'dstNode' looking for a
1212  // fusion candidate sibling node which shares no dependences with 'dstNode'
1213  // but which loads from the same memref. Returns true and sets
1214  // 'idAndMemrefToFuse' on success. Returns false otherwise.
1215  bool findSiblingNodeToFuse(Node *dstNode,
1216  DenseSet<unsigned> *visitedSibNodeIds,
1217  std::pair<unsigned, Value> *idAndMemrefToFuse) {
1218  // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse
1219  // on 'memref'.
1220  auto canFuseWithSibNode = [&](Node *sibNode, Value memref) {
1221  // Skip if 'outEdge' is not a read-after-write dependence.
1222  // TODO: Remove restrict to single load op restriction.
1223  if (sibNode->getLoadOpCount(memref) != 1)
1224  return false;
1225  // Skip if there exists a path of dependent edges between
1226  // 'sibNode' and 'dstNode'.
1227  if (mdg->hasDependencePath(sibNode->id, dstNode->id) ||
1228  mdg->hasDependencePath(dstNode->id, sibNode->id))
1229  return false;
1230  // Skip sib node if it loads to (and stores from) the same memref on
1231  // which it also has an input dependence edge.
1232  DenseSet<Value> loadAndStoreMemrefSet;
1233  sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
1234  if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) {
1235  return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0;
1236  }))
1237  return false;
1238 
1239  // Check that all stores are to the same memref if any.
1240  DenseSet<Value> storeMemrefs;
1241  for (auto *storeOpInst : sibNode->stores) {
1242  storeMemrefs.insert(
1243  cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
1244  }
1245  if (storeMemrefs.size() > 1)
1246  return false;
1247 
1248  // Skip if a memref value in one node is used by a non-affine memref
1249  // access that lies between 'dstNode' and 'sibNode'.
1250  if (hasNonAffineUsersOnThePath(dstNode->id, sibNode->id, mdg) ||
1251  hasNonAffineUsersOnThePath(sibNode->id, dstNode->id, mdg))
1252  return false;
1253  return true;
1254  };
1255 
1256  // Search for siblings which load the same memref block argument.
1257  Block *block = dstNode->op->getBlock();
1258  for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) {
1259  for (Operation *user : block->getArgument(i).getUsers()) {
1260  auto loadOp = dyn_cast<AffineReadOpInterface>(user);
1261  if (!loadOp)
1262  continue;
1263  // Gather loops surrounding 'use'.
1265  getAffineForIVs(*user, &loops);
1266  // Skip 'use' if it is not within a loop nest.
1267  // Find the surrounding affine.for nested immediately within the
1268  // block.
1269  auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
1270  return loop->getBlock() == &mdg->block;
1271  });
1272  // Skip 'use' if it is not within a loop nest in `block`.
1273  if (it == loops.end())
1274  continue;
1275  Node *sibNode = mdg->getForOpNode(*it);
1276  assert(sibNode != nullptr);
1277  // Skip 'use' if it not a sibling to 'dstNode'.
1278  if (sibNode->id == dstNode->id)
1279  continue;
1280  // Skip 'use' if it has been visited.
1281  if (visitedSibNodeIds->count(sibNode->id) > 0)
1282  continue;
1283  // Skip 'use' if it does not load from the same memref as 'dstNode'.
1284  auto memref = loadOp.getMemRef();
1285  if (dstNode->getLoadOpCount(memref) == 0)
1286  continue;
1287  // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
1288  if (canFuseWithSibNode(sibNode, memref)) {
1289  visitedSibNodeIds->insert(sibNode->id);
1290  idAndMemrefToFuse->first = sibNode->id;
1291  idAndMemrefToFuse->second = memref;
1292  return true;
1293  }
1294  }
1295  }
1296 
1297  // Search for siblings by following edges through an intermediate src node.
1298  // Collect candidate 'dstNode' input edges in 'inEdges'.
1301  dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
1302  // Add 'inEdge' if it is a read-after-write dependence.
1303  if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
1304  mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
1305  inEdges.push_back(inEdge);
1306  });
1307 
1308  // Search for sibling nodes to fuse by visiting output edges from each input
1309  // edge in 'inEdges'.
1310  for (auto &inEdge : inEdges) {
1311  // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'.
1314  inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) {
1315  unsigned sibNodeId = outEdge.id;
1316  if (visitedSibNodeIds->count(sibNodeId) > 0)
1317  return;
1318  // Skip output edge if not a sibling using the same memref.
1319  if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
1320  return;
1321  auto *sibNode = mdg->getNode(sibNodeId);
1322  if (!isa<AffineForOp>(sibNode->op))
1323  return;
1324  // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
1325  if (canFuseWithSibNode(sibNode, outEdge.value)) {
1326  // Add candidate 'outEdge' to sibling node.
1327  outEdges.push_back(outEdge);
1328  }
1329  });
1330 
1331  // Add first candidate if any were returned.
1332  if (!outEdges.empty()) {
1333  visitedSibNodeIds->insert(outEdges[0].id);
1334  idAndMemrefToFuse->first = outEdges[0].id;
1335  idAndMemrefToFuse->second = outEdges[0].value;
1336  return true;
1337  }
1338  }
1339  return false;
1340  }
1341 
1342  /// Update data dependence graph state to reflect sibling fusion of 'sibNode'
1343  /// into 'dstNode'.
1344  void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) {
1345  // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion.
1346  mdg->updateEdges(sibNode->id, dstNode->id);
1347 
1348  // Collect dst loop stats after memref privatization transformation.
1349  auto dstForInst = cast<AffineForOp>(dstNode->op);
1350  LoopNestStateCollector dstLoopCollector;
1351  dstLoopCollector.collect(dstForInst);
1352  // Clear and add back loads and stores
1353  mdg->clearNodeLoadAndStores(dstNode->id);
1354  mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
1355  dstLoopCollector.storeOpInsts);
1356  // Remove old sibling loop nest if it no longer has outgoing dependence
1357  // edges, and it does not write to a memref which escapes the block.
1358  if (mdg->getOutEdgeCount(sibNode->id) == 0) {
1359  Operation *op = sibNode->op;
1360  mdg->removeNode(sibNode->id);
1361  op->erase();
1362  }
1363  }
1364 
1365  // Clean up any allocs with no users.
1366  void eraseUnusedMemRefAllocations() {
1367  for (auto &pair : mdg->memrefEdgeCount) {
1368  if (pair.second > 0)
1369  continue;
1370  auto memref = pair.first;
1371  // Skip if there exist other uses (return operation or function calls).
1372  if (!memref.use_empty())
1373  continue;
1374  // Use list expected to match the dep graph info.
1375  auto *op = memref.getDefiningOp();
1376  if (isa_and_nonnull<memref::AllocOp>(op))
1377  op->erase();
1378  }
1379  }
1380 };
1381 
1382 } // namespace
1383 
1384 /// Run fusion on `block`.
1385 void LoopFusion::runOnBlock(Block *block) {
1386  MemRefDependenceGraph g(*block);
1387  if (!g.init()) {
1388  LLVM_DEBUG(llvm::dbgs() << "MDG init failed\n");
1389  return;
1390  }
1391 
1392  std::optional<unsigned> fastMemorySpaceOpt;
1393  if (fastMemorySpace.hasValue())
1394  fastMemorySpaceOpt = fastMemorySpace;
1395  unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
1396  GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
1397  maximalFusion, computeToleranceThreshold);
1398 
1399  if (affineFusionMode == FusionMode::ProducerConsumer)
1400  fusion.runProducerConsumerFusionOnly();
1401  else if (affineFusionMode == FusionMode::Sibling)
1402  fusion.runSiblingFusionOnly();
1403  else
1404  fusion.runGreedyFusion();
1405 }
1406 
1407 void LoopFusion::runOnOperation() {
1408  // Call fusion on every op that has at least two affine.for nests (in post
1409  // order).
1410  getOperation()->walk([&](Operation *op) {
1411  for (Region &region : op->getRegions()) {
1412  for (Block &block : region.getBlocks()) {
1413  auto affineFors = block.getOps<AffineForOp>();
1414  if (!affineFors.empty() && !llvm::hasSingleElement(affineFors))
1415  runOnBlock(&block);
1416  }
1417  }
1418  });
1419 }
1420 
1422  unsigned fastMemorySpace, uint64_t localBufSizeThreshold,
1423  bool maximalFusion, enum FusionMode affineFusionMode) {
1424  return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
1425  maximalFusion, affineFusionMode);
1426 }
MemRefDependenceGraph::Node Node
Definition: Utils.cpp:38
static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, AffineForOp dstForOp, ArrayRef< ComputationSliceState > depthSliceUnions, unsigned maxLegalFusionDepth, unsigned *dstLoopDepth, double computeToleranceThreshold)
Definition: LoopFusion.cpp:443
static bool canRemoveSrcNodeAfterFusion(unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, Operation *fusedLoopInsPoint, const DenseSet< Value > &escapingMemRefs, MemRefDependenceGraph *mdg)
Returns true if node 'srcId' can be removed after fusing it with node 'dstId'.
Definition: LoopFusion.cpp:78
static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, Value memref, MemRefDependenceGraph *mdg)
Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and 'dstId'), if there is any non-aff...
Definition: LoopFusion.cpp:349
static bool isEscapingMemref(Value memref, Block *block)
A memref escapes in the context of the fusion pass if either:
Definition: LoopFusion.cpp:187
static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, DenseSet< Value > &escapingMemRefs)
Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' that escape the block or ...
Definition: LoopFusion.cpp:217
static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, unsigned dstLoopDepth, std::optional< unsigned > fastMemorySpace, uint64_t localBufSizeThreshold)
Definition: LoopFusion.cpp:245
static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg, SmallVectorImpl< unsigned > &srcIdCandidates)
Returns in 'srcIdCandidates' the producer fusion candidates for consumer 'dstId'.
Definition: LoopFusion.cpp:137
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
Definition: AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:379
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:371
void getValues(unsigned start, unsigned end, SmallVectorImpl< Value > *values) const
Returns the Values associated with variables in range [start, end).
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
user_range getUsers() const
Definition: Value.h:228
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
Describes the fusion strategy to be used in the Affine loop fusion utilities.
unsigned getNumCols() const
Returns the number of columns in the constraint system.
bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, AffineForOp dstForOp, LoopNestStats &dstStats, const ComputationSliceState &slice, int64_t *computeCost)
Computes and returns in 'computeCost', the total compute cost of fusing the 'slice' of the loop nest ...
void gatherProducerConsumerMemrefs(ArrayRef< Operation * > srcOps, ArrayRef< Operation * > dstOps, DenseSet< Value > &producerConsumerMemrefs)
Returns in 'producerConsumerMemrefs' the memrefs involved in a producer-consumer dependence between w...
int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats)
Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
void fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, const ComputationSliceState &srcSlice, bool isInnermostSiblingInsertionFusion=false)
Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point and source slice loop bo...
void getAffineForIVs(Operation &op, SmallVectorImpl< AffineForOp > *loops)
Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered from the outermost 'affine....
Definition: Utils.cpp:633
std::optional< int64_t > getMemoryFootprintBytes(AffineForOp forOp, int memorySpace=-1)
Gets the memory footprint of all data touched in the specified memory space in bytes; if the memory s...
Definition: Utils.cpp:1953
std::unique_ptr< Pass > createLoopFusionPass(unsigned fastMemorySpace=0, uint64_t localBufSizeThreshold=0, bool maximalFusion=false, enum FusionMode fusionMode=FusionMode::Greedy)
Creates a loop fusion pass which fuses affine loop nests at the top-level of the operation the pass i...
FusionMode
Fusion mode to attempt.
Definition: Passes.h:31
@ ProducerConsumer
Definition: Passes.h:31
unsigned getInnermostCommonLoopDepth(ArrayRef< Operation * > ops, SmallVectorImpl< AffineForOp > *surroundingLoops=nullptr)
Returns the innermost common loop depth for the set of operations in 'ops'.
Definition: Utils.cpp:1385
bool getLoopNestStats(AffineForOp forOp, LoopNestStats *stats)
Collect loop nest statistics (eg.
AffineForOp sinkSequentialLoops(AffineForOp forOp)
Definition: LoopUtils.cpp:1454
FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, ComputationSliceState *srcSlice, FusionStrategy fusionStrategy=FusionStrategy::Generic)
Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the loop nest rooted at 'dst...
LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, ArrayRef< Value > extraIndices={}, AffineMap indexRemap=AffineMap(), ArrayRef< Value > extraOperands={}, ArrayRef< Value > symbolOperands={}, Operation *domOpFilter=nullptr, Operation *postDomOpFilter=nullptr, bool allowNonDereferencingOps=false, bool replaceInDeallocOp=false)
Replaces all "dereferencing" uses of oldMemRef with newMemRef while optionally remapping the old memr...
Definition: Utils.cpp:1271
std::optional< int64_t > getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType)
Returns the memref's element type's size in bytes where the elemental type is an int or float or a ve...
Definition: Utils.cpp:1196
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their associated operands for a ...
Definition: Utils.h:259
bool isEmpty() const
Returns true if the computation slice is empty.
Definition: Utils.h:291
std::optional< bool > isMaximal() const
Returns true if the computation slice encloses all the iterations of the sliced loop nest.
Definition: Utils.cpp:888
enum mlir::affine::FusionResult::ResultEnum value
SmallVector< Operation *, 4 > loadOpInsts
Definition: Utils.h:40
void collect(Operation *opToWalk)
Definition: Utils.cpp:43
SmallVector< Operation *, 4 > storeOpInsts
Definition: Utils.h:41
LoopNestStats aggregates various per-loop statistics (eg.
DenseMap< unsigned, SmallVector< Edge, 2 > > outEdges
Definition: Utils.h:112
Block & block
The block for which this graph is created to perform fusion.
Definition: Utils.h:211
unsigned addNode(Operation *op)
Definition: Utils.cpp:251
unsigned getIncomingMemRefAccesses(unsigned id, Value memref)
Definition: Utils.cpp:388
void addEdge(unsigned srcId, unsigned dstId, Value value)
Definition: Utils.cpp:316
Node * getForOpNode(AffineForOp forOp)
Definition: Utils.cpp:243
DenseMap< unsigned, Node > nodes
Definition: Utils.h:108
bool hasDependencePath(unsigned srcId, unsigned dstId)
Definition: Utils.cpp:354
void clearNodeLoadAndStores(unsigned id)
Definition: Utils.cpp:576
void updateEdges(unsigned srcId, unsigned dstId, const DenseSet< Value > &privateMemRefs, bool removeSrcId)
Definition: Utils.cpp:504
DenseMap< unsigned, SmallVector< Edge, 2 > > inEdges
Definition: Utils.h:110
void addToNode(unsigned id, const SmallVectorImpl< Operation * > &loads, const SmallVectorImpl< Operation * > &stores)
Definition: Utils.cpp:568
void forEachMemRefInputEdge(unsigned id, const std::function< void(Edge)> &callback)
Definition: Utils.cpp:584
unsigned getOutEdgeCount(unsigned id, Value memref=nullptr)
Definition: Utils.cpp:404
void removeNode(unsigned id)
Definition: Utils.cpp:258
Operation * getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId)
Definition: Utils.cpp:428
void forEachMemRefOutputEdge(unsigned id, const std::function< void(Edge)> &callback)
Definition: Utils.cpp:592
Node * getNode(unsigned id)
Definition: Utils.cpp:236
DenseMap< Value, unsigned > memrefEdgeCount
Definition: Utils.h:115
A region of a memref's data space; this is typically constructed by analyzing load/store op's on this...
Definition: Utils.h:430
FlatAffineValueConstraints * getConstraints()
Definition: Utils.h:469
std::optional< int64_t > getConstantBoundingSizeAndShape(SmallVectorImpl< int64_t > *shape=nullptr, std::vector< SmallVector< int64_t, 4 >> *lbs=nullptr, SmallVectorImpl< int64_t > *lbDivisors=nullptr) const
Returns a constant upper bound on the number of elements in this region if bounded by a known constan...
Definition: Utils.cpp:942
std::optional< int64_t > getRegionSize()
Returns the size of this MemRefRegion in bytes.
Definition: Utils.cpp:1215
LogicalResult compute(Operation *op, unsigned loopDepth, const ComputationSliceState *sliceState=nullptr, bool addMemRefDimBounds=true)
Computes the memory region accessed by this memref with the region represented as constraints symboli...
Definition: Utils.cpp:1044
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.