MLIR 23.0.0git
Utils.cpp
Go to the documentation of this file.
1//===- Utils.cpp ---- Misc utilities for analysis -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements miscellaneous analysis routines for non-loop IR
10// structures.
11//
12//===----------------------------------------------------------------------===//
13
15
23#include "mlir/IR/IntegerSet.h"
24#include "llvm/ADT/SetVector.h"
25#include "llvm/ADT/SmallVectorExtras.h"
26#include "llvm/Support/Debug.h"
27#include "llvm/Support/DebugLog.h"
28#include "llvm/Support/raw_ostream.h"
29#include <optional>
30
31#define DEBUG_TYPE "analysis-utils"
32
33using namespace mlir;
34using namespace affine;
35using namespace presburger;
36
37using llvm::SmallDenseMap;
38
40
41// LoopNestStateCollector walks loop nests and collects load and store
42// operations, and whether or not a region holding op other than ForOp and IfOp
43// was encountered in the loop nest.
45 opToWalk->walk([&](Operation *op) {
46 if (auto forOp = dyn_cast<AffineForOp>(op)) {
47 forOps.push_back(forOp);
48 } else if (isa<AffineReadOpInterface>(op)) {
49 loadOpInsts.push_back(op);
50 } else if (isa<AffineWriteOpInterface>(op)) {
51 storeOpInsts.push_back(op);
52 } else {
53 auto memInterface = dyn_cast<MemoryEffectOpInterface>(op);
54 if (!memInterface) {
56 // This op itself is memory-effect free.
57 return;
58 // Check operands. Eg. ops like the `call` op are handled here.
59 for (Value v : op->getOperands()) {
60 if (!isa<MemRefType>(v.getType()))
61 continue;
62 // Conservatively, we assume the memref is read and written to.
63 memrefLoads.push_back(op);
64 memrefStores.push_back(op);
65 }
66 } else {
67 // Non-affine loads and stores.
69 memrefLoads.push_back(op);
71 memrefStores.push_back(op);
73 memrefFrees.push_back(op);
74 }
75 }
76 });
77}
78
80 unsigned loadOpCount = 0;
81 for (Operation *loadOp : loads) {
82 // Common case: affine reads.
83 if (auto affineLoad = dyn_cast<AffineReadOpInterface>(loadOp)) {
84 if (memref == affineLoad.getMemRef())
85 ++loadOpCount;
86 } else if (hasEffect<MemoryEffects::Read>(loadOp, memref)) {
87 ++loadOpCount;
88 }
89 }
90 return loadOpCount;
91}
92
93// Returns the store op count for 'memref'.
95 unsigned storeOpCount = 0;
96 for (auto *storeOp : llvm::concat<Operation *const>(stores, memrefStores)) {
97 // Common case: affine writes.
98 if (auto affineStore = dyn_cast<AffineWriteOpInterface>(storeOp)) {
99 if (memref == affineStore.getMemRef())
100 ++storeOpCount;
101 } else if (hasEffect<MemoryEffects::Write>(const_cast<Operation *>(storeOp),
102 memref)) {
103 ++storeOpCount;
104 }
105 }
106 return storeOpCount;
107}
108
109// Returns the store op count for 'memref'.
110unsigned Node::hasStore(Value memref) const {
111 return llvm::any_of(
112 llvm::concat<Operation *const>(stores, memrefStores),
113 [&](Operation *storeOp) {
114 if (auto affineStore = dyn_cast<AffineWriteOpInterface>(storeOp)) {
115 if (memref == affineStore.getMemRef())
116 return true;
117 } else if (hasEffect<MemoryEffects::Write>(storeOp, memref)) {
118 return true;
119 }
120 return false;
121 });
122}
123
124unsigned Node::hasFree(Value memref) const {
125 return llvm::any_of(memrefFrees, [&](Operation *freeOp) {
127 });
128}
129
130// Returns all store ops in 'storeOps' which access 'memref'.
132 SmallVectorImpl<Operation *> *storeOps) const {
133 for (Operation *storeOp : stores) {
134 if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
135 storeOps->push_back(storeOp);
136 }
137}
138
139// Returns all load ops in 'loadOps' which access 'memref'.
141 SmallVectorImpl<Operation *> *loadOps) const {
142 for (Operation *loadOp : loads) {
143 if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
144 loadOps->push_back(loadOp);
145 }
146}
147
148// Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
149// has at least one load and store operation.
151 DenseSet<Value> *loadAndStoreMemrefSet) const {
152 llvm::SmallDenseSet<Value, 2> loadMemrefs;
153 for (Operation *loadOp : loads) {
154 loadMemrefs.insert(cast<AffineReadOpInterface>(loadOp).getMemRef());
155 }
156 for (Operation *storeOp : stores) {
157 auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
158 if (loadMemrefs.count(memref) > 0)
159 loadAndStoreMemrefSet->insert(memref);
160 }
161}
162
163/// Returns the values that this op has a memref effect of type `EffectTys` on,
164/// not considering recursive effects.
165template <typename... EffectTys>
167 auto memOp = dyn_cast<MemoryEffectOpInterface>(op);
168 if (!memOp) {
170 // No effects.
171 return;
172 // Memref operands have to be considered as being affected.
173 for (Value operand : op->getOperands()) {
174 if (isa<MemRefType>(operand.getType()))
175 values.push_back(operand);
176 }
177 return;
178 }
180 memOp.getEffects(effects);
181 for (auto &effect : effects) {
182 Value effectVal = effect.getValue();
183 if (isa<EffectTys...>(effect.getEffect()) && effectVal &&
184 isa<MemRefType>(effectVal.getType()))
185 values.push_back(effectVal);
186 };
187}
188
189/// Add `op` to MDG creating a new node and adding its memory accesses (affine
190/// or non-affine to memrefAccesses (memref -> list of nodes with accesses) map.
191static Node *
193 DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
194 auto &nodes = mdg.nodes;
195 // Create graph node 'id' to represent top-level 'forOp' and record
196 // all loads and store accesses it contains.
197 LoopNestStateCollector collector;
198 collector.collect(nodeOp);
199 unsigned newNodeId = mdg.nextNodeId++;
200 Node &node = nodes.insert({newNodeId, Node(newNodeId, nodeOp)}).first->second;
201 for (Operation *op : collector.loadOpInsts) {
202 node.loads.push_back(op);
203 auto memref = cast<AffineReadOpInterface>(op).getMemRef();
204 memrefAccesses[memref].insert(node.id);
205 }
206 for (Operation *op : collector.storeOpInsts) {
207 node.stores.push_back(op);
208 auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
209 memrefAccesses[memref].insert(node.id);
210 }
211 for (Operation *op : collector.memrefLoads) {
212 SmallVector<Value> effectedValues;
213 getEffectedValues<MemoryEffects::Read>(op, effectedValues);
214 if (llvm::any_of(((ValueRange)effectedValues).getTypes(),
215 [](Type type) { return !isa<MemRefType>(type); }))
216 // We do not know the interaction here.
217 return nullptr;
218 for (Value memref : effectedValues)
219 memrefAccesses[memref].insert(node.id);
220 node.memrefLoads.push_back(op);
221 }
222 for (Operation *op : collector.memrefStores) {
223 SmallVector<Value> effectedValues;
224 getEffectedValues<MemoryEffects::Write>(op, effectedValues);
225 if (llvm::any_of((ValueRange(effectedValues)).getTypes(),
226 [](Type type) { return !isa<MemRefType>(type); }))
227 return nullptr;
228 for (Value memref : effectedValues)
229 memrefAccesses[memref].insert(node.id);
230 node.memrefStores.push_back(op);
231 }
232 for (Operation *op : collector.memrefFrees) {
233 SmallVector<Value> effectedValues;
234 getEffectedValues<MemoryEffects::Free>(op, effectedValues);
235 if (llvm::any_of((ValueRange(effectedValues)).getTypes(),
236 [](Type type) { return !isa<MemRefType>(type); }))
237 return nullptr;
238 for (Value memref : effectedValues)
239 memrefAccesses[memref].insert(node.id);
240 node.memrefFrees.push_back(op);
241 }
242
243 return &node;
244}
245
246/// Returns the memref being read/written by a memref/affine load/store op.
247static Value getMemRef(Operation *memOp) {
248 if (auto memrefLoad = dyn_cast<memref::LoadOp>(memOp))
249 return memrefLoad.getMemRef();
250 if (auto affineLoad = dyn_cast<AffineReadOpInterface>(memOp))
251 return affineLoad.getMemRef();
252 if (auto memrefStore = dyn_cast<memref::StoreOp>(memOp))
253 return memrefStore.getMemRef();
254 if (auto affineStore = dyn_cast<AffineWriteOpInterface>(memOp))
255 return affineStore.getMemRef();
256 llvm_unreachable("unexpected op");
257}
258
259/// Returns true if there may be a dependence on `memref` from srcNode's
260/// memory ops to dstNode's memory ops, while using the affine memory
261/// dependence analysis checks. The method assumes that there is at least one
262/// memory op in srcNode's loads and stores on `memref`, and similarly for
263/// `dstNode`. `srcNode.op` and `destNode.op` are expected to be nested in the
264/// same block and so the dependences are tested at the depth of that block.
265static bool mayDependence(const Node &srcNode, const Node &dstNode,
266 Value memref) {
267 assert(srcNode.op->getBlock() == dstNode.op->getBlock());
268 if (!isa<AffineForOp>(srcNode.op) || !isa<AffineForOp>(dstNode.op))
269 return true;
270
271 // Conservatively handle dependences involving non-affine load/stores. Return
272 // true if there exists a conflicting read/write access involving such.
273
274 // Check whether there is a dependence from a source read/write op to a
275 // destination read/write one; all expected to be memref/affine load/store.
276 auto hasNonAffineDep = [&](ArrayRef<Operation *> srcMemOps,
277 ArrayRef<Operation *> dstMemOps) {
278 return llvm::any_of(srcMemOps, [&](Operation *srcOp) {
279 Value srcMemref = getMemRef(srcOp);
280 if (srcMemref != memref)
281 return false;
282 return llvm::find_if(dstMemOps, [&](Operation *dstOp) {
283 return srcMemref == getMemRef(dstOp);
284 }) != dstMemOps.end();
285 });
286 };
287
289 // Between non-affine src stores and dst load/store.
290 llvm::append_range(dstOps, llvm::concat<Operation *const>(
291 dstNode.loads, dstNode.stores,
292 dstNode.memrefLoads, dstNode.memrefStores));
293 if (hasNonAffineDep(srcNode.memrefStores, dstOps))
294 return true;
295 // Between non-affine loads and dst stores.
296 dstOps.clear();
297 llvm::append_range(dstOps, llvm::concat<Operation *const>(
298 dstNode.stores, dstNode.memrefStores));
299 if (hasNonAffineDep(srcNode.memrefLoads, dstOps))
300 return true;
301 // Between affine stores and memref load/stores.
302 dstOps.clear();
303 llvm::append_range(dstOps, llvm::concat<Operation *const>(
304 dstNode.memrefLoads, dstNode.memrefStores));
305 if (hasNonAffineDep(srcNode.stores, dstOps))
306 return true;
307 // Between affine loads and memref stores.
308 dstOps.clear();
309 llvm::append_range(dstOps, dstNode.memrefStores);
310 if (hasNonAffineDep(srcNode.loads, dstOps))
311 return true;
312
313 // Affine load/store pairs. We don't need to check for locally allocated
314 // memrefs since the dependence analysis here is between mem ops from
315 // srcNode's for op to dstNode's for op at the depth at which those
316 // `affine.for` ops are nested, i.e., dependences at depth `d + 1` where
317 // `d` is the number of common surrounding loops.
318 for (auto *srcMemOp :
319 llvm::concat<Operation *const>(srcNode.stores, srcNode.loads)) {
320 MemRefAccess srcAcc(srcMemOp);
321 if (srcAcc.memref != memref)
322 continue;
323 for (auto *destMemOp :
324 llvm::concat<Operation *const>(dstNode.stores, dstNode.loads)) {
325 MemRefAccess destAcc(destMemOp);
326 if (destAcc.memref != memref)
327 continue;
328 // Check for a top-level dependence between srcNode and destNode's ops.
330 srcAcc, destAcc, getNestingDepth(srcNode.op) + 1)))
331 return true;
332 }
333 }
334 return false;
335}
336
337bool MemRefDependenceGraph::init(bool fullAffineDependences) {
338 LDBG() << "--- Initializing MDG ---";
339 // Map from a memref to the set of ids of the nodes that have ops accessing
340 // the memref.
342
343 // Create graph nodes.
345 for (Operation &op : block) {
346 if (auto forOp = dyn_cast<AffineForOp>(op)) {
347 Node *node = addNodeToMDG(&op, *this, memrefAccesses);
348 if (!node)
349 return false;
350 forToNodeMap[&op] = node->id;
351 } else if (isa<AffineReadOpInterface>(op)) {
352 // Create graph node for top-level load op.
353 Node node(nextNodeId++, &op);
354 node.loads.push_back(&op);
355 auto memref = cast<AffineReadOpInterface>(op).getMemRef();
356 memrefAccesses[memref].insert(node.id);
357 nodes.insert({node.id, node});
358 } else if (isa<AffineWriteOpInterface>(op)) {
359 // Create graph node for top-level store op.
360 Node node(nextNodeId++, &op);
361 node.stores.push_back(&op);
362 auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
363 memrefAccesses[memref].insert(node.id);
364 nodes.insert({node.id, node});
365 } else if (op.getNumResults() > 0 && !op.use_empty()) {
366 // Create graph node for top-level producer of SSA values, which
367 // could be used by loop nest nodes.
368 Node *node = addNodeToMDG(&op, *this, memrefAccesses);
369 if (!node)
370 return false;
371 } else if (!isMemoryEffectFree(&op) &&
372 (op.getNumRegions() == 0 || isa<RegionBranchOpInterface>(op))) {
373 // Create graph node for top-level op unless it is known to be
374 // memory-effect free. This covers all unknown/unregistered ops,
375 // non-affine ops with memory effects, and region-holding ops with a
376 // well-defined control flow. During the fusion validity checks, edges
377 // to/from these ops get looked at.
378 Node *node = addNodeToMDG(&op, *this, memrefAccesses);
379 if (!node)
380 return false;
381 } else if (op.getNumRegions() != 0 && !isa<RegionBranchOpInterface>(op)) {
382 // Return false if non-handled/unknown region-holding ops are found. We
383 // won't know what such ops do or what its regions mean; for e.g., it may
384 // not be an imperative op.
385 LDBG() << "MDG init failed; unknown region-holding op found!";
386 return false;
387 }
388 // We aren't creating nodes for memory-effect free ops either with no
389 // regions (unless it has results being used) or those with branch op
390 // interface.
391 }
392
393 LDBG() << "Created " << nodes.size() << " nodes";
394
395 // Add dependence edges between nodes which produce SSA values and their
396 // users. Load ops can be considered as the ones producing SSA values.
397 for (auto &idAndNode : nodes) {
398 const Node &node = idAndNode.second;
399 // Stores don't define SSA values, skip them.
400 if (!node.stores.empty())
401 continue;
402 Operation *opInst = node.op;
403 for (Value value : opInst->getResults()) {
404 for (Operation *user : value.getUsers()) {
405 // Ignore users outside of the block.
406 if (block.getParent()->findAncestorOpInRegion(*user)->getBlock() !=
407 &block)
408 continue;
410 getAffineForIVs(*user, &loops);
411 // Find the surrounding affine.for nested immediately within the
412 // block.
413 auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
414 return loop->getBlock() == &block;
415 });
416 if (it == loops.end())
417 continue;
418 assert(forToNodeMap.count(*it) > 0 && "missing mapping");
419 unsigned userLoopNestId = forToNodeMap[*it];
420 addEdge(node.id, userLoopNestId, value);
421 }
422 }
423 }
424
425 // Walk memref access lists and add graph edges between dependent nodes.
426 for (auto &memrefAndList : memrefAccesses) {
427 unsigned n = memrefAndList.second.size();
428 Value srcMemRef = memrefAndList.first;
429 // Add edges between all dependent pairs among the node IDs on this memref.
430 for (unsigned i = 0; i < n; ++i) {
431 unsigned srcId = memrefAndList.second[i];
432 Node *srcNode = getNode(srcId);
433 bool srcHasStoreOrFree =
434 srcNode->hasStore(srcMemRef) || srcNode->hasFree(srcMemRef);
435 for (unsigned j = i + 1; j < n; ++j) {
436 unsigned dstId = memrefAndList.second[j];
437 Node *dstNode = getNode(dstId);
438 bool dstHasStoreOrFree =
439 dstNode->hasStore(srcMemRef) || dstNode->hasFree(srcMemRef);
440 if ((srcHasStoreOrFree || dstHasStoreOrFree)) {
441 // Check precise affine deps if asked for; otherwise, conservative.
442 if (!fullAffineDependences ||
443 mayDependence(*srcNode, *dstNode, srcMemRef))
444 addEdge(srcId, dstId, srcMemRef);
445 }
446 }
447 }
448 }
449 return true;
450}
451
452// Returns the graph node for 'id'.
453const Node *MemRefDependenceGraph::getNode(unsigned id) const {
454 auto it = nodes.find(id);
455 assert(it != nodes.end());
456 return &it->second;
457}
458
459// Returns the graph node for 'forOp'.
460const Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) const {
461 for (auto &idAndNode : nodes)
462 if (idAndNode.second.op == forOp)
463 return &idAndNode.second;
464 return nullptr;
465}
466
467// Adds a node with 'op' to the graph and returns its unique identifier.
469 Node node(nextNodeId++, op);
470 nodes.insert({node.id, node});
471 return node.id;
472}
473
474// Remove node 'id' (and its associated edges) from graph.
476 // Remove each edge in 'inEdges[id]'.
477 if (inEdges.count(id) > 0) {
478 SmallVector<Edge, 2> oldInEdges = inEdges[id];
479 for (auto &inEdge : oldInEdges) {
480 removeEdge(inEdge.id, id, inEdge.value);
481 }
482 }
483 // Remove each edge in 'outEdges[id]'.
484 if (outEdges.contains(id)) {
485 SmallVector<Edge, 2> oldOutEdges = outEdges[id];
486 for (auto &outEdge : oldOutEdges) {
487 removeEdge(id, outEdge.id, outEdge.value);
488 }
489 }
490 // Erase remaining node state.
491 inEdges.erase(id);
492 outEdges.erase(id);
493 nodes.erase(id);
494}
495
496// Returns true if node 'id' writes to any memref which escapes (or is an
497// argument to) the block. Returns false otherwise.
499 const Node *node = getNode(id);
500 for (auto *storeOpInst : node->stores) {
501 auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
502 auto *op = memref.getDefiningOp();
503 // Return true if 'memref' is a block argument.
504 if (!op)
505 return true;
506 // Return true if any use of 'memref' does not deference it in an affine
507 // way.
508 for (auto *user : memref.getUsers())
509 if (!isa<AffineMapAccessInterface>(*user))
510 return true;
511 }
512 return false;
513}
514
515// Returns true iff there is an edge from node 'srcId' to node 'dstId' which
516// is for 'value' if non-null, or for any value otherwise. Returns false
517// otherwise.
518bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
519 Value value) const {
520 if (!outEdges.contains(srcId) || !inEdges.contains(dstId)) {
521 return false;
522 }
523 bool hasOutEdge = llvm::any_of(outEdges.lookup(srcId), [=](const Edge &edge) {
524 return edge.id == dstId && (!value || edge.value == value);
525 });
526 bool hasInEdge = llvm::any_of(inEdges.lookup(dstId), [=](const Edge &edge) {
527 return edge.id == srcId && (!value || edge.value == value);
528 });
529 return hasOutEdge && hasInEdge;
530}
531
532// Adds an edge from node 'srcId' to node 'dstId' for 'value'.
533void MemRefDependenceGraph::addEdge(unsigned srcId, unsigned dstId,
534 Value value) {
535 if (!hasEdge(srcId, dstId, value)) {
536 outEdges[srcId].push_back({dstId, value});
537 inEdges[dstId].push_back({srcId, value});
538 if (isa<MemRefType>(value.getType()))
539 memrefEdgeCount[value]++;
540 }
541}
542
543// Removes an edge from node 'srcId' to node 'dstId' for 'value'.
544void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
545 Value value) {
546 assert(inEdges.count(dstId) > 0);
547 assert(outEdges.count(srcId) > 0);
548 if (isa<MemRefType>(value.getType())) {
549 assert(memrefEdgeCount.count(value) > 0);
550 memrefEdgeCount[value]--;
551 }
552 // Remove 'srcId' from 'inEdges[dstId]'.
553 for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
554 if ((*it).id == srcId && (*it).value == value) {
555 inEdges[dstId].erase(it);
556 break;
557 }
558 }
559 // Remove 'dstId' from 'outEdges[srcId]'.
560 for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
561 if ((*it).id == dstId && (*it).value == value) {
562 outEdges[srcId].erase(it);
563 break;
564 }
565 }
566}
567
568// Returns true if there is a path in the dependence graph from node 'srcId'
569// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
570// operations that the edges connected are expected to be from the same block.
572 unsigned dstId) const {
573 // Worklist state is: <node-id, next-output-edge-index-to-visit>
575 worklist.push_back({srcId, 0});
576 Operation *dstOp = getNode(dstId)->op;
577 // Run DFS traversal to see if 'dstId' is reachable from 'srcId'.
578 while (!worklist.empty()) {
579 auto &idAndIndex = worklist.back();
580 // Return true if we have reached 'dstId'.
581 if (idAndIndex.first == dstId)
582 return true;
583 // Pop and continue if node has no out edges, or if all out edges have
584 // already been visited.
585 if (!outEdges.contains(idAndIndex.first) ||
586 idAndIndex.second == outEdges.lookup(idAndIndex.first).size()) {
587 worklist.pop_back();
588 continue;
589 }
590 // Get graph edge to traverse.
591 const Edge edge = outEdges.lookup(idAndIndex.first)[idAndIndex.second];
592 // Increment next output edge index for 'idAndIndex'.
593 ++idAndIndex.second;
594 // Add node at 'edge.id' to the worklist. We don't need to consider
595 // nodes that are "after" dstId in the containing block; one can't have a
596 // path to `dstId` from any of those nodes.
597 bool afterDst = dstOp->isBeforeInBlock(getNode(edge.id)->op);
598 if (!afterDst && edge.id != idAndIndex.first)
599 worklist.push_back({edge.id, 0});
600 }
601 return false;
602}
603
604// Returns the input edge count for node 'id' and 'memref' from src nodes
605// which access 'memref' with a store operation.
607 Value memref) const {
608 unsigned inEdgeCount = 0;
609 for (const Edge &inEdge : inEdges.lookup(id)) {
610 if (inEdge.value == memref) {
611 const Node *srcNode = getNode(inEdge.id);
612 // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
613 if (srcNode->getStoreOpCount(memref) > 0)
614 ++inEdgeCount;
615 }
616 }
617 return inEdgeCount;
618}
619
620// Returns the output edge count for node 'id' and 'memref' (if non-null),
621// otherwise returns the total output edge count from node 'id'.
623 Value memref) const {
624 unsigned outEdgeCount = 0;
625 for (const auto &outEdge : outEdges.lookup(id))
626 if (!memref || outEdge.value == memref)
627 ++outEdgeCount;
628 return outEdgeCount;
629}
630
631/// Return all nodes which define SSA values used in node 'id'.
633 unsigned id, DenseSet<unsigned> &definingNodes) const {
634 for (const Edge &edge : inEdges.lookup(id))
635 // By definition of edge, if the edge value is a non-memref value,
636 // then the dependence is between a graph node which defines an SSA value
637 // and another graph node which uses the SSA value.
638 if (!isa<MemRefType>(edge.value.getType()))
639 definingNodes.insert(edge.id);
640}
641
642// Computes and returns an insertion point operation, before which the
643// the fused <srcId, dstId> loop nest can be inserted while preserving
644// dependences. Returns nullptr if no such insertion point is found.
645Operation *
647 unsigned dstId) const {
648 if (!outEdges.contains(srcId))
649 return getNode(dstId)->op;
650
651 // Skip if there is any defining node of 'dstId' that depends on 'srcId'.
652 DenseSet<unsigned> definingNodes;
653 gatherDefiningNodes(dstId, definingNodes);
654 if (llvm::any_of(definingNodes,
655 [&](unsigned id) { return hasDependencePath(srcId, id); })) {
656 LDBG() << "Can't fuse: a defining op with a user in the dst "
657 << "loop has dependence from the src loop";
658 return nullptr;
659 }
660
661 // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
663 for (auto &outEdge : outEdges.lookup(srcId))
664 if (outEdge.id != dstId)
665 srcDepInsts.insert(getNode(outEdge.id)->op);
666
667 // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
669 for (auto &inEdge : inEdges.lookup(dstId))
670 if (inEdge.id != srcId)
671 dstDepInsts.insert(getNode(inEdge.id)->op);
672
673 Operation *srcNodeInst = getNode(srcId)->op;
674 Operation *dstNodeInst = getNode(dstId)->op;
675
676 // Computing insertion point:
677 // *) Walk all operation positions in Block operation list in the
678 // range (src, dst). For each operation 'op' visited in this search:
679 // *) Store in 'firstSrcDepPos' the first position where 'op' has a
680 // dependence edge from 'srcNode'.
681 // *) Store in 'lastDstDepPost' the last position where 'op' has a
682 // dependence edge to 'dstNode'.
683 // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
684 // operation insertion point (or return null pointer if no such
685 // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
687 std::optional<unsigned> firstSrcDepPos;
688 std::optional<unsigned> lastDstDepPos;
689 unsigned pos = 0;
690 for (Block::iterator it = std::next(Block::iterator(srcNodeInst));
691 it != Block::iterator(dstNodeInst); ++it) {
692 Operation *op = &(*it);
693 if (srcDepInsts.count(op) > 0 && firstSrcDepPos == std::nullopt)
694 firstSrcDepPos = pos;
695 if (dstDepInsts.count(op) > 0)
696 lastDstDepPos = pos;
697 depInsts.push_back(op);
698 ++pos;
699 }
700
701 if (firstSrcDepPos.has_value()) {
702 if (lastDstDepPos.has_value()) {
703 if (*firstSrcDepPos <= *lastDstDepPos) {
704 // No valid insertion point exists which preserves dependences.
705 return nullptr;
706 }
707 }
708 // Return the insertion point at 'firstSrcDepPos'.
709 return depInsts[*firstSrcDepPos];
710 }
711 // No dependence targets in range (or only dst deps in range), return
712 // 'dstNodInst' insertion point.
713 return dstNodeInst;
714}
715
716// Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
717// taking into account that:
718// *) if 'removeSrcId' is true, 'srcId' will be removed after fusion,
719// *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a
720// private memref.
721void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId,
722 const DenseSet<Value> &privateMemRefs,
723 bool removeSrcId) {
724 // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'.
725 if (inEdges.count(srcId) > 0) {
726 SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
727 for (auto &inEdge : oldInEdges) {
728 // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
729 if (!privateMemRefs.contains(inEdge.value))
730 addEdge(inEdge.id, dstId, inEdge.value);
731 }
732 }
733 // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
734 // If 'srcId' is going to be removed, remap all the out edges to 'dstId'.
735 if (outEdges.count(srcId) > 0) {
736 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
737 for (auto &outEdge : oldOutEdges) {
738 // Remove any out edges from 'srcId' to 'dstId' across memrefs.
739 if (outEdge.id == dstId)
740 removeEdge(srcId, outEdge.id, outEdge.value);
741 else if (removeSrcId) {
742 addEdge(dstId, outEdge.id, outEdge.value);
743 removeEdge(srcId, outEdge.id, outEdge.value);
744 }
745 }
746 }
747 // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
748 // replaced by a private memref). These edges could come from nodes
749 // other than 'srcId' which were removed in the previous step.
750 if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) {
751 SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
752 for (auto &inEdge : oldInEdges)
753 if (privateMemRefs.count(inEdge.value) > 0)
754 removeEdge(inEdge.id, dstId, inEdge.value);
755 }
756}
757
758// Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
759// of sibling node 'sibId' into node 'dstId'.
760void MemRefDependenceGraph::updateEdges(unsigned sibId, unsigned dstId) {
761 // For each edge in 'inEdges[sibId]':
762 // *) Add new edge from source node 'inEdge.id' to 'dstNode'.
763 // *) Remove edge from source node 'inEdge.id' to 'sibNode'.
764 if (inEdges.count(sibId) > 0) {
765 SmallVector<Edge, 2> oldInEdges = inEdges[sibId];
766 for (auto &inEdge : oldInEdges) {
767 addEdge(inEdge.id, dstId, inEdge.value);
768 removeEdge(inEdge.id, sibId, inEdge.value);
769 }
770 }
771
772 // For each edge in 'outEdges[sibId]' to node 'id'
773 // *) Add new edge from 'dstId' to 'outEdge.id'.
774 // *) Remove edge from 'sibId' to 'outEdge.id'.
775 if (outEdges.count(sibId) > 0) {
776 SmallVector<Edge, 2> oldOutEdges = outEdges[sibId];
777 for (auto &outEdge : oldOutEdges) {
778 addEdge(dstId, outEdge.id, outEdge.value);
779 removeEdge(sibId, outEdge.id, outEdge.value);
780 }
781 }
782}
783
784// Adds ops in 'loads' and 'stores' to node at 'id'.
787 ArrayRef<Operation *> memrefLoads,
788 ArrayRef<Operation *> memrefStores,
789 ArrayRef<Operation *> memrefFrees) {
790 Node *node = getNode(id);
791 llvm::append_range(node->loads, loads);
792 llvm::append_range(node->stores, stores);
793 llvm::append_range(node->memrefLoads, memrefLoads);
794 llvm::append_range(node->memrefStores, memrefStores);
795 llvm::append_range(node->memrefFrees, memrefFrees);
796}
797
799 Node *node = getNode(id);
800 node->loads.clear();
801 node->stores.clear();
802}
803
804// Calls 'callback' for each input edge incident to node 'id' which carries a
805// memref dependence.
807 unsigned id, const std::function<void(Edge)> &callback) {
808 if (inEdges.count(id) > 0)
809 forEachMemRefEdge(inEdges.at(id), callback);
810}
811
812// Calls 'callback' for each output edge from node 'id' which carries a
813// memref dependence.
815 unsigned id, const std::function<void(Edge)> &callback) {
816 if (outEdges.count(id) > 0)
817 forEachMemRefEdge(outEdges.at(id), callback);
818}
819
820// Calls 'callback' for each edge in 'edges' which carries a memref
821// dependence.
823 ArrayRef<Edge> edges, const std::function<void(Edge)> &callback) {
824 for (const auto &edge : edges) {
825 // Skip if 'edge' is not a memref dependence edge.
826 if (!isa<MemRefType>(edge.value.getType()))
827 continue;
828 assert(nodes.count(edge.id) > 0);
829 // Visit current input edge 'edge'.
830 callback(edge);
831 }
832}
833
835 os << "\nMemRefDependenceGraph\n";
836 os << "\nNodes:\n";
837 for (const auto &idAndNode : nodes) {
838 os << "Node: " << idAndNode.first << "\n";
839 auto it = inEdges.find(idAndNode.first);
840 if (it != inEdges.end()) {
841 for (const auto &e : it->second)
842 os << " InEdge: " << e.id << " " << e.value << "\n";
843 }
844 it = outEdges.find(idAndNode.first);
845 if (it != outEdges.end()) {
846 for (const auto &e : it->second)
847 os << " OutEdge: " << e.id << " " << e.value << "\n";
848 }
849 }
850}
851
854 auto *currOp = op.getParentOp();
855 AffineForOp currAffineForOp;
856 // Traverse up the hierarchy collecting all 'affine.for' operation while
857 // skipping over 'affine.if' operations.
858 while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
859 if (auto currAffineForOp = dyn_cast<AffineForOp>(currOp))
860 loops->push_back(currAffineForOp);
861 currOp = currOp->getParentOp();
862 }
863 std::reverse(loops->begin(), loops->end());
864}
865
868 ops->clear();
869 Operation *currOp = op.getParentOp();
870
871 // Traverse up the hierarchy collecting all `affine.for`, `affine.if`, and
872 // affine.parallel operations.
873 while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
874 if (isa<AffineIfOp, AffineForOp, AffineParallelOp>(currOp))
875 ops->push_back(currOp);
876 currOp = currOp->getParentOp();
877 }
878 std::reverse(ops->begin(), ops->end());
879}
880
881// Populates 'cst' with FlatAffineValueConstraints which represent original
882// domain of the loop bounds that define 'ivs'.
884 FlatAffineValueConstraints &cst) const {
885 assert(!ivs.empty() && "Cannot have a slice without its IVs");
886 cst = FlatAffineValueConstraints(/*numDims=*/ivs.size(), /*numSymbols=*/0,
887 /*numLocals=*/0, ivs);
888 for (Value iv : ivs) {
889 AffineForOp loop = getForInductionVarOwner(iv);
890 assert(loop && "Expected affine for");
891 if (failed(cst.addAffineForOpDomain(loop)))
892 return failure();
893 }
894 return success();
895}
896
897// Populates 'cst' with FlatAffineValueConstraints which represent slice bounds.
898LogicalResult
900 assert(!lbOperands.empty());
901 // Adds src 'ivs' as dimension variables in 'cst'.
902 unsigned numDims = ivs.size();
903 // Adds operands (dst ivs and symbols) as symbols in 'cst'.
904 unsigned numSymbols = lbOperands[0].size();
905
907 // Append 'ivs' then 'operands' to 'values'.
908 values.append(lbOperands[0].begin(), lbOperands[0].end());
909 *cst = FlatAffineValueConstraints(numDims, numSymbols, 0, values);
910
911 // Add loop bound constraints for values which are loop IVs of the destination
912 // of fusion and equality constraints for symbols which are constants.
913 for (unsigned i = numDims, end = values.size(); i < end; ++i) {
914 Value value = values[i];
915 assert(cst->containsVar(value) && "value expected to be present");
916 if (isValidSymbol(value)) {
917 // Check if the symbol is a constant.
918 if (std::optional<int64_t> cOp = getConstantIntValue(value))
919 cst->addBound(BoundType::EQ, value, cOp.value());
920 } else if (auto loop = getForInductionVarOwner(value)) {
921 if (failed(cst->addAffineForOpDomain(loop)))
922 return failure();
923 }
924 }
925
926 // Add slices bounds on 'ivs' using maps 'lbs'/'ubs' with 'lbOperands[0]'
927 LogicalResult ret = cst->addSliceBounds(ivs, lbs, ubs, lbOperands[0]);
928 assert(succeeded(ret) &&
929 "should not fail as we never have semi-affine slice maps");
930 (void)ret;
931 return success();
932}
933
934// Clears state bounds and operand state.
936 lbs.clear();
937 ubs.clear();
938 lbOperands.clear();
939 ubOperands.clear();
940}
941
943 llvm::errs() << "\tIVs:\n";
944 for (Value iv : ivs)
945 llvm::errs() << "\t\t" << iv << "\n";
946
947 llvm::errs() << "\tLBs:\n";
948 for (auto en : llvm::enumerate(lbs)) {
949 llvm::errs() << "\t\t" << en.value() << "\n";
950 llvm::errs() << "\t\tOperands:\n";
951 for (Value lbOp : lbOperands[en.index()])
952 llvm::errs() << "\t\t\t" << lbOp << "\n";
953 }
954
955 llvm::errs() << "\tUBs:\n";
956 for (auto en : llvm::enumerate(ubs)) {
957 llvm::errs() << "\t\t" << en.value() << "\n";
958 llvm::errs() << "\t\tOperands:\n";
959 for (Value ubOp : ubOperands[en.index()])
960 llvm::errs() << "\t\t\t" << ubOp << "\n";
961 }
962}
963
964/// Fast check to determine if the computation slice is maximal. Returns true if
965/// each slice dimension maps to an existing dst dimension and both the src
966/// and the dst loops for those dimensions have the same bounds. Returns false
967/// if both the src and the dst loops don't have the same bounds. Returns
968/// std::nullopt if none of the above can be proven.
969std::optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
970 assert(lbs.size() == ubs.size() && !lbs.empty() && !ivs.empty() &&
971 "Unexpected number of lbs, ubs and ivs in slice");
972
973 for (unsigned i = 0, end = lbs.size(); i < end; ++i) {
974 AffineMap lbMap = lbs[i];
975 AffineMap ubMap = ubs[i];
976
977 // Check if this slice is just an equality along this dimension.
978 if (!lbMap || !ubMap || lbMap.getNumResults() != 1 ||
979 ubMap.getNumResults() != 1 ||
980 lbMap.getResult(0) + 1 != ubMap.getResult(0) ||
981 // The condition above will be true for maps describing a single
982 // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
983 // Make sure we skip those cases by checking that the lb result is not
984 // just a constant.
985 isa<AffineConstantExpr>(lbMap.getResult(0)))
986 return std::nullopt;
987
988 // Limited support: we expect the lb result to be just a loop dimension for
989 // now.
990 AffineDimExpr result = dyn_cast<AffineDimExpr>(lbMap.getResult(0));
991 if (!result)
992 return std::nullopt;
993
994 // Retrieve dst loop bounds.
995 AffineForOp dstLoop =
996 getForInductionVarOwner(lbOperands[i][result.getPosition()]);
997 if (!dstLoop)
998 return std::nullopt;
999 AffineMap dstLbMap = dstLoop.getLowerBoundMap();
1000 AffineMap dstUbMap = dstLoop.getUpperBoundMap();
1001
1002 // Retrieve src loop bounds.
1003 AffineForOp srcLoop = getForInductionVarOwner(ivs[i]);
1004 assert(srcLoop && "Expected affine for");
1005 AffineMap srcLbMap = srcLoop.getLowerBoundMap();
1006 AffineMap srcUbMap = srcLoop.getUpperBoundMap();
1007
1008 // Limited support: we expect simple src and dst loops with a single
1009 // constant component per bound for now.
1010 if (srcLbMap.getNumResults() != 1 || srcUbMap.getNumResults() != 1 ||
1011 dstLbMap.getNumResults() != 1 || dstUbMap.getNumResults() != 1)
1012 return std::nullopt;
1013
1014 AffineExpr srcLbResult = srcLbMap.getResult(0);
1015 AffineExpr dstLbResult = dstLbMap.getResult(0);
1016 AffineExpr srcUbResult = srcUbMap.getResult(0);
1017 AffineExpr dstUbResult = dstUbMap.getResult(0);
1018 if (!isa<AffineConstantExpr>(srcLbResult) ||
1019 !isa<AffineConstantExpr>(srcUbResult) ||
1020 !isa<AffineConstantExpr>(dstLbResult) ||
1021 !isa<AffineConstantExpr>(dstUbResult))
1022 return std::nullopt;
1023
1024 // Check if src and dst loop bounds are the same. If not, we can guarantee
1025 // that the slice is not maximal.
1026 if (srcLbResult != dstLbResult || srcUbResult != dstUbResult ||
1027 srcLoop.getStep() != dstLoop.getStep())
1028 return false;
1029 }
1030
1031 return true;
1032}
1033
1034/// Returns true if it is deterministically verified that the original iteration
1035/// space of the slice is contained within the new iteration space that is
1036/// created after fusing 'this' slice into its destination.
1037std::optional<bool> ComputationSliceState::isSliceValid() const {
1038 // Fast check to determine if the slice is valid. If the following conditions
1039 // are verified to be true, slice is declared valid by the fast check:
1040 // 1. Each slice loop is a single iteration loop bound in terms of a single
1041 // destination loop IV.
1042 // 2. Loop bounds of the destination loop IV (from above) and those of the
1043 // source loop IV are exactly the same.
1044 // If the fast check is inconclusive or false, we proceed with a more
1045 // expensive analysis.
1046 // TODO: Store the result of the fast check, as it might be used again in
1047 // `canRemoveSrcNodeAfterFusion`.
1048 std::optional<bool> isValidFastCheck = isSliceMaximalFastCheck();
1049 if (isValidFastCheck && *isValidFastCheck)
1050 return true;
1051
1052 // Create constraints for the source loop nest using which slice is computed.
1053 FlatAffineValueConstraints srcConstraints;
1054 // TODO: Store the source's domain to avoid computation at each depth.
1055 if (failed(getSourceAsConstraints(srcConstraints))) {
1056 LDBG() << "Unable to compute source's domain";
1057 return std::nullopt;
1058 }
1059 // TODO: Handle local vars in the source domains while using the 'projectOut'
1060 // utility below. Currently, aligning is not done assuming that there will be
1061 // no local vars in the source domain.
1062 if (srcConstraints.getNumLocalVars() != 0) {
1063 LDBG() << "Cannot handle locals in source domain";
1064 return std::nullopt;
1065 }
1066
1067 // Create constraints for the slice loop nest that would be created if the
1068 // fusion succeeds.
1069 FlatAffineValueConstraints sliceConstraints;
1070 if (failed(getAsConstraints(&sliceConstraints))) {
1071 LDBG() << "Unable to compute slice's domain";
1072 return std::nullopt;
1073 }
1074
1075 // Projecting out every dimension other than the 'ivs' to express slice's
1076 // domain completely in terms of source's IVs.
1077 sliceConstraints.projectOut(ivs.size(),
1078 sliceConstraints.getNumVars() - ivs.size());
1079 srcConstraints.projectOut(ivs.size(),
1080 srcConstraints.getNumVars() - ivs.size());
1081
1082 LDBG() << "Domain of the source of the slice:\n"
1083 << "Source constraints:" << srcConstraints
1084 << "\nDomain of the slice if this fusion succeeds "
1085 << "(expressed in terms of its source's IVs):\n"
1086 << "Slice constraints:" << sliceConstraints;
1087
1088 // TODO: Store 'srcSet' to avoid recalculating for each depth.
1089 PresburgerSet srcSet(srcConstraints);
1090 PresburgerSet sliceSet(sliceConstraints);
1091 PresburgerSet diffSet = sliceSet.subtract(srcSet);
1092
1093 if (!diffSet.isIntegerEmpty()) {
1094 LDBG() << "Incorrect slice";
1095 return false;
1096 }
1097 return true;
1098}
1099
1100/// Returns true if the computation slice encloses all the iterations of the
1101/// sliced loop nest. Returns false if it does not. Returns std::nullopt if it
1102/// cannot determine if the slice is maximal or not.
1103std::optional<bool> ComputationSliceState::isMaximal() const {
1104 // Fast check to determine if the computation slice is maximal. If the result
1105 // is inconclusive, we proceed with a more expensive analysis.
1106 std::optional<bool> isMaximalFastCheck = isSliceMaximalFastCheck();
1107 if (isMaximalFastCheck)
1108 return isMaximalFastCheck;
1109
1110 // Create constraints for the src loop nest being sliced.
1111 FlatAffineValueConstraints srcConstraints(/*numDims=*/ivs.size(),
1112 /*numSymbols=*/0,
1113 /*numLocals=*/0, ivs);
1114 for (Value iv : ivs) {
1115 AffineForOp loop = getForInductionVarOwner(iv);
1116 assert(loop && "Expected affine for");
1117 if (failed(srcConstraints.addAffineForOpDomain(loop)))
1118 return std::nullopt;
1119 }
1120
1121 // Create constraints for the slice using the dst loop nest information. We
1122 // retrieve existing dst loops from the lbOperands.
1123 SmallVector<Value> consumerIVs;
1124 for (Value lbOp : lbOperands[0])
1125 if (getForInductionVarOwner(lbOp))
1126 consumerIVs.push_back(lbOp);
1127
1128 // Add empty IV Values for those new loops that are not equalities and,
1129 // therefore, are not yet materialized in the IR.
1130 for (int i = consumerIVs.size(), end = ivs.size(); i < end; ++i)
1131 consumerIVs.push_back(Value());
1132
1133 FlatAffineValueConstraints sliceConstraints(/*numDims=*/consumerIVs.size(),
1134 /*numSymbols=*/0,
1135 /*numLocals=*/0, consumerIVs);
1136
1137 if (failed(sliceConstraints.addDomainFromSliceMaps(lbs, ubs, lbOperands[0])))
1138 return std::nullopt;
1139
1140 if (srcConstraints.getNumDimVars() != sliceConstraints.getNumDimVars())
1141 // Constraint dims are different. The integer set difference can't be
1142 // computed so we don't know if the slice is maximal.
1143 return std::nullopt;
1144
1145 // Compute the difference between the src loop nest and the slice integer
1146 // sets.
1147 PresburgerSet srcSet(srcConstraints);
1148 PresburgerSet sliceSet(sliceConstraints);
1149 PresburgerSet diffSet = srcSet.subtract(sliceSet);
1150 return diffSet.isIntegerEmpty();
1151}
1152
1153unsigned MemRefRegion::getRank() const {
1154 return cast<MemRefType>(memref.getType()).getRank();
1155}
1156
1159 auto memRefType = cast<MemRefType>(memref.getType());
1160 MLIRContext *context = memref.getContext();
1161 unsigned rank = memRefType.getRank();
1162 if (shape)
1163 shape->reserve(rank);
1164
1165 assert(rank == cst.getNumDimVars() && "inconsistent memref region");
1166
1167 // Use a copy of the region constraints that has upper/lower bounds for each
1168 // memref dimension with static size added to guard against potential
1169 // over-approximation from projection or union bounding box. We may not add
1170 // this on the region itself since they might just be redundant constraints
1171 // that will need non-trivials means to eliminate.
1172 FlatLinearValueConstraints cstWithShapeBounds(cst);
1173 for (unsigned r = 0; r < rank; r++) {
1174 cstWithShapeBounds.addBound(BoundType::LB, r, 0);
1175 int64_t dimSize = memRefType.getDimSize(r);
1176 if (ShapedType::isDynamic(dimSize))
1177 continue;
1178 cstWithShapeBounds.addBound(BoundType::UB, r, dimSize - 1);
1179 }
1180
1181 // Find a constant upper bound on the extent of this memref region along
1182 // each dimension.
1183 int64_t numElements = 1;
1184 int64_t diffConstant;
1185 for (unsigned d = 0; d < rank; d++) {
1186 AffineMap lb;
1187 std::optional<int64_t> diff =
1188 cstWithShapeBounds.getConstantBoundOnDimSize(context, d, &lb);
1189 if (diff.has_value()) {
1190 diffConstant = *diff;
1191 assert(diffConstant >= 0 && "dim size bound cannot be negative");
1192 } else {
1193 // If no constant bound is found, then it can always be bound by the
1194 // memref's dim size if the latter has a constant size along this dim.
1195 auto dimSize = memRefType.getDimSize(d);
1196 if (ShapedType::isDynamic(dimSize))
1197 return std::nullopt;
1198 diffConstant = dimSize;
1199 // Lower bound becomes 0.
1200 lb = AffineMap::get(/*dimCount=*/0, cstWithShapeBounds.getNumSymbolVars(),
1201 /*result=*/getAffineConstantExpr(0, context));
1202 }
1203 numElements *= diffConstant;
1204 // Populate outputs if available.
1205 if (lbs)
1206 lbs->push_back(lb);
1207 if (shape)
1208 shape->push_back(diffConstant);
1209 }
1210 return numElements;
1211}
1212
1214 AffineMap &ubMap) const {
1215 assert(pos < cst.getNumDimVars() && "invalid position");
1216 auto memRefType = cast<MemRefType>(memref.getType());
1217 unsigned rank = memRefType.getRank();
1218
1219 assert(rank == cst.getNumDimVars() && "inconsistent memref region");
1220
1221 auto boundPairs = cst.getLowerAndUpperBound(
1222 pos, /*offset=*/0, /*num=*/rank, cst.getNumDimAndSymbolVars(),
1223 /*localExprs=*/{}, memRefType.getContext());
1224 lbMap = boundPairs.first;
1225 ubMap = boundPairs.second;
1226 assert(lbMap && "lower bound for a region must exist");
1227 assert(ubMap && "upper bound for a region must exist");
1228 assert(lbMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank);
1229 assert(ubMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank);
1230}
1231
1233 assert(memref == other.memref);
1234 return cst.unionBoundingBox(*other.getConstraints());
1235}
1236
1237/// Computes the memory region accessed by this memref with the region
1238/// represented as constraints symbolic/parametric in 'loopDepth' loops
1239/// surrounding opInst and any additional Function symbols.
1240// For example, the memref region for this load operation at loopDepth = 1 will
1241// be as below:
1242//
1243// affine.for %i = 0 to 32 {
1244// affine.for %ii = %i to (d0) -> (d0 + 8) (%i) {
1245// load %A[%ii]
1246// }
1247// }
1248//
1249// region: {memref = %A, write = false, {%i <= m0 <= %i + 7} }
1250// The last field is a 2-d FlatAffineValueConstraints symbolic in %i.
1251//
1252// TODO: extend this to any other memref dereferencing ops
1253// (dma_start, dma_wait).
1254LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
1255 const ComputationSliceState *sliceState,
1256 bool addMemRefDimBounds, bool dropLocalVars,
1257 bool dropOuterIvs) {
1258 assert((isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) &&
1259 "affine read/write op expected");
1260
1261 MemRefAccess access(op);
1262 memref = access.memref;
1263 write = access.isStore();
1264
1265 unsigned rank = access.getRank();
1266
1267 LDBG() << "MemRefRegion::compute: " << *op << " depth: " << loopDepth;
1268
1269 // 0-d memrefs.
1270 if (rank == 0) {
1272 getAffineIVs(*op, ivs);
1273 assert(loopDepth <= ivs.size() && "invalid 'loopDepth'");
1274 // The first 'loopDepth' IVs are symbols for this region.
1275 ivs.resize(loopDepth);
1276 // A 0-d memref has a 0-d region.
1277 cst = FlatAffineValueConstraints(rank, loopDepth, /*numLocals=*/0, ivs);
1278 return success();
1279 }
1280
1281 // Build the constraints for this region.
1282 AffineValueMap accessValueMap;
1283 access.getAccessMap(&accessValueMap);
1284 AffineMap accessMap = accessValueMap.getAffineMap();
1285
1286 unsigned numDims = accessMap.getNumDims();
1287 unsigned numSymbols = accessMap.getNumSymbols();
1288 unsigned numOperands = accessValueMap.getNumOperands();
1289 // Merge operands with slice operands.
1290 SmallVector<Value, 4> operands;
1291 operands.resize(numOperands);
1292 for (unsigned i = 0; i < numOperands; ++i)
1293 operands[i] = accessValueMap.getOperand(i);
1294
1295 if (sliceState != nullptr) {
1296 operands.reserve(operands.size() + sliceState->lbOperands[0].size());
1297 // Append slice operands to 'operands' as symbols.
1298 for (auto extraOperand : sliceState->lbOperands[0]) {
1299 if (!llvm::is_contained(operands, extraOperand)) {
1300 operands.push_back(extraOperand);
1301 numSymbols++;
1302 }
1303 }
1304 }
1305 // We'll first associate the dims and symbols of the access map to the dims
1306 // and symbols resp. of cst. This will change below once cst is
1307 // fully constructed out.
1308 cst = FlatAffineValueConstraints(numDims, numSymbols, 0, operands);
1309
1310 // Add equality constraints.
1311 // Add inequalities for loop lower/upper bounds.
1312 for (unsigned i = 0; i < numDims + numSymbols; ++i) {
1313 auto operand = operands[i];
1314 if (auto affineFor = getForInductionVarOwner(operand)) {
1315 // Note that cst can now have more dimensions than accessMap if the
1316 // bounds expressions involve outer loops or other symbols.
1317 // TODO: rewrite this to use getInstIndexSet; this way
1318 // conditionals will be handled when the latter supports it.
1319 if (failed(cst.addAffineForOpDomain(affineFor)))
1320 return failure();
1321 } else if (auto parallelOp = getAffineParallelInductionVarOwner(operand)) {
1322 if (failed(cst.addAffineParallelOpDomain(parallelOp)))
1323 return failure();
1324 } else if (isValidSymbol(operand)) {
1325 // Check if the symbol is a constant.
1326 Value symbol = operand;
1327 if (auto constVal = getConstantIntValue(symbol))
1328 cst.addBound(BoundType::EQ, symbol, constVal.value());
1329 } else {
1330 LDBG() << "unknown affine dimensional value";
1331 return failure();
1332 }
1333 }
1334
1335 // Add lower/upper bounds on loop IVs using bounds from 'sliceState'.
1336 if (sliceState != nullptr) {
1337 // Add dim and symbol slice operands.
1338 for (auto operand : sliceState->lbOperands[0]) {
1339 if (failed(cst.addInductionVarOrTerminalSymbol(operand)))
1340 return failure();
1341 }
1342 // Add upper/lower bounds from 'sliceState' to 'cst'.
1343 LogicalResult ret =
1344 cst.addSliceBounds(sliceState->ivs, sliceState->lbs, sliceState->ubs,
1345 sliceState->lbOperands[0]);
1346 assert(succeeded(ret) &&
1347 "should not fail as we never have semi-affine slice maps");
1348 (void)ret;
1349 }
1350
1351 // Add access function equalities to connect loop IVs to data dimensions.
1352 if (failed(cst.composeMap(&accessValueMap))) {
1353 op->emitError("getMemRefRegion: compose affine map failed");
1354 LDBG() << "Access map: " << accessValueMap.getAffineMap();
1355 return failure();
1356 }
1357
1358 // Set all variables appearing after the first 'rank' variables as
1359 // symbolic variables - so that the ones corresponding to the memref
1360 // dimensions are the dimensional variables for the memref region.
1361 cst.setDimSymbolSeparation(cst.getNumDimAndSymbolVars() - rank);
1362
1363 // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
1364 // this memref region is symbolic.
1365 SmallVector<Value, 4> enclosingIVs;
1366 getAffineIVs(*op, enclosingIVs);
1367 assert(loopDepth <= enclosingIVs.size() && "invalid loop depth");
1368 enclosingIVs.resize(loopDepth);
1370 cst.getValues(cst.getNumDimVars(), cst.getNumDimAndSymbolVars(), &vars);
1371 for (auto en : llvm::enumerate(vars)) {
1372 if ((isAffineInductionVar(en.value())) &&
1373 !llvm::is_contained(enclosingIVs, en.value())) {
1374 if (dropOuterIvs) {
1375 cst.projectOut(en.value());
1376 } else {
1377 unsigned varPosition;
1378 cst.findVar(en.value(), &varPosition);
1379 auto varKind = cst.getVarKindAt(varPosition);
1380 varPosition -= cst.getNumDimVars();
1381 cst.convertToLocal(varKind, varPosition, varPosition + 1);
1382 }
1383 }
1384 }
1385
1386 // Project out any local variables (these would have been added for any
1387 // mod/divs) if specified.
1388 if (dropLocalVars)
1389 cst.projectOut(cst.getNumDimAndSymbolVars(), cst.getNumLocalVars());
1390
1391 // Constant fold any symbolic variables.
1392 cst.constantFoldVarRange(/*pos=*/cst.getNumDimVars(),
1393 /*num=*/cst.getNumSymbolVars());
1394
1395 assert(cst.getNumDimVars() == rank && "unexpected MemRefRegion format");
1396
1397 // Add upper/lower bounds for each memref dimension with static size
1398 // to guard against potential over-approximation from projection.
1399 // TODO: Support dynamic memref dimensions.
1400 if (addMemRefDimBounds) {
1401 auto memRefType = cast<MemRefType>(memref.getType());
1402 for (unsigned r = 0; r < rank; r++) {
1403 cst.addBound(BoundType::LB, /*pos=*/r, /*value=*/0);
1404 if (memRefType.isDynamicDim(r))
1405 continue;
1406 cst.addBound(BoundType::UB, /*pos=*/r, memRefType.getDimSize(r) - 1);
1407 }
1408 }
1409 cst.removeTrivialRedundancy();
1410
1411 LDBG() << "Memory region: " << cst;
1412 return success();
1413}
1414
1415std::optional<int64_t>
1417 auto elementType = memRefType.getElementType();
1418
1419 unsigned sizeInBits;
1420 if (elementType.isIntOrFloat()) {
1421 sizeInBits = elementType.getIntOrFloatBitWidth();
1422 } else if (auto vectorType = dyn_cast<VectorType>(elementType)) {
1423 if (vectorType.getElementType().isIntOrFloat())
1424 sizeInBits =
1425 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
1426 else
1427 return std::nullopt;
1428 } else {
1429 return std::nullopt;
1430 }
1431 return llvm::divideCeil(sizeInBits, 8);
1432}
1433
1434// Returns the size of the region.
1435std::optional<int64_t> MemRefRegion::getRegionSize() {
1436 auto memRefType = cast<MemRefType>(memref.getType());
1437
1438 if (!memRefType.getLayout().isIdentity()) {
1439 LDBG() << "Non-identity layout map not yet supported";
1440 return false;
1441 }
1442
1443 // Compute the extents of the buffer.
1444 std::optional<int64_t> numElements = getConstantBoundingSizeAndShape();
1445 if (!numElements) {
1446 LDBG() << "Dynamic shapes not yet supported";
1447 return std::nullopt;
1448 }
1449 auto eltSize = getMemRefIntOrFloatEltSizeInBytes(memRefType);
1450 if (!eltSize)
1451 return std::nullopt;
1452 return *eltSize * *numElements;
1453}
1454
1455/// Returns the size of memref data in bytes if it's statically shaped,
1456/// std::nullopt otherwise. If the element of the memref has vector type, takes
1457/// into account size of the vector as well.
1458// TODO: improve/complete this when we have target data.
1459std::optional<uint64_t>
1461 if (!memRefType.hasStaticShape())
1462 return std::nullopt;
1463 auto elementType = memRefType.getElementType();
1464 if (!elementType.isIntOrFloat() && !isa<VectorType>(elementType))
1465 return std::nullopt;
1466
1467 auto sizeInBytes = getMemRefIntOrFloatEltSizeInBytes(memRefType);
1468 if (!sizeInBytes)
1469 return std::nullopt;
1470 for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) {
1471 sizeInBytes = *sizeInBytes * memRefType.getDimSize(i);
1472 }
1473 return sizeInBytes;
1474}
1475
1476template <typename LoadOrStoreOp>
1477LogicalResult mlir::affine::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
1478 bool emitError) {
1479 static_assert(llvm::is_one_of<LoadOrStoreOp, AffineReadOpInterface,
1480 AffineWriteOpInterface>::value,
1481 "argument should be either a AffineReadOpInterface or a "
1482 "AffineWriteOpInterface");
1483
1484 Operation *op = loadOrStoreOp.getOperation();
1485 MemRefRegion region(op->getLoc());
1486 if (failed(region.compute(op, /*loopDepth=*/0, /*sliceState=*/nullptr,
1487 /*addMemRefDimBounds=*/false)))
1488 return success();
1489
1490 LDBG() << "Memory region: " << region.getConstraints();
1491
1492 bool outOfBounds = false;
1493 unsigned rank = loadOrStoreOp.getMemRefType().getRank();
1494
1495 // For each dimension, check for out of bounds.
1496 for (unsigned r = 0; r < rank; r++) {
1497 FlatAffineValueConstraints ucst(*region.getConstraints());
1498
1499 // Intersect memory region with constraint capturing out of bounds (both out
1500 // of upper and out of lower), and check if the constraint system is
1501 // feasible. If it is, there is at least one point out of bounds.
1502 SmallVector<int64_t, 4> ineq(rank + 1, 0);
1503 int64_t dimSize = loadOrStoreOp.getMemRefType().getDimSize(r);
1504 // TODO: handle dynamic dim sizes.
1505 if (dimSize == -1)
1506 continue;
1507
1508 // Check for overflow: d_i >= memref dim size.
1509 ucst.addBound(BoundType::LB, r, dimSize);
1510 outOfBounds = !ucst.isEmpty();
1511 if (outOfBounds && emitError) {
1512 loadOrStoreOp.emitOpError()
1513 << "memref out of upper bound access along dimension #" << (r + 1);
1514 }
1515
1516 // Check for a negative index.
1517 FlatAffineValueConstraints lcst(*region.getConstraints());
1518 llvm::fill(ineq, 0);
1519 // d_i <= -1;
1520 lcst.addBound(BoundType::UB, r, -1);
1521 outOfBounds = !lcst.isEmpty();
1522 if (outOfBounds && emitError) {
1523 loadOrStoreOp.emitOpError()
1524 << "memref out of lower bound access along dimension #" << (r + 1);
1525 }
1526 }
1527 return failure(outOfBounds);
1528}
1529
1530// Explicitly instantiate the template so that the compiler knows we need them!
1531template LogicalResult
1532mlir::affine::boundCheckLoadOrStoreOp(AffineReadOpInterface loadOp,
1533 bool emitError);
1534template LogicalResult
1535mlir::affine::boundCheckLoadOrStoreOp(AffineWriteOpInterface storeOp,
1536 bool emitError);
1537
1538// Returns in 'positions' the Block positions of 'op' in each ancestor
1539// Block from the Block containing operation, stopping at 'limitBlock'.
1540static void findInstPosition(Operation *op, Block *limitBlock,
1541 SmallVectorImpl<unsigned> *positions) {
1542 Block *block = op->getBlock();
1543 while (block != limitBlock) {
1544 // FIXME: This algorithm is unnecessarily O(n) and should be improved to not
1545 // rely on linear scans.
1546 int instPosInBlock = std::distance(block->begin(), op->getIterator());
1547 positions->push_back(instPosInBlock);
1548 op = block->getParentOp();
1549 block = op->getBlock();
1550 }
1551 std::reverse(positions->begin(), positions->end());
1552}
1553
1554// Returns the Operation in a possibly nested set of Blocks, where the
1555// position of the operation is represented by 'positions', which has a
1556// Block position for each level of nesting.
1558 unsigned level, Block *block) {
1559 unsigned i = 0;
1560 for (auto &op : *block) {
1561 if (i != positions[level]) {
1562 ++i;
1563 continue;
1564 }
1565 if (level == positions.size() - 1)
1566 return &op;
1567 if (auto childAffineForOp = dyn_cast<AffineForOp>(op))
1568 return getInstAtPosition(positions, level + 1,
1569 childAffineForOp.getBody());
1570
1571 for (auto &region : op.getRegions()) {
1572 for (auto &b : region)
1573 if (auto *ret = getInstAtPosition(positions, level + 1, &b))
1574 return ret;
1575 }
1576 return nullptr;
1577 }
1578 return nullptr;
1579}
1580
1581// Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'.
1584 for (unsigned i = 0, e = cst->getNumDimVars(); i < e; ++i) {
1585 auto value = cst->getValue(i);
1586 if (ivs.count(value) == 0) {
1587 assert(isAffineForInductionVar(value));
1588 auto loop = getForInductionVarOwner(value);
1589 if (failed(cst->addAffineForOpDomain(loop)))
1590 return failure();
1591 }
1592 }
1593 return success();
1594}
1595
1596/// Returns the innermost common loop depth for the set of operations in 'ops'.
1597// TODO: Move this to LoopUtils.
1599 ArrayRef<Operation *> ops, SmallVectorImpl<AffineForOp> *surroundingLoops) {
1600 unsigned numOps = ops.size();
1601 assert(numOps > 0 && "Expected at least one operation");
1602
1603 std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
1604 unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
1605 for (unsigned i = 0; i < numOps; ++i) {
1606 getAffineForIVs(*ops[i], &loops[i]);
1607 loopDepthLimit =
1608 std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
1609 }
1610
1611 unsigned loopDepth = 0;
1612 for (unsigned d = 0; d < loopDepthLimit; ++d) {
1613 unsigned i;
1614 for (i = 1; i < numOps; ++i) {
1615 if (loops[i - 1][d] != loops[i][d])
1616 return loopDepth;
1617 }
1618 if (surroundingLoops)
1619 surroundingLoops->push_back(loops[i - 1][d]);
1620 ++loopDepth;
1621 }
1622 return loopDepth;
1623}
1624
1625/// Computes in 'sliceUnion' the union of all slice bounds computed at
1626/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and
1627/// then verifies if it is valid. Returns 'SliceComputationResult::Success' if
1628/// union was computed correctly, an appropriate failure otherwise.
1631 ArrayRef<Operation *> opsB, unsigned loopDepth,
1632 unsigned numCommonLoops, bool isBackwardSlice,
1633 ComputationSliceState *sliceUnion) {
1634 // Compute the union of slice bounds between all pairs in 'opsA' and
1635 // 'opsB' in 'sliceUnionCst'.
1636 FlatAffineValueConstraints sliceUnionCst;
1637 assert(sliceUnionCst.getNumDimAndSymbolVars() == 0);
1638 std::vector<std::pair<Operation *, Operation *>> dependentOpPairs;
1639 MemRefAccess srcAccess;
1640 MemRefAccess dstAccess;
1641 for (Operation *a : opsA) {
1642 srcAccess = MemRefAccess(a);
1643 for (Operation *b : opsB) {
1644 dstAccess = MemRefAccess(b);
1645 if (srcAccess.memref != dstAccess.memref)
1646 continue;
1647 // Check if 'loopDepth' exceeds nesting depth of src/dst ops.
1648 if ((!isBackwardSlice && loopDepth > getNestingDepth(a)) ||
1649 (isBackwardSlice && loopDepth > getNestingDepth(b))) {
1650 LDBG() << "Invalid loop depth";
1652 }
1653
1654 bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) &&
1655 isa<AffineReadOpInterface>(dstAccess.opInst);
1656 FlatAffineValueConstraints dependenceConstraints;
1657 // Check dependence between 'srcAccess' and 'dstAccess'.
1659 srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1,
1660 &dependenceConstraints, /*dependenceComponents=*/nullptr,
1661 /*allowRAR=*/readReadAccesses);
1662 if (result.value == DependenceResult::Failure) {
1663 LDBG() << "Dependence check failed";
1665 }
1667 continue;
1668 dependentOpPairs.emplace_back(a, b);
1669
1670 // Compute slice bounds for 'srcAccess' and 'dstAccess'.
1671 ComputationSliceState tmpSliceState;
1672 getComputationSliceState(a, b, dependenceConstraints, loopDepth,
1673 isBackwardSlice, &tmpSliceState);
1674
1675 if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
1676 // Initialize 'sliceUnionCst' with the bounds computed in previous step.
1677 if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
1678 LDBG() << "Unable to compute slice bound constraints";
1680 }
1681 assert(sliceUnionCst.getNumDimAndSymbolVars() > 0);
1682 continue;
1683 }
1684
1685 // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
1686 FlatAffineValueConstraints tmpSliceCst;
1687 if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
1688 LDBG() << "Unable to compute slice bound constraints";
1690 }
1691
1692 // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
1693 if (!sliceUnionCst.areVarsAlignedWithOther(tmpSliceCst)) {
1694
1695 // Pre-constraint var alignment: record loop IVs used in each constraint
1696 // system.
1697 SmallPtrSet<Value, 8> sliceUnionIVs;
1698 for (unsigned k = 0, l = sliceUnionCst.getNumDimVars(); k < l; ++k)
1699 sliceUnionIVs.insert(sliceUnionCst.getValue(k));
1700 SmallPtrSet<Value, 8> tmpSliceIVs;
1701 for (unsigned k = 0, l = tmpSliceCst.getNumDimVars(); k < l; ++k)
1702 tmpSliceIVs.insert(tmpSliceCst.getValue(k));
1703
1704 sliceUnionCst.mergeAndAlignVarsWithOther(/*offset=*/0, &tmpSliceCst);
1705
1706 // Post-constraint var alignment: add loop IV bounds missing after
1707 // var alignment to constraint systems. This can occur if one constraint
1708 // system uses an loop IV that is not used by the other. The call
1709 // to unionBoundingBox below expects constraints for each Loop IV, even
1710 // if they are the unsliced full loop bounds added here.
1711 if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst)))
1713 if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst)))
1715 }
1716 // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
1717 if (sliceUnionCst.getNumLocalVars() > 0 ||
1718 tmpSliceCst.getNumLocalVars() > 0 ||
1719 failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
1720 LDBG() << "Unable to compute union bounding box of slice bounds";
1722 }
1723 }
1724 }
1725
1726 // Empty union.
1727 if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
1728 LDBG() << "empty slice union - unexpected";
1730 }
1731
1732 // Gather loops surrounding ops from loop nest where slice will be inserted.
1734 for (auto &dep : dependentOpPairs) {
1735 ops.push_back(isBackwardSlice ? dep.second : dep.first);
1736 }
1737 SmallVector<AffineForOp, 4> surroundingLoops;
1738 unsigned innermostCommonLoopDepth =
1739 getInnermostCommonLoopDepth(ops, &surroundingLoops);
1740 if (loopDepth > innermostCommonLoopDepth) {
1741 LDBG() << "Exceeds max loop depth";
1743 }
1744
1745 // Store 'numSliceLoopIVs' before converting dst loop IVs to dims.
1746 unsigned numSliceLoopIVs = sliceUnionCst.getNumDimVars();
1747
1748 // Convert any dst loop IVs which are symbol variables to dim variables.
1749 sliceUnionCst.convertLoopIVSymbolsToDims();
1750 sliceUnion->clearBounds();
1751 sliceUnion->lbs.resize(numSliceLoopIVs, AffineMap());
1752 sliceUnion->ubs.resize(numSliceLoopIVs, AffineMap());
1753
1754 // Get slice bounds from slice union constraints 'sliceUnionCst'.
1755 sliceUnionCst.getSliceBounds(/*offset=*/0, numSliceLoopIVs,
1756 opsA[0]->getContext(), &sliceUnion->lbs,
1757 &sliceUnion->ubs);
1758
1759 // Add slice bound operands of union.
1760 SmallVector<Value, 4> sliceBoundOperands;
1761 sliceUnionCst.getValues(numSliceLoopIVs,
1762 sliceUnionCst.getNumDimAndSymbolVars(),
1763 &sliceBoundOperands);
1764
1765 // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'.
1766 sliceUnion->ivs.clear();
1767 sliceUnionCst.getValues(0, numSliceLoopIVs, &sliceUnion->ivs);
1768
1769 // Set loop nest insertion point to block start at 'loopDepth' for forward
1770 // slices, while at the end for backward slices.
1771 sliceUnion->insertPoint =
1772 isBackwardSlice
1773 ? surroundingLoops[loopDepth - 1].getBody()->begin()
1774 : std::prev(surroundingLoops[loopDepth - 1].getBody()->end());
1775
1776 // Give each bound its own copy of 'sliceBoundOperands' for subsequent
1777 // canonicalization.
1778 sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1779 sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1780
1781 // Check if the slice computed is valid. Return success only if it is verified
1782 // that the slice is valid, otherwise return appropriate failure status.
1783 std::optional<bool> isSliceValid = sliceUnion->isSliceValid();
1784 if (!isSliceValid) {
1785 LDBG() << "Cannot determine if the slice is valid";
1787 }
1788 if (!*isSliceValid)
1790
1792}
1793
1794// TODO: extend this to handle multiple result maps.
1795static std::optional<uint64_t> getConstDifference(AffineMap lbMap,
1796 AffineMap ubMap) {
1797 assert(lbMap.getNumResults() == 1 && "expected single result bound map");
1798 assert(ubMap.getNumResults() == 1 && "expected single result bound map");
1799 assert(lbMap.getNumDims() == ubMap.getNumDims());
1800 assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
1801 AffineExpr lbExpr(lbMap.getResult(0));
1802 AffineExpr ubExpr(ubMap.getResult(0));
1803 auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
1804 lbMap.getNumSymbols());
1805 auto cExpr = dyn_cast<AffineConstantExpr>(loopSpanExpr);
1806 if (!cExpr)
1807 return std::nullopt;
1808 return cExpr.getValue();
1809}
1810
1811// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
1812// nest surrounding represented by slice loop bounds in 'slice'. Returns true
1813// on success, false otherwise (if a non-constant trip count was encountered).
1814// TODO: Make this work with non-unit step loops.
1816 const ComputationSliceState &slice,
1817 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
1818 unsigned numSrcLoopIVs = slice.ivs.size();
1819 // Populate map from AffineForOp -> trip count
1820 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
1821 AffineForOp forOp = getForInductionVarOwner(slice.ivs[i]);
1822 auto *op = forOp.getOperation();
1823 AffineMap lbMap = slice.lbs[i];
1824 AffineMap ubMap = slice.ubs[i];
1825 // If lower or upper bound maps are null or provide no results, it implies
1826 // that source loop was not at all sliced, and the entire loop will be a
1827 // part of the slice.
1828 if (!lbMap || lbMap.getNumResults() == 0 || !ubMap ||
1829 ubMap.getNumResults() == 0) {
1830 // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
1831 if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
1832 (*tripCountMap)[op] =
1833 forOp.getConstantUpperBound() - forOp.getConstantLowerBound();
1834 continue;
1835 }
1836 std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
1837 if (maybeConstTripCount.has_value()) {
1838 (*tripCountMap)[op] = *maybeConstTripCount;
1839 continue;
1840 }
1841 return false;
1842 }
1843 std::optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
1844 // Slice bounds are created with a constant ub - lb difference.
1845 if (!tripCount.has_value())
1846 return false;
1847 (*tripCountMap)[op] = *tripCount;
1848 }
1849 return true;
1850}
1851
1852// Return the number of iterations in the given slice.
1854 const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
1855 uint64_t iterCount = 1;
1856 for (const auto &count : sliceTripCountMap) {
1857 iterCount *= count.second;
1858 }
1859 return iterCount;
1860}
1861
1862const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";
1863// Computes slice bounds by projecting out any loop IVs from
1864// 'dependenceConstraints' at depth greater than 'loopDepth', and computes slice
1865// bounds in 'sliceState' which represent the one loop nest's IVs in terms of
1866// the other loop nest's IVs, symbols and constants (using 'isBackwardsSlice').
1868 Operation *depSourceOp, Operation *depSinkOp,
1869 const FlatAffineValueConstraints &dependenceConstraints, unsigned loopDepth,
1870 bool isBackwardSlice, ComputationSliceState *sliceState) {
1871 // Get loop nest surrounding src operation.
1872 SmallVector<AffineForOp, 4> srcLoopIVs;
1873 getAffineForIVs(*depSourceOp, &srcLoopIVs);
1874 unsigned numSrcLoopIVs = srcLoopIVs.size();
1875
1876 // Get loop nest surrounding dst operation.
1877 SmallVector<AffineForOp, 4> dstLoopIVs;
1878 getAffineForIVs(*depSinkOp, &dstLoopIVs);
1879 unsigned numDstLoopIVs = dstLoopIVs.size();
1880
1881 assert((!isBackwardSlice && loopDepth <= numSrcLoopIVs) ||
1882 (isBackwardSlice && loopDepth <= numDstLoopIVs));
1883
1884 // Project out dimensions other than those up to 'loopDepth'.
1885 unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth;
1886 unsigned num =
1887 isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth;
1888 FlatAffineValueConstraints sliceCst(dependenceConstraints);
1889 sliceCst.projectOut(pos, num);
1890
1891 // Add slice loop IV values to 'sliceState'.
1892 unsigned offset = isBackwardSlice ? 0 : loopDepth;
1893 unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs;
1894 sliceCst.getValues(offset, offset + numSliceLoopIVs, &sliceState->ivs);
1895
1896 // Set up lower/upper bound affine maps for the slice.
1897 sliceState->lbs.resize(numSliceLoopIVs, AffineMap());
1898 sliceState->ubs.resize(numSliceLoopIVs, AffineMap());
1899
1900 // Get bounds for slice IVs in terms of other IVs, symbols, and constants.
1901 sliceCst.getSliceBounds(offset, numSliceLoopIVs, depSourceOp->getContext(),
1902 &sliceState->lbs, &sliceState->ubs);
1903
1904 // Set up bound operands for the slice's lower and upper bounds.
1905 SmallVector<Value, 4> sliceBoundOperands;
1906 unsigned numDimsAndSymbols = sliceCst.getNumDimAndSymbolVars();
1907 for (unsigned i = 0; i < numDimsAndSymbols; ++i) {
1908 if (i < offset || i >= offset + numSliceLoopIVs)
1909 sliceBoundOperands.push_back(sliceCst.getValue(i));
1910 }
1911
1912 // Give each bound its own copy of 'sliceBoundOperands' for subsequent
1913 // canonicalization.
1914 sliceState->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1915 sliceState->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
1916
1917 // Set destination loop nest insertion point to block start at 'dstLoopDepth'.
1918 sliceState->insertPoint =
1919 isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin()
1920 : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end());
1921
1922 llvm::SmallDenseSet<Value, 8> sequentialLoops;
1923 if (isa<AffineReadOpInterface>(depSourceOp) &&
1924 isa<AffineReadOpInterface>(depSinkOp)) {
1925 // For read-read access pairs, clear any slice bounds on sequential loops.
1926 // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
1927 getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0],
1928 &sequentialLoops);
1929 }
1930 auto getSliceLoop = [&](unsigned i) {
1931 return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i];
1932 };
1933 auto isInnermostInsertion = [&]() {
1934 return (isBackwardSlice ? loopDepth >= srcLoopIVs.size()
1935 : loopDepth >= dstLoopIVs.size());
1936 };
1937 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
1938 auto srcIsUnitSlice = [&]() {
1939 return (buildSliceTripCountMap(*sliceState, &sliceTripCountMap) &&
1940 (getSliceIterationCount(sliceTripCountMap) == 1));
1941 };
1942 // Clear all sliced loop bounds beginning at the first sequential loop, or
1943 // first loop with a slice fusion barrier attribute..
1944
1945 for (unsigned i = 0; i < numSliceLoopIVs; ++i) {
1946 Value iv = getSliceLoop(i).getInductionVar();
1947 if (sequentialLoops.count(iv) == 0 &&
1948 getSliceLoop(i)->getAttr(kSliceFusionBarrierAttrName) == nullptr)
1949 continue;
1950 // Skip reset of bounds of reduction loop inserted in the destination loop
1951 // that meets the following conditions:
1952 // 1. Slice is single trip count.
1953 // 2. Loop bounds of the source and destination match.
1954 // 3. Is being inserted at the innermost insertion point.
1955 std::optional<bool> isMaximal = sliceState->isMaximal();
1956 if (isLoopParallelAndContainsReduction(getSliceLoop(i)) &&
1957 isInnermostInsertion() && srcIsUnitSlice() && isMaximal && *isMaximal)
1958 continue;
1959 for (unsigned j = i; j < numSliceLoopIVs; ++j) {
1960 sliceState->lbs[j] = AffineMap();
1961 sliceState->ubs[j] = AffineMap();
1962 }
1963 break;
1964 }
1965}
1966
1967/// Creates a computation slice of the loop nest surrounding 'srcOpInst',
1968/// updates the slice loop bounds with any non-null bound maps specified in
1969/// 'sliceState', and inserts this slice into the loop nest surrounding
1970/// 'dstOpInst' at loop depth 'dstLoopDepth'.
1971// TODO: extend the slicing utility to compute slices that
1972// aren't necessarily a one-to-one relation b/w the source and destination. The
1973// relation between the source and destination could be many-to-many in general.
1974// TODO: the slice computation is incorrect in the cases
1975// where the dependence from the source to the destination does not cover the
1976// entire destination index set. Subtract out the dependent destination
1977// iterations from destination index set and check for emptiness --- this is one
1978// solution.
1980 Operation *srcOpInst, Operation *dstOpInst, unsigned dstLoopDepth,
1981 ComputationSliceState *sliceState) {
1982 // Get loop nest surrounding src operation.
1983 SmallVector<AffineForOp, 4> srcLoopIVs;
1984 getAffineForIVs(*srcOpInst, &srcLoopIVs);
1985 unsigned numSrcLoopIVs = srcLoopIVs.size();
1986
1987 // Get loop nest surrounding dst operation.
1988 SmallVector<AffineForOp, 4> dstLoopIVs;
1989 getAffineForIVs(*dstOpInst, &dstLoopIVs);
1990 unsigned dstLoopIVsSize = dstLoopIVs.size();
1991 if (dstLoopDepth > dstLoopIVsSize) {
1992 dstOpInst->emitError("invalid destination loop depth");
1993 return AffineForOp();
1994 }
1995
1996 // Find the op block positions of 'srcOpInst' within 'srcLoopIVs'.
1997 SmallVector<unsigned, 4> positions;
1998 // TODO: This code is incorrect since srcLoopIVs can be 0-d.
1999 findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions);
2000
2001 // Clone src loop nest and insert it a the beginning of the operation block
2002 // of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
2003 auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1];
2004 OpBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin());
2005 auto sliceLoopNest =
2006 cast<AffineForOp>(b.clone(*srcLoopIVs[0].getOperation()));
2007
2008 Operation *sliceInst =
2009 getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody());
2010 // Get loop nest surrounding 'sliceInst'.
2011 SmallVector<AffineForOp, 4> sliceSurroundingLoops;
2012 getAffineForIVs(*sliceInst, &sliceSurroundingLoops);
2013
2014 // Sanity check.
2015 unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size();
2016 (void)sliceSurroundingLoopsSize;
2017 assert(dstLoopDepth + numSrcLoopIVs >= sliceSurroundingLoopsSize);
2018 unsigned sliceLoopLimit = dstLoopDepth + numSrcLoopIVs;
2019 (void)sliceLoopLimit;
2020 assert(sliceLoopLimit >= sliceSurroundingLoopsSize);
2021
2022 // Update loop bounds for loops in 'sliceLoopNest'.
2023 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
2024 auto forOp = sliceSurroundingLoops[dstLoopDepth + i];
2025 if (AffineMap lbMap = sliceState->lbs[i])
2026 forOp.setLowerBound(sliceState->lbOperands[i], lbMap);
2027 if (AffineMap ubMap = sliceState->ubs[i])
2028 forOp.setUpperBound(sliceState->ubOperands[i], ubMap);
2029 }
2030 return sliceLoopNest;
2031}
2032
2033// Constructs MemRefAccess populating it with the memref, its indices and
2034// opinst from 'loadOrStoreOpInst'.
2036 if (auto loadOp = dyn_cast<AffineReadOpInterface>(memOp)) {
2037 memref = loadOp.getMemRef();
2038 opInst = memOp;
2039 llvm::append_range(indices, loadOp.getMapOperands());
2040 } else {
2041 assert(isa<AffineWriteOpInterface>(memOp) &&
2042 "Affine read/write op expected");
2043 auto storeOp = cast<AffineWriteOpInterface>(memOp);
2044 opInst = memOp;
2045 memref = storeOp.getMemRef();
2046 llvm::append_range(indices, storeOp.getMapOperands());
2047 }
2048}
2049
2050unsigned MemRefAccess::getRank() const {
2051 return cast<MemRefType>(memref.getType()).getRank();
2052}
2053
2055 return isa<AffineWriteOpInterface>(opInst);
2056}
2057
2058/// Returns the nesting depth of this statement, i.e., the number of loops
2059/// surrounding this statement.
2061 Operation *currOp = op;
2062 unsigned depth = 0;
2063 while ((currOp = currOp->getParentOp())) {
2064 if (isa<AffineForOp>(currOp))
2065 depth++;
2066 if (auto parOp = dyn_cast<AffineParallelOp>(currOp))
2067 depth += parOp.getNumDims();
2068 }
2069 return depth;
2070}
2071
2072/// Equal if both affine accesses are provably equivalent (at compile
2073/// time) when considering the memref, the affine maps and their respective
2074/// operands. The equality of access functions + operands is checked by
2075/// subtracting fully composed value maps, and then simplifying the difference
2076/// using the expression flattener.
2077/// TODO: this does not account for aliasing of memrefs.
2079 if (memref != rhs.memref)
2080 return false;
2081
2082 AffineValueMap diff, thisMap, rhsMap;
2083 getAccessMap(&thisMap);
2084 rhs.getAccessMap(&rhsMap);
2085 return thisMap == rhsMap;
2086}
2087
2089 auto *currOp = op.getParentOp();
2090 AffineForOp currAffineForOp;
2091 // Traverse up the hierarchy collecting all 'affine.for' and affine.parallel
2092 // operation while skipping over 'affine.if' operations.
2093 while (currOp) {
2094 if (AffineForOp currAffineForOp = dyn_cast<AffineForOp>(currOp))
2095 ivs.push_back(currAffineForOp.getInductionVar());
2096 else if (auto parOp = dyn_cast<AffineParallelOp>(currOp))
2097 llvm::append_range(ivs, parOp.getIVs());
2098 currOp = currOp->getParentOp();
2099 }
2100 std::reverse(ivs.begin(), ivs.end());
2101}
2102
2103/// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
2104/// where each lists loops from outer-most to inner-most in loop nest.
2106 Operation &b) {
2107 SmallVector<Value, 4> loopsA, loopsB;
2108 getAffineIVs(a, loopsA);
2109 getAffineIVs(b, loopsB);
2110
2111 unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
2112 unsigned numCommonLoops = 0;
2113 for (unsigned i = 0; i < minNumLoops; ++i) {
2114 if (loopsA[i] != loopsB[i])
2115 break;
2116 ++numCommonLoops;
2117 }
2118 return numCommonLoops;
2119}
2120
2121static std::optional<int64_t> getMemoryFootprintBytes(Block &block,
2122 Block::iterator start,
2123 Block::iterator end,
2124 int memorySpace) {
2126
2127 // Walk this 'affine.for' operation to gather all memory regions.
2128 auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult {
2129 if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst)) {
2130 // Neither load nor a store op.
2131 return WalkResult::advance();
2132 }
2133
2134 // Compute the memref region symbolic in any IVs enclosing this block.
2135 auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
2136 if (failed(
2137 region->compute(opInst,
2138 /*loopDepth=*/getNestingDepth(&*block.begin())))) {
2139 LDBG() << "Error obtaining memory region";
2140 opInst->emitError("error obtaining memory region");
2141 return failure();
2142 }
2143
2144 auto [it, inserted] = regions.try_emplace(region->memref);
2145 if (inserted) {
2146 it->second = std::move(region);
2147 } else if (failed(it->second->unionBoundingBox(*region))) {
2148 LDBG() << "getMemoryFootprintBytes: unable to perform a union on a "
2149 "memory region";
2150 opInst->emitWarning(
2151 "getMemoryFootprintBytes: unable to perform a union on a memory "
2152 "region");
2153 return failure();
2154 }
2155 return WalkResult::advance();
2156 });
2157 if (result.wasInterrupted())
2158 return std::nullopt;
2159
2160 int64_t totalSizeInBytes = 0;
2161 for (const auto &region : regions) {
2162 std::optional<int64_t> size = region.second->getRegionSize();
2163 if (!size.has_value())
2164 return std::nullopt;
2165 totalSizeInBytes += *size;
2166 }
2167 return totalSizeInBytes;
2168}
2169
2170std::optional<int64_t> mlir::affine::getMemoryFootprintBytes(AffineForOp forOp,
2171 int memorySpace) {
2172 auto *forInst = forOp.getOperation();
2173 return ::getMemoryFootprintBytes(
2174 *forInst->getBlock(), Block::iterator(forInst),
2175 std::next(Block::iterator(forInst)), memorySpace);
2176}
2177
2178/// Returns whether a loop is parallel and contains a reduction loop.
2180 SmallVector<LoopReduction> reductions;
2181 if (!isLoopParallel(forOp, &reductions))
2182 return false;
2183 return !reductions.empty();
2184}
2185
2186/// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
2187/// at 'forOp'.
2189 AffineForOp forOp, llvm::SmallDenseSet<Value, 8> *sequentialLoops) {
2190 forOp->walk([&](Operation *op) {
2191 if (auto innerFor = dyn_cast<AffineForOp>(op))
2192 if (!isLoopParallel(innerFor))
2193 sequentialLoops->insert(innerFor.getInductionVar());
2194 });
2195}
2196
2199 if (fac.isEmpty())
2201 set.getContext());
2203
2204 auto simplifiedSet = fac.getAsIntegerSet(set.getContext());
2205 assert(simplifiedSet && "guaranteed to succeed while roundtripping");
2206 return simplifiedSet;
2207}
2208
2209static void unpackOptionalValues(ArrayRef<std::optional<Value>> source,
2211 target = llvm::map_to_vector<4>(source, [](std::optional<Value> val) {
2212 return val.has_value() ? *val : Value();
2213 });
2214}
2215
2216/// Bound an identifier `pos` in a given FlatAffineValueConstraints with
2217/// constraints drawn from an affine map. Before adding the constraint, the
2218/// dimensions/symbols of the affine map are aligned with `constraints`.
2219/// `operands` are the SSA Value operands used with the affine map.
2220/// Note: This function adds a new symbol column to the `constraints` for each
2221/// dimension/symbol that exists in the affine map but not in `constraints`.
2222static LogicalResult alignAndAddBound(FlatAffineValueConstraints &constraints,
2223 BoundType type, unsigned pos,
2224 AffineMap map, ValueRange operands) {
2225 SmallVector<Value> dims, syms, newSyms;
2226 unpackOptionalValues(constraints.getMaybeValues(VarKind::SetDim), dims);
2227 unpackOptionalValues(constraints.getMaybeValues(VarKind::Symbol), syms);
2228
2229 AffineMap alignedMap =
2230 alignAffineMapWithValues(map, operands, dims, syms, &newSyms);
2231 for (unsigned i = syms.size(); i < newSyms.size(); ++i)
2232 constraints.appendSymbolVar(newSyms[i]);
2233 return constraints.addBound(type, pos, alignedMap);
2234}
2235
2236/// Add `val` to each result of `map`.
2238 SmallVector<AffineExpr> newResults;
2239 for (AffineExpr r : map.getResults())
2240 newResults.push_back(r + val);
2241 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
2242 map.getContext());
2243}
2244
2245// Attempt to simplify the given min/max operation by proving that its value is
2246// bounded by the same lower and upper bound.
2247//
2248// Bounds are computed by FlatAffineValueConstraints. Invariants required for
2249// finding/proving bounds should be supplied via `constraints`.
2250//
2251// 1. Add dimensions for `op` and `opBound` (lower or upper bound of `op`).
2252// 2. Compute an upper bound of `op` (in case of `isMin`) or a lower bound (in
2253// case of `!isMin`) and bind it to `opBound`. SSA values that are used in
2254// `op` but are not part of `constraints`, are added as extra symbols.
2255// 3. For each result of `op`: Add result as a dimension `r_i`. Prove that:
2256// * If `isMin`: r_i >= opBound
2257// * If `isMax`: r_i <= opBound
2258// If this is the case, ub(op) == lb(op).
2259// 4. Replace `op` with `opBound`.
2260//
2261// In summary, the following constraints are added throughout this function.
2262// Note: `invar` are dimensions added by the caller to express the invariants.
2263// (Showing only the case where `isMin`.)
2264//
2265// invar | op | opBound | r_i | extra syms... | const | eq/ineq
2266// ------+-------+---------+-----+---------------+-------+-------------------
2267// (various eq./ineq. constraining `invar`, added by the caller)
2268// ... | 0 | 0 | 0 | 0 | ... | ...
2269// ------+-------+---------+-----+---------------+-------+-------------------
2270// (various ineq. constraining `op` in terms of `op` operands (`invar` and
2271// extra `op` operands "extra syms" that are not in `invar`)).
2272// ... | -1 | 0 | 0 | ... | ... | >= 0
2273// ------+-------+---------+-----+---------------+-------+-------------------
2274// (set `opBound` to `op` upper bound in terms of `invar` and "extra syms")
2275// ... | 0 | -1 | 0 | ... | ... | = 0
2276// ------+-------+---------+-----+---------------+-------+-------------------
2277// (for each `op` map result r_i: set r_i to corresponding map result,
2278// prove that r_i >= minOpUb via contradiction)
2279// ... | 0 | 0 | -1 | ... | ... | = 0
2280// 0 | 0 | 1 | -1 | 0 | -1 | >= 0
2281//
2283 Operation *op, FlatAffineValueConstraints constraints) {
2284 bool isMin = isa<AffineMinOp>(op);
2285 assert((isMin || isa<AffineMaxOp>(op)) && "expect AffineMin/MaxOp");
2286 MLIRContext *ctx = op->getContext();
2287 Builder builder(ctx);
2288 AffineMap map =
2289 isMin ? cast<AffineMinOp>(op).getMap() : cast<AffineMaxOp>(op).getMap();
2290 ValueRange operands = op->getOperands();
2291 unsigned numResults = map.getNumResults();
2292
2293 // Add a few extra dimensions.
2294 unsigned dimOp = constraints.appendDimVar(); // `op`
2295 unsigned dimOpBound = constraints.appendDimVar(); // `op` lower/upper bound
2296 unsigned resultDimStart = constraints.appendDimVar(/*num=*/numResults);
2297
2298 // Add an inequality for each result expr_i of map:
2299 // isMin: op <= expr_i, !isMin: op >= expr_i
2300 auto boundType = isMin ? BoundType::UB : BoundType::LB;
2301 // Upper bounds are exclusive, so add 1. (`affine.min` ops are inclusive.)
2302 AffineMap mapLbUb = isMin ? addConstToResults(map, 1) : map;
2303 if (failed(
2304 alignAndAddBound(constraints, boundType, dimOp, mapLbUb, operands)))
2305 return failure();
2306
2307 // Try to compute a lower/upper bound for op, expressed in terms of the other
2308 // `dims` and extra symbols.
2309 SmallVector<AffineMap> opLb(1), opUb(1);
2310 constraints.getSliceBounds(dimOp, 1, ctx, &opLb, &opUb);
2311 AffineMap sliceBound = isMin ? opUb[0] : opLb[0];
2312 // TODO: `getSliceBounds` may return multiple bounds at the moment. This is
2313 // a TODO of `getSliceBounds` and not handled here.
2314 if (!sliceBound || sliceBound.getNumResults() != 1)
2315 return failure(); // No or multiple bounds found.
2316 // Recover the inclusive UB in the case of an `affine.min`.
2317 AffineMap boundMap = isMin ? addConstToResults(sliceBound, -1) : sliceBound;
2318
2319 // Add an equality: Set dimOpBound to computed bound.
2320 // Add back dimension for op. (Was removed by `getSliceBounds`.)
2321 AffineMap alignedBoundMap = boundMap.shiftDims(/*shift=*/1, /*offset=*/dimOp);
2322 if (failed(constraints.addBound(BoundType::EQ, dimOpBound, alignedBoundMap)))
2323 return failure();
2324
2325 // If the constraint system is empty, there is an inconsistency. (E.g., this
2326 // can happen if loop lb > ub.)
2327 if (constraints.isEmpty())
2328 return failure();
2329
2330 // In the case of `isMin` (`!isMin` is inversed):
2331 // Prove that each result of `map` has a lower bound that is equal to (or
2332 // greater than) the upper bound of `op` (`dimOpBound`). In that case, `op`
2333 // can be replaced with the bound. I.e., prove that for each result
2334 // expr_i (represented by dimension r_i):
2335 //
2336 // r_i >= opBound
2337 //
2338 // To prove this inequality, add its negation to the constraint set and prove
2339 // that the constraint set is empty.
2340 for (unsigned i = resultDimStart; i < resultDimStart + numResults; ++i) {
2341 FlatAffineValueConstraints newConstr(constraints);
2342
2343 // Add an equality: r_i = expr_i
2344 // Note: These equalities could have been added earlier and used to express
2345 // minOp <= expr_i. However, then we run the risk that `getSliceBounds`
2346 // computes minOpUb in terms of r_i dims, which is not desired.
2347 if (failed(alignAndAddBound(newConstr, BoundType::EQ, i,
2348 map.getSubMap({i - resultDimStart}), operands)))
2349 return failure();
2350
2351 // If `isMin`: Add inequality: r_i < opBound
2352 // equiv.: opBound - r_i - 1 >= 0
2353 // If `!isMin`: Add inequality: r_i > opBound
2354 // equiv.: -opBound + r_i - 1 >= 0
2355 SmallVector<int64_t> ineq(newConstr.getNumCols(), 0);
2356 ineq[dimOpBound] = isMin ? 1 : -1;
2357 ineq[i] = isMin ? -1 : 1;
2358 ineq[newConstr.getNumCols() - 1] = -1;
2359 newConstr.addInequality(ineq);
2360 if (!newConstr.isEmpty())
2361 return failure();
2362 }
2363
2364 // Lower and upper bound of `op` are equal. Replace `minOp` with its bound.
2365 AffineMap newMap = alignedBoundMap;
2366 SmallVector<Value> newOperands;
2367 unpackOptionalValues(constraints.getMaybeValues(), newOperands);
2368 // If dims/symbols have known constant values, use those in order to simplify
2369 // the affine map further.
2370 for (int64_t i = 0, e = constraints.getNumDimAndSymbolVars(); i < e; ++i) {
2371 // Skip unused operands and operands that are already constants.
2372 if (!newOperands[i] || getConstantIntValue(newOperands[i]))
2373 continue;
2374 if (auto bound = constraints.getConstantBound64(BoundType::EQ, i)) {
2375 AffineExpr expr =
2376 i < newMap.getNumDims()
2377 ? builder.getAffineDimExpr(i)
2378 : builder.getAffineSymbolExpr(i - newMap.getNumDims());
2379 newMap = newMap.replace(expr, builder.getAffineConstantExpr(*bound),
2380 newMap.getNumDims(), newMap.getNumSymbols());
2381 }
2382 }
2383 affine::canonicalizeMapAndOperands(&newMap, &newOperands);
2384 return AffineValueMap(newMap, newOperands);
2385}
2386
2388 Operation *b) {
2389 Region *aScope = getAffineAnalysisScope(a);
2390 Region *bScope = getAffineAnalysisScope(b);
2391 if (aScope != bScope)
2392 return nullptr;
2393
2394 // Get the block ancestry of `op` while stopping at the affine scope `aScope`
2395 // and store them in `ancestry`.
2396 auto getBlockAncestry = [&](Operation *op,
2397 SmallVectorImpl<Block *> &ancestry) {
2398 Operation *curOp = op;
2399 do {
2400 ancestry.push_back(curOp->getBlock());
2401 if (curOp->getParentRegion() == aScope)
2402 break;
2403 curOp = curOp->getParentOp();
2404 } while (curOp);
2405 assert(curOp && "can't reach root op without passing through affine scope");
2406 std::reverse(ancestry.begin(), ancestry.end());
2407 };
2408
2409 SmallVector<Block *, 4> aAncestors, bAncestors;
2410 getBlockAncestry(a, aAncestors);
2411 getBlockAncestry(b, bAncestors);
2412 assert(!aAncestors.empty() && !bAncestors.empty() &&
2413 "at least one Block ancestor expected");
2414
2415 Block *innermostCommonBlock = nullptr;
2416 for (unsigned a = 0, b = 0, e = aAncestors.size(), f = bAncestors.size();
2417 a < e && b < f; ++a, ++b) {
2418 if (aAncestors[a] != bAncestors[b])
2419 break;
2420 innermostCommonBlock = aAncestors[a];
2421 }
2422 return innermostCommonBlock;
2423}
return success()
static Value getMemRef(Operation *memOp)
Returns the memref being read/written by a memref/affine load/store op.
Definition Utils.cpp:247
static std::optional< uint64_t > getConstDifference(AffineMap lbMap, AffineMap ubMap)
Definition Utils.cpp:1795
static void findInstPosition(Operation *op, Block *limitBlock, SmallVectorImpl< unsigned > *positions)
Definition Utils.cpp:1540
static Node * addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg, DenseMap< Value, SetVector< unsigned > > &memrefAccesses)
Add op to MDG creating a new node and adding its memory accesses (affine or non-affine to memrefAcces...
Definition Utils.cpp:192
static bool mayDependence(const Node &srcNode, const Node &dstNode, Value memref)
Returns true if there may be a dependence on memref from srcNode's memory ops to dstNode's memory ops...
Definition Utils.cpp:265
const char *const kSliceFusionBarrierAttrName
Definition Utils.cpp:1862
static LogicalResult addMissingLoopIVBounds(SmallPtrSet< Value, 8 > &ivs, FlatAffineValueConstraints *cst)
Definition Utils.cpp:1582
static void getEffectedValues(Operation *op, SmallVectorImpl< Value > &values)
Returns the values that this op has a memref effect of type EffectTys on, not considering recursive e...
Definition Utils.cpp:166
MemRefDependenceGraph::Node Node
Definition Utils.cpp:39
static void unpackOptionalValues(ArrayRef< std::optional< Value > > source, SmallVector< Value > &target)
Definition Utils.cpp:2209
static AffineMap addConstToResults(AffineMap map, int64_t val)
Add val to each result of map.
Definition Utils.cpp:2237
static LogicalResult alignAndAddBound(FlatAffineValueConstraints &constraints, BoundType type, unsigned pos, AffineMap map, ValueRange operands)
Bound an identifier pos in a given FlatAffineValueConstraints with constraints drawn from an affine m...
Definition Utils.cpp:2222
static Operation * getInstAtPosition(ArrayRef< unsigned > positions, unsigned level, Block *block)
Definition Utils.cpp:1557
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
template bool mlir::hasEffect< MemoryEffects::Free >(Operation *)
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
Definition AffineMap.h:267
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:150
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition Block.h:318
iterator begin()
Definition Block.h:153
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:372
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:376
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:368
IntegerSet getAsIntegerSet(MLIRContext *context) const
Returns the constraint system as an integer set.
void getSliceBounds(unsigned offset, unsigned num, MLIRContext *context, SmallVectorImpl< AffineMap > *lbMaps, SmallVectorImpl< AffineMap > *ubMaps, bool closedUB=false)
Computes the lower and upper bounds of the first num dimensional variables (starting at offset) as an...
std::optional< int64_t > getConstantBoundOnDimSize(MLIRContext *context, unsigned pos, AffineMap *lb=nullptr, AffineMap *ub=nullptr, unsigned *minLbPos=nullptr, unsigned *minUbPos=nullptr) const
Returns a non-negative constant bound on the extent (upper bound - lower bound) of the specified vari...
FlatLinearValueConstraints represents an extension of FlatLinearConstraints where each non-local vari...
LogicalResult unionBoundingBox(const FlatLinearValueConstraints &other)
Updates the constraints to be the smallest bounding (enclosing) box that contains the points of this ...
void mergeAndAlignVarsWithOther(unsigned offset, FlatLinearValueConstraints *other)
Merge and align the variables of this and other starting at offset, so that both constraint systems g...
Value getValue(unsigned pos) const
Returns the Value associated with the pos^th variable.
void projectOut(Value val)
Projects out the variable that is associate with Value.
bool containsVar(Value val) const
Returns true if a variable with the specified Value exists, false otherwise.
void addBound(presburger::BoundType type, Value val, int64_t value)
Adds a constant bound for the variable associated with the given Value.
bool areVarsAlignedWithOther(const FlatLinearConstraints &other)
Returns true if this constraint system and other are in the same space, i.e., if they are associated ...
void getValues(unsigned start, unsigned end, SmallVectorImpl< Value > *values) const
Returns the Values associated with variables in range [start, end).
SmallVector< std::optional< Value > > getMaybeValues() const
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition IntegerSet.h:44
unsigned getNumDims() const
MLIRContext * getContext() const
static IntegerSet getEmptySet(unsigned numDims, unsigned numSymbols, MLIRContext *context)
Definition IntegerSet.h:56
unsigned getNumSymbols() const
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
result_range getResults()
Definition Operation.h:415
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
Value getOperand(unsigned i) const
unsigned getNumOperands() const
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
LogicalResult addBound(presburger::BoundType type, unsigned pos, AffineMap boundMap, ValueRange operands)
Adds a bound for the variable at the specified position with constraints being drawn from the specifi...
void convertLoopIVSymbolsToDims()
Changes all symbol variables which are loop IVs to dim variables.
LogicalResult addDomainFromSliceMaps(ArrayRef< AffineMap > lbMaps, ArrayRef< AffineMap > ubMaps, ArrayRef< Value > operands)
Adds constraints (lower and upper bounds) for each loop in the loop nest described by the bound maps ...
LogicalResult addAffineForOpDomain(AffineForOp forOp)
Adds constraints (lower and upper bounds) for the specified 'affine.for' operation's Value using IR i...
LogicalResult addSliceBounds(ArrayRef< Value > values, ArrayRef< AffineMap > lbMaps, ArrayRef< AffineMap > ubMaps, ArrayRef< Value > operands)
Adds slice lower bounds represented by lower bounds in lbMaps and upper bounds in ubMaps to each vari...
void removeTrivialRedundancy()
Removes duplicate constraints, trivially true constraints, and constraints that can be detected as re...
bool isEmpty() const
Checks for emptiness by performing variable elimination on all variables, running the GCD test on eac...
unsigned getNumCols() const
Returns the number of columns in the constraint system.
void addInequality(ArrayRef< DynamicAPInt > inEq)
Adds an inequality (>= 0) from the coefficients specified in inEq.
std::optional< int64_t > getConstantBound64(BoundType type, unsigned pos) const
The same, but casts to int64_t.
bool isIntegerEmpty() const
Return true if all the sets in the union are known to be integer empty false otherwise.
PresburgerSet subtract(const PresburgerRelation &set) const
std::optional< uint64_t > getConstantTripCount(AffineForOp forOp)
Returns the trip count of the loop if it's a constant, std::nullopt otherwise.
IntegerSet simplifyIntegerSet(IntegerSet set)
Simplify the integer set by simplifying the underlying affine expressions by flattening and some simp...
Definition Utils.cpp:2197
void getEnclosingAffineOps(Operation &op, SmallVectorImpl< Operation * > *ops)
Populates 'ops' with affine operations enclosing op ordered from outermost to innermost while stoppin...
Definition Utils.cpp:866
SliceComputationResult computeSliceUnion(ArrayRef< Operation * > opsA, ArrayRef< Operation * > opsB, unsigned loopDepth, unsigned numCommonLoops, bool isBackwardSlice, ComputationSliceState *sliceUnion)
Computes in 'sliceUnion' the union of all slice bounds computed at 'loopDepth' between all dependent ...
Definition Utils.cpp:1630
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
bool isLoopParallelAndContainsReduction(AffineForOp forOp)
Returns whether a loop is a parallel loop and contains a reduction loop.
Definition Utils.cpp:2179
unsigned getNumCommonSurroundingLoops(Operation &a, Operation &b)
Returns the number of surrounding loops common to both A and B.
Definition Utils.cpp:2105
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
void getAffineIVs(Operation &op, SmallVectorImpl< Value > &ivs)
Populates 'ivs' with IVs of the surrounding affine.for and affine.parallel ops ordered from the outer...
Definition Utils.cpp:2088
void getSequentialLoops(AffineForOp forOp, llvm::SmallDenseSet< Value, 8 > *sequentialLoops)
Returns in 'sequentialLoops' all sequential loops in loop nest rooted at 'forOp'.
Definition Utils.cpp:2188
DependenceResult checkMemrefAccessDependence(const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, unsigned loopDepth, FlatAffineValueConstraints *dependenceConstraints=nullptr, SmallVector< DependenceComponent, 2 > *dependenceComponents=nullptr, bool allowRAR=false)
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
Region * getAffineAnalysisScope(Operation *op)
Returns the closest region enclosing op that is held by a non-affine operation; nullptr if there is n...
void getAffineForIVs(Operation &op, SmallVectorImpl< AffineForOp > *loops)
Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered from the outermost 'affine....
Definition Utils.cpp:852
std::optional< int64_t > getMemoryFootprintBytes(AffineForOp forOp, int memorySpace=-1)
Gets the memory footprint of all data touched in the specified memory space in bytes; if the memory s...
Definition Utils.cpp:2170
unsigned getInnermostCommonLoopDepth(ArrayRef< Operation * > ops, SmallVectorImpl< AffineForOp > *surroundingLoops=nullptr)
Returns the innermost common loop depth for the set of operations in 'ops'.
Definition Utils.cpp:1598
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
void getComputationSliceState(Operation *depSourceOp, Operation *depSinkOp, const FlatAffineValueConstraints &dependenceConstraints, unsigned loopDepth, bool isBackwardSlice, ComputationSliceState *sliceState)
Computes the computation slice loop bounds for one loop nest as affine maps of the other loop nest's ...
Definition Utils.cpp:1867
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
std::optional< uint64_t > getIntOrFloatMemRefSizeInBytes(MemRefType memRefType)
Returns the size of a memref with element type int or float in bytes if it's statically shaped,...
Definition Utils.cpp:1460
unsigned getNestingDepth(Operation *op)
Returns the nesting depth of this operation, i.e., the number of loops surrounding this operation.
Definition Utils.cpp:2060
uint64_t getSliceIterationCount(const llvm::SmallDenseMap< Operation *, uint64_t, 8 > &sliceTripCountMap)
Return the number of iterations for the slicetripCountMap provided.
Definition Utils.cpp:1853
LogicalResult boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, bool emitError=true)
Checks a load or store op for an out of bound access; returns failure if the access is out of bounds ...
mlir::Block * findInnermostCommonBlockInScope(mlir::Operation *a, mlir::Operation *b)
Find the innermost common Block of a and b in the affine scope that a and b are part of.
Definition Utils.cpp:2387
bool noDependence(DependenceResult result)
Returns true if the provided DependenceResult corresponds to the absence of a dependence.
bool buildSliceTripCountMap(const ComputationSliceState &slice, llvm::SmallDenseMap< Operation *, uint64_t, 8 > *tripCountMap)
Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop nest surrounding represe...
Definition Utils.cpp:1815
AffineForOp insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst, unsigned dstLoopDepth, ComputationSliceState *sliceState)
Creates a clone of the computation contained in the loop nest surrounding 'srcOpInst',...
Definition Utils.cpp:1979
FailureOr< AffineValueMap > simplifyConstrainedMinMaxOp(Operation *op, FlatAffineValueConstraints constraints)
Try to simplify the given affine.min or affine.max op to an affine map with a single result and opera...
Definition Utils.cpp:2282
std::optional< int64_t > getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType)
Returns the memref's element type's size in bytes where the elemental type is an int or float or a ve...
Definition Utils.cpp:1416
BoundType
The type of bound: equal, lower bound or upper bound.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
AffineMap alignAffineMapWithValues(AffineMap map, ValueRange operands, ValueRange dims, ValueRange syms, SmallVector< Value > *newSyms=nullptr)
Re-indexes the dimensions and symbols of an affine map with given operands values to align with dims ...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their associated operands for a ...
Definition Utils.h:318
std::optional< bool > isSliceValid() const
Checks the validity of the slice computed.
Definition Utils.cpp:1037
SmallVector< Value, 4 > ivs
Definition Utils.h:321
LogicalResult getAsConstraints(FlatAffineValueConstraints *cst) const
Definition Utils.cpp:899
LogicalResult getSourceAsConstraints(FlatAffineValueConstraints &cst) const
Adds to 'cst' constraints which represent the original loop bounds on 'ivs' in 'this'.
Definition Utils.cpp:883
std::vector< SmallVector< Value, 4 > > ubOperands
Definition Utils.h:329
SmallVector< AffineMap, 4 > ubs
Definition Utils.h:325
std::optional< bool > isMaximal() const
Returns true if the computation slice encloses all the iterations of the sliced loop nest.
Definition Utils.cpp:1103
SmallVector< AffineMap, 4 > lbs
Definition Utils.h:323
std::vector< SmallVector< Value, 4 > > lbOperands
Definition Utils.h:327
Checks whether two accesses to the same memref access the same element.
SmallVector< Operation *, 4 > memrefFrees
Definition Utils.h:49
SmallVector< AffineForOp, 4 > forOps
Definition Utils.h:39
SmallVector< Operation *, 4 > loadOpInsts
Definition Utils.h:41
SmallVector< Operation *, 4 > memrefStores
Definition Utils.h:47
void collect(Operation *opToWalk)
Definition Utils.cpp:44
SmallVector< Operation *, 4 > memrefLoads
Definition Utils.h:45
SmallVector< Operation *, 4 > storeOpInsts
Definition Utils.h:43
Encapsulates a memref load or store access information.
unsigned getRank() const
Definition Utils.cpp:2050
SmallVector< Value, 4 > indices
void getAccessMap(AffineValueMap *accessMap) const
Populates 'accessMap' with composition of AffineApplyOps reachable from 'indices'.
MemRefAccess(Operation *memOp)
Constructs a MemRefAccess from an affine read/write operation.
Definition Utils.cpp:2035
bool operator==(const MemRefAccess &rhs) const
Equal if both affine accesses can be proved to be equivalent at compile time (considering the memrefs...
Definition Utils.cpp:2078
void getStoreOpsForMemref(Value memref, SmallVectorImpl< Operation * > *storeOps) const
Definition Utils.cpp:131
SmallVector< Operation *, 4 > loads
Definition Utils.h:73
SmallVector< Operation *, 4 > stores
Definition Utils.h:77
void getLoadAndStoreMemrefSet(DenseSet< Value > *loadAndStoreMemrefSet) const
Definition Utils.cpp:150
unsigned hasFree(Value memref) const
Definition Utils.cpp:124
SmallVector< Operation *, 4 > memrefLoads
Definition Utils.h:75
SmallVector< Operation *, 4 > memrefStores
Definition Utils.h:79
unsigned getLoadOpCount(Value memref) const
Definition Utils.cpp:79
unsigned getStoreOpCount(Value memref) const
Definition Utils.cpp:94
unsigned hasStore(Value memref) const
Returns true if there exists an operation with a write memory effect to memref in this node.
Definition Utils.cpp:110
void getLoadOpsForMemref(Value memref, SmallVectorImpl< Operation * > *loadOps) const
Definition Utils.cpp:140
SmallVector< Operation *, 4 > memrefFrees
Definition Utils.h:81
DenseMap< unsigned, SmallVector< Edge, 2 > > outEdges
Definition Utils.h:147
Block & block
The block for which this graph is created to perform fusion.
Definition Utils.h:270
unsigned addNode(Operation *op)
Definition Utils.cpp:468
bool writesToLiveInOrEscapingMemrefs(unsigned id) const
Definition Utils.cpp:498
void removeEdge(unsigned srcId, unsigned dstId, Value value)
Definition Utils.cpp:544
void addEdge(unsigned srcId, unsigned dstId, Value value)
Definition Utils.cpp:533
DenseMap< unsigned, Node > nodes
Definition Utils.h:141
void gatherDefiningNodes(unsigned id, DenseSet< unsigned > &definingNodes) const
Return all nodes which define SSA values used in node 'id'.
Definition Utils.cpp:632
bool hasDependencePath(unsigned srcId, unsigned dstId) const
Definition Utils.cpp:571
void clearNodeLoadAndStores(unsigned id)
Definition Utils.cpp:798
const Node * getForOpNode(AffineForOp forOp) const
Definition Utils.cpp:460
Operation * getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) const
Definition Utils.cpp:646
void updateEdges(unsigned srcId, unsigned dstId, const DenseSet< Value > &privateMemRefs, bool removeSrcId)
Definition Utils.cpp:721
DenseMap< unsigned, SmallVector< Edge, 2 > > inEdges
Definition Utils.h:144
void forEachMemRefInputEdge(unsigned id, const std::function< void(Edge)> &callback)
Definition Utils.cpp:806
bool init(bool fullAffineDependences=true)
Definition Utils.cpp:337
unsigned getOutEdgeCount(unsigned id, Value memref=nullptr) const
Definition Utils.cpp:622
const Node * getNode(unsigned id) const
Definition Utils.cpp:453
void forEachMemRefOutputEdge(unsigned id, const std::function< void(Edge)> &callback)
Definition Utils.cpp:814
void forEachMemRefEdge(ArrayRef< Edge > edges, const std::function< void(Edge)> &callback)
Definition Utils.cpp:822
void addToNode(unsigned id, ArrayRef< Operation * > loads, ArrayRef< Operation * > stores, ArrayRef< Operation * > memrefLoads, ArrayRef< Operation * > memrefStores, ArrayRef< Operation * > memrefFrees)
Definition Utils.cpp:785
unsigned getIncomingMemRefAccesses(unsigned id, Value memref) const
Definition Utils.cpp:606
bool hasEdge(unsigned srcId, unsigned dstId, Value value=nullptr) const
Definition Utils.cpp:518
void print(raw_ostream &os) const
Definition Utils.cpp:834
DenseMap< Value, unsigned > memrefEdgeCount
Definition Utils.h:150
A region of a memref's data space; this is typically constructed by analyzing load/store op's on this...
Definition Utils.h:489
std::optional< int64_t > getConstantBoundingSizeAndShape(SmallVectorImpl< int64_t > *shape=nullptr, SmallVectorImpl< AffineMap > *lbs=nullptr) const
Returns a constant upper bound on the number of elements in this region if bounded by a known constan...
Definition Utils.cpp:1157
unsigned getRank() const
Returns the rank of the memref that this region corresponds to.
Definition Utils.cpp:1153
FlatAffineValueConstraints cst
Region (data space) of the memref accessed.
Definition Utils.h:586
LogicalResult compute(Operation *op, unsigned loopDepth, const ComputationSliceState *sliceState=nullptr, bool addMemRefDimBounds=true, bool dropLocalVars=true, bool dropOuterIVs=true)
Computes the memory region accessed by this memref with the region represented as constraints symboli...
Definition Utils.cpp:1254
void getLowerAndUpperBound(unsigned pos, AffineMap &lbMap, AffineMap &ubMap) const
Gets the lower and upper bound map for the dimensional variable at pos.
Definition Utils.cpp:1213
std::optional< int64_t > getRegionSize()
Returns the size of this MemRefRegion in bytes.
Definition Utils.cpp:1435
LogicalResult unionBoundingBox(const MemRefRegion &other)
Definition Utils.cpp:1232
bool write
Read or write.
Definition Utils.h:573
FlatAffineValueConstraints * getConstraints()
Definition Utils.h:535
Value memref
Memref that this region corresponds to.
Definition Utils.h:570
MemRefRegion(Location loc)
Definition Utils.h:490
Enumerates different result statuses of slice computation by computeSliceUnion
Definition Utils.h:305
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.