MLIR  22.0.0git
Utils.cpp
Go to the documentation of this file.
1 //===- Utils.cpp ---- Misc utilities for analysis -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous analysis routines for non-loop IR
10 // structures.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
23 #include "mlir/IR/IntegerSet.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/DebugLog.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include <optional>
29 
30 #define DEBUG_TYPE "analysis-utils"
31 
32 using namespace mlir;
33 using namespace affine;
34 using namespace presburger;
35 
36 using llvm::SmallDenseMap;
37 
39 
40 // LoopNestStateCollector walks loop nests and collects load and store
41 // operations, and whether or not a region holding op other than ForOp and IfOp
42 // was encountered in the loop nest.
44  opToWalk->walk([&](Operation *op) {
45  if (auto forOp = dyn_cast<AffineForOp>(op)) {
46  forOps.push_back(forOp);
47  } else if (isa<AffineReadOpInterface>(op)) {
48  loadOpInsts.push_back(op);
49  } else if (isa<AffineWriteOpInterface>(op)) {
50  storeOpInsts.push_back(op);
51  } else {
52  auto memInterface = dyn_cast<MemoryEffectOpInterface>(op);
53  if (!memInterface) {
55  // This op itself is memory-effect free.
56  return;
57  // Check operands. Eg. ops like the `call` op are handled here.
58  for (Value v : op->getOperands()) {
59  if (!isa<MemRefType>(v.getType()))
60  continue;
61  // Conservatively, we assume the memref is read and written to.
62  memrefLoads.push_back(op);
63  memrefStores.push_back(op);
64  }
65  } else {
66  // Non-affine loads and stores.
67  if (hasEffect<MemoryEffects::Read>(op))
68  memrefLoads.push_back(op);
69  if (hasEffect<MemoryEffects::Write>(op))
70  memrefStores.push_back(op);
71  if (hasEffect<MemoryEffects::Free>(op))
72  memrefFrees.push_back(op);
73  }
74  }
75  });
76 }
77 
78 unsigned Node::getLoadOpCount(Value memref) const {
79  unsigned loadOpCount = 0;
80  for (Operation *loadOp : loads) {
81  // Common case: affine reads.
82  if (auto affineLoad = dyn_cast<AffineReadOpInterface>(loadOp)) {
83  if (memref == affineLoad.getMemRef())
84  ++loadOpCount;
85  } else if (hasEffect<MemoryEffects::Read>(loadOp, memref)) {
86  ++loadOpCount;
87  }
88  }
89  return loadOpCount;
90 }
91 
92 // Returns the store op count for 'memref'.
93 unsigned Node::getStoreOpCount(Value memref) const {
94  unsigned storeOpCount = 0;
95  for (auto *storeOp : llvm::concat<Operation *const>(stores, memrefStores)) {
96  // Common case: affine writes.
97  if (auto affineStore = dyn_cast<AffineWriteOpInterface>(storeOp)) {
98  if (memref == affineStore.getMemRef())
99  ++storeOpCount;
100  } else if (hasEffect<MemoryEffects::Write>(const_cast<Operation *>(storeOp),
101  memref)) {
102  ++storeOpCount;
103  }
104  }
105  return storeOpCount;
106 }
107 
108 // Returns the store op count for 'memref'.
109 unsigned Node::hasStore(Value memref) const {
110  return llvm::any_of(
111  llvm::concat<Operation *const>(stores, memrefStores),
112  [&](Operation *storeOp) {
113  if (auto affineStore = dyn_cast<AffineWriteOpInterface>(storeOp)) {
114  if (memref == affineStore.getMemRef())
115  return true;
116  } else if (hasEffect<MemoryEffects::Write>(storeOp, memref)) {
117  return true;
118  }
119  return false;
120  });
121 }
122 
123 unsigned Node::hasFree(Value memref) const {
124  return llvm::any_of(memrefFrees, [&](Operation *freeOp) {
125  return hasEffect<MemoryEffects::Free>(freeOp, memref);
126  });
127 }
128 
129 // Returns all store ops in 'storeOps' which access 'memref'.
130 void Node::getStoreOpsForMemref(Value memref,
131  SmallVectorImpl<Operation *> *storeOps) const {
132  for (Operation *storeOp : stores) {
133  if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
134  storeOps->push_back(storeOp);
135  }
136 }
137 
138 // Returns all load ops in 'loadOps' which access 'memref'.
139 void Node::getLoadOpsForMemref(Value memref,
140  SmallVectorImpl<Operation *> *loadOps) const {
141  for (Operation *loadOp : loads) {
142  if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
143  loadOps->push_back(loadOp);
144  }
145 }
146 
147 // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
148 // has at least one load and store operation.
149 void Node::getLoadAndStoreMemrefSet(
150  DenseSet<Value> *loadAndStoreMemrefSet) const {
151  llvm::SmallDenseSet<Value, 2> loadMemrefs;
152  for (Operation *loadOp : loads) {
153  loadMemrefs.insert(cast<AffineReadOpInterface>(loadOp).getMemRef());
154  }
155  for (Operation *storeOp : stores) {
156  auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
157  if (loadMemrefs.count(memref) > 0)
158  loadAndStoreMemrefSet->insert(memref);
159  }
160 }
161 
162 /// Returns the values that this op has a memref effect of type `EffectTys` on,
163 /// not considering recursive effects.
164 template <typename... EffectTys>
166  auto memOp = dyn_cast<MemoryEffectOpInterface>(op);
167  if (!memOp) {
169  // No effects.
170  return;
171  // Memref operands have to be considered as being affected.
172  for (Value operand : op->getOperands()) {
173  if (isa<MemRefType>(operand.getType()))
174  values.push_back(operand);
175  }
176  return;
177  }
179  memOp.getEffects(effects);
180  for (auto &effect : effects) {
181  Value effectVal = effect.getValue();
182  if (isa<EffectTys...>(effect.getEffect()) && effectVal &&
183  isa<MemRefType>(effectVal.getType()))
184  values.push_back(effectVal);
185  };
186 }
187 
188 /// Add `op` to MDG creating a new node and adding its memory accesses (affine
189 /// or non-affine to memrefAccesses (memref -> list of nodes with accesses) map.
190 static Node *
192  DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
193  auto &nodes = mdg.nodes;
194  // Create graph node 'id' to represent top-level 'forOp' and record
195  // all loads and store accesses it contains.
196  LoopNestStateCollector collector;
197  collector.collect(nodeOp);
198  unsigned newNodeId = mdg.nextNodeId++;
199  Node &node = nodes.insert({newNodeId, Node(newNodeId, nodeOp)}).first->second;
200  for (Operation *op : collector.loadOpInsts) {
201  node.loads.push_back(op);
202  auto memref = cast<AffineReadOpInterface>(op).getMemRef();
203  memrefAccesses[memref].insert(node.id);
204  }
205  for (Operation *op : collector.storeOpInsts) {
206  node.stores.push_back(op);
207  auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
208  memrefAccesses[memref].insert(node.id);
209  }
210  for (Operation *op : collector.memrefLoads) {
211  SmallVector<Value> effectedValues;
212  getEffectedValues<MemoryEffects::Read>(op, effectedValues);
213  if (llvm::any_of(((ValueRange)effectedValues).getTypes(),
214  [](Type type) { return !isa<MemRefType>(type); }))
215  // We do not know the interaction here.
216  return nullptr;
217  for (Value memref : effectedValues)
218  memrefAccesses[memref].insert(node.id);
219  node.memrefLoads.push_back(op);
220  }
221  for (Operation *op : collector.memrefStores) {
222  SmallVector<Value> effectedValues;
223  getEffectedValues<MemoryEffects::Write>(op, effectedValues);
224  if (llvm::any_of((ValueRange(effectedValues)).getTypes(),
225  [](Type type) { return !isa<MemRefType>(type); }))
226  return nullptr;
227  for (Value memref : effectedValues)
228  memrefAccesses[memref].insert(node.id);
229  node.memrefStores.push_back(op);
230  }
231  for (Operation *op : collector.memrefFrees) {
232  SmallVector<Value> effectedValues;
233  getEffectedValues<MemoryEffects::Free>(op, effectedValues);
234  if (llvm::any_of((ValueRange(effectedValues)).getTypes(),
235  [](Type type) { return !isa<MemRefType>(type); }))
236  return nullptr;
237  for (Value memref : effectedValues)
238  memrefAccesses[memref].insert(node.id);
239  node.memrefFrees.push_back(op);
240  }
241 
242  return &node;
243 }
244 
245 /// Returns the memref being read/written by a memref/affine load/store op.
246 static Value getMemRef(Operation *memOp) {
247  if (auto memrefLoad = dyn_cast<memref::LoadOp>(memOp))
248  return memrefLoad.getMemRef();
249  if (auto affineLoad = dyn_cast<AffineReadOpInterface>(memOp))
250  return affineLoad.getMemRef();
251  if (auto memrefStore = dyn_cast<memref::StoreOp>(memOp))
252  return memrefStore.getMemRef();
253  if (auto affineStore = dyn_cast<AffineWriteOpInterface>(memOp))
254  return affineStore.getMemRef();
255  llvm_unreachable("unexpected op");
256 }
257 
258 /// Returns true if there may be a dependence on `memref` from srcNode's
259 /// memory ops to dstNode's memory ops, while using the affine memory
260 /// dependence analysis checks. The method assumes that there is at least one
261 /// memory op in srcNode's loads and stores on `memref`, and similarly for
262 /// `dstNode`. `srcNode.op` and `destNode.op` are expected to be nested in the
263 /// same block and so the dependences are tested at the depth of that block.
264 static bool mayDependence(const Node &srcNode, const Node &dstNode,
265  Value memref) {
266  assert(srcNode.op->getBlock() == dstNode.op->getBlock());
267  if (!isa<AffineForOp>(srcNode.op) || !isa<AffineForOp>(dstNode.op))
268  return true;
269 
270  // Conservatively handle dependences involving non-affine load/stores. Return
271  // true if there exists a conflicting read/write access involving such.
272 
273  // Check whether there is a dependence from a source read/write op to a
274  // destination read/write one; all expected to be memref/affine load/store.
275  auto hasNonAffineDep = [&](ArrayRef<Operation *> srcMemOps,
276  ArrayRef<Operation *> dstMemOps) {
277  return llvm::any_of(srcMemOps, [&](Operation *srcOp) {
278  Value srcMemref = getMemRef(srcOp);
279  if (srcMemref != memref)
280  return false;
281  return llvm::find_if(dstMemOps, [&](Operation *dstOp) {
282  return srcMemref == getMemRef(dstOp);
283  }) != dstMemOps.end();
284  });
285  };
286 
288  // Between non-affine src stores and dst load/store.
289  llvm::append_range(dstOps, llvm::concat<Operation *const>(
290  dstNode.loads, dstNode.stores,
291  dstNode.memrefLoads, dstNode.memrefStores));
292  if (hasNonAffineDep(srcNode.memrefStores, dstOps))
293  return true;
294  // Between non-affine loads and dst stores.
295  dstOps.clear();
296  llvm::append_range(dstOps, llvm::concat<Operation *const>(
297  dstNode.stores, dstNode.memrefStores));
298  if (hasNonAffineDep(srcNode.memrefLoads, dstOps))
299  return true;
300  // Between affine stores and memref load/stores.
301  dstOps.clear();
302  llvm::append_range(dstOps, llvm::concat<Operation *const>(
303  dstNode.memrefLoads, dstNode.memrefStores));
304  if (hasNonAffineDep(srcNode.stores, dstOps))
305  return true;
306  // Between affine loads and memref stores.
307  dstOps.clear();
308  llvm::append_range(dstOps, dstNode.memrefStores);
309  if (hasNonAffineDep(srcNode.loads, dstOps))
310  return true;
311 
312  // Affine load/store pairs. We don't need to check for locally allocated
313  // memrefs since the dependence analysis here is between mem ops from
314  // srcNode's for op to dstNode's for op at the depth at which those
315  // `affine.for` ops are nested, i.e., dependences at depth `d + 1` where
316  // `d` is the number of common surrounding loops.
317  for (auto *srcMemOp :
318  llvm::concat<Operation *const>(srcNode.stores, srcNode.loads)) {
319  MemRefAccess srcAcc(srcMemOp);
320  if (srcAcc.memref != memref)
321  continue;
322  for (auto *destMemOp :
323  llvm::concat<Operation *const>(dstNode.stores, dstNode.loads)) {
324  MemRefAccess destAcc(destMemOp);
325  if (destAcc.memref != memref)
326  continue;
327  // Check for a top-level dependence between srcNode and destNode's ops.
329  srcAcc, destAcc, getNestingDepth(srcNode.op) + 1)))
330  return true;
331  }
332  }
333  return false;
334 }
335 
336 bool MemRefDependenceGraph::init(bool fullAffineDependences) {
337  LDBG() << "--- Initializing MDG ---";
338  // Map from a memref to the set of ids of the nodes that have ops accessing
339  // the memref.
340  DenseMap<Value, SetVector<unsigned>> memrefAccesses;
341 
342  // Create graph nodes.
343  DenseMap<Operation *, unsigned> forToNodeMap;
344  for (Operation &op : block) {
345  if (auto forOp = dyn_cast<AffineForOp>(op)) {
346  Node *node = addNodeToMDG(&op, *this, memrefAccesses);
347  if (!node)
348  return false;
349  forToNodeMap[&op] = node->id;
350  } else if (isa<AffineReadOpInterface>(op)) {
351  // Create graph node for top-level load op.
352  Node node(nextNodeId++, &op);
353  node.loads.push_back(&op);
354  auto memref = cast<AffineReadOpInterface>(op).getMemRef();
355  memrefAccesses[memref].insert(node.id);
356  nodes.insert({node.id, node});
357  } else if (isa<AffineWriteOpInterface>(op)) {
358  // Create graph node for top-level store op.
359  Node node(nextNodeId++, &op);
360  node.stores.push_back(&op);
361  auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
362  memrefAccesses[memref].insert(node.id);
363  nodes.insert({node.id, node});
364  } else if (op.getNumResults() > 0 && !op.use_empty()) {
365  // Create graph node for top-level producer of SSA values, which
366  // could be used by loop nest nodes.
367  Node *node = addNodeToMDG(&op, *this, memrefAccesses);
368  if (!node)
369  return false;
370  } else if (!isMemoryEffectFree(&op) &&
371  (op.getNumRegions() == 0 || isa<RegionBranchOpInterface>(op))) {
372  // Create graph node for top-level op unless it is known to be
373  // memory-effect free. This covers all unknown/unregistered ops,
374  // non-affine ops with memory effects, and region-holding ops with a
375  // well-defined control flow. During the fusion validity checks, edges
376  // to/from these ops get looked at.
377  Node *node = addNodeToMDG(&op, *this, memrefAccesses);
378  if (!node)
379  return false;
380  } else if (op.getNumRegions() != 0 && !isa<RegionBranchOpInterface>(op)) {
381  // Return false if non-handled/unknown region-holding ops are found. We
382  // won't know what such ops do or what its regions mean; for e.g., it may
383  // not be an imperative op.
384  LDBG() << "MDG init failed; unknown region-holding op found!";
385  return false;
386  }
387  // We aren't creating nodes for memory-effect free ops either with no
388  // regions (unless it has results being used) or those with branch op
389  // interface.
390  }
391 
392  LDBG() << "Created " << nodes.size() << " nodes";
393 
394  // Add dependence edges between nodes which produce SSA values and their
395  // users. Load ops can be considered as the ones producing SSA values.
396  for (auto &idAndNode : nodes) {
397  const Node &node = idAndNode.second;
398  // Stores don't define SSA values, skip them.
399  if (!node.stores.empty())
400  continue;
401  Operation *opInst = node.op;
402  for (Value value : opInst->getResults()) {
403  for (Operation *user : value.getUsers()) {
404  // Ignore users outside of the block.
405  if (block.getParent()->findAncestorOpInRegion(*user)->getBlock() !=
406  &block)
407  continue;
409  getAffineForIVs(*user, &loops);
410  // Find the surrounding affine.for nested immediately within the
411  // block.
412  auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
413  return loop->getBlock() == &block;
414  });
415  if (it == loops.end())
416  continue;
417  assert(forToNodeMap.count(*it) > 0 && "missing mapping");
418  unsigned userLoopNestId = forToNodeMap[*it];
419  addEdge(node.id, userLoopNestId, value);
420  }
421  }
422  }
423 
424  // Walk memref access lists and add graph edges between dependent nodes.
425  for (auto &memrefAndList : memrefAccesses) {
426  unsigned n = memrefAndList.second.size();
427  Value srcMemRef = memrefAndList.first;
428  // Add edges between all dependent pairs among the node IDs on this memref.
429  for (unsigned i = 0; i < n; ++i) {
430  unsigned srcId = memrefAndList.second[i];
431  Node *srcNode = getNode(srcId);
432  bool srcHasStoreOrFree =
433  srcNode->hasStore(srcMemRef) || srcNode->hasFree(srcMemRef);
434  for (unsigned j = i + 1; j < n; ++j) {
435  unsigned dstId = memrefAndList.second[j];
436  Node *dstNode = getNode(dstId);
437  bool dstHasStoreOrFree =
438  dstNode->hasStore(srcMemRef) || dstNode->hasFree(srcMemRef);
439  if ((srcHasStoreOrFree || dstHasStoreOrFree)) {
440  // Check precise affine deps if asked for; otherwise, conservative.
441  if (!fullAffineDependences ||
442  mayDependence(*srcNode, *dstNode, srcMemRef))
443  addEdge(srcId, dstId, srcMemRef);
444  }
445  }
446  }
447  }
448  return true;
449 }
450 
451 // Returns the graph node for 'id'.
452 const Node *MemRefDependenceGraph::getNode(unsigned id) const {
453  auto it = nodes.find(id);
454  assert(it != nodes.end());
455  return &it->second;
456 }
457 
458 // Returns the graph node for 'forOp'.
459 const Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) const {
460  for (auto &idAndNode : nodes)
461  if (idAndNode.second.op == forOp)
462  return &idAndNode.second;
463  return nullptr;
464 }
465 
466 // Adds a node with 'op' to the graph and returns its unique identifier.
468  Node node(nextNodeId++, op);
469  nodes.insert({node.id, node});
470  return node.id;
471 }
472 
473 // Remove node 'id' (and its associated edges) from graph.
475  // Remove each edge in 'inEdges[id]'.
476  if (inEdges.count(id) > 0) {
477  SmallVector<Edge, 2> oldInEdges = inEdges[id];
478  for (auto &inEdge : oldInEdges) {
479  removeEdge(inEdge.id, id, inEdge.value);
480  }
481  }
482  // Remove each edge in 'outEdges[id]'.
483  if (outEdges.contains(id)) {
484  SmallVector<Edge, 2> oldOutEdges = outEdges[id];
485  for (auto &outEdge : oldOutEdges) {
486  removeEdge(id, outEdge.id, outEdge.value);
487  }
488  }
489  // Erase remaining node state.
490  inEdges.erase(id);
491  outEdges.erase(id);
492  nodes.erase(id);
493 }
494 
495 // Returns true if node 'id' writes to any memref which escapes (or is an
496 // argument to) the block. Returns false otherwise.
498  const Node *node = getNode(id);
499  for (auto *storeOpInst : node->stores) {
500  auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
501  auto *op = memref.getDefiningOp();
502  // Return true if 'memref' is a block argument.
503  if (!op)
504  return true;
505  // Return true if any use of 'memref' does not deference it in an affine
506  // way.
507  for (auto *user : memref.getUsers())
508  if (!isa<AffineMapAccessInterface>(*user))
509  return true;
510  }
511  return false;
512 }
513 
514 // Returns true iff there is an edge from node 'srcId' to node 'dstId' which
515 // is for 'value' if non-null, or for any value otherwise. Returns false
516 // otherwise.
517 bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
518  Value value) const {
519  if (!outEdges.contains(srcId) || !inEdges.contains(dstId)) {
520  return false;
521  }
522  bool hasOutEdge = llvm::any_of(outEdges.lookup(srcId), [=](const Edge &edge) {
523  return edge.id == dstId && (!value || edge.value == value);
524  });
525  bool hasInEdge = llvm::any_of(inEdges.lookup(dstId), [=](const Edge &edge) {
526  return edge.id == srcId && (!value || edge.value == value);
527  });
528  return hasOutEdge && hasInEdge;
529 }
530 
531 // Adds an edge from node 'srcId' to node 'dstId' for 'value'.
532 void MemRefDependenceGraph::addEdge(unsigned srcId, unsigned dstId,
533  Value value) {
534  if (!hasEdge(srcId, dstId, value)) {
535  outEdges[srcId].push_back({dstId, value});
536  inEdges[dstId].push_back({srcId, value});
537  if (isa<MemRefType>(value.getType()))
538  memrefEdgeCount[value]++;
539  }
540 }
541 
542 // Removes an edge from node 'srcId' to node 'dstId' for 'value'.
543 void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
544  Value value) {
545  assert(inEdges.count(dstId) > 0);
546  assert(outEdges.count(srcId) > 0);
547  if (isa<MemRefType>(value.getType())) {
548  assert(memrefEdgeCount.count(value) > 0);
549  memrefEdgeCount[value]--;
550  }
551  // Remove 'srcId' from 'inEdges[dstId]'.
552  for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
553  if ((*it).id == srcId && (*it).value == value) {
554  inEdges[dstId].erase(it);
555  break;
556  }
557  }
558  // Remove 'dstId' from 'outEdges[srcId]'.
559  for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
560  if ((*it).id == dstId && (*it).value == value) {
561  outEdges[srcId].erase(it);
562  break;
563  }
564  }
565 }
566 
567 // Returns true if there is a path in the dependence graph from node 'srcId'
568 // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
569 // operations that the edges connected are expected to be from the same block.
571  unsigned dstId) const {
572  // Worklist state is: <node-id, next-output-edge-index-to-visit>
574  worklist.push_back({srcId, 0});
575  Operation *dstOp = getNode(dstId)->op;
576  // Run DFS traversal to see if 'dstId' is reachable from 'srcId'.
577  while (!worklist.empty()) {
578  auto &idAndIndex = worklist.back();
579  // Return true if we have reached 'dstId'.
580  if (idAndIndex.first == dstId)
581  return true;
582  // Pop and continue if node has no out edges, or if all out edges have
583  // already been visited.
584  if (!outEdges.contains(idAndIndex.first) ||
585  idAndIndex.second == outEdges.lookup(idAndIndex.first).size()) {
586  worklist.pop_back();
587  continue;
588  }
589  // Get graph edge to traverse.
590  const Edge edge = outEdges.lookup(idAndIndex.first)[idAndIndex.second];
591  // Increment next output edge index for 'idAndIndex'.
592  ++idAndIndex.second;
593  // Add node at 'edge.id' to the worklist. We don't need to consider
594  // nodes that are "after" dstId in the containing block; one can't have a
595  // path to `dstId` from any of those nodes.
596  bool afterDst = dstOp->isBeforeInBlock(getNode(edge.id)->op);
597  if (!afterDst && edge.id != idAndIndex.first)
598  worklist.push_back({edge.id, 0});
599  }
600  return false;
601 }
602 
603 // Returns the input edge count for node 'id' and 'memref' from src nodes
604 // which access 'memref' with a store operation.
606  Value memref) const {
607  unsigned inEdgeCount = 0;
608  for (const Edge &inEdge : inEdges.lookup(id)) {
609  if (inEdge.value == memref) {
610  const Node *srcNode = getNode(inEdge.id);
611  // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
612  if (srcNode->getStoreOpCount(memref) > 0)
613  ++inEdgeCount;
614  }
615  }
616  return inEdgeCount;
617 }
618 
619 // Returns the output edge count for node 'id' and 'memref' (if non-null),
620 // otherwise returns the total output edge count from node 'id'.
622  Value memref) const {
623  unsigned outEdgeCount = 0;
624  for (const auto &outEdge : outEdges.lookup(id))
625  if (!memref || outEdge.value == memref)
626  ++outEdgeCount;
627  return outEdgeCount;
628 }
629 
630 /// Return all nodes which define SSA values used in node 'id'.
632  unsigned id, DenseSet<unsigned> &definingNodes) const {
633  for (const Edge &edge : inEdges.lookup(id))
634  // By definition of edge, if the edge value is a non-memref value,
635  // then the dependence is between a graph node which defines an SSA value
636  // and another graph node which uses the SSA value.
637  if (!isa<MemRefType>(edge.value.getType()))
638  definingNodes.insert(edge.id);
639 }
640 
641 // Computes and returns an insertion point operation, before which the
642 // the fused <srcId, dstId> loop nest can be inserted while preserving
643 // dependences. Returns nullptr if no such insertion point is found.
644 Operation *
646  unsigned dstId) const {
647  if (!outEdges.contains(srcId))
648  return getNode(dstId)->op;
649 
650  // Skip if there is any defining node of 'dstId' that depends on 'srcId'.
651  DenseSet<unsigned> definingNodes;
652  gatherDefiningNodes(dstId, definingNodes);
653  if (llvm::any_of(definingNodes,
654  [&](unsigned id) { return hasDependencePath(srcId, id); })) {
655  LDBG() << "Can't fuse: a defining op with a user in the dst "
656  << "loop has dependence from the src loop";
657  return nullptr;
658  }
659 
660  // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
662  for (auto &outEdge : outEdges.lookup(srcId))
663  if (outEdge.id != dstId)
664  srcDepInsts.insert(getNode(outEdge.id)->op);
665 
666  // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
668  for (auto &inEdge : inEdges.lookup(dstId))
669  if (inEdge.id != srcId)
670  dstDepInsts.insert(getNode(inEdge.id)->op);
671 
672  Operation *srcNodeInst = getNode(srcId)->op;
673  Operation *dstNodeInst = getNode(dstId)->op;
674 
675  // Computing insertion point:
676  // *) Walk all operation positions in Block operation list in the
677  // range (src, dst). For each operation 'op' visited in this search:
678  // *) Store in 'firstSrcDepPos' the first position where 'op' has a
679  // dependence edge from 'srcNode'.
680  // *) Store in 'lastDstDepPost' the last position where 'op' has a
681  // dependence edge to 'dstNode'.
682  // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
683  // operation insertion point (or return null pointer if no such
684  // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
686  std::optional<unsigned> firstSrcDepPos;
687  std::optional<unsigned> lastDstDepPos;
688  unsigned pos = 0;
689  for (Block::iterator it = std::next(Block::iterator(srcNodeInst));
690  it != Block::iterator(dstNodeInst); ++it) {
691  Operation *op = &(*it);
692  if (srcDepInsts.count(op) > 0 && firstSrcDepPos == std::nullopt)
693  firstSrcDepPos = pos;
694  if (dstDepInsts.count(op) > 0)
695  lastDstDepPos = pos;
696  depInsts.push_back(op);
697  ++pos;
698  }
699 
700  if (firstSrcDepPos.has_value()) {
701  if (lastDstDepPos.has_value()) {
702  if (*firstSrcDepPos <= *lastDstDepPos) {
703  // No valid insertion point exists which preserves dependences.
704  return nullptr;
705  }
706  }
707  // Return the insertion point at 'firstSrcDepPos'.
708  return depInsts[*firstSrcDepPos];
709  }
710  // No dependence targets in range (or only dst deps in range), return
711  // 'dstNodInst' insertion point.
712  return dstNodeInst;
713 }
714 
715 // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
716 // taking into account that:
717 // *) if 'removeSrcId' is true, 'srcId' will be removed after fusion,
718 // *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a
719 // private memref.
720 void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId,
721  const DenseSet<Value> &privateMemRefs,
722  bool removeSrcId) {
723  // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'.
724  if (inEdges.count(srcId) > 0) {
725  SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
726  for (auto &inEdge : oldInEdges) {
727  // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
728  if (!privateMemRefs.contains(inEdge.value))
729  addEdge(inEdge.id, dstId, inEdge.value);
730  }
731  }
732  // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
733  // If 'srcId' is going to be removed, remap all the out edges to 'dstId'.
734  if (outEdges.count(srcId) > 0) {
735  SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
736  for (auto &outEdge : oldOutEdges) {
737  // Remove any out edges from 'srcId' to 'dstId' across memrefs.
738  if (outEdge.id == dstId)
739  removeEdge(srcId, outEdge.id, outEdge.value);
740  else if (removeSrcId) {
741  addEdge(dstId, outEdge.id, outEdge.value);
742  removeEdge(srcId, outEdge.id, outEdge.value);
743  }
744  }
745  }
746  // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
747  // replaced by a private memref). These edges could come from nodes
748  // other than 'srcId' which were removed in the previous step.
749  if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) {
750  SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
751  for (auto &inEdge : oldInEdges)
752  if (privateMemRefs.count(inEdge.value) > 0)
753  removeEdge(inEdge.id, dstId, inEdge.value);
754  }
755 }
756 
757 // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
758 // of sibling node 'sibId' into node 'dstId'.
759 void MemRefDependenceGraph::updateEdges(unsigned sibId, unsigned dstId) {
760  // For each edge in 'inEdges[sibId]':
761  // *) Add new edge from source node 'inEdge.id' to 'dstNode'.
762  // *) Remove edge from source node 'inEdge.id' to 'sibNode'.
763  if (inEdges.count(sibId) > 0) {
764  SmallVector<Edge, 2> oldInEdges = inEdges[sibId];
765  for (auto &inEdge : oldInEdges) {
766  addEdge(inEdge.id, dstId, inEdge.value);
767  removeEdge(inEdge.id, sibId, inEdge.value);
768  }
769  }
770 
771  // For each edge in 'outEdges[sibId]' to node 'id'
772  // *) Add new edge from 'dstId' to 'outEdge.id'.
773  // *) Remove edge from 'sibId' to 'outEdge.id'.
774  if (outEdges.count(sibId) > 0) {
775  SmallVector<Edge, 2> oldOutEdges = outEdges[sibId];
776  for (auto &outEdge : oldOutEdges) {
777  addEdge(dstId, outEdge.id, outEdge.value);
778  removeEdge(sibId, outEdge.id, outEdge.value);
779  }
780  }
781 }
782 
783 // Adds ops in 'loads' and 'stores' to node at 'id'.
785  ArrayRef<Operation *> stores,
786  ArrayRef<Operation *> memrefLoads,
787  ArrayRef<Operation *> memrefStores,
788  ArrayRef<Operation *> memrefFrees) {
789  Node *node = getNode(id);
790  llvm::append_range(node->loads, loads);
791  llvm::append_range(node->stores, stores);
792  llvm::append_range(node->memrefLoads, memrefLoads);
793  llvm::append_range(node->memrefStores, memrefStores);
794  llvm::append_range(node->memrefFrees, memrefFrees);
795 }
796 
798  Node *node = getNode(id);
799  node->loads.clear();
800  node->stores.clear();
801 }
802 
803 // Calls 'callback' for each input edge incident to node 'id' which carries a
804 // memref dependence.
806  unsigned id, const std::function<void(Edge)> &callback) {
807  if (inEdges.count(id) > 0)
808  forEachMemRefEdge(inEdges.at(id), callback);
809 }
810 
811 // Calls 'callback' for each output edge from node 'id' which carries a
812 // memref dependence.
814  unsigned id, const std::function<void(Edge)> &callback) {
815  if (outEdges.count(id) > 0)
816  forEachMemRefEdge(outEdges.at(id), callback);
817 }
818 
819 // Calls 'callback' for each edge in 'edges' which carries a memref
820 // dependence.
822  ArrayRef<Edge> edges, const std::function<void(Edge)> &callback) {
823  for (const auto &edge : edges) {
824  // Skip if 'edge' is not a memref dependence edge.
825  if (!isa<MemRefType>(edge.value.getType()))
826  continue;
827  assert(nodes.count(edge.id) > 0);
828  // Visit current input edge 'edge'.
829  callback(edge);
830  }
831 }
832 
833 void MemRefDependenceGraph::print(raw_ostream &os) const {
834  os << "\nMemRefDependenceGraph\n";
835  os << "\nNodes:\n";
836  for (const auto &idAndNode : nodes) {
837  os << "Node: " << idAndNode.first << "\n";
838  auto it = inEdges.find(idAndNode.first);
839  if (it != inEdges.end()) {
840  for (const auto &e : it->second)
841  os << " InEdge: " << e.id << " " << e.value << "\n";
842  }
843  it = outEdges.find(idAndNode.first);
844  if (it != outEdges.end()) {
845  for (const auto &e : it->second)
846  os << " OutEdge: " << e.id << " " << e.value << "\n";
847  }
848  }
849 }
850 
853  auto *currOp = op.getParentOp();
854  AffineForOp currAffineForOp;
855  // Traverse up the hierarchy collecting all 'affine.for' operation while
856  // skipping over 'affine.if' operations.
857  while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
858  if (auto currAffineForOp = dyn_cast<AffineForOp>(currOp))
859  loops->push_back(currAffineForOp);
860  currOp = currOp->getParentOp();
861  }
862  std::reverse(loops->begin(), loops->end());
863 }
864 
867  ops->clear();
868  Operation *currOp = op.getParentOp();
869 
870  // Traverse up the hierarchy collecting all `affine.for`, `affine.if`, and
871  // affine.parallel operations.
872  while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
873  if (isa<AffineIfOp, AffineForOp, AffineParallelOp>(currOp))
874  ops->push_back(currOp);
875  currOp = currOp->getParentOp();
876  }
877  std::reverse(ops->begin(), ops->end());
878 }
879 
880 // Populates 'cst' with FlatAffineValueConstraints which represent original
881 // domain of the loop bounds that define 'ivs'.
883  FlatAffineValueConstraints &cst) const {
884  assert(!ivs.empty() && "Cannot have a slice without its IVs");
885  cst = FlatAffineValueConstraints(/*numDims=*/ivs.size(), /*numSymbols=*/0,
886  /*numLocals=*/0, ivs);
887  for (Value iv : ivs) {
888  AffineForOp loop = getForInductionVarOwner(iv);
889  assert(loop && "Expected affine for");
890  if (failed(cst.addAffineForOpDomain(loop)))
891  return failure();
892  }
893  return success();
894 }
895 
896 // Populates 'cst' with FlatAffineValueConstraints which represent slice bounds.
897 LogicalResult
899  assert(!lbOperands.empty());
900  // Adds src 'ivs' as dimension variables in 'cst'.
901  unsigned numDims = ivs.size();
902  // Adds operands (dst ivs and symbols) as symbols in 'cst'.
903  unsigned numSymbols = lbOperands[0].size();
904 
905  SmallVector<Value, 4> values(ivs);
906  // Append 'ivs' then 'operands' to 'values'.
907  values.append(lbOperands[0].begin(), lbOperands[0].end());
908  *cst = FlatAffineValueConstraints(numDims, numSymbols, 0, values);
909 
910  // Add loop bound constraints for values which are loop IVs of the destination
911  // of fusion and equality constraints for symbols which are constants.
912  for (unsigned i = numDims, end = values.size(); i < end; ++i) {
913  Value value = values[i];
914  assert(cst->containsVar(value) && "value expected to be present");
915  if (isValidSymbol(value)) {
916  // Check if the symbol is a constant.
917  if (std::optional<int64_t> cOp = getConstantIntValue(value))
918  cst->addBound(BoundType::EQ, value, cOp.value());
919  } else if (auto loop = getForInductionVarOwner(value)) {
920  if (failed(cst->addAffineForOpDomain(loop)))
921  return failure();
922  }
923  }
924 
925  // Add slices bounds on 'ivs' using maps 'lbs'/'ubs' with 'lbOperands[0]'
926  LogicalResult ret = cst->addSliceBounds(ivs, lbs, ubs, lbOperands[0]);
927  assert(succeeded(ret) &&
928  "should not fail as we never have semi-affine slice maps");
929  (void)ret;
930  return success();
931 }
932 
933 // Clears state bounds and operand state.
935  lbs.clear();
936  ubs.clear();
937  lbOperands.clear();
938  ubOperands.clear();
939 }
940 
942  llvm::errs() << "\tIVs:\n";
943  for (Value iv : ivs)
944  llvm::errs() << "\t\t" << iv << "\n";
945 
946  llvm::errs() << "\tLBs:\n";
947  for (auto en : llvm::enumerate(lbs)) {
948  llvm::errs() << "\t\t" << en.value() << "\n";
949  llvm::errs() << "\t\tOperands:\n";
950  for (Value lbOp : lbOperands[en.index()])
951  llvm::errs() << "\t\t\t" << lbOp << "\n";
952  }
953 
954  llvm::errs() << "\tUBs:\n";
955  for (auto en : llvm::enumerate(ubs)) {
956  llvm::errs() << "\t\t" << en.value() << "\n";
957  llvm::errs() << "\t\tOperands:\n";
958  for (Value ubOp : ubOperands[en.index()])
959  llvm::errs() << "\t\t\t" << ubOp << "\n";
960  }
961 }
962 
963 /// Fast check to determine if the computation slice is maximal. Returns true if
964 /// each slice dimension maps to an existing dst dimension and both the src
965 /// and the dst loops for those dimensions have the same bounds. Returns false
966 /// if both the src and the dst loops don't have the same bounds. Returns
967 /// std::nullopt if none of the above can be proven.
968 std::optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
969  assert(lbs.size() == ubs.size() && !lbs.empty() && !ivs.empty() &&
970  "Unexpected number of lbs, ubs and ivs in slice");
971 
972  for (unsigned i = 0, end = lbs.size(); i < end; ++i) {
973  AffineMap lbMap = lbs[i];
974  AffineMap ubMap = ubs[i];
975 
976  // Check if this slice is just an equality along this dimension.
977  if (!lbMap || !ubMap || lbMap.getNumResults() != 1 ||
978  ubMap.getNumResults() != 1 ||
979  lbMap.getResult(0) + 1 != ubMap.getResult(0) ||
980  // The condition above will be true for maps describing a single
981  // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
982  // Make sure we skip those cases by checking that the lb result is not
983  // just a constant.
984  isa<AffineConstantExpr>(lbMap.getResult(0)))
985  return std::nullopt;
986 
987  // Limited support: we expect the lb result to be just a loop dimension for
988  // now.
989  AffineDimExpr result = dyn_cast<AffineDimExpr>(lbMap.getResult(0));
990  if (!result)
991  return std::nullopt;
992 
993  // Retrieve dst loop bounds.
994  AffineForOp dstLoop =
995  getForInductionVarOwner(lbOperands[i][result.getPosition()]);
996  if (!dstLoop)
997  return std::nullopt;
998  AffineMap dstLbMap = dstLoop.getLowerBoundMap();
999  AffineMap dstUbMap = dstLoop.getUpperBoundMap();
1000 
1001  // Retrieve src loop bounds.
1002  AffineForOp srcLoop = getForInductionVarOwner(ivs[i]);
1003  assert(srcLoop && "Expected affine for");
1004  AffineMap srcLbMap = srcLoop.getLowerBoundMap();
1005  AffineMap srcUbMap = srcLoop.getUpperBoundMap();
1006 
1007  // Limited support: we expect simple src and dst loops with a single
1008  // constant component per bound for now.
1009  if (srcLbMap.getNumResults() != 1 || srcUbMap.getNumResults() != 1 ||
1010  dstLbMap.getNumResults() != 1 || dstUbMap.getNumResults() != 1)
1011  return std::nullopt;
1012 
1013  AffineExpr srcLbResult = srcLbMap.getResult(0);
1014  AffineExpr dstLbResult = dstLbMap.getResult(0);
1015  AffineExpr srcUbResult = srcUbMap.getResult(0);
1016  AffineExpr dstUbResult = dstUbMap.getResult(0);
1017  if (!isa<AffineConstantExpr>(srcLbResult) ||
1018  !isa<AffineConstantExpr>(srcUbResult) ||
1019  !isa<AffineConstantExpr>(dstLbResult) ||
1020  !isa<AffineConstantExpr>(dstUbResult))
1021  return std::nullopt;
1022 
1023  // Check if src and dst loop bounds are the same. If not, we can guarantee
1024  // that the slice is not maximal.
1025  if (srcLbResult != dstLbResult || srcUbResult != dstUbResult ||
1026  srcLoop.getStep() != dstLoop.getStep())
1027  return false;
1028  }
1029 
1030  return true;
1031 }
1032 
1033 /// Returns true if it is deterministically verified that the original iteration
1034 /// space of the slice is contained within the new iteration space that is
1035 /// created after fusing 'this' slice into its destination.
1036 std::optional<bool> ComputationSliceState::isSliceValid() const {
1037  // Fast check to determine if the slice is valid. If the following conditions
1038  // are verified to be true, slice is declared valid by the fast check:
1039  // 1. Each slice loop is a single iteration loop bound in terms of a single
1040  // destination loop IV.
1041  // 2. Loop bounds of the destination loop IV (from above) and those of the
1042  // source loop IV are exactly the same.
1043  // If the fast check is inconclusive or false, we proceed with a more
1044  // expensive analysis.
1045  // TODO: Store the result of the fast check, as it might be used again in
1046  // `canRemoveSrcNodeAfterFusion`.
1047  std::optional<bool> isValidFastCheck = isSliceMaximalFastCheck();
1048  if (isValidFastCheck && *isValidFastCheck)
1049  return true;
1050 
1051  // Create constraints for the source loop nest using which slice is computed.
1052  FlatAffineValueConstraints srcConstraints;
1053  // TODO: Store the source's domain to avoid computation at each depth.
1054  if (failed(getSourceAsConstraints(srcConstraints))) {
1055  LDBG() << "Unable to compute source's domain";
1056  return std::nullopt;
1057  }
1058  // As the set difference utility currently cannot handle symbols in its
1059  // operands, validity of the slice cannot be determined.
1060  if (srcConstraints.getNumSymbolVars() > 0) {
1061  LDBG() << "Cannot handle symbols in source domain";
1062  return std::nullopt;
1063  }
1064  // TODO: Handle local vars in the source domains while using the 'projectOut'
1065  // utility below. Currently, aligning is not done assuming that there will be
1066  // no local vars in the source domain.
1067  if (srcConstraints.getNumLocalVars() != 0) {
1068  LDBG() << "Cannot handle locals in source domain";
1069  return std::nullopt;
1070  }
1071 
1072  // Create constraints for the slice loop nest that would be created if the
1073  // fusion succeeds.
1074  FlatAffineValueConstraints sliceConstraints;
1075  if (failed(getAsConstraints(&sliceConstraints))) {
1076  LDBG() << "Unable to compute slice's domain";
1077  return std::nullopt;
1078  }
1079 
1080  // Projecting out every dimension other than the 'ivs' to express slice's
1081  // domain completely in terms of source's IVs.
1082  sliceConstraints.projectOut(ivs.size(),
1083  sliceConstraints.getNumVars() - ivs.size());
1084 
1085  LDBG() << "Domain of the source of the slice:\n"
1086  << "Source constraints:" << srcConstraints
1087  << "\nDomain of the slice if this fusion succeeds "
1088  << "(expressed in terms of its source's IVs):\n"
1089  << "Slice constraints:" << sliceConstraints;
1090 
1091  // TODO: Store 'srcSet' to avoid recalculating for each depth.
1092  PresburgerSet srcSet(srcConstraints);
1093  PresburgerSet sliceSet(sliceConstraints);
1094  PresburgerSet diffSet = sliceSet.subtract(srcSet);
1095 
1096  if (!diffSet.isIntegerEmpty()) {
1097  LDBG() << "Incorrect slice";
1098  return false;
1099  }
1100  return true;
1101 }
1102 
1103 /// Returns true if the computation slice encloses all the iterations of the
1104 /// sliced loop nest. Returns false if it does not. Returns std::nullopt if it
1105 /// cannot determine if the slice is maximal or not.
1106 std::optional<bool> ComputationSliceState::isMaximal() const {
1107  // Fast check to determine if the computation slice is maximal. If the result
1108  // is inconclusive, we proceed with a more expensive analysis.
1109  std::optional<bool> isMaximalFastCheck = isSliceMaximalFastCheck();
1110  if (isMaximalFastCheck)
1111  return isMaximalFastCheck;
1112 
1113  // Create constraints for the src loop nest being sliced.
1114  FlatAffineValueConstraints srcConstraints(/*numDims=*/ivs.size(),
1115  /*numSymbols=*/0,
1116  /*numLocals=*/0, ivs);
1117  for (Value iv : ivs) {
1118  AffineForOp loop = getForInductionVarOwner(iv);
1119  assert(loop && "Expected affine for");
1120  if (failed(srcConstraints.addAffineForOpDomain(loop)))
1121  return std::nullopt;
1122  }
1123 
1124  // Create constraints for the slice using the dst loop nest information. We
1125  // retrieve existing dst loops from the lbOperands.
1126  SmallVector<Value> consumerIVs;
1127  for (Value lbOp : lbOperands[0])
1128  if (getForInductionVarOwner(lbOp))
1129  consumerIVs.push_back(lbOp);
1130 
1131  // Add empty IV Values for those new loops that are not equalities and,
1132  // therefore, are not yet materialized in the IR.
1133  for (int i = consumerIVs.size(), end = ivs.size(); i < end; ++i)
1134  consumerIVs.push_back(Value());
1135 
1136  FlatAffineValueConstraints sliceConstraints(/*numDims=*/consumerIVs.size(),
1137  /*numSymbols=*/0,
1138  /*numLocals=*/0, consumerIVs);
1139 
1140  if (failed(sliceConstraints.addDomainFromSliceMaps(lbs, ubs, lbOperands[0])))
1141  return std::nullopt;
1142 
1143  if (srcConstraints.getNumDimVars() != sliceConstraints.getNumDimVars())
1144  // Constraint dims are different. The integer set difference can't be
1145  // computed so we don't know if the slice is maximal.
1146  return std::nullopt;
1147 
1148  // Compute the difference between the src loop nest and the slice integer
1149  // sets.
1150  PresburgerSet srcSet(srcConstraints);
1151  PresburgerSet sliceSet(sliceConstraints);
1152  PresburgerSet diffSet = srcSet.subtract(sliceSet);
1153  return diffSet.isIntegerEmpty();
1154 }
1155 
1156 unsigned MemRefRegion::getRank() const {
1157  return cast<MemRefType>(memref.getType()).getRank();
1158 }
1159 
1162  auto memRefType = cast<MemRefType>(memref.getType());
1163  MLIRContext *context = memref.getContext();
1164  unsigned rank = memRefType.getRank();
1165  if (shape)
1166  shape->reserve(rank);
1167 
1168  assert(rank == cst.getNumDimVars() && "inconsistent memref region");
1169 
1170  // Use a copy of the region constraints that has upper/lower bounds for each
1171  // memref dimension with static size added to guard against potential
1172  // over-approximation from projection or union bounding box. We may not add
1173  // this on the region itself since they might just be redundant constraints
1174  // that will need non-trivials means to eliminate.
1175  FlatLinearValueConstraints cstWithShapeBounds(cst);
1176  for (unsigned r = 0; r < rank; r++) {
1177  cstWithShapeBounds.addBound(BoundType::LB, r, 0);
1178  int64_t dimSize = memRefType.getDimSize(r);
1179  if (ShapedType::isDynamic(dimSize))
1180  continue;
1181  cstWithShapeBounds.addBound(BoundType::UB, r, dimSize - 1);
1182  }
1183 
1184  // Find a constant upper bound on the extent of this memref region along
1185  // each dimension.
1186  int64_t numElements = 1;
1187  int64_t diffConstant;
1188  for (unsigned d = 0; d < rank; d++) {
1189  AffineMap lb;
1190  std::optional<int64_t> diff =
1191  cstWithShapeBounds.getConstantBoundOnDimSize(context, d, &lb);
1192  if (diff.has_value()) {
1193  diffConstant = *diff;
1194  assert(diffConstant >= 0 && "dim size bound cannot be negative");
1195  } else {
1196  // If no constant bound is found, then it can always be bound by the
1197  // memref's dim size if the latter has a constant size along this dim.
1198  auto dimSize = memRefType.getDimSize(d);
1199  if (ShapedType::isDynamic(dimSize))
1200  return std::nullopt;
1201  diffConstant = dimSize;
1202  // Lower bound becomes 0.
1203  lb = AffineMap::get(/*dimCount=*/0, cstWithShapeBounds.getNumSymbolVars(),
1204  /*result=*/getAffineConstantExpr(0, context));
1205  }
1206  numElements *= diffConstant;
1207  // Populate outputs if available.
1208  if (lbs)
1209  lbs->push_back(lb);
1210  if (shape)
1211  shape->push_back(diffConstant);
1212  }
1213  return numElements;
1214 }
1215 
1217  AffineMap &ubMap) const {
1218  assert(pos < cst.getNumDimVars() && "invalid position");
1219  auto memRefType = cast<MemRefType>(memref.getType());
1220  unsigned rank = memRefType.getRank();
1221 
1222  assert(rank == cst.getNumDimVars() && "inconsistent memref region");
1223 
1224  auto boundPairs = cst.getLowerAndUpperBound(
1225  pos, /*offset=*/0, /*num=*/rank, cst.getNumDimAndSymbolVars(),
1226  /*localExprs=*/{}, memRefType.getContext());
1227  lbMap = boundPairs.first;
1228  ubMap = boundPairs.second;
1229  assert(lbMap && "lower bound for a region must exist");
1230  assert(ubMap && "upper bound for a region must exist");
1231  assert(lbMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank);
1232  assert(ubMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank);
1233 }
1234 
1235 LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) {
1236  assert(memref == other.memref);
1237  return cst.unionBoundingBox(*other.getConstraints());
1238 }
1239 
1240 /// Computes the memory region accessed by this memref with the region
1241 /// represented as constraints symbolic/parametric in 'loopDepth' loops
1242 /// surrounding opInst and any additional Function symbols.
1243 // For example, the memref region for this load operation at loopDepth = 1 will
1244 // be as below:
1245 //
1246 // affine.for %i = 0 to 32 {
1247 // affine.for %ii = %i to (d0) -> (d0 + 8) (%i) {
1248 // load %A[%ii]
1249 // }
1250 // }
1251 //
1252 // region: {memref = %A, write = false, {%i <= m0 <= %i + 7} }
1253 // The last field is a 2-d FlatAffineValueConstraints symbolic in %i.
1254 //
1255 // TODO: extend this to any other memref dereferencing ops
1256 // (dma_start, dma_wait).
1257 LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
1258  const ComputationSliceState *sliceState,
1259  bool addMemRefDimBounds, bool dropLocalVars,
1260  bool dropOuterIvs) {
1261  assert((isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) &&
1262  "affine read/write op expected");
1263 
1264  MemRefAccess access(op);
1265  memref = access.memref;
1266  write = access.isStore();
1267 
1268  unsigned rank = access.getRank();
1269 
1270  LDBG() << "MemRefRegion::compute: " << *op << " depth: " << loopDepth;
1271 
1272  // 0-d memrefs.
1273  if (rank == 0) {
1275  getAffineIVs(*op, ivs);
1276  assert(loopDepth <= ivs.size() && "invalid 'loopDepth'");
1277  // The first 'loopDepth' IVs are symbols for this region.
1278  ivs.resize(loopDepth);
1279  // A 0-d memref has a 0-d region.
1280  cst = FlatAffineValueConstraints(rank, loopDepth, /*numLocals=*/0, ivs);
1281  return success();
1282  }
1283 
1284  // Build the constraints for this region.
1285  AffineValueMap accessValueMap;
1286  access.getAccessMap(&accessValueMap);
1287  AffineMap accessMap = accessValueMap.getAffineMap();
1288 
1289  unsigned numDims = accessMap.getNumDims();
1290  unsigned numSymbols = accessMap.getNumSymbols();
1291  unsigned numOperands = accessValueMap.getNumOperands();
1292  // Merge operands with slice operands.
1293  SmallVector<Value, 4> operands;
1294  operands.resize(numOperands);
1295  for (unsigned i = 0; i < numOperands; ++i)
1296  operands[i] = accessValueMap.getOperand(i);
1297 
1298  if (sliceState != nullptr) {
1299  operands.reserve(operands.size() + sliceState->lbOperands[0].size());
1300  // Append slice operands to 'operands' as symbols.
1301  for (auto extraOperand : sliceState->lbOperands[0]) {
1302  if (!llvm::is_contained(operands, extraOperand)) {
1303  operands.push_back(extraOperand);
1304  numSymbols++;
1305  }
1306  }
1307  }
1308  // We'll first associate the dims and symbols of the access map to the dims
1309  // and symbols resp. of cst. This will change below once cst is
1310  // fully constructed out.
1311  cst = FlatAffineValueConstraints(numDims, numSymbols, 0, operands);
1312 
1313  // Add equality constraints.
1314  // Add inequalities for loop lower/upper bounds.
1315  for (unsigned i = 0; i < numDims + numSymbols; ++i) {
1316  auto operand = operands[i];
1317  if (auto affineFor = getForInductionVarOwner(operand)) {
1318  // Note that cst can now have more dimensions than accessMap if the
1319  // bounds expressions involve outer loops or other symbols.
1320  // TODO: rewrite this to use getInstIndexSet; this way
1321  // conditionals will be handled when the latter supports it.
1322  if (failed(cst.addAffineForOpDomain(affineFor)))
1323  return failure();
1324  } else if (auto parallelOp = getAffineParallelInductionVarOwner(operand)) {
1325  if (failed(cst.addAffineParallelOpDomain(parallelOp)))
1326  return failure();
1327  } else if (isValidSymbol(operand)) {
1328  // Check if the symbol is a constant.
1329  Value symbol = operand;
1330  if (auto constVal = getConstantIntValue(symbol))
1331  cst.addBound(BoundType::EQ, symbol, constVal.value());
1332  } else {
1333  LDBG() << "unknown affine dimensional value";
1334  return failure();
1335  }
1336  }
1337 
1338  // Add lower/upper bounds on loop IVs using bounds from 'sliceState'.
1339  if (sliceState != nullptr) {
1340  // Add dim and symbol slice operands.
1341  for (auto operand : sliceState->lbOperands[0]) {
1342  if (failed(cst.addInductionVarOrTerminalSymbol(operand)))
1343  return failure();
1344  }
1345  // Add upper/lower bounds from 'sliceState' to 'cst'.
1346  LogicalResult ret =
1347  cst.addSliceBounds(sliceState->ivs, sliceState->lbs, sliceState->ubs,
1348  sliceState->lbOperands[0]);
1349  assert(succeeded(ret) &&
1350  "should not fail as we never have semi-affine slice maps");
1351  (void)ret;
1352  }
1353 
1354  // Add access function equalities to connect loop IVs to data dimensions.
1355  if (failed(cst.composeMap(&accessValueMap))) {
1356  op->emitError("getMemRefRegion: compose affine map failed");
1357  LDBG() << "Access map: " << accessValueMap.getAffineMap();
1358  return failure();
1359  }
1360 
1361  // Set all variables appearing after the first 'rank' variables as
1362  // symbolic variables - so that the ones corresponding to the memref
1363  // dimensions are the dimensional variables for the memref region.
1364  cst.setDimSymbolSeparation(cst.getNumDimAndSymbolVars() - rank);
1365 
1366  // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
1367  // this memref region is symbolic.
1368  SmallVector<Value, 4> enclosingIVs;
1369  getAffineIVs(*op, enclosingIVs);
1370  assert(loopDepth <= enclosingIVs.size() && "invalid loop depth");
1371  enclosingIVs.resize(loopDepth);
1372  SmallVector<Value, 4> vars;
1373  cst.getValues(cst.getNumDimVars(), cst.getNumDimAndSymbolVars(), &vars);
1374  for (auto en : llvm::enumerate(vars)) {
1375  if ((isAffineInductionVar(en.value())) &&
1376  !llvm::is_contained(enclosingIVs, en.value())) {
1377  if (dropOuterIvs) {
1378  cst.projectOut(en.value());
1379  } else {
1380  unsigned varPosition;
1381  cst.findVar(en.value(), &varPosition);
1382  auto varKind = cst.getVarKindAt(varPosition);
1383  varPosition -= cst.getNumDimVars();
1384  cst.convertToLocal(varKind, varPosition, varPosition + 1);
1385  }
1386  }
1387  }
1388 
1389  // Project out any local variables (these would have been added for any
1390  // mod/divs) if specified.
1391  if (dropLocalVars)
1392  cst.projectOut(cst.getNumDimAndSymbolVars(), cst.getNumLocalVars());
1393 
1394  // Constant fold any symbolic variables.
1395  cst.constantFoldVarRange(/*pos=*/cst.getNumDimVars(),
1396  /*num=*/cst.getNumSymbolVars());
1397 
1398  assert(cst.getNumDimVars() == rank && "unexpected MemRefRegion format");
1399 
1400  // Add upper/lower bounds for each memref dimension with static size
1401  // to guard against potential over-approximation from projection.
1402  // TODO: Support dynamic memref dimensions.
1403  if (addMemRefDimBounds) {
1404  auto memRefType = cast<MemRefType>(memref.getType());
1405  for (unsigned r = 0; r < rank; r++) {
1406  cst.addBound(BoundType::LB, /*pos=*/r, /*value=*/0);
1407  if (memRefType.isDynamicDim(r))
1408  continue;
1409  cst.addBound(BoundType::UB, /*pos=*/r, memRefType.getDimSize(r) - 1);
1410  }
1411  }
1412  cst.removeTrivialRedundancy();
1413 
1414  LDBG() << "Memory region: " << cst;
1415  return success();
1416 }
1417 
1418 std::optional<int64_t>
1420  auto elementType = memRefType.getElementType();
1421 
1422  unsigned sizeInBits;
1423  if (elementType.isIntOrFloat()) {
1424  sizeInBits = elementType.getIntOrFloatBitWidth();
1425  } else if (auto vectorType = dyn_cast<VectorType>(elementType)) {
1426  if (vectorType.getElementType().isIntOrFloat())
1427  sizeInBits =
1428  vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
1429  else
1430  return std::nullopt;
1431  } else {
1432  return std::nullopt;
1433  }
1434  return llvm::divideCeil(sizeInBits, 8);
1435 }
1436 
1437 // Returns the size of the region.
1438 std::optional<int64_t> MemRefRegion::getRegionSize() {
1439  auto memRefType = cast<MemRefType>(memref.getType());
1440 
1441  if (!memRefType.getLayout().isIdentity()) {
1442  LDBG() << "Non-identity layout map not yet supported";
1443  return false;
1444  }
1445 
1446  // Compute the extents of the buffer.
1447  std::optional<int64_t> numElements = getConstantBoundingSizeAndShape();
1448  if (!numElements) {
1449  LDBG() << "Dynamic shapes not yet supported";
1450  return std::nullopt;
1451  }
1452  auto eltSize = getMemRefIntOrFloatEltSizeInBytes(memRefType);
1453  if (!eltSize)
1454  return std::nullopt;
1455  return *eltSize * *numElements;
1456 }
1457 
1458 /// Returns the size of memref data in bytes if it's statically shaped,
1459 /// std::nullopt otherwise. If the element of the memref has vector type, takes
1460 /// into account size of the vector as well.
1461 // TODO: improve/complete this when we have target data.
1462 std::optional<uint64_t>
1464  if (!memRefType.hasStaticShape())
1465  return std::nullopt;
1466  auto elementType = memRefType.getElementType();
1467  if (!elementType.isIntOrFloat() && !isa<VectorType>(elementType))
1468  return std::nullopt;
1469 
1470  auto sizeInBytes = getMemRefIntOrFloatEltSizeInBytes(memRefType);
1471  if (!sizeInBytes)
1472  return std::nullopt;
1473  for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) {
1474  sizeInBytes = *sizeInBytes * memRefType.getDimSize(i);
1475  }
1476  return sizeInBytes;
1477 }
1478 
1479 template <typename LoadOrStoreOp>
1480 LogicalResult mlir::affine::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
1481  bool emitError) {
1482  static_assert(llvm::is_one_of<LoadOrStoreOp, AffineReadOpInterface,
1483  AffineWriteOpInterface>::value,
1484  "argument should be either a AffineReadOpInterface or a "
1485  "AffineWriteOpInterface");
1486 
1487  Operation *op = loadOrStoreOp.getOperation();
1488  MemRefRegion region(op->getLoc());
1489  if (failed(region.compute(op, /*loopDepth=*/0, /*sliceState=*/nullptr,
1490  /*addMemRefDimBounds=*/false)))
1491  return success();
1492 
1493  LDBG() << "Memory region: " << region.getConstraints();
1494 
1495  bool outOfBounds = false;
1496  unsigned rank = loadOrStoreOp.getMemRefType().getRank();
1497 
1498  // For each dimension, check for out of bounds.
1499  for (unsigned r = 0; r < rank; r++) {
1500  FlatAffineValueConstraints ucst(*region.getConstraints());
1501 
1502  // Intersect memory region with constraint capturing out of bounds (both out
1503  // of upper and out of lower), and check if the constraint system is
1504  // feasible. If it is, there is at least one point out of bounds.
1505  SmallVector<int64_t, 4> ineq(rank + 1, 0);
1506  int64_t dimSize = loadOrStoreOp.getMemRefType().getDimSize(r);
1507  // TODO: handle dynamic dim sizes.
1508  if (dimSize == -1)
1509  continue;
1510 
1511  // Check for overflow: d_i >= memref dim size.
1512  ucst.addBound(BoundType::LB, r, dimSize);
1513  outOfBounds = !ucst.isEmpty();
1514  if (outOfBounds && emitError) {
1515  loadOrStoreOp.emitOpError()
1516  << "memref out of upper bound access along dimension #" << (r + 1);
1517  }
1518 
1519  // Check for a negative index.
1520  FlatAffineValueConstraints lcst(*region.getConstraints());
1521  llvm::fill(ineq, 0);
1522  // d_i <= -1;
1523  lcst.addBound(BoundType::UB, r, -1);
1524  outOfBounds = !lcst.isEmpty();
1525  if (outOfBounds && emitError) {
1526  loadOrStoreOp.emitOpError()
1527  << "memref out of lower bound access along dimension #" << (r + 1);
1528  }
1529  }
1530  return failure(outOfBounds);
1531 }
1532 
1533 // Explicitly instantiate the template so that the compiler knows we need them!
1534 template LogicalResult
1535 mlir::affine::boundCheckLoadOrStoreOp(AffineReadOpInterface loadOp,
1536  bool emitError);
1537 template LogicalResult
1538 mlir::affine::boundCheckLoadOrStoreOp(AffineWriteOpInterface storeOp,
1539  bool emitError);
1540 
1541 // Returns in 'positions' the Block positions of 'op' in each ancestor
1542 // Block from the Block containing operation, stopping at 'limitBlock'.
1543 static void findInstPosition(Operation *op, Block *limitBlock,
1544  SmallVectorImpl<unsigned> *positions) {
1545  Block *block = op->getBlock();
1546  while (block != limitBlock) {
1547  // FIXME: This algorithm is unnecessarily O(n) and should be improved to not
1548  // rely on linear scans.
1549  int instPosInBlock = std::distance(block->begin(), op->getIterator());
1550  positions->push_back(instPosInBlock);
1551  op = block->getParentOp();
1552  block = op->getBlock();
1553  }
1554  std::reverse(positions->begin(), positions->end());
1555 }
1556 
1557 // Returns the Operation in a possibly nested set of Blocks, where the
1558 // position of the operation is represented by 'positions', which has a
1559 // Block position for each level of nesting.
1561  unsigned level, Block *block) {
1562  unsigned i = 0;
1563  for (auto &op : *block) {
1564  if (i != positions[level]) {
1565  ++i;
1566  continue;
1567  }
1568  if (level == positions.size() - 1)
1569  return &op;
1570  if (auto childAffineForOp = dyn_cast<AffineForOp>(op))
1571  return getInstAtPosition(positions, level + 1,
1572  childAffineForOp.getBody());
1573 
1574  for (auto &region : op.getRegions()) {
1575  for (auto &b : region)
1576  if (auto *ret = getInstAtPosition(positions, level + 1, &b))
1577  return ret;
1578  }
1579  return nullptr;
1580  }
1581  return nullptr;
1582 }
1583 
1584 // Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'.
1587  for (unsigned i = 0, e = cst->getNumDimVars(); i < e; ++i) {
1588  auto value = cst->getValue(i);
1589  if (ivs.count(value) == 0) {
1590  assert(isAffineForInductionVar(value));
1591  auto loop = getForInductionVarOwner(value);
1592  if (failed(cst->addAffineForOpDomain(loop)))
1593  return failure();
1594  }
1595  }
1596  return success();
1597 }
1598 
1599 /// Returns the innermost common loop depth for the set of operations in 'ops'.
1600 // TODO: Move this to LoopUtils.
1602  ArrayRef<Operation *> ops, SmallVectorImpl<AffineForOp> *surroundingLoops) {
1603  unsigned numOps = ops.size();
1604  assert(numOps > 0 && "Expected at least one operation");
1605 
1606  std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
1607  unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
1608  for (unsigned i = 0; i < numOps; ++i) {
1609  getAffineForIVs(*ops[i], &loops[i]);
1610  loopDepthLimit =
1611  std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
1612  }
1613 
1614  unsigned loopDepth = 0;
1615  for (unsigned d = 0; d < loopDepthLimit; ++d) {
1616  unsigned i;
1617  for (i = 1; i < numOps; ++i) {
1618  if (loops[i - 1][d] != loops[i][d])
1619  return loopDepth;
1620  }
1621  if (surroundingLoops)
1622  surroundingLoops->push_back(loops[i - 1][d]);
1623  ++loopDepth;
1624  }
1625  return loopDepth;
1626 }
1627 
1628 /// Computes in 'sliceUnion' the union of all slice bounds computed at
1629 /// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and
1630 /// then verifies if it is valid. Returns 'SliceComputationResult::Success' if
1631 /// union was computed correctly, an appropriate failure otherwise.
1634  ArrayRef<Operation *> opsB, unsigned loopDepth,
1635  unsigned numCommonLoops, bool isBackwardSlice,
1636  ComputationSliceState *sliceUnion) {
1637  // Compute the union of slice bounds between all pairs in 'opsA' and
1638  // 'opsB' in 'sliceUnionCst'.
1639  FlatAffineValueConstraints sliceUnionCst;
1640  assert(sliceUnionCst.getNumDimAndSymbolVars() == 0);
1641  std::vector<std::pair<Operation *, Operation *>> dependentOpPairs;
1642  MemRefAccess srcAccess;
1643  MemRefAccess dstAccess;
1644  for (Operation *a : opsA) {
1645  srcAccess = MemRefAccess(a);
1646  for (Operation *b : opsB) {
1647  dstAccess = MemRefAccess(b);
1648  if (srcAccess.memref != dstAccess.memref)
1649  continue;
1650  // Check if 'loopDepth' exceeds nesting depth of src/dst ops.
1651  if ((!isBackwardSlice && loopDepth > getNestingDepth(a)) ||
1652  (isBackwardSlice && loopDepth > getNestingDepth(b))) {
1653  LDBG() << "Invalid loop depth";
1655  }
1656 
1657  bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) &&
1658  isa<AffineReadOpInterface>(dstAccess.opInst);
1659  FlatAffineValueConstraints dependenceConstraints;
1660  // Check dependence between 'srcAccess' and 'dstAccess'.
1662  srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1,
1663  &dependenceConstraints, /*dependenceComponents=*/nullptr,
1664  /*allowRAR=*/readReadAccesses);
1665  if (result.value == DependenceResult::Failure) {
1666  LDBG() << "Dependence check failed";
1668  }
1669  if (result.value == DependenceResult::NoDependence)
1670  continue;
1671  dependentOpPairs.emplace_back(a, b);
1672 
1673  // Compute slice bounds for 'srcAccess' and 'dstAccess'.
1674  ComputationSliceState tmpSliceState;
1675  getComputationSliceState(a, b, dependenceConstraints, loopDepth,
1676  isBackwardSlice, &tmpSliceState);
1677 
1678  if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
1679  // Initialize 'sliceUnionCst' with the bounds computed in previous step.
1680  if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
1681  LDBG() << "Unable to compute slice bound constraints";
1683  }
1684  assert(sliceUnionCst.getNumDimAndSymbolVars() > 0);
1685  continue;
1686  }
1687 
1688  // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
1689  FlatAffineValueConstraints tmpSliceCst;
1690  if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
1691  LDBG() << "Unable to compute slice bound constraints";
1693  }
1694 
1695  // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
1696  if (!sliceUnionCst.areVarsAlignedWithOther(tmpSliceCst)) {
1697 
1698  // Pre-constraint var alignment: record loop IVs used in each constraint
1699  // system.
1700  SmallPtrSet<Value, 8> sliceUnionIVs;
1701  for (unsigned k = 0, l = sliceUnionCst.getNumDimVars(); k < l; ++k)
1702  sliceUnionIVs.insert(sliceUnionCst.getValue(k));
1703  SmallPtrSet<Value, 8> tmpSliceIVs;
1704  for (unsigned k = 0, l = tmpSliceCst.getNumDimVars(); k < l; ++k)
1705  tmpSliceIVs.insert(tmpSliceCst.getValue(k));
1706 
1707  sliceUnionCst.mergeAndAlignVarsWithOther(/*offset=*/0, &tmpSliceCst);
1708 
1709  // Post-constraint var alignment: add loop IV bounds missing after
1710  // var alignment to constraint systems. This can occur if one constraint
1711  // system uses an loop IV that is not used by the other. The call
1712  // to unionBoundingBox below expects constraints for each Loop IV, even
1713  // if they are the unsliced full loop bounds added here.
1714  if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst)))
1716  if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst)))
1718  }
1719  // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
1720  if (sliceUnionCst.getNumLocalVars() > 0 ||
1721  tmpSliceCst.getNumLocalVars() > 0 ||
1722  failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
1723  LDBG() << "Unable to compute union bounding box of slice bounds";
1725  }
1726  }
1727  }
1728 
1729  // Empty union.
1730  if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
1731  LDBG() << "empty slice union - unexpected";
1733  }
1734 
1735  // Gather loops surrounding ops from loop nest where slice will be inserted.
1737  for (auto &dep : dependentOpPairs) {
1738  ops.push_back(isBackwardSlice ? dep.second : dep.first);
1739  }
1740  SmallVector<AffineForOp, 4> surroundingLoops;
1741  unsigned innermostCommonLoopDepth =
1742  getInnermostCommonLoopDepth(ops, &surroundingLoops);
1743  if (loopDepth > innermostCommonLoopDepth) {
1744  LDBG() << "Exceeds max loop depth";
1746  }
1747 
1748  // Store 'numSliceLoopIVs' before converting dst loop IVs to dims.
1749  unsigned numSliceLoopIVs = sliceUnionCst.getNumDimVars();
1750 
1751  // Convert any dst loop IVs which are symbol variables to dim variables.
1752  sliceUnionCst.convertLoopIVSymbolsToDims();
1753  sliceUnion->clearBounds();
1754  sliceUnion->lbs.resize(numSliceLoopIVs, AffineMap());
1755  sliceUnion->ubs.resize(numSliceLoopIVs, AffineMap());
1756 
1757  // Get slice bounds from slice union constraints 'sliceUnionCst'.
1758  sliceUnionCst.getSliceBounds(/*offset=*/0, numSliceLoopIVs,
1759  opsA[0]->getContext(), &sliceUnion->lbs,
1760  &sliceUnion->ubs);
1761 
1762  // Add slice bound operands of union.
1763  SmallVector<Value, 4> sliceBoundOperands;
1764  sliceUnionCst.getValues(numSliceLoopIVs,
1765  sliceUnionCst.getNumDimAndSymbolVars(),
1766  &sliceBoundOperands);
1767 
1768  // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'.
1769  sliceUnion->ivs.clear();
1770  sliceUnionCst.getValues(0, numSliceLoopIVs, &sliceUnion->ivs);
1771 
1772  // Set loop nest insertion point to block start at 'loopDepth' for forward
1773  // slices, while at the end for backward slices.
1774  sliceUnion->insertPoint =
1775  isBackwardSlice
1776  ? surroundingLoops[loopDepth - 1].getBody()->begin()
1777  : std::prev(surroundingLoops[loopDepth - 1].getBody()->end());
1778 
1779  // Give each bound its own copy of 'sliceBoundOperands' for subsequent
1780  // canonicalization.
1781  sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1782  sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1783 
1784  // Check if the slice computed is valid. Return success only if it is verified
1785  // that the slice is valid, otherwise return appropriate failure status.
1786  std::optional<bool> isSliceValid = sliceUnion->isSliceValid();
1787  if (!isSliceValid) {
1788  LDBG() << "Cannot determine if the slice is valid";
1790  }
1791  if (!*isSliceValid)
1793 
1795 }
1796 
1797 // TODO: extend this to handle multiple result maps.
1798 static std::optional<uint64_t> getConstDifference(AffineMap lbMap,
1799  AffineMap ubMap) {
1800  assert(lbMap.getNumResults() == 1 && "expected single result bound map");
1801  assert(ubMap.getNumResults() == 1 && "expected single result bound map");
1802  assert(lbMap.getNumDims() == ubMap.getNumDims());
1803  assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
1804  AffineExpr lbExpr(lbMap.getResult(0));
1805  AffineExpr ubExpr(ubMap.getResult(0));
1806  auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
1807  lbMap.getNumSymbols());
1808  auto cExpr = dyn_cast<AffineConstantExpr>(loopSpanExpr);
1809  if (!cExpr)
1810  return std::nullopt;
1811  return cExpr.getValue();
1812 }
1813 
1814 // Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
1815 // nest surrounding represented by slice loop bounds in 'slice'. Returns true
1816 // on success, false otherwise (if a non-constant trip count was encountered).
1817 // TODO: Make this work with non-unit step loops.
1819  const ComputationSliceState &slice,
1820  llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
1821  unsigned numSrcLoopIVs = slice.ivs.size();
1822  // Populate map from AffineForOp -> trip count
1823  for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
1824  AffineForOp forOp = getForInductionVarOwner(slice.ivs[i]);
1825  auto *op = forOp.getOperation();
1826  AffineMap lbMap = slice.lbs[i];
1827  AffineMap ubMap = slice.ubs[i];
1828  // If lower or upper bound maps are null or provide no results, it implies
1829  // that source loop was not at all sliced, and the entire loop will be a
1830  // part of the slice.
1831  if (!lbMap || lbMap.getNumResults() == 0 || !ubMap ||
1832  ubMap.getNumResults() == 0) {
1833  // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
1834  if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
1835  (*tripCountMap)[op] =
1836  forOp.getConstantUpperBound() - forOp.getConstantLowerBound();
1837  continue;
1838  }
1839  std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
1840  if (maybeConstTripCount.has_value()) {
1841  (*tripCountMap)[op] = *maybeConstTripCount;
1842  continue;
1843  }
1844  return false;
1845  }
1846  std::optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
1847  // Slice bounds are created with a constant ub - lb difference.
1848  if (!tripCount.has_value())
1849  return false;
1850  (*tripCountMap)[op] = *tripCount;
1851  }
1852  return true;
1853 }
1854 
1855 // Return the number of iterations in the given slice.
1857  const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
1858  uint64_t iterCount = 1;
1859  for (const auto &count : sliceTripCountMap) {
1860  iterCount *= count.second;
1861  }
1862  return iterCount;
1863 }
1864 
1865 const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";
1866 // Computes slice bounds by projecting out any loop IVs from
1867 // 'dependenceConstraints' at depth greater than 'loopDepth', and computes slice
1868 // bounds in 'sliceState' which represent the one loop nest's IVs in terms of
1869 // the other loop nest's IVs, symbols and constants (using 'isBackwardsSlice').
1871  Operation *depSourceOp, Operation *depSinkOp,
1872  const FlatAffineValueConstraints &dependenceConstraints, unsigned loopDepth,
1873  bool isBackwardSlice, ComputationSliceState *sliceState) {
1874  // Get loop nest surrounding src operation.
1875  SmallVector<AffineForOp, 4> srcLoopIVs;
1876  getAffineForIVs(*depSourceOp, &srcLoopIVs);
1877  unsigned numSrcLoopIVs = srcLoopIVs.size();
1878 
1879  // Get loop nest surrounding dst operation.
1880  SmallVector<AffineForOp, 4> dstLoopIVs;
1881  getAffineForIVs(*depSinkOp, &dstLoopIVs);
1882  unsigned numDstLoopIVs = dstLoopIVs.size();
1883 
1884  assert((!isBackwardSlice && loopDepth <= numSrcLoopIVs) ||
1885  (isBackwardSlice && loopDepth <= numDstLoopIVs));
1886 
1887  // Project out dimensions other than those up to 'loopDepth'.
1888  unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth;
1889  unsigned num =
1890  isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth;
1891  FlatAffineValueConstraints sliceCst(dependenceConstraints);
1892  sliceCst.projectOut(pos, num);
1893 
1894  // Add slice loop IV values to 'sliceState'.
1895  unsigned offset = isBackwardSlice ? 0 : loopDepth;
1896  unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs;
1897  sliceCst.getValues(offset, offset + numSliceLoopIVs, &sliceState->ivs);
1898 
1899  // Set up lower/upper bound affine maps for the slice.
1900  sliceState->lbs.resize(numSliceLoopIVs, AffineMap());
1901  sliceState->ubs.resize(numSliceLoopIVs, AffineMap());
1902 
1903  // Get bounds for slice IVs in terms of other IVs, symbols, and constants.
1904  sliceCst.getSliceBounds(offset, numSliceLoopIVs, depSourceOp->getContext(),
1905  &sliceState->lbs, &sliceState->ubs);
1906 
1907  // Set up bound operands for the slice's lower and upper bounds.
1908  SmallVector<Value, 4> sliceBoundOperands;
1909  unsigned numDimsAndSymbols = sliceCst.getNumDimAndSymbolVars();
1910  for (unsigned i = 0; i < numDimsAndSymbols; ++i) {
1911  if (i < offset || i >= offset + numSliceLoopIVs)
1912  sliceBoundOperands.push_back(sliceCst.getValue(i));
1913  }
1914 
1915  // Give each bound its own copy of 'sliceBoundOperands' for subsequent
1916  // canonicalization.
1917  sliceState->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1918  sliceState->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1919 
1920  // Set destination loop nest insertion point to block start at 'dstLoopDepth'.
1921  sliceState->insertPoint =
1922  isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin()
1923  : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end());
1924 
1925  llvm::SmallDenseSet<Value, 8> sequentialLoops;
1926  if (isa<AffineReadOpInterface>(depSourceOp) &&
1927  isa<AffineReadOpInterface>(depSinkOp)) {
1928  // For read-read access pairs, clear any slice bounds on sequential loops.
1929  // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
1930  getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0],
1931  &sequentialLoops);
1932  }
1933  auto getSliceLoop = [&](unsigned i) {
1934  return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i];
1935  };
1936  auto isInnermostInsertion = [&]() {
1937  return (isBackwardSlice ? loopDepth >= srcLoopIVs.size()
1938  : loopDepth >= dstLoopIVs.size());
1939  };
1940  llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
1941  auto srcIsUnitSlice = [&]() {
1942  return (buildSliceTripCountMap(*sliceState, &sliceTripCountMap) &&
1943  (getSliceIterationCount(sliceTripCountMap) == 1));
1944  };
1945  // Clear all sliced loop bounds beginning at the first sequential loop, or
1946  // first loop with a slice fusion barrier attribute..
1947 
1948  for (unsigned i = 0; i < numSliceLoopIVs; ++i) {
1949  Value iv = getSliceLoop(i).getInductionVar();
1950  if (sequentialLoops.count(iv) == 0 &&
1951  getSliceLoop(i)->getAttr(kSliceFusionBarrierAttrName) == nullptr)
1952  continue;
1953  // Skip reset of bounds of reduction loop inserted in the destination loop
1954  // that meets the following conditions:
1955  // 1. Slice is single trip count.
1956  // 2. Loop bounds of the source and destination match.
1957  // 3. Is being inserted at the innermost insertion point.
1958  std::optional<bool> isMaximal = sliceState->isMaximal();
1959  if (isLoopParallelAndContainsReduction(getSliceLoop(i)) &&
1960  isInnermostInsertion() && srcIsUnitSlice() && isMaximal && *isMaximal)
1961  continue;
1962  for (unsigned j = i; j < numSliceLoopIVs; ++j) {
1963  sliceState->lbs[j] = AffineMap();
1964  sliceState->ubs[j] = AffineMap();
1965  }
1966  break;
1967  }
1968 }
1969 
1970 /// Creates a computation slice of the loop nest surrounding 'srcOpInst',
1971 /// updates the slice loop bounds with any non-null bound maps specified in
1972 /// 'sliceState', and inserts this slice into the loop nest surrounding
1973 /// 'dstOpInst' at loop depth 'dstLoopDepth'.
1974 // TODO: extend the slicing utility to compute slices that
1975 // aren't necessarily a one-to-one relation b/w the source and destination. The
1976 // relation between the source and destination could be many-to-many in general.
1977 // TODO: the slice computation is incorrect in the cases
1978 // where the dependence from the source to the destination does not cover the
1979 // entire destination index set. Subtract out the dependent destination
1980 // iterations from destination index set and check for emptiness --- this is one
1981 // solution.
1983  Operation *srcOpInst, Operation *dstOpInst, unsigned dstLoopDepth,
1984  ComputationSliceState *sliceState) {
1985  // Get loop nest surrounding src operation.
1986  SmallVector<AffineForOp, 4> srcLoopIVs;
1987  getAffineForIVs(*srcOpInst, &srcLoopIVs);
1988  unsigned numSrcLoopIVs = srcLoopIVs.size();
1989 
1990  // Get loop nest surrounding dst operation.
1991  SmallVector<AffineForOp, 4> dstLoopIVs;
1992  getAffineForIVs(*dstOpInst, &dstLoopIVs);
1993  unsigned dstLoopIVsSize = dstLoopIVs.size();
1994  if (dstLoopDepth > dstLoopIVsSize) {
1995  dstOpInst->emitError("invalid destination loop depth");
1996  return AffineForOp();
1997  }
1998 
1999  // Find the op block positions of 'srcOpInst' within 'srcLoopIVs'.
2000  SmallVector<unsigned, 4> positions;
2001  // TODO: This code is incorrect since srcLoopIVs can be 0-d.
2002  findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions);
2003 
2004  // Clone src loop nest and insert it a the beginning of the operation block
2005  // of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
2006  auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1];
2007  OpBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin());
2008  auto sliceLoopNest =
2009  cast<AffineForOp>(b.clone(*srcLoopIVs[0].getOperation()));
2010 
2011  Operation *sliceInst =
2012  getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody());
2013  // Get loop nest surrounding 'sliceInst'.
2014  SmallVector<AffineForOp, 4> sliceSurroundingLoops;
2015  getAffineForIVs(*sliceInst, &sliceSurroundingLoops);
2016 
2017  // Sanity check.
2018  unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size();
2019  (void)sliceSurroundingLoopsSize;
2020  assert(dstLoopDepth + numSrcLoopIVs >= sliceSurroundingLoopsSize);
2021  unsigned sliceLoopLimit = dstLoopDepth + numSrcLoopIVs;
2022  (void)sliceLoopLimit;
2023  assert(sliceLoopLimit >= sliceSurroundingLoopsSize);
2024 
2025  // Update loop bounds for loops in 'sliceLoopNest'.
2026  for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
2027  auto forOp = sliceSurroundingLoops[dstLoopDepth + i];
2028  if (AffineMap lbMap = sliceState->lbs[i])
2029  forOp.setLowerBound(sliceState->lbOperands[i], lbMap);
2030  if (AffineMap ubMap = sliceState->ubs[i])
2031  forOp.setUpperBound(sliceState->ubOperands[i], ubMap);
2032  }
2033  return sliceLoopNest;
2034 }
2035 
2036 // Constructs MemRefAccess populating it with the memref, its indices and
2037 // opinst from 'loadOrStoreOpInst'.
2039  if (auto loadOp = dyn_cast<AffineReadOpInterface>(memOp)) {
2040  memref = loadOp.getMemRef();
2041  opInst = memOp;
2042  llvm::append_range(indices, loadOp.getMapOperands());
2043  } else {
2044  assert(isa<AffineWriteOpInterface>(memOp) &&
2045  "Affine read/write op expected");
2046  auto storeOp = cast<AffineWriteOpInterface>(memOp);
2047  opInst = memOp;
2048  memref = storeOp.getMemRef();
2049  llvm::append_range(indices, storeOp.getMapOperands());
2050  }
2051 }
2052 
2053 unsigned MemRefAccess::getRank() const {
2054  return cast<MemRefType>(memref.getType()).getRank();
2055 }
2056 
2058  return isa<AffineWriteOpInterface>(opInst);
2059 }
2060 
2061 /// Returns the nesting depth of this statement, i.e., the number of loops
2062 /// surrounding this statement.
2064  Operation *currOp = op;
2065  unsigned depth = 0;
2066  while ((currOp = currOp->getParentOp())) {
2067  if (isa<AffineForOp>(currOp))
2068  depth++;
2069  if (auto parOp = dyn_cast<AffineParallelOp>(currOp))
2070  depth += parOp.getNumDims();
2071  }
2072  return depth;
2073 }
2074 
2075 /// Equal if both affine accesses are provably equivalent (at compile
2076 /// time) when considering the memref, the affine maps and their respective
2077 /// operands. The equality of access functions + operands is checked by
2078 /// subtracting fully composed value maps, and then simplifying the difference
2079 /// using the expression flattener.
2080 /// TODO: this does not account for aliasing of memrefs.
2081 bool MemRefAccess::operator==(const MemRefAccess &rhs) const {
2082  if (memref != rhs.memref)
2083  return false;
2084 
2085  AffineValueMap diff, thisMap, rhsMap;
2086  getAccessMap(&thisMap);
2087  rhs.getAccessMap(&rhsMap);
2088  return thisMap == rhsMap;
2089 }
2090 
2092  auto *currOp = op.getParentOp();
2093  AffineForOp currAffineForOp;
2094  // Traverse up the hierarchy collecting all 'affine.for' and affine.parallel
2095  // operation while skipping over 'affine.if' operations.
2096  while (currOp) {
2097  if (AffineForOp currAffineForOp = dyn_cast<AffineForOp>(currOp))
2098  ivs.push_back(currAffineForOp.getInductionVar());
2099  else if (auto parOp = dyn_cast<AffineParallelOp>(currOp))
2100  llvm::append_range(ivs, parOp.getIVs());
2101  currOp = currOp->getParentOp();
2102  }
2103  std::reverse(ivs.begin(), ivs.end());
2104 }
2105 
2106 /// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
2107 /// where each lists loops from outer-most to inner-most in loop nest.
2109  Operation &b) {
2110  SmallVector<Value, 4> loopsA, loopsB;
2111  getAffineIVs(a, loopsA);
2112  getAffineIVs(b, loopsB);
2113 
2114  unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
2115  unsigned numCommonLoops = 0;
2116  for (unsigned i = 0; i < minNumLoops; ++i) {
2117  if (loopsA[i] != loopsB[i])
2118  break;
2119  ++numCommonLoops;
2120  }
2121  return numCommonLoops;
2122 }
2123 
2124 static std::optional<int64_t> getMemoryFootprintBytes(Block &block,
2125  Block::iterator start,
2126  Block::iterator end,
2127  int memorySpace) {
2128  SmallDenseMap<Value, std::unique_ptr<MemRefRegion>, 4> regions;
2129 
2130  // Walk this 'affine.for' operation to gather all memory regions.
2131  auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult {
2132  if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst)) {
2133  // Neither load nor a store op.
2134  return WalkResult::advance();
2135  }
2136 
2137  // Compute the memref region symbolic in any IVs enclosing this block.
2138  auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
2139  if (failed(
2140  region->compute(opInst,
2141  /*loopDepth=*/getNestingDepth(&*block.begin())))) {
2142  LDBG() << "Error obtaining memory region";
2143  opInst->emitError("error obtaining memory region");
2144  return failure();
2145  }
2146 
2147  auto [it, inserted] = regions.try_emplace(region->memref);
2148  if (inserted) {
2149  it->second = std::move(region);
2150  } else if (failed(it->second->unionBoundingBox(*region))) {
2151  LDBG() << "getMemoryFootprintBytes: unable to perform a union on a "
2152  "memory region";
2153  opInst->emitWarning(
2154  "getMemoryFootprintBytes: unable to perform a union on a memory "
2155  "region");
2156  return failure();
2157  }
2158  return WalkResult::advance();
2159  });
2160  if (result.wasInterrupted())
2161  return std::nullopt;
2162 
2163  int64_t totalSizeInBytes = 0;
2164  for (const auto &region : regions) {
2165  std::optional<int64_t> size = region.second->getRegionSize();
2166  if (!size.has_value())
2167  return std::nullopt;
2168  totalSizeInBytes += *size;
2169  }
2170  return totalSizeInBytes;
2171 }
2172 
2173 std::optional<int64_t> mlir::affine::getMemoryFootprintBytes(AffineForOp forOp,
2174  int memorySpace) {
2175  auto *forInst = forOp.getOperation();
2177  *forInst->getBlock(), Block::iterator(forInst),
2178  std::next(Block::iterator(forInst)), memorySpace);
2179 }
2180 
2181 /// Returns whether a loop is parallel and contains a reduction loop.
2183  SmallVector<LoopReduction> reductions;
2184  if (!isLoopParallel(forOp, &reductions))
2185  return false;
2186  return !reductions.empty();
2187 }
2188 
2189 /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
2190 /// at 'forOp'.
2192  AffineForOp forOp, llvm::SmallDenseSet<Value, 8> *sequentialLoops) {
2193  forOp->walk([&](Operation *op) {
2194  if (auto innerFor = dyn_cast<AffineForOp>(op))
2195  if (!isLoopParallel(innerFor))
2196  sequentialLoops->insert(innerFor.getInductionVar());
2197  });
2198 }
2199 
2201  FlatAffineValueConstraints fac(set);
2202  if (fac.isEmpty())
2203  return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(),
2204  set.getContext());
2206 
2207  auto simplifiedSet = fac.getAsIntegerSet(set.getContext());
2208  assert(simplifiedSet && "guaranteed to succeed while roundtripping");
2209  return simplifiedSet;
2210 }
2211 
2212 static void unpackOptionalValues(ArrayRef<std::optional<Value>> source,
2213  SmallVector<Value> &target) {
2214  target =
2215  llvm::to_vector<4>(llvm::map_range(source, [](std::optional<Value> val) {
2216  return val.has_value() ? *val : Value();
2217  }));
2218 }
2219 
2220 /// Bound an identifier `pos` in a given FlatAffineValueConstraints with
2221 /// constraints drawn from an affine map. Before adding the constraint, the
2222 /// dimensions/symbols of the affine map are aligned with `constraints`.
2223 /// `operands` are the SSA Value operands used with the affine map.
2224 /// Note: This function adds a new symbol column to the `constraints` for each
2225 /// dimension/symbol that exists in the affine map but not in `constraints`.
2226 static LogicalResult alignAndAddBound(FlatAffineValueConstraints &constraints,
2227  BoundType type, unsigned pos,
2228  AffineMap map, ValueRange operands) {
2229  SmallVector<Value> dims, syms, newSyms;
2230  unpackOptionalValues(constraints.getMaybeValues(VarKind::SetDim), dims);
2231  unpackOptionalValues(constraints.getMaybeValues(VarKind::Symbol), syms);
2232 
2233  AffineMap alignedMap =
2234  alignAffineMapWithValues(map, operands, dims, syms, &newSyms);
2235  for (unsigned i = syms.size(); i < newSyms.size(); ++i)
2236  constraints.appendSymbolVar(newSyms[i]);
2237  return constraints.addBound(type, pos, alignedMap);
2238 }
2239 
2240 /// Add `val` to each result of `map`.
2241 static AffineMap addConstToResults(AffineMap map, int64_t val) {
2242  SmallVector<AffineExpr> newResults;
2243  for (AffineExpr r : map.getResults())
2244  newResults.push_back(r + val);
2245  return AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
2246  map.getContext());
2247 }
2248 
2249 // Attempt to simplify the given min/max operation by proving that its value is
2250 // bounded by the same lower and upper bound.
2251 //
2252 // Bounds are computed by FlatAffineValueConstraints. Invariants required for
2253 // finding/proving bounds should be supplied via `constraints`.
2254 //
2255 // 1. Add dimensions for `op` and `opBound` (lower or upper bound of `op`).
2256 // 2. Compute an upper bound of `op` (in case of `isMin`) or a lower bound (in
2257 // case of `!isMin`) and bind it to `opBound`. SSA values that are used in
2258 // `op` but are not part of `constraints`, are added as extra symbols.
2259 // 3. For each result of `op`: Add result as a dimension `r_i`. Prove that:
2260 // * If `isMin`: r_i >= opBound
2261 // * If `isMax`: r_i <= opBound
2262 // If this is the case, ub(op) == lb(op).
2263 // 4. Replace `op` with `opBound`.
2264 //
2265 // In summary, the following constraints are added throughout this function.
2266 // Note: `invar` are dimensions added by the caller to express the invariants.
2267 // (Showing only the case where `isMin`.)
2268 //
2269 // invar | op | opBound | r_i | extra syms... | const | eq/ineq
2270 // ------+-------+---------+-----+---------------+-------+-------------------
2271 // (various eq./ineq. constraining `invar`, added by the caller)
2272 // ... | 0 | 0 | 0 | 0 | ... | ...
2273 // ------+-------+---------+-----+---------------+-------+-------------------
2274 // (various ineq. constraining `op` in terms of `op` operands (`invar` and
2275 // extra `op` operands "extra syms" that are not in `invar`)).
2276 // ... | -1 | 0 | 0 | ... | ... | >= 0
2277 // ------+-------+---------+-----+---------------+-------+-------------------
2278 // (set `opBound` to `op` upper bound in terms of `invar` and "extra syms")
2279 // ... | 0 | -1 | 0 | ... | ... | = 0
2280 // ------+-------+---------+-----+---------------+-------+-------------------
2281 // (for each `op` map result r_i: set r_i to corresponding map result,
2282 // prove that r_i >= minOpUb via contradiction)
2283 // ... | 0 | 0 | -1 | ... | ... | = 0
2284 // 0 | 0 | 1 | -1 | 0 | -1 | >= 0
2285 //
2287  Operation *op, FlatAffineValueConstraints constraints) {
2288  bool isMin = isa<AffineMinOp>(op);
2289  assert((isMin || isa<AffineMaxOp>(op)) && "expect AffineMin/MaxOp");
2290  MLIRContext *ctx = op->getContext();
2291  Builder builder(ctx);
2292  AffineMap map =
2293  isMin ? cast<AffineMinOp>(op).getMap() : cast<AffineMaxOp>(op).getMap();
2294  ValueRange operands = op->getOperands();
2295  unsigned numResults = map.getNumResults();
2296 
2297  // Add a few extra dimensions.
2298  unsigned dimOp = constraints.appendDimVar(); // `op`
2299  unsigned dimOpBound = constraints.appendDimVar(); // `op` lower/upper bound
2300  unsigned resultDimStart = constraints.appendDimVar(/*num=*/numResults);
2301 
2302  // Add an inequality for each result expr_i of map:
2303  // isMin: op <= expr_i, !isMin: op >= expr_i
2304  auto boundType = isMin ? BoundType::UB : BoundType::LB;
2305  // Upper bounds are exclusive, so add 1. (`affine.min` ops are inclusive.)
2306  AffineMap mapLbUb = isMin ? addConstToResults(map, 1) : map;
2307  if (failed(
2308  alignAndAddBound(constraints, boundType, dimOp, mapLbUb, operands)))
2309  return failure();
2310 
2311  // Try to compute a lower/upper bound for op, expressed in terms of the other
2312  // `dims` and extra symbols.
2313  SmallVector<AffineMap> opLb(1), opUb(1);
2314  constraints.getSliceBounds(dimOp, 1, ctx, &opLb, &opUb);
2315  AffineMap sliceBound = isMin ? opUb[0] : opLb[0];
2316  // TODO: `getSliceBounds` may return multiple bounds at the moment. This is
2317  // a TODO of `getSliceBounds` and not handled here.
2318  if (!sliceBound || sliceBound.getNumResults() != 1)
2319  return failure(); // No or multiple bounds found.
2320  // Recover the inclusive UB in the case of an `affine.min`.
2321  AffineMap boundMap = isMin ? addConstToResults(sliceBound, -1) : sliceBound;
2322 
2323  // Add an equality: Set dimOpBound to computed bound.
2324  // Add back dimension for op. (Was removed by `getSliceBounds`.)
2325  AffineMap alignedBoundMap = boundMap.shiftDims(/*shift=*/1, /*offset=*/dimOp);
2326  if (failed(constraints.addBound(BoundType::EQ, dimOpBound, alignedBoundMap)))
2327  return failure();
2328 
2329  // If the constraint system is empty, there is an inconsistency. (E.g., this
2330  // can happen if loop lb > ub.)
2331  if (constraints.isEmpty())
2332  return failure();
2333 
2334  // In the case of `isMin` (`!isMin` is inversed):
2335  // Prove that each result of `map` has a lower bound that is equal to (or
2336  // greater than) the upper bound of `op` (`dimOpBound`). In that case, `op`
2337  // can be replaced with the bound. I.e., prove that for each result
2338  // expr_i (represented by dimension r_i):
2339  //
2340  // r_i >= opBound
2341  //
2342  // To prove this inequality, add its negation to the constraint set and prove
2343  // that the constraint set is empty.
2344  for (unsigned i = resultDimStart; i < resultDimStart + numResults; ++i) {
2345  FlatAffineValueConstraints newConstr(constraints);
2346 
2347  // Add an equality: r_i = expr_i
2348  // Note: These equalities could have been added earlier and used to express
2349  // minOp <= expr_i. However, then we run the risk that `getSliceBounds`
2350  // computes minOpUb in terms of r_i dims, which is not desired.
2351  if (failed(alignAndAddBound(newConstr, BoundType::EQ, i,
2352  map.getSubMap({i - resultDimStart}), operands)))
2353  return failure();
2354 
2355  // If `isMin`: Add inequality: r_i < opBound
2356  // equiv.: opBound - r_i - 1 >= 0
2357  // If `!isMin`: Add inequality: r_i > opBound
2358  // equiv.: -opBound + r_i - 1 >= 0
2359  SmallVector<int64_t> ineq(newConstr.getNumCols(), 0);
2360  ineq[dimOpBound] = isMin ? 1 : -1;
2361  ineq[i] = isMin ? -1 : 1;
2362  ineq[newConstr.getNumCols() - 1] = -1;
2363  newConstr.addInequality(ineq);
2364  if (!newConstr.isEmpty())
2365  return failure();
2366  }
2367 
2368  // Lower and upper bound of `op` are equal. Replace `minOp` with its bound.
2369  AffineMap newMap = alignedBoundMap;
2370  SmallVector<Value> newOperands;
2371  unpackOptionalValues(constraints.getMaybeValues(), newOperands);
2372  // If dims/symbols have known constant values, use those in order to simplify
2373  // the affine map further.
2374  for (int64_t i = 0, e = constraints.getNumDimAndSymbolVars(); i < e; ++i) {
2375  // Skip unused operands and operands that are already constants.
2376  if (!newOperands[i] || getConstantIntValue(newOperands[i]))
2377  continue;
2378  if (auto bound = constraints.getConstantBound64(BoundType::EQ, i)) {
2379  AffineExpr expr =
2380  i < newMap.getNumDims()
2381  ? builder.getAffineDimExpr(i)
2382  : builder.getAffineSymbolExpr(i - newMap.getNumDims());
2383  newMap = newMap.replace(expr, builder.getAffineConstantExpr(*bound),
2384  newMap.getNumDims(), newMap.getNumSymbols());
2385  }
2386  }
2387  affine::canonicalizeMapAndOperands(&newMap, &newOperands);
2388  return AffineValueMap(newMap, newOperands);
2389 }
2390 
2392  Operation *b) {
2393  Region *aScope = getAffineAnalysisScope(a);
2394  Region *bScope = getAffineAnalysisScope(b);
2395  if (aScope != bScope)
2396  return nullptr;
2397 
2398  // Get the block ancestry of `op` while stopping at the affine scope `aScope`
2399  // and store them in `ancestry`.
2400  auto getBlockAncestry = [&](Operation *op,
2401  SmallVectorImpl<Block *> &ancestry) {
2402  Operation *curOp = op;
2403  do {
2404  ancestry.push_back(curOp->getBlock());
2405  if (curOp->getParentRegion() == aScope)
2406  break;
2407  curOp = curOp->getParentOp();
2408  } while (curOp);
2409  assert(curOp && "can't reach root op without passing through affine scope");
2410  std::reverse(ancestry.begin(), ancestry.end());
2411  };
2412 
2413  SmallVector<Block *, 4> aAncestors, bAncestors;
2414  getBlockAncestry(a, aAncestors);
2415  getBlockAncestry(b, bAncestors);
2416  assert(!aAncestors.empty() && !bAncestors.empty() &&
2417  "at least one Block ancestor expected");
2418 
2419  Block *innermostCommonBlock = nullptr;
2420  for (unsigned a = 0, b = 0, e = aAncestors.size(), f = bAncestors.size();
2421  a < e && b < f; ++a, ++b) {
2422  if (aAncestors[a] != bAncestors[b])
2423  break;
2424  innermostCommonBlock = aAncestors[a];
2425  }
2426  return innermostCommonBlock;
2427 }
static Value getMemRef(Operation *memOp)
Returns the memref being read/written by a memref/affine load/store op.
Definition: Utils.cpp:246
static std::optional< uint64_t > getConstDifference(AffineMap lbMap, AffineMap ubMap)
Definition: Utils.cpp:1798
static void findInstPosition(Operation *op, Block *limitBlock, SmallVectorImpl< unsigned > *positions)
Definition: Utils.cpp:1543
static Node * addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg, DenseMap< Value, SetVector< unsigned >> &memrefAccesses)
Add op to MDG creating a new node and adding its memory accesses (affine or non-affine to memrefAcces...
Definition: Utils.cpp:191
static bool mayDependence(const Node &srcNode, const Node &dstNode, Value memref)
Returns true if there may be a dependence on memref from srcNode's memory ops to dstNode's memory ops...
Definition: Utils.cpp:264
const char *const kSliceFusionBarrierAttrName
Definition: Utils.cpp:1865
static LogicalResult addMissingLoopIVBounds(SmallPtrSet< Value, 8 > &ivs, FlatAffineValueConstraints *cst)
Definition: Utils.cpp:1585
static void getEffectedValues(Operation *op, SmallVectorImpl< Value > &values)
Returns the values that this op has a memref effect of type EffectTys on, not considering recursive e...
Definition: Utils.cpp:165
MemRefDependenceGraph::Node Node
Definition: Utils.cpp:38
static Operation * getInstAtPosition(ArrayRef< unsigned > positions, unsigned level, Block *block)
Definition: Utils.cpp:1560
static std::optional< int64_t > getMemoryFootprintBytes(Block &block, Block::iterator start, Block::iterator end, int memorySpace)
Definition: Utils.cpp:2124
static AffineMap addConstToResults(AffineMap map, int64_t val)
Add val to each result of map.
Definition: Utils.cpp:2241
static LogicalResult alignAndAddBound(FlatAffineValueConstraints &constraints, BoundType type, unsigned pos, AffineMap map, ValueRange operands)
Bound an identifier pos in a given FlatAffineValueConstraints with constraints drawn from an affine m...
Definition: Utils.cpp:2226
static void unpackOptionalValues(ArrayRef< std::optional< Value >> source, SmallVector< Value > &target)
Definition: Utils.cpp:2212
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:223
unsigned getPosition() const
Definition: AffineExpr.cpp:346
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:339
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineMap.h:267
unsigned getNumSymbols() const
Definition: AffineMap.cpp:394
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
unsigned getNumResults() const
Definition: AffineMap.cpp:398
unsigned getNumInputs() const
Definition: AffineMap.cpp:399
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:511
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:647
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:308
iterator begin()
Definition: Block.h:143
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:367
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:371
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:363
IntegerSet getAsIntegerSet(MLIRContext *context) const
Returns the constraint system as an integer set.
void getSliceBounds(unsigned offset, unsigned num, MLIRContext *context, SmallVectorImpl< AffineMap > *lbMaps, SmallVectorImpl< AffineMap > *ubMaps, bool closedUB=false)
Computes the lower and upper bounds of the first num dimensional variables (starting at offset) as an...
std::optional< int64_t > getConstantBoundOnDimSize(MLIRContext *context, unsigned pos, AffineMap *lb=nullptr, AffineMap *ub=nullptr, unsigned *minLbPos=nullptr, unsigned *minUbPos=nullptr) const
Returns a non-negative constant bound on the extent (upper bound - lower bound) of the specified vari...
FlatLinearValueConstraints represents an extension of FlatLinearConstraints where each non-local vari...
LogicalResult unionBoundingBox(const FlatLinearValueConstraints &other)
Updates the constraints to be the smallest bounding (enclosing) box that contains the points of this ...
void mergeAndAlignVarsWithOther(unsigned offset, FlatLinearValueConstraints *other)
Merge and align the variables of this and other starting at offset, so that both constraint systems g...
SmallVector< std::optional< Value > > getMaybeValues() const
Value getValue(unsigned pos) const
Returns the Value associated with the pos^th variable.
void projectOut(Value val)
Projects out the variable that is associate with Value.
bool containsVar(Value val) const
Returns true if a variable with the specified Value exists, false otherwise.
void addBound(presburger::BoundType type, Value val, int64_t value)
Adds a constant bound for the variable associated with the given Value.
bool areVarsAlignedWithOther(const FlatLinearConstraints &other)
Returns true if this constraint system and other are in the same space, i.e., if they are associated ...
void getValues(unsigned start, unsigned end, SmallVectorImpl< Value > *values) const
Returns the Values associated with variables in range [start, end).
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition: IntegerSet.h:44
unsigned getNumDims() const
Definition: IntegerSet.cpp:15
MLIRContext * getContext() const
Definition: IntegerSet.cpp:57
static IntegerSet getEmptySet(unsigned numDims, unsigned numSymbols, MLIRContext *context)
Definition: IntegerSet.h:56
unsigned getNumSymbols() const
Definition: IntegerSet.cpp:16
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:552
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Definition: Operation.cpp:279
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:415
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
Value getOperand(unsigned i) const
unsigned getNumOperands() const
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
LogicalResult addBound(presburger::BoundType type, unsigned pos, AffineMap boundMap, ValueRange operands)
Adds a bound for the variable at the specified position with constraints being drawn from the specifi...
void convertLoopIVSymbolsToDims()
Changes all symbol variables which are loop IVs to dim variables.
LogicalResult addDomainFromSliceMaps(ArrayRef< AffineMap > lbMaps, ArrayRef< AffineMap > ubMaps, ArrayRef< Value > operands)
Adds constraints (lower and upper bounds) for each loop in the loop nest described by the bound maps ...
LogicalResult addAffineForOpDomain(AffineForOp forOp)
Adds constraints (lower and upper bounds) for the specified 'affine.for' operation's Value using IR i...
LogicalResult addSliceBounds(ArrayRef< Value > values, ArrayRef< AffineMap > lbMaps, ArrayRef< AffineMap > ubMaps, ArrayRef< Value > operands)
Adds slice lower bounds represented by lower bounds in lbMaps and upper bounds in ubMaps to each vari...
void removeTrivialRedundancy()
Removes duplicate constraints, trivially true constraints, and constraints that can be detected as re...
std::optional< int64_t > getConstantBound64(BoundType type, unsigned pos) const
The same, but casts to int64_t.
bool isEmpty() const
Checks for emptiness by performing variable elimination on all variables, running the GCD test on eac...
unsigned getNumCols() const
Returns the number of columns in the constraint system.
void addInequality(ArrayRef< DynamicAPInt > inEq)
Adds an inequality (>= 0) from the coefficients specified in inEq.
bool isIntegerEmpty() const
Return true if all the sets in the union are known to be integer empty false otherwise.
PresburgerSet subtract(const PresburgerRelation &set) const
std::optional< uint64_t > getConstantTripCount(AffineForOp forOp)
Returns the trip count of the loop if it's a constant, std::nullopt otherwise.
IntegerSet simplifyIntegerSet(IntegerSet set)
Simplify the integer set by simplifying the underlying affine expressions by flattening and some simp...
Definition: Utils.cpp:2200
void getEnclosingAffineOps(Operation &op, SmallVectorImpl< Operation * > *ops)
Populates 'ops' with affine operations enclosing op ordered from outermost to innermost while stoppin...
Definition: Utils.cpp:865
SliceComputationResult computeSliceUnion(ArrayRef< Operation * > opsA, ArrayRef< Operation * > opsB, unsigned loopDepth, unsigned numCommonLoops, bool isBackwardSlice, ComputationSliceState *sliceUnion)
Computes in 'sliceUnion' the union of all slice bounds computed at 'loopDepth' between all dependent ...
Definition: Utils.cpp:1633
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
Definition: AffineOps.cpp:2797
bool isLoopParallelAndContainsReduction(AffineForOp forOp)
Returns whether a loop is a parallel loop and contains a reduction loop.
Definition: Utils.cpp:2182
unsigned getNumCommonSurroundingLoops(Operation &a, Operation &b)
Returns the number of surrounding loops common to both A and B.
Definition: Utils.cpp:2108
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: AffineOps.cpp:2801
void getAffineIVs(Operation &op, SmallVectorImpl< Value > &ivs)
Populates 'ivs' with IVs of the surrounding affine.for and affine.parallel ops ordered from the outer...
Definition: Utils.cpp:2091
void getSequentialLoops(AffineForOp forOp, llvm::SmallDenseSet< Value, 8 > *sequentialLoops)
Returns in 'sequentialLoops' all sequential loops in loop nest rooted at 'forOp'.
Definition: Utils.cpp:2191
DependenceResult checkMemrefAccessDependence(const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, unsigned loopDepth, FlatAffineValueConstraints *dependenceConstraints=nullptr, SmallVector< DependenceComponent, 2 > *dependenceComponents=nullptr, bool allowRAR=false)
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
Definition: AffineOps.cpp:1619
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
Definition: AffineOps.cpp:2789
Region * getAffineAnalysisScope(Operation *op)
Returns the closest region enclosing op that is held by a non-affine operation; nullptr if there is n...
Definition: AffineOps.cpp:275
void getAffineForIVs(Operation &op, SmallVectorImpl< AffineForOp > *loops)
Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered from the outermost 'affine....
Definition: Utils.cpp:851
std::optional< int64_t > getMemoryFootprintBytes(AffineForOp forOp, int memorySpace=-1)
Gets the memory footprint of all data touched in the specified memory space in bytes; if the memory s...
Definition: Utils.cpp:2173
unsigned getInnermostCommonLoopDepth(ArrayRef< Operation * > ops, SmallVectorImpl< AffineForOp > *surroundingLoops=nullptr)
Returns the innermost common loop depth for the set of operations in 'ops'.
Definition: Utils.cpp:1601
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
Definition: AffineOps.cpp:412
void getComputationSliceState(Operation *depSourceOp, Operation *depSinkOp, const FlatAffineValueConstraints &dependenceConstraints, unsigned loopDepth, bool isBackwardSlice, ComputationSliceState *sliceState)
Computes the computation slice loop bounds for one loop nest as affine maps of the other loop nest's ...
Definition: Utils.cpp:1870
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Definition: AffineOps.cpp:2812
std::optional< uint64_t > getIntOrFloatMemRefSizeInBytes(MemRefType memRefType)
Returns the size of a memref with element type int or float in bytes if it's statically shaped,...
Definition: Utils.cpp:1463
unsigned getNestingDepth(Operation *op)
Returns the nesting depth of this operation, i.e., the number of loops surrounding this operation.
Definition: Utils.cpp:2063
uint64_t getSliceIterationCount(const llvm::SmallDenseMap< Operation *, uint64_t, 8 > &sliceTripCountMap)
Return the number of iterations for the slicetripCountMap provided.
Definition: Utils.cpp:1856
bool isLoopParallel(AffineForOp forOp, SmallVectorImpl< LoopReduction > *parallelReductions=nullptr)
Returns true if ‘forOp’ is a parallel loop.
LogicalResult boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, bool emitError=true)
Checks a load or store op for an out of bound access; returns failure if the access is out of bounds ...
mlir::Block * findInnermostCommonBlockInScope(mlir::Operation *a, mlir::Operation *b)
Find the innermost common Block of a and b in the affine scope that a and b are part of.
Definition: Utils.cpp:2391
bool noDependence(DependenceResult result)
Returns true if the provided DependenceResult corresponds to the absence of a dependence.
bool buildSliceTripCountMap(const ComputationSliceState &slice, llvm::SmallDenseMap< Operation *, uint64_t, 8 > *tripCountMap)
Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop nest surrounding represe...
Definition: Utils.cpp:1818
AffineForOp insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst, unsigned dstLoopDepth, ComputationSliceState *sliceState)
Creates a clone of the computation contained in the loop nest surrounding 'srcOpInst',...
Definition: Utils.cpp:1982
FailureOr< AffineValueMap > simplifyConstrainedMinMaxOp(Operation *op, FlatAffineValueConstraints constraints)
Try to simplify the given affine.min or affine.max op to an affine map with a single result and opera...
Definition: Utils.cpp:2286
std::optional< int64_t > getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType)
Returns the memref's element type's size in bytes where the elemental type is an int or float or a ve...
Definition: Utils.cpp:1419
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
BoundType
The type of bound: equal, lower bound or upper bound.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
AffineMap alignAffineMapWithValues(AffineMap map, ValueRange operands, ValueRange dims, ValueRange syms, SmallVector< Value > *newSyms=nullptr)
Re-indexes the dimensions and symbols of an affine map with given operands values to align with dims ...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:643
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their associated operands for a ...
Definition: Utils.h:318
std::optional< bool > isSliceValid() const
Checks the validity of the slice computed.
Definition: Utils.cpp:1036
SmallVector< Value, 4 > ivs
Definition: Utils.h:321
LogicalResult getAsConstraints(FlatAffineValueConstraints *cst) const
Definition: Utils.cpp:898
LogicalResult getSourceAsConstraints(FlatAffineValueConstraints &cst) const
Adds to 'cst' constraints which represent the original loop bounds on 'ivs' in 'this'.
Definition: Utils.cpp:882
std::vector< SmallVector< Value, 4 > > ubOperands
Definition: Utils.h:329
SmallVector< AffineMap, 4 > ubs
Definition: Utils.h:325
std::optional< bool > isMaximal() const
Returns true if the computation slice encloses all the iterations of the sliced loop nest.
Definition: Utils.cpp:1106
SmallVector< AffineMap, 4 > lbs
Definition: Utils.h:323
std::vector< SmallVector< Value, 4 > > lbOperands
Definition: Utils.h:327
Checks whether two accesses to the same memref access the same element.
enum mlir::affine::DependenceResult::ResultEnum value
SmallVector< Operation *, 4 > memrefFrees
Definition: Utils.h:49
SmallVector< Operation *, 4 > loadOpInsts
Definition: Utils.h:41
SmallVector< Operation *, 4 > memrefStores
Definition: Utils.h:47
void collect(Operation *opToWalk)
Definition: Utils.cpp:43
SmallVector< Operation *, 4 > memrefLoads
Definition: Utils.h:45
SmallVector< Operation *, 4 > storeOpInsts
Definition: Utils.h:43
Encapsulates a memref load or store access information.
unsigned getRank() const
Definition: Utils.cpp:2053
void getAccessMap(AffineValueMap *accessMap) const
Populates 'accessMap' with composition of AffineApplyOps reachable from 'indices'.
bool operator==(const MemRefAccess &rhs) const
Equal if both affine accesses can be proved to be equivalent at compile time (considering the memrefs...
Definition: Utils.cpp:2081
SmallVector< Operation *, 4 > loads
Definition: Utils.h:73
SmallVector< Operation *, 4 > stores
Definition: Utils.h:77
unsigned hasFree(Value memref) const
Definition: Utils.cpp:123
SmallVector< Operation *, 4 > memrefLoads
Definition: Utils.h:75
SmallVector< Operation *, 4 > memrefStores
Definition: Utils.h:79
unsigned getStoreOpCount(Value memref) const
Definition: Utils.cpp:93
unsigned hasStore(Value memref) const
Returns true if there exists an operation with a write memory effect to memref in this node.
Definition: Utils.cpp:109
SmallVector< Operation *, 4 > memrefFrees
Definition: Utils.h:81
unsigned addNode(Operation *op)
Definition: Utils.cpp:467
bool writesToLiveInOrEscapingMemrefs(unsigned id) const
Definition: Utils.cpp:497
void removeEdge(unsigned srcId, unsigned dstId, Value value)
Definition: Utils.cpp:543
void addEdge(unsigned srcId, unsigned dstId, Value value)
Definition: Utils.cpp:532
DenseMap< unsigned, Node > nodes
Definition: Utils.h:141
void gatherDefiningNodes(unsigned id, DenseSet< unsigned > &definingNodes) const
Return all nodes which define SSA values used in node 'id'.
Definition: Utils.cpp:631
bool hasDependencePath(unsigned srcId, unsigned dstId) const
Definition: Utils.cpp:570
void clearNodeLoadAndStores(unsigned id)
Definition: Utils.cpp:797
const Node * getForOpNode(AffineForOp forOp) const
Definition: Utils.cpp:459
Operation * getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) const
Definition: Utils.cpp:645
void updateEdges(unsigned srcId, unsigned dstId, const DenseSet< Value > &privateMemRefs, bool removeSrcId)
Definition: Utils.cpp:720
void forEachMemRefInputEdge(unsigned id, const std::function< void(Edge)> &callback)
Definition: Utils.cpp:805
bool init(bool fullAffineDependences=true)
Definition: Utils.cpp:336
unsigned getOutEdgeCount(unsigned id, Value memref=nullptr) const
Definition: Utils.cpp:621
const Node * getNode(unsigned id) const
Definition: Utils.cpp:452
void removeNode(unsigned id)
Definition: Utils.cpp:474
void forEachMemRefOutputEdge(unsigned id, const std::function< void(Edge)> &callback)
Definition: Utils.cpp:813
void forEachMemRefEdge(ArrayRef< Edge > edges, const std::function< void(Edge)> &callback)
Definition: Utils.cpp:821
void addToNode(unsigned id, ArrayRef< Operation * > loads, ArrayRef< Operation * > stores, ArrayRef< Operation * > memrefLoads, ArrayRef< Operation * > memrefStores, ArrayRef< Operation * > memrefFrees)
Definition: Utils.cpp:784
unsigned getIncomingMemRefAccesses(unsigned id, Value memref) const
Definition: Utils.cpp:605
bool hasEdge(unsigned srcId, unsigned dstId, Value value=nullptr) const
Definition: Utils.cpp:517
void print(raw_ostream &os) const
Definition: Utils.cpp:833
A region of a memref's data space; this is typically constructed by analyzing load/store op's on this...
Definition: Utils.h:489
std::optional< int64_t > getConstantBoundingSizeAndShape(SmallVectorImpl< int64_t > *shape=nullptr, SmallVectorImpl< AffineMap > *lbs=nullptr) const
Returns a constant upper bound on the number of elements in this region if bounded by a known constan...
Definition: Utils.cpp:1160
FlatAffineValueConstraints * getConstraints()
Definition: Utils.h:535
unsigned getRank() const
Returns the rank of the memref that this region corresponds to.
Definition: Utils.cpp:1156
LogicalResult compute(Operation *op, unsigned loopDepth, const ComputationSliceState *sliceState=nullptr, bool addMemRefDimBounds=true, bool dropLocalVars=true, bool dropOuterIVs=true)
Computes the memory region accessed by this memref with the region represented as constraints symboli...
Definition: Utils.cpp:1257
void getLowerAndUpperBound(unsigned pos, AffineMap &lbMap, AffineMap &ubMap) const
Gets the lower and upper bound map for the dimensional variable at pos.
Definition: Utils.cpp:1216
std::optional< int64_t > getRegionSize()
Returns the size of this MemRefRegion in bytes.
Definition: Utils.cpp:1438
LogicalResult unionBoundingBox(const MemRefRegion &other)
Definition: Utils.cpp:1235
Value memref
Memref that this region corresponds to.
Definition: Utils.h:570
Enumerates different result statuses of slice computation by computeSliceUnion
Definition: Utils.h:305
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.