256 !commonBlock ? &*sliceInsertionBlock->
begin() : &*commonBlock->
begin();
259 assert(commonBlock &&
260 "common block of producer stores and slice should exist");
265 for (
Operation *store : producerStores) {
267 assert(ancestor &&
"producer store should be contained in common block");
268 firstAncestor = !firstAncestor || ancestor->
isBeforeInBlock(firstAncestor)
272 return firstAncestor;
280 AffineForOp srcForOp, AffineForOp dstForOp,
unsigned depth,
282 int64_t &fusedLoopNestComputeCost) {
283 LDBG() <<
"Determining additional compute fraction...";
288 LDBG() <<
"Failed to get source loop nest stats.";
295 LDBG() <<
"Failed to get destination loop nest stats.";
300 uint64_t srcLoopNestCost =
getComputeCost(srcForOp, srcLoopNestStats);
303 uint64_t dstLoopNestCost =
getComputeCost(dstForOp, dstLoopNestStats);
308 LDBG() <<
"Slice wasn't computed.";
313 dstLoopNestStats, slice,
314 &fusedLoopNestComputeCost)) {
315 LDBG() <<
"Unable to compute fusion compute cost";
319 double additionalComputeFraction =
320 fusedLoopNestComputeCost /
321 (
static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
324 return additionalComputeFraction;
335 unsigned dstLoopDepth,
336 std::optional<unsigned> fastMemorySpace,
337 Block *sliceInsertionBlock,
338 uint64_t localBufSizeThreshold) {
339 assert(!storeOps.empty() &&
"no source stores supplied");
345 if (storeOps.size() > 1 &&
346 !std::equal(std::next(storeOps.begin()), storeOps.end(), storeOps.begin(),
348 MemRefAccess aM(cast<AffineWriteOpInterface>(a));
349 MemRefAccess bM(cast<AffineWriteOpInterface>(b));
352 LDBG() <<
"Private memref creation unsupported for multiple producer "
353 <<
"stores with different access functions.";
364 auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOp).getMemRef();
365 auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
366 unsigned rank = oldMemRefType.getRank();
370 bool validRegion = succeeded(
371 region.
compute(srcStoreOp, dstLoopDepth,
nullptr,
375 assert(validRegion &&
"unexpected memref region failure");
381 std::optional<int64_t> numElements =
383 assert(numElements &&
"non-constant number of elts in local buffer");
394 offsets.reserve(rank);
399 for (
unsigned j = 0, e = lbs[0].getNumSymbols();
j < e; ++
j)
401 for (
unsigned d = 0; d < rank; ++d) {
402 assert(lbs[d].getNumResults() == 1 &&
403 "invalid private memref bound calculation");
404 offsets.push_back(lbs[d].getResult(0).replaceSymbols(replacements));
410 assert(eltSize &&
"memrefs with size elt types expected");
411 uint64_t bufSize = *eltSize * *numElements;
413 if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
414 newMemSpace =
b.getI64IntegerAttr(*fastMemorySpace);
416 newMemSpace = oldMemRefType.getMemorySpace();
418 auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
427 Value newMemRef = memref::AllocOp::create(top, forOp.getLoc(), newMemRefType);
431 remapExprs.reserve(rank);
432 for (
unsigned i = 0; i < rank; i++) {
433 auto dimExpr =
b.getAffineDimExpr(outerIVs.size() + i);
437 remapExprs.push_back(remapExpr);
441 AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
446 auto userFilterFn = [&](
Operation *user) {
447 auto domInfo = std::make_unique<DominanceInfo>(
449 return domInfo->dominates(domFilter, user);
451 LogicalResult res = replaceAllMemRefUsesWith(
452 oldMemRef, newMemRef, {}, indexRemap,
455 assert(succeeded(res) &&
456 "replaceAllMemrefUsesWith should always succeed here");
458 LDBG() <<
"Created private memref of type: " << newMemRefType;
502 AffineForOp dstForOp,
504 unsigned maxLegalFusionDepth,
505 unsigned *dstLoopDepth,
506 double computeToleranceThreshold) {
507 LDBG() <<
"Checking whether fusion is profitable between source nest:";
508 LDBG() <<
' ' << srcForOp <<
" and destination nest:";
511 if (maxLegalFusionDepth == 0) {
512 LDBG() <<
"Can't fuse: maxLegalFusionDepth is 0";
535 if (producerStores.size() > 1) {
536 LDBG() <<
"Limited profitability analysis. Not "
537 <<
"supported for multiple producer store case.";
539 int64_t fusedLoopNestComputeCost;
543 srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost,
544 fusedLoopNestComputeCost);
545 if (!fraction || fraction > computeToleranceThreshold) {
546 LDBG() <<
"Additional computation exceeds "
547 <<
"compute tolerance. Not fusing.";
550 LDBG() <<
"Considering fusion profitable at max legal depth.";
554 Operation *srcStoreOp = producerStores.front();
562 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
563 double maxStorageReduction = 0.0;
564 std::optional<uint64_t> sliceMemEstimate;
567 std::optional<unsigned> bestDstLoopDepth;
571 if (failed(srcWriteRegion.
compute(srcStoreOp, 0))) {
572 LDBG() <<
"Unable to compute MemRefRegion for source operation";
576 std::optional<int64_t> maybeSrcWriteRegionSizeBytes =
578 if (!maybeSrcWriteRegionSizeBytes.has_value())
580 int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;
583 uint64_t srcLoopNestCost =
getComputeCost(srcForOp, srcLoopNestStats);
586 uint64_t dstLoopNestCost =
getComputeCost(dstForOp, dstLoopNestStats);
590 for (
unsigned i = maxLegalFusionDepth; i >= 1; --i) {
600 int64_t fusedLoopNestComputeCost;
602 auto mayAdditionalComputeFraction =
604 sliceCost, fusedLoopNestComputeCost);
605 if (!mayAdditionalComputeFraction) {
606 LDBG() <<
"Can't determine additional compute fraction.";
609 double additionalComputeFraction = *mayAdditionalComputeFraction;
615 if (failed(sliceWriteRegion.
compute(srcStoreOp, 0, &slice))) {
616 LDBG() <<
"Failed to compute slice write region at loopDepth: " << i;
620 std::optional<int64_t> maybeSliceWriteRegionSizeBytes =
622 if (!maybeSliceWriteRegionSizeBytes.has_value() ||
623 *maybeSliceWriteRegionSizeBytes == 0) {
624 LDBG() <<
"Failed to get slice write region size at loopDepth: " << i;
627 int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
629 double storageReduction =
static_cast<double>(srcWriteRegionSizeBytes) /
630 static_cast<double>(sliceWriteRegionSizeBytes);
633 std::stringstream msg;
634 msg <<
" evaluating fusion profitability at depth : " << i <<
"\n"
635 << std::fixed << std::setprecision(2)
636 <<
" additional compute fraction: "
637 << 100.0 * additionalComputeFraction <<
"%\n"
638 <<
" storage reduction factor: " << storageReduction <<
"x\n"
639 <<
" fused nest cost: " << fusedLoopNestComputeCost <<
"\n"
640 <<
" src write region size: " << srcWriteRegionSizeBytes <<
"\n"
641 <<
" slice write region size: " << sliceWriteRegionSizeBytes;
649 if ((storageReduction > maxStorageReduction) &&
650 (additionalComputeFraction <= computeToleranceThreshold)) {
651 maxStorageReduction = storageReduction;
652 bestDstLoopDepth = i;
653 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
654 sliceMemEstimate = sliceWriteRegionSizeBytes;
660 if (!bestDstLoopDepth) {
661 LDBG() <<
"All fusion choices involve more than the threshold amount of "
662 <<
"redundant computation; NOT fusing.";
666 if (!bestDstLoopDepth) {
667 LDBG() <<
"no fusion depth could be evaluated.";
672 *dstLoopDepth = *bestDstLoopDepth;
674 LDBG() <<
" LoopFusion fusion stats:";
675 LDBG() <<
" best loop depth: " << bestDstLoopDepth;
676 LDBG() <<
" src loop nest compute cost: " << srcLoopNestCost;
677 LDBG() <<
" dst loop nest compute cost: " << dstLoopNestCost;
678 LDBG() <<
" fused loop nest compute cost: " << minFusedLoopNestComputeCost;
683 std::optional<double> storageReduction;
685 if (!dstMemSize || !srcMemSize) {
686 LDBG() <<
" fusion memory benefit cannot be evaluated; NOT fusing.";
690 auto srcMemSizeVal = *srcMemSize;
691 auto dstMemSizeVal = *dstMemSize;
693 assert(sliceMemEstimate &&
"expected value");
694 auto fusedMem = dstMemSizeVal + *sliceMemEstimate;
696 LDBG() <<
" src mem: " << srcMemSizeVal;
697 LDBG() <<
" dst mem: " << dstMemSizeVal;
698 LDBG() <<
" fused mem: " << fusedMem;
699 LDBG() <<
" slice mem: " << sliceMemEstimate;
701 if (
static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
702 LDBG() <<
"Fusion is not profitable; NOT fusing.";
707 (1.0 - fusedMem / (
static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
709 double additionalComputeFraction =
710 100.0 * (minFusedLoopNestComputeCost /
711 (
static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
713 (
void)additionalComputeFraction;
715 std::stringstream msg;
716 msg <<
" fusion is most profitable at depth " << *dstLoopDepth <<
" with "
717 << std::setprecision(2) << additionalComputeFraction
718 <<
"% redundant computation and a ";
719 msg << (storageReduction ? std::to_string(*storageReduction) :
"<unknown>");
720 msg <<
"% storage reduction.";
778 MemRefDependenceGraph *mdg;
780 SmallVector<unsigned, 8> worklist;
782 unsigned localBufSizeThreshold;
784 std::optional<unsigned> fastMemorySpace;
790 double computeToleranceThreshold;
792 using Node = MemRefDependenceGraph::Node;
794 GreedyFusion(MemRefDependenceGraph *mdg,
unsigned localBufSizeThreshold,
795 std::optional<unsigned> fastMemorySpace,
bool maximalFusion,
796 double computeToleranceThreshold)
797 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
798 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
799 computeToleranceThreshold(computeToleranceThreshold) {}
806 for (
auto &idAndNode : mdg->nodes) {
807 const Node &node = idAndNode.second;
808 worklist.push_back(node.id);
812 void runSiblingFusionOnly() {
814 eraseUnusedMemRefAllocations();
818 void runProducerConsumerFusionOnly() {
819 fuseProducerConsumerNodes(
820 std::numeric_limits<unsigned>::max());
821 eraseUnusedMemRefAllocations();
829 void runGreedyFusion() {
831 fuseProducerConsumerNodes(1);
833 fuseProducerConsumerNodes(
834 std::numeric_limits<unsigned>::max());
835 eraseUnusedMemRefAllocations();
840 bool canCreatePrivateMemRef(Value memref,
842 unsigned producerId,
unsigned consumerId,
843 bool removeSrcNode) {
847 const Node *consumerNode = mdg->getNode(consumerId);
855 if (srcEscapingMemRefs.count(memref) > 0 &&
856 (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
861 if (mdg->getIncomingMemRefAccesses(producerId, memref) > 0 ||
862 mdg->getOutEdgeCount(consumerId, memref) > 0)
869 any_of(mdg->outEdges[producerId], [&](
const auto &edge) {
870 return edge.value == memref && edge.id != consumerId;
880 void performFusionsIntoDest(
unsigned dstId,
unsigned maxSrcUserCount) {
881 LDBG() <<
"Evaluating dst loop " << dstId;
883 if (mdg->nodes.count(dstId) == 0)
886 auto *dstNode = mdg->getNode(dstId);
888 if (!isa<AffineForOp>(dstNode->op))
892 if (dstNode->op->getNumResults() > 0)
895 LDBG() <<
"Evaluating dst loop " << dstId;
902 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
913 dstNodeChanged =
false;
914 SmallVector<unsigned, 16> srcIdCandidates;
917 for (
unsigned srcId : llvm::reverse(srcIdCandidates)) {
919 auto *srcNode = mdg->getNode(srcId);
920 auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
922 LDBG() <<
"Trying to fuse producer loop nest " << srcId
923 <<
" with consumer loop nest " << dstId;
924 LDBG() <<
"Compute tolerance threshold: " << computeToleranceThreshold;
925 LDBG() <<
"Producer loop nest:";
926 LDBG() << *srcNode->op <<
" and consumer loop nest:";
927 LDBG() << *dstNode->op;
929 LDBG() <<
"Evaluating src loop " << srcId <<
" for dst loop " << dstId;
933 if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
938 producerConsumerMemrefs);
942 if (any_of(producerConsumerMemrefs, [&](Value memref) {
943 return mdg->getOutEdgeCount(srcNode->id, memref) >
956 Operation *fusedLoopInsPoint =
957 mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
958 if (fusedLoopInsPoint ==
nullptr)
967 SmallVector<AffineForOp, 4> surroundingLoops;
969 unsigned numSurroundingLoops = surroundingLoops.size();
973 SmallVector<Operation *, 2> dstMemrefOps;
974 for (Operation *op : dstNode->loads)
975 if (producerConsumerMemrefs.count(
976 cast<AffineReadOpInterface>(op).getMemRef()) > 0)
977 dstMemrefOps.push_back(op);
978 for (Operation *op : dstNode->stores)
979 if (producerConsumerMemrefs.count(
980 cast<AffineWriteOpInterface>(op).getMemRef()))
981 dstMemrefOps.push_back(op);
982 if (dstMemrefOps.empty())
984 unsigned dstLoopDepthTest =
989 unsigned maxLegalFusionDepth = 0;
990 SmallVector<ComputationSliceState, 8> depthSliceUnions;
991 depthSliceUnions.resize(dstLoopDepthTest);
993 for (
unsigned i = 1; i <= dstLoopDepthTest; ++i) {
996 i + numSurroundingLoops,
997 &depthSliceUnions[i - 1], strategy);
999 maxLegalFusionDepth = i;
1000 LDBG() <<
"Found valid slice for depth: " << i;
1004 if (maxLegalFusionDepth == 0) {
1005 LDBG() <<
"Can't fuse: fusion is not legal at any depth";
1009 LDBG() <<
"Max legal depth for fusion: " << maxLegalFusionDepth;
1011 double computeToleranceThresholdToUse = computeToleranceThreshold;
1019 LDBG() <<
"Source nest has a cyclic dependence.";
1023 if (maximalFusion) {
1024 auto srcForOp = cast<AffineForOp>(srcNode->op);
1025 auto dstForOp = cast<AffineForOp>(dstNode->op);
1027 int64_t fusedLoopNestComputeCost;
1029 srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
1030 sliceCost, fusedLoopNestComputeCost);
1031 if (!fraction || fraction > 0) {
1032 LDBG() <<
"Can't perform maximal fusion with a cyclic dependence "
1033 <<
"and non-zero additional compute.";
1039 LDBG() <<
"Setting compute tolerance to zero since "
1040 <<
"source has a cylic dependence.";
1041 computeToleranceThresholdToUse = 0;
1048 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1049 if (!maximalFusion) {
1051 SmallVector<Operation *, 2> producerStores;
1052 for (Operation *op : srcNode->stores)
1053 if (producerConsumerMemrefs.count(
1054 cast<AffineWriteOpInterface>(op).getMemRef()))
1055 producerStores.push_back(op);
1057 assert(!producerStores.empty() &&
"Expected producer store");
1059 dstAffineForOp, depthSliceUnions,
1060 maxLegalFusionDepth, &bestDstLoopDepth,
1061 computeToleranceThresholdToUse)) {
1066 assert(bestDstLoopDepth > 0 &&
"Unexpected loop fusion depth");
1067 ComputationSliceState &bestSlice =
1068 depthSliceUnions[bestDstLoopDepth - 1];
1069 assert(!bestSlice.
isEmpty() &&
"Missing slice union for depth");
1075 srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
1079 for (Value memref : producerConsumerMemrefs) {
1080 if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
1083 LDBG() <<
"Creating private memref for " << memref;
1085 privateMemrefs.insert(memref);
1090 fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
1091 dstNodeChanged =
true;
1093 LDBG() <<
"Fused src loop " << srcId <<
" into dst loop " << dstId
1094 <<
" at depth " << bestDstLoopDepth <<
":";
1095 LDBG() << dstAffineForOp;
1098 if (fusedLoopInsPoint != dstAffineForOp)
1099 dstAffineForOp->moveBefore(fusedLoopInsPoint);
1102 mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
1106 if (!privateMemrefs.empty()) {
1113 dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
1114 Value storeMemRef = storeOp.getMemRef();
1115 if (privateMemrefs.count(storeMemRef) > 0)
1116 privateMemRefToStores[storeMemRef].push_back(storeOp);
1123 for (
auto &memrefToStoresPair : privateMemRefToStores) {
1124 ArrayRef<Operation *> storesForMemref = memrefToStoresPair.second;
1126 dstAffineForOp, storesForMemref, bestDstLoopDepth,
1127 fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
1131 unsigned newMemRefNodeId = mdg->addNode(newMemRef.
getDefiningOp());
1133 mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
1138 dstNode = mdg->getNode(dstId);
1142 LoopNestStateCollector dstLoopCollector;
1143 dstLoopCollector.
collect(dstAffineForOp);
1146 mdg->clearNodeLoadAndStores(dstNode->id);
1152 if (removeSrcNode) {
1153 LDBG() <<
"Removing src loop " << srcId <<
" after fusion";
1155 srcAffineForOp.erase();
1156 mdg->removeNode(srcId);
1160 }
while (dstNodeChanged);
1167 void fuseProducerConsumerNodes(
unsigned maxSrcUserCount) {
1168 LDBG() <<
"--- Producer/Consumer Fusion ---";
1170 while (!worklist.empty()) {
1171 unsigned dstId = worklist.back();
1172 worklist.pop_back();
1173 performFusionsIntoDest(dstId, maxSrcUserCount);
1179 void fuseSiblingNodes() {
1180 LDBG() <<
"--- Sibling Fusion ---";
1182 while (!worklist.empty()) {
1183 unsigned dstId = worklist.back();
1184 worklist.pop_back();
1187 if (mdg->nodes.count(dstId) == 0)
1190 auto *dstNode = mdg->getNode(dstId);
1192 if (!isa<AffineForOp>(dstNode->op))
1195 fuseWithSiblingNodes(dstNode);
1200 void fuseWithSiblingNodes(Node *dstNode) {
1202 std::pair<unsigned, Value> idAndMemref;
1203 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
1205 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
1206 unsigned sibId = idAndMemref.first;
1207 Value memref = idAndMemref.second;
1210 auto *sibNode = mdg->getNode(sibId);
1213 assert(sibNode->op->getBlock() == dstNode->op->getBlock());
1214 Operation *insertPointInst =
1215 sibNode->op->isBeforeInBlock(dstNode->op)
1216 ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id)
1217 : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id);
1218 if (insertPointInst ==
nullptr)
1224 SmallVector<Operation *, 2> sibLoadOpInsts;
1225 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
1227 Operation *sibLoadOpInst = llvm::getSingleElement(sibLoadOpInsts);
1230 SmallVector<Operation *, 2> dstLoadOpInsts;
1231 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
1239 SmallVector<AffineForOp, 4> surroundingLoops;
1241 unsigned numSurroundingLoops = surroundingLoops.size();
1242 SmallVector<AffineForOp, 4> dstLoopIVs;
1244 unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
1245 auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
1248 SmallVector<ComputationSliceState, 8> depthSliceUnions;
1249 depthSliceUnions.resize(dstLoopDepthTest);
1250 unsigned maxLegalFusionDepth = 0;
1251 FusionStrategy strategy(memref);
1252 for (
unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1255 i + numSurroundingLoops,
1256 &depthSliceUnions[i - 1], strategy);
1259 maxLegalFusionDepth = i;
1262 LDBG() <<
"Max legal depth for fusion: " << maxLegalFusionDepth;
1265 if (maxLegalFusionDepth == 0)
1268 double computeToleranceThresholdToUse = computeToleranceThreshold;
1276 LDBG() <<
"Source nest has a cyclic dependence.";
1280 if (maximalFusion) {
1281 auto dstForOp = cast<AffineForOp>(dstNode->op);
1283 int64_t fusedLoopNestComputeCost;
1285 sibAffineForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
1286 sliceCost, fusedLoopNestComputeCost);
1287 if (!fraction || fraction > 0) {
1288 LDBG() <<
"Can't perform maximal fusion with a cyclic dependence "
1289 <<
"and non-zero additional compute.";
1295 LDBG() <<
"Setting compute tolerance to zero since "
1296 <<
"source has a cyclic dependence.";
1297 computeToleranceThresholdToUse = 0.0;
1301 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1302 if (!maximalFusion) {
1308 depthSliceUnions, maxLegalFusionDepth,
1310 computeToleranceThresholdToUse))
1314 assert(bestDstLoopDepth > 0 &&
"Unexpected loop fusion depth");
1316 const ComputationSliceState &bestSlice =
1317 depthSliceUnions[bestDstLoopDepth - 1];
1318 assert(!bestSlice.
isEmpty() &&
1319 "Fusion depth has no computed slice union");
1325 if (!isMaximal.value_or(
false)) {
1326 LDBG() <<
"Slice isn't maximal; not performing sibling fusion.";
1333 bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
1336 isInnermostInsertion);
1338 auto dstForInst = cast<AffineForOp>(dstNode->op);
1340 if (insertPointInst != dstForInst)
1341 dstForInst->moveBefore(insertPointInst);
1343 LDBG() <<
"Fused sibling nest " << sibId <<
" into destination nest "
1344 << dstNode->id <<
" at depth " << bestDstLoopDepth <<
":";
1345 LDBG() << dstAffineForOp;
1348 updateStateAfterSiblingFusion(sibNode, dstNode);
1352 Operation *op = sibNode->op;
1353 mdg->removeNode(sibNode->id);
1362 bool findSiblingNodeToFuse(Node *dstNode,
1364 std::pair<unsigned, Value> *idAndMemrefToFuse) {
1367 auto canFuseWithSibNode = [&](Node *sibNode, Value memref) {
1370 if (sibNode->getLoadOpCount(memref) != 1)
1374 if (mdg->hasDependencePath(sibNode->id, dstNode->id) ||
1375 mdg->hasDependencePath(dstNode->id, sibNode->id))
1380 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
1381 if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) {
1382 return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0;
1388 for (
auto *storeOpInst : sibNode->stores) {
1389 storeMemrefs.insert(
1390 cast<AffineWriteOpInterface>(storeOpInst).
getMemRef());
1392 return storeMemrefs.size() <= 1;
1396 Block *block = dstNode->op->getBlock();
1399 auto loadOp = dyn_cast<AffineReadOpInterface>(user);
1403 SmallVector<AffineForOp, 4> loops;
1408 auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
1409 return loop->getBlock() == &mdg->block;
1412 if (it == loops.end())
1414 Node *sibNode = mdg->getForOpNode(*it);
1415 assert(sibNode !=
nullptr);
1417 if (sibNode->id == dstNode->id)
1420 if (visitedSibNodeIds->count(sibNode->id) > 0)
1423 auto memref = loadOp.getMemRef();
1424 if (dstNode->getLoadOpCount(memref) == 0)
1427 if (canFuseWithSibNode(sibNode, memref)) {
1428 visitedSibNodeIds->insert(sibNode->id);
1429 idAndMemrefToFuse->first = sibNode->id;
1430 idAndMemrefToFuse->second = memref;
1438 SmallVector<MemRefDependenceGraph::Edge, 2> inEdges;
1439 mdg->forEachMemRefInputEdge(
1440 dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
1443 if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
1444 (mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0 ||
1445 inEdge.value.getDefiningOp() == mdg->getNode(inEdge.id)->op))
1446 inEdges.push_back(inEdge);
1451 for (
auto &inEdge : inEdges) {
1453 SmallVector<MemRefDependenceGraph::Edge, 2> outEdges;
1454 mdg->forEachMemRefOutputEdge(
1455 inEdge.
id, [&](MemRefDependenceGraph::Edge outEdge) {
1456 unsigned sibNodeId = outEdge.id;
1457 if (visitedSibNodeIds->count(sibNodeId) > 0)
1460 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
1462 auto *sibNode = mdg->getNode(sibNodeId);
1463 if (!isa<AffineForOp>(sibNode->op))
1466 if (canFuseWithSibNode(sibNode, outEdge.value)) {
1468 outEdges.push_back(outEdge);
1473 if (!outEdges.empty()) {
1474 visitedSibNodeIds->insert(outEdges[0].
id);
1475 idAndMemrefToFuse->first = outEdges[0].id;
1476 idAndMemrefToFuse->second = outEdges[0].value;
1485 void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) {
1490 auto dstForInst = cast<AffineForOp>(dstNode->op);
1492 dstLoopCollector.
collect(dstForInst);
1501 void eraseUnusedMemRefAllocations() {
1503 if (pair.second > 0)
1505 auto memref = pair.first;
1510 auto *op =
memref.getDefiningOp();
1511 if (isa_and_nonnull<memref::AllocOp>(op))
1520void LoopFusion::runOnBlock(
Block *block) {
1521 MemRefDependenceGraph g(*block);
1523 LDBG() <<
"MDG init failed";
1527 std::optional<unsigned> fastMemorySpaceOpt;
1528 if (fastMemorySpace.hasValue())
1529 fastMemorySpaceOpt = fastMemorySpace;
1530 unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
1531 GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
1532 maximalFusion, computeToleranceThreshold);
1534 if (affineFusionMode == FusionMode::ProducerConsumer)
1535 fusion.runProducerConsumerFusionOnly();
1536 else if (affineFusionMode == FusionMode::Sibling)
1537 fusion.runSiblingFusionOnly();
1539 fusion.runGreedyFusion();
1542void LoopFusion::runOnOperation() {
1545 getOperation()->walk([&](Operation *op) {
1547 for (
Block &block : region.getBlocks()) {
1548 auto affineFors = block.
getOps<AffineForOp>();
1549 if (!affineFors.empty() && !llvm::hasSingleElement(affineFors))
1557 unsigned fastMemorySpace, uint64_t localBufSizeThreshold,
1558 bool maximalFusion,
enum FusionMode affineFusionMode) {
1559 return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
1560 maximalFusion, affineFusionMode);