27 #include "llvm/ADT/DenseMap.h"
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
39 #define GEN_PASS_DEF_AFFINELOOPFUSION
40 #include "mlir/Dialect/Affine/Passes.h.inc"
44 #define DEBUG_TYPE "affine-loop-fusion"
55 struct LoopFusion :
public affine::impl::AffineLoopFusionBase<LoopFusion> {
56 LoopFusion() =
default;
57 LoopFusion(
unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes,
58 bool maximalFusion,
enum FusionMode affineFusionMode) {
59 this->fastMemorySpace = fastMemorySpace;
60 this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024;
61 this->maximalFusion = maximalFusion;
62 this->affineFusionMode = affineFusionMode;
65 void runOnBlock(
Block *block);
66 void runOnOperation()
override;
84 bool hasOutDepsAfterFusion =
false;
86 for (
auto &outEdge : mdg->
outEdges[srcId]) {
89 if (depNodeOp == dstNodeOp)
99 if (fusedLoopInsPoint != depNodeOp &&
101 LLVM_DEBUG(llvm::dbgs() <<
"Src loop can't be removed: dst loop doesn't "
102 "dominate dependence\n");
106 hasOutDepsAfterFusion =
true;
112 if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
113 std::optional<bool> isMaximal = fusionSlice.
isMaximal();
115 LLVM_DEBUG(llvm::dbgs() <<
"Src loop can't be removed: can't determine "
116 "if fusion is maximal\n");
121 LLVM_DEBUG(llvm::dbgs()
122 <<
"Src loop can't be removed: fusion is not maximal\n");
140 if (mdg->
inEdges.count(dstId) == 0)
144 auto *dstNode = mdg->
getNode(dstId);
147 consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
151 for (
auto &srcEdge : mdg->
inEdges[dstId]) {
152 auto *srcNode = mdg->
getNode(srcEdge.id);
154 if (!isa<AffineForOp>(srcNode->op))
157 if (any_of(srcNode->stores, [&](
Operation *op) {
158 auto storeOp = cast<AffineWriteOpInterface>(op);
159 return consumedMemrefs.count(storeOp.getMemRef()) > 0;
161 srcIdCandidates.push_back(srcNode->id);
164 llvm::sort(srcIdCandidates);
165 srcIdCandidates.erase(llvm::unique(srcIdCandidates), srcIdCandidates.end());
174 auto *dstNode = mdg->
getNode(dstId);
175 auto *srcNode = mdg->
getNode(srcId);
177 producerConsumerMemrefs);
194 if (
auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
199 if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(defOp, memref))
206 Operation *ancestorOp = block->getParent()->findAncestorOpInRegion(*user);
209 if (ancestorOp->getBlock() != block)
211 return !isa<AffineMapAccessInterface>(*user);
220 for (
Operation *storeOp : node->stores) {
221 auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
222 if (escapingMemRefs.count(memref))
225 escapingMemRefs.insert(memref);
235 assert(isa<AffineForOp>(node->
op));
237 node->
op = newRootForOp;
246 unsigned dstLoopDepth,
247 std::optional<unsigned> fastMemorySpace,
248 uint64_t localBufSizeThreshold) {
249 Operation *forInst = forOp.getOperation();
256 auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
257 auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
258 unsigned rank = oldMemRefType.getRank();
262 bool validRegion = succeeded(region.
compute(srcStoreOpInst, dstLoopDepth));
264 assert(validRegion &&
"unexpected memref region failure");
266 std::vector<SmallVector<int64_t, 4>> lbs;
271 std::optional<int64_t> numElements =
273 assert(numElements &&
"non-constant number of elts in local buffer");
284 offsets.reserve(rank);
285 for (
unsigned d = 0; d < rank; ++d) {
286 assert(lbs[d].size() == cst->
getNumCols() - rank &&
"incorrect bound size");
289 for (
unsigned j = 0, e = cst->
getNumCols() - rank - 1;
j < e;
j++) {
292 assert(lbDivisors[d] > 0);
294 (offset + lbs[d][cst->
getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
295 offsets.push_back(offset);
301 assert(eltSize &&
"memrefs with size elt types expected");
302 uint64_t bufSize = *eltSize * *numElements;
303 unsigned newMemSpace;
304 if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
305 newMemSpace = *fastMemorySpace;
307 newMemSpace = oldMemRefType.getMemorySpaceAsInt();
309 auto newMemRefType =
MemRefType::get(newShape, oldMemRefType.getElementType(),
318 Value newMemRef = top.
create<memref::AllocOp>(forOp.getLoc(), newMemRefType);
322 remapExprs.reserve(rank);
323 for (
unsigned i = 0; i < rank; i++) {
328 remapExprs.push_back(remapExpr);
332 AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
339 &*forOp.getBody()->begin());
340 assert(succeeded(res) &&
341 "replaceAllMemrefUsesWith should always succeed here");
352 assert(start->
isBeforeInBlock(end) &&
"start expected to be before end");
357 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(user))
359 Operation *ancestor = block->findAncestorOpInBlock(*user);
360 return ancestor && start->isBeforeInBlock(ancestor) &&
361 ancestor->isBeforeInBlock(end);
377 llvm::SmallDenseSet<Value, 2> memRefValues;
381 if (isa<MemRefType>(v.getType()))
382 memRefValues.insert(v);
386 return llvm::any_of(memRefValues, [&](
Value memref) {
431 AffineForOp dstForOp,
433 unsigned maxLegalFusionDepth,
434 unsigned *dstLoopDepth,
435 double computeToleranceThreshold) {
437 llvm::dbgs() <<
"Checking whether fusion is profitable between src op:\n";
438 llvm::dbgs() <<
' ' << *srcOpInst <<
" and destination loop:\n";
439 llvm::dbgs() << dstForOp <<
"\n";
442 if (maxLegalFusionDepth == 0) {
443 LLVM_DEBUG(llvm::dbgs() <<
"Can't fuse: maxLegalFusionDepth is 0\n");
468 double maxStorageReduction = 0.0;
469 std::optional<uint64_t> sliceMemEstimate;
472 std::optional<unsigned> bestDstLoopDepth;
475 uint64_t srcLoopNestCost =
getComputeCost(srcLoopIVs[0], srcLoopNestStats);
479 if (failed(srcWriteRegion.
compute(srcStoreOpInst, 0))) {
480 LLVM_DEBUG(llvm::dbgs()
481 <<
"Unable to compute MemRefRegion for source operation\n");
485 std::optional<int64_t> maybeSrcWriteRegionSizeBytes =
487 if (!maybeSrcWriteRegionSizeBytes.has_value())
489 int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;
492 uint64_t dstLoopNestCost =
getComputeCost(dstForOp, dstLoopNestStats);
496 for (
unsigned i = maxLegalFusionDepth; i >= 1; --i) {
502 int64_t fusedLoopNestComputeCost;
504 dstLoopNestStats, slice,
505 &fusedLoopNestComputeCost)) {
506 LLVM_DEBUG(llvm::dbgs() <<
"Unable to compute fusion compute cost\n");
510 double additionalComputeFraction =
511 fusedLoopNestComputeCost /
512 (
static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
519 if (failed(sliceWriteRegion.
compute(srcStoreOpInst, 0,
521 LLVM_DEBUG(llvm::dbgs()
522 <<
"Failed to compute slice write region at loopDepth: " << i
527 std::optional<int64_t> maybeSliceWriteRegionSizeBytes =
529 if (!maybeSliceWriteRegionSizeBytes.has_value() ||
530 *maybeSliceWriteRegionSizeBytes == 0) {
531 LLVM_DEBUG(llvm::dbgs()
532 <<
"Failed to get slice write region size at loopDepth: " << i
536 int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
542 if (srcOpInst != srcStoreOpInst &&
543 sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
546 double storageReduction =
static_cast<double>(srcWriteRegionSizeBytes) /
547 static_cast<double>(sliceWriteRegionSizeBytes);
550 std::stringstream msg;
551 msg <<
" evaluating fusion profitability at depth : " << i <<
"\n"
552 << std::fixed << std::setprecision(2)
553 <<
" additional compute fraction: "
554 << 100.0 * additionalComputeFraction <<
"%\n"
555 <<
" storage reduction factor: " << storageReduction <<
"x\n"
556 <<
" fused nest cost: " << fusedLoopNestComputeCost <<
"\n"
557 <<
" src write region size: " << srcWriteRegionSizeBytes <<
"\n"
558 <<
" slice write region size: " << sliceWriteRegionSizeBytes
560 llvm::dbgs() << msg.str();
567 if ((storageReduction > maxStorageReduction) &&
568 (additionalComputeFraction < computeToleranceThreshold)) {
569 maxStorageReduction = storageReduction;
570 bestDstLoopDepth = i;
571 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
572 sliceMemEstimate = sliceWriteRegionSizeBytes;
578 if (!bestDstLoopDepth) {
581 <<
"All fusion choices involve more than the threshold amount of "
582 "redundant computation; NOT fusing.\n");
586 if (!bestDstLoopDepth) {
587 LLVM_DEBUG(llvm::dbgs() <<
"no fusion depth could be evaluated.\n");
592 *dstLoopDepth = *bestDstLoopDepth;
595 llvm::dbgs() <<
" LoopFusion fusion stats:"
596 <<
"\n best loop depth: " << bestDstLoopDepth
597 <<
"\n src loop nest compute cost: " << srcLoopNestCost
598 <<
"\n dst loop nest compute cost: " << dstLoopNestCost
599 <<
"\n fused loop nest compute cost: "
600 << minFusedLoopNestComputeCost <<
"\n");
605 std::optional<double> storageReduction;
607 if (!dstMemSize || !srcMemSize) {
608 LLVM_DEBUG(llvm::dbgs()
609 <<
" fusion memory benefit cannot be evaluated; NOT fusing.\n");
613 auto srcMemSizeVal = *srcMemSize;
614 auto dstMemSizeVal = *dstMemSize;
616 assert(sliceMemEstimate &&
"expected value");
617 auto fusedMem = dstMemSizeVal + *sliceMemEstimate;
619 LLVM_DEBUG(llvm::dbgs() <<
" src mem: " << srcMemSizeVal <<
"\n"
620 <<
" dst mem: " << dstMemSizeVal <<
"\n"
621 <<
" fused mem: " << fusedMem <<
"\n"
622 <<
" slice mem: " << sliceMemEstimate <<
"\n");
624 if (
static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
625 LLVM_DEBUG(llvm::dbgs() <<
"Fusion is not profitable; NOT fusing.\n");
630 (1.0 - fusedMem / (
static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
632 double additionalComputeFraction =
633 100.0 * (minFusedLoopNestComputeCost /
634 (
static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
636 (void)additionalComputeFraction;
638 std::stringstream msg;
639 msg <<
" fusion is most profitable at depth " << *dstLoopDepth <<
" with "
640 << std::setprecision(2) << additionalComputeFraction
641 <<
"% redundant computation and a ";
642 msg << (storageReduction ? std::to_string(*storageReduction) :
"<unknown>");
643 msg <<
"% storage reduction.\n";
644 llvm::dbgs() << msg.str();
698 struct GreedyFusion {
705 unsigned localBufSizeThreshold;
707 std::optional<unsigned> fastMemorySpace;
713 double computeToleranceThreshold;
718 std::optional<unsigned> fastMemorySpace,
bool maximalFusion,
719 double computeToleranceThreshold)
720 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
721 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
722 computeToleranceThreshold(computeToleranceThreshold) {}
729 for (
auto &idAndNode : mdg->
nodes) {
730 const Node &node = idAndNode.second;
731 worklist.push_back(node.id);
735 void runSiblingFusionOnly() {
737 eraseUnusedMemRefAllocations();
741 void runProducerConsumerFusionOnly() {
742 fuseProducerConsumerNodes(
744 eraseUnusedMemRefAllocations();
752 void runGreedyFusion() {
754 fuseProducerConsumerNodes(1);
756 fuseProducerConsumerNodes(
758 eraseUnusedMemRefAllocations();
763 bool canCreatePrivateMemRef(
Value memref,
765 unsigned producerId,
unsigned consumerId,
766 bool removeSrcNode) {
767 const Node *consumerNode = mdg->
getNode(consumerId);
775 if (srcEscapingMemRefs.count(memref) > 0 &&
776 (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
789 any_of(mdg->
outEdges[producerId], [&](
const auto &edge) {
790 return edge.value == memref && edge.id != consumerId;
800 void performFusionsIntoDest(
unsigned dstId,
unsigned maxSrcUserCount) {
801 LLVM_DEBUG(llvm::dbgs() <<
"Evaluating dst loop " << dstId <<
"\n");
803 if (mdg->
nodes.count(dstId) == 0)
806 auto *dstNode = mdg->
getNode(dstId);
808 if (!isa<AffineForOp>(dstNode->op))
812 if (dstNode->op->getNumResults() > 0)
815 LLVM_DEBUG(llvm::dbgs() <<
"Evaluating dst loop " << dstId <<
"\n");
822 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
833 dstNodeChanged =
false;
837 for (
unsigned srcId : llvm::reverse(srcIdCandidates)) {
839 auto *srcNode = mdg->
getNode(srcId);
840 auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
841 LLVM_DEBUG(llvm::dbgs() <<
"Evaluating src loop " << srcId
842 <<
" for dst loop " << dstId <<
"\n");
846 if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
851 producerConsumerMemrefs);
855 if (any_of(producerConsumerMemrefs, [&](
Value memref) {
873 if (!srcEscapingMemRefs.empty() &&
875 LLVM_DEBUG(llvm::dbgs()
876 <<
"Can't fuse: non-affine users in between the loops\n");
884 if (fusedLoopInsPoint ==
nullptr)
895 unsigned numSurroundingLoops = surroundingLoops.size();
901 if (producerConsumerMemrefs.count(
902 cast<AffineReadOpInterface>(op).getMemRef()) > 0)
903 dstMemrefOps.push_back(op);
905 if (producerConsumerMemrefs.count(
906 cast<AffineWriteOpInterface>(op).getMemRef()))
907 dstMemrefOps.push_back(op);
908 unsigned dstLoopDepthTest =
913 unsigned maxLegalFusionDepth = 0;
915 depthSliceUnions.resize(dstLoopDepthTest);
917 for (
unsigned i = 1; i <= dstLoopDepthTest; ++i) {
920 i + numSurroundingLoops,
921 &depthSliceUnions[i - 1], strategy);
924 maxLegalFusionDepth = i;
927 if (maxLegalFusionDepth == 0) {
928 LLVM_DEBUG(llvm::dbgs()
929 <<
"Can't fuse: fusion is not legal at any depth\n");
936 unsigned bestDstLoopDepth = maxLegalFusionDepth;
937 if (!maximalFusion) {
941 if (producerConsumerMemrefs.count(
942 cast<AffineWriteOpInterface>(op).getMemRef()))
943 producerStores.push_back(op);
951 assert(!producerStores.empty() &&
"Expected producer store");
952 if (producerStores.size() > 1)
953 LLVM_DEBUG(llvm::dbgs() <<
"Skipping profitability analysis. Not "
954 "supported for this case\n");
956 dstAffineForOp, depthSliceUnions,
957 maxLegalFusionDepth, &bestDstLoopDepth,
958 computeToleranceThreshold))
962 assert(bestDstLoopDepth > 0 &&
"Unexpected loop fusion depth");
964 depthSliceUnions[bestDstLoopDepth - 1];
965 assert(!bestSlice.
isEmpty() &&
"Missing slice union for depth");
971 srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
975 for (
Value memref : producerConsumerMemrefs) {
976 if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
979 LLVM_DEBUG(llvm::dbgs()
980 <<
"Creating private memref for " << memref <<
'\n');
982 privateMemrefs.insert(memref);
987 fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
988 dstNodeChanged =
true;
990 LLVM_DEBUG(llvm::dbgs()
991 <<
"Fused src loop " << srcId <<
" into dst loop " << dstId
992 <<
" at depth " << bestDstLoopDepth <<
":\n"
993 << dstAffineForOp <<
"\n");
996 if (fusedLoopInsPoint != dstAffineForOp)
997 dstAffineForOp->moveBefore(fusedLoopInsPoint);
1000 mdg->
updateEdges(srcNode->id, dstNode->id, privateMemrefs,
1004 if (!privateMemrefs.empty()) {
1007 dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
1008 Value storeMemRef = storeOp.getMemRef();
1009 if (privateMemrefs.count(storeMemRef) > 0)
1010 privateMemRefToStores[storeMemRef].push_back(storeOp);
1017 for (
auto &memrefToStoresPair : privateMemRefToStores) {
1021 memrefToStoresPair.second;
1023 dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
1024 fastMemorySpace, localBufSizeThreshold);
1028 mdg->
addEdge(newMemRefNodeId, dstId, newMemRef);
1033 dstNode = mdg->
getNode(dstId);
1038 dstLoopCollector.
collect(dstAffineForOp);
1045 if (removeSrcNode) {
1046 LLVM_DEBUG(llvm::dbgs()
1047 <<
"Removing src loop " << srcId <<
" after fusion\n");
1049 srcAffineForOp.erase();
1054 }
while (dstNodeChanged);
1061 void fuseProducerConsumerNodes(
unsigned maxSrcUserCount) {
1062 LLVM_DEBUG(llvm::dbgs() <<
"--- Producer/Consumer Fusion ---\n");
1064 while (!worklist.empty()) {
1065 unsigned dstId = worklist.back();
1066 worklist.pop_back();
1067 performFusionsIntoDest(dstId, maxSrcUserCount);
1073 void fuseSiblingNodes() {
1074 LLVM_DEBUG(llvm::dbgs() <<
"--- Sibling Fusion ---\n");
1076 while (!worklist.empty()) {
1077 unsigned dstId = worklist.back();
1078 worklist.pop_back();
1081 if (mdg->
nodes.count(dstId) == 0)
1084 auto *dstNode = mdg->
getNode(dstId);
1086 if (!isa<AffineForOp>(dstNode->op))
1089 fuseWithSiblingNodes(dstNode);
1094 void fuseWithSiblingNodes(
Node *dstNode) {
1096 std::pair<unsigned, Value> idAndMemref;
1097 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
1099 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
1100 unsigned sibId = idAndMemref.first;
1101 Value memref = idAndMemref.second;
1104 auto *sibNode = mdg->
getNode(sibId);
1107 assert(sibNode->op->getBlock() == dstNode->op->getBlock());
1112 if (insertPointInst ==
nullptr)
1119 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
1121 assert(sibLoadOpInsts.size() == 1);
1122 Operation *sibLoadOpInst = sibLoadOpInsts[0];
1126 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
1136 unsigned numSurroundingLoops = surroundingLoops.size();
1139 unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
1140 auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
1144 depthSliceUnions.resize(dstLoopDepthTest);
1145 unsigned maxLegalFusionDepth = 0;
1147 for (
unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1150 i + numSurroundingLoops,
1151 &depthSliceUnions[i - 1], strategy);
1154 maxLegalFusionDepth = i;
1157 LLVM_DEBUG(llvm::dbgs() <<
"Max legal depth for fusion: "
1158 << maxLegalFusionDepth <<
'\n');
1161 if (maxLegalFusionDepth == 0)
1164 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1165 if (!maximalFusion) {
1171 depthSliceUnions, maxLegalFusionDepth,
1172 &bestDstLoopDepth, computeToleranceThreshold))
1176 assert(bestDstLoopDepth > 0 &&
"Unexpected loop fusion depth");
1177 assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
1178 "Fusion depth has no computed slice union");
1182 bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
1185 depthSliceUnions[bestDstLoopDepth - 1],
1186 isInnermostInsertion);
1188 auto dstForInst = cast<AffineForOp>(dstNode->op);
1190 if (insertPointInst != dstForInst) {
1191 dstForInst->moveBefore(insertPointInst);
1194 updateStateAfterSiblingFusion(sibNode, dstNode);
1202 bool findSiblingNodeToFuse(
Node *dstNode,
1204 std::pair<unsigned, Value> *idAndMemrefToFuse) {
1207 auto canFuseWithSibNode = [&](
Node *sibNode,
Value memref) {
1210 if (sibNode->getLoadOpCount(memref) != 1)
1220 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
1221 if (llvm::any_of(loadAndStoreMemrefSet, [=](
Value memref) {
1228 for (
auto *storeOpInst : sibNode->stores) {
1229 storeMemrefs.insert(
1230 cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
1232 if (storeMemrefs.size() > 1)
1244 Block *block = dstNode->op->getBlock();
1247 auto loadOp = dyn_cast<AffineReadOpInterface>(user);
1256 auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
1257 return loop->getBlock() == &mdg->
block;
1260 if (it == loops.end())
1263 assert(sibNode !=
nullptr);
1265 if (sibNode->id == dstNode->id)
1268 if (visitedSibNodeIds->count(sibNode->id) > 0)
1271 auto memref = loadOp.getMemRef();
1272 if (dstNode->getLoadOpCount(memref) == 0)
1275 if (canFuseWithSibNode(sibNode, memref)) {
1276 visitedSibNodeIds->insert(sibNode->id);
1277 idAndMemrefToFuse->first = sibNode->id;
1278 idAndMemrefToFuse->second = memref;
1290 if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
1291 mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
1292 inEdges.push_back(inEdge);
1297 for (
auto &inEdge : inEdges) {
1302 unsigned sibNodeId = outEdge.id;
1303 if (visitedSibNodeIds->count(sibNodeId) > 0)
1306 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
1308 auto *sibNode = mdg->getNode(sibNodeId);
1309 if (!isa<AffineForOp>(sibNode->op))
1312 if (canFuseWithSibNode(sibNode, outEdge.value)) {
1314 outEdges.push_back(outEdge);
1319 if (!outEdges.empty()) {
1320 visitedSibNodeIds->insert(outEdges[0].
id);
1321 idAndMemrefToFuse->first = outEdges[0].id;
1322 idAndMemrefToFuse->second = outEdges[0].value;
1331 void updateStateAfterSiblingFusion(
Node *sibNode,
Node *dstNode) {
1336 auto dstForInst = cast<AffineForOp>(dstNode->op);
1338 dstLoopCollector.
collect(dstForInst);
1353 void eraseUnusedMemRefAllocations() {
1355 if (pair.second > 0)
1357 auto memref = pair.first;
1363 if (isa_and_nonnull<memref::AllocOp>(op))
1372 void LoopFusion::runOnBlock(
Block *block) {
1375 LLVM_DEBUG(llvm::dbgs() <<
"MDG init failed\n");
1379 std::optional<unsigned> fastMemorySpaceOpt;
1380 if (fastMemorySpace.hasValue())
1381 fastMemorySpaceOpt = fastMemorySpace;
1382 unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
1383 GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
1384 maximalFusion, computeToleranceThreshold);
1387 fusion.runProducerConsumerFusionOnly();
1389 fusion.runSiblingFusionOnly();
1391 fusion.runGreedyFusion();
1394 void LoopFusion::runOnOperation() {
1397 getOperation()->walk([&](
Operation *op) {
1399 for (Block &block : region.getBlocks()) {
1400 auto affineFors = block.getOps<AffineForOp>();
1401 if (!affineFors.empty() && !llvm::hasSingleElement(affineFors))
1409 unsigned fastMemorySpace, uint64_t localBufSizeThreshold,
1410 bool maximalFusion,
enum FusionMode affineFusionMode) {
1411 return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
1412 maximalFusion, affineFusionMode);
MemRefDependenceGraph::Node Node
static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, AffineForOp dstForOp, ArrayRef< ComputationSliceState > depthSliceUnions, unsigned maxLegalFusionDepth, unsigned *dstLoopDepth, double computeToleranceThreshold)
static bool hasNonAffineUsersOnPath(Operation *start, Operation *end, Value memref)
Returns true if there are any non-affine uses of memref in any of the operations between start and en...
static bool canRemoveSrcNodeAfterFusion(unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, Operation *fusedLoopInsPoint, const DenseSet< Value > &escapingMemRefs, MemRefDependenceGraph *mdg)
Returns true if node 'srcId' can be removed after fusing it with node 'dstId'.
static bool isEscapingMemref(Value memref, Block *block)
A memref escapes in the context of the fusion pass if either:
static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, DenseSet< Value > &escapingMemRefs)
Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' that escape the block or ...
static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, unsigned dstLoopDepth, std::optional< unsigned > fastMemorySpace, uint64_t localBufSizeThreshold)
static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg, SmallVectorImpl< unsigned > &srcIdCandidates)
Returns in 'srcIdCandidates' the producer fusion candidates for consumer 'dstId'.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
AffineExpr getAffineConstantExpr(int64_t constant)
AffineExpr getAffineDimExpr(unsigned position)
void getValues(unsigned start, unsigned end, SmallVectorImpl< Value > *values) const
Returns the Values associated with variables in range [start, end).
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
Describes the fusion strategy to be used in the Affine loop fusion utilities.
unsigned getNumVars() const
unsigned getNumCols() const
Returns the number of columns in the constraint system.
bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, AffineForOp dstForOp, LoopNestStats &dstStats, const ComputationSliceState &slice, int64_t *computeCost)
Computes and returns in 'computeCost', the total compute cost of fusing the 'slice' of the loop nest ...
void gatherProducerConsumerMemrefs(ArrayRef< Operation * > srcOps, ArrayRef< Operation * > dstOps, DenseSet< Value > &producerConsumerMemrefs)
Returns in 'producerConsumerMemrefs' the memrefs involved in a producer-consumer dependence between w...
int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats)
Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
void fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, const ComputationSliceState &srcSlice, bool isInnermostSiblingInsertionFusion=false)
Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point and source slice loop bo...
void getAffineForIVs(Operation &op, SmallVectorImpl< AffineForOp > *loops)
Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered from the outermost 'affine....
std::optional< int64_t > getMemoryFootprintBytes(AffineForOp forOp, int memorySpace=-1)
Gets the memory footprint of all data touched in the specified memory space in bytes; if the memory s...
std::unique_ptr< Pass > createLoopFusionPass(unsigned fastMemorySpace=0, uint64_t localBufSizeThreshold=0, bool maximalFusion=false, enum FusionMode fusionMode=FusionMode::Greedy)
Creates a loop fusion pass which fuses affine loop nests at the top-level of the operation the pass i...
FusionMode
Fusion mode to attempt.
unsigned getInnermostCommonLoopDepth(ArrayRef< Operation * > ops, SmallVectorImpl< AffineForOp > *surroundingLoops=nullptr)
Returns the innermost common loop depth for the set of operations in 'ops'.
bool getLoopNestStats(AffineForOp forOp, LoopNestStats *stats)
Collect loop nest statistics (eg.
AffineForOp sinkSequentialLoops(AffineForOp forOp)
FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, ComputationSliceState *srcSlice, FusionStrategy fusionStrategy=FusionStrategy::Generic)
Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the loop nest rooted at 'dst...
LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, ArrayRef< Value > extraIndices={}, AffineMap indexRemap=AffineMap(), ArrayRef< Value > extraOperands={}, ArrayRef< Value > symbolOperands={}, Operation *domOpFilter=nullptr, Operation *postDomOpFilter=nullptr, bool allowNonDereferencingOps=false, bool replaceInDeallocOp=false)
Replaces all "dereferencing" uses of oldMemRef with newMemRef while optionally remapping the old memr...
std::optional< int64_t > getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType)
Returns the memref's element type's size in bytes where the elemental type is an int or float or a ve...
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their associated operands for a ...
bool isEmpty() const
Returns true if the computation slice is empty.
std::optional< bool > isMaximal() const
Returns true if the computation slice encloses all the iterations of the sliced loop nest.
enum mlir::affine::FusionResult::ResultEnum value
SmallVector< Operation *, 4 > loadOpInsts
void collect(Operation *opToWalk)
SmallVector< Operation *, 4 > storeOpInsts
LoopNestStats aggregates various per-loop statistics (eg.
DenseMap< unsigned, SmallVector< Edge, 2 > > outEdges
Block & block
The block for which this graph is created to perform fusion.
unsigned addNode(Operation *op)
unsigned getIncomingMemRefAccesses(unsigned id, Value memref)
void addEdge(unsigned srcId, unsigned dstId, Value value)
Node * getForOpNode(AffineForOp forOp)
DenseMap< unsigned, Node > nodes
bool hasDependencePath(unsigned srcId, unsigned dstId)
void clearNodeLoadAndStores(unsigned id)
void updateEdges(unsigned srcId, unsigned dstId, const DenseSet< Value > &privateMemRefs, bool removeSrcId)
DenseMap< unsigned, SmallVector< Edge, 2 > > inEdges
void addToNode(unsigned id, const SmallVectorImpl< Operation * > &loads, const SmallVectorImpl< Operation * > &stores)
void forEachMemRefInputEdge(unsigned id, const std::function< void(Edge)> &callback)
unsigned getOutEdgeCount(unsigned id, Value memref=nullptr)
void removeNode(unsigned id)
Operation * getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId)
void forEachMemRefOutputEdge(unsigned id, const std::function< void(Edge)> &callback)
Node * getNode(unsigned id)
DenseMap< Value, unsigned > memrefEdgeCount
A region of a memref's data space; this is typically constructed by analyzing load/store op's on this...
FlatAffineValueConstraints * getConstraints()
std::optional< int64_t > getConstantBoundingSizeAndShape(SmallVectorImpl< int64_t > *shape=nullptr, std::vector< SmallVector< int64_t, 4 >> *lbs=nullptr, SmallVectorImpl< int64_t > *lbDivisors=nullptr) const
Returns a constant upper bound on the number of elements in this region if bounded by a known constan...
std::optional< int64_t > getRegionSize()
Returns the size of this MemRefRegion in bytes.
LogicalResult compute(Operation *op, unsigned loopDepth, const ComputationSliceState *sliceState=nullptr, bool addMemRefDimBounds=true)
Computes the memory region accessed by this memref with the region represented as constraints symboli...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.