27 #include "llvm/ADT/DenseMap.h"
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
39 #define GEN_PASS_DEF_AFFINELOOPFUSION
40 #include "mlir/Dialect/Affine/Passes.h.inc"
44 #define DEBUG_TYPE "affine-loop-fusion"
55 struct LoopFusion :
public affine::impl::AffineLoopFusionBase<LoopFusion> {
56 LoopFusion() =
default;
57 LoopFusion(
unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes,
58 bool maximalFusion,
enum FusionMode affineFusionMode) {
59 this->fastMemorySpace = fastMemorySpace;
60 this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024;
61 this->maximalFusion = maximalFusion;
62 this->affineFusionMode = affineFusionMode;
65 void runOnBlock(
Block *block);
66 void runOnOperation()
override;
84 bool hasOutDepsAfterFusion =
false;
86 for (
auto &outEdge : mdg->
outEdges[srcId]) {
89 if (depNodeOp == dstNodeOp)
99 if (fusedLoopInsPoint != depNodeOp &&
101 LLVM_DEBUG(llvm::dbgs() <<
"Src loop can't be removed: dst loop doesn't "
102 "dominate dependence\n");
106 hasOutDepsAfterFusion =
true;
112 if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
113 std::optional<bool> isMaximal = fusionSlice.
isMaximal();
115 LLVM_DEBUG(llvm::dbgs() <<
"Src loop can't be removed: can't determine "
116 "if fusion is maximal\n");
121 LLVM_DEBUG(llvm::dbgs()
122 <<
"Src loop can't be removed: fusion is not maximal\n");
140 if (mdg->
inEdges.count(dstId) == 0)
144 auto *dstNode = mdg->
getNode(dstId);
147 consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
151 for (
auto &srcEdge : mdg->
inEdges[dstId]) {
152 auto *srcNode = mdg->
getNode(srcEdge.id);
154 if (!isa<AffineForOp>(srcNode->op))
157 if (any_of(srcNode->stores, [&](
Operation *op) {
158 auto storeOp = cast<AffineWriteOpInterface>(op);
159 return consumedMemrefs.count(storeOp.getMemRef()) > 0;
161 srcIdCandidates.push_back(srcNode->id);
164 llvm::sort(srcIdCandidates);
165 srcIdCandidates.erase(
166 std::unique(srcIdCandidates.begin(), srcIdCandidates.end()),
167 srcIdCandidates.end());
176 auto *dstNode = mdg->
getNode(dstId);
177 auto *srcNode = mdg->
getNode(srcId);
179 producerConsumerMemrefs);
196 if (
auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
201 if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(defOp, memref))
208 if (block->getParent()->findAncestorOpInRegion(*user)->getBlock() != block)
210 return !isa<AffineMapAccessInterface>(*user);
219 for (
Operation *storeOp : node->stores) {
220 auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
221 if (escapingMemRefs.count(memref))
224 escapingMemRefs.insert(memref);
234 assert(isa<AffineForOp>(node->
op));
236 node->
op = newRootForOp;
245 unsigned dstLoopDepth,
246 std::optional<unsigned> fastMemorySpace,
247 uint64_t localBufSizeThreshold) {
248 Operation *forInst = forOp.getOperation();
255 auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
256 auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
257 unsigned rank = oldMemRefType.getRank();
263 assert(validRegion &&
"unexpected memref region failure");
265 std::vector<SmallVector<int64_t, 4>> lbs;
270 std::optional<int64_t> numElements =
272 assert(numElements &&
"non-constant number of elts in local buffer");
283 offsets.reserve(rank);
284 for (
unsigned d = 0; d < rank; ++d) {
285 assert(lbs[d].size() == cst->
getNumCols() - rank &&
"incorrect bound size");
288 for (
unsigned j = 0, e = cst->
getNumCols() - rank - 1;
j < e;
j++) {
291 assert(lbDivisors[d] > 0);
294 offsets.push_back(offset);
300 assert(eltSize &&
"memrefs with size elt types expected");
301 uint64_t bufSize = *eltSize * *numElements;
302 unsigned newMemSpace;
303 if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
304 newMemSpace = *fastMemorySpace;
306 newMemSpace = oldMemRefType.getMemorySpaceAsInt();
308 auto newMemRefType =
MemRefType::get(newShape, oldMemRefType.getElementType(),
317 Value newMemRef = top.
create<memref::AllocOp>(forOp.getLoc(), newMemRefType);
321 remapExprs.reserve(rank);
322 for (
unsigned i = 0; i < rank; i++) {
327 remapExprs.push_back(remapExpr);
331 AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
338 &*forOp.getBody()->begin());
340 "replaceAllMemrefUsesWith should always succeed here");
351 auto *srcNode = mdg->
getNode(srcId);
352 auto *dstNode = mdg->
getNode(dstId);
357 for (
auto &idAndNode : mdg->
nodes) {
360 if (srcNode->op->isBeforeInBlock(op) && op->
isBeforeInBlock(dstNode->op)) {
365 if (isa<AffineMapAccessInterface>(*user))
368 if (llvm::is_contained(users, user))
372 if (walkResult.wasInterrupted())
385 auto *srcNode = mdg->
getNode(srcId);
386 llvm::SmallDenseSet<Value, 2> memRefValues;
389 if (isa<AffineForOp>(op))
393 if (isa<MemRefType>(v.getType()))
394 memRefValues.insert(v);
398 return llvm::any_of(memRefValues, [&](
Value memref) {
443 AffineForOp dstForOp,
445 unsigned maxLegalFusionDepth,
446 unsigned *dstLoopDepth,
447 double computeToleranceThreshold) {
449 llvm::dbgs() <<
"Checking whether fusion is profitable between src op:\n";
450 llvm::dbgs() <<
' ' << *srcOpInst <<
" and destination loop:\n";
451 llvm::dbgs() << dstForOp <<
"\n";
454 if (maxLegalFusionDepth == 0) {
455 LLVM_DEBUG(llvm::dbgs() <<
"Can't fuse: maxLegalFusionDepth is 0\n");
480 double maxStorageReduction = 0.0;
481 std::optional<uint64_t> sliceMemEstimate;
484 std::optional<unsigned> bestDstLoopDepth;
487 uint64_t srcLoopNestCost =
getComputeCost(srcLoopIVs[0], srcLoopNestStats);
492 LLVM_DEBUG(llvm::dbgs()
493 <<
"Unable to compute MemRefRegion for source operation\n");
497 std::optional<int64_t> maybeSrcWriteRegionSizeBytes =
499 if (!maybeSrcWriteRegionSizeBytes.has_value())
501 int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;
504 uint64_t dstLoopNestCost =
getComputeCost(dstForOp, dstLoopNestStats);
508 for (
unsigned i = maxLegalFusionDepth; i >= 1; --i) {
514 int64_t fusedLoopNestComputeCost;
516 dstLoopNestStats, slice,
517 &fusedLoopNestComputeCost)) {
518 LLVM_DEBUG(llvm::dbgs() <<
"Unable to compute fusion compute cost\n");
522 double additionalComputeFraction =
523 fusedLoopNestComputeCost /
524 (
static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
533 LLVM_DEBUG(llvm::dbgs()
534 <<
"Failed to compute slice write region at loopDepth: " << i
539 std::optional<int64_t> maybeSliceWriteRegionSizeBytes =
541 if (!maybeSliceWriteRegionSizeBytes.has_value() ||
542 *maybeSliceWriteRegionSizeBytes == 0) {
543 LLVM_DEBUG(llvm::dbgs()
544 <<
"Failed to get slice write region size at loopDepth: " << i
548 int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
554 if (srcOpInst != srcStoreOpInst &&
555 sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
558 double storageReduction =
static_cast<double>(srcWriteRegionSizeBytes) /
559 static_cast<double>(sliceWriteRegionSizeBytes);
562 std::stringstream msg;
563 msg <<
" evaluating fusion profitability at depth : " << i <<
"\n"
564 << std::fixed << std::setprecision(2)
565 <<
" additional compute fraction: "
566 << 100.0 * additionalComputeFraction <<
"%\n"
567 <<
" storage reduction factor: " << storageReduction <<
"x\n"
568 <<
" fused nest cost: " << fusedLoopNestComputeCost <<
"\n"
569 <<
" src write region size: " << srcWriteRegionSizeBytes <<
"\n"
570 <<
" slice write region size: " << sliceWriteRegionSizeBytes
572 llvm::dbgs() << msg.str();
579 if ((storageReduction > maxStorageReduction) &&
580 (additionalComputeFraction < computeToleranceThreshold)) {
581 maxStorageReduction = storageReduction;
582 bestDstLoopDepth = i;
583 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
584 sliceMemEstimate = sliceWriteRegionSizeBytes;
590 if (!bestDstLoopDepth) {
593 <<
"All fusion choices involve more than the threshold amount of "
594 "redundant computation; NOT fusing.\n");
598 if (!bestDstLoopDepth) {
599 LLVM_DEBUG(llvm::dbgs() <<
"no fusion depth could be evaluated.\n");
604 *dstLoopDepth = *bestDstLoopDepth;
607 llvm::dbgs() <<
" LoopFusion fusion stats:"
608 <<
"\n best loop depth: " << bestDstLoopDepth
609 <<
"\n src loop nest compute cost: " << srcLoopNestCost
610 <<
"\n dst loop nest compute cost: " << dstLoopNestCost
611 <<
"\n fused loop nest compute cost: "
612 << minFusedLoopNestComputeCost <<
"\n");
617 std::optional<double> storageReduction;
619 if (!dstMemSize || !srcMemSize) {
620 LLVM_DEBUG(llvm::dbgs()
621 <<
" fusion memory benefit cannot be evaluated; NOT fusing.\n");
625 auto srcMemSizeVal = *srcMemSize;
626 auto dstMemSizeVal = *dstMemSize;
628 assert(sliceMemEstimate &&
"expected value");
629 auto fusedMem = dstMemSizeVal + *sliceMemEstimate;
631 LLVM_DEBUG(llvm::dbgs() <<
" src mem: " << srcMemSizeVal <<
"\n"
632 <<
" dst mem: " << dstMemSizeVal <<
"\n"
633 <<
" fused mem: " << fusedMem <<
"\n"
634 <<
" slice mem: " << sliceMemEstimate <<
"\n");
636 if (
static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
637 LLVM_DEBUG(llvm::dbgs() <<
"Fusion is not profitable; NOT fusing.\n");
642 (1.0 - fusedMem / (
static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
644 double additionalComputeFraction =
645 100.0 * (minFusedLoopNestComputeCost /
646 (
static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
648 (void)additionalComputeFraction;
650 std::stringstream msg;
651 msg <<
" fusion is most profitable at depth " << *dstLoopDepth <<
" with "
652 << std::setprecision(2) << additionalComputeFraction
653 <<
"% redundant computation and a ";
654 msg << (storageReduction ? std::to_string(*storageReduction) :
"<unknown>");
655 msg <<
"% storage reduction.\n";
656 llvm::dbgs() << msg.str();
710 struct GreedyFusion {
717 unsigned localBufSizeThreshold;
719 std::optional<unsigned> fastMemorySpace;
725 double computeToleranceThreshold;
730 std::optional<unsigned> fastMemorySpace,
bool maximalFusion,
731 double computeToleranceThreshold)
732 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
733 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
734 computeToleranceThreshold(computeToleranceThreshold) {}
741 for (
auto &idAndNode : mdg->
nodes) {
742 const Node &node = idAndNode.second;
743 worklist.push_back(node.id);
747 void runSiblingFusionOnly() {
749 eraseUnusedMemRefAllocations();
753 void runProducerConsumerFusionOnly() {
754 fuseProducerConsumerNodes(
756 eraseUnusedMemRefAllocations();
764 void runGreedyFusion() {
766 fuseProducerConsumerNodes(1);
768 fuseProducerConsumerNodes(
770 eraseUnusedMemRefAllocations();
775 bool canCreatePrivateMemRef(
Value memref,
777 unsigned producerId,
unsigned consumerId,
778 bool removeSrcNode) {
779 const Node *consumerNode = mdg->
getNode(consumerId);
787 if (srcEscapingMemRefs.count(memref) > 0 &&
788 (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
801 any_of(mdg->
outEdges[producerId], [&](
const auto &edge) {
802 return edge.value == memref && edge.id != consumerId;
812 void performFusionsIntoDest(
unsigned dstId,
unsigned maxSrcUserCount) {
813 LLVM_DEBUG(llvm::dbgs() <<
"Evaluating dst loop " << dstId <<
"\n");
815 if (mdg->
nodes.count(dstId) == 0)
818 auto *dstNode = mdg->
getNode(dstId);
820 if (!isa<AffineForOp>(dstNode->op))
824 if (dstNode->op->getNumResults() > 0)
827 LLVM_DEBUG(llvm::dbgs() <<
"Evaluating dst loop " << dstId <<
"\n");
834 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
845 dstNodeChanged =
false;
849 for (
unsigned srcId : llvm::reverse(srcIdCandidates)) {
851 auto *srcNode = mdg->
getNode(srcId);
852 auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
853 LLVM_DEBUG(llvm::dbgs() <<
"Evaluating src loop " << srcId
854 <<
" for dst loop " << dstId <<
"\n");
858 if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
863 producerConsumerMemrefs);
867 if (any_of(producerConsumerMemrefs, [&](
Value memref) {
885 if (!srcEscapingMemRefs.empty() &&
887 LLVM_DEBUG(llvm::dbgs()
888 <<
"Can't fuse: non-affine users in between the loops\n");
896 if (fusedLoopInsPoint ==
nullptr)
907 unsigned numSurroundingLoops = surroundingLoops.size();
913 if (producerConsumerMemrefs.count(
914 cast<AffineReadOpInterface>(op).getMemRef()) > 0)
915 dstMemrefOps.push_back(op);
917 if (producerConsumerMemrefs.count(
918 cast<AffineWriteOpInterface>(op).getMemRef()))
919 dstMemrefOps.push_back(op);
920 unsigned dstLoopDepthTest =
925 unsigned maxLegalFusionDepth = 0;
927 depthSliceUnions.resize(dstLoopDepthTest);
929 for (
unsigned i = 1; i <= dstLoopDepthTest; ++i) {
932 i + numSurroundingLoops,
933 &depthSliceUnions[i - 1], strategy);
936 maxLegalFusionDepth = i;
939 if (maxLegalFusionDepth == 0) {
940 LLVM_DEBUG(llvm::dbgs()
941 <<
"Can't fuse: fusion is not legal at any depth\n");
948 unsigned bestDstLoopDepth = maxLegalFusionDepth;
949 if (!maximalFusion) {
953 if (producerConsumerMemrefs.count(
954 cast<AffineWriteOpInterface>(op).getMemRef()))
955 producerStores.push_back(op);
963 assert(!producerStores.empty() &&
"Expected producer store");
964 if (producerStores.size() > 1)
965 LLVM_DEBUG(llvm::dbgs() <<
"Skipping profitability analysis. Not "
966 "supported for this case\n");
968 dstAffineForOp, depthSliceUnions,
969 maxLegalFusionDepth, &bestDstLoopDepth,
970 computeToleranceThreshold))
974 assert(bestDstLoopDepth > 0 &&
"Unexpected loop fusion depth");
976 depthSliceUnions[bestDstLoopDepth - 1];
977 assert(!bestSlice.
isEmpty() &&
"Missing slice union for depth");
983 srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
987 for (
Value memref : producerConsumerMemrefs) {
988 if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
991 LLVM_DEBUG(llvm::dbgs()
992 <<
"Creating private memref for " << memref <<
'\n');
994 privateMemrefs.insert(memref);
999 fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
1000 dstNodeChanged =
true;
1002 LLVM_DEBUG(llvm::dbgs()
1003 <<
"Fused src loop " << srcId <<
" into dst loop " << dstId
1004 <<
" at depth " << bestDstLoopDepth <<
":\n"
1005 << dstAffineForOp <<
"\n");
1008 if (fusedLoopInsPoint != dstAffineForOp)
1009 dstAffineForOp->moveBefore(fusedLoopInsPoint);
1012 mdg->
updateEdges(srcNode->id, dstNode->id, privateMemrefs,
1016 if (!privateMemrefs.empty()) {
1019 dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
1020 Value storeMemRef = storeOp.getMemRef();
1021 if (privateMemrefs.count(storeMemRef) > 0)
1022 privateMemRefToStores[storeMemRef].push_back(storeOp);
1029 for (
auto &memrefToStoresPair : privateMemRefToStores) {
1033 memrefToStoresPair.second;
1035 dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
1036 fastMemorySpace, localBufSizeThreshold);
1040 mdg->
addEdge(newMemRefNodeId, dstId, newMemRef);
1045 dstNode = mdg->
getNode(dstId);
1050 dstLoopCollector.
collect(dstAffineForOp);
1057 if (removeSrcNode) {
1058 LLVM_DEBUG(llvm::dbgs()
1059 <<
"Removing src loop " << srcId <<
" after fusion\n");
1061 srcAffineForOp.erase();
1066 }
while (dstNodeChanged);
1073 void fuseProducerConsumerNodes(
unsigned maxSrcUserCount) {
1074 LLVM_DEBUG(llvm::dbgs() <<
"--- Producer/Consumer Fusion ---\n");
1076 while (!worklist.empty()) {
1077 unsigned dstId = worklist.back();
1078 worklist.pop_back();
1079 performFusionsIntoDest(dstId, maxSrcUserCount);
1085 void fuseSiblingNodes() {
1086 LLVM_DEBUG(llvm::dbgs() <<
"--- Sibling Fusion ---\n");
1088 while (!worklist.empty()) {
1089 unsigned dstId = worklist.back();
1090 worklist.pop_back();
1093 if (mdg->
nodes.count(dstId) == 0)
1096 auto *dstNode = mdg->
getNode(dstId);
1098 if (!isa<AffineForOp>(dstNode->op))
1101 fuseWithSiblingNodes(dstNode);
1106 void fuseWithSiblingNodes(
Node *dstNode) {
1108 std::pair<unsigned, Value> idAndMemref;
1109 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
1111 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
1112 unsigned sibId = idAndMemref.first;
1113 Value memref = idAndMemref.second;
1116 auto *sibNode = mdg->
getNode(sibId);
1119 assert(sibNode->op->getBlock() == dstNode->op->getBlock());
1124 if (insertPointInst ==
nullptr)
1131 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
1133 assert(sibLoadOpInsts.size() == 1);
1134 Operation *sibLoadOpInst = sibLoadOpInsts[0];
1138 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
1148 unsigned numSurroundingLoops = surroundingLoops.size();
1151 unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
1152 auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
1156 depthSliceUnions.resize(dstLoopDepthTest);
1157 unsigned maxLegalFusionDepth = 0;
1159 for (
unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1162 i + numSurroundingLoops,
1163 &depthSliceUnions[i - 1], strategy);
1166 maxLegalFusionDepth = i;
1169 LLVM_DEBUG(llvm::dbgs() <<
"Max legal depth for fusion: "
1170 << maxLegalFusionDepth <<
'\n');
1173 if (maxLegalFusionDepth == 0)
1176 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1177 if (!maximalFusion) {
1183 depthSliceUnions, maxLegalFusionDepth,
1184 &bestDstLoopDepth, computeToleranceThreshold))
1188 assert(bestDstLoopDepth > 0 &&
"Unexpected loop fusion depth");
1189 assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
1190 "Fusion depth has no computed slice union");
1194 bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
1197 depthSliceUnions[bestDstLoopDepth - 1],
1198 isInnermostInsertion);
1200 auto dstForInst = cast<AffineForOp>(dstNode->op);
1202 if (insertPointInst != dstForInst) {
1203 dstForInst->moveBefore(insertPointInst);
1206 updateStateAfterSiblingFusion(sibNode, dstNode);
1214 bool findSiblingNodeToFuse(
Node *dstNode,
1216 std::pair<unsigned, Value> *idAndMemrefToFuse) {
1219 auto canFuseWithSibNode = [&](
Node *sibNode,
Value memref) {
1222 if (sibNode->getLoadOpCount(memref) != 1)
1232 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
1233 if (llvm::any_of(loadAndStoreMemrefSet, [=](
Value memref) {
1240 for (
auto *storeOpInst : sibNode->stores) {
1241 storeMemrefs.insert(
1242 cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
1244 if (storeMemrefs.size() > 1)
1256 Block *block = dstNode->op->getBlock();
1259 auto loadOp = dyn_cast<AffineReadOpInterface>(user);
1268 auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
1269 return loop->getBlock() == &mdg->
block;
1272 if (it == loops.end())
1275 assert(sibNode !=
nullptr);
1277 if (sibNode->id == dstNode->id)
1280 if (visitedSibNodeIds->count(sibNode->id) > 0)
1283 auto memref = loadOp.getMemRef();
1284 if (dstNode->getLoadOpCount(memref) == 0)
1287 if (canFuseWithSibNode(sibNode, memref)) {
1288 visitedSibNodeIds->insert(sibNode->id);
1289 idAndMemrefToFuse->first = sibNode->id;
1290 idAndMemrefToFuse->second = memref;
1302 if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
1303 mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
1304 inEdges.push_back(inEdge);
1309 for (
auto &inEdge : inEdges) {
1314 unsigned sibNodeId = outEdge.id;
1315 if (visitedSibNodeIds->count(sibNodeId) > 0)
1318 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
1320 auto *sibNode = mdg->getNode(sibNodeId);
1321 if (!isa<AffineForOp>(sibNode->op))
1324 if (canFuseWithSibNode(sibNode, outEdge.value)) {
1326 outEdges.push_back(outEdge);
1331 if (!outEdges.empty()) {
1332 visitedSibNodeIds->insert(outEdges[0].
id);
1333 idAndMemrefToFuse->first = outEdges[0].id;
1334 idAndMemrefToFuse->second = outEdges[0].value;
1343 void updateStateAfterSiblingFusion(
Node *sibNode,
Node *dstNode) {
1348 auto dstForInst = cast<AffineForOp>(dstNode->op);
1350 dstLoopCollector.
collect(dstForInst);
1365 void eraseUnusedMemRefAllocations() {
1367 if (pair.second > 0)
1369 auto memref = pair.first;
1375 if (isa_and_nonnull<memref::AllocOp>(op))
1384 void LoopFusion::runOnBlock(
Block *block) {
1387 LLVM_DEBUG(llvm::dbgs() <<
"MDG init failed\n");
1391 std::optional<unsigned> fastMemorySpaceOpt;
1392 if (fastMemorySpace.hasValue())
1393 fastMemorySpaceOpt = fastMemorySpace;
1394 unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
1395 GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
1396 maximalFusion, computeToleranceThreshold);
1399 fusion.runProducerConsumerFusionOnly();
1401 fusion.runSiblingFusionOnly();
1403 fusion.runGreedyFusion();
1406 void LoopFusion::runOnOperation() {
1409 getOperation()->walk([&](
Operation *op) {
1411 for (Block &block : region.getBlocks()) {
1412 auto affineFors = block.getOps<AffineForOp>();
1413 if (!affineFors.empty() && !llvm::hasSingleElement(affineFors))
1421 unsigned fastMemorySpace, uint64_t localBufSizeThreshold,
1422 bool maximalFusion,
enum FusionMode affineFusionMode) {
1423 return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
1424 maximalFusion, affineFusionMode);
MemRefDependenceGraph::Node Node
static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, AffineForOp dstForOp, ArrayRef< ComputationSliceState > depthSliceUnions, unsigned maxLegalFusionDepth, unsigned *dstLoopDepth, double computeToleranceThreshold)
static bool canRemoveSrcNodeAfterFusion(unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, Operation *fusedLoopInsPoint, const DenseSet< Value > &escapingMemRefs, MemRefDependenceGraph *mdg)
Returns true if node 'srcId' can be removed after fusing it with node 'dstId'.
static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, Value memref, MemRefDependenceGraph *mdg)
Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and 'dstId'), if there is any non-aff...
static bool isEscapingMemref(Value memref, Block *block)
A memref escapes in the context of the fusion pass if either:
static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, DenseSet< Value > &escapingMemRefs)
Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' that escape the block or ...
static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, unsigned dstLoopDepth, std::optional< unsigned > fastMemorySpace, uint64_t localBufSizeThreshold)
static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg, SmallVectorImpl< unsigned > &srcIdCandidates)
Returns in 'srcIdCandidates' the producer fusion candidates for consumer 'dstId'.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
AffineExpr getAffineConstantExpr(int64_t constant)
AffineExpr getAffineDimExpr(unsigned position)
void getValues(unsigned start, unsigned end, SmallVectorImpl< Value > *values) const
Returns the Values associated with variables in range [start, end).
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
Describes the fusion strategy to be used in the Affine loop fusion utilities.
unsigned getNumVars() const
unsigned getNumCols() const
Returns the number of columns in the constraint system.
bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, AffineForOp dstForOp, LoopNestStats &dstStats, const ComputationSliceState &slice, int64_t *computeCost)
Computes and returns in 'computeCost', the total compute cost of fusing the 'slice' of the loop nest ...
void gatherProducerConsumerMemrefs(ArrayRef< Operation * > srcOps, ArrayRef< Operation * > dstOps, DenseSet< Value > &producerConsumerMemrefs)
Returns in 'producerConsumerMemrefs' the memrefs involved in a producer-consumer dependence between w...
int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats)
Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
void fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, const ComputationSliceState &srcSlice, bool isInnermostSiblingInsertionFusion=false)
Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point and source slice loop bo...
void getAffineForIVs(Operation &op, SmallVectorImpl< AffineForOp > *loops)
Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered from the outermost 'affine....
std::optional< int64_t > getMemoryFootprintBytes(AffineForOp forOp, int memorySpace=-1)
Gets the memory footprint of all data touched in the specified memory space in bytes; if the memory s...
std::unique_ptr< Pass > createLoopFusionPass(unsigned fastMemorySpace=0, uint64_t localBufSizeThreshold=0, bool maximalFusion=false, enum FusionMode fusionMode=FusionMode::Greedy)
Creates a loop fusion pass which fuses affine loop nests at the top-level of the operation the pass i...
FusionMode
Fusion mode to attempt.
unsigned getInnermostCommonLoopDepth(ArrayRef< Operation * > ops, SmallVectorImpl< AffineForOp > *surroundingLoops=nullptr)
Returns the innermost common loop depth for the set of operations in 'ops'.
bool getLoopNestStats(AffineForOp forOp, LoopNestStats *stats)
Collect loop nest statistics (eg.
AffineForOp sinkSequentialLoops(AffineForOp forOp)
FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, ComputationSliceState *srcSlice, FusionStrategy fusionStrategy=FusionStrategy::Generic)
Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the loop nest rooted at 'dst...
LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, ArrayRef< Value > extraIndices={}, AffineMap indexRemap=AffineMap(), ArrayRef< Value > extraOperands={}, ArrayRef< Value > symbolOperands={}, Operation *domOpFilter=nullptr, Operation *postDomOpFilter=nullptr, bool allowNonDereferencingOps=false, bool replaceInDeallocOp=false)
Replaces all "dereferencing" uses of oldMemRef with newMemRef while optionally remapping the old memr...
std::optional< int64_t > getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType)
Returns the memref's element type's size in bytes where the elemental type is an int or float or a ve...
Include the generated interface declarations.
int64_t floorDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's floordiv operation on constants.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their associated operands for a ...
bool isEmpty() const
Returns true if the computation slice is empty.
std::optional< bool > isMaximal() const
Returns true if the computation slice encloses all the iterations of the sliced loop nest.
enum mlir::affine::FusionResult::ResultEnum value
SmallVector< Operation *, 4 > loadOpInsts
void collect(Operation *opToWalk)
SmallVector< Operation *, 4 > storeOpInsts
LoopNestStats aggregates various per-loop statistics (eg.
DenseMap< unsigned, SmallVector< Edge, 2 > > outEdges
Block & block
The block for which this graph is created to perform fusion.
unsigned addNode(Operation *op)
unsigned getIncomingMemRefAccesses(unsigned id, Value memref)
void addEdge(unsigned srcId, unsigned dstId, Value value)
Node * getForOpNode(AffineForOp forOp)
DenseMap< unsigned, Node > nodes
bool hasDependencePath(unsigned srcId, unsigned dstId)
void clearNodeLoadAndStores(unsigned id)
void updateEdges(unsigned srcId, unsigned dstId, const DenseSet< Value > &privateMemRefs, bool removeSrcId)
DenseMap< unsigned, SmallVector< Edge, 2 > > inEdges
void addToNode(unsigned id, const SmallVectorImpl< Operation * > &loads, const SmallVectorImpl< Operation * > &stores)
void forEachMemRefInputEdge(unsigned id, const std::function< void(Edge)> &callback)
unsigned getOutEdgeCount(unsigned id, Value memref=nullptr)
void removeNode(unsigned id)
Operation * getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId)
void forEachMemRefOutputEdge(unsigned id, const std::function< void(Edge)> &callback)
Node * getNode(unsigned id)
DenseMap< Value, unsigned > memrefEdgeCount
A region of a memref's data space; this is typically constructed by analyzing load/store op's on this...
FlatAffineValueConstraints * getConstraints()
std::optional< int64_t > getConstantBoundingSizeAndShape(SmallVectorImpl< int64_t > *shape=nullptr, std::vector< SmallVector< int64_t, 4 >> *lbs=nullptr, SmallVectorImpl< int64_t > *lbDivisors=nullptr) const
Returns a constant upper bound on the number of elements in this region if bounded by a known constan...
std::optional< int64_t > getRegionSize()
Returns the size of this MemRefRegion in bytes.
LogicalResult compute(Operation *op, unsigned loopDepth, const ComputationSliceState *sliceState=nullptr, bool addMemRefDimBounds=true)
Computes the memory region accessed by this memref with the region represented as constraints symboli...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.