MLIR 22.0.0git
Utils.cpp
Go to the documentation of this file.
1//===- Utils.cpp ---- Misc utilities for analysis -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements miscellaneous analysis routines for non-loop IR
10// structures.
11//
12//===----------------------------------------------------------------------===//
13
15
23#include "mlir/IR/IntegerSet.h"
24#include "llvm/ADT/SetVector.h"
25#include "llvm/Support/Debug.h"
26#include "llvm/Support/DebugLog.h"
27#include "llvm/Support/raw_ostream.h"
28#include <optional>
29
30#define DEBUG_TYPE "analysis-utils"
31
32using namespace mlir;
33using namespace affine;
34using namespace presburger;
35
36using llvm::SmallDenseMap;
37
39
40// LoopNestStateCollector walks loop nests and collects load and store
41// operations, and whether or not a region holding op other than ForOp and IfOp
42// was encountered in the loop nest.
44 opToWalk->walk([&](Operation *op) {
45 if (auto forOp = dyn_cast<AffineForOp>(op)) {
46 forOps.push_back(forOp);
47 } else if (isa<AffineReadOpInterface>(op)) {
48 loadOpInsts.push_back(op);
49 } else if (isa<AffineWriteOpInterface>(op)) {
50 storeOpInsts.push_back(op);
51 } else {
52 auto memInterface = dyn_cast<MemoryEffectOpInterface>(op);
53 if (!memInterface) {
55 // This op itself is memory-effect free.
56 return;
57 // Check operands. Eg. ops like the `call` op are handled here.
58 for (Value v : op->getOperands()) {
59 if (!isa<MemRefType>(v.getType()))
60 continue;
61 // Conservatively, we assume the memref is read and written to.
62 memrefLoads.push_back(op);
63 memrefStores.push_back(op);
64 }
65 } else {
66 // Non-affine loads and stores.
68 memrefLoads.push_back(op);
70 memrefStores.push_back(op);
72 memrefFrees.push_back(op);
73 }
74 }
75 });
76}
77
79 unsigned loadOpCount = 0;
80 for (Operation *loadOp : loads) {
81 // Common case: affine reads.
82 if (auto affineLoad = dyn_cast<AffineReadOpInterface>(loadOp)) {
83 if (memref == affineLoad.getMemRef())
84 ++loadOpCount;
85 } else if (hasEffect<MemoryEffects::Read>(loadOp, memref)) {
86 ++loadOpCount;
87 }
88 }
89 return loadOpCount;
90}
91
92// Returns the store op count for 'memref'.
94 unsigned storeOpCount = 0;
95 for (auto *storeOp : llvm::concat<Operation *const>(stores, memrefStores)) {
96 // Common case: affine writes.
97 if (auto affineStore = dyn_cast<AffineWriteOpInterface>(storeOp)) {
98 if (memref == affineStore.getMemRef())
99 ++storeOpCount;
100 } else if (hasEffect<MemoryEffects::Write>(const_cast<Operation *>(storeOp),
101 memref)) {
102 ++storeOpCount;
103 }
104 }
105 return storeOpCount;
106}
107
108// Returns the store op count for 'memref'.
109unsigned Node::hasStore(Value memref) const {
110 return llvm::any_of(
111 llvm::concat<Operation *const>(stores, memrefStores),
112 [&](Operation *storeOp) {
113 if (auto affineStore = dyn_cast<AffineWriteOpInterface>(storeOp)) {
114 if (memref == affineStore.getMemRef())
115 return true;
116 } else if (hasEffect<MemoryEffects::Write>(storeOp, memref)) {
117 return true;
118 }
119 return false;
120 });
121}
122
123unsigned Node::hasFree(Value memref) const {
124 return llvm::any_of(memrefFrees, [&](Operation *freeOp) {
126 });
127}
128
129// Returns all store ops in 'storeOps' which access 'memref'.
131 SmallVectorImpl<Operation *> *storeOps) const {
132 for (Operation *storeOp : stores) {
133 if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
134 storeOps->push_back(storeOp);
135 }
136}
137
138// Returns all load ops in 'loadOps' which access 'memref'.
140 SmallVectorImpl<Operation *> *loadOps) const {
141 for (Operation *loadOp : loads) {
142 if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
143 loadOps->push_back(loadOp);
144 }
145}
146
147// Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
148// has at least one load and store operation.
150 DenseSet<Value> *loadAndStoreMemrefSet) const {
151 llvm::SmallDenseSet<Value, 2> loadMemrefs;
152 for (Operation *loadOp : loads) {
153 loadMemrefs.insert(cast<AffineReadOpInterface>(loadOp).getMemRef());
154 }
155 for (Operation *storeOp : stores) {
156 auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
157 if (loadMemrefs.count(memref) > 0)
158 loadAndStoreMemrefSet->insert(memref);
159 }
160}
161
162/// Returns the values that this op has a memref effect of type `EffectTys` on,
163/// not considering recursive effects.
164template <typename... EffectTys>
166 auto memOp = dyn_cast<MemoryEffectOpInterface>(op);
167 if (!memOp) {
169 // No effects.
170 return;
171 // Memref operands have to be considered as being affected.
172 for (Value operand : op->getOperands()) {
173 if (isa<MemRefType>(operand.getType()))
174 values.push_back(operand);
175 }
176 return;
177 }
179 memOp.getEffects(effects);
180 for (auto &effect : effects) {
181 Value effectVal = effect.getValue();
182 if (isa<EffectTys...>(effect.getEffect()) && effectVal &&
183 isa<MemRefType>(effectVal.getType()))
184 values.push_back(effectVal);
185 };
186}
187
188/// Add `op` to MDG creating a new node and adding its memory accesses (affine
189/// or non-affine to memrefAccesses (memref -> list of nodes with accesses) map.
190static Node *
192 DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
193 auto &nodes = mdg.nodes;
194 // Create graph node 'id' to represent top-level 'forOp' and record
195 // all loads and store accesses it contains.
196 LoopNestStateCollector collector;
197 collector.collect(nodeOp);
198 unsigned newNodeId = mdg.nextNodeId++;
199 Node &node = nodes.insert({newNodeId, Node(newNodeId, nodeOp)}).first->second;
200 for (Operation *op : collector.loadOpInsts) {
201 node.loads.push_back(op);
202 auto memref = cast<AffineReadOpInterface>(op).getMemRef();
203 memrefAccesses[memref].insert(node.id);
204 }
205 for (Operation *op : collector.storeOpInsts) {
206 node.stores.push_back(op);
207 auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
208 memrefAccesses[memref].insert(node.id);
209 }
210 for (Operation *op : collector.memrefLoads) {
211 SmallVector<Value> effectedValues;
212 getEffectedValues<MemoryEffects::Read>(op, effectedValues);
213 if (llvm::any_of(((ValueRange)effectedValues).getTypes(),
214 [](Type type) { return !isa<MemRefType>(type); }))
215 // We do not know the interaction here.
216 return nullptr;
217 for (Value memref : effectedValues)
218 memrefAccesses[memref].insert(node.id);
219 node.memrefLoads.push_back(op);
220 }
221 for (Operation *op : collector.memrefStores) {
222 SmallVector<Value> effectedValues;
223 getEffectedValues<MemoryEffects::Write>(op, effectedValues);
224 if (llvm::any_of((ValueRange(effectedValues)).getTypes(),
225 [](Type type) { return !isa<MemRefType>(type); }))
226 return nullptr;
227 for (Value memref : effectedValues)
228 memrefAccesses[memref].insert(node.id);
229 node.memrefStores.push_back(op);
230 }
231 for (Operation *op : collector.memrefFrees) {
232 SmallVector<Value> effectedValues;
233 getEffectedValues<MemoryEffects::Free>(op, effectedValues);
234 if (llvm::any_of((ValueRange(effectedValues)).getTypes(),
235 [](Type type) { return !isa<MemRefType>(type); }))
236 return nullptr;
237 for (Value memref : effectedValues)
238 memrefAccesses[memref].insert(node.id);
239 node.memrefFrees.push_back(op);
240 }
241
242 return &node;
243}
244
245/// Returns the memref being read/written by a memref/affine load/store op.
246static Value getMemRef(Operation *memOp) {
247 if (auto memrefLoad = dyn_cast<memref::LoadOp>(memOp))
248 return memrefLoad.getMemRef();
249 if (auto affineLoad = dyn_cast<AffineReadOpInterface>(memOp))
250 return affineLoad.getMemRef();
251 if (auto memrefStore = dyn_cast<memref::StoreOp>(memOp))
252 return memrefStore.getMemRef();
253 if (auto affineStore = dyn_cast<AffineWriteOpInterface>(memOp))
254 return affineStore.getMemRef();
255 llvm_unreachable("unexpected op");
256}
257
258/// Returns true if there may be a dependence on `memref` from srcNode's
259/// memory ops to dstNode's memory ops, while using the affine memory
260/// dependence analysis checks. The method assumes that there is at least one
261/// memory op in srcNode's loads and stores on `memref`, and similarly for
262/// `dstNode`. `srcNode.op` and `destNode.op` are expected to be nested in the
263/// same block and so the dependences are tested at the depth of that block.
264static bool mayDependence(const Node &srcNode, const Node &dstNode,
265 Value memref) {
266 assert(srcNode.op->getBlock() == dstNode.op->getBlock());
267 if (!isa<AffineForOp>(srcNode.op) || !isa<AffineForOp>(dstNode.op))
268 return true;
269
270 // Conservatively handle dependences involving non-affine load/stores. Return
271 // true if there exists a conflicting read/write access involving such.
272
273 // Check whether there is a dependence from a source read/write op to a
274 // destination read/write one; all expected to be memref/affine load/store.
275 auto hasNonAffineDep = [&](ArrayRef<Operation *> srcMemOps,
276 ArrayRef<Operation *> dstMemOps) {
277 return llvm::any_of(srcMemOps, [&](Operation *srcOp) {
278 Value srcMemref = getMemRef(srcOp);
279 if (srcMemref != memref)
280 return false;
281 return llvm::find_if(dstMemOps, [&](Operation *dstOp) {
282 return srcMemref == getMemRef(dstOp);
283 }) != dstMemOps.end();
284 });
285 };
286
288 // Between non-affine src stores and dst load/store.
289 llvm::append_range(dstOps, llvm::concat<Operation *const>(
290 dstNode.loads, dstNode.stores,
291 dstNode.memrefLoads, dstNode.memrefStores));
292 if (hasNonAffineDep(srcNode.memrefStores, dstOps))
293 return true;
294 // Between non-affine loads and dst stores.
295 dstOps.clear();
296 llvm::append_range(dstOps, llvm::concat<Operation *const>(
297 dstNode.stores, dstNode.memrefStores));
298 if (hasNonAffineDep(srcNode.memrefLoads, dstOps))
299 return true;
300 // Between affine stores and memref load/stores.
301 dstOps.clear();
302 llvm::append_range(dstOps, llvm::concat<Operation *const>(
303 dstNode.memrefLoads, dstNode.memrefStores));
304 if (hasNonAffineDep(srcNode.stores, dstOps))
305 return true;
306 // Between affine loads and memref stores.
307 dstOps.clear();
308 llvm::append_range(dstOps, dstNode.memrefStores);
309 if (hasNonAffineDep(srcNode.loads, dstOps))
310 return true;
311
312 // Affine load/store pairs. We don't need to check for locally allocated
313 // memrefs since the dependence analysis here is between mem ops from
314 // srcNode's for op to dstNode's for op at the depth at which those
315 // `affine.for` ops are nested, i.e., dependences at depth `d + 1` where
316 // `d` is the number of common surrounding loops.
317 for (auto *srcMemOp :
318 llvm::concat<Operation *const>(srcNode.stores, srcNode.loads)) {
319 MemRefAccess srcAcc(srcMemOp);
320 if (srcAcc.memref != memref)
321 continue;
322 for (auto *destMemOp :
323 llvm::concat<Operation *const>(dstNode.stores, dstNode.loads)) {
324 MemRefAccess destAcc(destMemOp);
325 if (destAcc.memref != memref)
326 continue;
327 // Check for a top-level dependence between srcNode and destNode's ops.
329 srcAcc, destAcc, getNestingDepth(srcNode.op) + 1)))
330 return true;
331 }
332 }
333 return false;
334}
335
336bool MemRefDependenceGraph::init(bool fullAffineDependences) {
337 LDBG() << "--- Initializing MDG ---";
338 // Map from a memref to the set of ids of the nodes that have ops accessing
339 // the memref.
341
342 // Create graph nodes.
344 for (Operation &op : block) {
345 if (auto forOp = dyn_cast<AffineForOp>(op)) {
346 Node *node = addNodeToMDG(&op, *this, memrefAccesses);
347 if (!node)
348 return false;
349 forToNodeMap[&op] = node->id;
350 } else if (isa<AffineReadOpInterface>(op)) {
351 // Create graph node for top-level load op.
352 Node node(nextNodeId++, &op);
353 node.loads.push_back(&op);
354 auto memref = cast<AffineReadOpInterface>(op).getMemRef();
355 memrefAccesses[memref].insert(node.id);
356 nodes.insert({node.id, node});
357 } else if (isa<AffineWriteOpInterface>(op)) {
358 // Create graph node for top-level store op.
359 Node node(nextNodeId++, &op);
360 node.stores.push_back(&op);
361 auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
362 memrefAccesses[memref].insert(node.id);
363 nodes.insert({node.id, node});
364 } else if (op.getNumResults() > 0 && !op.use_empty()) {
365 // Create graph node for top-level producer of SSA values, which
366 // could be used by loop nest nodes.
367 Node *node = addNodeToMDG(&op, *this, memrefAccesses);
368 if (!node)
369 return false;
370 } else if (!isMemoryEffectFree(&op) &&
371 (op.getNumRegions() == 0 || isa<RegionBranchOpInterface>(op))) {
372 // Create graph node for top-level op unless it is known to be
373 // memory-effect free. This covers all unknown/unregistered ops,
374 // non-affine ops with memory effects, and region-holding ops with a
375 // well-defined control flow. During the fusion validity checks, edges
376 // to/from these ops get looked at.
377 Node *node = addNodeToMDG(&op, *this, memrefAccesses);
378 if (!node)
379 return false;
380 } else if (op.getNumRegions() != 0 && !isa<RegionBranchOpInterface>(op)) {
381 // Return false if non-handled/unknown region-holding ops are found. We
382 // won't know what such ops do or what its regions mean; for e.g., it may
383 // not be an imperative op.
384 LDBG() << "MDG init failed; unknown region-holding op found!";
385 return false;
386 }
387 // We aren't creating nodes for memory-effect free ops either with no
388 // regions (unless it has results being used) or those with branch op
389 // interface.
390 }
391
392 LDBG() << "Created " << nodes.size() << " nodes";
393
394 // Add dependence edges between nodes which produce SSA values and their
395 // users. Load ops can be considered as the ones producing SSA values.
396 for (auto &idAndNode : nodes) {
397 const Node &node = idAndNode.second;
398 // Stores don't define SSA values, skip them.
399 if (!node.stores.empty())
400 continue;
401 Operation *opInst = node.op;
402 for (Value value : opInst->getResults()) {
403 for (Operation *user : value.getUsers()) {
404 // Ignore users outside of the block.
405 if (block.getParent()->findAncestorOpInRegion(*user)->getBlock() !=
406 &block)
407 continue;
409 getAffineForIVs(*user, &loops);
410 // Find the surrounding affine.for nested immediately within the
411 // block.
412 auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
413 return loop->getBlock() == &block;
414 });
415 if (it == loops.end())
416 continue;
417 assert(forToNodeMap.count(*it) > 0 && "missing mapping");
418 unsigned userLoopNestId = forToNodeMap[*it];
419 addEdge(node.id, userLoopNestId, value);
420 }
421 }
422 }
423
424 // Walk memref access lists and add graph edges between dependent nodes.
425 for (auto &memrefAndList : memrefAccesses) {
426 unsigned n = memrefAndList.second.size();
427 Value srcMemRef = memrefAndList.first;
428 // Add edges between all dependent pairs among the node IDs on this memref.
429 for (unsigned i = 0; i < n; ++i) {
430 unsigned srcId = memrefAndList.second[i];
431 Node *srcNode = getNode(srcId);
432 bool srcHasStoreOrFree =
433 srcNode->hasStore(srcMemRef) || srcNode->hasFree(srcMemRef);
434 for (unsigned j = i + 1; j < n; ++j) {
435 unsigned dstId = memrefAndList.second[j];
436 Node *dstNode = getNode(dstId);
437 bool dstHasStoreOrFree =
438 dstNode->hasStore(srcMemRef) || dstNode->hasFree(srcMemRef);
439 if ((srcHasStoreOrFree || dstHasStoreOrFree)) {
440 // Check precise affine deps if asked for; otherwise, conservative.
441 if (!fullAffineDependences ||
442 mayDependence(*srcNode, *dstNode, srcMemRef))
443 addEdge(srcId, dstId, srcMemRef);
444 }
445 }
446 }
447 }
448 return true;
449}
450
451// Returns the graph node for 'id'.
452const Node *MemRefDependenceGraph::getNode(unsigned id) const {
453 auto it = nodes.find(id);
454 assert(it != nodes.end());
455 return &it->second;
456}
457
458// Returns the graph node for 'forOp'.
459const Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) const {
460 for (auto &idAndNode : nodes)
461 if (idAndNode.second.op == forOp)
462 return &idAndNode.second;
463 return nullptr;
464}
465
466// Adds a node with 'op' to the graph and returns its unique identifier.
468 Node node(nextNodeId++, op);
469 nodes.insert({node.id, node});
470 return node.id;
471}
472
473// Remove node 'id' (and its associated edges) from graph.
475 // Remove each edge in 'inEdges[id]'.
476 if (inEdges.count(id) > 0) {
477 SmallVector<Edge, 2> oldInEdges = inEdges[id];
478 for (auto &inEdge : oldInEdges) {
479 removeEdge(inEdge.id, id, inEdge.value);
480 }
481 }
482 // Remove each edge in 'outEdges[id]'.
483 if (outEdges.contains(id)) {
484 SmallVector<Edge, 2> oldOutEdges = outEdges[id];
485 for (auto &outEdge : oldOutEdges) {
486 removeEdge(id, outEdge.id, outEdge.value);
487 }
488 }
489 // Erase remaining node state.
490 inEdges.erase(id);
491 outEdges.erase(id);
492 nodes.erase(id);
493}
494
495// Returns true if node 'id' writes to any memref which escapes (or is an
496// argument to) the block. Returns false otherwise.
498 const Node *node = getNode(id);
499 for (auto *storeOpInst : node->stores) {
500 auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
501 auto *op = memref.getDefiningOp();
502 // Return true if 'memref' is a block argument.
503 if (!op)
504 return true;
505 // Return true if any use of 'memref' does not deference it in an affine
506 // way.
507 for (auto *user : memref.getUsers())
508 if (!isa<AffineMapAccessInterface>(*user))
509 return true;
510 }
511 return false;
512}
513
514// Returns true iff there is an edge from node 'srcId' to node 'dstId' which
515// is for 'value' if non-null, or for any value otherwise. Returns false
516// otherwise.
517bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
518 Value value) const {
519 if (!outEdges.contains(srcId) || !inEdges.contains(dstId)) {
520 return false;
521 }
522 bool hasOutEdge = llvm::any_of(outEdges.lookup(srcId), [=](const Edge &edge) {
523 return edge.id == dstId && (!value || edge.value == value);
524 });
525 bool hasInEdge = llvm::any_of(inEdges.lookup(dstId), [=](const Edge &edge) {
526 return edge.id == srcId && (!value || edge.value == value);
527 });
528 return hasOutEdge && hasInEdge;
529}
530
531// Adds an edge from node 'srcId' to node 'dstId' for 'value'.
532void MemRefDependenceGraph::addEdge(unsigned srcId, unsigned dstId,
533 Value value) {
534 if (!hasEdge(srcId, dstId, value)) {
535 outEdges[srcId].push_back({dstId, value});
536 inEdges[dstId].push_back({srcId, value});
537 if (isa<MemRefType>(value.getType()))
538 memrefEdgeCount[value]++;
539 }
540}
541
542// Removes an edge from node 'srcId' to node 'dstId' for 'value'.
543void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
544 Value value) {
545 assert(inEdges.count(dstId) > 0);
546 assert(outEdges.count(srcId) > 0);
547 if (isa<MemRefType>(value.getType())) {
548 assert(memrefEdgeCount.count(value) > 0);
549 memrefEdgeCount[value]--;
550 }
551 // Remove 'srcId' from 'inEdges[dstId]'.
552 for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
553 if ((*it).id == srcId && (*it).value == value) {
554 inEdges[dstId].erase(it);
555 break;
556 }
557 }
558 // Remove 'dstId' from 'outEdges[srcId]'.
559 for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
560 if ((*it).id == dstId && (*it).value == value) {
561 outEdges[srcId].erase(it);
562 break;
563 }
564 }
565}
566
567// Returns true if there is a path in the dependence graph from node 'srcId'
568// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
569// operations that the edges connected are expected to be from the same block.
571 unsigned dstId) const {
572 // Worklist state is: <node-id, next-output-edge-index-to-visit>
574 worklist.push_back({srcId, 0});
575 Operation *dstOp = getNode(dstId)->op;
576 // Run DFS traversal to see if 'dstId' is reachable from 'srcId'.
577 while (!worklist.empty()) {
578 auto &idAndIndex = worklist.back();
579 // Return true if we have reached 'dstId'.
580 if (idAndIndex.first == dstId)
581 return true;
582 // Pop and continue if node has no out edges, or if all out edges have
583 // already been visited.
584 if (!outEdges.contains(idAndIndex.first) ||
585 idAndIndex.second == outEdges.lookup(idAndIndex.first).size()) {
586 worklist.pop_back();
587 continue;
588 }
589 // Get graph edge to traverse.
590 const Edge edge = outEdges.lookup(idAndIndex.first)[idAndIndex.second];
591 // Increment next output edge index for 'idAndIndex'.
592 ++idAndIndex.second;
593 // Add node at 'edge.id' to the worklist. We don't need to consider
594 // nodes that are "after" dstId in the containing block; one can't have a
595 // path to `dstId` from any of those nodes.
596 bool afterDst = dstOp->isBeforeInBlock(getNode(edge.id)->op);
597 if (!afterDst && edge.id != idAndIndex.first)
598 worklist.push_back({edge.id, 0});
599 }
600 return false;
601}
602
603// Returns the input edge count for node 'id' and 'memref' from src nodes
604// which access 'memref' with a store operation.
606 Value memref) const {
607 unsigned inEdgeCount = 0;
608 for (const Edge &inEdge : inEdges.lookup(id)) {
609 if (inEdge.value == memref) {
610 const Node *srcNode = getNode(inEdge.id);
611 // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
612 if (srcNode->getStoreOpCount(memref) > 0)
613 ++inEdgeCount;
614 }
615 }
616 return inEdgeCount;
617}
618
619// Returns the output edge count for node 'id' and 'memref' (if non-null),
620// otherwise returns the total output edge count from node 'id'.
622 Value memref) const {
623 unsigned outEdgeCount = 0;
624 for (const auto &outEdge : outEdges.lookup(id))
625 if (!memref || outEdge.value == memref)
626 ++outEdgeCount;
627 return outEdgeCount;
628}
629
630/// Return all nodes which define SSA values used in node 'id'.
632 unsigned id, DenseSet<unsigned> &definingNodes) const {
633 for (const Edge &edge : inEdges.lookup(id))
634 // By definition of edge, if the edge value is a non-memref value,
635 // then the dependence is between a graph node which defines an SSA value
636 // and another graph node which uses the SSA value.
637 if (!isa<MemRefType>(edge.value.getType()))
638 definingNodes.insert(edge.id);
639}
640
641// Computes and returns an insertion point operation, before which the
642// the fused <srcId, dstId> loop nest can be inserted while preserving
643// dependences. Returns nullptr if no such insertion point is found.
644Operation *
646 unsigned dstId) const {
647 if (!outEdges.contains(srcId))
648 return getNode(dstId)->op;
649
650 // Skip if there is any defining node of 'dstId' that depends on 'srcId'.
651 DenseSet<unsigned> definingNodes;
652 gatherDefiningNodes(dstId, definingNodes);
653 if (llvm::any_of(definingNodes,
654 [&](unsigned id) { return hasDependencePath(srcId, id); })) {
655 LDBG() << "Can't fuse: a defining op with a user in the dst "
656 << "loop has dependence from the src loop";
657 return nullptr;
658 }
659
660 // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
662 for (auto &outEdge : outEdges.lookup(srcId))
663 if (outEdge.id != dstId)
664 srcDepInsts.insert(getNode(outEdge.id)->op);
665
666 // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
668 for (auto &inEdge : inEdges.lookup(dstId))
669 if (inEdge.id != srcId)
670 dstDepInsts.insert(getNode(inEdge.id)->op);
671
672 Operation *srcNodeInst = getNode(srcId)->op;
673 Operation *dstNodeInst = getNode(dstId)->op;
674
675 // Computing insertion point:
676 // *) Walk all operation positions in Block operation list in the
677 // range (src, dst). For each operation 'op' visited in this search:
678 // *) Store in 'firstSrcDepPos' the first position where 'op' has a
679 // dependence edge from 'srcNode'.
680 // *) Store in 'lastDstDepPost' the last position where 'op' has a
681 // dependence edge to 'dstNode'.
682 // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
683 // operation insertion point (or return null pointer if no such
684 // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
686 std::optional<unsigned> firstSrcDepPos;
687 std::optional<unsigned> lastDstDepPos;
688 unsigned pos = 0;
689 for (Block::iterator it = std::next(Block::iterator(srcNodeInst));
690 it != Block::iterator(dstNodeInst); ++it) {
691 Operation *op = &(*it);
692 if (srcDepInsts.count(op) > 0 && firstSrcDepPos == std::nullopt)
693 firstSrcDepPos = pos;
694 if (dstDepInsts.count(op) > 0)
695 lastDstDepPos = pos;
696 depInsts.push_back(op);
697 ++pos;
698 }
699
700 if (firstSrcDepPos.has_value()) {
701 if (lastDstDepPos.has_value()) {
702 if (*firstSrcDepPos <= *lastDstDepPos) {
703 // No valid insertion point exists which preserves dependences.
704 return nullptr;
705 }
706 }
707 // Return the insertion point at 'firstSrcDepPos'.
708 return depInsts[*firstSrcDepPos];
709 }
710 // No dependence targets in range (or only dst deps in range), return
711 // 'dstNodInst' insertion point.
712 return dstNodeInst;
713}
714
715// Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
716// taking into account that:
717// *) if 'removeSrcId' is true, 'srcId' will be removed after fusion,
718// *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a
719// private memref.
720void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId,
721 const DenseSet<Value> &privateMemRefs,
722 bool removeSrcId) {
723 // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'.
724 if (inEdges.count(srcId) > 0) {
725 SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
726 for (auto &inEdge : oldInEdges) {
727 // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
728 if (!privateMemRefs.contains(inEdge.value))
729 addEdge(inEdge.id, dstId, inEdge.value);
730 }
731 }
732 // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
733 // If 'srcId' is going to be removed, remap all the out edges to 'dstId'.
734 if (outEdges.count(srcId) > 0) {
735 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
736 for (auto &outEdge : oldOutEdges) {
737 // Remove any out edges from 'srcId' to 'dstId' across memrefs.
738 if (outEdge.id == dstId)
739 removeEdge(srcId, outEdge.id, outEdge.value);
740 else if (removeSrcId) {
741 addEdge(dstId, outEdge.id, outEdge.value);
742 removeEdge(srcId, outEdge.id, outEdge.value);
743 }
744 }
745 }
746 // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
747 // replaced by a private memref). These edges could come from nodes
748 // other than 'srcId' which were removed in the previous step.
749 if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) {
750 SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
751 for (auto &inEdge : oldInEdges)
752 if (privateMemRefs.count(inEdge.value) > 0)
753 removeEdge(inEdge.id, dstId, inEdge.value);
754 }
755}
756
757// Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
758// of sibling node 'sibId' into node 'dstId'.
759void MemRefDependenceGraph::updateEdges(unsigned sibId, unsigned dstId) {
760 // For each edge in 'inEdges[sibId]':
761 // *) Add new edge from source node 'inEdge.id' to 'dstNode'.
762 // *) Remove edge from source node 'inEdge.id' to 'sibNode'.
763 if (inEdges.count(sibId) > 0) {
764 SmallVector<Edge, 2> oldInEdges = inEdges[sibId];
765 for (auto &inEdge : oldInEdges) {
766 addEdge(inEdge.id, dstId, inEdge.value);
767 removeEdge(inEdge.id, sibId, inEdge.value);
768 }
769 }
770
771 // For each edge in 'outEdges[sibId]' to node 'id'
772 // *) Add new edge from 'dstId' to 'outEdge.id'.
773 // *) Remove edge from 'sibId' to 'outEdge.id'.
774 if (outEdges.count(sibId) > 0) {
775 SmallVector<Edge, 2> oldOutEdges = outEdges[sibId];
776 for (auto &outEdge : oldOutEdges) {
777 addEdge(dstId, outEdge.id, outEdge.value);
778 removeEdge(sibId, outEdge.id, outEdge.value);
779 }
780 }
781}
782
783// Adds ops in 'loads' and 'stores' to node at 'id'.
786 ArrayRef<Operation *> memrefLoads,
787 ArrayRef<Operation *> memrefStores,
788 ArrayRef<Operation *> memrefFrees) {
789 Node *node = getNode(id);
790 llvm::append_range(node->loads, loads);
791 llvm::append_range(node->stores, stores);
792 llvm::append_range(node->memrefLoads, memrefLoads);
793 llvm::append_range(node->memrefStores, memrefStores);
794 llvm::append_range(node->memrefFrees, memrefFrees);
795}
796
798 Node *node = getNode(id);
799 node->loads.clear();
800 node->stores.clear();
801}
802
803// Calls 'callback' for each input edge incident to node 'id' which carries a
804// memref dependence.
806 unsigned id, const std::function<void(Edge)> &callback) {
807 if (inEdges.count(id) > 0)
808 forEachMemRefEdge(inEdges.at(id), callback);
809}
810
811// Calls 'callback' for each output edge from node 'id' which carries a
812// memref dependence.
814 unsigned id, const std::function<void(Edge)> &callback) {
815 if (outEdges.count(id) > 0)
816 forEachMemRefEdge(outEdges.at(id), callback);
817}
818
819// Calls 'callback' for each edge in 'edges' which carries a memref
820// dependence.
822 ArrayRef<Edge> edges, const std::function<void(Edge)> &callback) {
823 for (const auto &edge : edges) {
824 // Skip if 'edge' is not a memref dependence edge.
825 if (!isa<MemRefType>(edge.value.getType()))
826 continue;
827 assert(nodes.count(edge.id) > 0);
828 // Visit current input edge 'edge'.
829 callback(edge);
830 }
831}
832
834 os << "\nMemRefDependenceGraph\n";
835 os << "\nNodes:\n";
836 for (const auto &idAndNode : nodes) {
837 os << "Node: " << idAndNode.first << "\n";
838 auto it = inEdges.find(idAndNode.first);
839 if (it != inEdges.end()) {
840 for (const auto &e : it->second)
841 os << " InEdge: " << e.id << " " << e.value << "\n";
842 }
843 it = outEdges.find(idAndNode.first);
844 if (it != outEdges.end()) {
845 for (const auto &e : it->second)
846 os << " OutEdge: " << e.id << " " << e.value << "\n";
847 }
848 }
849}
850
853 auto *currOp = op.getParentOp();
854 AffineForOp currAffineForOp;
855 // Traverse up the hierarchy collecting all 'affine.for' operation while
856 // skipping over 'affine.if' operations.
857 while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
858 if (auto currAffineForOp = dyn_cast<AffineForOp>(currOp))
859 loops->push_back(currAffineForOp);
860 currOp = currOp->getParentOp();
861 }
862 std::reverse(loops->begin(), loops->end());
863}
864
867 ops->clear();
868 Operation *currOp = op.getParentOp();
869
870 // Traverse up the hierarchy collecting all `affine.for`, `affine.if`, and
871 // affine.parallel operations.
872 while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
873 if (isa<AffineIfOp, AffineForOp, AffineParallelOp>(currOp))
874 ops->push_back(currOp);
875 currOp = currOp->getParentOp();
876 }
877 std::reverse(ops->begin(), ops->end());
878}
879
880// Populates 'cst' with FlatAffineValueConstraints which represent original
881// domain of the loop bounds that define 'ivs'.
883 FlatAffineValueConstraints &cst) const {
884 assert(!ivs.empty() && "Cannot have a slice without its IVs");
885 cst = FlatAffineValueConstraints(/*numDims=*/ivs.size(), /*numSymbols=*/0,
886 /*numLocals=*/0, ivs);
887 for (Value iv : ivs) {
888 AffineForOp loop = getForInductionVarOwner(iv);
889 assert(loop && "Expected affine for");
890 if (failed(cst.addAffineForOpDomain(loop)))
891 return failure();
892 }
893 return success();
894}
895
896// Populates 'cst' with FlatAffineValueConstraints which represent slice bounds.
897LogicalResult
899 assert(!lbOperands.empty());
900 // Adds src 'ivs' as dimension variables in 'cst'.
901 unsigned numDims = ivs.size();
902 // Adds operands (dst ivs and symbols) as symbols in 'cst'.
903 unsigned numSymbols = lbOperands[0].size();
904
906 // Append 'ivs' then 'operands' to 'values'.
907 values.append(lbOperands[0].begin(), lbOperands[0].end());
908 *cst = FlatAffineValueConstraints(numDims, numSymbols, 0, values);
909
910 // Add loop bound constraints for values which are loop IVs of the destination
911 // of fusion and equality constraints for symbols which are constants.
912 for (unsigned i = numDims, end = values.size(); i < end; ++i) {
913 Value value = values[i];
914 assert(cst->containsVar(value) && "value expected to be present");
915 if (isValidSymbol(value)) {
916 // Check if the symbol is a constant.
917 if (std::optional<int64_t> cOp = getConstantIntValue(value))
918 cst->addBound(BoundType::EQ, value, cOp.value());
919 } else if (auto loop = getForInductionVarOwner(value)) {
920 if (failed(cst->addAffineForOpDomain(loop)))
921 return failure();
922 }
923 }
924
925 // Add slices bounds on 'ivs' using maps 'lbs'/'ubs' with 'lbOperands[0]'
926 LogicalResult ret = cst->addSliceBounds(ivs, lbs, ubs, lbOperands[0]);
927 assert(succeeded(ret) &&
928 "should not fail as we never have semi-affine slice maps");
929 (void)ret;
930 return success();
931}
932
933// Clears state bounds and operand state.
935 lbs.clear();
936 ubs.clear();
937 lbOperands.clear();
938 ubOperands.clear();
939}
940
942 llvm::errs() << "\tIVs:\n";
943 for (Value iv : ivs)
944 llvm::errs() << "\t\t" << iv << "\n";
945
946 llvm::errs() << "\tLBs:\n";
947 for (auto en : llvm::enumerate(lbs)) {
948 llvm::errs() << "\t\t" << en.value() << "\n";
949 llvm::errs() << "\t\tOperands:\n";
950 for (Value lbOp : lbOperands[en.index()])
951 llvm::errs() << "\t\t\t" << lbOp << "\n";
952 }
953
954 llvm::errs() << "\tUBs:\n";
955 for (auto en : llvm::enumerate(ubs)) {
956 llvm::errs() << "\t\t" << en.value() << "\n";
957 llvm::errs() << "\t\tOperands:\n";
958 for (Value ubOp : ubOperands[en.index()])
959 llvm::errs() << "\t\t\t" << ubOp << "\n";
960 }
961}
962
963/// Fast check to determine if the computation slice is maximal. Returns true if
964/// each slice dimension maps to an existing dst dimension and both the src
965/// and the dst loops for those dimensions have the same bounds. Returns false
966/// if both the src and the dst loops don't have the same bounds. Returns
967/// std::nullopt if none of the above can be proven.
968std::optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
969 assert(lbs.size() == ubs.size() && !lbs.empty() && !ivs.empty() &&
970 "Unexpected number of lbs, ubs and ivs in slice");
971
972 for (unsigned i = 0, end = lbs.size(); i < end; ++i) {
973 AffineMap lbMap = lbs[i];
974 AffineMap ubMap = ubs[i];
975
976 // Check if this slice is just an equality along this dimension.
977 if (!lbMap || !ubMap || lbMap.getNumResults() != 1 ||
978 ubMap.getNumResults() != 1 ||
979 lbMap.getResult(0) + 1 != ubMap.getResult(0) ||
980 // The condition above will be true for maps describing a single
981 // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
982 // Make sure we skip those cases by checking that the lb result is not
983 // just a constant.
984 isa<AffineConstantExpr>(lbMap.getResult(0)))
985 return std::nullopt;
986
987 // Limited support: we expect the lb result to be just a loop dimension for
988 // now.
989 AffineDimExpr result = dyn_cast<AffineDimExpr>(lbMap.getResult(0));
990 if (!result)
991 return std::nullopt;
992
993 // Retrieve dst loop bounds.
994 AffineForOp dstLoop =
995 getForInductionVarOwner(lbOperands[i][result.getPosition()]);
996 if (!dstLoop)
997 return std::nullopt;
998 AffineMap dstLbMap = dstLoop.getLowerBoundMap();
999 AffineMap dstUbMap = dstLoop.getUpperBoundMap();
1000
1001 // Retrieve src loop bounds.
1002 AffineForOp srcLoop = getForInductionVarOwner(ivs[i]);
1003 assert(srcLoop && "Expected affine for");
1004 AffineMap srcLbMap = srcLoop.getLowerBoundMap();
1005 AffineMap srcUbMap = srcLoop.getUpperBoundMap();
1006
1007 // Limited support: we expect simple src and dst loops with a single
1008 // constant component per bound for now.
1009 if (srcLbMap.getNumResults() != 1 || srcUbMap.getNumResults() != 1 ||
1010 dstLbMap.getNumResults() != 1 || dstUbMap.getNumResults() != 1)
1011 return std::nullopt;
1012
1013 AffineExpr srcLbResult = srcLbMap.getResult(0);
1014 AffineExpr dstLbResult = dstLbMap.getResult(0);
1015 AffineExpr srcUbResult = srcUbMap.getResult(0);
1016 AffineExpr dstUbResult = dstUbMap.getResult(0);
1017 if (!isa<AffineConstantExpr>(srcLbResult) ||
1018 !isa<AffineConstantExpr>(srcUbResult) ||
1019 !isa<AffineConstantExpr>(dstLbResult) ||
1020 !isa<AffineConstantExpr>(dstUbResult))
1021 return std::nullopt;
1022
1023 // Check if src and dst loop bounds are the same. If not, we can guarantee
1024 // that the slice is not maximal.
1025 if (srcLbResult != dstLbResult || srcUbResult != dstUbResult ||
1026 srcLoop.getStep() != dstLoop.getStep())
1027 return false;
1028 }
1029
1030 return true;
1031}
1032
1033/// Returns true if it is deterministically verified that the original iteration
1034/// space of the slice is contained within the new iteration space that is
1035/// created after fusing 'this' slice into its destination.
1036std::optional<bool> ComputationSliceState::isSliceValid() const {
1037 // Fast check to determine if the slice is valid. If the following conditions
1038 // are verified to be true, slice is declared valid by the fast check:
1039 // 1. Each slice loop is a single iteration loop bound in terms of a single
1040 // destination loop IV.
1041 // 2. Loop bounds of the destination loop IV (from above) and those of the
1042 // source loop IV are exactly the same.
1043 // If the fast check is inconclusive or false, we proceed with a more
1044 // expensive analysis.
1045 // TODO: Store the result of the fast check, as it might be used again in
1046 // `canRemoveSrcNodeAfterFusion`.
1047 std::optional<bool> isValidFastCheck = isSliceMaximalFastCheck();
1048 if (isValidFastCheck && *isValidFastCheck)
1049 return true;
1050
1051 // Create constraints for the source loop nest using which slice is computed.
1052 FlatAffineValueConstraints srcConstraints;
1053 // TODO: Store the source's domain to avoid computation at each depth.
1054 if (failed(getSourceAsConstraints(srcConstraints))) {
1055 LDBG() << "Unable to compute source's domain";
1056 return std::nullopt;
1057 }
1058 // As the set difference utility currently cannot handle symbols in its
1059 // operands, validity of the slice cannot be determined.
1060 if (srcConstraints.getNumSymbolVars() > 0) {
1061 LDBG() << "Cannot handle symbols in source domain";
1062 return std::nullopt;
1063 }
1064 // TODO: Handle local vars in the source domains while using the 'projectOut'
1065 // utility below. Currently, aligning is not done assuming that there will be
1066 // no local vars in the source domain.
1067 if (srcConstraints.getNumLocalVars() != 0) {
1068 LDBG() << "Cannot handle locals in source domain";
1069 return std::nullopt;
1070 }
1071
1072 // Create constraints for the slice loop nest that would be created if the
1073 // fusion succeeds.
1074 FlatAffineValueConstraints sliceConstraints;
1075 if (failed(getAsConstraints(&sliceConstraints))) {
1076 LDBG() << "Unable to compute slice's domain";
1077 return std::nullopt;
1078 }
1079
1080 // Projecting out every dimension other than the 'ivs' to express slice's
1081 // domain completely in terms of source's IVs.
1082 sliceConstraints.projectOut(ivs.size(),
1083 sliceConstraints.getNumVars() - ivs.size());
1084
1085 LDBG() << "Domain of the source of the slice:\n"
1086 << "Source constraints:" << srcConstraints
1087 << "\nDomain of the slice if this fusion succeeds "
1088 << "(expressed in terms of its source's IVs):\n"
1089 << "Slice constraints:" << sliceConstraints;
1090
1091 // TODO: Store 'srcSet' to avoid recalculating for each depth.
1092 PresburgerSet srcSet(srcConstraints);
1093 PresburgerSet sliceSet(sliceConstraints);
1094 PresburgerSet diffSet = sliceSet.subtract(srcSet);
1095
1096 if (!diffSet.isIntegerEmpty()) {
1097 LDBG() << "Incorrect slice";
1098 return false;
1099 }
1100 return true;
1101}
1102
1103/// Returns true if the computation slice encloses all the iterations of the
1104/// sliced loop nest. Returns false if it does not. Returns std::nullopt if it
1105/// cannot determine if the slice is maximal or not.
1106std::optional<bool> ComputationSliceState::isMaximal() const {
1107 // Fast check to determine if the computation slice is maximal. If the result
1108 // is inconclusive, we proceed with a more expensive analysis.
1109 std::optional<bool> isMaximalFastCheck = isSliceMaximalFastCheck();
1110 if (isMaximalFastCheck)
1111 return isMaximalFastCheck;
1112
1113 // Create constraints for the src loop nest being sliced.
1114 FlatAffineValueConstraints srcConstraints(/*numDims=*/ivs.size(),
1115 /*numSymbols=*/0,
1116 /*numLocals=*/0, ivs);
1117 for (Value iv : ivs) {
1118 AffineForOp loop = getForInductionVarOwner(iv);
1119 assert(loop && "Expected affine for");
1120 if (failed(srcConstraints.addAffineForOpDomain(loop)))
1121 return std::nullopt;
1122 }
1123
1124 // Create constraints for the slice using the dst loop nest information. We
1125 // retrieve existing dst loops from the lbOperands.
1126 SmallVector<Value> consumerIVs;
1127 for (Value lbOp : lbOperands[0])
1128 if (getForInductionVarOwner(lbOp))
1129 consumerIVs.push_back(lbOp);
1130
1131 // Add empty IV Values for those new loops that are not equalities and,
1132 // therefore, are not yet materialized in the IR.
1133 for (int i = consumerIVs.size(), end = ivs.size(); i < end; ++i)
1134 consumerIVs.push_back(Value());
1135
1136 FlatAffineValueConstraints sliceConstraints(/*numDims=*/consumerIVs.size(),
1137 /*numSymbols=*/0,
1138 /*numLocals=*/0, consumerIVs);
1139
1140 if (failed(sliceConstraints.addDomainFromSliceMaps(lbs, ubs, lbOperands[0])))
1141 return std::nullopt;
1142
1143 if (srcConstraints.getNumDimVars() != sliceConstraints.getNumDimVars())
1144 // Constraint dims are different. The integer set difference can't be
1145 // computed so we don't know if the slice is maximal.
1146 return std::nullopt;
1147
1148 // Compute the difference between the src loop nest and the slice integer
1149 // sets.
1150 PresburgerSet srcSet(srcConstraints);
1151 PresburgerSet sliceSet(sliceConstraints);
1152 PresburgerSet diffSet = srcSet.subtract(sliceSet);
1153 return diffSet.isIntegerEmpty();
1154}
1155
1156unsigned MemRefRegion::getRank() const {
1157 return cast<MemRefType>(memref.getType()).getRank();
1158}
1159
1162 auto memRefType = cast<MemRefType>(memref.getType());
1163 MLIRContext *context = memref.getContext();
1164 unsigned rank = memRefType.getRank();
1165 if (shape)
1166 shape->reserve(rank);
1167
1168 assert(rank == cst.getNumDimVars() && "inconsistent memref region");
1169
1170 // Use a copy of the region constraints that has upper/lower bounds for each
1171 // memref dimension with static size added to guard against potential
1172 // over-approximation from projection or union bounding box. We may not add
1173 // this on the region itself since they might just be redundant constraints
1174 // that will need non-trivials means to eliminate.
1175 FlatLinearValueConstraints cstWithShapeBounds(cst);
1176 for (unsigned r = 0; r < rank; r++) {
1177 cstWithShapeBounds.addBound(BoundType::LB, r, 0);
1178 int64_t dimSize = memRefType.getDimSize(r);
1179 if (ShapedType::isDynamic(dimSize))
1180 continue;
1181 cstWithShapeBounds.addBound(BoundType::UB, r, dimSize - 1);
1182 }
1183
1184 // Find a constant upper bound on the extent of this memref region along
1185 // each dimension.
1186 int64_t numElements = 1;
1187 int64_t diffConstant;
1188 for (unsigned d = 0; d < rank; d++) {
1189 AffineMap lb;
1190 std::optional<int64_t> diff =
1191 cstWithShapeBounds.getConstantBoundOnDimSize(context, d, &lb);
1192 if (diff.has_value()) {
1193 diffConstant = *diff;
1194 assert(diffConstant >= 0 && "dim size bound cannot be negative");
1195 } else {
1196 // If no constant bound is found, then it can always be bound by the
1197 // memref's dim size if the latter has a constant size along this dim.
1198 auto dimSize = memRefType.getDimSize(d);
1199 if (ShapedType::isDynamic(dimSize))
1200 return std::nullopt;
1201 diffConstant = dimSize;
1202 // Lower bound becomes 0.
1203 lb = AffineMap::get(/*dimCount=*/0, cstWithShapeBounds.getNumSymbolVars(),
1204 /*result=*/getAffineConstantExpr(0, context));
1205 }
1206 numElements *= diffConstant;
1207 // Populate outputs if available.
1208 if (lbs)
1209 lbs->push_back(lb);
1210 if (shape)
1211 shape->push_back(diffConstant);
1212 }
1213 return numElements;
1214}
1215
1217 AffineMap &ubMap) const {
1218 assert(pos < cst.getNumDimVars() && "invalid position");
1219 auto memRefType = cast<MemRefType>(memref.getType());
1220 unsigned rank = memRefType.getRank();
1221
1222 assert(rank == cst.getNumDimVars() && "inconsistent memref region");
1223
1224 auto boundPairs = cst.getLowerAndUpperBound(
1225 pos, /*offset=*/0, /*num=*/rank, cst.getNumDimAndSymbolVars(),
1226 /*localExprs=*/{}, memRefType.getContext());
1227 lbMap = boundPairs.first;
1228 ubMap = boundPairs.second;
1229 assert(lbMap && "lower bound for a region must exist");
1230 assert(ubMap && "upper bound for a region must exist");
1231 assert(lbMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank);
1232 assert(ubMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank);
1233}
1234
1236 assert(memref == other.memref);
1237 return cst.unionBoundingBox(*other.getConstraints());
1238}
1239
1240/// Computes the memory region accessed by this memref with the region
1241/// represented as constraints symbolic/parametric in 'loopDepth' loops
1242/// surrounding opInst and any additional Function symbols.
1243// For example, the memref region for this load operation at loopDepth = 1 will
1244// be as below:
1245//
1246// affine.for %i = 0 to 32 {
1247// affine.for %ii = %i to (d0) -> (d0 + 8) (%i) {
1248// load %A[%ii]
1249// }
1250// }
1251//
1252// region: {memref = %A, write = false, {%i <= m0 <= %i + 7} }
1253// The last field is a 2-d FlatAffineValueConstraints symbolic in %i.
1254//
1255// TODO: extend this to any other memref dereferencing ops
1256// (dma_start, dma_wait).
1257LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
1258 const ComputationSliceState *sliceState,
1259 bool addMemRefDimBounds, bool dropLocalVars,
1260 bool dropOuterIvs) {
1261 assert((isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) &&
1262 "affine read/write op expected");
1263
1264 MemRefAccess access(op);
1265 memref = access.memref;
1266 write = access.isStore();
1267
1268 unsigned rank = access.getRank();
1269
1270 LDBG() << "MemRefRegion::compute: " << *op << " depth: " << loopDepth;
1271
1272 // 0-d memrefs.
1273 if (rank == 0) {
1275 getAffineIVs(*op, ivs);
1276 assert(loopDepth <= ivs.size() && "invalid 'loopDepth'");
1277 // The first 'loopDepth' IVs are symbols for this region.
1278 ivs.resize(loopDepth);
1279 // A 0-d memref has a 0-d region.
1280 cst = FlatAffineValueConstraints(rank, loopDepth, /*numLocals=*/0, ivs);
1281 return success();
1282 }
1283
1284 // Build the constraints for this region.
1285 AffineValueMap accessValueMap;
1286 access.getAccessMap(&accessValueMap);
1287 AffineMap accessMap = accessValueMap.getAffineMap();
1288
1289 unsigned numDims = accessMap.getNumDims();
1290 unsigned numSymbols = accessMap.getNumSymbols();
1291 unsigned numOperands = accessValueMap.getNumOperands();
1292 // Merge operands with slice operands.
1293 SmallVector<Value, 4> operands;
1294 operands.resize(numOperands);
1295 for (unsigned i = 0; i < numOperands; ++i)
1296 operands[i] = accessValueMap.getOperand(i);
1297
1298 if (sliceState != nullptr) {
1299 operands.reserve(operands.size() + sliceState->lbOperands[0].size());
1300 // Append slice operands to 'operands' as symbols.
1301 for (auto extraOperand : sliceState->lbOperands[0]) {
1302 if (!llvm::is_contained(operands, extraOperand)) {
1303 operands.push_back(extraOperand);
1304 numSymbols++;
1305 }
1306 }
1307 }
1308 // We'll first associate the dims and symbols of the access map to the dims
1309 // and symbols resp. of cst. This will change below once cst is
1310 // fully constructed out.
1311 cst = FlatAffineValueConstraints(numDims, numSymbols, 0, operands);
1312
1313 // Add equality constraints.
1314 // Add inequalities for loop lower/upper bounds.
1315 for (unsigned i = 0; i < numDims + numSymbols; ++i) {
1316 auto operand = operands[i];
1317 if (auto affineFor = getForInductionVarOwner(operand)) {
1318 // Note that cst can now have more dimensions than accessMap if the
1319 // bounds expressions involve outer loops or other symbols.
1320 // TODO: rewrite this to use getInstIndexSet; this way
1321 // conditionals will be handled when the latter supports it.
1322 if (failed(cst.addAffineForOpDomain(affineFor)))
1323 return failure();
1324 } else if (auto parallelOp = getAffineParallelInductionVarOwner(operand)) {
1325 if (failed(cst.addAffineParallelOpDomain(parallelOp)))
1326 return failure();
1327 } else if (isValidSymbol(operand)) {
1328 // Check if the symbol is a constant.
1329 Value symbol = operand;
1330 if (auto constVal = getConstantIntValue(symbol))
1331 cst.addBound(BoundType::EQ, symbol, constVal.value());
1332 } else {
1333 LDBG() << "unknown affine dimensional value";
1334 return failure();
1335 }
1336 }
1337
1338 // Add lower/upper bounds on loop IVs using bounds from 'sliceState'.
1339 if (sliceState != nullptr) {
1340 // Add dim and symbol slice operands.
1341 for (auto operand : sliceState->lbOperands[0]) {
1342 if (failed(cst.addInductionVarOrTerminalSymbol(operand)))
1343 return failure();
1344 }
1345 // Add upper/lower bounds from 'sliceState' to 'cst'.
1346 LogicalResult ret =
1347 cst.addSliceBounds(sliceState->ivs, sliceState->lbs, sliceState->ubs,
1348 sliceState->lbOperands[0]);
1349 assert(succeeded(ret) &&
1350 "should not fail as we never have semi-affine slice maps");
1351 (void)ret;
1352 }
1353
1354 // Add access function equalities to connect loop IVs to data dimensions.
1355 if (failed(cst.composeMap(&accessValueMap))) {
1356 op->emitError("getMemRefRegion: compose affine map failed");
1357 LDBG() << "Access map: " << accessValueMap.getAffineMap();
1358 return failure();
1359 }
1360
1361 // Set all variables appearing after the first 'rank' variables as
1362 // symbolic variables - so that the ones corresponding to the memref
1363 // dimensions are the dimensional variables for the memref region.
1364 cst.setDimSymbolSeparation(cst.getNumDimAndSymbolVars() - rank);
1365
1366 // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
1367 // this memref region is symbolic.
1368 SmallVector<Value, 4> enclosingIVs;
1369 getAffineIVs(*op, enclosingIVs);
1370 assert(loopDepth <= enclosingIVs.size() && "invalid loop depth");
1371 enclosingIVs.resize(loopDepth);
1373 cst.getValues(cst.getNumDimVars(), cst.getNumDimAndSymbolVars(), &vars);
1374 for (auto en : llvm::enumerate(vars)) {
1375 if ((isAffineInductionVar(en.value())) &&
1376 !llvm::is_contained(enclosingIVs, en.value())) {
1377 if (dropOuterIvs) {
1378 cst.projectOut(en.value());
1379 } else {
1380 unsigned varPosition;
1381 cst.findVar(en.value(), &varPosition);
1382 auto varKind = cst.getVarKindAt(varPosition);
1383 varPosition -= cst.getNumDimVars();
1384 cst.convertToLocal(varKind, varPosition, varPosition + 1);
1385 }
1386 }
1387 }
1388
1389 // Project out any local variables (these would have been added for any
1390 // mod/divs) if specified.
1391 if (dropLocalVars)
1392 cst.projectOut(cst.getNumDimAndSymbolVars(), cst.getNumLocalVars());
1393
1394 // Constant fold any symbolic variables.
1395 cst.constantFoldVarRange(/*pos=*/cst.getNumDimVars(),
1396 /*num=*/cst.getNumSymbolVars());
1397
1398 assert(cst.getNumDimVars() == rank && "unexpected MemRefRegion format");
1399
1400 // Add upper/lower bounds for each memref dimension with static size
1401 // to guard against potential over-approximation from projection.
1402 // TODO: Support dynamic memref dimensions.
1403 if (addMemRefDimBounds) {
1404 auto memRefType = cast<MemRefType>(memref.getType());
1405 for (unsigned r = 0; r < rank; r++) {
1406 cst.addBound(BoundType::LB, /*pos=*/r, /*value=*/0);
1407 if (memRefType.isDynamicDim(r))
1408 continue;
1409 cst.addBound(BoundType::UB, /*pos=*/r, memRefType.getDimSize(r) - 1);
1410 }
1411 }
1412 cst.removeTrivialRedundancy();
1413
1414 LDBG() << "Memory region: " << cst;
1415 return success();
1416}
1417
1418std::optional<int64_t>
1420 auto elementType = memRefType.getElementType();
1421
1422 unsigned sizeInBits;
1423 if (elementType.isIntOrFloat()) {
1424 sizeInBits = elementType.getIntOrFloatBitWidth();
1425 } else if (auto vectorType = dyn_cast<VectorType>(elementType)) {
1426 if (vectorType.getElementType().isIntOrFloat())
1427 sizeInBits =
1428 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
1429 else
1430 return std::nullopt;
1431 } else {
1432 return std::nullopt;
1433 }
1434 return llvm::divideCeil(sizeInBits, 8);
1435}
1436
1437// Returns the size of the region.
1438std::optional<int64_t> MemRefRegion::getRegionSize() {
1439 auto memRefType = cast<MemRefType>(memref.getType());
1440
1441 if (!memRefType.getLayout().isIdentity()) {
1442 LDBG() << "Non-identity layout map not yet supported";
1443 return false;
1444 }
1445
1446 // Compute the extents of the buffer.
1447 std::optional<int64_t> numElements = getConstantBoundingSizeAndShape();
1448 if (!numElements) {
1449 LDBG() << "Dynamic shapes not yet supported";
1450 return std::nullopt;
1451 }
1452 auto eltSize = getMemRefIntOrFloatEltSizeInBytes(memRefType);
1453 if (!eltSize)
1454 return std::nullopt;
1455 return *eltSize * *numElements;
1456}
1457
1458/// Returns the size of memref data in bytes if it's statically shaped,
1459/// std::nullopt otherwise. If the element of the memref has vector type, takes
1460/// into account size of the vector as well.
1461// TODO: improve/complete this when we have target data.
1462std::optional<uint64_t>
1464 if (!memRefType.hasStaticShape())
1465 return std::nullopt;
1466 auto elementType = memRefType.getElementType();
1467 if (!elementType.isIntOrFloat() && !isa<VectorType>(elementType))
1468 return std::nullopt;
1469
1470 auto sizeInBytes = getMemRefIntOrFloatEltSizeInBytes(memRefType);
1471 if (!sizeInBytes)
1472 return std::nullopt;
1473 for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) {
1474 sizeInBytes = *sizeInBytes * memRefType.getDimSize(i);
1475 }
1476 return sizeInBytes;
1477}
1478
1479template <typename LoadOrStoreOp>
1480LogicalResult mlir::affine::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
1481 bool emitError) {
1482 static_assert(llvm::is_one_of<LoadOrStoreOp, AffineReadOpInterface,
1483 AffineWriteOpInterface>::value,
1484 "argument should be either a AffineReadOpInterface or a "
1485 "AffineWriteOpInterface");
1486
1487 Operation *op = loadOrStoreOp.getOperation();
1488 MemRefRegion region(op->getLoc());
1489 if (failed(region.compute(op, /*loopDepth=*/0, /*sliceState=*/nullptr,
1490 /*addMemRefDimBounds=*/false)))
1491 return success();
1492
1493 LDBG() << "Memory region: " << region.getConstraints();
1494
1495 bool outOfBounds = false;
1496 unsigned rank = loadOrStoreOp.getMemRefType().getRank();
1497
1498 // For each dimension, check for out of bounds.
1499 for (unsigned r = 0; r < rank; r++) {
1500 FlatAffineValueConstraints ucst(*region.getConstraints());
1501
1502 // Intersect memory region with constraint capturing out of bounds (both out
1503 // of upper and out of lower), and check if the constraint system is
1504 // feasible. If it is, there is at least one point out of bounds.
1505 SmallVector<int64_t, 4> ineq(rank + 1, 0);
1506 int64_t dimSize = loadOrStoreOp.getMemRefType().getDimSize(r);
1507 // TODO: handle dynamic dim sizes.
1508 if (dimSize == -1)
1509 continue;
1510
1511 // Check for overflow: d_i >= memref dim size.
1512 ucst.addBound(BoundType::LB, r, dimSize);
1513 outOfBounds = !ucst.isEmpty();
1514 if (outOfBounds && emitError) {
1515 loadOrStoreOp.emitOpError()
1516 << "memref out of upper bound access along dimension #" << (r + 1);
1517 }
1518
1519 // Check for a negative index.
1520 FlatAffineValueConstraints lcst(*region.getConstraints());
1521 llvm::fill(ineq, 0);
1522 // d_i <= -1;
1523 lcst.addBound(BoundType::UB, r, -1);
1524 outOfBounds = !lcst.isEmpty();
1525 if (outOfBounds && emitError) {
1526 loadOrStoreOp.emitOpError()
1527 << "memref out of lower bound access along dimension #" << (r + 1);
1528 }
1529 }
1530 return failure(outOfBounds);
1531}
1532
1533// Explicitly instantiate the template so that the compiler knows we need them!
1534template LogicalResult
1535mlir::affine::boundCheckLoadOrStoreOp(AffineReadOpInterface loadOp,
1536 bool emitError);
1537template LogicalResult
1538mlir::affine::boundCheckLoadOrStoreOp(AffineWriteOpInterface storeOp,
1539 bool emitError);
1540
1541// Returns in 'positions' the Block positions of 'op' in each ancestor
1542// Block from the Block containing operation, stopping at 'limitBlock'.
1543static void findInstPosition(Operation *op, Block *limitBlock,
1544 SmallVectorImpl<unsigned> *positions) {
1545 Block *block = op->getBlock();
1546 while (block != limitBlock) {
1547 // FIXME: This algorithm is unnecessarily O(n) and should be improved to not
1548 // rely on linear scans.
1549 int instPosInBlock = std::distance(block->begin(), op->getIterator());
1550 positions->push_back(instPosInBlock);
1551 op = block->getParentOp();
1552 block = op->getBlock();
1553 }
1554 std::reverse(positions->begin(), positions->end());
1555}
1556
1557// Returns the Operation in a possibly nested set of Blocks, where the
1558// position of the operation is represented by 'positions', which has a
1559// Block position for each level of nesting.
1561 unsigned level, Block *block) {
1562 unsigned i = 0;
1563 for (auto &op : *block) {
1564 if (i != positions[level]) {
1565 ++i;
1566 continue;
1567 }
1568 if (level == positions.size() - 1)
1569 return &op;
1570 if (auto childAffineForOp = dyn_cast<AffineForOp>(op))
1571 return getInstAtPosition(positions, level + 1,
1572 childAffineForOp.getBody());
1573
1574 for (auto &region : op.getRegions()) {
1575 for (auto &b : region)
1576 if (auto *ret = getInstAtPosition(positions, level + 1, &b))
1577 return ret;
1578 }
1579 return nullptr;
1580 }
1581 return nullptr;
1582}
1583
1584// Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'.
1587 for (unsigned i = 0, e = cst->getNumDimVars(); i < e; ++i) {
1588 auto value = cst->getValue(i);
1589 if (ivs.count(value) == 0) {
1590 assert(isAffineForInductionVar(value));
1591 auto loop = getForInductionVarOwner(value);
1592 if (failed(cst->addAffineForOpDomain(loop)))
1593 return failure();
1594 }
1595 }
1596 return success();
1597}
1598
1599/// Returns the innermost common loop depth for the set of operations in 'ops'.
1600// TODO: Move this to LoopUtils.
1602 ArrayRef<Operation *> ops, SmallVectorImpl<AffineForOp> *surroundingLoops) {
1603 unsigned numOps = ops.size();
1604 assert(numOps > 0 && "Expected at least one operation");
1605
1606 std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
1607 unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
1608 for (unsigned i = 0; i < numOps; ++i) {
1609 getAffineForIVs(*ops[i], &loops[i]);
1610 loopDepthLimit =
1611 std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
1612 }
1613
1614 unsigned loopDepth = 0;
1615 for (unsigned d = 0; d < loopDepthLimit; ++d) {
1616 unsigned i;
1617 for (i = 1; i < numOps; ++i) {
1618 if (loops[i - 1][d] != loops[i][d])
1619 return loopDepth;
1620 }
1621 if (surroundingLoops)
1622 surroundingLoops->push_back(loops[i - 1][d]);
1623 ++loopDepth;
1624 }
1625 return loopDepth;
1626}
1627
1628/// Computes in 'sliceUnion' the union of all slice bounds computed at
1629/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and
1630/// then verifies if it is valid. Returns 'SliceComputationResult::Success' if
1631/// union was computed correctly, an appropriate failure otherwise.
1634 ArrayRef<Operation *> opsB, unsigned loopDepth,
1635 unsigned numCommonLoops, bool isBackwardSlice,
1636 ComputationSliceState *sliceUnion) {
1637 // Compute the union of slice bounds between all pairs in 'opsA' and
1638 // 'opsB' in 'sliceUnionCst'.
1639 FlatAffineValueConstraints sliceUnionCst;
1640 assert(sliceUnionCst.getNumDimAndSymbolVars() == 0);
1641 std::vector<std::pair<Operation *, Operation *>> dependentOpPairs;
1642 MemRefAccess srcAccess;
1643 MemRefAccess dstAccess;
1644 for (Operation *a : opsA) {
1645 srcAccess = MemRefAccess(a);
1646 for (Operation *b : opsB) {
1647 dstAccess = MemRefAccess(b);
1648 if (srcAccess.memref != dstAccess.memref)
1649 continue;
1650 // Check if 'loopDepth' exceeds nesting depth of src/dst ops.
1651 if ((!isBackwardSlice && loopDepth > getNestingDepth(a)) ||
1652 (isBackwardSlice && loopDepth > getNestingDepth(b))) {
1653 LDBG() << "Invalid loop depth";
1655 }
1656
1657 bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) &&
1658 isa<AffineReadOpInterface>(dstAccess.opInst);
1659 FlatAffineValueConstraints dependenceConstraints;
1660 // Check dependence between 'srcAccess' and 'dstAccess'.
1662 srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1,
1663 &dependenceConstraints, /*dependenceComponents=*/nullptr,
1664 /*allowRAR=*/readReadAccesses);
1665 if (result.value == DependenceResult::Failure) {
1666 LDBG() << "Dependence check failed";
1668 }
1670 continue;
1671 dependentOpPairs.emplace_back(a, b);
1672
1673 // Compute slice bounds for 'srcAccess' and 'dstAccess'.
1674 ComputationSliceState tmpSliceState;
1675 getComputationSliceState(a, b, dependenceConstraints, loopDepth,
1676 isBackwardSlice, &tmpSliceState);
1677
1678 if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
1679 // Initialize 'sliceUnionCst' with the bounds computed in previous step.
1680 if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
1681 LDBG() << "Unable to compute slice bound constraints";
1683 }
1684 assert(sliceUnionCst.getNumDimAndSymbolVars() > 0);
1685 continue;
1686 }
1687
1688 // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
1689 FlatAffineValueConstraints tmpSliceCst;
1690 if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
1691 LDBG() << "Unable to compute slice bound constraints";
1693 }
1694
1695 // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
1696 if (!sliceUnionCst.areVarsAlignedWithOther(tmpSliceCst)) {
1697
1698 // Pre-constraint var alignment: record loop IVs used in each constraint
1699 // system.
1700 SmallPtrSet<Value, 8> sliceUnionIVs;
1701 for (unsigned k = 0, l = sliceUnionCst.getNumDimVars(); k < l; ++k)
1702 sliceUnionIVs.insert(sliceUnionCst.getValue(k));
1703 SmallPtrSet<Value, 8> tmpSliceIVs;
1704 for (unsigned k = 0, l = tmpSliceCst.getNumDimVars(); k < l; ++k)
1705 tmpSliceIVs.insert(tmpSliceCst.getValue(k));
1706
1707 sliceUnionCst.mergeAndAlignVarsWithOther(/*offset=*/0, &tmpSliceCst);
1708
1709 // Post-constraint var alignment: add loop IV bounds missing after
1710 // var alignment to constraint systems. This can occur if one constraint
1711 // system uses an loop IV that is not used by the other. The call
1712 // to unionBoundingBox below expects constraints for each Loop IV, even
1713 // if they are the unsliced full loop bounds added here.
1714 if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst)))
1716 if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst)))
1718 }
1719 // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
1720 if (sliceUnionCst.getNumLocalVars() > 0 ||
1721 tmpSliceCst.getNumLocalVars() > 0 ||
1722 failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
1723 LDBG() << "Unable to compute union bounding box of slice bounds";
1725 }
1726 }
1727 }
1728
1729 // Empty union.
1730 if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
1731 LDBG() << "empty slice union - unexpected";
1733 }
1734
1735 // Gather loops surrounding ops from loop nest where slice will be inserted.
1737 for (auto &dep : dependentOpPairs) {
1738 ops.push_back(isBackwardSlice ? dep.second : dep.first);
1739 }
1740 SmallVector<AffineForOp, 4> surroundingLoops;
1741 unsigned innermostCommonLoopDepth =
1742 getInnermostCommonLoopDepth(ops, &surroundingLoops);
1743 if (loopDepth > innermostCommonLoopDepth) {
1744 LDBG() << "Exceeds max loop depth";
1746 }
1747
1748 // Store 'numSliceLoopIVs' before converting dst loop IVs to dims.
1749 unsigned numSliceLoopIVs = sliceUnionCst.getNumDimVars();
1750
1751 // Convert any dst loop IVs which are symbol variables to dim variables.
1752 sliceUnionCst.convertLoopIVSymbolsToDims();
1753 sliceUnion->clearBounds();
1754 sliceUnion->lbs.resize(numSliceLoopIVs, AffineMap());
1755 sliceUnion->ubs.resize(numSliceLoopIVs, AffineMap());
1756
1757 // Get slice bounds from slice union constraints 'sliceUnionCst'.
1758 sliceUnionCst.getSliceBounds(/*offset=*/0, numSliceLoopIVs,
1759 opsA[0]->getContext(), &sliceUnion->lbs,
1760 &sliceUnion->ubs);
1761
1762 // Add slice bound operands of union.
1763 SmallVector<Value, 4> sliceBoundOperands;
1764 sliceUnionCst.getValues(numSliceLoopIVs,
1765 sliceUnionCst.getNumDimAndSymbolVars(),
1766 &sliceBoundOperands);
1767
1768 // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'.
1769 sliceUnion->ivs.clear();
1770 sliceUnionCst.getValues(0, numSliceLoopIVs, &sliceUnion->ivs);
1771
1772 // Set loop nest insertion point to block start at 'loopDepth' for forward
1773 // slices, while at the end for backward slices.
1774 sliceUnion->insertPoint =
1775 isBackwardSlice
1776 ? surroundingLoops[loopDepth - 1].getBody()->begin()
1777 : std::prev(surroundingLoops[loopDepth - 1].getBody()->end());
1778
1779 // Give each bound its own copy of 'sliceBoundOperands' for subsequent
1780 // canonicalization.
1781 sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1782 sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1783
1784 // Check if the slice computed is valid. Return success only if it is verified
1785 // that the slice is valid, otherwise return appropriate failure status.
1786 std::optional<bool> isSliceValid = sliceUnion->isSliceValid();
1787 if (!isSliceValid) {
1788 LDBG() << "Cannot determine if the slice is valid";
1790 }
1791 if (!*isSliceValid)
1793
1795}
1796
1797// TODO: extend this to handle multiple result maps.
1798static std::optional<uint64_t> getConstDifference(AffineMap lbMap,
1799 AffineMap ubMap) {
1800 assert(lbMap.getNumResults() == 1 && "expected single result bound map");
1801 assert(ubMap.getNumResults() == 1 && "expected single result bound map");
1802 assert(lbMap.getNumDims() == ubMap.getNumDims());
1803 assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
1804 AffineExpr lbExpr(lbMap.getResult(0));
1805 AffineExpr ubExpr(ubMap.getResult(0));
1806 auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
1807 lbMap.getNumSymbols());
1808 auto cExpr = dyn_cast<AffineConstantExpr>(loopSpanExpr);
1809 if (!cExpr)
1810 return std::nullopt;
1811 return cExpr.getValue();
1812}
1813
1814// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
1815// nest surrounding represented by slice loop bounds in 'slice'. Returns true
1816// on success, false otherwise (if a non-constant trip count was encountered).
1817// TODO: Make this work with non-unit step loops.
1819 const ComputationSliceState &slice,
1820 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
1821 unsigned numSrcLoopIVs = slice.ivs.size();
1822 // Populate map from AffineForOp -> trip count
1823 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
1824 AffineForOp forOp = getForInductionVarOwner(slice.ivs[i]);
1825 auto *op = forOp.getOperation();
1826 AffineMap lbMap = slice.lbs[i];
1827 AffineMap ubMap = slice.ubs[i];
1828 // If lower or upper bound maps are null or provide no results, it implies
1829 // that source loop was not at all sliced, and the entire loop will be a
1830 // part of the slice.
1831 if (!lbMap || lbMap.getNumResults() == 0 || !ubMap ||
1832 ubMap.getNumResults() == 0) {
1833 // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
1834 if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
1835 (*tripCountMap)[op] =
1836 forOp.getConstantUpperBound() - forOp.getConstantLowerBound();
1837 continue;
1838 }
1839 std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
1840 if (maybeConstTripCount.has_value()) {
1841 (*tripCountMap)[op] = *maybeConstTripCount;
1842 continue;
1843 }
1844 return false;
1845 }
1846 std::optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
1847 // Slice bounds are created with a constant ub - lb difference.
1848 if (!tripCount.has_value())
1849 return false;
1850 (*tripCountMap)[op] = *tripCount;
1851 }
1852 return true;
1853}
1854
1855// Return the number of iterations in the given slice.
1857 const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
1858 uint64_t iterCount = 1;
1859 for (const auto &count : sliceTripCountMap) {
1860 iterCount *= count.second;
1861 }
1862 return iterCount;
1863}
1864
1865const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";
1866// Computes slice bounds by projecting out any loop IVs from
1867// 'dependenceConstraints' at depth greater than 'loopDepth', and computes slice
1868// bounds in 'sliceState' which represent the one loop nest's IVs in terms of
1869// the other loop nest's IVs, symbols and constants (using 'isBackwardsSlice').
1871 Operation *depSourceOp, Operation *depSinkOp,
1872 const FlatAffineValueConstraints &dependenceConstraints, unsigned loopDepth,
1873 bool isBackwardSlice, ComputationSliceState *sliceState) {
1874 // Get loop nest surrounding src operation.
1875 SmallVector<AffineForOp, 4> srcLoopIVs;
1876 getAffineForIVs(*depSourceOp, &srcLoopIVs);
1877 unsigned numSrcLoopIVs = srcLoopIVs.size();
1878
1879 // Get loop nest surrounding dst operation.
1880 SmallVector<AffineForOp, 4> dstLoopIVs;
1881 getAffineForIVs(*depSinkOp, &dstLoopIVs);
1882 unsigned numDstLoopIVs = dstLoopIVs.size();
1883
1884 assert((!isBackwardSlice && loopDepth <= numSrcLoopIVs) ||
1885 (isBackwardSlice && loopDepth <= numDstLoopIVs));
1886
1887 // Project out dimensions other than those up to 'loopDepth'.
1888 unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth;
1889 unsigned num =
1890 isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth;
1891 FlatAffineValueConstraints sliceCst(dependenceConstraints);
1892 sliceCst.projectOut(pos, num);
1893
1894 // Add slice loop IV values to 'sliceState'.
1895 unsigned offset = isBackwardSlice ? 0 : loopDepth;
1896 unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs;
1897 sliceCst.getValues(offset, offset + numSliceLoopIVs, &sliceState->ivs);
1898
1899 // Set up lower/upper bound affine maps for the slice.
1900 sliceState->lbs.resize(numSliceLoopIVs, AffineMap());
1901 sliceState->ubs.resize(numSliceLoopIVs, AffineMap());
1902
1903 // Get bounds for slice IVs in terms of other IVs, symbols, and constants.
1904 sliceCst.getSliceBounds(offset, numSliceLoopIVs, depSourceOp->getContext(),
1905 &sliceState->lbs, &sliceState->ubs);
1906
1907 // Set up bound operands for the slice's lower and upper bounds.
1908 SmallVector<Value, 4> sliceBoundOperands;
1909 unsigned numDimsAndSymbols = sliceCst.getNumDimAndSymbolVars();
1910 for (unsigned i = 0; i < numDimsAndSymbols; ++i) {
1911 if (i < offset || i >= offset + numSliceLoopIVs)
1912 sliceBoundOperands.push_back(sliceCst.getValue(i));
1913 }
1914
1915 // Give each bound its own copy of 'sliceBoundOperands' for subsequent
1916 // canonicalization.
1917 sliceState->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1918 sliceState->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1919
1920 // Set destination loop nest insertion point to block start at 'dstLoopDepth'.
1921 sliceState->insertPoint =
1922 isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin()
1923 : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end());
1924
1925 llvm::SmallDenseSet<Value, 8> sequentialLoops;
1926 if (isa<AffineReadOpInterface>(depSourceOp) &&
1927 isa<AffineReadOpInterface>(depSinkOp)) {
1928 // For read-read access pairs, clear any slice bounds on sequential loops.
1929 // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
1930 getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0],
1931 &sequentialLoops);
1932 }
1933 auto getSliceLoop = [&](unsigned i) {
1934 return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i];
1935 };
1936 auto isInnermostInsertion = [&]() {
1937 return (isBackwardSlice ? loopDepth >= srcLoopIVs.size()
1938 : loopDepth >= dstLoopIVs.size());
1939 };
1940 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
1941 auto srcIsUnitSlice = [&]() {
1942 return (buildSliceTripCountMap(*sliceState, &sliceTripCountMap) &&
1943 (getSliceIterationCount(sliceTripCountMap) == 1));
1944 };
1945 // Clear all sliced loop bounds beginning at the first sequential loop, or
1946 // first loop with a slice fusion barrier attribute..
1947
1948 for (unsigned i = 0; i < numSliceLoopIVs; ++i) {
1949 Value iv = getSliceLoop(i).getInductionVar();
1950 if (sequentialLoops.count(iv) == 0 &&
1951 getSliceLoop(i)->getAttr(kSliceFusionBarrierAttrName) == nullptr)
1952 continue;
1953 // Skip reset of bounds of reduction loop inserted in the destination loop
1954 // that meets the following conditions:
1955 // 1. Slice is single trip count.
1956 // 2. Loop bounds of the source and destination match.
1957 // 3. Is being inserted at the innermost insertion point.
1958 std::optional<bool> isMaximal = sliceState->isMaximal();
1959 if (isLoopParallelAndContainsReduction(getSliceLoop(i)) &&
1960 isInnermostInsertion() && srcIsUnitSlice() && isMaximal && *isMaximal)
1961 continue;
1962 for (unsigned j = i; j < numSliceLoopIVs; ++j) {
1963 sliceState->lbs[j] = AffineMap();
1964 sliceState->ubs[j] = AffineMap();
1965 }
1966 break;
1967 }
1968}
1969
1970/// Creates a computation slice of the loop nest surrounding 'srcOpInst',
1971/// updates the slice loop bounds with any non-null bound maps specified in
1972/// 'sliceState', and inserts this slice into the loop nest surrounding
1973/// 'dstOpInst' at loop depth 'dstLoopDepth'.
1974// TODO: extend the slicing utility to compute slices that
1975// aren't necessarily a one-to-one relation b/w the source and destination. The
1976// relation between the source and destination could be many-to-many in general.
1977// TODO: the slice computation is incorrect in the cases
1978// where the dependence from the source to the destination does not cover the
1979// entire destination index set. Subtract out the dependent destination
1980// iterations from destination index set and check for emptiness --- this is one
1981// solution.
1983 Operation *srcOpInst, Operation *dstOpInst, unsigned dstLoopDepth,
1984 ComputationSliceState *sliceState) {
1985 // Get loop nest surrounding src operation.
1986 SmallVector<AffineForOp, 4> srcLoopIVs;
1987 getAffineForIVs(*srcOpInst, &srcLoopIVs);
1988 unsigned numSrcLoopIVs = srcLoopIVs.size();
1989
1990 // Get loop nest surrounding dst operation.
1991 SmallVector<AffineForOp, 4> dstLoopIVs;
1992 getAffineForIVs(*dstOpInst, &dstLoopIVs);
1993 unsigned dstLoopIVsSize = dstLoopIVs.size();
1994 if (dstLoopDepth > dstLoopIVsSize) {
1995 dstOpInst->emitError("invalid destination loop depth");
1996 return AffineForOp();
1997 }
1998
1999 // Find the op block positions of 'srcOpInst' within 'srcLoopIVs'.
2000 SmallVector<unsigned, 4> positions;
2001 // TODO: This code is incorrect since srcLoopIVs can be 0-d.
2002 findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions);
2003
2004 // Clone src loop nest and insert it a the beginning of the operation block
2005 // of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
2006 auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1];
2007 OpBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin());
2008 auto sliceLoopNest =
2009 cast<AffineForOp>(b.clone(*srcLoopIVs[0].getOperation()));
2010
2011 Operation *sliceInst =
2012 getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody());
2013 // Get loop nest surrounding 'sliceInst'.
2014 SmallVector<AffineForOp, 4> sliceSurroundingLoops;
2015 getAffineForIVs(*sliceInst, &sliceSurroundingLoops);
2016
2017 // Sanity check.
2018 unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size();
2019 (void)sliceSurroundingLoopsSize;
2020 assert(dstLoopDepth + numSrcLoopIVs >= sliceSurroundingLoopsSize);
2021 unsigned sliceLoopLimit = dstLoopDepth + numSrcLoopIVs;
2022 (void)sliceLoopLimit;
2023 assert(sliceLoopLimit >= sliceSurroundingLoopsSize);
2024
2025 // Update loop bounds for loops in 'sliceLoopNest'.
2026 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
2027 auto forOp = sliceSurroundingLoops[dstLoopDepth + i];
2028 if (AffineMap lbMap = sliceState->lbs[i])
2029 forOp.setLowerBound(sliceState->lbOperands[i], lbMap);
2030 if (AffineMap ubMap = sliceState->ubs[i])
2031 forOp.setUpperBound(sliceState->ubOperands[i], ubMap);
2032 }
2033 return sliceLoopNest;
2034}
2035
2036// Constructs MemRefAccess populating it with the memref, its indices and
2037// opinst from 'loadOrStoreOpInst'.
2039 if (auto loadOp = dyn_cast<AffineReadOpInterface>(memOp)) {
2040 memref = loadOp.getMemRef();
2041 opInst = memOp;
2042 llvm::append_range(indices, loadOp.getMapOperands());
2043 } else {
2044 assert(isa<AffineWriteOpInterface>(memOp) &&
2045 "Affine read/write op expected");
2046 auto storeOp = cast<AffineWriteOpInterface>(memOp);
2047 opInst = memOp;
2048 memref = storeOp.getMemRef();
2049 llvm::append_range(indices, storeOp.getMapOperands());
2050 }
2051}
2052
2053unsigned MemRefAccess::getRank() const {
2054 return cast<MemRefType>(memref.getType()).getRank();
2055}
2056
2058 return isa<AffineWriteOpInterface>(opInst);
2059}
2060
2061/// Returns the nesting depth of this statement, i.e., the number of loops
2062/// surrounding this statement.
2064 Operation *currOp = op;
2065 unsigned depth = 0;
2066 while ((currOp = currOp->getParentOp())) {
2067 if (isa<AffineForOp>(currOp))
2068 depth++;
2069 if (auto parOp = dyn_cast<AffineParallelOp>(currOp))
2070 depth += parOp.getNumDims();
2071 }
2072 return depth;
2073}
2074
2075/// Equal if both affine accesses are provably equivalent (at compile
2076/// time) when considering the memref, the affine maps and their respective
2077/// operands. The equality of access functions + operands is checked by
2078/// subtracting fully composed value maps, and then simplifying the difference
2079/// using the expression flattener.
2080/// TODO: this does not account for aliasing of memrefs.
2082 if (memref != rhs.memref)
2083 return false;
2084
2085 AffineValueMap diff, thisMap, rhsMap;
2086 getAccessMap(&thisMap);
2087 rhs.getAccessMap(&rhsMap);
2088 return thisMap == rhsMap;
2089}
2090
2092 auto *currOp = op.getParentOp();
2093 AffineForOp currAffineForOp;
2094 // Traverse up the hierarchy collecting all 'affine.for' and affine.parallel
2095 // operation while skipping over 'affine.if' operations.
2096 while (currOp) {
2097 if (AffineForOp currAffineForOp = dyn_cast<AffineForOp>(currOp))
2098 ivs.push_back(currAffineForOp.getInductionVar());
2099 else if (auto parOp = dyn_cast<AffineParallelOp>(currOp))
2100 llvm::append_range(ivs, parOp.getIVs());
2101 currOp = currOp->getParentOp();
2102 }
2103 std::reverse(ivs.begin(), ivs.end());
2104}
2105
2106/// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
2107/// where each lists loops from outer-most to inner-most in loop nest.
2109 Operation &b) {
2110 SmallVector<Value, 4> loopsA, loopsB;
2111 getAffineIVs(a, loopsA);
2112 getAffineIVs(b, loopsB);
2113
2114 unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
2115 unsigned numCommonLoops = 0;
2116 for (unsigned i = 0; i < minNumLoops; ++i) {
2117 if (loopsA[i] != loopsB[i])
2118 break;
2119 ++numCommonLoops;
2120 }
2121 return numCommonLoops;
2122}
2123
2124static std::optional<int64_t> getMemoryFootprintBytes(Block &block,
2125 Block::iterator start,
2126 Block::iterator end,
2127 int memorySpace) {
2129
2130 // Walk this 'affine.for' operation to gather all memory regions.
2131 auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult {
2132 if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst)) {
2133 // Neither load nor a store op.
2134 return WalkResult::advance();
2135 }
2136
2137 // Compute the memref region symbolic in any IVs enclosing this block.
2138 auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
2139 if (failed(
2140 region->compute(opInst,
2141 /*loopDepth=*/getNestingDepth(&*block.begin())))) {
2142 LDBG() << "Error obtaining memory region";
2143 opInst->emitError("error obtaining memory region");
2144 return failure();
2145 }
2146
2147 auto [it, inserted] = regions.try_emplace(region->memref);
2148 if (inserted) {
2149 it->second = std::move(region);
2150 } else if (failed(it->second->unionBoundingBox(*region))) {
2151 LDBG() << "getMemoryFootprintBytes: unable to perform a union on a "
2152 "memory region";
2153 opInst->emitWarning(
2154 "getMemoryFootprintBytes: unable to perform a union on a memory "
2155 "region");
2156 return failure();
2157 }
2158 return WalkResult::advance();
2159 });
2160 if (result.wasInterrupted())
2161 return std::nullopt;
2162
2163 int64_t totalSizeInBytes = 0;
2164 for (const auto &region : regions) {
2165 std::optional<int64_t> size = region.second->getRegionSize();
2166 if (!size.has_value())
2167 return std::nullopt;
2168 totalSizeInBytes += *size;
2169 }
2170 return totalSizeInBytes;
2171}
2172
2173std::optional<int64_t> mlir::affine::getMemoryFootprintBytes(AffineForOp forOp,
2174 int memorySpace) {
2175 auto *forInst = forOp.getOperation();
2176 return ::getMemoryFootprintBytes(
2177 *forInst->getBlock(), Block::iterator(forInst),
2178 std::next(Block::iterator(forInst)), memorySpace);
2179}
2180
2181/// Returns whether a loop is parallel and contains a reduction loop.
2183 SmallVector<LoopReduction> reductions;
2184 if (!isLoopParallel(forOp, &reductions))
2185 return false;
2186 return !reductions.empty();
2187}
2188
2189/// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
2190/// at 'forOp'.
2192 AffineForOp forOp, llvm::SmallDenseSet<Value, 8> *sequentialLoops) {
2193 forOp->walk([&](Operation *op) {
2194 if (auto innerFor = dyn_cast<AffineForOp>(op))
2195 if (!isLoopParallel(innerFor))
2196 sequentialLoops->insert(innerFor.getInductionVar());
2197 });
2198}
2199
2202 if (fac.isEmpty())
2204 set.getContext());
2206
2207 auto simplifiedSet = fac.getAsIntegerSet(set.getContext());
2208 assert(simplifiedSet && "guaranteed to succeed while roundtripping");
2209 return simplifiedSet;
2210}
2211
2212static void unpackOptionalValues(ArrayRef<std::optional<Value>> source,
2214 target =
2215 llvm::to_vector<4>(llvm::map_range(source, [](std::optional<Value> val) {
2216 return val.has_value() ? *val : Value();
2217 }));
2218}
2219
2220/// Bound an identifier `pos` in a given FlatAffineValueConstraints with
2221/// constraints drawn from an affine map. Before adding the constraint, the
2222/// dimensions/symbols of the affine map are aligned with `constraints`.
2223/// `operands` are the SSA Value operands used with the affine map.
2224/// Note: This function adds a new symbol column to the `constraints` for each
2225/// dimension/symbol that exists in the affine map but not in `constraints`.
2226static LogicalResult alignAndAddBound(FlatAffineValueConstraints &constraints,
2227 BoundType type, unsigned pos,
2228 AffineMap map, ValueRange operands) {
2229 SmallVector<Value> dims, syms, newSyms;
2230 unpackOptionalValues(constraints.getMaybeValues(VarKind::SetDim), dims);
2231 unpackOptionalValues(constraints.getMaybeValues(VarKind::Symbol), syms);
2232
2233 AffineMap alignedMap =
2234 alignAffineMapWithValues(map, operands, dims, syms, &newSyms);
2235 for (unsigned i = syms.size(); i < newSyms.size(); ++i)
2236 constraints.appendSymbolVar(newSyms[i]);
2237 return constraints.addBound(type, pos, alignedMap);
2238}
2239
2240/// Add `val` to each result of `map`.
2242 SmallVector<AffineExpr> newResults;
2243 for (AffineExpr r : map.getResults())
2244 newResults.push_back(r + val);
2245 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
2246 map.getContext());
2247}
2248
2249// Attempt to simplify the given min/max operation by proving that its value is
2250// bounded by the same lower and upper bound.
2251//
2252// Bounds are computed by FlatAffineValueConstraints. Invariants required for
2253// finding/proving bounds should be supplied via `constraints`.
2254//
2255// 1. Add dimensions for `op` and `opBound` (lower or upper bound of `op`).
2256// 2. Compute an upper bound of `op` (in case of `isMin`) or a lower bound (in
2257// case of `!isMin`) and bind it to `opBound`. SSA values that are used in
2258// `op` but are not part of `constraints`, are added as extra symbols.
2259// 3. For each result of `op`: Add result as a dimension `r_i`. Prove that:
2260// * If `isMin`: r_i >= opBound
2261// * If `isMax`: r_i <= opBound
2262// If this is the case, ub(op) == lb(op).
2263// 4. Replace `op` with `opBound`.
2264//
2265// In summary, the following constraints are added throughout this function.
2266// Note: `invar` are dimensions added by the caller to express the invariants.
2267// (Showing only the case where `isMin`.)
2268//
2269// invar | op | opBound | r_i | extra syms... | const | eq/ineq
2270// ------+-------+---------+-----+---------------+-------+-------------------
2271// (various eq./ineq. constraining `invar`, added by the caller)
2272// ... | 0 | 0 | 0 | 0 | ... | ...
2273// ------+-------+---------+-----+---------------+-------+-------------------
2274// (various ineq. constraining `op` in terms of `op` operands (`invar` and
2275// extra `op` operands "extra syms" that are not in `invar`)).
2276// ... | -1 | 0 | 0 | ... | ... | >= 0
2277// ------+-------+---------+-----+---------------+-------+-------------------
2278// (set `opBound` to `op` upper bound in terms of `invar` and "extra syms")
2279// ... | 0 | -1 | 0 | ... | ... | = 0
2280// ------+-------+---------+-----+---------------+-------+-------------------
2281// (for each `op` map result r_i: set r_i to corresponding map result,
2282// prove that r_i >= minOpUb via contradiction)
2283// ... | 0 | 0 | -1 | ... | ... | = 0
2284// 0 | 0 | 1 | -1 | 0 | -1 | >= 0
2285//
2287 Operation *op, FlatAffineValueConstraints constraints) {
2288 bool isMin = isa<AffineMinOp>(op);
2289 assert((isMin || isa<AffineMaxOp>(op)) && "expect AffineMin/MaxOp");
2290 MLIRContext *ctx = op->getContext();
2291 Builder builder(ctx);
2292 AffineMap map =
2293 isMin ? cast<AffineMinOp>(op).getMap() : cast<AffineMaxOp>(op).getMap();
2294 ValueRange operands = op->getOperands();
2295 unsigned numResults = map.getNumResults();
2296
2297 // Add a few extra dimensions.
2298 unsigned dimOp = constraints.appendDimVar(); // `op`
2299 unsigned dimOpBound = constraints.appendDimVar(); // `op` lower/upper bound
2300 unsigned resultDimStart = constraints.appendDimVar(/*num=*/numResults);
2301
2302 // Add an inequality for each result expr_i of map:
2303 // isMin: op <= expr_i, !isMin: op >= expr_i
2304 auto boundType = isMin ? BoundType::UB : BoundType::LB;
2305 // Upper bounds are exclusive, so add 1. (`affine.min` ops are inclusive.)
2306 AffineMap mapLbUb = isMin ? addConstToResults(map, 1) : map;
2307 if (failed(
2308 alignAndAddBound(constraints, boundType, dimOp, mapLbUb, operands)))
2309 return failure();
2310
2311 // Try to compute a lower/upper bound for op, expressed in terms of the other
2312 // `dims` and extra symbols.
2313 SmallVector<AffineMap> opLb(1), opUb(1);
2314 constraints.getSliceBounds(dimOp, 1, ctx, &opLb, &opUb);
2315 AffineMap sliceBound = isMin ? opUb[0] : opLb[0];
2316 // TODO: `getSliceBounds` may return multiple bounds at the moment. This is
2317 // a TODO of `getSliceBounds` and not handled here.
2318 if (!sliceBound || sliceBound.getNumResults() != 1)
2319 return failure(); // No or multiple bounds found.
2320 // Recover the inclusive UB in the case of an `affine.min`.
2321 AffineMap boundMap = isMin ? addConstToResults(sliceBound, -1) : sliceBound;
2322
2323 // Add an equality: Set dimOpBound to computed bound.
2324 // Add back dimension for op. (Was removed by `getSliceBounds`.)
2325 AffineMap alignedBoundMap = boundMap.shiftDims(/*shift=*/1, /*offset=*/dimOp);
2326 if (failed(constraints.addBound(BoundType::EQ, dimOpBound, alignedBoundMap)))
2327 return failure();
2328
2329 // If the constraint system is empty, there is an inconsistency. (E.g., this
2330 // can happen if loop lb > ub.)
2331 if (constraints.isEmpty())
2332 return failure();
2333
2334 // In the case of `isMin` (`!isMin` is inversed):
2335 // Prove that each result of `map` has a lower bound that is equal to (or
2336 // greater than) the upper bound of `op` (`dimOpBound`). In that case, `op`
2337 // can be replaced with the bound. I.e., prove that for each result
2338 // expr_i (represented by dimension r_i):
2339 //
2340 // r_i >= opBound
2341 //
2342 // To prove this inequality, add its negation to the constraint set and prove
2343 // that the constraint set is empty.
2344 for (unsigned i = resultDimStart; i < resultDimStart + numResults; ++i) {
2345 FlatAffineValueConstraints newConstr(constraints);
2346
2347 // Add an equality: r_i = expr_i
2348 // Note: These equalities could have been added earlier and used to express
2349 // minOp <= expr_i. However, then we run the risk that `getSliceBounds`
2350 // computes minOpUb in terms of r_i dims, which is not desired.
2351 if (failed(alignAndAddBound(newConstr, BoundType::EQ, i,
2352 map.getSubMap({i - resultDimStart}), operands)))
2353 return failure();
2354
2355 // If `isMin`: Add inequality: r_i < opBound
2356 // equiv.: opBound - r_i - 1 >= 0
2357 // If `!isMin`: Add inequality: r_i > opBound
2358 // equiv.: -opBound + r_i - 1 >= 0
2359 SmallVector<int64_t> ineq(newConstr.getNumCols(), 0);
2360 ineq[dimOpBound] = isMin ? 1 : -1;
2361 ineq[i] = isMin ? -1 : 1;
2362 ineq[newConstr.getNumCols() - 1] = -1;
2363 newConstr.addInequality(ineq);
2364 if (!newConstr.isEmpty())
2365 return failure();
2366 }
2367
2368 // Lower and upper bound of `op` are equal. Replace `minOp` with its bound.
2369 AffineMap newMap = alignedBoundMap;
2370 SmallVector<Value> newOperands;
2371 unpackOptionalValues(constraints.getMaybeValues(), newOperands);
2372 // If dims/symbols have known constant values, use those in order to simplify
2373 // the affine map further.
2374 for (int64_t i = 0, e = constraints.getNumDimAndSymbolVars(); i < e; ++i) {
2375 // Skip unused operands and operands that are already constants.
2376 if (!newOperands[i] || getConstantIntValue(newOperands[i]))
2377 continue;
2378 if (auto bound = constraints.getConstantBound64(BoundType::EQ, i)) {
2379 AffineExpr expr =
2380 i < newMap.getNumDims()
2381 ? builder.getAffineDimExpr(i)
2382 : builder.getAffineSymbolExpr(i - newMap.getNumDims());
2383 newMap = newMap.replace(expr, builder.getAffineConstantExpr(*bound),
2384 newMap.getNumDims(), newMap.getNumSymbols());
2385 }
2386 }
2387 affine::canonicalizeMapAndOperands(&newMap, &newOperands);
2388 return AffineValueMap(newMap, newOperands);
2389}
2390
2392 Operation *b) {
2393 Region *aScope = getAffineAnalysisScope(a);
2394 Region *bScope = getAffineAnalysisScope(b);
2395 if (aScope != bScope)
2396 return nullptr;
2397
2398 // Get the block ancestry of `op` while stopping at the affine scope `aScope`
2399 // and store them in `ancestry`.
2400 auto getBlockAncestry = [&](Operation *op,
2401 SmallVectorImpl<Block *> &ancestry) {
2402 Operation *curOp = op;
2403 do {
2404 ancestry.push_back(curOp->getBlock());
2405 if (curOp->getParentRegion() == aScope)
2406 break;
2407 curOp = curOp->getParentOp();
2408 } while (curOp);
2409 assert(curOp && "can't reach root op without passing through affine scope");
2410 std::reverse(ancestry.begin(), ancestry.end());
2411 };
2412
2413 SmallVector<Block *, 4> aAncestors, bAncestors;
2414 getBlockAncestry(a, aAncestors);
2415 getBlockAncestry(b, bAncestors);
2416 assert(!aAncestors.empty() && !bAncestors.empty() &&
2417 "at least one Block ancestor expected");
2418
2419 Block *innermostCommonBlock = nullptr;
2420 for (unsigned a = 0, b = 0, e = aAncestors.size(), f = bAncestors.size();
2421 a < e && b < f; ++a, ++b) {
2422 if (aAncestors[a] != bAncestors[b])
2423 break;
2424 innermostCommonBlock = aAncestors[a];
2425 }
2426 return innermostCommonBlock;
2427}
return success()
static Value getMemRef(Operation *memOp)
Returns the memref being read/written by a memref/affine load/store op.
Definition Utils.cpp:246
static std::optional< uint64_t > getConstDifference(AffineMap lbMap, AffineMap ubMap)
Definition Utils.cpp:1798
static void findInstPosition(Operation *op, Block *limitBlock, SmallVectorImpl< unsigned > *positions)
Definition Utils.cpp:1543
static Node * addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg, DenseMap< Value, SetVector< unsigned > > &memrefAccesses)
Add op to MDG creating a new node and adding its memory accesses (affine or non-affine to memrefAcces...
Definition Utils.cpp:191
static bool mayDependence(const Node &srcNode, const Node &dstNode, Value memref)
Returns true if there may be a dependence on memref from srcNode's memory ops to dstNode's memory ops...
Definition Utils.cpp:264
const char *const kSliceFusionBarrierAttrName
Definition Utils.cpp:1865
static LogicalResult addMissingLoopIVBounds(SmallPtrSet< Value, 8 > &ivs, FlatAffineValueConstraints *cst)
Definition Utils.cpp:1585
static void getEffectedValues(Operation *op, SmallVectorImpl< Value > &values)
Returns the values that this op has a memref effect of type EffectTys on, not considering recursive e...
Definition Utils.cpp:165
MemRefDependenceGraph::Node Node
Definition Utils.cpp:38
static void unpackOptionalValues(ArrayRef< std::optional< Value > > source, SmallVector< Value > &target)
Definition Utils.cpp:2212
static AffineMap addConstToResults(AffineMap map, int64_t val)
Add val to each result of map.
Definition Utils.cpp:2241
static LogicalResult alignAndAddBound(FlatAffineValueConstraints &constraints, BoundType type, unsigned pos, AffineMap map, ValueRange operands)
Bound an identifier pos in a given FlatAffineValueConstraints with constraints drawn from an affine m...
Definition Utils.cpp:2226
static Operation * getInstAtPosition(ArrayRef< unsigned > positions, unsigned level, Block *block)
Definition Utils.cpp:1560
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
template bool mlir::hasEffect< MemoryEffects::Free >(Operation *)
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
Definition AffineMap.h:267
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:140
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition Block.h:308
iterator begin()
Definition Block.h:143
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:368
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:372
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
IntegerSet getAsIntegerSet(MLIRContext *context) const
Returns the constraint system as an integer set.
void getSliceBounds(unsigned offset, unsigned num, MLIRContext *context, SmallVectorImpl< AffineMap > *lbMaps, SmallVectorImpl< AffineMap > *ubMaps, bool closedUB=false)
Computes the lower and upper bounds of the first num dimensional variables (starting at offset) as an...
std::optional< int64_t > getConstantBoundOnDimSize(MLIRContext *context, unsigned pos, AffineMap *lb=nullptr, AffineMap *ub=nullptr, unsigned *minLbPos=nullptr, unsigned *minUbPos=nullptr) const
Returns a non-negative constant bound on the extent (upper bound - lower bound) of the specified vari...
FlatLinearValueConstraints represents an extension of FlatLinearConstraints where each non-local vari...
LogicalResult unionBoundingBox(const FlatLinearValueConstraints &other)
Updates the constraints to be the smallest bounding (enclosing) box that contains the points of this ...
void mergeAndAlignVarsWithOther(unsigned offset, FlatLinearValueConstraints *other)
Merge and align the variables of this and other starting at offset, so that both constraint systems g...
Value getValue(unsigned pos) const
Returns the Value associated with the pos^th variable.
void projectOut(Value val)
Projects out the variable that is associate with Value.
bool containsVar(Value val) const
Returns true if a variable with the specified Value exists, false otherwise.
void addBound(presburger::BoundType type, Value val, int64_t value)
Adds a constant bound for the variable associated with the given Value.
bool areVarsAlignedWithOther(const FlatLinearConstraints &other)
Returns true if this constraint system and other are in the same space, i.e., if they are associated ...
void getValues(unsigned start, unsigned end, SmallVectorImpl< Value > *values) const
Returns the Values associated with variables in range [start, end).
SmallVector< std::optional< Value > > getMaybeValues() const
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition IntegerSet.h:44
unsigned getNumDims() const
MLIRContext * getContext() const
static IntegerSet getEmptySet(unsigned numDims, unsigned numSymbols, MLIRContext *context)
Definition IntegerSet.h:56
unsigned getNumSymbols() const
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
result_range getResults()
Definition Operation.h:415
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
Value getOperand(unsigned i) const
unsigned getNumOperands() const
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
LogicalResult addBound(presburger::BoundType type, unsigned pos, AffineMap boundMap, ValueRange operands)
Adds a bound for the variable at the specified position with constraints being drawn from the specifi...
void convertLoopIVSymbolsToDims()
Changes all symbol variables which are loop IVs to dim variables.
LogicalResult addDomainFromSliceMaps(ArrayRef< AffineMap > lbMaps, ArrayRef< AffineMap > ubMaps, ArrayRef< Value > operands)
Adds constraints (lower and upper bounds) for each loop in the loop nest described by the bound maps ...
LogicalResult addAffineForOpDomain(AffineForOp forOp)
Adds constraints (lower and upper bounds) for the specified 'affine.for' operation's Value using IR i...
LogicalResult addSliceBounds(ArrayRef< Value > values, ArrayRef< AffineMap > lbMaps, ArrayRef< AffineMap > ubMaps, ArrayRef< Value > operands)
Adds slice lower bounds represented by lower bounds in lbMaps and upper bounds in ubMaps to each vari...
void removeTrivialRedundancy()
Removes duplicate constraints, trivially true constraints, and constraints that can be detected as re...
bool isEmpty() const
Checks for emptiness by performing variable elimination on all variables, running the GCD test on eac...
unsigned getNumCols() const
Returns the number of columns in the constraint system.
void addInequality(ArrayRef< DynamicAPInt > inEq)
Adds an inequality (>= 0) from the coefficients specified in inEq.
std::optional< int64_t > getConstantBound64(BoundType type, unsigned pos) const
The same, but casts to int64_t.
bool isIntegerEmpty() const
Return true if all the sets in the union are known to be integer empty false otherwise.
PresburgerSet subtract(const PresburgerRelation &set) const
std::optional< uint64_t > getConstantTripCount(AffineForOp forOp)
Returns the trip count of the loop if it's a constant, std::nullopt otherwise.
IntegerSet simplifyIntegerSet(IntegerSet set)
Simplify the integer set by simplifying the underlying affine expressions by flattening and some simp...
Definition Utils.cpp:2200
void getEnclosingAffineOps(Operation &op, SmallVectorImpl< Operation * > *ops)
Populates 'ops' with affine operations enclosing op ordered from outermost to innermost while stoppin...
Definition Utils.cpp:865
SliceComputationResult computeSliceUnion(ArrayRef< Operation * > opsA, ArrayRef< Operation * > opsB, unsigned loopDepth, unsigned numCommonLoops, bool isBackwardSlice, ComputationSliceState *sliceUnion)
Computes in 'sliceUnion' the union of all slice bounds computed at 'loopDepth' between all dependent ...
Definition Utils.cpp:1633
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
bool isLoopParallelAndContainsReduction(AffineForOp forOp)
Returns whether a loop is a parallel loop and contains a reduction loop.
Definition Utils.cpp:2182
unsigned getNumCommonSurroundingLoops(Operation &a, Operation &b)
Returns the number of surrounding loops common to both A and B.
Definition Utils.cpp:2108
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
void getAffineIVs(Operation &op, SmallVectorImpl< Value > &ivs)
Populates 'ivs' with IVs of the surrounding affine.for and affine.parallel ops ordered from the outer...
Definition Utils.cpp:2091
void getSequentialLoops(AffineForOp forOp, llvm::SmallDenseSet< Value, 8 > *sequentialLoops)
Returns in 'sequentialLoops' all sequential loops in loop nest rooted at 'forOp'.
Definition Utils.cpp:2191
DependenceResult checkMemrefAccessDependence(const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, unsigned loopDepth, FlatAffineValueConstraints *dependenceConstraints=nullptr, SmallVector< DependenceComponent, 2 > *dependenceComponents=nullptr, bool allowRAR=false)
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
Region * getAffineAnalysisScope(Operation *op)
Returns the closest region enclosing op that is held by a non-affine operation; nullptr if there is n...
void getAffineForIVs(Operation &op, SmallVectorImpl< AffineForOp > *loops)
Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered from the outermost 'affine....
Definition Utils.cpp:851
std::optional< int64_t > getMemoryFootprintBytes(AffineForOp forOp, int memorySpace=-1)
Gets the memory footprint of all data touched in the specified memory space in bytes; if the memory s...
Definition Utils.cpp:2173
unsigned getInnermostCommonLoopDepth(ArrayRef< Operation * > ops, SmallVectorImpl< AffineForOp > *surroundingLoops=nullptr)
Returns the innermost common loop depth for the set of operations in 'ops'.
Definition Utils.cpp:1601
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
void getComputationSliceState(Operation *depSourceOp, Operation *depSinkOp, const FlatAffineValueConstraints &dependenceConstraints, unsigned loopDepth, bool isBackwardSlice, ComputationSliceState *sliceState)
Computes the computation slice loop bounds for one loop nest as affine maps of the other loop nest's ...
Definition Utils.cpp:1870
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
std::optional< uint64_t > getIntOrFloatMemRefSizeInBytes(MemRefType memRefType)
Returns the size of a memref with element type int or float in bytes if it's statically shaped,...
Definition Utils.cpp:1463
unsigned getNestingDepth(Operation *op)
Returns the nesting depth of this operation, i.e., the number of loops surrounding this operation.
Definition Utils.cpp:2063
uint64_t getSliceIterationCount(const llvm::SmallDenseMap< Operation *, uint64_t, 8 > &sliceTripCountMap)
Return the number of iterations for the slicetripCountMap provided.
Definition Utils.cpp:1856
LogicalResult boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, bool emitError=true)
Checks a load or store op for an out of bound access; returns failure if the access is out of bounds ...
mlir::Block * findInnermostCommonBlockInScope(mlir::Operation *a, mlir::Operation *b)
Find the innermost common Block of a and b in the affine scope that a and b are part of.
Definition Utils.cpp:2391
bool noDependence(DependenceResult result)
Returns true if the provided DependenceResult corresponds to the absence of a dependence.
bool buildSliceTripCountMap(const ComputationSliceState &slice, llvm::SmallDenseMap< Operation *, uint64_t, 8 > *tripCountMap)
Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop nest surrounding represe...
Definition Utils.cpp:1818
AffineForOp insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst, unsigned dstLoopDepth, ComputationSliceState *sliceState)
Creates a clone of the computation contained in the loop nest surrounding 'srcOpInst',...
Definition Utils.cpp:1982
FailureOr< AffineValueMap > simplifyConstrainedMinMaxOp(Operation *op, FlatAffineValueConstraints constraints)
Try to simplify the given affine.min or affine.max op to an affine map with a single result and opera...
Definition Utils.cpp:2286
std::optional< int64_t > getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType)
Returns the memref's element type's size in bytes where the elemental type is an int or float or a ve...
Definition Utils.cpp:1419
BoundType
The type of bound: equal, lower bound or upper bound.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
AffineMap alignAffineMapWithValues(AffineMap map, ValueRange operands, ValueRange dims, ValueRange syms, SmallVector< Value > *newSyms=nullptr)
Re-indexes the dimensions and symbols of an affine map with given operands values to align with dims ...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their associated operands for a ...
Definition Utils.h:318
std::optional< bool > isSliceValid() const
Checks the validity of the slice computed.
Definition Utils.cpp:1036
SmallVector< Value, 4 > ivs
Definition Utils.h:321
LogicalResult getAsConstraints(FlatAffineValueConstraints *cst) const
Definition Utils.cpp:898
LogicalResult getSourceAsConstraints(FlatAffineValueConstraints &cst) const
Adds to 'cst' constraints which represent the original loop bounds on 'ivs' in 'this'.
Definition Utils.cpp:882
std::vector< SmallVector< Value, 4 > > ubOperands
Definition Utils.h:329
SmallVector< AffineMap, 4 > ubs
Definition Utils.h:325
std::optional< bool > isMaximal() const
Returns true if the computation slice encloses all the iterations of the sliced loop nest.
Definition Utils.cpp:1106
SmallVector< AffineMap, 4 > lbs
Definition Utils.h:323
std::vector< SmallVector< Value, 4 > > lbOperands
Definition Utils.h:327
Checks whether two accesses to the same memref access the same element.
SmallVector< Operation *, 4 > memrefFrees
Definition Utils.h:49
SmallVector< AffineForOp, 4 > forOps
Definition Utils.h:39
SmallVector< Operation *, 4 > loadOpInsts
Definition Utils.h:41
SmallVector< Operation *, 4 > memrefStores
Definition Utils.h:47
void collect(Operation *opToWalk)
Definition Utils.cpp:43
SmallVector< Operation *, 4 > memrefLoads
Definition Utils.h:45
SmallVector< Operation *, 4 > storeOpInsts
Definition Utils.h:43
Encapsulates a memref load or store access information.
unsigned getRank() const
Definition Utils.cpp:2053
SmallVector< Value, 4 > indices
void getAccessMap(AffineValueMap *accessMap) const
Populates 'accessMap' with composition of AffineApplyOps reachable from 'indices'.
MemRefAccess(Operation *memOp)
Constructs a MemRefAccess from an affine read/write operation.
Definition Utils.cpp:2038
bool operator==(const MemRefAccess &rhs) const
Equal if both affine accesses can be proved to be equivalent at compile time (considering the memrefs...
Definition Utils.cpp:2081
void getStoreOpsForMemref(Value memref, SmallVectorImpl< Operation * > *storeOps) const
Definition Utils.cpp:130
SmallVector< Operation *, 4 > loads
Definition Utils.h:73
SmallVector< Operation *, 4 > stores
Definition Utils.h:77
void getLoadAndStoreMemrefSet(DenseSet< Value > *loadAndStoreMemrefSet) const
Definition Utils.cpp:149
unsigned hasFree(Value memref) const
Definition Utils.cpp:123
SmallVector< Operation *, 4 > memrefLoads
Definition Utils.h:75
SmallVector< Operation *, 4 > memrefStores
Definition Utils.h:79
unsigned getLoadOpCount(Value memref) const
Definition Utils.cpp:78
unsigned getStoreOpCount(Value memref) const
Definition Utils.cpp:93
unsigned hasStore(Value memref) const
Returns true if there exists an operation with a write memory effect to memref in this node.
Definition Utils.cpp:109
void getLoadOpsForMemref(Value memref, SmallVectorImpl< Operation * > *loadOps) const
Definition Utils.cpp:139
SmallVector< Operation *, 4 > memrefFrees
Definition Utils.h:81
DenseMap< unsigned, SmallVector< Edge, 2 > > outEdges
Definition Utils.h:147
Block & block
The block for which this graph is created to perform fusion.
Definition Utils.h:270
unsigned addNode(Operation *op)
Definition Utils.cpp:467
bool writesToLiveInOrEscapingMemrefs(unsigned id) const
Definition Utils.cpp:497
void removeEdge(unsigned srcId, unsigned dstId, Value value)
Definition Utils.cpp:543
void addEdge(unsigned srcId, unsigned dstId, Value value)
Definition Utils.cpp:532
DenseMap< unsigned, Node > nodes
Definition Utils.h:141
void gatherDefiningNodes(unsigned id, DenseSet< unsigned > &definingNodes) const
Return all nodes which define SSA values used in node 'id'.
Definition Utils.cpp:631
bool hasDependencePath(unsigned srcId, unsigned dstId) const
Definition Utils.cpp:570
void clearNodeLoadAndStores(unsigned id)
Definition Utils.cpp:797
const Node * getForOpNode(AffineForOp forOp) const
Definition Utils.cpp:459
Operation * getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) const
Definition Utils.cpp:645
void updateEdges(unsigned srcId, unsigned dstId, const DenseSet< Value > &privateMemRefs, bool removeSrcId)
Definition Utils.cpp:720
DenseMap< unsigned, SmallVector< Edge, 2 > > inEdges
Definition Utils.h:144
void forEachMemRefInputEdge(unsigned id, const std::function< void(Edge)> &callback)
Definition Utils.cpp:805
bool init(bool fullAffineDependences=true)
Definition Utils.cpp:336
unsigned getOutEdgeCount(unsigned id, Value memref=nullptr) const
Definition Utils.cpp:621
const Node * getNode(unsigned id) const
Definition Utils.cpp:452
void forEachMemRefOutputEdge(unsigned id, const std::function< void(Edge)> &callback)
Definition Utils.cpp:813
void forEachMemRefEdge(ArrayRef< Edge > edges, const std::function< void(Edge)> &callback)
Definition Utils.cpp:821
void addToNode(unsigned id, ArrayRef< Operation * > loads, ArrayRef< Operation * > stores, ArrayRef< Operation * > memrefLoads, ArrayRef< Operation * > memrefStores, ArrayRef< Operation * > memrefFrees)
Definition Utils.cpp:784
unsigned getIncomingMemRefAccesses(unsigned id, Value memref) const
Definition Utils.cpp:605
bool hasEdge(unsigned srcId, unsigned dstId, Value value=nullptr) const
Definition Utils.cpp:517
void print(raw_ostream &os) const
Definition Utils.cpp:833
DenseMap< Value, unsigned > memrefEdgeCount
Definition Utils.h:150
A region of a memref's data space; this is typically constructed by analyzing load/store op's on this...
Definition Utils.h:489
std::optional< int64_t > getConstantBoundingSizeAndShape(SmallVectorImpl< int64_t > *shape=nullptr, SmallVectorImpl< AffineMap > *lbs=nullptr) const
Returns a constant upper bound on the number of elements in this region if bounded by a known constan...
Definition Utils.cpp:1160
unsigned getRank() const
Returns the rank of the memref that this region corresponds to.
Definition Utils.cpp:1156
FlatAffineValueConstraints cst
Region (data space) of the memref accessed.
Definition Utils.h:586
LogicalResult compute(Operation *op, unsigned loopDepth, const ComputationSliceState *sliceState=nullptr, bool addMemRefDimBounds=true, bool dropLocalVars=true, bool dropOuterIVs=true)
Computes the memory region accessed by this memref with the region represented as constraints symboli...
Definition Utils.cpp:1257
void getLowerAndUpperBound(unsigned pos, AffineMap &lbMap, AffineMap &ubMap) const
Gets the lower and upper bound map for the dimensional variable at pos.
Definition Utils.cpp:1216
std::optional< int64_t > getRegionSize()
Returns the size of this MemRefRegion in bytes.
Definition Utils.cpp:1438
LogicalResult unionBoundingBox(const MemRefRegion &other)
Definition Utils.cpp:1235
bool write
Read or write.
Definition Utils.h:573
FlatAffineValueConstraints * getConstraints()
Definition Utils.h:535
Value memref
Memref that this region corresponds to.
Definition Utils.h:570
MemRefRegion(Location loc)
Definition Utils.h:490
Enumerates different result statuses of slice computation by computeSliceUnion
Definition Utils.h:305
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.