MLIR  20.0.0git
Utils.cpp
Go to the documentation of this file.
1 //===- Utils.cpp ---- Misc utilities for analysis -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous analysis routines for non-loop IR
10 // structures.
11 //
12 //===----------------------------------------------------------------------===//
13 
22 #include "mlir/IR/IntegerSet.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include <optional>
29 
30 #define DEBUG_TYPE "analysis-utils"
31 
32 using namespace mlir;
33 using namespace affine;
34 using namespace presburger;
35 
36 using llvm::SmallDenseMap;
37 
39 
40 // LoopNestStateCollector walks loop nests and collects load and store
41 // operations, and whether or not a region holding op other than ForOp and IfOp
42 // was encountered in the loop nest.
44  opToWalk->walk([&](Operation *op) {
45  if (isa<AffineForOp>(op))
46  forOps.push_back(cast<AffineForOp>(op));
47  else if (op->getNumRegions() != 0 && !isa<AffineIfOp>(op))
48  hasNonAffineRegionOp = true;
49  else if (isa<AffineReadOpInterface>(op))
50  loadOpInsts.push_back(op);
51  else if (isa<AffineWriteOpInterface>(op))
52  storeOpInsts.push_back(op);
53  });
54 }
55 
56 // Returns the load op count for 'memref'.
57 unsigned Node::getLoadOpCount(Value memref) const {
58  unsigned loadOpCount = 0;
59  for (Operation *loadOp : loads) {
60  if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
61  ++loadOpCount;
62  }
63  return loadOpCount;
64 }
65 
66 // Returns the store op count for 'memref'.
67 unsigned Node::getStoreOpCount(Value memref) const {
68  unsigned storeOpCount = 0;
69  for (Operation *storeOp : stores) {
70  if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
71  ++storeOpCount;
72  }
73  return storeOpCount;
74 }
75 
76 // Returns all store ops in 'storeOps' which access 'memref'.
77 void Node::getStoreOpsForMemref(Value memref,
78  SmallVectorImpl<Operation *> *storeOps) const {
79  for (Operation *storeOp : stores) {
80  if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
81  storeOps->push_back(storeOp);
82  }
83 }
84 
85 // Returns all load ops in 'loadOps' which access 'memref'.
86 void Node::getLoadOpsForMemref(Value memref,
87  SmallVectorImpl<Operation *> *loadOps) const {
88  for (Operation *loadOp : loads) {
89  if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
90  loadOps->push_back(loadOp);
91  }
92 }
93 
94 // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
95 // has at least one load and store operation.
96 void Node::getLoadAndStoreMemrefSet(
97  DenseSet<Value> *loadAndStoreMemrefSet) const {
98  llvm::SmallDenseSet<Value, 2> loadMemrefs;
99  for (Operation *loadOp : loads) {
100  loadMemrefs.insert(cast<AffineReadOpInterface>(loadOp).getMemRef());
101  }
102  for (Operation *storeOp : stores) {
103  auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
104  if (loadMemrefs.count(memref) > 0)
105  loadAndStoreMemrefSet->insert(memref);
106  }
107 }
108 
109 // Initializes the data dependence graph by walking operations in `block`.
110 // Assigns each node in the graph a node id based on program order in 'f'.
112  LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
113  // Map from a memref to the set of ids of the nodes that have ops accessing
114  // the memref.
115  DenseMap<Value, SetVector<unsigned>> memrefAccesses;
116 
117  DenseMap<Operation *, unsigned> forToNodeMap;
118  for (Operation &op : block) {
119  if (dyn_cast<AffineForOp>(op)) {
120  // Create graph node 'id' to represent top-level 'forOp' and record
121  // all loads and store accesses it contains.
122  LoopNestStateCollector collector;
123  collector.collect(&op);
124  // Return false if a region holding op other than 'affine.for' and
125  // 'affine.if' was found (not currently supported).
126  if (collector.hasNonAffineRegionOp)
127  return false;
128  Node node(nextNodeId++, &op);
129  for (auto *opInst : collector.loadOpInsts) {
130  node.loads.push_back(opInst);
131  auto memref = cast<AffineReadOpInterface>(opInst).getMemRef();
132  memrefAccesses[memref].insert(node.id);
133  }
134  for (auto *opInst : collector.storeOpInsts) {
135  node.stores.push_back(opInst);
136  auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef();
137  memrefAccesses[memref].insert(node.id);
138  }
139  forToNodeMap[&op] = node.id;
140  nodes.insert({node.id, node});
141  } else if (dyn_cast<AffineReadOpInterface>(op)) {
142  // Create graph node for top-level load op.
143  Node node(nextNodeId++, &op);
144  node.loads.push_back(&op);
145  auto memref = cast<AffineReadOpInterface>(op).getMemRef();
146  memrefAccesses[memref].insert(node.id);
147  nodes.insert({node.id, node});
148  } else if (dyn_cast<AffineWriteOpInterface>(op)) {
149  // Create graph node for top-level store op.
150  Node node(nextNodeId++, &op);
151  node.stores.push_back(&op);
152  auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
153  memrefAccesses[memref].insert(node.id);
154  nodes.insert({node.id, node});
155  } else if (op.getNumResults() > 0 && !op.use_empty()) {
156  // Create graph node for top-level producer of SSA values, which
157  // could be used by loop nest nodes.
158  Node node(nextNodeId++, &op);
159  nodes.insert({node.id, node});
160  } else if (!isMemoryEffectFree(&op) &&
161  (op.getNumRegions() == 0 || isa<RegionBranchOpInterface>(op))) {
162  // Create graph node for top-level op unless it is known to be
163  // memory-effect free. This covers all unknown/unregistered ops,
164  // non-affine ops with memory effects, and region-holding ops with a
165  // well-defined control flow. During the fusion validity checks, we look
166  // for non-affine ops on the path from source to destination, at which
167  // point we check which memrefs if any are used in the region.
168  Node node(nextNodeId++, &op);
169  nodes.insert({node.id, node});
170  } else if (op.getNumRegions() != 0) {
171  // Return false if non-handled/unknown region-holding ops are found. We
172  // won't know what such ops do or what its regions mean; for e.g., it may
173  // not be an imperative op.
174  LLVM_DEBUG(llvm::dbgs()
175  << "MDG init failed; unknown region-holding op found!\n");
176  return false;
177  }
178  }
179 
180  for (auto &idAndNode : nodes) {
181  LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n"
182  << *(idAndNode.second.op) << "\n");
183  (void)idAndNode;
184  }
185 
186  // Add dependence edges between nodes which produce SSA values and their
187  // users. Load ops can be considered as the ones producing SSA values.
188  for (auto &idAndNode : nodes) {
189  const Node &node = idAndNode.second;
190  // Stores don't define SSA values, skip them.
191  if (!node.stores.empty())
192  continue;
193  Operation *opInst = node.op;
194  for (Value value : opInst->getResults()) {
195  for (Operation *user : value.getUsers()) {
196  // Ignore users outside of the block.
197  if (block.getParent()->findAncestorOpInRegion(*user)->getBlock() !=
198  &block)
199  continue;
201  getAffineForIVs(*user, &loops);
202  // Find the surrounding affine.for nested immediately within the
203  // block.
204  auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
205  return loop->getBlock() == &block;
206  });
207  if (it == loops.end())
208  continue;
209  assert(forToNodeMap.count(*it) > 0 && "missing mapping");
210  unsigned userLoopNestId = forToNodeMap[*it];
211  addEdge(node.id, userLoopNestId, value);
212  }
213  }
214  }
215 
216  // Walk memref access lists and add graph edges between dependent nodes.
217  for (auto &memrefAndList : memrefAccesses) {
218  unsigned n = memrefAndList.second.size();
219  for (unsigned i = 0; i < n; ++i) {
220  unsigned srcId = memrefAndList.second[i];
221  bool srcHasStore =
222  getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
223  for (unsigned j = i + 1; j < n; ++j) {
224  unsigned dstId = memrefAndList.second[j];
225  bool dstHasStore =
226  getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
227  if (srcHasStore || dstHasStore)
228  addEdge(srcId, dstId, memrefAndList.first);
229  }
230  }
231  }
232  return true;
233 }
234 
235 // Returns the graph node for 'id'.
237  auto it = nodes.find(id);
238  assert(it != nodes.end());
239  return &it->second;
240 }
241 
242 // Returns the graph node for 'forOp'.
244  for (auto &idAndNode : nodes)
245  if (idAndNode.second.op == forOp)
246  return &idAndNode.second;
247  return nullptr;
248 }
249 
250 // Adds a node with 'op' to the graph and returns its unique identifier.
252  Node node(nextNodeId++, op);
253  nodes.insert({node.id, node});
254  return node.id;
255 }
256 
257 // Remove node 'id' (and its associated edges) from graph.
259  // Remove each edge in 'inEdges[id]'.
260  if (inEdges.count(id) > 0) {
261  SmallVector<Edge, 2> oldInEdges = inEdges[id];
262  for (auto &inEdge : oldInEdges) {
263  removeEdge(inEdge.id, id, inEdge.value);
264  }
265  }
266  // Remove each edge in 'outEdges[id]'.
267  if (outEdges.count(id) > 0) {
268  SmallVector<Edge, 2> oldOutEdges = outEdges[id];
269  for (auto &outEdge : oldOutEdges) {
270  removeEdge(id, outEdge.id, outEdge.value);
271  }
272  }
273  // Erase remaining node state.
274  inEdges.erase(id);
275  outEdges.erase(id);
276  nodes.erase(id);
277 }
278 
279 // Returns true if node 'id' writes to any memref which escapes (or is an
280 // argument to) the block. Returns false otherwise.
282  Node *node = getNode(id);
283  for (auto *storeOpInst : node->stores) {
284  auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
285  auto *op = memref.getDefiningOp();
286  // Return true if 'memref' is a block argument.
287  if (!op)
288  return true;
289  // Return true if any use of 'memref' does not deference it in an affine
290  // way.
291  for (auto *user : memref.getUsers())
292  if (!isa<AffineMapAccessInterface>(*user))
293  return true;
294  }
295  return false;
296 }
297 
298 // Returns true iff there is an edge from node 'srcId' to node 'dstId' which
299 // is for 'value' if non-null, or for any value otherwise. Returns false
300 // otherwise.
301 bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
302  Value value) {
303  if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
304  return false;
305  }
306  bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
307  return edge.id == dstId && (!value || edge.value == value);
308  });
309  bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
310  return edge.id == srcId && (!value || edge.value == value);
311  });
312  return hasOutEdge && hasInEdge;
313 }
314 
315 // Adds an edge from node 'srcId' to node 'dstId' for 'value'.
316 void MemRefDependenceGraph::addEdge(unsigned srcId, unsigned dstId,
317  Value value) {
318  if (!hasEdge(srcId, dstId, value)) {
319  outEdges[srcId].push_back({dstId, value});
320  inEdges[dstId].push_back({srcId, value});
321  if (isa<MemRefType>(value.getType()))
322  memrefEdgeCount[value]++;
323  }
324 }
325 
326 // Removes an edge from node 'srcId' to node 'dstId' for 'value'.
327 void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
328  Value value) {
329  assert(inEdges.count(dstId) > 0);
330  assert(outEdges.count(srcId) > 0);
331  if (isa<MemRefType>(value.getType())) {
332  assert(memrefEdgeCount.count(value) > 0);
333  memrefEdgeCount[value]--;
334  }
335  // Remove 'srcId' from 'inEdges[dstId]'.
336  for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
337  if ((*it).id == srcId && (*it).value == value) {
338  inEdges[dstId].erase(it);
339  break;
340  }
341  }
342  // Remove 'dstId' from 'outEdges[srcId]'.
343  for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
344  if ((*it).id == dstId && (*it).value == value) {
345  outEdges[srcId].erase(it);
346  break;
347  }
348  }
349 }
350 
351 // Returns true if there is a path in the dependence graph from node 'srcId'
352 // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
353 // operations that the edges connected are expected to be from the same block.
354 bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
355  // Worklist state is: <node-id, next-output-edge-index-to-visit>
357  worklist.push_back({srcId, 0});
358  Operation *dstOp = getNode(dstId)->op;
359  // Run DFS traversal to see if 'dstId' is reachable from 'srcId'.
360  while (!worklist.empty()) {
361  auto &idAndIndex = worklist.back();
362  // Return true if we have reached 'dstId'.
363  if (idAndIndex.first == dstId)
364  return true;
365  // Pop and continue if node has no out edges, or if all out edges have
366  // already been visited.
367  if (outEdges.count(idAndIndex.first) == 0 ||
368  idAndIndex.second == outEdges[idAndIndex.first].size()) {
369  worklist.pop_back();
370  continue;
371  }
372  // Get graph edge to traverse.
373  Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
374  // Increment next output edge index for 'idAndIndex'.
375  ++idAndIndex.second;
376  // Add node at 'edge.id' to the worklist. We don't need to consider
377  // nodes that are "after" dstId in the containing block; one can't have a
378  // path to `dstId` from any of those nodes.
379  bool afterDst = dstOp->isBeforeInBlock(getNode(edge.id)->op);
380  if (!afterDst && edge.id != idAndIndex.first)
381  worklist.push_back({edge.id, 0});
382  }
383  return false;
384 }
385 
386 // Returns the input edge count for node 'id' and 'memref' from src nodes
387 // which access 'memref' with a store operation.
389  Value memref) {
390  unsigned inEdgeCount = 0;
391  if (inEdges.count(id) > 0)
392  for (auto &inEdge : inEdges[id])
393  if (inEdge.value == memref) {
394  Node *srcNode = getNode(inEdge.id);
395  // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
396  if (srcNode->getStoreOpCount(memref) > 0)
397  ++inEdgeCount;
398  }
399  return inEdgeCount;
400 }
401 
402 // Returns the output edge count for node 'id' and 'memref' (if non-null),
403 // otherwise returns the total output edge count from node 'id'.
404 unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
405  unsigned outEdgeCount = 0;
406  if (outEdges.count(id) > 0)
407  for (auto &outEdge : outEdges[id])
408  if (!memref || outEdge.value == memref)
409  ++outEdgeCount;
410  return outEdgeCount;
411 }
412 
413 /// Return all nodes which define SSA values used in node 'id'.
415  unsigned id, DenseSet<unsigned> &definingNodes) {
416  for (MemRefDependenceGraph::Edge edge : inEdges[id])
417  // By definition of edge, if the edge value is a non-memref value,
418  // then the dependence is between a graph node which defines an SSA value
419  // and another graph node which uses the SSA value.
420  if (!isa<MemRefType>(edge.value.getType()))
421  definingNodes.insert(edge.id);
422 }
423 
424 // Computes and returns an insertion point operation, before which the
425 // the fused <srcId, dstId> loop nest can be inserted while preserving
426 // dependences. Returns nullptr if no such insertion point is found.
427 Operation *
429  unsigned dstId) {
430  if (outEdges.count(srcId) == 0)
431  return getNode(dstId)->op;
432 
433  // Skip if there is any defining node of 'dstId' that depends on 'srcId'.
434  DenseSet<unsigned> definingNodes;
435  gatherDefiningNodes(dstId, definingNodes);
436  if (llvm::any_of(definingNodes,
437  [&](unsigned id) { return hasDependencePath(srcId, id); })) {
438  LLVM_DEBUG(llvm::dbgs()
439  << "Can't fuse: a defining op with a user in the dst "
440  "loop has dependence from the src loop\n");
441  return nullptr;
442  }
443 
444  // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
445  SmallPtrSet<Operation *, 2> srcDepInsts;
446  for (auto &outEdge : outEdges[srcId])
447  if (outEdge.id != dstId)
448  srcDepInsts.insert(getNode(outEdge.id)->op);
449 
450  // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
451  SmallPtrSet<Operation *, 2> dstDepInsts;
452  for (auto &inEdge : inEdges[dstId])
453  if (inEdge.id != srcId)
454  dstDepInsts.insert(getNode(inEdge.id)->op);
455 
456  Operation *srcNodeInst = getNode(srcId)->op;
457  Operation *dstNodeInst = getNode(dstId)->op;
458 
459  // Computing insertion point:
460  // *) Walk all operation positions in Block operation list in the
461  // range (src, dst). For each operation 'op' visited in this search:
462  // *) Store in 'firstSrcDepPos' the first position where 'op' has a
463  // dependence edge from 'srcNode'.
464  // *) Store in 'lastDstDepPost' the last position where 'op' has a
465  // dependence edge to 'dstNode'.
466  // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
467  // operation insertion point (or return null pointer if no such
468  // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
470  std::optional<unsigned> firstSrcDepPos;
471  std::optional<unsigned> lastDstDepPos;
472  unsigned pos = 0;
473  for (Block::iterator it = std::next(Block::iterator(srcNodeInst));
474  it != Block::iterator(dstNodeInst); ++it) {
475  Operation *op = &(*it);
476  if (srcDepInsts.count(op) > 0 && firstSrcDepPos == std::nullopt)
477  firstSrcDepPos = pos;
478  if (dstDepInsts.count(op) > 0)
479  lastDstDepPos = pos;
480  depInsts.push_back(op);
481  ++pos;
482  }
483 
484  if (firstSrcDepPos.has_value()) {
485  if (lastDstDepPos.has_value()) {
486  if (*firstSrcDepPos <= *lastDstDepPos) {
487  // No valid insertion point exists which preserves dependences.
488  return nullptr;
489  }
490  }
491  // Return the insertion point at 'firstSrcDepPos'.
492  return depInsts[*firstSrcDepPos];
493  }
494  // No dependence targets in range (or only dst deps in range), return
495  // 'dstNodInst' insertion point.
496  return dstNodeInst;
497 }
498 
499 // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
500 // taking into account that:
501 // *) if 'removeSrcId' is true, 'srcId' will be removed after fusion,
502 // *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a
503 // private memref.
504 void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId,
505  const DenseSet<Value> &privateMemRefs,
506  bool removeSrcId) {
507  // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'.
508  if (inEdges.count(srcId) > 0) {
509  SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
510  for (auto &inEdge : oldInEdges) {
511  // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
512  if (privateMemRefs.count(inEdge.value) == 0)
513  addEdge(inEdge.id, dstId, inEdge.value);
514  }
515  }
516  // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
517  // If 'srcId' is going to be removed, remap all the out edges to 'dstId'.
518  if (outEdges.count(srcId) > 0) {
519  SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
520  for (auto &outEdge : oldOutEdges) {
521  // Remove any out edges from 'srcId' to 'dstId' across memrefs.
522  if (outEdge.id == dstId)
523  removeEdge(srcId, outEdge.id, outEdge.value);
524  else if (removeSrcId) {
525  addEdge(dstId, outEdge.id, outEdge.value);
526  removeEdge(srcId, outEdge.id, outEdge.value);
527  }
528  }
529  }
530  // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
531  // replaced by a private memref). These edges could come from nodes
532  // other than 'srcId' which were removed in the previous step.
533  if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) {
534  SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
535  for (auto &inEdge : oldInEdges)
536  if (privateMemRefs.count(inEdge.value) > 0)
537  removeEdge(inEdge.id, dstId, inEdge.value);
538  }
539 }
540 
541 // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
542 // of sibling node 'sibId' into node 'dstId'.
543 void MemRefDependenceGraph::updateEdges(unsigned sibId, unsigned dstId) {
544  // For each edge in 'inEdges[sibId]':
545  // *) Add new edge from source node 'inEdge.id' to 'dstNode'.
546  // *) Remove edge from source node 'inEdge.id' to 'sibNode'.
547  if (inEdges.count(sibId) > 0) {
548  SmallVector<Edge, 2> oldInEdges = inEdges[sibId];
549  for (auto &inEdge : oldInEdges) {
550  addEdge(inEdge.id, dstId, inEdge.value);
551  removeEdge(inEdge.id, sibId, inEdge.value);
552  }
553  }
554 
555  // For each edge in 'outEdges[sibId]' to node 'id'
556  // *) Add new edge from 'dstId' to 'outEdge.id'.
557  // *) Remove edge from 'sibId' to 'outEdge.id'.
558  if (outEdges.count(sibId) > 0) {
559  SmallVector<Edge, 2> oldOutEdges = outEdges[sibId];
560  for (auto &outEdge : oldOutEdges) {
561  addEdge(dstId, outEdge.id, outEdge.value);
562  removeEdge(sibId, outEdge.id, outEdge.value);
563  }
564  }
565 }
566 
567 // Adds ops in 'loads' and 'stores' to node at 'id'.
569  unsigned id, const SmallVectorImpl<Operation *> &loads,
570  const SmallVectorImpl<Operation *> &stores) {
571  Node *node = getNode(id);
572  llvm::append_range(node->loads, loads);
573  llvm::append_range(node->stores, stores);
574 }
575 
577  Node *node = getNode(id);
578  node->loads.clear();
579  node->stores.clear();
580 }
581 
582 // Calls 'callback' for each input edge incident to node 'id' which carries a
583 // memref dependence.
585  unsigned id, const std::function<void(Edge)> &callback) {
586  if (inEdges.count(id) > 0)
587  forEachMemRefEdge(inEdges[id], callback);
588 }
589 
590 // Calls 'callback' for each output edge from node 'id' which carries a
591 // memref dependence.
593  unsigned id, const std::function<void(Edge)> &callback) {
594  if (outEdges.count(id) > 0)
595  forEachMemRefEdge(outEdges[id], callback);
596 }
597 
598 // Calls 'callback' for each edge in 'edges' which carries a memref
599 // dependence.
601  ArrayRef<Edge> edges, const std::function<void(Edge)> &callback) {
602  for (const auto &edge : edges) {
603  // Skip if 'edge' is not a memref dependence edge.
604  if (!isa<MemRefType>(edge.value.getType()))
605  continue;
606  assert(nodes.count(edge.id) > 0);
607  // Skip if 'edge.id' is not a loop nest.
608  if (!isa<AffineForOp>(getNode(edge.id)->op))
609  continue;
610  // Visit current input edge 'edge'.
611  callback(edge);
612  }
613 }
614 
615 void MemRefDependenceGraph::print(raw_ostream &os) const {
616  os << "\nMemRefDependenceGraph\n";
617  os << "\nNodes:\n";
618  for (const auto &idAndNode : nodes) {
619  os << "Node: " << idAndNode.first << "\n";
620  auto it = inEdges.find(idAndNode.first);
621  if (it != inEdges.end()) {
622  for (const auto &e : it->second)
623  os << " InEdge: " << e.id << " " << e.value << "\n";
624  }
625  it = outEdges.find(idAndNode.first);
626  if (it != outEdges.end()) {
627  for (const auto &e : it->second)
628  os << " OutEdge: " << e.id << " " << e.value << "\n";
629  }
630  }
631 }
632 
635  auto *currOp = op.getParentOp();
636  AffineForOp currAffineForOp;
637  // Traverse up the hierarchy collecting all 'affine.for' operation while
638  // skipping over 'affine.if' operations.
639  while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
640  if (auto currAffineForOp = dyn_cast<AffineForOp>(currOp))
641  loops->push_back(currAffineForOp);
642  currOp = currOp->getParentOp();
643  }
644  std::reverse(loops->begin(), loops->end());
645 }
646 
649  ops->clear();
650  Operation *currOp = op.getParentOp();
651 
652  // Traverse up the hierarchy collecting all `affine.for`, `affine.if`, and
653  // affine.parallel operations.
654  while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
655  if (isa<AffineIfOp, AffineForOp, AffineParallelOp>(currOp))
656  ops->push_back(currOp);
657  currOp = currOp->getParentOp();
658  }
659  std::reverse(ops->begin(), ops->end());
660 }
661 
662 // Populates 'cst' with FlatAffineValueConstraints which represent original
663 // domain of the loop bounds that define 'ivs'.
665  FlatAffineValueConstraints &cst) const {
666  assert(!ivs.empty() && "Cannot have a slice without its IVs");
667  cst = FlatAffineValueConstraints(/*numDims=*/ivs.size(), /*numSymbols=*/0,
668  /*numLocals=*/0, ivs);
669  for (Value iv : ivs) {
670  AffineForOp loop = getForInductionVarOwner(iv);
671  assert(loop && "Expected affine for");
672  if (failed(cst.addAffineForOpDomain(loop)))
673  return failure();
674  }
675  return success();
676 }
677 
678 // Populates 'cst' with FlatAffineValueConstraints which represent slice bounds.
679 LogicalResult
681  assert(!lbOperands.empty());
682  // Adds src 'ivs' as dimension variables in 'cst'.
683  unsigned numDims = ivs.size();
684  // Adds operands (dst ivs and symbols) as symbols in 'cst'.
685  unsigned numSymbols = lbOperands[0].size();
686 
687  SmallVector<Value, 4> values(ivs);
688  // Append 'ivs' then 'operands' to 'values'.
689  values.append(lbOperands[0].begin(), lbOperands[0].end());
690  *cst = FlatAffineValueConstraints(numDims, numSymbols, 0, values);
691 
692  // Add loop bound constraints for values which are loop IVs of the destination
693  // of fusion and equality constraints for symbols which are constants.
694  for (unsigned i = numDims, end = values.size(); i < end; ++i) {
695  Value value = values[i];
696  assert(cst->containsVar(value) && "value expected to be present");
697  if (isValidSymbol(value)) {
698  // Check if the symbol is a constant.
699  if (std::optional<int64_t> cOp = getConstantIntValue(value))
700  cst->addBound(BoundType::EQ, value, cOp.value());
701  } else if (auto loop = getForInductionVarOwner(value)) {
702  if (failed(cst->addAffineForOpDomain(loop)))
703  return failure();
704  }
705  }
706 
707  // Add slices bounds on 'ivs' using maps 'lbs'/'ubs' with 'lbOperands[0]'
708  LogicalResult ret = cst->addSliceBounds(ivs, lbs, ubs, lbOperands[0]);
709  assert(succeeded(ret) &&
710  "should not fail as we never have semi-affine slice maps");
711  (void)ret;
712  return success();
713 }
714 
715 // Clears state bounds and operand state.
717  lbs.clear();
718  ubs.clear();
719  lbOperands.clear();
720  ubOperands.clear();
721 }
722 
724  llvm::errs() << "\tIVs:\n";
725  for (Value iv : ivs)
726  llvm::errs() << "\t\t" << iv << "\n";
727 
728  llvm::errs() << "\tLBs:\n";
729  for (auto en : llvm::enumerate(lbs)) {
730  llvm::errs() << "\t\t" << en.value() << "\n";
731  llvm::errs() << "\t\tOperands:\n";
732  for (Value lbOp : lbOperands[en.index()])
733  llvm::errs() << "\t\t\t" << lbOp << "\n";
734  }
735 
736  llvm::errs() << "\tUBs:\n";
737  for (auto en : llvm::enumerate(ubs)) {
738  llvm::errs() << "\t\t" << en.value() << "\n";
739  llvm::errs() << "\t\tOperands:\n";
740  for (Value ubOp : ubOperands[en.index()])
741  llvm::errs() << "\t\t\t" << ubOp << "\n";
742  }
743 }
744 
745 /// Fast check to determine if the computation slice is maximal. Returns true if
746 /// each slice dimension maps to an existing dst dimension and both the src
747 /// and the dst loops for those dimensions have the same bounds. Returns false
748 /// if both the src and the dst loops don't have the same bounds. Returns
749 /// std::nullopt if none of the above can be proven.
750 std::optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
751  assert(lbs.size() == ubs.size() && !lbs.empty() && !ivs.empty() &&
752  "Unexpected number of lbs, ubs and ivs in slice");
753 
754  for (unsigned i = 0, end = lbs.size(); i < end; ++i) {
755  AffineMap lbMap = lbs[i];
756  AffineMap ubMap = ubs[i];
757 
758  // Check if this slice is just an equality along this dimension.
759  if (!lbMap || !ubMap || lbMap.getNumResults() != 1 ||
760  ubMap.getNumResults() != 1 ||
761  lbMap.getResult(0) + 1 != ubMap.getResult(0) ||
762  // The condition above will be true for maps describing a single
763  // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
764  // Make sure we skip those cases by checking that the lb result is not
765  // just a constant.
766  isa<AffineConstantExpr>(lbMap.getResult(0)))
767  return std::nullopt;
768 
769  // Limited support: we expect the lb result to be just a loop dimension for
770  // now.
771  AffineDimExpr result = dyn_cast<AffineDimExpr>(lbMap.getResult(0));
772  if (!result)
773  return std::nullopt;
774 
775  // Retrieve dst loop bounds.
776  AffineForOp dstLoop =
777  getForInductionVarOwner(lbOperands[i][result.getPosition()]);
778  if (!dstLoop)
779  return std::nullopt;
780  AffineMap dstLbMap = dstLoop.getLowerBoundMap();
781  AffineMap dstUbMap = dstLoop.getUpperBoundMap();
782 
783  // Retrieve src loop bounds.
784  AffineForOp srcLoop = getForInductionVarOwner(ivs[i]);
785  assert(srcLoop && "Expected affine for");
786  AffineMap srcLbMap = srcLoop.getLowerBoundMap();
787  AffineMap srcUbMap = srcLoop.getUpperBoundMap();
788 
789  // Limited support: we expect simple src and dst loops with a single
790  // constant component per bound for now.
791  if (srcLbMap.getNumResults() != 1 || srcUbMap.getNumResults() != 1 ||
792  dstLbMap.getNumResults() != 1 || dstUbMap.getNumResults() != 1)
793  return std::nullopt;
794 
795  AffineExpr srcLbResult = srcLbMap.getResult(0);
796  AffineExpr dstLbResult = dstLbMap.getResult(0);
797  AffineExpr srcUbResult = srcUbMap.getResult(0);
798  AffineExpr dstUbResult = dstUbMap.getResult(0);
799  if (!isa<AffineConstantExpr>(srcLbResult) ||
800  !isa<AffineConstantExpr>(srcUbResult) ||
801  !isa<AffineConstantExpr>(dstLbResult) ||
802  !isa<AffineConstantExpr>(dstUbResult))
803  return std::nullopt;
804 
805  // Check if src and dst loop bounds are the same. If not, we can guarantee
806  // that the slice is not maximal.
807  if (srcLbResult != dstLbResult || srcUbResult != dstUbResult ||
808  srcLoop.getStep() != dstLoop.getStep())
809  return false;
810  }
811 
812  return true;
813 }
814 
815 /// Returns true if it is deterministically verified that the original iteration
816 /// space of the slice is contained within the new iteration space that is
817 /// created after fusing 'this' slice into its destination.
818 std::optional<bool> ComputationSliceState::isSliceValid() const {
819  // Fast check to determine if the slice is valid. If the following conditions
820  // are verified to be true, slice is declared valid by the fast check:
821  // 1. Each slice loop is a single iteration loop bound in terms of a single
822  // destination loop IV.
823  // 2. Loop bounds of the destination loop IV (from above) and those of the
824  // source loop IV are exactly the same.
825  // If the fast check is inconclusive or false, we proceed with a more
826  // expensive analysis.
827  // TODO: Store the result of the fast check, as it might be used again in
828  // `canRemoveSrcNodeAfterFusion`.
829  std::optional<bool> isValidFastCheck = isSliceMaximalFastCheck();
830  if (isValidFastCheck && *isValidFastCheck)
831  return true;
832 
833  // Create constraints for the source loop nest using which slice is computed.
834  FlatAffineValueConstraints srcConstraints;
835  // TODO: Store the source's domain to avoid computation at each depth.
836  if (failed(getSourceAsConstraints(srcConstraints))) {
837  LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n");
838  return std::nullopt;
839  }
840  // As the set difference utility currently cannot handle symbols in its
841  // operands, validity of the slice cannot be determined.
842  if (srcConstraints.getNumSymbolVars() > 0) {
843  LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n");
844  return std::nullopt;
845  }
846  // TODO: Handle local vars in the source domains while using the 'projectOut'
847  // utility below. Currently, aligning is not done assuming that there will be
848  // no local vars in the source domain.
849  if (srcConstraints.getNumLocalVars() != 0) {
850  LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n");
851  return std::nullopt;
852  }
853 
854  // Create constraints for the slice loop nest that would be created if the
855  // fusion succeeds.
856  FlatAffineValueConstraints sliceConstraints;
857  if (failed(getAsConstraints(&sliceConstraints))) {
858  LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n");
859  return std::nullopt;
860  }
861 
862  // Projecting out every dimension other than the 'ivs' to express slice's
863  // domain completely in terms of source's IVs.
864  sliceConstraints.projectOut(ivs.size(),
865  sliceConstraints.getNumVars() - ivs.size());
866 
867  LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n");
868  LLVM_DEBUG(srcConstraints.dump());
869  LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds "
870  "(expressed in terms of its source's IVs):\n");
871  LLVM_DEBUG(sliceConstraints.dump());
872 
873  // TODO: Store 'srcSet' to avoid recalculating for each depth.
874  PresburgerSet srcSet(srcConstraints);
875  PresburgerSet sliceSet(sliceConstraints);
876  PresburgerSet diffSet = sliceSet.subtract(srcSet);
877 
878  if (!diffSet.isIntegerEmpty()) {
879  LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n");
880  return false;
881  }
882  return true;
883 }
884 
885 /// Returns true if the computation slice encloses all the iterations of the
886 /// sliced loop nest. Returns false if it does not. Returns std::nullopt if it
887 /// cannot determine if the slice is maximal or not.
888 std::optional<bool> ComputationSliceState::isMaximal() const {
889  // Fast check to determine if the computation slice is maximal. If the result
890  // is inconclusive, we proceed with a more expensive analysis.
891  std::optional<bool> isMaximalFastCheck = isSliceMaximalFastCheck();
892  if (isMaximalFastCheck)
893  return isMaximalFastCheck;
894 
895  // Create constraints for the src loop nest being sliced.
896  FlatAffineValueConstraints srcConstraints(/*numDims=*/ivs.size(),
897  /*numSymbols=*/0,
898  /*numLocals=*/0, ivs);
899  for (Value iv : ivs) {
900  AffineForOp loop = getForInductionVarOwner(iv);
901  assert(loop && "Expected affine for");
902  if (failed(srcConstraints.addAffineForOpDomain(loop)))
903  return std::nullopt;
904  }
905 
906  // Create constraints for the slice using the dst loop nest information. We
907  // retrieve existing dst loops from the lbOperands.
908  SmallVector<Value> consumerIVs;
909  for (Value lbOp : lbOperands[0])
910  if (getForInductionVarOwner(lbOp))
911  consumerIVs.push_back(lbOp);
912 
913  // Add empty IV Values for those new loops that are not equalities and,
914  // therefore, are not yet materialized in the IR.
915  for (int i = consumerIVs.size(), end = ivs.size(); i < end; ++i)
916  consumerIVs.push_back(Value());
917 
918  FlatAffineValueConstraints sliceConstraints(/*numDims=*/consumerIVs.size(),
919  /*numSymbols=*/0,
920  /*numLocals=*/0, consumerIVs);
921 
922  if (failed(sliceConstraints.addDomainFromSliceMaps(lbs, ubs, lbOperands[0])))
923  return std::nullopt;
924 
925  if (srcConstraints.getNumDimVars() != sliceConstraints.getNumDimVars())
926  // Constraint dims are different. The integer set difference can't be
927  // computed so we don't know if the slice is maximal.
928  return std::nullopt;
929 
930  // Compute the difference between the src loop nest and the slice integer
931  // sets.
932  PresburgerSet srcSet(srcConstraints);
933  PresburgerSet sliceSet(sliceConstraints);
934  PresburgerSet diffSet = srcSet.subtract(sliceSet);
935  return diffSet.isIntegerEmpty();
936 }
937 
938 unsigned MemRefRegion::getRank() const {
939  return cast<MemRefType>(memref.getType()).getRank();
940 }
941 
943  SmallVectorImpl<int64_t> *shape, std::vector<SmallVector<int64_t, 4>> *lbs,
944  SmallVectorImpl<int64_t> *lbDivisors) const {
945  auto memRefType = cast<MemRefType>(memref.getType());
946  unsigned rank = memRefType.getRank();
947  if (shape)
948  shape->reserve(rank);
949 
950  assert(rank == cst.getNumDimVars() && "inconsistent memref region");
951 
952  // Use a copy of the region constraints that has upper/lower bounds for each
953  // memref dimension with static size added to guard against potential
954  // over-approximation from projection or union bounding box. We may not add
955  // this on the region itself since they might just be redundant constraints
956  // that will need non-trivials means to eliminate.
957  FlatAffineValueConstraints cstWithShapeBounds(cst);
958  for (unsigned r = 0; r < rank; r++) {
959  cstWithShapeBounds.addBound(BoundType::LB, r, 0);
960  int64_t dimSize = memRefType.getDimSize(r);
961  if (ShapedType::isDynamic(dimSize))
962  continue;
963  cstWithShapeBounds.addBound(BoundType::UB, r, dimSize - 1);
964  }
965 
966  // Find a constant upper bound on the extent of this memref region along each
967  // dimension.
968  int64_t numElements = 1;
969  int64_t diffConstant;
970  int64_t lbDivisor;
971  for (unsigned d = 0; d < rank; d++) {
973  std::optional<int64_t> diff =
974  cstWithShapeBounds.getConstantBoundOnDimSize64(d, &lb, &lbDivisor);
975  if (diff.has_value()) {
976  diffConstant = *diff;
977  assert(diffConstant >= 0 && "Dim size bound can't be negative");
978  assert(lbDivisor > 0);
979  } else {
980  // If no constant bound is found, then it can always be bound by the
981  // memref's dim size if the latter has a constant size along this dim.
982  auto dimSize = memRefType.getDimSize(d);
983  if (dimSize == ShapedType::kDynamic)
984  return std::nullopt;
985  diffConstant = dimSize;
986  // Lower bound becomes 0.
987  lb.resize(cstWithShapeBounds.getNumSymbolVars() + 1, 0);
988  lbDivisor = 1;
989  }
990  numElements *= diffConstant;
991  if (lbs) {
992  lbs->push_back(lb);
993  assert(lbDivisors && "both lbs and lbDivisor or none");
994  lbDivisors->push_back(lbDivisor);
995  }
996  if (shape) {
997  shape->push_back(diffConstant);
998  }
999  }
1000  return numElements;
1001 }
1002 
1004  AffineMap &ubMap) const {
1005  assert(pos < cst.getNumDimVars() && "invalid position");
1006  auto memRefType = cast<MemRefType>(memref.getType());
1007  unsigned rank = memRefType.getRank();
1008 
1009  assert(rank == cst.getNumDimVars() && "inconsistent memref region");
1010 
1011  auto boundPairs = cst.getLowerAndUpperBound(
1012  pos, /*offset=*/0, /*num=*/rank, cst.getNumDimAndSymbolVars(),
1013  /*localExprs=*/{}, memRefType.getContext());
1014  lbMap = boundPairs.first;
1015  ubMap = boundPairs.second;
1016  assert(lbMap && "lower bound for a region must exist");
1017  assert(ubMap && "upper bound for a region must exist");
1018  assert(lbMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank);
1019  assert(ubMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank);
1020 }
1021 
1022 LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) {
1023  assert(memref == other.memref);
1024  return cst.unionBoundingBox(*other.getConstraints());
1025 }
1026 
1027 /// Computes the memory region accessed by this memref with the region
1028 /// represented as constraints symbolic/parametric in 'loopDepth' loops
1029 /// surrounding opInst and any additional Function symbols.
1030 // For example, the memref region for this load operation at loopDepth = 1 will
1031 // be as below:
1032 //
1033 // affine.for %i = 0 to 32 {
1034 // affine.for %ii = %i to (d0) -> (d0 + 8) (%i) {
1035 // load %A[%ii]
1036 // }
1037 // }
1038 //
1039 // region: {memref = %A, write = false, {%i <= m0 <= %i + 7} }
1040 // The last field is a 2-d FlatAffineValueConstraints symbolic in %i.
1041 //
1042 // TODO: extend this to any other memref dereferencing ops
1043 // (dma_start, dma_wait).
1044 LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
1045  const ComputationSliceState *sliceState,
1046  bool addMemRefDimBounds) {
1047  assert((isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) &&
1048  "affine read/write op expected");
1049 
1050  MemRefAccess access(op);
1051  memref = access.memref;
1052  write = access.isStore();
1053 
1054  unsigned rank = access.getRank();
1055 
1056  LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op
1057  << "\ndepth: " << loopDepth << "\n";);
1058 
1059  // 0-d memrefs.
1060  if (rank == 0) {
1062  getAffineIVs(*op, ivs);
1063  assert(loopDepth <= ivs.size() && "invalid 'loopDepth'");
1064  // The first 'loopDepth' IVs are symbols for this region.
1065  ivs.resize(loopDepth);
1066  // A 0-d memref has a 0-d region.
1067  cst = FlatAffineValueConstraints(rank, loopDepth, /*numLocals=*/0, ivs);
1068  return success();
1069  }
1070 
1071  // Build the constraints for this region.
1072  AffineValueMap accessValueMap;
1073  access.getAccessMap(&accessValueMap);
1074  AffineMap accessMap = accessValueMap.getAffineMap();
1075 
1076  unsigned numDims = accessMap.getNumDims();
1077  unsigned numSymbols = accessMap.getNumSymbols();
1078  unsigned numOperands = accessValueMap.getNumOperands();
1079  // Merge operands with slice operands.
1080  SmallVector<Value, 4> operands;
1081  operands.resize(numOperands);
1082  for (unsigned i = 0; i < numOperands; ++i)
1083  operands[i] = accessValueMap.getOperand(i);
1084 
1085  if (sliceState != nullptr) {
1086  operands.reserve(operands.size() + sliceState->lbOperands[0].size());
1087  // Append slice operands to 'operands' as symbols.
1088  for (auto extraOperand : sliceState->lbOperands[0]) {
1089  if (!llvm::is_contained(operands, extraOperand)) {
1090  operands.push_back(extraOperand);
1091  numSymbols++;
1092  }
1093  }
1094  }
1095  // We'll first associate the dims and symbols of the access map to the dims
1096  // and symbols resp. of cst. This will change below once cst is
1097  // fully constructed out.
1098  cst = FlatAffineValueConstraints(numDims, numSymbols, 0, operands);
1099 
1100  // Add equality constraints.
1101  // Add inequalities for loop lower/upper bounds.
1102  for (unsigned i = 0; i < numDims + numSymbols; ++i) {
1103  auto operand = operands[i];
1104  if (auto affineFor = getForInductionVarOwner(operand)) {
1105  // Note that cst can now have more dimensions than accessMap if the
1106  // bounds expressions involve outer loops or other symbols.
1107  // TODO: rewrite this to use getInstIndexSet; this way
1108  // conditionals will be handled when the latter supports it.
1109  if (failed(cst.addAffineForOpDomain(affineFor)))
1110  return failure();
1111  } else if (auto parallelOp = getAffineParallelInductionVarOwner(operand)) {
1112  if (failed(cst.addAffineParallelOpDomain(parallelOp)))
1113  return failure();
1114  } else if (isValidSymbol(operand)) {
1115  // Check if the symbol is a constant.
1116  Value symbol = operand;
1117  if (auto constVal = getConstantIntValue(symbol))
1118  cst.addBound(BoundType::EQ, symbol, constVal.value());
1119  } else {
1120  LLVM_DEBUG(llvm::dbgs() << "unknown affine dimensional value");
1121  return failure();
1122  }
1123  }
1124 
1125  // Add lower/upper bounds on loop IVs using bounds from 'sliceState'.
1126  if (sliceState != nullptr) {
1127  // Add dim and symbol slice operands.
1128  for (auto operand : sliceState->lbOperands[0]) {
1129  cst.addInductionVarOrTerminalSymbol(operand);
1130  }
1131  // Add upper/lower bounds from 'sliceState' to 'cst'.
1132  LogicalResult ret =
1133  cst.addSliceBounds(sliceState->ivs, sliceState->lbs, sliceState->ubs,
1134  sliceState->lbOperands[0]);
1135  assert(succeeded(ret) &&
1136  "should not fail as we never have semi-affine slice maps");
1137  (void)ret;
1138  }
1139 
1140  // Add access function equalities to connect loop IVs to data dimensions.
1141  if (failed(cst.composeMap(&accessValueMap))) {
1142  op->emitError("getMemRefRegion: compose affine map failed");
1143  LLVM_DEBUG(accessValueMap.getAffineMap().dump());
1144  return failure();
1145  }
1146 
1147  // Set all variables appearing after the first 'rank' variables as
1148  // symbolic variables - so that the ones corresponding to the memref
1149  // dimensions are the dimensional variables for the memref region.
1150  cst.setDimSymbolSeparation(cst.getNumDimAndSymbolVars() - rank);
1151 
1152  // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
1153  // this memref region is symbolic.
1154  SmallVector<Value, 4> enclosingIVs;
1155  getAffineIVs(*op, enclosingIVs);
1156  assert(loopDepth <= enclosingIVs.size() && "invalid loop depth");
1157  enclosingIVs.resize(loopDepth);
1158  SmallVector<Value, 4> vars;
1159  cst.getValues(cst.getNumDimVars(), cst.getNumDimAndSymbolVars(), &vars);
1160  for (Value var : vars) {
1161  if ((isAffineInductionVar(var)) && !llvm::is_contained(enclosingIVs, var)) {
1162  cst.projectOut(var);
1163  }
1164  }
1165 
1166  // Project out any local variables (these would have been added for any
1167  // mod/divs).
1168  cst.projectOut(cst.getNumDimAndSymbolVars(), cst.getNumLocalVars());
1169 
1170  // Constant fold any symbolic variables.
1171  cst.constantFoldVarRange(/*pos=*/cst.getNumDimVars(),
1172  /*num=*/cst.getNumSymbolVars());
1173 
1174  assert(cst.getNumDimVars() == rank && "unexpected MemRefRegion format");
1175 
1176  // Add upper/lower bounds for each memref dimension with static size
1177  // to guard against potential over-approximation from projection.
1178  // TODO: Support dynamic memref dimensions.
1179  if (addMemRefDimBounds) {
1180  auto memRefType = cast<MemRefType>(memref.getType());
1181  for (unsigned r = 0; r < rank; r++) {
1182  cst.addBound(BoundType::LB, /*pos=*/r, /*value=*/0);
1183  if (memRefType.isDynamicDim(r))
1184  continue;
1185  cst.addBound(BoundType::UB, /*pos=*/r, memRefType.getDimSize(r) - 1);
1186  }
1187  }
1188  cst.removeTrivialRedundancy();
1189 
1190  LLVM_DEBUG(llvm::dbgs() << "Memory region:\n");
1191  LLVM_DEBUG(cst.dump());
1192  return success();
1193 }
1194 
1195 std::optional<int64_t>
1197  auto elementType = memRefType.getElementType();
1198 
1199  unsigned sizeInBits;
1200  if (elementType.isIntOrFloat()) {
1201  sizeInBits = elementType.getIntOrFloatBitWidth();
1202  } else if (auto vectorType = dyn_cast<VectorType>(elementType)) {
1203  if (vectorType.getElementType().isIntOrFloat())
1204  sizeInBits =
1205  vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
1206  else
1207  return std::nullopt;
1208  } else {
1209  return std::nullopt;
1210  }
1211  return llvm::divideCeil(sizeInBits, 8);
1212 }
1213 
1214 // Returns the size of the region.
1215 std::optional<int64_t> MemRefRegion::getRegionSize() {
1216  auto memRefType = cast<MemRefType>(memref.getType());
1217 
1218  if (!memRefType.getLayout().isIdentity()) {
1219  LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
1220  return false;
1221  }
1222 
1223  // Indices to use for the DmaStart op.
1224  // Indices for the original memref being DMAed from/to.
1225  SmallVector<Value, 4> memIndices;
1226  // Indices for the faster buffer being DMAed into/from.
1227  SmallVector<Value, 4> bufIndices;
1228 
1229  // Compute the extents of the buffer.
1230  std::optional<int64_t> numElements = getConstantBoundingSizeAndShape();
1231  if (!numElements) {
1232  LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
1233  return std::nullopt;
1234  }
1235  auto eltSize = getMemRefIntOrFloatEltSizeInBytes(memRefType);
1236  if (!eltSize)
1237  return std::nullopt;
1238  return *eltSize * *numElements;
1239 }
1240 
1241 /// Returns the size of memref data in bytes if it's statically shaped,
1242 /// std::nullopt otherwise. If the element of the memref has vector type, takes
1243 /// into account size of the vector as well.
1244 // TODO: improve/complete this when we have target data.
1245 std::optional<uint64_t>
1247  if (!memRefType.hasStaticShape())
1248  return std::nullopt;
1249  auto elementType = memRefType.getElementType();
1250  if (!elementType.isIntOrFloat() && !isa<VectorType>(elementType))
1251  return std::nullopt;
1252 
1253  auto sizeInBytes = getMemRefIntOrFloatEltSizeInBytes(memRefType);
1254  if (!sizeInBytes)
1255  return std::nullopt;
1256  for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) {
1257  sizeInBytes = *sizeInBytes * memRefType.getDimSize(i);
1258  }
1259  return sizeInBytes;
1260 }
1261 
1262 template <typename LoadOrStoreOp>
1263 LogicalResult mlir::affine::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
1264  bool emitError) {
1265  static_assert(llvm::is_one_of<LoadOrStoreOp, AffineReadOpInterface,
1266  AffineWriteOpInterface>::value,
1267  "argument should be either a AffineReadOpInterface or a "
1268  "AffineWriteOpInterface");
1269 
1270  Operation *op = loadOrStoreOp.getOperation();
1271  MemRefRegion region(op->getLoc());
1272  if (failed(region.compute(op, /*loopDepth=*/0, /*sliceState=*/nullptr,
1273  /*addMemRefDimBounds=*/false)))
1274  return success();
1275 
1276  LLVM_DEBUG(llvm::dbgs() << "Memory region");
1277  LLVM_DEBUG(region.getConstraints()->dump());
1278 
1279  bool outOfBounds = false;
1280  unsigned rank = loadOrStoreOp.getMemRefType().getRank();
1281 
1282  // For each dimension, check for out of bounds.
1283  for (unsigned r = 0; r < rank; r++) {
1284  FlatAffineValueConstraints ucst(*region.getConstraints());
1285 
1286  // Intersect memory region with constraint capturing out of bounds (both out
1287  // of upper and out of lower), and check if the constraint system is
1288  // feasible. If it is, there is at least one point out of bounds.
1289  SmallVector<int64_t, 4> ineq(rank + 1, 0);
1290  int64_t dimSize = loadOrStoreOp.getMemRefType().getDimSize(r);
1291  // TODO: handle dynamic dim sizes.
1292  if (dimSize == -1)
1293  continue;
1294 
1295  // Check for overflow: d_i >= memref dim size.
1296  ucst.addBound(BoundType::LB, r, dimSize);
1297  outOfBounds = !ucst.isEmpty();
1298  if (outOfBounds && emitError) {
1299  loadOrStoreOp.emitOpError()
1300  << "memref out of upper bound access along dimension #" << (r + 1);
1301  }
1302 
1303  // Check for a negative index.
1304  FlatAffineValueConstraints lcst(*region.getConstraints());
1305  std::fill(ineq.begin(), ineq.end(), 0);
1306  // d_i <= -1;
1307  lcst.addBound(BoundType::UB, r, -1);
1308  outOfBounds = !lcst.isEmpty();
1309  if (outOfBounds && emitError) {
1310  loadOrStoreOp.emitOpError()
1311  << "memref out of lower bound access along dimension #" << (r + 1);
1312  }
1313  }
1314  return failure(outOfBounds);
1315 }
1316 
1317 // Explicitly instantiate the template so that the compiler knows we need them!
1318 template LogicalResult
1319 mlir::affine::boundCheckLoadOrStoreOp(AffineReadOpInterface loadOp,
1320  bool emitError);
1321 template LogicalResult
1322 mlir::affine::boundCheckLoadOrStoreOp(AffineWriteOpInterface storeOp,
1323  bool emitError);
1324 
1325 // Returns in 'positions' the Block positions of 'op' in each ancestor
1326 // Block from the Block containing operation, stopping at 'limitBlock'.
1327 static void findInstPosition(Operation *op, Block *limitBlock,
1328  SmallVectorImpl<unsigned> *positions) {
1329  Block *block = op->getBlock();
1330  while (block != limitBlock) {
1331  // FIXME: This algorithm is unnecessarily O(n) and should be improved to not
1332  // rely on linear scans.
1333  int instPosInBlock = std::distance(block->begin(), op->getIterator());
1334  positions->push_back(instPosInBlock);
1335  op = block->getParentOp();
1336  block = op->getBlock();
1337  }
1338  std::reverse(positions->begin(), positions->end());
1339 }
1340 
1341 // Returns the Operation in a possibly nested set of Blocks, where the
1342 // position of the operation is represented by 'positions', which has a
1343 // Block position for each level of nesting.
1345  unsigned level, Block *block) {
1346  unsigned i = 0;
1347  for (auto &op : *block) {
1348  if (i != positions[level]) {
1349  ++i;
1350  continue;
1351  }
1352  if (level == positions.size() - 1)
1353  return &op;
1354  if (auto childAffineForOp = dyn_cast<AffineForOp>(op))
1355  return getInstAtPosition(positions, level + 1,
1356  childAffineForOp.getBody());
1357 
1358  for (auto &region : op.getRegions()) {
1359  for (auto &b : region)
1360  if (auto *ret = getInstAtPosition(positions, level + 1, &b))
1361  return ret;
1362  }
1363  return nullptr;
1364  }
1365  return nullptr;
1366 }
1367 
1368 // Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'.
1371  for (unsigned i = 0, e = cst->getNumDimVars(); i < e; ++i) {
1372  auto value = cst->getValue(i);
1373  if (ivs.count(value) == 0) {
1374  assert(isAffineForInductionVar(value));
1375  auto loop = getForInductionVarOwner(value);
1376  if (failed(cst->addAffineForOpDomain(loop)))
1377  return failure();
1378  }
1379  }
1380  return success();
1381 }
1382 
1383 /// Returns the innermost common loop depth for the set of operations in 'ops'.
1384 // TODO: Move this to LoopUtils.
1386  ArrayRef<Operation *> ops, SmallVectorImpl<AffineForOp> *surroundingLoops) {
1387  unsigned numOps = ops.size();
1388  assert(numOps > 0 && "Expected at least one operation");
1389 
1390  std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
1391  unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
1392  for (unsigned i = 0; i < numOps; ++i) {
1393  getAffineForIVs(*ops[i], &loops[i]);
1394  loopDepthLimit =
1395  std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
1396  }
1397 
1398  unsigned loopDepth = 0;
1399  for (unsigned d = 0; d < loopDepthLimit; ++d) {
1400  unsigned i;
1401  for (i = 1; i < numOps; ++i) {
1402  if (loops[i - 1][d] != loops[i][d])
1403  return loopDepth;
1404  }
1405  if (surroundingLoops)
1406  surroundingLoops->push_back(loops[i - 1][d]);
1407  ++loopDepth;
1408  }
1409  return loopDepth;
1410 }
1411 
1412 /// Computes in 'sliceUnion' the union of all slice bounds computed at
1413 /// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and
1414 /// then verifies if it is valid. Returns 'SliceComputationResult::Success' if
1415 /// union was computed correctly, an appropriate failure otherwise.
1418  ArrayRef<Operation *> opsB, unsigned loopDepth,
1419  unsigned numCommonLoops, bool isBackwardSlice,
1420  ComputationSliceState *sliceUnion) {
1421  // Compute the union of slice bounds between all pairs in 'opsA' and
1422  // 'opsB' in 'sliceUnionCst'.
1423  FlatAffineValueConstraints sliceUnionCst;
1424  assert(sliceUnionCst.getNumDimAndSymbolVars() == 0);
1425  std::vector<std::pair<Operation *, Operation *>> dependentOpPairs;
1426  for (auto *i : opsA) {
1427  MemRefAccess srcAccess(i);
1428  for (auto *j : opsB) {
1429  MemRefAccess dstAccess(j);
1430  if (srcAccess.memref != dstAccess.memref)
1431  continue;
1432  // Check if 'loopDepth' exceeds nesting depth of src/dst ops.
1433  if ((!isBackwardSlice && loopDepth > getNestingDepth(i)) ||
1434  (isBackwardSlice && loopDepth > getNestingDepth(j))) {
1435  LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n");
1437  }
1438 
1439  bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) &&
1440  isa<AffineReadOpInterface>(dstAccess.opInst);
1441  FlatAffineValueConstraints dependenceConstraints;
1442  // Check dependence between 'srcAccess' and 'dstAccess'.
1444  srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1,
1445  &dependenceConstraints, /*dependenceComponents=*/nullptr,
1446  /*allowRAR=*/readReadAccesses);
1447  if (result.value == DependenceResult::Failure) {
1448  LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n");
1450  }
1451  if (result.value == DependenceResult::NoDependence)
1452  continue;
1453  dependentOpPairs.emplace_back(i, j);
1454 
1455  // Compute slice bounds for 'srcAccess' and 'dstAccess'.
1456  ComputationSliceState tmpSliceState;
1457  mlir::affine::getComputationSliceState(i, j, &dependenceConstraints,
1458  loopDepth, isBackwardSlice,
1459  &tmpSliceState);
1460 
1461  if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
1462  // Initialize 'sliceUnionCst' with the bounds computed in previous step.
1463  if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
1464  LLVM_DEBUG(llvm::dbgs()
1465  << "Unable to compute slice bound constraints\n");
1467  }
1468  assert(sliceUnionCst.getNumDimAndSymbolVars() > 0);
1469  continue;
1470  }
1471 
1472  // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
1473  FlatAffineValueConstraints tmpSliceCst;
1474  if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
1475  LLVM_DEBUG(llvm::dbgs()
1476  << "Unable to compute slice bound constraints\n");
1478  }
1479 
1480  // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
1481  if (!sliceUnionCst.areVarsAlignedWithOther(tmpSliceCst)) {
1482 
1483  // Pre-constraint var alignment: record loop IVs used in each constraint
1484  // system.
1485  SmallPtrSet<Value, 8> sliceUnionIVs;
1486  for (unsigned k = 0, l = sliceUnionCst.getNumDimVars(); k < l; ++k)
1487  sliceUnionIVs.insert(sliceUnionCst.getValue(k));
1488  SmallPtrSet<Value, 8> tmpSliceIVs;
1489  for (unsigned k = 0, l = tmpSliceCst.getNumDimVars(); k < l; ++k)
1490  tmpSliceIVs.insert(tmpSliceCst.getValue(k));
1491 
1492  sliceUnionCst.mergeAndAlignVarsWithOther(/*offset=*/0, &tmpSliceCst);
1493 
1494  // Post-constraint var alignment: add loop IV bounds missing after
1495  // var alignment to constraint systems. This can occur if one constraint
1496  // system uses an loop IV that is not used by the other. The call
1497  // to unionBoundingBox below expects constraints for each Loop IV, even
1498  // if they are the unsliced full loop bounds added here.
1499  if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst)))
1501  if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst)))
1503  }
1504  // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
1505  if (sliceUnionCst.getNumLocalVars() > 0 ||
1506  tmpSliceCst.getNumLocalVars() > 0 ||
1507  failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
1508  LLVM_DEBUG(llvm::dbgs()
1509  << "Unable to compute union bounding box of slice bounds\n");
1511  }
1512  }
1513  }
1514 
1515  // Empty union.
1516  if (sliceUnionCst.getNumDimAndSymbolVars() == 0)
1518 
1519  // Gather loops surrounding ops from loop nest where slice will be inserted.
1521  for (auto &dep : dependentOpPairs) {
1522  ops.push_back(isBackwardSlice ? dep.second : dep.first);
1523  }
1524  SmallVector<AffineForOp, 4> surroundingLoops;
1525  unsigned innermostCommonLoopDepth =
1526  getInnermostCommonLoopDepth(ops, &surroundingLoops);
1527  if (loopDepth > innermostCommonLoopDepth) {
1528  LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n");
1530  }
1531 
1532  // Store 'numSliceLoopIVs' before converting dst loop IVs to dims.
1533  unsigned numSliceLoopIVs = sliceUnionCst.getNumDimVars();
1534 
1535  // Convert any dst loop IVs which are symbol variables to dim variables.
1536  sliceUnionCst.convertLoopIVSymbolsToDims();
1537  sliceUnion->clearBounds();
1538  sliceUnion->lbs.resize(numSliceLoopIVs, AffineMap());
1539  sliceUnion->ubs.resize(numSliceLoopIVs, AffineMap());
1540 
1541  // Get slice bounds from slice union constraints 'sliceUnionCst'.
1542  sliceUnionCst.getSliceBounds(/*offset=*/0, numSliceLoopIVs,
1543  opsA[0]->getContext(), &sliceUnion->lbs,
1544  &sliceUnion->ubs);
1545 
1546  // Add slice bound operands of union.
1547  SmallVector<Value, 4> sliceBoundOperands;
1548  sliceUnionCst.getValues(numSliceLoopIVs,
1549  sliceUnionCst.getNumDimAndSymbolVars(),
1550  &sliceBoundOperands);
1551 
1552  // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'.
1553  sliceUnion->ivs.clear();
1554  sliceUnionCst.getValues(0, numSliceLoopIVs, &sliceUnion->ivs);
1555 
1556  // Set loop nest insertion point to block start at 'loopDepth'.
1557  sliceUnion->insertPoint =
1558  isBackwardSlice
1559  ? surroundingLoops[loopDepth - 1].getBody()->begin()
1560  : std::prev(surroundingLoops[loopDepth - 1].getBody()->end());
1561 
1562  // Give each bound its own copy of 'sliceBoundOperands' for subsequent
1563  // canonicalization.
1564  sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1565  sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1566 
1567  // Check if the slice computed is valid. Return success only if it is verified
1568  // that the slice is valid, otherwise return appropriate failure status.
1569  std::optional<bool> isSliceValid = sliceUnion->isSliceValid();
1570  if (!isSliceValid) {
1571  LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n");
1573  }
1574  if (!*isSliceValid)
1576 
1578 }
1579 
1580 // TODO: extend this to handle multiple result maps.
1581 static std::optional<uint64_t> getConstDifference(AffineMap lbMap,
1582  AffineMap ubMap) {
1583  assert(lbMap.getNumResults() == 1 && "expected single result bound map");
1584  assert(ubMap.getNumResults() == 1 && "expected single result bound map");
1585  assert(lbMap.getNumDims() == ubMap.getNumDims());
1586  assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
1587  AffineExpr lbExpr(lbMap.getResult(0));
1588  AffineExpr ubExpr(ubMap.getResult(0));
1589  auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
1590  lbMap.getNumSymbols());
1591  auto cExpr = dyn_cast<AffineConstantExpr>(loopSpanExpr);
1592  if (!cExpr)
1593  return std::nullopt;
1594  return cExpr.getValue();
1595 }
1596 
1597 // Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
1598 // nest surrounding represented by slice loop bounds in 'slice'. Returns true
1599 // on success, false otherwise (if a non-constant trip count was encountered).
1600 // TODO: Make this work with non-unit step loops.
1602  const ComputationSliceState &slice,
1603  llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
1604  unsigned numSrcLoopIVs = slice.ivs.size();
1605  // Populate map from AffineForOp -> trip count
1606  for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
1607  AffineForOp forOp = getForInductionVarOwner(slice.ivs[i]);
1608  auto *op = forOp.getOperation();
1609  AffineMap lbMap = slice.lbs[i];
1610  AffineMap ubMap = slice.ubs[i];
1611  // If lower or upper bound maps are null or provide no results, it implies
1612  // that source loop was not at all sliced, and the entire loop will be a
1613  // part of the slice.
1614  if (!lbMap || lbMap.getNumResults() == 0 || !ubMap ||
1615  ubMap.getNumResults() == 0) {
1616  // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
1617  if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
1618  (*tripCountMap)[op] =
1619  forOp.getConstantUpperBound() - forOp.getConstantLowerBound();
1620  continue;
1621  }
1622  std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
1623  if (maybeConstTripCount.has_value()) {
1624  (*tripCountMap)[op] = *maybeConstTripCount;
1625  continue;
1626  }
1627  return false;
1628  }
1629  std::optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
1630  // Slice bounds are created with a constant ub - lb difference.
1631  if (!tripCount.has_value())
1632  return false;
1633  (*tripCountMap)[op] = *tripCount;
1634  }
1635  return true;
1636 }
1637 
1638 // Return the number of iterations in the given slice.
1640  const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
1641  uint64_t iterCount = 1;
1642  for (const auto &count : sliceTripCountMap) {
1643  iterCount *= count.second;
1644  }
1645  return iterCount;
1646 }
1647 
1648 const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";
1649 // Computes slice bounds by projecting out any loop IVs from
1650 // 'dependenceConstraints' at depth greater than 'loopDepth', and computes slice
1651 // bounds in 'sliceState' which represent the one loop nest's IVs in terms of
1652 // the other loop nest's IVs, symbols and constants (using 'isBackwardsSlice').
1654  Operation *depSourceOp, Operation *depSinkOp,
1655  FlatAffineValueConstraints *dependenceConstraints, unsigned loopDepth,
1656  bool isBackwardSlice, ComputationSliceState *sliceState) {
1657  // Get loop nest surrounding src operation.
1658  SmallVector<AffineForOp, 4> srcLoopIVs;
1659  getAffineForIVs(*depSourceOp, &srcLoopIVs);
1660  unsigned numSrcLoopIVs = srcLoopIVs.size();
1661 
1662  // Get loop nest surrounding dst operation.
1663  SmallVector<AffineForOp, 4> dstLoopIVs;
1664  getAffineForIVs(*depSinkOp, &dstLoopIVs);
1665  unsigned numDstLoopIVs = dstLoopIVs.size();
1666 
1667  assert((!isBackwardSlice && loopDepth <= numSrcLoopIVs) ||
1668  (isBackwardSlice && loopDepth <= numDstLoopIVs));
1669 
1670  // Project out dimensions other than those up to 'loopDepth'.
1671  unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth;
1672  unsigned num =
1673  isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth;
1674  dependenceConstraints->projectOut(pos, num);
1675 
1676  // Add slice loop IV values to 'sliceState'.
1677  unsigned offset = isBackwardSlice ? 0 : loopDepth;
1678  unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs;
1679  dependenceConstraints->getValues(offset, offset + numSliceLoopIVs,
1680  &sliceState->ivs);
1681 
1682  // Set up lower/upper bound affine maps for the slice.
1683  sliceState->lbs.resize(numSliceLoopIVs, AffineMap());
1684  sliceState->ubs.resize(numSliceLoopIVs, AffineMap());
1685 
1686  // Get bounds for slice IVs in terms of other IVs, symbols, and constants.
1687  dependenceConstraints->getSliceBounds(offset, numSliceLoopIVs,
1688  depSourceOp->getContext(),
1689  &sliceState->lbs, &sliceState->ubs);
1690 
1691  // Set up bound operands for the slice's lower and upper bounds.
1692  SmallVector<Value, 4> sliceBoundOperands;
1693  unsigned numDimsAndSymbols = dependenceConstraints->getNumDimAndSymbolVars();
1694  for (unsigned i = 0; i < numDimsAndSymbols; ++i) {
1695  if (i < offset || i >= offset + numSliceLoopIVs) {
1696  sliceBoundOperands.push_back(dependenceConstraints->getValue(i));
1697  }
1698  }
1699 
1700  // Give each bound its own copy of 'sliceBoundOperands' for subsequent
1701  // canonicalization.
1702  sliceState->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1703  sliceState->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1704 
1705  // Set destination loop nest insertion point to block start at 'dstLoopDepth'.
1706  sliceState->insertPoint =
1707  isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin()
1708  : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end());
1709 
1710  llvm::SmallDenseSet<Value, 8> sequentialLoops;
1711  if (isa<AffineReadOpInterface>(depSourceOp) &&
1712  isa<AffineReadOpInterface>(depSinkOp)) {
1713  // For read-read access pairs, clear any slice bounds on sequential loops.
1714  // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
1715  getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0],
1716  &sequentialLoops);
1717  }
1718  auto getSliceLoop = [&](unsigned i) {
1719  return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i];
1720  };
1721  auto isInnermostInsertion = [&]() {
1722  return (isBackwardSlice ? loopDepth >= srcLoopIVs.size()
1723  : loopDepth >= dstLoopIVs.size());
1724  };
1725  llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
1726  auto srcIsUnitSlice = [&]() {
1727  return (buildSliceTripCountMap(*sliceState, &sliceTripCountMap) &&
1728  (getSliceIterationCount(sliceTripCountMap) == 1));
1729  };
1730  // Clear all sliced loop bounds beginning at the first sequential loop, or
1731  // first loop with a slice fusion barrier attribute..
1732 
1733  for (unsigned i = 0; i < numSliceLoopIVs; ++i) {
1734  Value iv = getSliceLoop(i).getInductionVar();
1735  if (sequentialLoops.count(iv) == 0 &&
1736  getSliceLoop(i)->getAttr(kSliceFusionBarrierAttrName) == nullptr)
1737  continue;
1738  // Skip reset of bounds of reduction loop inserted in the destination loop
1739  // that meets the following conditions:
1740  // 1. Slice is single trip count.
1741  // 2. Loop bounds of the source and destination match.
1742  // 3. Is being inserted at the innermost insertion point.
1743  std::optional<bool> isMaximal = sliceState->isMaximal();
1744  if (isLoopParallelAndContainsReduction(getSliceLoop(i)) &&
1745  isInnermostInsertion() && srcIsUnitSlice() && isMaximal && *isMaximal)
1746  continue;
1747  for (unsigned j = i; j < numSliceLoopIVs; ++j) {
1748  sliceState->lbs[j] = AffineMap();
1749  sliceState->ubs[j] = AffineMap();
1750  }
1751  break;
1752  }
1753 }
1754 
1755 /// Creates a computation slice of the loop nest surrounding 'srcOpInst',
1756 /// updates the slice loop bounds with any non-null bound maps specified in
1757 /// 'sliceState', and inserts this slice into the loop nest surrounding
1758 /// 'dstOpInst' at loop depth 'dstLoopDepth'.
1759 // TODO: extend the slicing utility to compute slices that
1760 // aren't necessarily a one-to-one relation b/w the source and destination. The
1761 // relation between the source and destination could be many-to-many in general.
1762 // TODO: the slice computation is incorrect in the cases
1763 // where the dependence from the source to the destination does not cover the
1764 // entire destination index set. Subtract out the dependent destination
1765 // iterations from destination index set and check for emptiness --- this is one
1766 // solution.
1768  Operation *srcOpInst, Operation *dstOpInst, unsigned dstLoopDepth,
1769  ComputationSliceState *sliceState) {
1770  // Get loop nest surrounding src operation.
1771  SmallVector<AffineForOp, 4> srcLoopIVs;
1772  getAffineForIVs(*srcOpInst, &srcLoopIVs);
1773  unsigned numSrcLoopIVs = srcLoopIVs.size();
1774 
1775  // Get loop nest surrounding dst operation.
1776  SmallVector<AffineForOp, 4> dstLoopIVs;
1777  getAffineForIVs(*dstOpInst, &dstLoopIVs);
1778  unsigned dstLoopIVsSize = dstLoopIVs.size();
1779  if (dstLoopDepth > dstLoopIVsSize) {
1780  dstOpInst->emitError("invalid destination loop depth");
1781  return AffineForOp();
1782  }
1783 
1784  // Find the op block positions of 'srcOpInst' within 'srcLoopIVs'.
1785  SmallVector<unsigned, 4> positions;
1786  // TODO: This code is incorrect since srcLoopIVs can be 0-d.
1787  findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions);
1788 
1789  // Clone src loop nest and insert it a the beginning of the operation block
1790  // of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
1791  auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1];
1792  OpBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin());
1793  auto sliceLoopNest =
1794  cast<AffineForOp>(b.clone(*srcLoopIVs[0].getOperation()));
1795 
1796  Operation *sliceInst =
1797  getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody());
1798  // Get loop nest surrounding 'sliceInst'.
1799  SmallVector<AffineForOp, 4> sliceSurroundingLoops;
1800  getAffineForIVs(*sliceInst, &sliceSurroundingLoops);
1801 
1802  // Sanity check.
1803  unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size();
1804  (void)sliceSurroundingLoopsSize;
1805  assert(dstLoopDepth + numSrcLoopIVs >= sliceSurroundingLoopsSize);
1806  unsigned sliceLoopLimit = dstLoopDepth + numSrcLoopIVs;
1807  (void)sliceLoopLimit;
1808  assert(sliceLoopLimit >= sliceSurroundingLoopsSize);
1809 
1810  // Update loop bounds for loops in 'sliceLoopNest'.
1811  for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
1812  auto forOp = sliceSurroundingLoops[dstLoopDepth + i];
1813  if (AffineMap lbMap = sliceState->lbs[i])
1814  forOp.setLowerBound(sliceState->lbOperands[i], lbMap);
1815  if (AffineMap ubMap = sliceState->ubs[i])
1816  forOp.setUpperBound(sliceState->ubOperands[i], ubMap);
1817  }
1818  return sliceLoopNest;
1819 }
1820 
1821 // Constructs MemRefAccess populating it with the memref, its indices and
1822 // opinst from 'loadOrStoreOpInst'.
1824  if (auto loadOp = dyn_cast<AffineReadOpInterface>(loadOrStoreOpInst)) {
1825  memref = loadOp.getMemRef();
1826  opInst = loadOrStoreOpInst;
1827  llvm::append_range(indices, loadOp.getMapOperands());
1828  } else {
1829  assert(isa<AffineWriteOpInterface>(loadOrStoreOpInst) &&
1830  "Affine read/write op expected");
1831  auto storeOp = cast<AffineWriteOpInterface>(loadOrStoreOpInst);
1832  opInst = loadOrStoreOpInst;
1833  memref = storeOp.getMemRef();
1834  llvm::append_range(indices, storeOp.getMapOperands());
1835  }
1836 }
1837 
1838 unsigned MemRefAccess::getRank() const {
1839  return cast<MemRefType>(memref.getType()).getRank();
1840 }
1841 
1843  return isa<AffineWriteOpInterface>(opInst);
1844 }
1845 
1846 /// Returns the nesting depth of this statement, i.e., the number of loops
1847 /// surrounding this statement.
1849  Operation *currOp = op;
1850  unsigned depth = 0;
1851  while ((currOp = currOp->getParentOp())) {
1852  if (isa<AffineForOp>(currOp))
1853  depth++;
1854  }
1855  return depth;
1856 }
1857 
1858 /// Equal if both affine accesses are provably equivalent (at compile
1859 /// time) when considering the memref, the affine maps and their respective
1860 /// operands. The equality of access functions + operands is checked by
1861 /// subtracting fully composed value maps, and then simplifying the difference
1862 /// using the expression flattener.
1863 /// TODO: this does not account for aliasing of memrefs.
1864 bool MemRefAccess::operator==(const MemRefAccess &rhs) const {
1865  if (memref != rhs.memref)
1866  return false;
1867 
1868  AffineValueMap diff, thisMap, rhsMap;
1869  getAccessMap(&thisMap);
1870  rhs.getAccessMap(&rhsMap);
1871  AffineValueMap::difference(thisMap, rhsMap, &diff);
1872  return llvm::all_of(diff.getAffineMap().getResults(),
1873  [](AffineExpr e) { return e == 0; });
1874 }
1875 
1877  auto *currOp = op.getParentOp();
1878  AffineForOp currAffineForOp;
1879  // Traverse up the hierarchy collecting all 'affine.for' and affine.parallel
1880  // operation while skipping over 'affine.if' operations.
1881  while (currOp) {
1882  if (AffineForOp currAffineForOp = dyn_cast<AffineForOp>(currOp))
1883  ivs.push_back(currAffineForOp.getInductionVar());
1884  else if (auto parOp = dyn_cast<AffineParallelOp>(currOp))
1885  llvm::append_range(ivs, parOp.getIVs());
1886  currOp = currOp->getParentOp();
1887  }
1888  std::reverse(ivs.begin(), ivs.end());
1889 }
1890 
1891 /// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
1892 /// where each lists loops from outer-most to inner-most in loop nest.
1894  Operation &b) {
1895  SmallVector<Value, 4> loopsA, loopsB;
1896  getAffineIVs(a, loopsA);
1897  getAffineIVs(b, loopsB);
1898 
1899  unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
1900  unsigned numCommonLoops = 0;
1901  for (unsigned i = 0; i < minNumLoops; ++i) {
1902  if (loopsA[i] != loopsB[i])
1903  break;
1904  ++numCommonLoops;
1905  }
1906  return numCommonLoops;
1907 }
1908 
1909 static std::optional<int64_t> getMemoryFootprintBytes(Block &block,
1910  Block::iterator start,
1911  Block::iterator end,
1912  int memorySpace) {
1913  SmallDenseMap<Value, std::unique_ptr<MemRefRegion>, 4> regions;
1914 
1915  // Walk this 'affine.for' operation to gather all memory regions.
1916  auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult {
1917  if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst)) {
1918  // Neither load nor a store op.
1919  return WalkResult::advance();
1920  }
1921 
1922  // Compute the memref region symbolic in any IVs enclosing this block.
1923  auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
1924  if (failed(
1925  region->compute(opInst,
1926  /*loopDepth=*/getNestingDepth(&*block.begin())))) {
1927  return opInst->emitError("error obtaining memory region\n");
1928  }
1929 
1930  auto [it, inserted] = regions.try_emplace(region->memref);
1931  if (inserted) {
1932  it->second = std::move(region);
1933  } else if (failed(it->second->unionBoundingBox(*region))) {
1934  return opInst->emitWarning(
1935  "getMemoryFootprintBytes: unable to perform a union on a memory "
1936  "region");
1937  }
1938  return WalkResult::advance();
1939  });
1940  if (result.wasInterrupted())
1941  return std::nullopt;
1942 
1943  int64_t totalSizeInBytes = 0;
1944  for (const auto &region : regions) {
1945  std::optional<int64_t> size = region.second->getRegionSize();
1946  if (!size.has_value())
1947  return std::nullopt;
1948  totalSizeInBytes += *size;
1949  }
1950  return totalSizeInBytes;
1951 }
1952 
1953 std::optional<int64_t> mlir::affine::getMemoryFootprintBytes(AffineForOp forOp,
1954  int memorySpace) {
1955  auto *forInst = forOp.getOperation();
1957  *forInst->getBlock(), Block::iterator(forInst),
1958  std::next(Block::iterator(forInst)), memorySpace);
1959 }
1960 
1961 /// Returns whether a loop is parallel and contains a reduction loop.
1963  SmallVector<LoopReduction> reductions;
1964  if (!isLoopParallel(forOp, &reductions))
1965  return false;
1966  return !reductions.empty();
1967 }
1968 
1969 /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
1970 /// at 'forOp'.
1972  AffineForOp forOp, llvm::SmallDenseSet<Value, 8> *sequentialLoops) {
1973  forOp->walk([&](Operation *op) {
1974  if (auto innerFor = dyn_cast<AffineForOp>(op))
1975  if (!isLoopParallel(innerFor))
1976  sequentialLoops->insert(innerFor.getInductionVar());
1977  });
1978 }
1979 
1981  FlatAffineValueConstraints fac(set);
1982  if (fac.isEmpty())
1983  return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(),
1984  set.getContext());
1986 
1987  auto simplifiedSet = fac.getAsIntegerSet(set.getContext());
1988  assert(simplifiedSet && "guaranteed to succeed while roundtripping");
1989  return simplifiedSet;
1990 }
1991 
1992 static void unpackOptionalValues(ArrayRef<std::optional<Value>> source,
1993  SmallVector<Value> &target) {
1994  target =
1995  llvm::to_vector<4>(llvm::map_range(source, [](std::optional<Value> val) {
1996  return val.has_value() ? *val : Value();
1997  }));
1998 }
1999 
2000 /// Bound an identifier `pos` in a given FlatAffineValueConstraints with
2001 /// constraints drawn from an affine map. Before adding the constraint, the
2002 /// dimensions/symbols of the affine map are aligned with `constraints`.
2003 /// `operands` are the SSA Value operands used with the affine map.
2004 /// Note: This function adds a new symbol column to the `constraints` for each
2005 /// dimension/symbol that exists in the affine map but not in `constraints`.
2006 static LogicalResult alignAndAddBound(FlatAffineValueConstraints &constraints,
2007  BoundType type, unsigned pos,
2008  AffineMap map, ValueRange operands) {
2009  SmallVector<Value> dims, syms, newSyms;
2010  unpackOptionalValues(constraints.getMaybeValues(VarKind::SetDim), dims);
2011  unpackOptionalValues(constraints.getMaybeValues(VarKind::Symbol), syms);
2012 
2013  AffineMap alignedMap =
2014  alignAffineMapWithValues(map, operands, dims, syms, &newSyms);
2015  for (unsigned i = syms.size(); i < newSyms.size(); ++i)
2016  constraints.appendSymbolVar(newSyms[i]);
2017  return constraints.addBound(type, pos, alignedMap);
2018 }
2019 
2020 /// Add `val` to each result of `map`.
2021 static AffineMap addConstToResults(AffineMap map, int64_t val) {
2022  SmallVector<AffineExpr> newResults;
2023  for (AffineExpr r : map.getResults())
2024  newResults.push_back(r + val);
2025  return AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
2026  map.getContext());
2027 }
2028 
2029 // Attempt to simplify the given min/max operation by proving that its value is
2030 // bounded by the same lower and upper bound.
2031 //
2032 // Bounds are computed by FlatAffineValueConstraints. Invariants required for
2033 // finding/proving bounds should be supplied via `constraints`.
2034 //
2035 // 1. Add dimensions for `op` and `opBound` (lower or upper bound of `op`).
2036 // 2. Compute an upper bound of `op` (in case of `isMin`) or a lower bound (in
2037 // case of `!isMin`) and bind it to `opBound`. SSA values that are used in
2038 // `op` but are not part of `constraints`, are added as extra symbols.
2039 // 3. For each result of `op`: Add result as a dimension `r_i`. Prove that:
2040 // * If `isMin`: r_i >= opBound
2041 // * If `isMax`: r_i <= opBound
2042 // If this is the case, ub(op) == lb(op).
2043 // 4. Replace `op` with `opBound`.
2044 //
2045 // In summary, the following constraints are added throughout this function.
2046 // Note: `invar` are dimensions added by the caller to express the invariants.
2047 // (Showing only the case where `isMin`.)
2048 //
2049 // invar | op | opBound | r_i | extra syms... | const | eq/ineq
2050 // ------+-------+---------+-----+---------------+-------+-------------------
2051 // (various eq./ineq. constraining `invar`, added by the caller)
2052 // ... | 0 | 0 | 0 | 0 | ... | ...
2053 // ------+-------+---------+-----+---------------+-------+-------------------
2054 // (various ineq. constraining `op` in terms of `op` operands (`invar` and
2055 // extra `op` operands "extra syms" that are not in `invar`)).
2056 // ... | -1 | 0 | 0 | ... | ... | >= 0
2057 // ------+-------+---------+-----+---------------+-------+-------------------
2058 // (set `opBound` to `op` upper bound in terms of `invar` and "extra syms")
2059 // ... | 0 | -1 | 0 | ... | ... | = 0
2060 // ------+-------+---------+-----+---------------+-------+-------------------
2061 // (for each `op` map result r_i: set r_i to corresponding map result,
2062 // prove that r_i >= minOpUb via contradiction)
2063 // ... | 0 | 0 | -1 | ... | ... | = 0
2064 // 0 | 0 | 1 | -1 | 0 | -1 | >= 0
2065 //
2067  Operation *op, FlatAffineValueConstraints constraints) {
2068  bool isMin = isa<AffineMinOp>(op);
2069  assert((isMin || isa<AffineMaxOp>(op)) && "expect AffineMin/MaxOp");
2070  MLIRContext *ctx = op->getContext();
2071  Builder builder(ctx);
2072  AffineMap map =
2073  isMin ? cast<AffineMinOp>(op).getMap() : cast<AffineMaxOp>(op).getMap();
2074  ValueRange operands = op->getOperands();
2075  unsigned numResults = map.getNumResults();
2076 
2077  // Add a few extra dimensions.
2078  unsigned dimOp = constraints.appendDimVar(); // `op`
2079  unsigned dimOpBound = constraints.appendDimVar(); // `op` lower/upper bound
2080  unsigned resultDimStart = constraints.appendDimVar(/*num=*/numResults);
2081 
2082  // Add an inequality for each result expr_i of map:
2083  // isMin: op <= expr_i, !isMin: op >= expr_i
2084  auto boundType = isMin ? BoundType::UB : BoundType::LB;
2085  // Upper bounds are exclusive, so add 1. (`affine.min` ops are inclusive.)
2086  AffineMap mapLbUb = isMin ? addConstToResults(map, 1) : map;
2087  if (failed(
2088  alignAndAddBound(constraints, boundType, dimOp, mapLbUb, operands)))
2089  return failure();
2090 
2091  // Try to compute a lower/upper bound for op, expressed in terms of the other
2092  // `dims` and extra symbols.
2093  SmallVector<AffineMap> opLb(1), opUb(1);
2094  constraints.getSliceBounds(dimOp, 1, ctx, &opLb, &opUb);
2095  AffineMap sliceBound = isMin ? opUb[0] : opLb[0];
2096  // TODO: `getSliceBounds` may return multiple bounds at the moment. This is
2097  // a TODO of `getSliceBounds` and not handled here.
2098  if (!sliceBound || sliceBound.getNumResults() != 1)
2099  return failure(); // No or multiple bounds found.
2100  // Recover the inclusive UB in the case of an `affine.min`.
2101  AffineMap boundMap = isMin ? addConstToResults(sliceBound, -1) : sliceBound;
2102 
2103  // Add an equality: Set dimOpBound to computed bound.
2104  // Add back dimension for op. (Was removed by `getSliceBounds`.)
2105  AffineMap alignedBoundMap = boundMap.shiftDims(/*shift=*/1, /*offset=*/dimOp);
2106  if (failed(constraints.addBound(BoundType::EQ, dimOpBound, alignedBoundMap)))
2107  return failure();
2108 
2109  // If the constraint system is empty, there is an inconsistency. (E.g., this
2110  // can happen if loop lb > ub.)
2111  if (constraints.isEmpty())
2112  return failure();
2113 
2114  // In the case of `isMin` (`!isMin` is inversed):
2115  // Prove that each result of `map` has a lower bound that is equal to (or
2116  // greater than) the upper bound of `op` (`dimOpBound`). In that case, `op`
2117  // can be replaced with the bound. I.e., prove that for each result
2118  // expr_i (represented by dimension r_i):
2119  //
2120  // r_i >= opBound
2121  //
2122  // To prove this inequality, add its negation to the constraint set and prove
2123  // that the constraint set is empty.
2124  for (unsigned i = resultDimStart; i < resultDimStart + numResults; ++i) {
2125  FlatAffineValueConstraints newConstr(constraints);
2126 
2127  // Add an equality: r_i = expr_i
2128  // Note: These equalities could have been added earlier and used to express
2129  // minOp <= expr_i. However, then we run the risk that `getSliceBounds`
2130  // computes minOpUb in terms of r_i dims, which is not desired.
2131  if (failed(alignAndAddBound(newConstr, BoundType::EQ, i,
2132  map.getSubMap({i - resultDimStart}), operands)))
2133  return failure();
2134 
2135  // If `isMin`: Add inequality: r_i < opBound
2136  // equiv.: opBound - r_i - 1 >= 0
2137  // If `!isMin`: Add inequality: r_i > opBound
2138  // equiv.: -opBound + r_i - 1 >= 0
2139  SmallVector<int64_t> ineq(newConstr.getNumCols(), 0);
2140  ineq[dimOpBound] = isMin ? 1 : -1;
2141  ineq[i] = isMin ? -1 : 1;
2142  ineq[newConstr.getNumCols() - 1] = -1;
2143  newConstr.addInequality(ineq);
2144  if (!newConstr.isEmpty())
2145  return failure();
2146  }
2147 
2148  // Lower and upper bound of `op` are equal. Replace `minOp` with its bound.
2149  AffineMap newMap = alignedBoundMap;
2150  SmallVector<Value> newOperands;
2151  unpackOptionalValues(constraints.getMaybeValues(), newOperands);
2152  // If dims/symbols have known constant values, use those in order to simplify
2153  // the affine map further.
2154  for (int64_t i = 0, e = constraints.getNumDimAndSymbolVars(); i < e; ++i) {
2155  // Skip unused operands and operands that are already constants.
2156  if (!newOperands[i] || getConstantIntValue(newOperands[i]))
2157  continue;
2158  if (auto bound = constraints.getConstantBound64(BoundType::EQ, i)) {
2159  AffineExpr expr =
2160  i < newMap.getNumDims()
2161  ? builder.getAffineDimExpr(i)
2162  : builder.getAffineSymbolExpr(i - newMap.getNumDims());
2163  newMap = newMap.replace(expr, builder.getAffineConstantExpr(*bound),
2164  newMap.getNumDims(), newMap.getNumSymbols());
2165  }
2166  }
2167  affine::canonicalizeMapAndOperands(&newMap, &newOperands);
2168  return AffineValueMap(newMap, newOperands);
2169 }
static std::optional< uint64_t > getConstDifference(AffineMap lbMap, AffineMap ubMap)
Definition: Utils.cpp:1581
static void findInstPosition(Operation *op, Block *limitBlock, SmallVectorImpl< unsigned > *positions)
Definition: Utils.cpp:1327
const char *const kSliceFusionBarrierAttrName
Definition: Utils.cpp:1648
static LogicalResult addMissingLoopIVBounds(SmallPtrSet< Value, 8 > &ivs, FlatAffineValueConstraints *cst)
Definition: Utils.cpp:1369
MemRefDependenceGraph::Node Node
Definition: Utils.cpp:38
static Operation * getInstAtPosition(ArrayRef< unsigned > positions, unsigned level, Block *block)
Definition: Utils.cpp:1344
static std::optional< int64_t > getMemoryFootprintBytes(Block &block, Block::iterator start, Block::iterator end, int memorySpace)
Definition: Utils.cpp:1909
static AffineMap addConstToResults(AffineMap map, int64_t val)
Add val to each result of map.
Definition: Utils.cpp:2021
static LogicalResult alignAndAddBound(FlatAffineValueConstraints &constraints, BoundType type, unsigned pos, AffineMap map, ValueRange operands)
Bound an identifier pos in a given FlatAffineValueConstraints with constraints drawn from an affine m...
Definition: Utils.cpp:2006
static void unpackOptionalValues(ArrayRef< std::optional< Value >> source, SmallVector< Value > &target)
Definition: Utils.cpp:1992
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:236
unsigned getPosition() const
Definition: AffineExpr.cpp:348
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineMap.h:267
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:515
void dump() const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:654
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition: Block.h:305
iterator begin()
Definition: Block.h:143
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:408
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:412
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:404
IntegerSet getAsIntegerSet(MLIRContext *context) const
Returns the constraint system as an integer set.
void getSliceBounds(unsigned offset, unsigned num, MLIRContext *context, SmallVectorImpl< AffineMap > *lbMaps, SmallVectorImpl< AffineMap > *ubMaps, bool closedUB=false)
Computes the lower and upper bounds of the first num dimensional variables (starting at offset) as an...
LogicalResult unionBoundingBox(const FlatLinearValueConstraints &other)
Updates the constraints to be the smallest bounding (enclosing) box that contains the points of this ...
void mergeAndAlignVarsWithOther(unsigned offset, FlatLinearValueConstraints *other)
Merge and align the variables of this and other starting at offset, so that both constraint systems g...
SmallVector< std::optional< Value > > getMaybeValues() const
Value getValue(unsigned pos) const
Returns the Value associated with the pos^th variable.
void projectOut(Value val)
Projects out the variable that is associate with Value.
bool containsVar(Value val) const
Returns true if a variable with the specified Value exists, false otherwise.
bool areVarsAlignedWithOther(const FlatLinearConstraints &other)
Returns true if this constraint system and other are in the same space, i.e., if they are associated ...
void getValues(unsigned start, unsigned end, SmallVectorImpl< Value > *values) const
Returns the Values associated with variables in range [start, end).
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition: IntegerSet.h:44
unsigned getNumDims() const
Definition: IntegerSet.cpp:15
MLIRContext * getContext() const
Definition: IntegerSet.cpp:57
static IntegerSet getEmptySet(unsigned numDims, unsigned numSymbols, MLIRContext *context)
Definition: IntegerSet.h:56
unsigned getNumSymbols() const
Definition: IntegerSet.cpp:16
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:216
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Definition: Operation.cpp:280
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:874
result_range getResults()
Definition: Operation.h:415
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult advance()
Definition: Visitors.h:51
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
Value getOperand(unsigned i) const
unsigned getNumOperands() const
static void difference(const AffineValueMap &a, const AffineValueMap &b, AffineValueMap *res)
Return the value map that is the difference of value maps 'a' and 'b', represented as an affine map a...
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
LogicalResult addBound(presburger::BoundType type, unsigned pos, AffineMap boundMap, ValueRange operands)
Adds a bound for the variable at the specified position with constraints being drawn from the specifi...
void convertLoopIVSymbolsToDims()
Changes all symbol variables which are loop IVs to dim variables.
LogicalResult addDomainFromSliceMaps(ArrayRef< AffineMap > lbMaps, ArrayRef< AffineMap > ubMaps, ArrayRef< Value > operands)
Adds constraints (lower and upper bounds) for each loop in the loop nest described by the bound maps ...
LogicalResult addAffineForOpDomain(AffineForOp forOp)
Adds constraints (lower and upper bounds) for the specified 'affine.for' operation's Value using IR i...
LogicalResult addSliceBounds(ArrayRef< Value > values, ArrayRef< AffineMap > lbMaps, ArrayRef< AffineMap > ubMaps, ArrayRef< Value > operands)
Adds slice lower bounds represented by lower bounds in lbMaps and upper bounds in ubMaps to each vari...
void removeTrivialRedundancy()
Removes duplicate constraints, trivially true constraints, and constraints that can be detected as re...
std::optional< int64_t > getConstantBoundOnDimSize64(unsigned pos, SmallVectorImpl< int64_t > *lb=nullptr, int64_t *boundFloorDivisor=nullptr, SmallVectorImpl< int64_t > *ub=nullptr, unsigned *minLbPos=nullptr, unsigned *minUbPos=nullptr) const
The same, but casts to int64_t.
std::optional< int64_t > getConstantBound64(BoundType type, unsigned pos) const
The same, but casts to int64_t.
bool isEmpty() const
Checks for emptiness by performing variable elimination on all variables, running the GCD test on eac...
unsigned getNumCols() const
Returns the number of columns in the constraint system.
void addInequality(ArrayRef< DynamicAPInt > inEq)
Adds an inequality (>= 0) from the coefficients specified in inEq.
bool isIntegerEmpty() const
Return true if all the sets in the union are known to be integer empty false otherwise.
PresburgerSet subtract(const PresburgerRelation &set) const
std::optional< uint64_t > getConstantTripCount(AffineForOp forOp)
Returns the trip count of the loop if it's a constant, std::nullopt otherwise.
IntegerSet simplifyIntegerSet(IntegerSet set)
Simplify the integer set by simplifying the underlying affine expressions by flattening and some simp...
Definition: Utils.cpp:1980
void getEnclosingAffineOps(Operation &op, SmallVectorImpl< Operation * > *ops)
Populates 'ops' with affine operations enclosing op ordered from outermost to innermost while stoppin...
Definition: Utils.cpp:647
SliceComputationResult computeSliceUnion(ArrayRef< Operation * > opsA, ArrayRef< Operation * > opsB, unsigned loopDepth, unsigned numCommonLoops, bool isBackwardSlice, ComputationSliceState *sliceUnion)
Computes in 'sliceUnion' the union of all slice bounds computed at 'loopDepth' between all dependent ...
Definition: Utils.cpp:1417
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
Definition: AffineOps.cpp:2561
bool isLoopParallelAndContainsReduction(AffineForOp forOp)
Returns whether a loop is a parallel loop and contains a reduction loop.
Definition: Utils.cpp:1962
unsigned getNumCommonSurroundingLoops(Operation &a, Operation &b)
Returns the number of surrounding loops common to both A and B.
Definition: Utils.cpp:1893
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: AffineOps.cpp:2565
void getAffineIVs(Operation &op, SmallVectorImpl< Value > &ivs)
Populates 'ivs' with IVs of the surrounding affine.for and affine.parallel ops ordered from the outer...
Definition: Utils.cpp:1876
void getComputationSliceState(Operation *depSourceOp, Operation *depSinkOp, FlatAffineValueConstraints *dependenceConstraints, unsigned loopDepth, bool isBackwardSlice, ComputationSliceState *sliceState)
Computes the computation slice loop bounds for one loop nest as affine maps of the other loop nest's ...
Definition: Utils.cpp:1653
void getSequentialLoops(AffineForOp forOp, llvm::SmallDenseSet< Value, 8 > *sequentialLoops)
Returns in 'sequentialLoops' all sequential loops in loop nest rooted at 'forOp'.
Definition: Utils.cpp:1971
DependenceResult checkMemrefAccessDependence(const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, unsigned loopDepth, FlatAffineValueConstraints *dependenceConstraints=nullptr, SmallVector< DependenceComponent, 2 > *dependenceComponents=nullptr, bool allowRAR=false)
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
Definition: AffineOps.cpp:1435
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
Definition: AffineOps.cpp:2553
void getAffineForIVs(Operation &op, SmallVectorImpl< AffineForOp > *loops)
Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered from the outermost 'affine....
Definition: Utils.cpp:633
std::optional< int64_t > getMemoryFootprintBytes(AffineForOp forOp, int memorySpace=-1)
Gets the memory footprint of all data touched in the specified memory space in bytes; if the memory s...
Definition: Utils.cpp:1953
unsigned getInnermostCommonLoopDepth(ArrayRef< Operation * > ops, SmallVectorImpl< AffineForOp > *surroundingLoops=nullptr)
Returns the innermost common loop depth for the set of operations in 'ops'.
Definition: Utils.cpp:1385
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
Definition: AffineOps.cpp:392
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Definition: AffineOps.cpp:2576
std::optional< uint64_t > getIntOrFloatMemRefSizeInBytes(MemRefType memRefType)
Returns the size of a memref with element type int or float in bytes if it's statically shaped,...
Definition: Utils.cpp:1246
unsigned getNestingDepth(Operation *op)
Returns the nesting depth of this operation, i.e., the number of loops surrounding this operation.
Definition: Utils.cpp:1848
uint64_t getSliceIterationCount(const llvm::SmallDenseMap< Operation *, uint64_t, 8 > &sliceTripCountMap)
Return the number of iterations for the slicetripCountMap provided.
Definition: Utils.cpp:1639
bool isLoopParallel(AffineForOp forOp, SmallVectorImpl< LoopReduction > *parallelReductions=nullptr)
Returns true if ‘forOp’ is a parallel loop.
LogicalResult boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, bool emitError=true)
Checks a load or store op for an out of bound access; returns failure if the access is out of bounds ...
bool buildSliceTripCountMap(const ComputationSliceState &slice, llvm::SmallDenseMap< Operation *, uint64_t, 8 > *tripCountMap)
Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop nest surrounding represe...
Definition: Utils.cpp:1601
AffineForOp insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst, unsigned dstLoopDepth, ComputationSliceState *sliceState)
Creates a clone of the computation contained in the loop nest surrounding 'srcOpInst',...
Definition: Utils.cpp:1767
FailureOr< AffineValueMap > simplifyConstrainedMinMaxOp(Operation *op, FlatAffineValueConstraints constraints)
Try to simplify the given affine.min or affine.max op to an affine map with a single result and opera...
Definition: Utils.cpp:2066
std::optional< int64_t > getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType)
Returns the memref's element type's size in bytes where the elemental type is an int or float or a ve...
Definition: Utils.cpp:1196
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
BoundType
The type of bound: equal, lower bound or upper bound.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
AffineMap alignAffineMapWithValues(AffineMap map, ValueRange operands, ValueRange dims, ValueRange syms, SmallVector< Value > *newSyms=nullptr)
Re-indexes the dimensions and symbols of an affine map with given operands values to align with dims ...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their associated operands for a ...
Definition: Utils.h:259
std::optional< bool > isSliceValid() const
Checks the validity of the slice computed.
Definition: Utils.cpp:818
SmallVector< Value, 4 > ivs
Definition: Utils.h:262
LogicalResult getAsConstraints(FlatAffineValueConstraints *cst) const
Definition: Utils.cpp:680
LogicalResult getSourceAsConstraints(FlatAffineValueConstraints &cst) const
Adds to 'cst' constraints which represent the original loop bounds on 'ivs' in 'this'.
Definition: Utils.cpp:664
std::vector< SmallVector< Value, 4 > > ubOperands
Definition: Utils.h:270
SmallVector< AffineMap, 4 > ubs
Definition: Utils.h:266
std::optional< bool > isMaximal() const
Returns true if the computation slice encloses all the iterations of the sliced loop nest.
Definition: Utils.cpp:888
SmallVector< AffineMap, 4 > lbs
Definition: Utils.h:264
std::vector< SmallVector< Value, 4 > > lbOperands
Definition: Utils.h:268
Checks whether two accesses to the same memref access the same element.
enum mlir::affine::DependenceResult::ResultEnum value
SmallVector< Operation *, 4 > loadOpInsts
Definition: Utils.h:40
void collect(Operation *opToWalk)
Definition: Utils.cpp:43
SmallVector< Operation *, 4 > storeOpInsts
Definition: Utils.h:41
Encapsulates a memref load or store access information.
unsigned getRank() const
Definition: Utils.cpp:1838
MemRefAccess(Operation *opInst)
Constructs a MemRefAccess from a load or store operation.
Definition: Utils.cpp:1823
void getAccessMap(AffineValueMap *accessMap) const
Populates 'accessMap' with composition of AffineApplyOps reachable from 'indices'.
bool operator==(const MemRefAccess &rhs) const
Equal if both affine accesses can be proved to be equivalent at compile time (considering the memrefs...
Definition: Utils.cpp:1864
SmallVector< Operation *, 4 > loads
Definition: Utils.h:64
SmallVector< Operation *, 4 > stores
Definition: Utils.h:66
unsigned getStoreOpCount(Value memref) const
Definition: Utils.cpp:67
unsigned addNode(Operation *op)
Definition: Utils.cpp:251
unsigned getIncomingMemRefAccesses(unsigned id, Value memref)
Definition: Utils.cpp:388
void removeEdge(unsigned srcId, unsigned dstId, Value value)
Definition: Utils.cpp:327
void addEdge(unsigned srcId, unsigned dstId, Value value)
Definition: Utils.cpp:316
Node * getForOpNode(AffineForOp forOp)
Definition: Utils.cpp:243
bool hasDependencePath(unsigned srcId, unsigned dstId)
Definition: Utils.cpp:354
void clearNodeLoadAndStores(unsigned id)
Definition: Utils.cpp:576
void updateEdges(unsigned srcId, unsigned dstId, const DenseSet< Value > &privateMemRefs, bool removeSrcId)
Definition: Utils.cpp:504
bool writesToLiveInOrEscapingMemrefs(unsigned id)
Definition: Utils.cpp:281
void addToNode(unsigned id, const SmallVectorImpl< Operation * > &loads, const SmallVectorImpl< Operation * > &stores)
Definition: Utils.cpp:568
void forEachMemRefInputEdge(unsigned id, const std::function< void(Edge)> &callback)
Definition: Utils.cpp:584
unsigned getOutEdgeCount(unsigned id, Value memref=nullptr)
Definition: Utils.cpp:404
void removeNode(unsigned id)
Definition: Utils.cpp:258
Operation * getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId)
Definition: Utils.cpp:428
void forEachMemRefOutputEdge(unsigned id, const std::function< void(Edge)> &callback)
Definition: Utils.cpp:592
void forEachMemRefEdge(ArrayRef< Edge > edges, const std::function< void(Edge)> &callback)
Definition: Utils.cpp:600
bool hasEdge(unsigned srcId, unsigned dstId, Value value=nullptr)
Definition: Utils.cpp:301
void gatherDefiningNodes(unsigned id, DenseSet< unsigned > &definingNodes)
Return all nodes which define SSA values used in node 'id'.
Definition: Utils.cpp:414
Node * getNode(unsigned id)
Definition: Utils.cpp:236
void print(raw_ostream &os) const
Definition: Utils.cpp:615
A region of a memref's data space; this is typically constructed by analyzing load/store op's on this...
Definition: Utils.h:430
FlatAffineValueConstraints * getConstraints()
Definition: Utils.h:469
unsigned getRank() const
Returns the rank of the memref that this region corresponds to.
Definition: Utils.cpp:938
std::optional< int64_t > getConstantBoundingSizeAndShape(SmallVectorImpl< int64_t > *shape=nullptr, std::vector< SmallVector< int64_t, 4 >> *lbs=nullptr, SmallVectorImpl< int64_t > *lbDivisors=nullptr) const
Returns a constant upper bound on the number of elements in this region if bounded by a known constan...
Definition: Utils.cpp:942
void getLowerAndUpperBound(unsigned pos, AffineMap &lbMap, AffineMap &ubMap) const
Gets the lower and upper bound map for the dimensional variable at pos.
Definition: Utils.cpp:1003
std::optional< int64_t > getRegionSize()
Returns the size of this MemRefRegion in bytes.
Definition: Utils.cpp:1215
LogicalResult unionBoundingBox(const MemRefRegion &other)
Definition: Utils.cpp:1022
Value memref
Memref that this region corresponds to.
Definition: Utils.h:517
LogicalResult compute(Operation *op, unsigned loopDepth, const ComputationSliceState *sliceState=nullptr, bool addMemRefDimBounds=true)
Computes the memory region accessed by this memref with the region represented as constraints symboli...
Definition: Utils.cpp:1044
Enumerates different result statuses of slice computation by computeSliceUnion
Definition: Utils.h:246
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.