MLIR  16.0.0git
LoopFusionUtils.cpp
Go to the documentation of this file.
1 //===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion transformation utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/Builders.h"
26 #include "mlir/IR/BuiltinOps.h"
27 #include "mlir/IR/Operation.h"
28 #include "llvm/ADT/DenseMap.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/raw_ostream.h"
32 
33 #define DEBUG_TYPE "loop-fusion-utils"
34 
35 using namespace mlir;
36 
37 // Gathers all load and store memref accesses in 'opA' into 'values', where
38 // 'values[memref] == true' for each store operation.
40  DenseMap<Value, bool> &values) {
41  opA->walk([&](Operation *op) {
42  if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
43  if (values.count(loadOp.getMemRef()) == 0)
44  values[loadOp.getMemRef()] = false;
45  } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
46  values[storeOp.getMemRef()] = true;
47  }
48  });
49 }
50 
51 /// Returns true if 'op' is a load or store operation which access a memref
52 /// accessed 'values' and at least one of the access is a store operation.
53 /// Returns false otherwise.
55  DenseMap<Value, bool> &values) {
56  if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
57  return values.count(loadOp.getMemRef()) > 0 && values[loadOp.getMemRef()];
58  }
59  if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
60  return values.count(storeOp.getMemRef()) > 0;
61  }
62  return false;
63 }
64 
65 // Returns the first operation in range ('opA', 'opB') which has a data
66 // dependence on 'opA'. Returns 'nullptr' of no dependence exists.
68  // Record memref values from all loads/store in loop nest rooted at 'opA'.
69  // Map from memref value to bool which is true if store, false otherwise.
70  DenseMap<Value, bool> values;
71  getLoadAndStoreMemRefAccesses(opA, values);
72 
73  // For each 'opX' in block in range ('opA', 'opB'), check if there is a data
74  // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref
75  // and at least one of the accesses is a store).
76  Operation *firstDepOp = nullptr;
77  for (Block::iterator it = std::next(Block::iterator(opA));
78  it != Block::iterator(opB); ++it) {
79  Operation *opX = &(*it);
80  opX->walk([&](Operation *op) {
81  if (!firstDepOp && isDependentLoadOrStoreOp(op, values))
82  firstDepOp = opX;
83  });
84  if (firstDepOp)
85  break;
86  }
87  return firstDepOp;
88 }
89 
90 // Returns the last operation 'opX' in range ('opA', 'opB'), for which there
91 // exists a data dependence from 'opX' to 'opB'.
92 // Returns 'nullptr' of no dependence exists.
94  // Record memref values from all loads/store in loop nest rooted at 'opB'.
95  // Map from memref value to bool which is true if store, false otherwise.
96  DenseMap<Value, bool> values;
97  getLoadAndStoreMemRefAccesses(opB, values);
98 
99  // For each 'opX' in block in range ('opA', 'opB') in reverse order,
100  // check if there is a data dependence from 'opX' to 'opB':
101  // *) 'opX' and 'opB' access the same memref and at least one of the accesses
102  // is a store.
103  // *) 'opX' produces an SSA Value which is used by 'opB'.
104  Operation *lastDepOp = nullptr;
105  for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB));
106  it != Block::reverse_iterator(opA); ++it) {
107  Operation *opX = &(*it);
108  opX->walk([&](Operation *op) {
109  if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
110  if (isDependentLoadOrStoreOp(op, values)) {
111  lastDepOp = opX;
112  return WalkResult::interrupt();
113  }
114  return WalkResult::advance();
115  }
116  for (auto value : op->getResults()) {
117  for (Operation *user : value.getUsers()) {
119  // Check if any loop in loop nest surrounding 'user' is 'opB'.
120  getLoopIVs(*user, &loops);
121  if (llvm::is_contained(loops, cast<AffineForOp>(opB))) {
122  lastDepOp = opX;
123  return WalkResult::interrupt();
124  }
125  }
126  }
127  return WalkResult::advance();
128  });
129  if (lastDepOp)
130  break;
131  }
132  return lastDepOp;
133 }
134 
135 // Computes and returns an insertion point operation, before which the
136 // the fused <srcForOp, dstForOp> loop nest can be inserted while preserving
137 // dependences. Returns nullptr if no such insertion point is found.
138 static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
139  AffineForOp dstForOp) {
140  bool isSrcForOpBeforeDstForOp =
141  srcForOp->isBeforeInBlock(dstForOp.getOperation());
142  auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
143  auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
144 
145  auto *firstDepOpA =
146  getFirstDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
147  auto *lastDepOpB =
148  getLastDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
149  // Block:
150  // ...
151  // |-- opA
152  // | ...
153  // | lastDepOpB --|
154  // | ... |
155  // |-> firstDepOpA |
156  // ... |
157  // opB <---------
158  //
159  // Valid insertion point range: (lastDepOpB, firstDepOpA)
160  //
161  if (firstDepOpA != nullptr) {
162  if (lastDepOpB != nullptr) {
163  if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB)
164  // No valid insertion point exists which preserves dependences.
165  return nullptr;
166  }
167  // Return insertion point in valid range closest to 'opB'.
168  // TODO: Consider other insertion points in valid range.
169  return firstDepOpA;
170  }
171  // No dependences from 'opA' to operation in range ('opA', 'opB'), return
172  // 'opB' insertion point.
173  return forOpB.getOperation();
174 }
175 
176 // Gathers all load and store ops in loop nest rooted at 'forOp' into
177 // 'loadAndStoreOps'.
178 static bool
179 gatherLoadsAndStores(AffineForOp forOp,
180  SmallVectorImpl<Operation *> &loadAndStoreOps) {
181  bool hasIfOp = false;
182  forOp.walk([&](Operation *op) {
183  if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
184  loadAndStoreOps.push_back(op);
185  else if (isa<AffineIfOp>(op))
186  hasIfOp = true;
187  });
188  return !hasIfOp;
189 }
190 
191 /// Returns the maximum loop depth at which we could fuse producer loop
192 /// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences.
193 // TODO: Generalize this check for sibling and more generic fusion scenarios.
194 // TODO: Support forward slice fusion.
195 static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
196  ArrayRef<Operation *> dstOps) {
197  if (dstOps.empty())
198  // Expected at least one memory operation.
199  // TODO: Revisit this case with a specific example.
200  return 0;
201 
202  // Filter out ops in 'dstOps' that do not use the producer-consumer memref so
203  // that they are not considered for analysis.
204  DenseSet<Value> producerConsumerMemrefs;
205  gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs);
206  SmallVector<Operation *, 4> targetDstOps;
207  for (Operation *dstOp : dstOps) {
208  auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp);
209  Value memref = loadOp ? loadOp.getMemRef()
210  : cast<AffineWriteOpInterface>(dstOp).getMemRef();
211  if (producerConsumerMemrefs.count(memref) > 0)
212  targetDstOps.push_back(dstOp);
213  }
214 
215  assert(!targetDstOps.empty() &&
216  "No dependences between 'srcForOp' and 'dstForOp'?");
217 
218  // Compute the innermost common loop depth for loads and stores.
219  unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps);
220 
221  // Return common loop depth for loads if there are no store ops.
222  if (all_of(targetDstOps,
223  [&](Operation *op) { return isa<AffineReadOpInterface>(op); }))
224  return loopDepth;
225 
226  // Check dependences on all pairs of ops in 'targetDstOps' and store the
227  // minimum loop depth at which a dependence is satisfied.
228  for (unsigned i = 0, e = targetDstOps.size(); i < e; ++i) {
229  auto *srcOpInst = targetDstOps[i];
230  MemRefAccess srcAccess(srcOpInst);
231  for (unsigned j = 0; j < e; ++j) {
232  auto *dstOpInst = targetDstOps[j];
233  MemRefAccess dstAccess(dstOpInst);
234 
235  unsigned numCommonLoops =
236  getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
237  for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
238  FlatAffineValueConstraints dependenceConstraints;
239  // TODO: Cache dependence analysis results, check cache here.
241  srcAccess, dstAccess, d, &dependenceConstraints,
242  /*dependenceComponents=*/nullptr);
243  if (hasDependence(result)) {
244  // Store minimum loop depth and break because we want the min 'd' at
245  // which there is a dependence.
246  loopDepth = std::min(loopDepth, d - 1);
247  break;
248  }
249  }
250  }
251  }
252 
253  return loopDepth;
254 }
255 
256 // TODO: Prevent fusion of loop nests with side-effecting operations.
257 // TODO: This pass performs some computation that is the same for all the depths
258 // (e.g., getMaxLoopDepth). Implement a version of this utility that processes
259 // all the depths at once or only the legal maximal depth for maximal fusion.
260 FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
261  unsigned dstLoopDepth,
262  ComputationSliceState *srcSlice,
263  FusionStrategy fusionStrategy) {
264  // Return 'failure' if 'dstLoopDepth == 0'.
265  if (dstLoopDepth == 0) {
266  LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n");
268  }
269  // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
270  auto *block = srcForOp->getBlock();
271  if (block != dstForOp->getBlock()) {
272  LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n");
274  }
275 
276  // Return 'failure' if no valid insertion point for fused loop nest in 'block'
277  // exists which would preserve dependences.
278  if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
279  LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n");
281  }
282 
283  // Check if 'srcForOp' precedes 'dstForOp' in 'block'.
284  bool isSrcForOpBeforeDstForOp =
285  srcForOp->isBeforeInBlock(dstForOp.getOperation());
286  // 'forOpA' executes before 'forOpB' in 'block'.
287  auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
288  auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
289 
290  // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'.
292  if (!gatherLoadsAndStores(forOpA, opsA)) {
293  LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
295  }
296 
297  // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'.
299  if (!gatherLoadsAndStores(forOpB, opsB)) {
300  LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
302  }
303 
304  // Return 'failure' if fusing loops at depth 'dstLoopDepth' wouldn't preserve
305  // loop dependences.
306  // TODO: Enable this check for sibling and more generic loop fusion
307  // strategies.
308  if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) {
309  // TODO: 'getMaxLoopDepth' does not support forward slice fusion.
310  assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion");
311  if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) {
312  LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n");
314  }
315  }
316 
317  // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'.
318  unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops(
319  *srcForOp.getOperation(), *dstForOp.getOperation());
320 
321  // Filter out ops in 'opsA' to compute the slice union based on the
322  // assumptions made by the fusion strategy.
323  SmallVector<Operation *, 4> strategyOpsA;
324  switch (fusionStrategy.getStrategy()) {
326  // Generic fusion. Take into account all the memory operations to compute
327  // the slice union.
328  strategyOpsA.append(opsA.begin(), opsA.end());
329  break;
331  // Producer-consumer fusion (AffineLoopFusion pass) only takes into
332  // account stores in 'srcForOp' to compute the slice union.
333  for (Operation *op : opsA) {
334  if (isa<AffineWriteOpInterface>(op))
335  strategyOpsA.push_back(op);
336  }
337  break;
339  // Sibling fusion (AffineLoopFusion pass) only takes into account the loads
340  // to 'memref' in 'srcForOp' to compute the slice union.
341  for (Operation *op : opsA) {
342  auto load = dyn_cast<AffineReadOpInterface>(op);
343  if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef())
344  strategyOpsA.push_back(op);
345  }
346  break;
347  }
348 
349  // Compute union of computation slices computed between all pairs of ops
350  // from 'forOpA' and 'forOpB'.
351  SliceComputationResult sliceComputationResult =
352  mlir::computeSliceUnion(strategyOpsA, opsB, dstLoopDepth, numCommonLoops,
353  isSrcForOpBeforeDstForOp, srcSlice);
354  if (sliceComputationResult.value == SliceComputationResult::GenericFailure) {
355  LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
357  }
358  if (sliceComputationResult.value ==
360  LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n");
362  }
363 
364  return FusionResult::Success;
365 }
366 
367 /// Patch the loop body of a forOp that is a single iteration reduction loop
368 /// into its containing block.
370  bool siblingFusionUser) {
371  // Check if the reduction loop is a single iteration loop.
372  Optional<uint64_t> tripCount = getConstantTripCount(forOp);
373  if (!tripCount || *tripCount != 1)
374  return failure();
375  auto iterOperands = forOp.getIterOperands();
376  auto *parentOp = forOp->getParentOp();
377  if (!isa<AffineForOp>(parentOp))
378  return failure();
379  auto newOperands = forOp.getBody()->getTerminator()->getOperands();
380  OpBuilder b(parentOp);
381  // Replace the parent loop and add iteroperands and results from the `forOp`.
382  AffineForOp parentForOp = forOp->getParentOfType<AffineForOp>();
383  AffineForOp newLoop = replaceForOpWithNewYields(
384  b, parentForOp, iterOperands, newOperands, forOp.getRegionIterArgs());
385 
386  // For sibling-fusion users, collect operations that use the results of the
387  // `forOp` outside the new parent loop that has absorbed all its iter args
388  // and operands. These operations will be moved later after the results
389  // have been replaced.
390  SetVector<Operation *> forwardSlice;
391  if (siblingFusionUser) {
392  for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
393  SetVector<Operation *> tmpForwardSlice;
394  getForwardSlice(forOp.getResult(i), &tmpForwardSlice);
395  forwardSlice.set_union(tmpForwardSlice);
396  }
397  }
398  // Update the results of the `forOp` in the new loop.
399  for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
400  forOp.getResult(i).replaceAllUsesWith(
401  newLoop.getResult(i + parentOp->getNumResults()));
402  }
403  // For sibling-fusion users, move operations that use the results of the
404  // `forOp` outside the new parent loop
405  if (siblingFusionUser) {
406  topologicalSort(forwardSlice);
407  for (Operation *op : llvm::reverse(forwardSlice))
408  op->moveAfter(newLoop);
409  }
410  // Replace the induction variable.
411  auto iv = forOp.getInductionVar();
412  iv.replaceAllUsesWith(newLoop.getInductionVar());
413  // Replace the iter args.
414  auto forOpIterArgs = forOp.getRegionIterArgs();
415  for (auto it : llvm::zip(forOpIterArgs, newLoop.getRegionIterArgs().take_back(
416  forOpIterArgs.size()))) {
417  std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
418  }
419  // Move the loop body operations, except for its terminator, to the loop's
420  // containing block.
421  forOp.getBody()->back().erase();
422  auto *parentBlock = forOp->getBlock();
423  parentBlock->getOperations().splice(Block::iterator(forOp),
424  forOp.getBody()->getOperations());
425  forOp.erase();
426  parentForOp.erase();
427  return success();
428 }
429 
430 /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
431 /// and source slice loop bounds specified in 'srcSlice'.
432 void mlir::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
433  const ComputationSliceState &srcSlice,
434  bool isInnermostSiblingInsertion) {
435  // Clone 'srcForOp' into 'dstForOp' at 'srcSlice->insertPoint'.
436  OpBuilder b(srcSlice.insertPoint->getBlock(), srcSlice.insertPoint);
437  BlockAndValueMapping mapper;
438  b.clone(*srcForOp, mapper);
439 
440  // Update 'sliceLoopNest' upper and lower bounds from computed 'srcSlice'.
441  SmallVector<AffineForOp, 4> sliceLoops;
442  for (unsigned i = 0, e = srcSlice.ivs.size(); i < e; ++i) {
443  auto loopIV = mapper.lookupOrNull(srcSlice.ivs[i]);
444  if (!loopIV)
445  continue;
446  auto forOp = getForInductionVarOwner(loopIV);
447  sliceLoops.push_back(forOp);
448  if (AffineMap lbMap = srcSlice.lbs[i]) {
449  auto lbOperands = srcSlice.lbOperands[i];
450  canonicalizeMapAndOperands(&lbMap, &lbOperands);
451  forOp.setLowerBound(lbOperands, lbMap);
452  }
453  if (AffineMap ubMap = srcSlice.ubs[i]) {
454  auto ubOperands = srcSlice.ubOperands[i];
455  canonicalizeMapAndOperands(&ubMap, &ubOperands);
456  forOp.setUpperBound(ubOperands, ubMap);
457  }
458  }
459 
460  llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
461  auto srcIsUnitSlice = [&]() {
462  return (buildSliceTripCountMap(srcSlice, &sliceTripCountMap) &&
463  (getSliceIterationCount(sliceTripCountMap) == 1));
464  };
465  // Fix up and if possible, eliminate single iteration loops.
466  for (AffineForOp forOp : sliceLoops) {
468  isInnermostSiblingInsertion && srcIsUnitSlice())
469  // Patch reduction loop - only ones that are sibling-fused with the
470  // destination loop - into the parent loop.
471  (void)promoteSingleIterReductionLoop(forOp, true);
472  else
473  // Promote any single iteration slice loops.
475  }
476 }
477 
478 /// Collect loop nest statistics (eg. loop trip count and operation count)
479 /// in 'stats' for loop nest rooted at 'forOp'. Returns true on success,
480 /// returns false otherwise.
481 bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) {
482  auto walkResult = forOpRoot.walk([&](AffineForOp forOp) {
483  auto *childForOp = forOp.getOperation();
484  auto *parentForOp = forOp->getParentOp();
485  if (!llvm::isa<func::FuncOp>(parentForOp)) {
486  if (!isa<AffineForOp>(parentForOp)) {
487  LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n");
488  return WalkResult::interrupt();
489  }
490  // Add mapping to 'forOp' from its parent AffineForOp.
491  stats->loopMap[parentForOp].push_back(forOp);
492  }
493 
494  // Record the number of op operations in the body of 'forOp'.
495  unsigned count = 0;
496  stats->opCountMap[childForOp] = 0;
497  for (auto &op : *forOp.getBody()) {
498  if (!isa<AffineForOp, AffineIfOp>(op))
499  ++count;
500  }
501  stats->opCountMap[childForOp] = count;
502 
503  // Record trip count for 'forOp'. Set flag if trip count is not
504  // constant.
505  Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
506  if (!maybeConstTripCount) {
507  // Currently only constant trip count loop nests are supported.
508  LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n");
509  return WalkResult::interrupt();
510  }
511 
512  stats->tripCountMap[childForOp] = *maybeConstTripCount;
513  return WalkResult::advance();
514  });
515  return !walkResult.wasInterrupted();
516 }
517 
518 // Computes the total cost of the loop nest rooted at 'forOp'.
519 // Currently, the total cost is computed by counting the total operation
520 // instance count (i.e. total number of operations in the loop bodyloop
521 // operation count * loop trip count) for the entire loop nest.
522 // If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
523 // specified in the map when computing the total op instance count.
524 // NOTEs: 1) This is used to compute the cost of computation slices, which are
525 // sliced along the iteration dimension, and thus reduce the trip count.
526 // If 'computeCostMap' is non-null, the total op count for forOps specified
527 // in the map is increased (not overridden) by adding the op count from the
528 // map to the existing op count for the for loop. This is done before
529 // multiplying by the loop's trip count, and is used to model the cost of
530 // inserting a sliced loop nest of known cost into the loop's body.
531 // 2) This is also used to compute the cost of fusing a slice of some loop nest
532 // within another loop.
533 static int64_t getComputeCostHelper(
534  Operation *forOp, LoopNestStats &stats,
535  llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
536  DenseMap<Operation *, int64_t> *computeCostMap) {
537  // 'opCount' is the total number operations in one iteration of 'forOp' body,
538  // minus terminator op which is a no-op.
539  int64_t opCount = stats.opCountMap[forOp] - 1;
540  if (stats.loopMap.count(forOp) > 0) {
541  for (auto childForOp : stats.loopMap[forOp]) {
542  opCount += getComputeCostHelper(childForOp.getOperation(), stats,
543  tripCountOverrideMap, computeCostMap);
544  }
545  }
546  // Add in additional op instances from slice (if specified in map).
547  if (computeCostMap != nullptr) {
548  auto it = computeCostMap->find(forOp);
549  if (it != computeCostMap->end()) {
550  opCount += it->second;
551  }
552  }
553  // Override trip count (if specified in map).
554  int64_t tripCount = stats.tripCountMap[forOp];
555  if (tripCountOverrideMap != nullptr) {
556  auto it = tripCountOverrideMap->find(forOp);
557  if (it != tripCountOverrideMap->end()) {
558  tripCount = it->second;
559  }
560  }
561  // Returns the total number of dynamic instances of operations in loop body.
562  return tripCount * opCount;
563 }
564 
565 /// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
566 /// Currently, the total cost is computed by counting the total operation
567 /// instance count (i.e. total number of operations in the loop body * loop
568 /// trip count) for the entire loop nest.
569 int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
570  return getComputeCostHelper(forOp.getOperation(), stats,
571  /*tripCountOverrideMap=*/nullptr,
572  /*computeCostMap=*/nullptr);
573 }
574 
575 /// Computes and returns in 'computeCost', the total compute cost of fusing the
576 /// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently,
577 /// the total cost is computed by counting the total operation instance count
578 /// (i.e. total number of operations in the loop body * loop trip count) for
579 /// the entire loop nest.
580 bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
581  AffineForOp dstForOp, LoopNestStats &dstStats,
582  const ComputationSliceState &slice,
583  int64_t *computeCost) {
584  llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
585  DenseMap<Operation *, int64_t> computeCostMap;
586 
587  // Build trip count map for computation slice.
588  if (!buildSliceTripCountMap(slice, &sliceTripCountMap))
589  return false;
590  // Checks whether a store to load forwarding will happen.
591  int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
592  assert(sliceIterationCount > 0);
593  bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
594  auto *insertPointParent = slice.insertPoint->getParentOp();
595 
596  // The store and loads to this memref will disappear.
597  // TODO: Add load coalescing to memref data flow opt pass.
598  if (storeLoadFwdGuaranteed) {
599  // Subtract from operation count the loads/store we expect load/store
600  // forwarding to remove.
601  unsigned storeCount = 0;
602  llvm::SmallDenseSet<Value, 4> storeMemrefs;
603  srcForOp.walk([&](Operation *op) {
604  if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
605  storeMemrefs.insert(storeOp.getMemRef());
606  ++storeCount;
607  }
608  });
609  // Subtract out any store ops in single-iteration src slice loop nest.
610  if (storeCount > 0)
611  computeCostMap[insertPointParent] = -storeCount;
612  // Subtract out any load users of 'storeMemrefs' nested below
613  // 'insertPointParent'.
614  for (auto value : storeMemrefs) {
615  for (auto *user : value.getUsers()) {
616  if (auto loadOp = dyn_cast<AffineReadOpInterface>(user)) {
618  // Check if any loop in loop nest surrounding 'user' is
619  // 'insertPointParent'.
620  getLoopIVs(*user, &loops);
621  if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) {
622  if (auto forOp =
623  dyn_cast_or_null<AffineForOp>(user->getParentOp())) {
624  if (computeCostMap.count(forOp) == 0)
625  computeCostMap[forOp] = 0;
626  computeCostMap[forOp] -= 1;
627  }
628  }
629  }
630  }
631  }
632  }
633 
634  // Compute op instance count for the src loop nest with iteration slicing.
635  int64_t sliceComputeCost = getComputeCostHelper(
636  srcForOp.getOperation(), srcStats, &sliceTripCountMap, &computeCostMap);
637 
638  // Compute cost of fusion for this depth.
639  computeCostMap[insertPointParent] = sliceComputeCost;
640 
641  *computeCost =
642  getComputeCostHelper(dstForOp.getOperation(), dstStats,
643  /*tripCountOverrideMap=*/nullptr, &computeCostMap);
644  return true;
645 }
646 
647 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a
648 /// producer-consumer dependence between write ops in 'srcOps' and read ops in
649 /// 'dstOps'.
652  DenseSet<Value> &producerConsumerMemrefs) {
653  // Gather memrefs from stores in 'srcOps'.
654  DenseSet<Value> srcStoreMemRefs;
655  for (Operation *op : srcOps)
656  if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op))
657  srcStoreMemRefs.insert(storeOp.getMemRef());
658 
659  // Compute the intersection between memrefs from stores in 'srcOps' and
660  // memrefs from loads in 'dstOps'.
661  for (Operation *op : dstOps)
662  if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
663  if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0)
664  producerConsumerMemrefs.insert(loadOp.getMemRef());
665 }
DependenceResult checkMemrefAccessDependence(const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, unsigned loopDepth, FlatAffineValueConstraints *dependenceConstraints, SmallVector< DependenceComponent, 2 > *dependenceComponents, bool allowRAR=false)
Include the generated interface declarations.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
OpListType::reverse_iterator reverse_iterator
Definition: Block.h:132
SmallVector< Value, 4 > ivs
Definition: Utils.h:78
static bool isDependentLoadOrStoreOp(Operation *op, DenseMap< Value, bool > &values)
Returns true if &#39;op&#39; is a load or store operation which access a memref accessed &#39;values&#39; and at leas...
bool isLoopParallelAndContainsReduction(AffineForOp forOp)
Returns whether a loop is a parallel loop and contains a reduction loop.
Definition: Utils.cpp:1339
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
Definition: LoopUtils.cpp:131
bool buildSliceTripCountMap(const ComputationSliceState &slice, llvm::SmallDenseMap< Operation *, uint64_t, 8 > *tripCountMap)
Builds a map &#39;tripCountMap&#39; from AffineForOp to constant trip count for loop nest surrounding represe...
Definition: Utils.cpp:993
bool isBeforeInBlock(Operation *other)
Given an operation &#39;other&#39; that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:261
DenseMap< Operation *, uint64_t > opCountMap
Map from AffineForOp to count of operations in its loop body.
Value getSiblingFusionMemRef() const
Returns the memref attached to this sibling fusion strategy.
void getLoopIVs(Operation &op, SmallVectorImpl< AffineForOp > *loops)
Populates &#39;loops&#39; with IVs of the loops surrounding &#39;op&#39; ordered from the outermost &#39;affine...
Definition: Utils.cpp:35
static bool gatherLoadsAndStores(AffineForOp forOp, SmallVectorImpl< Operation *> &loadAndStoreOps)
LoopNestStats aggregates various per-loop statistics (eg.
void fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, const ComputationSliceState &srcSlice, bool isInnermostSiblingInsertionFusion=false)
Fuses &#39;srcForOp&#39; into &#39;dstForOp&#39; with destination loop block insertion point and source slice loop bo...
Checks whether two accesses to the same memref access the same element.
unsigned getNumCommonSurroundingLoops(Operation &a, Operation &b)
Returns the number of surrounding loops common to both A and B.
Definition: Utils.cpp:1271
Enumerates different result statuses of slice computation by computeSliceUnion
Definition: Utils.h:62
static constexpr const bool value
DenseMap< Operation *, SmallVector< AffineForOp, 2 > > loopMap
Map from AffineForOp to immediate child AffineForOps in its loop body.
unsigned getInnermostCommonLoopDepth(ArrayRef< Operation *> ops, SmallVectorImpl< AffineForOp > *surroundingLoops=nullptr)
Returns the innermost common loop depth for the set of operations in &#39;ops&#39;.
Definition: Utils.cpp:779
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::enable_if< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT >::type walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one)...
Definition: Operation.h:574
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
OpListType::iterator iterator
Definition: Block.h:131
StrategyEnum getStrategy() const
Returns the fusion strategy.
bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, AffineForOp dstForOp, LoopNestStats &dstStats, const ComputationSliceState &slice, int64_t *computeCost)
Computes and returns in &#39;computeCost&#39;, the total compute cost of fusing the &#39;slice&#39; of the loop nest ...
ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their associated operands for a ...
Definition: Utils.h:75
SmallVector< AffineMap, 4 > lbs
Definition: Utils.h:80
SetVector< Operation * > topologicalSort(const SetVector< Operation *> &toSort)
Multi-root DAG topological sort.
bool hasDependence(DependenceResult result)
Utility function that returns true if the provided DependenceResult corresponds to a dependence resul...
bool getLoopNestStats(AffineForOp forOp, LoopNestStats *stats)
Collect loop nest statistics (eg.
void gatherProducerConsumerMemrefs(ArrayRef< Operation *> srcOps, ArrayRef< Operation *> dstOps, DenseSet< Value > &producerConsumerMemrefs)
Returns in &#39;producerConsumerMemrefs&#39; the memrefs involved in a producer-consumer dependence between w...
Block::iterator insertPoint
Definition: Utils.h:88
void getForwardSlice(Operation *op, SetVector< Operation *> *forwardSlice, TransitiveFilter filter=nullptr)
Fills forwardSlice with the computed forward slice (i.e.
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
Definition: AffineOps.cpp:1066
static WalkResult advance()
Definition: Visitors.h:51
DenseMap< Operation *, uint64_t > tripCountMap
Map from AffineForOp to its constant trip count.
static void getLoadAndStoreMemRefAccesses(Operation *opA, DenseMap< Value, bool > &values)
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:42
static WalkResult interrupt()
Definition: Visitors.h:50
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.
static Operation * getLastDependentOpInRange(Operation *opA, Operation *opB)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp, bool siblingFusionUser)
Patch the loop body of a forOp that is a single iteration reduction loop into its containing block...
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: AffineOps.cpp:2144
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Describes the fusion strategy to be used in the Affine loop fusion utilities.
SmallVector< AffineMap, 4 > ubs
Definition: Utils.h:82
static Operation * getFusedLoopNestInsertionPoint(AffineForOp srcForOp, AffineForOp dstForOp)
SliceComputationResult computeSliceUnion(ArrayRef< Operation *> opsA, ArrayRef< Operation *> opsB, unsigned loopDepth, unsigned numCommonLoops, bool isBackwardSlice, ComputationSliceState *sliceUnion)
Computes in &#39;sliceUnion&#39; the union of all slice bounds computed at &#39;loopDepth&#39; between all dependent ...
Definition: Utils.cpp:811
FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, ComputationSliceState *srcSlice, FusionStrategy fusionStrategy=FusionStrategy::Generic)
Checks the feasibility of fusing the loop nest rooted at &#39;srcForOp&#39; into the loop nest rooted at &#39;dst...
std::vector< SmallVector< Value, 4 > > ubOperands
Definition: Utils.h:86
static int64_t getComputeCostHelper(Operation *forOp, LoopNestStats &stats, llvm::SmallDenseMap< Operation *, uint64_t, 8 > *tripCountOverrideMap, DenseMap< Operation *, int64_t > *computeCostMap)
FlatAffineValueConstraints represents an extension of IntegerPolyhedron where each non-local variable...
uint64_t getSliceIterationCount(const llvm::SmallDenseMap< Operation *, uint64_t, 8 > &sliceTripCountMap)
Return the number of iterations for the slicetripCountMap provided.
Definition: Utils.cpp:1031
enum mlir::SliceComputationResult::ResultEnum value
AffineForOp replaceForOpWithNewYields(OpBuilder &b, AffineForOp loop, ValueRange newIterOperands, ValueRange newYieldedValues, ValueRange newIterArgs, bool replaceLoopResults=true)
Replace loop with a new loop where newIterOperands are appended with new initialization values and ne...
Definition: AffineOps.cpp:2246
Encapsulates a memref load or store access information.
static Operation * getFirstDependentOpInRange(Operation *opA, Operation *opB)
Optional< uint64_t > getConstantTripCount(AffineForOp forOp)
Returns the trip count of the loop if it&#39;s a constant, None otherwise.
std::vector< SmallVector< Value, 4 > > lbOperands
Definition: Utils.h:84
result_range getResults()
Definition: Operation.h:332
This class helps build Operations.
Definition: Builders.h:192
int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats)
Computes the total cost of the loop nest rooted at &#39;forOp&#39; using &#39;stats&#39;.
static unsigned getMaxLoopDepth(ArrayRef< Operation *> srcOps, ArrayRef< Operation *> dstOps)
Returns the maximum loop depth at which we could fuse producer loop &#39;srcForOp&#39; into consumer loop &#39;ds...