27 #include "llvm/ADT/DenseMap.h"
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
39 #define GEN_PASS_DEF_AFFINELOOPFUSION
40 #include "mlir/Dialect/Affine/Passes.h.inc"
44 #define DEBUG_TYPE "affine-loop-fusion"
55 struct LoopFusion :
public affine::impl::AffineLoopFusionBase<LoopFusion> {
56 LoopFusion() =
default;
57 LoopFusion(
unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes,
58 bool maximalFusion,
enum FusionMode affineFusionMode) {
59 this->fastMemorySpace = fastMemorySpace;
60 this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024;
61 this->maximalFusion = maximalFusion;
62 this->affineFusionMode = affineFusionMode;
65 void runOnBlock(
Block *block);
66 void runOnOperation()
override;
84 bool hasOutDepsAfterFusion =
false;
86 for (
auto &outEdge : mdg->
outEdges[srcId]) {
89 if (depNodeOp == dstNodeOp)
99 if (fusedLoopInsPoint != depNodeOp &&
101 LLVM_DEBUG(llvm::dbgs() <<
"Src loop can't be removed: dst loop doesn't "
102 "dominate dependence\n");
106 hasOutDepsAfterFusion =
true;
112 if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
113 std::optional<bool> isMaximal = fusionSlice.
isMaximal();
115 LLVM_DEBUG(llvm::dbgs() <<
"Src loop can't be removed: can't determine "
116 "if fusion is maximal\n");
121 LLVM_DEBUG(llvm::dbgs()
122 <<
"Src loop can't be removed: fusion is not maximal\n");
140 if (mdg->
inEdges.count(dstId) == 0)
144 auto *dstNode = mdg->
getNode(dstId);
147 consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
151 for (
auto &srcEdge : mdg->
inEdges[dstId]) {
152 auto *srcNode = mdg->
getNode(srcEdge.id);
154 if (!isa<AffineForOp>(srcNode->op))
157 if (any_of(srcNode->stores, [&](
Operation *op) {
158 auto storeOp = cast<AffineWriteOpInterface>(op);
159 return consumedMemrefs.count(storeOp.getMemRef()) > 0;
161 srcIdCandidates.push_back(srcNode->id);
164 llvm::sort(srcIdCandidates);
165 srcIdCandidates.erase(llvm::unique(srcIdCandidates), srcIdCandidates.end());
174 auto *dstNode = mdg->
getNode(dstId);
175 auto *srcNode = mdg->
getNode(srcId);
177 producerConsumerMemrefs);
194 if (
auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
199 if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(defOp, memref))
206 Operation *ancestorOp = block->getParent()->findAncestorOpInRegion(*user);
209 if (ancestorOp->getBlock() != block)
211 return !isa<AffineMapAccessInterface>(*user);
220 for (
Operation *storeOp : node->stores) {
221 auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
222 if (escapingMemRefs.count(memref))
225 escapingMemRefs.insert(memref);
235 assert(isa<AffineForOp>(node->
op));
237 node->
op = newRootForOp;
246 unsigned dstLoopDepth,
247 std::optional<unsigned> fastMemorySpace,
248 uint64_t localBufSizeThreshold) {
249 Operation *forInst = forOp.getOperation();
256 auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
257 auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
258 unsigned rank = oldMemRefType.getRank();
262 bool validRegion = succeeded(region.
compute(srcStoreOpInst, dstLoopDepth));
264 assert(validRegion &&
"unexpected memref region failure");
266 std::vector<SmallVector<int64_t, 4>> lbs;
271 std::optional<int64_t> numElements =
273 assert(numElements &&
"non-constant number of elts in local buffer");
284 offsets.reserve(rank);
285 for (
unsigned d = 0; d < rank; ++d) {
286 assert(lbs[d].size() == cst->
getNumCols() - rank &&
"incorrect bound size");
289 for (
unsigned j = 0, e = cst->
getNumCols() - rank - 1;
j < e;
j++) {
292 assert(lbDivisors[d] > 0);
294 (offset + lbs[d][cst->
getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
295 offsets.push_back(offset);
301 assert(eltSize &&
"memrefs with size elt types expected");
302 uint64_t bufSize = *eltSize * *numElements;
303 unsigned newMemSpace;
304 if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
305 newMemSpace = *fastMemorySpace;
307 newMemSpace = oldMemRefType.getMemorySpaceAsInt();
309 auto newMemRefType =
MemRefType::get(newShape, oldMemRefType.getElementType(),
318 Value newMemRef = top.
create<memref::AllocOp>(forOp.getLoc(), newMemRefType);
322 remapExprs.reserve(rank);
323 for (
unsigned i = 0; i < rank; i++) {
328 remapExprs.push_back(remapExpr);
332 AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
339 &*forOp.getBody()->begin());
340 assert(succeeded(res) &&
341 "replaceAllMemrefUsesWith should always succeed here");
352 auto *srcNode = mdg->
getNode(srcId);
353 auto *dstNode = mdg->
getNode(dstId);
358 for (
auto &idAndNode : mdg->
nodes) {
361 if (srcNode->op->isBeforeInBlock(op) && op->
isBeforeInBlock(dstNode->op)) {
366 if (isa<AffineMapAccessInterface>(*user))
369 if (llvm::is_contained(users, user))
373 if (walkResult.wasInterrupted())
386 auto *srcNode = mdg->
getNode(srcId);
387 llvm::SmallDenseSet<Value, 2> memRefValues;
390 if (isa<AffineForOp>(op))
394 if (isa<MemRefType>(v.getType()))
395 memRefValues.insert(v);
399 return llvm::any_of(memRefValues, [&](
Value memref) {
444 AffineForOp dstForOp,
446 unsigned maxLegalFusionDepth,
447 unsigned *dstLoopDepth,
448 double computeToleranceThreshold) {
450 llvm::dbgs() <<
"Checking whether fusion is profitable between src op:\n";
451 llvm::dbgs() <<
' ' << *srcOpInst <<
" and destination loop:\n";
452 llvm::dbgs() << dstForOp <<
"\n";
455 if (maxLegalFusionDepth == 0) {
456 LLVM_DEBUG(llvm::dbgs() <<
"Can't fuse: maxLegalFusionDepth is 0\n");
481 double maxStorageReduction = 0.0;
482 std::optional<uint64_t> sliceMemEstimate;
485 std::optional<unsigned> bestDstLoopDepth;
488 uint64_t srcLoopNestCost =
getComputeCost(srcLoopIVs[0], srcLoopNestStats);
492 if (failed(srcWriteRegion.
compute(srcStoreOpInst, 0))) {
493 LLVM_DEBUG(llvm::dbgs()
494 <<
"Unable to compute MemRefRegion for source operation\n");
498 std::optional<int64_t> maybeSrcWriteRegionSizeBytes =
500 if (!maybeSrcWriteRegionSizeBytes.has_value())
502 int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;
505 uint64_t dstLoopNestCost =
getComputeCost(dstForOp, dstLoopNestStats);
509 for (
unsigned i = maxLegalFusionDepth; i >= 1; --i) {
515 int64_t fusedLoopNestComputeCost;
517 dstLoopNestStats, slice,
518 &fusedLoopNestComputeCost)) {
519 LLVM_DEBUG(llvm::dbgs() <<
"Unable to compute fusion compute cost\n");
523 double additionalComputeFraction =
524 fusedLoopNestComputeCost /
525 (
static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
532 if (failed(sliceWriteRegion.
compute(srcStoreOpInst, 0,
534 LLVM_DEBUG(llvm::dbgs()
535 <<
"Failed to compute slice write region at loopDepth: " << i
540 std::optional<int64_t> maybeSliceWriteRegionSizeBytes =
542 if (!maybeSliceWriteRegionSizeBytes.has_value() ||
543 *maybeSliceWriteRegionSizeBytes == 0) {
544 LLVM_DEBUG(llvm::dbgs()
545 <<
"Failed to get slice write region size at loopDepth: " << i
549 int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
555 if (srcOpInst != srcStoreOpInst &&
556 sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
559 double storageReduction =
static_cast<double>(srcWriteRegionSizeBytes) /
560 static_cast<double>(sliceWriteRegionSizeBytes);
563 std::stringstream msg;
564 msg <<
" evaluating fusion profitability at depth : " << i <<
"\n"
565 << std::fixed << std::setprecision(2)
566 <<
" additional compute fraction: "
567 << 100.0 * additionalComputeFraction <<
"%\n"
568 <<
" storage reduction factor: " << storageReduction <<
"x\n"
569 <<
" fused nest cost: " << fusedLoopNestComputeCost <<
"\n"
570 <<
" src write region size: " << srcWriteRegionSizeBytes <<
"\n"
571 <<
" slice write region size: " << sliceWriteRegionSizeBytes
573 llvm::dbgs() << msg.str();
580 if ((storageReduction > maxStorageReduction) &&
581 (additionalComputeFraction < computeToleranceThreshold)) {
582 maxStorageReduction = storageReduction;
583 bestDstLoopDepth = i;
584 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
585 sliceMemEstimate = sliceWriteRegionSizeBytes;
591 if (!bestDstLoopDepth) {
594 <<
"All fusion choices involve more than the threshold amount of "
595 "redundant computation; NOT fusing.\n");
599 if (!bestDstLoopDepth) {
600 LLVM_DEBUG(llvm::dbgs() <<
"no fusion depth could be evaluated.\n");
605 *dstLoopDepth = *bestDstLoopDepth;
608 llvm::dbgs() <<
" LoopFusion fusion stats:"
609 <<
"\n best loop depth: " << bestDstLoopDepth
610 <<
"\n src loop nest compute cost: " << srcLoopNestCost
611 <<
"\n dst loop nest compute cost: " << dstLoopNestCost
612 <<
"\n fused loop nest compute cost: "
613 << minFusedLoopNestComputeCost <<
"\n");
618 std::optional<double> storageReduction;
620 if (!dstMemSize || !srcMemSize) {
621 LLVM_DEBUG(llvm::dbgs()
622 <<
" fusion memory benefit cannot be evaluated; NOT fusing.\n");
626 auto srcMemSizeVal = *srcMemSize;
627 auto dstMemSizeVal = *dstMemSize;
629 assert(sliceMemEstimate &&
"expected value");
630 auto fusedMem = dstMemSizeVal + *sliceMemEstimate;
632 LLVM_DEBUG(llvm::dbgs() <<
" src mem: " << srcMemSizeVal <<
"\n"
633 <<
" dst mem: " << dstMemSizeVal <<
"\n"
634 <<
" fused mem: " << fusedMem <<
"\n"
635 <<
" slice mem: " << sliceMemEstimate <<
"\n");
637 if (
static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
638 LLVM_DEBUG(llvm::dbgs() <<
"Fusion is not profitable; NOT fusing.\n");
643 (1.0 - fusedMem / (
static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
645 double additionalComputeFraction =
646 100.0 * (minFusedLoopNestComputeCost /
647 (
static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
649 (void)additionalComputeFraction;
651 std::stringstream msg;
652 msg <<
" fusion is most profitable at depth " << *dstLoopDepth <<
" with "
653 << std::setprecision(2) << additionalComputeFraction
654 <<
"% redundant computation and a ";
655 msg << (storageReduction ? std::to_string(*storageReduction) :
"<unknown>");
656 msg <<
"% storage reduction.\n";
657 llvm::dbgs() << msg.str();
711 struct GreedyFusion {
718 unsigned localBufSizeThreshold;
720 std::optional<unsigned> fastMemorySpace;
726 double computeToleranceThreshold;
731 std::optional<unsigned> fastMemorySpace,
bool maximalFusion,
732 double computeToleranceThreshold)
733 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
734 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
735 computeToleranceThreshold(computeToleranceThreshold) {}
742 for (
auto &idAndNode : mdg->
nodes) {
743 const Node &node = idAndNode.second;
744 worklist.push_back(node.id);
748 void runSiblingFusionOnly() {
750 eraseUnusedMemRefAllocations();
754 void runProducerConsumerFusionOnly() {
755 fuseProducerConsumerNodes(
757 eraseUnusedMemRefAllocations();
765 void runGreedyFusion() {
767 fuseProducerConsumerNodes(1);
769 fuseProducerConsumerNodes(
771 eraseUnusedMemRefAllocations();
776 bool canCreatePrivateMemRef(
Value memref,
778 unsigned producerId,
unsigned consumerId,
779 bool removeSrcNode) {
780 const Node *consumerNode = mdg->
getNode(consumerId);
788 if (srcEscapingMemRefs.count(memref) > 0 &&
789 (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
802 any_of(mdg->
outEdges[producerId], [&](
const auto &edge) {
803 return edge.value == memref && edge.id != consumerId;
813 void performFusionsIntoDest(
unsigned dstId,
unsigned maxSrcUserCount) {
814 LLVM_DEBUG(llvm::dbgs() <<
"Evaluating dst loop " << dstId <<
"\n");
816 if (mdg->
nodes.count(dstId) == 0)
819 auto *dstNode = mdg->
getNode(dstId);
821 if (!isa<AffineForOp>(dstNode->op))
825 if (dstNode->op->getNumResults() > 0)
828 LLVM_DEBUG(llvm::dbgs() <<
"Evaluating dst loop " << dstId <<
"\n");
835 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
846 dstNodeChanged =
false;
850 for (
unsigned srcId : llvm::reverse(srcIdCandidates)) {
852 auto *srcNode = mdg->
getNode(srcId);
853 auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
854 LLVM_DEBUG(llvm::dbgs() <<
"Evaluating src loop " << srcId
855 <<
" for dst loop " << dstId <<
"\n");
859 if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
864 producerConsumerMemrefs);
868 if (any_of(producerConsumerMemrefs, [&](
Value memref) {
886 if (!srcEscapingMemRefs.empty() &&
888 LLVM_DEBUG(llvm::dbgs()
889 <<
"Can't fuse: non-affine users in between the loops\n");
897 if (fusedLoopInsPoint ==
nullptr)
908 unsigned numSurroundingLoops = surroundingLoops.size();
914 if (producerConsumerMemrefs.count(
915 cast<AffineReadOpInterface>(op).getMemRef()) > 0)
916 dstMemrefOps.push_back(op);
918 if (producerConsumerMemrefs.count(
919 cast<AffineWriteOpInterface>(op).getMemRef()))
920 dstMemrefOps.push_back(op);
921 unsigned dstLoopDepthTest =
926 unsigned maxLegalFusionDepth = 0;
928 depthSliceUnions.resize(dstLoopDepthTest);
930 for (
unsigned i = 1; i <= dstLoopDepthTest; ++i) {
933 i + numSurroundingLoops,
934 &depthSliceUnions[i - 1], strategy);
937 maxLegalFusionDepth = i;
940 if (maxLegalFusionDepth == 0) {
941 LLVM_DEBUG(llvm::dbgs()
942 <<
"Can't fuse: fusion is not legal at any depth\n");
949 unsigned bestDstLoopDepth = maxLegalFusionDepth;
950 if (!maximalFusion) {
954 if (producerConsumerMemrefs.count(
955 cast<AffineWriteOpInterface>(op).getMemRef()))
956 producerStores.push_back(op);
964 assert(!producerStores.empty() &&
"Expected producer store");
965 if (producerStores.size() > 1)
966 LLVM_DEBUG(llvm::dbgs() <<
"Skipping profitability analysis. Not "
967 "supported for this case\n");
969 dstAffineForOp, depthSliceUnions,
970 maxLegalFusionDepth, &bestDstLoopDepth,
971 computeToleranceThreshold))
975 assert(bestDstLoopDepth > 0 &&
"Unexpected loop fusion depth");
977 depthSliceUnions[bestDstLoopDepth - 1];
978 assert(!bestSlice.
isEmpty() &&
"Missing slice union for depth");
984 srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
988 for (
Value memref : producerConsumerMemrefs) {
989 if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
992 LLVM_DEBUG(llvm::dbgs()
993 <<
"Creating private memref for " << memref <<
'\n');
995 privateMemrefs.insert(memref);
1000 fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
1001 dstNodeChanged =
true;
1003 LLVM_DEBUG(llvm::dbgs()
1004 <<
"Fused src loop " << srcId <<
" into dst loop " << dstId
1005 <<
" at depth " << bestDstLoopDepth <<
":\n"
1006 << dstAffineForOp <<
"\n");
1009 if (fusedLoopInsPoint != dstAffineForOp)
1010 dstAffineForOp->moveBefore(fusedLoopInsPoint);
1013 mdg->
updateEdges(srcNode->id, dstNode->id, privateMemrefs,
1017 if (!privateMemrefs.empty()) {
1020 dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
1021 Value storeMemRef = storeOp.getMemRef();
1022 if (privateMemrefs.count(storeMemRef) > 0)
1023 privateMemRefToStores[storeMemRef].push_back(storeOp);
1030 for (
auto &memrefToStoresPair : privateMemRefToStores) {
1034 memrefToStoresPair.second;
1036 dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
1037 fastMemorySpace, localBufSizeThreshold);
1041 mdg->
addEdge(newMemRefNodeId, dstId, newMemRef);
1046 dstNode = mdg->
getNode(dstId);
1051 dstLoopCollector.
collect(dstAffineForOp);
1058 if (removeSrcNode) {
1059 LLVM_DEBUG(llvm::dbgs()
1060 <<
"Removing src loop " << srcId <<
" after fusion\n");
1062 srcAffineForOp.erase();
1067 }
while (dstNodeChanged);
1074 void fuseProducerConsumerNodes(
unsigned maxSrcUserCount) {
1075 LLVM_DEBUG(llvm::dbgs() <<
"--- Producer/Consumer Fusion ---\n");
1077 while (!worklist.empty()) {
1078 unsigned dstId = worklist.back();
1079 worklist.pop_back();
1080 performFusionsIntoDest(dstId, maxSrcUserCount);
1086 void fuseSiblingNodes() {
1087 LLVM_DEBUG(llvm::dbgs() <<
"--- Sibling Fusion ---\n");
1089 while (!worklist.empty()) {
1090 unsigned dstId = worklist.back();
1091 worklist.pop_back();
1094 if (mdg->
nodes.count(dstId) == 0)
1097 auto *dstNode = mdg->
getNode(dstId);
1099 if (!isa<AffineForOp>(dstNode->op))
1102 fuseWithSiblingNodes(dstNode);
1107 void fuseWithSiblingNodes(
Node *dstNode) {
1109 std::pair<unsigned, Value> idAndMemref;
1110 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
1112 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
1113 unsigned sibId = idAndMemref.first;
1114 Value memref = idAndMemref.second;
1117 auto *sibNode = mdg->
getNode(sibId);
1120 assert(sibNode->op->getBlock() == dstNode->op->getBlock());
1125 if (insertPointInst ==
nullptr)
1132 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
1134 assert(sibLoadOpInsts.size() == 1);
1135 Operation *sibLoadOpInst = sibLoadOpInsts[0];
1139 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
1149 unsigned numSurroundingLoops = surroundingLoops.size();
1152 unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
1153 auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
1157 depthSliceUnions.resize(dstLoopDepthTest);
1158 unsigned maxLegalFusionDepth = 0;
1160 for (
unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1163 i + numSurroundingLoops,
1164 &depthSliceUnions[i - 1], strategy);
1167 maxLegalFusionDepth = i;
1170 LLVM_DEBUG(llvm::dbgs() <<
"Max legal depth for fusion: "
1171 << maxLegalFusionDepth <<
'\n');
1174 if (maxLegalFusionDepth == 0)
1177 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1178 if (!maximalFusion) {
1184 depthSliceUnions, maxLegalFusionDepth,
1185 &bestDstLoopDepth, computeToleranceThreshold))
1189 assert(bestDstLoopDepth > 0 &&
"Unexpected loop fusion depth");
1190 assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
1191 "Fusion depth has no computed slice union");
1195 bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
1198 depthSliceUnions[bestDstLoopDepth - 1],
1199 isInnermostInsertion);
1201 auto dstForInst = cast<AffineForOp>(dstNode->op);
1203 if (insertPointInst != dstForInst) {
1204 dstForInst->moveBefore(insertPointInst);
1207 updateStateAfterSiblingFusion(sibNode, dstNode);
1215 bool findSiblingNodeToFuse(
Node *dstNode,
1217 std::pair<unsigned, Value> *idAndMemrefToFuse) {
1220 auto canFuseWithSibNode = [&](
Node *sibNode,
Value memref) {
1223 if (sibNode->getLoadOpCount(memref) != 1)
1233 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
1234 if (llvm::any_of(loadAndStoreMemrefSet, [=](
Value memref) {
1241 for (
auto *storeOpInst : sibNode->stores) {
1242 storeMemrefs.insert(
1243 cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
1245 if (storeMemrefs.size() > 1)
1257 Block *block = dstNode->op->getBlock();
1260 auto loadOp = dyn_cast<AffineReadOpInterface>(user);
1269 auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
1270 return loop->getBlock() == &mdg->
block;
1273 if (it == loops.end())
1276 assert(sibNode !=
nullptr);
1278 if (sibNode->id == dstNode->id)
1281 if (visitedSibNodeIds->count(sibNode->id) > 0)
1284 auto memref = loadOp.getMemRef();
1285 if (dstNode->getLoadOpCount(memref) == 0)
1288 if (canFuseWithSibNode(sibNode, memref)) {
1289 visitedSibNodeIds->insert(sibNode->id);
1290 idAndMemrefToFuse->first = sibNode->id;
1291 idAndMemrefToFuse->second = memref;
1303 if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
1304 mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
1305 inEdges.push_back(inEdge);
1310 for (
auto &inEdge : inEdges) {
1315 unsigned sibNodeId = outEdge.id;
1316 if (visitedSibNodeIds->count(sibNodeId) > 0)
1319 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
1321 auto *sibNode = mdg->getNode(sibNodeId);
1322 if (!isa<AffineForOp>(sibNode->op))
1325 if (canFuseWithSibNode(sibNode, outEdge.value)) {
1327 outEdges.push_back(outEdge);
1332 if (!outEdges.empty()) {
1333 visitedSibNodeIds->insert(outEdges[0].
id);
1334 idAndMemrefToFuse->first = outEdges[0].id;
1335 idAndMemrefToFuse->second = outEdges[0].value;
1344 void updateStateAfterSiblingFusion(
Node *sibNode,
Node *dstNode) {
1349 auto dstForInst = cast<AffineForOp>(dstNode->op);
1351 dstLoopCollector.
collect(dstForInst);
1366 void eraseUnusedMemRefAllocations() {
1368 if (pair.second > 0)
1370 auto memref = pair.first;
1376 if (isa_and_nonnull<memref::AllocOp>(op))
1385 void LoopFusion::runOnBlock(
Block *block) {
1388 LLVM_DEBUG(llvm::dbgs() <<
"MDG init failed\n");
1392 std::optional<unsigned> fastMemorySpaceOpt;
1393 if (fastMemorySpace.hasValue())
1394 fastMemorySpaceOpt = fastMemorySpace;
1395 unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
1396 GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
1397 maximalFusion, computeToleranceThreshold);
1400 fusion.runProducerConsumerFusionOnly();
1402 fusion.runSiblingFusionOnly();
1404 fusion.runGreedyFusion();
1407 void LoopFusion::runOnOperation() {
1410 getOperation()->walk([&](
Operation *op) {
1412 for (Block &block : region.getBlocks()) {
1413 auto affineFors = block.getOps<AffineForOp>();
1414 if (!affineFors.empty() && !llvm::hasSingleElement(affineFors))
1422 unsigned fastMemorySpace, uint64_t localBufSizeThreshold,
1423 bool maximalFusion,
enum FusionMode affineFusionMode) {
1424 return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
1425 maximalFusion, affineFusionMode);
MemRefDependenceGraph::Node Node
static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, AffineForOp dstForOp, ArrayRef< ComputationSliceState > depthSliceUnions, unsigned maxLegalFusionDepth, unsigned *dstLoopDepth, double computeToleranceThreshold)
static bool canRemoveSrcNodeAfterFusion(unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, Operation *fusedLoopInsPoint, const DenseSet< Value > &escapingMemRefs, MemRefDependenceGraph *mdg)
Returns true if node 'srcId' can be removed after fusing it with node 'dstId'.
static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, Value memref, MemRefDependenceGraph *mdg)
Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and 'dstId'), if there is any non-aff...
static bool isEscapingMemref(Value memref, Block *block)
A memref escapes in the context of the fusion pass if either:
static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, DenseSet< Value > &escapingMemRefs)
Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' that escape the block or ...
static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, unsigned dstLoopDepth, std::optional< unsigned > fastMemorySpace, uint64_t localBufSizeThreshold)
static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg, SmallVectorImpl< unsigned > &srcIdCandidates)
Returns in 'srcIdCandidates' the producer fusion candidates for consumer 'dstId'.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
AffineExpr getAffineConstantExpr(int64_t constant)
AffineExpr getAffineDimExpr(unsigned position)
void getValues(unsigned start, unsigned end, SmallVectorImpl< Value > *values) const
Returns the Values associated with variables in range [start, end).
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
Describes the fusion strategy to be used in the Affine loop fusion utilities.
unsigned getNumVars() const
unsigned getNumCols() const
Returns the number of columns in the constraint system.
bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, AffineForOp dstForOp, LoopNestStats &dstStats, const ComputationSliceState &slice, int64_t *computeCost)
Computes and returns in 'computeCost', the total compute cost of fusing the 'slice' of the loop nest ...
void gatherProducerConsumerMemrefs(ArrayRef< Operation * > srcOps, ArrayRef< Operation * > dstOps, DenseSet< Value > &producerConsumerMemrefs)
Returns in 'producerConsumerMemrefs' the memrefs involved in a producer-consumer dependence between w...
int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats)
Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
void fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, const ComputationSliceState &srcSlice, bool isInnermostSiblingInsertionFusion=false)
Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point and source slice loop bo...
void getAffineForIVs(Operation &op, SmallVectorImpl< AffineForOp > *loops)
Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered from the outermost 'affine....
std::optional< int64_t > getMemoryFootprintBytes(AffineForOp forOp, int memorySpace=-1)
Gets the memory footprint of all data touched in the specified memory space in bytes; if the memory s...
std::unique_ptr< Pass > createLoopFusionPass(unsigned fastMemorySpace=0, uint64_t localBufSizeThreshold=0, bool maximalFusion=false, enum FusionMode fusionMode=FusionMode::Greedy)
Creates a loop fusion pass which fuses affine loop nests at the top-level of the operation the pass i...
FusionMode
Fusion mode to attempt.
unsigned getInnermostCommonLoopDepth(ArrayRef< Operation * > ops, SmallVectorImpl< AffineForOp > *surroundingLoops=nullptr)
Returns the innermost common loop depth for the set of operations in 'ops'.
bool getLoopNestStats(AffineForOp forOp, LoopNestStats *stats)
Collect loop nest statistics (eg.
AffineForOp sinkSequentialLoops(AffineForOp forOp)
FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, ComputationSliceState *srcSlice, FusionStrategy fusionStrategy=FusionStrategy::Generic)
Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the loop nest rooted at 'dst...
LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, ArrayRef< Value > extraIndices={}, AffineMap indexRemap=AffineMap(), ArrayRef< Value > extraOperands={}, ArrayRef< Value > symbolOperands={}, Operation *domOpFilter=nullptr, Operation *postDomOpFilter=nullptr, bool allowNonDereferencingOps=false, bool replaceInDeallocOp=false)
Replaces all "dereferencing" uses of oldMemRef with newMemRef while optionally remapping the old memr...
std::optional< int64_t > getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType)
Returns the memref's element type's size in bytes where the elemental type is an int or float or a ve...
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their associated operands for a ...
bool isEmpty() const
Returns true if the computation slice is empty.
std::optional< bool > isMaximal() const
Returns true if the computation slice encloses all the iterations of the sliced loop nest.
enum mlir::affine::FusionResult::ResultEnum value
SmallVector< Operation *, 4 > loadOpInsts
void collect(Operation *opToWalk)
SmallVector< Operation *, 4 > storeOpInsts
LoopNestStats aggregates various per-loop statistics (eg.
DenseMap< unsigned, SmallVector< Edge, 2 > > outEdges
Block & block
The block for which this graph is created to perform fusion.
unsigned addNode(Operation *op)
unsigned getIncomingMemRefAccesses(unsigned id, Value memref)
void addEdge(unsigned srcId, unsigned dstId, Value value)
Node * getForOpNode(AffineForOp forOp)
DenseMap< unsigned, Node > nodes
bool hasDependencePath(unsigned srcId, unsigned dstId)
void clearNodeLoadAndStores(unsigned id)
void updateEdges(unsigned srcId, unsigned dstId, const DenseSet< Value > &privateMemRefs, bool removeSrcId)
DenseMap< unsigned, SmallVector< Edge, 2 > > inEdges
void addToNode(unsigned id, const SmallVectorImpl< Operation * > &loads, const SmallVectorImpl< Operation * > &stores)
void forEachMemRefInputEdge(unsigned id, const std::function< void(Edge)> &callback)
unsigned getOutEdgeCount(unsigned id, Value memref=nullptr)
void removeNode(unsigned id)
Operation * getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId)
void forEachMemRefOutputEdge(unsigned id, const std::function< void(Edge)> &callback)
Node * getNode(unsigned id)
DenseMap< Value, unsigned > memrefEdgeCount
A region of a memref's data space; this is typically constructed by analyzing load/store op's on this...
FlatAffineValueConstraints * getConstraints()
std::optional< int64_t > getConstantBoundingSizeAndShape(SmallVectorImpl< int64_t > *shape=nullptr, std::vector< SmallVector< int64_t, 4 >> *lbs=nullptr, SmallVectorImpl< int64_t > *lbDivisors=nullptr) const
Returns a constant upper bound on the number of elements in this region if bounded by a known constan...
std::optional< int64_t > getRegionSize()
Returns the size of this MemRefRegion in bytes.
LogicalResult compute(Operation *op, unsigned loopDepth, const ComputationSliceState *sliceState=nullptr, bool addMemRefDimBounds=true)
Computes the memory region accessed by this memref with the region represented as constraints symboli...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.