27 #include "llvm/ADT/DenseMap.h" 28 #include "llvm/ADT/DenseSet.h" 29 #include "llvm/ADT/STLExtras.h" 30 #include "llvm/ADT/SetVector.h" 31 #include "llvm/Support/CommandLine.h" 32 #include "llvm/Support/Debug.h" 33 #include "llvm/Support/raw_ostream.h" 36 #define DEBUG_TYPE "affine-loop-fusion" 50 struct LoopFusion :
public AffineLoopFusionBase<LoopFusion> {
51 LoopFusion() =
default;
52 LoopFusion(
unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes,
53 bool maximalFusion,
enum FusionMode affineFusionMode) {
54 this->fastMemorySpace = fastMemorySpace;
55 this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024;
56 this->maximalFusion = maximalFusion;
57 this->affineFusionMode = affineFusionMode;
60 void runOnOperation()
override;
65 std::unique_ptr<OperationPass<func::FuncOp>>
67 uint64_t localBufSizeThreshold,
bool maximalFusion,
69 return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
70 maximalFusion, affineFusionMode);
78 struct LoopNestStateCollector {
82 bool hasNonAffineRegionOp =
false;
86 if (isa<AffineForOp>(op))
87 forOps.push_back(cast<AffineForOp>(op));
89 hasNonAffineRegionOp =
true;
90 else if (isa<AffineReadOpInterface>(op))
91 loadOpInsts.push_back(op);
92 else if (isa<AffineWriteOpInterface>(op))
93 storeOpInsts.push_back(op);
103 struct MemRefDependenceGraph {
117 Node(
unsigned id,
Operation *op) : id(
id), op(op) {}
120 unsigned getLoadOpCount(
Value memref) {
121 unsigned loadOpCount = 0;
122 for (
auto *loadOpInst : loads) {
123 if (memref == cast<AffineReadOpInterface>(loadOpInst).getMemRef())
130 unsigned getStoreOpCount(
Value memref) {
131 unsigned storeOpCount = 0;
132 for (
auto *storeOpInst : stores) {
133 if (memref == cast<AffineWriteOpInterface>(storeOpInst).getMemRef())
140 void getStoreOpsForMemref(
Value memref,
142 for (
auto *storeOpInst : stores) {
143 if (memref == cast<AffineWriteOpInterface>(storeOpInst).getMemRef())
144 storeOps->push_back(storeOpInst);
149 void getLoadOpsForMemref(
Value memref,
151 for (
auto *loadOpInst : loads) {
152 if (memref == cast<AffineReadOpInterface>(loadOpInst).getMemRef())
153 loadOps->push_back(loadOpInst);
159 void getLoadAndStoreMemrefSet(
DenseSet<Value> *loadAndStoreMemrefSet) {
160 llvm::SmallDenseSet<Value, 2> loadMemrefs;
161 for (
auto *loadOpInst : loads) {
162 loadMemrefs.insert(cast<AffineReadOpInterface>(loadOpInst).getMemRef());
164 for (
auto *storeOpInst : stores) {
165 auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
166 if (loadMemrefs.count(memref) > 0)
167 loadAndStoreMemrefSet->insert(memref);
200 unsigned nextNodeId = 0;
202 MemRefDependenceGraph() =
default;
206 bool init(func::FuncOp f);
209 Node *getNode(
unsigned id) {
210 auto it = nodes.find(
id);
211 assert(it != nodes.end());
216 Node *getForOpNode(AffineForOp forOp) {
217 for (
auto &idAndNode : nodes)
218 if (idAndNode.second.op == forOp.getOperation())
219 return &idAndNode.second;
225 Node node(nextNodeId++, op);
226 nodes.insert({node.id, node});
231 void removeNode(
unsigned id) {
233 if (inEdges.count(
id) > 0) {
235 for (
auto &inEdge : oldInEdges) {
236 removeEdge(inEdge.id,
id, inEdge.value);
240 if (outEdges.count(
id) > 0) {
242 for (
auto &outEdge : oldOutEdges) {
243 removeEdge(
id, outEdge.id, outEdge.value);
254 bool writesToLiveInOrEscapingMemrefs(
unsigned id) {
255 Node *node = getNode(
id);
256 for (
auto *storeOpInst : node->stores) {
257 auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
258 auto *op = memref.getDefiningOp();
263 for (
auto *user : memref.getUsers())
264 if (!isa<AffineMapAccessInterface>(*user))
273 bool hasEdge(
unsigned srcId,
unsigned dstId,
Value value =
nullptr) {
274 if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
277 bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
278 return edge.id == dstId && (!
value || edge.value ==
value);
280 bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
281 return edge.id == srcId && (!
value || edge.value ==
value);
283 return hasOutEdge && hasInEdge;
287 void addEdge(
unsigned srcId,
unsigned dstId,
Value value) {
288 if (!hasEdge(srcId, dstId, value)) {
289 outEdges[srcId].push_back({dstId, value});
290 inEdges[dstId].push_back({srcId, value});
292 memrefEdgeCount[
value]++;
297 void removeEdge(
unsigned srcId,
unsigned dstId,
Value value) {
298 assert(inEdges.count(dstId) > 0);
299 assert(outEdges.count(srcId) > 0);
301 assert(memrefEdgeCount.count(value) > 0);
302 memrefEdgeCount[
value]--;
305 for (
auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
306 if ((*it).id == srcId && (*it).value == value) {
307 inEdges[dstId].erase(it);
312 for (
auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end();
314 if ((*it).id == dstId && (*it).value == value) {
315 outEdges[srcId].erase(it);
323 bool hasDependencePath(
unsigned srcId,
unsigned dstId) {
326 worklist.push_back({srcId, 0});
328 while (!worklist.empty()) {
329 auto &idAndIndex = worklist.back();
331 if (idAndIndex.first == dstId)
335 if (outEdges.count(idAndIndex.first) == 0 ||
336 idAndIndex.second == outEdges[idAndIndex.first].size()) {
341 Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
345 worklist.push_back({edge.id, 0});
352 unsigned getIncomingMemRefAccesses(
unsigned id,
Value memref) {
353 unsigned inEdgeCount = 0;
354 if (inEdges.count(
id) > 0)
355 for (
auto &inEdge : inEdges[
id])
356 if (inEdge.value == memref) {
357 Node *srcNode = getNode(inEdge.id);
359 if (srcNode->getStoreOpCount(memref) > 0)
367 unsigned getOutEdgeCount(
unsigned id,
Value memref =
nullptr) {
368 unsigned outEdgeCount = 0;
369 if (outEdges.count(
id) > 0)
370 for (
auto &outEdge : outEdges[
id])
371 if (!memref || outEdge.value == memref)
378 for (MemRefDependenceGraph::Edge edge : inEdges[
id])
382 if (!edge.value.getType().isa<MemRefType>())
383 definingNodes.insert(edge.id);
390 if (outEdges.count(srcId) == 0)
391 return getNode(dstId)->op;
395 gatherDefiningNodes(dstId, definingNodes);
396 if (llvm::any_of(definingNodes, [&](
unsigned id) {
397 return hasDependencePath(srcId,
id);
399 LLVM_DEBUG(llvm::dbgs()
400 <<
"Can't fuse: a defining op with a user in the dst " 401 "loop has dependence from the src loop\n");
407 for (
auto &outEdge : outEdges[srcId])
408 if (outEdge.id != dstId)
409 srcDepInsts.insert(getNode(outEdge.id)->op);
413 for (
auto &inEdge : inEdges[dstId])
414 if (inEdge.id != srcId)
415 dstDepInsts.insert(getNode(inEdge.id)->op);
417 Operation *srcNodeInst = getNode(srcId)->op;
418 Operation *dstNodeInst = getNode(dstId)->op;
437 if (srcDepInsts.count(op) > 0 && firstSrcDepPos ==
None)
438 firstSrcDepPos = pos;
439 if (dstDepInsts.count(op) > 0)
441 depInsts.push_back(op);
445 if (firstSrcDepPos.has_value()) {
446 if (lastDstDepPos.has_value()) {
447 if (firstSrcDepPos.value() <= lastDstDepPos.value()) {
453 return depInsts[firstSrcDepPos.value()];
465 void updateEdges(
unsigned srcId,
unsigned dstId,
468 if (inEdges.count(srcId) > 0) {
470 for (
auto &inEdge : oldInEdges) {
472 if (privateMemRefs.count(inEdge.value) == 0)
473 addEdge(inEdge.id, dstId, inEdge.value);
478 if (outEdges.count(srcId) > 0) {
480 for (
auto &outEdge : oldOutEdges) {
482 if (outEdge.id == dstId)
483 removeEdge(srcId, outEdge.id, outEdge.value);
484 else if (removeSrcId) {
485 addEdge(dstId, outEdge.id, outEdge.value);
486 removeEdge(srcId, outEdge.id, outEdge.value);
493 if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) {
495 for (
auto &inEdge : oldInEdges)
496 if (privateMemRefs.count(inEdge.value) > 0)
497 removeEdge(inEdge.id, dstId, inEdge.value);
503 void updateEdges(
unsigned sibId,
unsigned dstId) {
507 if (inEdges.count(sibId) > 0) {
509 for (
auto &inEdge : oldInEdges) {
510 addEdge(inEdge.id, dstId, inEdge.value);
511 removeEdge(inEdge.id, sibId, inEdge.value);
518 if (outEdges.count(sibId) > 0) {
520 for (
auto &outEdge : oldOutEdges) {
521 addEdge(dstId, outEdge.id, outEdge.value);
522 removeEdge(sibId, outEdge.id, outEdge.value);
530 Node *node = getNode(
id);
531 llvm::append_range(node->loads, loads);
532 llvm::append_range(node->stores, stores);
535 void clearNodeLoadAndStores(
unsigned id) {
536 Node *node = getNode(
id);
538 node->stores.clear();
543 void forEachMemRefInputEdge(
unsigned id,
544 const std::function<
void(Edge)> &callback) {
545 if (inEdges.count(
id) > 0)
546 forEachMemRefEdge(inEdges[
id], callback);
551 void forEachMemRefOutputEdge(
unsigned id,
552 const std::function<
void(Edge)> &callback) {
553 if (outEdges.count(
id) > 0)
554 forEachMemRefEdge(outEdges[
id], callback);
560 const std::function<
void(Edge)> &callback) {
561 for (
const auto &edge : edges) {
563 if (!edge.value.getType().isa<MemRefType>())
565 assert(nodes.count(edge.id) > 0);
567 if (!isa<AffineForOp>(getNode(edge.id)->op))
574 void print(raw_ostream &os)
const {
575 os <<
"\nMemRefDependenceGraph\n";
577 for (
const auto &idAndNode : nodes) {
578 os <<
"Node: " << idAndNode.first <<
"\n";
579 auto it = inEdges.find(idAndNode.first);
580 if (it != inEdges.end()) {
581 for (
const auto &e : it->second)
582 os <<
" InEdge: " << e.id <<
" " << e.value <<
"\n";
584 it = outEdges.find(idAndNode.first);
585 if (it != outEdges.end()) {
586 for (
const auto &e : it->second)
587 os <<
" OutEdge: " << e.id <<
" " << e.value <<
"\n";
591 void dump()
const {
print(llvm::errs()); }
601 static bool canRemoveSrcNodeAfterFusion(
604 MemRefDependenceGraph *mdg) {
606 Operation *dstNodeOp = mdg->getNode(dstId)->op;
607 bool hasOutDepsAfterFusion =
false;
609 for (
auto &outEdge : mdg->outEdges[srcId]) {
610 Operation *depNodeOp = mdg->getNode(outEdge.id)->op;
612 if (depNodeOp == dstNodeOp)
622 if (fusedLoopInsPoint != depNodeOp &&
624 LLVM_DEBUG(llvm::dbgs() <<
"Src loop can't be removed: dst loop doesn't " 625 "dominate dependence\n");
629 hasOutDepsAfterFusion =
true;
635 if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
638 LLVM_DEBUG(llvm::dbgs() <<
"Src loop can't be removed: can't determine " 639 "if fusion is maximal\n");
644 LLVM_DEBUG(llvm::dbgs()
645 <<
"Src loop can't be removed: fusion is not maximal\n");
660 static void getProducerCandidates(
unsigned dstId, MemRefDependenceGraph *mdg,
663 if (mdg->inEdges.count(dstId) == 0)
667 auto *dstNode = mdg->getNode(dstId);
670 consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
674 for (
auto &srcEdge : mdg->inEdges[dstId]) {
675 auto *srcNode = mdg->getNode(srcEdge.id);
677 if (!isa<AffineForOp>(srcNode->op))
680 if (any_of(srcNode->stores, [&](
Operation *op) {
681 auto storeOp = cast<AffineWriteOpInterface>(op);
682 return consumedMemrefs.count(storeOp.getMemRef()) > 0;
684 srcIdCandidates.push_back(srcNode->id);
687 llvm::sort(srcIdCandidates);
688 srcIdCandidates.erase(
689 std::unique(srcIdCandidates.begin(), srcIdCandidates.end()),
690 srcIdCandidates.end());
697 MemRefDependenceGraph *mdg,
699 auto *dstNode = mdg->getNode(dstId);
700 auto *srcNode = mdg->getNode(srcId);
702 producerConsumerMemrefs);
709 void gatherEscapingMemrefs(
unsigned id, MemRefDependenceGraph *mdg,
711 auto *node = mdg->getNode(
id);
712 for (
auto *storeOpInst : node->stores) {
713 auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
714 if (escapingMemRefs.count(memref))
718 escapingMemRefs.insert(memref);
724 if (!isa<AffineMapAccessInterface>(*user))
725 escapingMemRefs.insert(memref);
735 bool MemRefDependenceGraph::init(func::FuncOp f) {
736 LLVM_DEBUG(llvm::dbgs() <<
"--- Initializing MDG ---\n");
740 if (!llvm::hasSingleElement(f))
744 for (
auto &op : f.front()) {
745 if (
auto forOp = dyn_cast<AffineForOp>(op)) {
748 LoopNestStateCollector collector;
749 collector.collect(&op);
752 if (collector.hasNonAffineRegionOp)
754 Node node(nextNodeId++, &op);
755 for (
auto *opInst : collector.loadOpInsts) {
756 node.loads.push_back(opInst);
757 auto memref = cast<AffineReadOpInterface>(opInst).getMemRef();
758 memrefAccesses[memref].insert(node.id);
760 for (
auto *opInst : collector.storeOpInsts) {
761 node.stores.push_back(opInst);
762 auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef();
763 memrefAccesses[memref].insert(node.id);
765 forToNodeMap[&op] = node.id;
766 nodes.insert({node.id, node});
767 }
else if (
auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
769 Node node(nextNodeId++, &op);
770 node.loads.push_back(&op);
771 auto memref = cast<AffineReadOpInterface>(op).getMemRef();
772 memrefAccesses[memref].insert(node.id);
773 nodes.insert({node.id, node});
774 }
else if (
auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
776 Node node(nextNodeId++, &op);
777 node.stores.push_back(&op);
778 auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
779 memrefAccesses[memref].insert(node.id);
780 nodes.insert({node.id, node});
781 }
else if (op.getNumRegions() != 0) {
784 }
else if (op.getNumResults() > 0 && !op.use_empty()) {
787 Node node(nextNodeId++, &op);
788 nodes.insert({node.id, node});
789 }
else if (isa<CallOpInterface>(op)) {
793 if (llvm::any_of(op.getOperandTypes(),
794 [&](
Type t) {
return t.isa<MemRefType>(); })) {
795 Node node(nextNodeId++, &op);
796 nodes.insert({node.id, node});
798 }
else if (
auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
802 effectInterface.getEffects(effects);
804 return isa<MemoryEffects::Write, MemoryEffects::Free>(
807 Node node(nextNodeId++, &op);
808 nodes.insert({node.id, node});
813 for (
auto &idAndNode : nodes) {
814 LLVM_DEBUG(llvm::dbgs() <<
"Create node " << idAndNode.first <<
" for:\n" 815 << *(idAndNode.second.op) <<
"\n");
821 for (
auto &idAndNode : nodes) {
822 const Node &node = idAndNode.second;
824 if (!node.stores.empty())
826 auto *opInst = node.op;
827 for (
auto value : opInst->getResults()) {
828 for (
auto *user :
value.getUsers()) {
833 assert(forToNodeMap.count(loops[0].getOperation()) > 0);
834 unsigned userLoopNestId = forToNodeMap[loops[0].getOperation()];
835 addEdge(node.id, userLoopNestId,
value);
841 for (
auto &memrefAndList : memrefAccesses) {
842 unsigned n = memrefAndList.second.size();
843 for (
unsigned i = 0; i < n; ++i) {
844 unsigned srcId = memrefAndList.second[i];
846 getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
847 for (
unsigned j = i + 1;
j < n; ++
j) {
848 unsigned dstId = memrefAndList.second[
j];
850 getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
851 if (srcHasStore || dstHasStore)
852 addEdge(srcId, dstId, memrefAndList.first);
865 assert(isa<AffineForOp>(node->op));
867 node->op = newRootForOp.getOperation();
872 auto elementType = memRefType.getElementType();
875 if (elementType.isIntOrFloat()) {
876 sizeInBits = elementType.getIntOrFloatBitWidth();
878 auto vectorType = elementType.cast<VectorType>();
882 return llvm::divideCeil(sizeInBits, 8);
891 unsigned dstLoopDepth,
893 uint64_t localBufSizeThreshold) {
894 auto *forInst = forOp.getOperation();
899 OpBuilder top(forInst->getParentOfType<func::FuncOp>().getBody());
901 auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
902 auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
903 unsigned rank = oldMemRefType.getRank();
907 bool validRegion =
succeeded(region.compute(srcStoreOpInst, dstLoopDepth));
909 assert(validRegion &&
"unexpected memref region failure");
911 std::vector<SmallVector<int64_t, 4>> lbs;
917 region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
918 assert(numElements &&
"non-constant number of elts in local buffer");
929 offsets.reserve(rank);
930 for (
unsigned d = 0; d < rank; ++d) {
931 assert(lbs[d].size() == cst->
getNumCols() - rank &&
"incorrect bound size");
933 AffineExpr offset = top.getAffineConstantExpr(0);
934 for (
unsigned j = 0, e = cst->
getNumCols() - rank - 1;
j < e;
j++) {
935 offset = offset + lbs[d][
j] * top.getAffineDimExpr(
j);
937 assert(lbDivisors[d] > 0);
940 offsets.push_back(offset);
947 unsigned newMemSpace;
948 if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
949 newMemSpace = fastMemorySpace.value();
951 newMemSpace = oldMemRefType.getMemorySpaceAsInt();
953 auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
962 Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType);
966 remapExprs.reserve(rank);
967 for (
unsigned i = 0; i < rank; i++) {
972 remapExprs.push_back(remapExpr);
976 AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
983 &*forOp.getBody()->begin());
985 "replaceAllMemrefUsesWith should always succeed here");
995 MemRefDependenceGraph *mdg) {
996 auto *srcNode = mdg->getNode(srcId);
997 auto *dstNode = mdg->getNode(dstId);
1002 for (
auto &idAndNode : mdg->nodes) {
1005 if (srcNode->op->isBeforeInBlock(op) && op->
isBeforeInBlock(dstNode->op)) {
1010 if (isa<AffineMapAccessInterface>(*user))
1013 if (llvm::is_contained(users, user))
1017 if (walkResult.wasInterrupted())
1028 MemRefDependenceGraph *mdg) {
1030 auto *srcNode = mdg->getNode(srcId);
1031 llvm::SmallDenseSet<Value, 2> memRefValues;
1034 if (isa<AffineForOp>(op))
1038 if (v.getType().isa<MemRefType>())
1039 memRefValues.insert(v);
1043 for (
Value memref : memRefValues)
1089 AffineForOp dstForOp,
1091 unsigned maxLegalFusionDepth,
1092 unsigned *dstLoopDepth,
1093 double computeToleranceThreshold) {
1095 llvm::dbgs() <<
"Checking whether fusion is profitable between src op:\n";
1096 llvm::dbgs() <<
' ' << *srcOpInst <<
" and destination loop:\n";
1097 llvm::dbgs() << dstForOp <<
"\n";
1100 if (maxLegalFusionDepth == 0) {
1101 LLVM_DEBUG(llvm::dbgs() <<
"Can't fuse: maxLegalFusionDepth == 0 .\n");
1126 double maxStorageReduction = 0.0;
1133 uint64_t srcLoopNestCost =
getComputeCost(srcLoopIVs[0], srcLoopNestStats);
1137 if (
failed(srcWriteRegion.compute(srcStoreOpInst, 0))) {
1138 LLVM_DEBUG(llvm::dbgs()
1139 <<
"Unable to compute MemRefRegion for source operation\n.");
1144 srcWriteRegion.getRegionSize();
1145 if (!maybeSrcWriteRegionSizeBytes.has_value())
1147 int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.value();
1150 uint64_t dstLoopNestCost =
getComputeCost(dstForOp, dstLoopNestStats);
1154 for (
unsigned i = maxLegalFusionDepth; i >= 1; --i) {
1160 int64_t fusedLoopNestComputeCost;
1162 dstLoopNestStats, slice,
1163 &fusedLoopNestComputeCost)) {
1164 LLVM_DEBUG(llvm::dbgs() <<
"Unable to compute fusion compute cost.\n.");
1168 double additionalComputeFraction =
1169 fusedLoopNestComputeCost /
1170 (
static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
1177 if (
failed(sliceWriteRegion.compute(srcStoreOpInst, 0,
1179 LLVM_DEBUG(llvm::dbgs()
1180 <<
"Failed to compute slice write region at loopDepth: " << i
1186 sliceWriteRegion.getRegionSize();
1187 if (!maybeSliceWriteRegionSizeBytes.has_value() ||
1188 maybeSliceWriteRegionSizeBytes.value() == 0) {
1189 LLVM_DEBUG(llvm::dbgs()
1190 <<
"Failed to get slice write region size at loopDepth: " << i
1194 int64_t sliceWriteRegionSizeBytes = maybeSliceWriteRegionSizeBytes.value();
1200 if (srcOpInst != srcStoreOpInst &&
1201 sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
1204 double storageReduction =
static_cast<double>(srcWriteRegionSizeBytes) /
1205 static_cast<double>(sliceWriteRegionSizeBytes);
1208 std::stringstream msg;
1209 msg <<
" evaluating fusion profitability at depth : " << i <<
"\n" 1210 << std::fixed << std::setprecision(2)
1211 <<
" additional compute fraction: " 1212 << 100.0 * additionalComputeFraction <<
"%\n" 1213 <<
" storage reduction factor: " << storageReduction <<
"x\n" 1214 <<
" fused nest cost: " << fusedLoopNestComputeCost <<
"\n" 1215 <<
" src write region size: " << srcWriteRegionSizeBytes <<
"\n" 1216 <<
" slice write region size: " << sliceWriteRegionSizeBytes
1218 llvm::dbgs() << msg.str();
1225 if ((storageReduction > maxStorageReduction) &&
1226 (additionalComputeFraction < computeToleranceThreshold)) {
1227 maxStorageReduction = storageReduction;
1228 bestDstLoopDepth = i;
1229 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
1230 sliceMemEstimate = sliceWriteRegionSizeBytes;
1236 if (!bestDstLoopDepth) {
1239 <<
"All fusion choices involve more than the threshold amount of " 1240 "redundant computation; NOT fusing.\n");
1244 if (!bestDstLoopDepth) {
1245 LLVM_DEBUG(llvm::dbgs() <<
"no fusion depth could be evaluated.\n");
1250 *dstLoopDepth = *bestDstLoopDepth;
1253 llvm::dbgs() <<
" LoopFusion fusion stats:" 1254 <<
"\n best loop depth: " << bestDstLoopDepth
1255 <<
"\n src loop nest compute cost: " << srcLoopNestCost
1256 <<
"\n dst loop nest compute cost: " << dstLoopNestCost
1257 <<
"\n fused loop nest compute cost: " 1258 << minFusedLoopNestComputeCost <<
"\n");
1265 if (!dstMemSize || !srcMemSize) {
1266 LLVM_DEBUG(llvm::dbgs()
1267 <<
" fusion memory benefit cannot be evaluated; NOT fusing.\n");
1271 auto srcMemSizeVal = srcMemSize.value();
1272 auto dstMemSizeVal = dstMemSize.value();
1274 assert(sliceMemEstimate &&
"expected value");
1275 auto fusedMem = dstMemSizeVal + sliceMemEstimate.value();
1277 LLVM_DEBUG(llvm::dbgs() <<
" src mem: " << srcMemSizeVal <<
"\n" 1278 <<
" dst mem: " << dstMemSizeVal <<
"\n" 1279 <<
" fused mem: " << fusedMem <<
"\n" 1280 <<
" slice mem: " << sliceMemEstimate <<
"\n");
1282 if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
1283 LLVM_DEBUG(llvm::dbgs() <<
"Fusion is not profitable; NOT fusing.\n");
1288 (1.0 - fusedMem / (
static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
1290 double additionalComputeFraction =
1291 100.0 * (minFusedLoopNestComputeCost /
1292 (
static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
1294 (
void)additionalComputeFraction;
1296 std::stringstream msg;
1297 msg <<
" fusion is most profitable at depth " << *dstLoopDepth <<
" with " 1298 << std::setprecision(2) << additionalComputeFraction
1299 <<
"% redundant computation and a ";
1300 msg << (storageReduction ? std::to_string(*storageReduction) :
"<unknown>");
1301 msg <<
"% storage reduction.\n";
1302 llvm::dbgs() << msg.str();
1356 struct GreedyFusion {
1359 MemRefDependenceGraph *mdg;
1363 unsigned localBufSizeThreshold;
1371 double computeToleranceThreshold;
1373 using Node = MemRefDependenceGraph::Node;
1375 GreedyFusion(MemRefDependenceGraph *mdg,
unsigned localBufSizeThreshold,
1377 double computeToleranceThreshold)
1378 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
1379 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
1380 computeToleranceThreshold(computeToleranceThreshold) {}
1387 for (
auto &idAndNode : mdg->nodes) {
1388 const Node &node = idAndNode.second;
1389 worklist.push_back(node.id);
1393 void runSiblingFusionOnly() {
1395 eraseUnusedMemRefAllocations();
1399 void runProducerConsumerFusionOnly() {
1400 fuseProducerConsumerNodes(
1402 eraseUnusedMemRefAllocations();
1410 void runGreedyFusion() {
1412 fuseProducerConsumerNodes(1);
1414 fuseProducerConsumerNodes(
1416 eraseUnusedMemRefAllocations();
1419 void fuseProducerConsumerNodes(
unsigned maxSrcUserCount) {
1420 LLVM_DEBUG(llvm::dbgs() <<
"--- Producer/Consumer Fusion ---\n");
1422 while (!worklist.empty()) {
1423 unsigned dstId = worklist.back();
1424 worklist.pop_back();
1427 if (mdg->nodes.count(dstId) == 0)
1430 auto *dstNode = mdg->getNode(dstId);
1432 if (!isa<AffineForOp>(dstNode->op))
1436 if (dstNode->op->getNumResults() > 0)
1439 LLVM_DEBUG(llvm::dbgs() <<
"Evaluating dst loop " << dstId <<
"\n");
1446 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
1450 bool dstNodeChanged;
1457 dstNodeChanged =
false;
1459 getProducerCandidates(dstId, mdg, srcIdCandidates);
1461 for (
unsigned srcId : llvm::reverse(srcIdCandidates)) {
1463 auto *srcNode = mdg->getNode(srcId);
1464 auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
1465 LLVM_DEBUG(llvm::dbgs() <<
"Evaluating src loop " << srcId
1466 <<
" for dst loop " << dstId <<
"\n");
1470 if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
1475 producerConsumerMemrefs);
1479 if (any_of(producerConsumerMemrefs, [&](
Value memref) {
1480 return mdg->getOutEdgeCount(srcNode->id, memref) >
1489 gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
1497 if (!srcEscapingMemRefs.empty() &&
1501 <<
"Can't fuse: non-affine users in between the loops\n.");
1508 mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
1509 if (fusedLoopInsPoint ==
nullptr)
1516 if (producerConsumerMemrefs.count(
1517 cast<AffineReadOpInterface>(op).getMemRef()) > 0)
1518 dstMemrefOps.push_back(op);
1520 if (producerConsumerMemrefs.count(
1521 cast<AffineWriteOpInterface>(op).getMemRef()))
1522 dstMemrefOps.push_back(op);
1527 unsigned maxLegalFusionDepth = 0;
1529 depthSliceUnions.resize(dstLoopDepthTest);
1531 for (
unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1533 srcAffineForOp, dstAffineForOp,
1534 i, &depthSliceUnions[i - 1], strategy);
1537 maxLegalFusionDepth = i;
1540 if (maxLegalFusionDepth == 0) {
1541 LLVM_DEBUG(llvm::dbgs()
1542 <<
"Can't fuse: fusion is not legal at any depth\n");
1549 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1550 if (!maximalFusion) {
1554 if (producerConsumerMemrefs.count(
1555 cast<AffineWriteOpInterface>(op).getMemRef()))
1556 producerStores.push_back(op);
1564 assert(!producerStores.empty() &&
"Expected producer store");
1565 if (producerStores.size() > 1)
1566 LLVM_DEBUG(llvm::dbgs() <<
"Skipping profitability analysis. Not " 1567 "supported for this case\n");
1569 dstAffineForOp, depthSliceUnions,
1570 maxLegalFusionDepth, &bestDstLoopDepth,
1571 computeToleranceThreshold))
1575 assert(bestDstLoopDepth > 0 &&
"Unexpected loop fusion depth");
1577 depthSliceUnions[bestDstLoopDepth - 1];
1578 assert(!bestSlice.
isEmpty() &&
"Missing slice union for depth");
1583 bool removeSrcNode = canRemoveSrcNodeAfterFusion(
1584 srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
1588 for (
Value memref : producerConsumerMemrefs) {
1596 if (srcEscapingMemRefs.count(memref) > 0 &&
1597 (removeSrcNode || dstNode->getStoreOpCount(memref) > 0))
1602 if (mdg->getIncomingMemRefAccesses(srcId, memref) > 0 ||
1603 mdg->getOutEdgeCount(dstId, memref) > 0)
1609 if (removeSrcNode &&
1610 any_of(mdg->outEdges[srcId], [&](
const auto &edge) {
1611 return edge.value == memref && edge.id != dstId;
1616 privateMemrefs.insert(memref);
1620 fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
1621 dstNodeChanged =
true;
1623 LLVM_DEBUG(llvm::dbgs()
1624 <<
"Fused src loop " << srcId <<
" into dst loop " << dstId
1625 <<
" at depth " << bestDstLoopDepth <<
":\n" 1626 << dstAffineForOp <<
"\n");
1629 if (fusedLoopInsPoint != dstAffineForOp.getOperation())
1630 dstAffineForOp.getOperation()->moveBefore(fusedLoopInsPoint);
1633 mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
1637 if (!privateMemrefs.empty()) {
1642 if (privateMemrefs.count(storeMemRef) > 0)
1643 privateMemRefToStores[storeMemRef].push_back(
1651 for (
auto &memrefToStoresPair : privateMemRefToStores) {
1655 memrefToStoresPair.second;
1657 dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
1658 fastMemorySpace, localBufSizeThreshold);
1660 unsigned newMemRefNodeId =
1663 mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
1668 dstNode = mdg->getNode(dstId);
1672 LoopNestStateCollector dstLoopCollector;
1673 dstLoopCollector.collect(dstAffineForOp.getOperation());
1676 mdg->clearNodeLoadAndStores(dstNode->id);
1677 mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
1678 dstLoopCollector.storeOpInsts);
1680 if (removeSrcNode) {
1681 LLVM_DEBUG(llvm::dbgs()
1682 <<
"Removing src loop " << srcId <<
" after fusion\n");
1684 srcAffineForOp.erase();
1685 mdg->removeNode(srcId);
1689 }
while (dstNodeChanged);
1695 void fuseSiblingNodes() {
1696 LLVM_DEBUG(llvm::dbgs() <<
"--- Sibling Fusion ---\n");
1698 while (!worklist.empty()) {
1699 unsigned dstId = worklist.back();
1700 worklist.pop_back();
1703 if (mdg->nodes.count(dstId) == 0)
1706 auto *dstNode = mdg->getNode(dstId);
1708 if (!isa<AffineForOp>(dstNode->op))
1711 fuseWithSiblingNodes(dstNode);
1716 void fuseWithSiblingNodes(Node *dstNode) {
1718 std::pair<unsigned, Value> idAndMemref;
1719 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
1721 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
1722 unsigned sibId = idAndMemref.first;
1723 Value memref = idAndMemref.second;
1726 auto *sibNode = mdg->getNode(sibId);
1729 assert(sibNode->op->getBlock() == dstNode->op->getBlock());
1732 ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id)
1733 : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id);
1734 if (insertPointInst ==
nullptr)
1741 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
1743 assert(sibLoadOpInsts.size() == 1);
1744 Operation *sibLoadOpInst = sibLoadOpInsts[0];
1745 assert(!sibNode->stores.empty());
1747 auto *sibStoreOpInst = sibNode->stores.back();
1751 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
1755 unsigned dstLoopDepthTest = dstLoopIVs.size();
1756 auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
1760 depthSliceUnions.resize(dstLoopDepthTest);
1761 unsigned maxLegalFusionDepth = 0;
1763 for (
unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1765 sibAffineForOp, dstAffineForOp,
1766 i, &depthSliceUnions[i - 1], strategy);
1769 maxLegalFusionDepth = i;
1773 if (maxLegalFusionDepth == 0)
1776 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1777 if (!maximalFusion) {
1780 depthSliceUnions, maxLegalFusionDepth,
1781 &bestDstLoopDepth, computeToleranceThreshold))
1785 assert(bestDstLoopDepth > 0 &&
"Unexpected loop fusion depth");
1786 assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
1787 "Fusion depth has no computed slice union");
1791 bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
1794 depthSliceUnions[bestDstLoopDepth - 1],
1795 isInnermostInsertion);
1797 auto dstForInst = cast<AffineForOp>(dstNode->op);
1799 if (insertPointInst != dstForInst.getOperation()) {
1800 dstForInst->moveBefore(insertPointInst);
1803 updateStateAfterSiblingFusion(sibNode, dstNode);
1811 bool findSiblingNodeToFuse(Node *dstNode,
1813 std::pair<unsigned, Value> *idAndMemrefToFuse) {
1816 auto canFuseWithSibNode = [&](Node *sibNode,
Value memref) {
1819 if (sibNode->getLoadOpCount(memref) != 1)
1823 if (mdg->hasDependencePath(sibNode->id, dstNode->id) ||
1824 mdg->hasDependencePath(dstNode->id, sibNode->id))
1829 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
1830 if (llvm::any_of(loadAndStoreMemrefSet, [=](
Value memref) {
1831 return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0;
1837 for (
auto *storeOpInst : sibNode->stores) {
1838 storeMemrefs.insert(
1839 cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
1841 if (storeMemrefs.size() != 1)
1853 auto fn = dstNode->op->getParentOfType<func::FuncOp>();
1854 for (
unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
1855 for (
auto *user : fn.getArgument(i).getUsers()) {
1856 if (
auto loadOp = dyn_cast<AffineReadOpInterface>(user)) {
1863 Node *sibNode = mdg->getForOpNode(loops[0]);
1864 assert(sibNode !=
nullptr);
1866 if (sibNode->id == dstNode->id)
1869 if (visitedSibNodeIds->count(sibNode->id) > 0)
1872 auto memref = loadOp.getMemRef();
1873 if (dstNode->getLoadOpCount(memref) == 0)
1876 if (canFuseWithSibNode(sibNode, memref)) {
1877 visitedSibNodeIds->insert(sibNode->id);
1878 idAndMemrefToFuse->first = sibNode->id;
1879 idAndMemrefToFuse->second = memref;
1889 mdg->forEachMemRefInputEdge(
1890 dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
1892 if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
1893 mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
1894 inEdges.push_back(inEdge);
1899 for (
auto &inEdge : inEdges) {
1902 mdg->forEachMemRefOutputEdge(
1903 inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) {
1904 unsigned sibNodeId = outEdge.id;
1905 if (visitedSibNodeIds->count(sibNodeId) > 0)
1908 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
1910 auto *sibNode = mdg->getNode(sibNodeId);
1911 if (!isa<AffineForOp>(sibNode->op))
1914 if (canFuseWithSibNode(sibNode, outEdge.value)) {
1916 outEdges.push_back(outEdge);
1921 if (!outEdges.empty()) {
1922 visitedSibNodeIds->insert(outEdges[0].
id);
1923 idAndMemrefToFuse->first = outEdges[0].id;
1924 idAndMemrefToFuse->second = outEdges[0].value;
1933 void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) {
1935 mdg->updateEdges(sibNode->id, dstNode->id);
1938 auto dstForInst = cast<AffineForOp>(dstNode->op);
1939 LoopNestStateCollector dstLoopCollector;
1940 dstLoopCollector.collect(dstForInst.getOperation());
1942 mdg->clearNodeLoadAndStores(dstNode->id);
1943 mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
1944 dstLoopCollector.storeOpInsts);
1948 if (mdg->getOutEdgeCount(sibNode->id) == 0) {
1950 mdg->removeNode(sibNode->id);
1956 void eraseUnusedMemRefAllocations() {
1957 for (
auto &pair : mdg->memrefEdgeCount) {
1958 if (pair.second > 0)
1960 auto memref = pair.first;
1962 if (!memref.use_empty())
1965 auto *op = memref.getDefiningOp();
1966 if (isa_and_nonnull<memref::AllocOp>(op))
1974 void LoopFusion::runOnOperation() {
1975 MemRefDependenceGraph g;
1976 if (!g.init(getOperation()))
1980 if (fastMemorySpace.hasValue())
1981 fastMemorySpaceOpt = fastMemorySpace;
1982 unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
1983 GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
1984 maximalFusion, computeToleranceThreshold);
1987 fusion.runProducerConsumerFusionOnly();
1989 fusion.runSiblingFusionOnly();
1991 fusion.runGreedyFusion();
Include the generated interface declarations.
Operation is a basic unit of execution within MLIR.
operand_range getOperands()
Returns an iterator on the underlying Value's.
EffectT * getEffect() const
Return the effect being applied.
unsigned getNumRegions()
Returns the number of regions held by this operation.
std::unique_ptr< OperationPass< func::FuncOp > > createLoopFusionPass(unsigned fastMemorySpace=0, uint64_t localBufSizeThreshold=0, bool maximalFusion=false, enum FusionMode fusionMode=FusionMode::Greedy)
Creates a loop fusion pass which fuses loops according to type of fusion specified in fusionMode...
Optional< int64_t > getMemoryFootprintBytes(AffineForOp forOp, int memorySpace=-1)
Gets the memory footprint of all data touched in the specified memory space in bytes; if the memory s...
static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, unsigned dstLoopDepth, Optional< unsigned > fastMemorySpace, uint64_t localBufSizeThreshold)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
FusionMode
Fusion mode to attempt.
Operation * getOperation()
Return the operation that this refers to.
void getLoopIVs(Operation &op, SmallVectorImpl< AffineForOp > *loops)
Populates 'loops' with IVs of the loops surrounding 'op' ordered from the outermost 'affine...
LoopNestStats aggregates various per-loop statistics (eg.
void fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, const ComputationSliceState &srcSlice, bool isInnermostSiblingInsertionFusion=false)
Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point and source slice loop bo...
Block * getBlock()
Returns the operation block that contains this operation.
user_range getUsers() const
bool isEmpty() const
Returns true if the computation slice is empty.
static constexpr const bool value
void erase()
Remove this operation from its parent block and delete it.
Optional< bool > isMaximal() const
Returns true if the computation slice encloses all the iterations of the sliced loop nest...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
unsigned getInnermostCommonLoopDepth(ArrayRef< Operation *> ops, SmallVectorImpl< AffineForOp > *surroundingLoops=nullptr)
Returns the innermost common loop depth for the set of operations in 'ops'.
std::enable_if< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT >::type walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one)...
This class represents an efficient way to signal success or failure.
int64_t floorDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's floordiv operation on constants.
OpListType::iterator iterator
bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, AffineForOp dstForOp, LoopNestStats &dstStats, const ComputationSliceState &slice, int64_t *computeCost)
Computes and returns in 'computeCost', the total compute cost of fusing the 'slice' of the loop nest ...
ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their associated operands for a ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumVars() const
bool getLoopNestStats(AffineForOp forOp, LoopNestStats *stats)
Collect loop nest statistics (eg.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static unsigned getMemRefEltSizeInBytes(MemRefType memRefType)
void gatherProducerConsumerMemrefs(ArrayRef< Operation *> srcOps, ArrayRef< Operation *> dstOps, DenseSet< Value > &producerConsumerMemrefs)
Returns in 'producerConsumerMemrefs' the memrefs involved in a producer-consumer dependence between w...
Base type for affine expression.
static WalkResult advance()
AffineForOp sinkSequentialLoops(AffineForOp forOp)
Location getLoc()
The source location the operation was defined or derived from.
Value getMemRef()
Returns the memref operand to write to.
static WalkResult interrupt()
This class represents an argument of a Block.
This class represents a specific instance of an effect.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Describes the fusion strategy to be used in the Affine loop fusion utilities.
static Operation * getFusedLoopNestInsertionPoint(AffineForOp srcForOp, AffineForOp dstForOp)
enum mlir::FusionResult::ResultEnum value
FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, ComputationSliceState *srcSlice, FusionStrategy fusionStrategy=FusionStrategy::Generic)
Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the loop nest rooted at 'dst...
Type getType() const
Return the type of this value.
FlatAffineValueConstraints represents an extension of IntegerPolyhedron where each non-local variable...
static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, Value memref, MemRefDependenceGraph *mdg)
Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and 'dstId'), if there is any non-aff...
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
unsigned getNumCols() const
Returns the number of columns in the constraint system.
static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, AffineForOp dstForOp, ArrayRef< ComputationSliceState > depthSliceUnions, unsigned maxLegalFusionDepth, unsigned *dstLoopDepth, double computeToleranceThreshold)
AffineExpr getAffineDimExpr(unsigned position)
A region of a memref's data space; this is typically constructed by analyzing load/store op's on this...
void getValues(unsigned start, unsigned end, SmallVectorImpl< Value > *values) const
Returns the Values associated with variables in range [start, end).
user_range getUsers()
Returns a range of all users.
This class helps build Operations.
int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats)
Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, ArrayRef< Value > extraIndices={}, AffineMap indexRemap=AffineMap(), ArrayRef< Value > extraOperands={}, ArrayRef< Value > symbolOperands={}, Operation *domOpFilter=nullptr, Operation *postDomOpFilter=nullptr, bool allowNonDereferencingOps=false, bool replaceInDeallocOp=false)
Replaces all "dereferencing" uses of oldMemRef with newMemRef while optionally remapping the old memr...