27 #include "llvm/ADT/DenseMap.h"
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
39 #define GEN_PASS_DEF_AFFINELOOPFUSION
40 #include "mlir/Dialect/Affine/Passes.h.inc"
44 #define DEBUG_TYPE "affine-loop-fusion"
55 struct LoopFusion :
public affine::impl::AffineLoopFusionBase<LoopFusion> {
56 LoopFusion() =
default;
57 LoopFusion(
unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes,
58 bool maximalFusion,
enum FusionMode affineFusionMode) {
59 this->fastMemorySpace = fastMemorySpace;
60 this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024;
61 this->maximalFusion = maximalFusion;
62 this->affineFusionMode = affineFusionMode;
65 void runOnBlock(
Block *block);
66 void runOnOperation()
override;
84 bool hasOutDepsAfterFusion =
false;
86 for (
auto &outEdge : mdg->
outEdges[srcId]) {
89 if (depNodeOp == dstNodeOp)
99 if (fusedLoopInsPoint != depNodeOp &&
101 LLVM_DEBUG(llvm::dbgs() <<
"Src loop can't be removed: dst loop doesn't "
102 "dominate dependence\n");
106 hasOutDepsAfterFusion =
true;
112 if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
113 std::optional<bool> isMaximal = fusionSlice.
isMaximal();
115 LLVM_DEBUG(llvm::dbgs() <<
"Src loop can't be removed: can't determine "
116 "if fusion is maximal\n");
121 LLVM_DEBUG(llvm::dbgs()
122 <<
"Src loop can't be removed: fusion is not maximal\n");
140 if (mdg->
inEdges.count(dstId) == 0)
144 auto *dstNode = mdg->
getNode(dstId);
147 consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
151 for (
auto &srcEdge : mdg->
inEdges[dstId]) {
152 auto *srcNode = mdg->
getNode(srcEdge.id);
154 if (!isa<AffineForOp>(srcNode->op))
157 if (any_of(srcNode->stores, [&](
Operation *op) {
158 auto storeOp = cast<AffineWriteOpInterface>(op);
159 return consumedMemrefs.count(storeOp.getMemRef()) > 0;
161 srcIdCandidates.push_back(srcNode->id);
164 llvm::sort(srcIdCandidates);
165 srcIdCandidates.erase(
166 std::unique(srcIdCandidates.begin(), srcIdCandidates.end()),
167 srcIdCandidates.end());
176 auto *dstNode = mdg->
getNode(dstId);
177 auto *srcNode = mdg->
getNode(srcId);
179 producerConsumerMemrefs);
196 if (
auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
201 if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(defOp, memref))
208 Operation *ancestorOp = block->getParent()->findAncestorOpInRegion(*user);
211 if (ancestorOp->getBlock() != block)
213 return !isa<AffineMapAccessInterface>(*user);
222 for (
Operation *storeOp : node->stores) {
223 auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
224 if (escapingMemRefs.count(memref))
227 escapingMemRefs.insert(memref);
237 assert(isa<AffineForOp>(node->
op));
239 node->
op = newRootForOp;
248 unsigned dstLoopDepth,
249 std::optional<unsigned> fastMemorySpace,
250 uint64_t localBufSizeThreshold) {
251 Operation *forInst = forOp.getOperation();
258 auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
259 auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
260 unsigned rank = oldMemRefType.getRank();
266 assert(validRegion &&
"unexpected memref region failure");
268 std::vector<SmallVector<int64_t, 4>> lbs;
273 std::optional<int64_t> numElements =
275 assert(numElements &&
"non-constant number of elts in local buffer");
286 offsets.reserve(rank);
287 for (
unsigned d = 0; d < rank; ++d) {
288 assert(lbs[d].size() == cst->
getNumCols() - rank &&
"incorrect bound size");
291 for (
unsigned j = 0, e = cst->
getNumCols() - rank - 1;
j < e;
j++) {
294 assert(lbDivisors[d] > 0);
297 offsets.push_back(offset);
303 assert(eltSize &&
"memrefs with size elt types expected");
304 uint64_t bufSize = *eltSize * *numElements;
305 unsigned newMemSpace;
306 if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
307 newMemSpace = *fastMemorySpace;
309 newMemSpace = oldMemRefType.getMemorySpaceAsInt();
311 auto newMemRefType =
MemRefType::get(newShape, oldMemRefType.getElementType(),
320 Value newMemRef = top.
create<memref::AllocOp>(forOp.getLoc(), newMemRefType);
324 remapExprs.reserve(rank);
325 for (
unsigned i = 0; i < rank; i++) {
330 remapExprs.push_back(remapExpr);
334 AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
341 &*forOp.getBody()->begin());
343 "replaceAllMemrefUsesWith should always succeed here");
354 auto *srcNode = mdg->
getNode(srcId);
355 auto *dstNode = mdg->
getNode(dstId);
360 for (
auto &idAndNode : mdg->
nodes) {
363 if (srcNode->op->isBeforeInBlock(op) && op->
isBeforeInBlock(dstNode->op)) {
368 if (isa<AffineMapAccessInterface>(*user))
371 if (llvm::is_contained(users, user))
375 if (walkResult.wasInterrupted())
388 auto *srcNode = mdg->
getNode(srcId);
389 llvm::SmallDenseSet<Value, 2> memRefValues;
392 if (isa<AffineForOp>(op))
396 if (isa<MemRefType>(v.getType()))
397 memRefValues.insert(v);
401 return llvm::any_of(memRefValues, [&](
Value memref) {
446 AffineForOp dstForOp,
448 unsigned maxLegalFusionDepth,
449 unsigned *dstLoopDepth,
450 double computeToleranceThreshold) {
452 llvm::dbgs() <<
"Checking whether fusion is profitable between src op:\n";
453 llvm::dbgs() <<
' ' << *srcOpInst <<
" and destination loop:\n";
454 llvm::dbgs() << dstForOp <<
"\n";
457 if (maxLegalFusionDepth == 0) {
458 LLVM_DEBUG(llvm::dbgs() <<
"Can't fuse: maxLegalFusionDepth is 0\n");
483 double maxStorageReduction = 0.0;
484 std::optional<uint64_t> sliceMemEstimate;
487 std::optional<unsigned> bestDstLoopDepth;
490 uint64_t srcLoopNestCost =
getComputeCost(srcLoopIVs[0], srcLoopNestStats);
495 LLVM_DEBUG(llvm::dbgs()
496 <<
"Unable to compute MemRefRegion for source operation\n");
500 std::optional<int64_t> maybeSrcWriteRegionSizeBytes =
502 if (!maybeSrcWriteRegionSizeBytes.has_value())
504 int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;
507 uint64_t dstLoopNestCost =
getComputeCost(dstForOp, dstLoopNestStats);
511 for (
unsigned i = maxLegalFusionDepth; i >= 1; --i) {
517 int64_t fusedLoopNestComputeCost;
519 dstLoopNestStats, slice,
520 &fusedLoopNestComputeCost)) {
521 LLVM_DEBUG(llvm::dbgs() <<
"Unable to compute fusion compute cost\n");
525 double additionalComputeFraction =
526 fusedLoopNestComputeCost /
527 (
static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
536 LLVM_DEBUG(llvm::dbgs()
537 <<
"Failed to compute slice write region at loopDepth: " << i
542 std::optional<int64_t> maybeSliceWriteRegionSizeBytes =
544 if (!maybeSliceWriteRegionSizeBytes.has_value() ||
545 *maybeSliceWriteRegionSizeBytes == 0) {
546 LLVM_DEBUG(llvm::dbgs()
547 <<
"Failed to get slice write region size at loopDepth: " << i
551 int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
557 if (srcOpInst != srcStoreOpInst &&
558 sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
561 double storageReduction =
static_cast<double>(srcWriteRegionSizeBytes) /
562 static_cast<double>(sliceWriteRegionSizeBytes);
565 std::stringstream msg;
566 msg <<
" evaluating fusion profitability at depth : " << i <<
"\n"
567 << std::fixed << std::setprecision(2)
568 <<
" additional compute fraction: "
569 << 100.0 * additionalComputeFraction <<
"%\n"
570 <<
" storage reduction factor: " << storageReduction <<
"x\n"
571 <<
" fused nest cost: " << fusedLoopNestComputeCost <<
"\n"
572 <<
" src write region size: " << srcWriteRegionSizeBytes <<
"\n"
573 <<
" slice write region size: " << sliceWriteRegionSizeBytes
575 llvm::dbgs() << msg.str();
582 if ((storageReduction > maxStorageReduction) &&
583 (additionalComputeFraction < computeToleranceThreshold)) {
584 maxStorageReduction = storageReduction;
585 bestDstLoopDepth = i;
586 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
587 sliceMemEstimate = sliceWriteRegionSizeBytes;
593 if (!bestDstLoopDepth) {
596 <<
"All fusion choices involve more than the threshold amount of "
597 "redundant computation; NOT fusing.\n");
601 if (!bestDstLoopDepth) {
602 LLVM_DEBUG(llvm::dbgs() <<
"no fusion depth could be evaluated.\n");
607 *dstLoopDepth = *bestDstLoopDepth;
610 llvm::dbgs() <<
" LoopFusion fusion stats:"
611 <<
"\n best loop depth: " << bestDstLoopDepth
612 <<
"\n src loop nest compute cost: " << srcLoopNestCost
613 <<
"\n dst loop nest compute cost: " << dstLoopNestCost
614 <<
"\n fused loop nest compute cost: "
615 << minFusedLoopNestComputeCost <<
"\n");
620 std::optional<double> storageReduction;
622 if (!dstMemSize || !srcMemSize) {
623 LLVM_DEBUG(llvm::dbgs()
624 <<
" fusion memory benefit cannot be evaluated; NOT fusing.\n");
628 auto srcMemSizeVal = *srcMemSize;
629 auto dstMemSizeVal = *dstMemSize;
631 assert(sliceMemEstimate &&
"expected value");
632 auto fusedMem = dstMemSizeVal + *sliceMemEstimate;
634 LLVM_DEBUG(llvm::dbgs() <<
" src mem: " << srcMemSizeVal <<
"\n"
635 <<
" dst mem: " << dstMemSizeVal <<
"\n"
636 <<
" fused mem: " << fusedMem <<
"\n"
637 <<
" slice mem: " << sliceMemEstimate <<
"\n");
639 if (
static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
640 LLVM_DEBUG(llvm::dbgs() <<
"Fusion is not profitable; NOT fusing.\n");
645 (1.0 - fusedMem / (
static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
647 double additionalComputeFraction =
648 100.0 * (minFusedLoopNestComputeCost /
649 (
static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
651 (void)additionalComputeFraction;
653 std::stringstream msg;
654 msg <<
" fusion is most profitable at depth " << *dstLoopDepth <<
" with "
655 << std::setprecision(2) << additionalComputeFraction
656 <<
"% redundant computation and a ";
657 msg << (storageReduction ? std::to_string(*storageReduction) :
"<unknown>");
658 msg <<
"% storage reduction.\n";
659 llvm::dbgs() << msg.str();
713 struct GreedyFusion {
720 unsigned localBufSizeThreshold;
722 std::optional<unsigned> fastMemorySpace;
728 double computeToleranceThreshold;
733 std::optional<unsigned> fastMemorySpace,
bool maximalFusion,
734 double computeToleranceThreshold)
735 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
736 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
737 computeToleranceThreshold(computeToleranceThreshold) {}
744 for (
auto &idAndNode : mdg->
nodes) {
745 const Node &node = idAndNode.second;
746 worklist.push_back(node.id);
750 void runSiblingFusionOnly() {
752 eraseUnusedMemRefAllocations();
756 void runProducerConsumerFusionOnly() {
757 fuseProducerConsumerNodes(
759 eraseUnusedMemRefAllocations();
767 void runGreedyFusion() {
769 fuseProducerConsumerNodes(1);
771 fuseProducerConsumerNodes(
773 eraseUnusedMemRefAllocations();
778 bool canCreatePrivateMemRef(
Value memref,
780 unsigned producerId,
unsigned consumerId,
781 bool removeSrcNode) {
782 const Node *consumerNode = mdg->
getNode(consumerId);
790 if (srcEscapingMemRefs.count(memref) > 0 &&
791 (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
804 any_of(mdg->
outEdges[producerId], [&](
const auto &edge) {
805 return edge.value == memref && edge.id != consumerId;
815 void performFusionsIntoDest(
unsigned dstId,
unsigned maxSrcUserCount) {
816 LLVM_DEBUG(llvm::dbgs() <<
"Evaluating dst loop " << dstId <<
"\n");
818 if (mdg->
nodes.count(dstId) == 0)
821 auto *dstNode = mdg->
getNode(dstId);
823 if (!isa<AffineForOp>(dstNode->op))
827 if (dstNode->op->getNumResults() > 0)
830 LLVM_DEBUG(llvm::dbgs() <<
"Evaluating dst loop " << dstId <<
"\n");
837 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
848 dstNodeChanged =
false;
852 for (
unsigned srcId : llvm::reverse(srcIdCandidates)) {
854 auto *srcNode = mdg->
getNode(srcId);
855 auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
856 LLVM_DEBUG(llvm::dbgs() <<
"Evaluating src loop " << srcId
857 <<
" for dst loop " << dstId <<
"\n");
861 if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
866 producerConsumerMemrefs);
870 if (any_of(producerConsumerMemrefs, [&](
Value memref) {
888 if (!srcEscapingMemRefs.empty() &&
890 LLVM_DEBUG(llvm::dbgs()
891 <<
"Can't fuse: non-affine users in between the loops\n");
899 if (fusedLoopInsPoint ==
nullptr)
910 unsigned numSurroundingLoops = surroundingLoops.size();
916 if (producerConsumerMemrefs.count(
917 cast<AffineReadOpInterface>(op).getMemRef()) > 0)
918 dstMemrefOps.push_back(op);
920 if (producerConsumerMemrefs.count(
921 cast<AffineWriteOpInterface>(op).getMemRef()))
922 dstMemrefOps.push_back(op);
923 unsigned dstLoopDepthTest =
928 unsigned maxLegalFusionDepth = 0;
930 depthSliceUnions.resize(dstLoopDepthTest);
932 for (
unsigned i = 1; i <= dstLoopDepthTest; ++i) {
935 i + numSurroundingLoops,
936 &depthSliceUnions[i - 1], strategy);
939 maxLegalFusionDepth = i;
942 if (maxLegalFusionDepth == 0) {
943 LLVM_DEBUG(llvm::dbgs()
944 <<
"Can't fuse: fusion is not legal at any depth\n");
951 unsigned bestDstLoopDepth = maxLegalFusionDepth;
952 if (!maximalFusion) {
956 if (producerConsumerMemrefs.count(
957 cast<AffineWriteOpInterface>(op).getMemRef()))
958 producerStores.push_back(op);
966 assert(!producerStores.empty() &&
"Expected producer store");
967 if (producerStores.size() > 1)
968 LLVM_DEBUG(llvm::dbgs() <<
"Skipping profitability analysis. Not "
969 "supported for this case\n");
971 dstAffineForOp, depthSliceUnions,
972 maxLegalFusionDepth, &bestDstLoopDepth,
973 computeToleranceThreshold))
977 assert(bestDstLoopDepth > 0 &&
"Unexpected loop fusion depth");
979 depthSliceUnions[bestDstLoopDepth - 1];
980 assert(!bestSlice.
isEmpty() &&
"Missing slice union for depth");
986 srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
990 for (
Value memref : producerConsumerMemrefs) {
991 if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
994 LLVM_DEBUG(llvm::dbgs()
995 <<
"Creating private memref for " << memref <<
'\n');
997 privateMemrefs.insert(memref);
1002 fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
1003 dstNodeChanged =
true;
1005 LLVM_DEBUG(llvm::dbgs()
1006 <<
"Fused src loop " << srcId <<
" into dst loop " << dstId
1007 <<
" at depth " << bestDstLoopDepth <<
":\n"
1008 << dstAffineForOp <<
"\n");
1011 if (fusedLoopInsPoint != dstAffineForOp)
1012 dstAffineForOp->moveBefore(fusedLoopInsPoint);
1015 mdg->
updateEdges(srcNode->id, dstNode->id, privateMemrefs,
1019 if (!privateMemrefs.empty()) {
1022 dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
1023 Value storeMemRef = storeOp.getMemRef();
1024 if (privateMemrefs.count(storeMemRef) > 0)
1025 privateMemRefToStores[storeMemRef].push_back(storeOp);
1032 for (
auto &memrefToStoresPair : privateMemRefToStores) {
1036 memrefToStoresPair.second;
1038 dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
1039 fastMemorySpace, localBufSizeThreshold);
1043 mdg->
addEdge(newMemRefNodeId, dstId, newMemRef);
1048 dstNode = mdg->
getNode(dstId);
1053 dstLoopCollector.
collect(dstAffineForOp);
1060 if (removeSrcNode) {
1061 LLVM_DEBUG(llvm::dbgs()
1062 <<
"Removing src loop " << srcId <<
" after fusion\n");
1064 srcAffineForOp.erase();
1069 }
while (dstNodeChanged);
1076 void fuseProducerConsumerNodes(
unsigned maxSrcUserCount) {
1077 LLVM_DEBUG(llvm::dbgs() <<
"--- Producer/Consumer Fusion ---\n");
1079 while (!worklist.empty()) {
1080 unsigned dstId = worklist.back();
1081 worklist.pop_back();
1082 performFusionsIntoDest(dstId, maxSrcUserCount);
1088 void fuseSiblingNodes() {
1089 LLVM_DEBUG(llvm::dbgs() <<
"--- Sibling Fusion ---\n");
1091 while (!worklist.empty()) {
1092 unsigned dstId = worklist.back();
1093 worklist.pop_back();
1096 if (mdg->
nodes.count(dstId) == 0)
1099 auto *dstNode = mdg->
getNode(dstId);
1101 if (!isa<AffineForOp>(dstNode->op))
1104 fuseWithSiblingNodes(dstNode);
1109 void fuseWithSiblingNodes(
Node *dstNode) {
1111 std::pair<unsigned, Value> idAndMemref;
1112 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
1114 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
1115 unsigned sibId = idAndMemref.first;
1116 Value memref = idAndMemref.second;
1119 auto *sibNode = mdg->
getNode(sibId);
1122 assert(sibNode->op->getBlock() == dstNode->op->getBlock());
1127 if (insertPointInst ==
nullptr)
1134 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
1136 assert(sibLoadOpInsts.size() == 1);
1137 Operation *sibLoadOpInst = sibLoadOpInsts[0];
1141 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
1151 unsigned numSurroundingLoops = surroundingLoops.size();
1154 unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
1155 auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
1159 depthSliceUnions.resize(dstLoopDepthTest);
1160 unsigned maxLegalFusionDepth = 0;
1162 for (
unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1165 i + numSurroundingLoops,
1166 &depthSliceUnions[i - 1], strategy);
1169 maxLegalFusionDepth = i;
1172 LLVM_DEBUG(llvm::dbgs() <<
"Max legal depth for fusion: "
1173 << maxLegalFusionDepth <<
'\n');
1176 if (maxLegalFusionDepth == 0)
1179 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1180 if (!maximalFusion) {
1186 depthSliceUnions, maxLegalFusionDepth,
1187 &bestDstLoopDepth, computeToleranceThreshold))
1191 assert(bestDstLoopDepth > 0 &&
"Unexpected loop fusion depth");
1192 assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
1193 "Fusion depth has no computed slice union");
1197 bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
1200 depthSliceUnions[bestDstLoopDepth - 1],
1201 isInnermostInsertion);
1203 auto dstForInst = cast<AffineForOp>(dstNode->op);
1205 if (insertPointInst != dstForInst) {
1206 dstForInst->moveBefore(insertPointInst);
1209 updateStateAfterSiblingFusion(sibNode, dstNode);
1217 bool findSiblingNodeToFuse(
Node *dstNode,
1219 std::pair<unsigned, Value> *idAndMemrefToFuse) {
1222 auto canFuseWithSibNode = [&](
Node *sibNode,
Value memref) {
1225 if (sibNode->getLoadOpCount(memref) != 1)
1235 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
1236 if (llvm::any_of(loadAndStoreMemrefSet, [=](
Value memref) {
1243 for (
auto *storeOpInst : sibNode->stores) {
1244 storeMemrefs.insert(
1245 cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
1247 if (storeMemrefs.size() > 1)
1259 Block *block = dstNode->op->getBlock();
1262 auto loadOp = dyn_cast<AffineReadOpInterface>(user);
1271 auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
1272 return loop->getBlock() == &mdg->
block;
1275 if (it == loops.end())
1278 assert(sibNode !=
nullptr);
1280 if (sibNode->id == dstNode->id)
1283 if (visitedSibNodeIds->count(sibNode->id) > 0)
1286 auto memref = loadOp.getMemRef();
1287 if (dstNode->getLoadOpCount(memref) == 0)
1290 if (canFuseWithSibNode(sibNode, memref)) {
1291 visitedSibNodeIds->insert(sibNode->id);
1292 idAndMemrefToFuse->first = sibNode->id;
1293 idAndMemrefToFuse->second = memref;
1305 if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
1306 mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
1307 inEdges.push_back(inEdge);
1312 for (
auto &inEdge : inEdges) {
1317 unsigned sibNodeId = outEdge.id;
1318 if (visitedSibNodeIds->count(sibNodeId) > 0)
1321 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
1323 auto *sibNode = mdg->getNode(sibNodeId);
1324 if (!isa<AffineForOp>(sibNode->op))
1327 if (canFuseWithSibNode(sibNode, outEdge.value)) {
1329 outEdges.push_back(outEdge);
1334 if (!outEdges.empty()) {
1335 visitedSibNodeIds->insert(outEdges[0].
id);
1336 idAndMemrefToFuse->first = outEdges[0].id;
1337 idAndMemrefToFuse->second = outEdges[0].value;
1346 void updateStateAfterSiblingFusion(
Node *sibNode,
Node *dstNode) {
1351 auto dstForInst = cast<AffineForOp>(dstNode->op);
1353 dstLoopCollector.
collect(dstForInst);
1368 void eraseUnusedMemRefAllocations() {
1370 if (pair.second > 0)
1372 auto memref = pair.first;
1378 if (isa_and_nonnull<memref::AllocOp>(op))
1387 void LoopFusion::runOnBlock(
Block *block) {
1390 LLVM_DEBUG(llvm::dbgs() <<
"MDG init failed\n");
1394 std::optional<unsigned> fastMemorySpaceOpt;
1395 if (fastMemorySpace.hasValue())
1396 fastMemorySpaceOpt = fastMemorySpace;
1397 unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
1398 GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
1399 maximalFusion, computeToleranceThreshold);
1402 fusion.runProducerConsumerFusionOnly();
1404 fusion.runSiblingFusionOnly();
1406 fusion.runGreedyFusion();
1409 void LoopFusion::runOnOperation() {
1412 getOperation()->walk([&](
Operation *op) {
1414 for (Block &block : region.getBlocks()) {
1415 auto affineFors = block.getOps<AffineForOp>();
1416 if (!affineFors.empty() && !llvm::hasSingleElement(affineFors))
1424 unsigned fastMemorySpace, uint64_t localBufSizeThreshold,
1425 bool maximalFusion,
enum FusionMode affineFusionMode) {
1426 return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
1427 maximalFusion, affineFusionMode);
MemRefDependenceGraph::Node Node
static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, AffineForOp dstForOp, ArrayRef< ComputationSliceState > depthSliceUnions, unsigned maxLegalFusionDepth, unsigned *dstLoopDepth, double computeToleranceThreshold)
static bool canRemoveSrcNodeAfterFusion(unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, Operation *fusedLoopInsPoint, const DenseSet< Value > &escapingMemRefs, MemRefDependenceGraph *mdg)
Returns true if node 'srcId' can be removed after fusing it with node 'dstId'.
static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, Value memref, MemRefDependenceGraph *mdg)
Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and 'dstId'), if there is any non-aff...
static bool isEscapingMemref(Value memref, Block *block)
A memref escapes in the context of the fusion pass if either:
static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, DenseSet< Value > &escapingMemRefs)
Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' that escape the block or ...
static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, unsigned dstLoopDepth, std::optional< unsigned > fastMemorySpace, uint64_t localBufSizeThreshold)
static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg, SmallVectorImpl< unsigned > &srcIdCandidates)
Returns in 'srcIdCandidates' the producer fusion candidates for consumer 'dstId'.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
AffineExpr getAffineConstantExpr(int64_t constant)
AffineExpr getAffineDimExpr(unsigned position)
void getValues(unsigned start, unsigned end, SmallVectorImpl< Value > *values) const
Returns the Values associated with variables in range [start, end).
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
Region * getParentRegion()
Returns the region to which the instruction belongs.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
Describes the fusion strategy to be used in the Affine loop fusion utilities.
unsigned getNumVars() const
unsigned getNumCols() const
Returns the number of columns in the constraint system.
bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, AffineForOp dstForOp, LoopNestStats &dstStats, const ComputationSliceState &slice, int64_t *computeCost)
Computes and returns in 'computeCost', the total compute cost of fusing the 'slice' of the loop nest ...
void gatherProducerConsumerMemrefs(ArrayRef< Operation * > srcOps, ArrayRef< Operation * > dstOps, DenseSet< Value > &producerConsumerMemrefs)
Returns in 'producerConsumerMemrefs' the memrefs involved in a producer-consumer dependence between w...
int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats)
Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
void fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, const ComputationSliceState &srcSlice, bool isInnermostSiblingInsertionFusion=false)
Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point and source slice loop bo...
void getAffineForIVs(Operation &op, SmallVectorImpl< AffineForOp > *loops)
Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered from the outermost 'affine....
std::optional< int64_t > getMemoryFootprintBytes(AffineForOp forOp, int memorySpace=-1)
Gets the memory footprint of all data touched in the specified memory space in bytes; if the memory s...
std::unique_ptr< Pass > createLoopFusionPass(unsigned fastMemorySpace=0, uint64_t localBufSizeThreshold=0, bool maximalFusion=false, enum FusionMode fusionMode=FusionMode::Greedy)
Creates a loop fusion pass which fuses affine loop nests at the top-level of the operation the pass i...
FusionMode
Fusion mode to attempt.
unsigned getInnermostCommonLoopDepth(ArrayRef< Operation * > ops, SmallVectorImpl< AffineForOp > *surroundingLoops=nullptr)
Returns the innermost common loop depth for the set of operations in 'ops'.
bool getLoopNestStats(AffineForOp forOp, LoopNestStats *stats)
Collect loop nest statistics (eg.
AffineForOp sinkSequentialLoops(AffineForOp forOp)
FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, ComputationSliceState *srcSlice, FusionStrategy fusionStrategy=FusionStrategy::Generic)
Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the loop nest rooted at 'dst...
LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, ArrayRef< Value > extraIndices={}, AffineMap indexRemap=AffineMap(), ArrayRef< Value > extraOperands={}, ArrayRef< Value > symbolOperands={}, Operation *domOpFilter=nullptr, Operation *postDomOpFilter=nullptr, bool allowNonDereferencingOps=false, bool replaceInDeallocOp=false)
Replaces all "dereferencing" uses of oldMemRef with newMemRef while optionally remapping the old memr...
std::optional< int64_t > getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType)
Returns the memref's element type's size in bytes where the elemental type is an int or float or a ve...
Include the generated interface declarations.
int64_t floorDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's floordiv operation on constants.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their associated operands for a ...
bool isEmpty() const
Returns true if the computation slice is empty.
std::optional< bool > isMaximal() const
Returns true if the computation slice encloses all the iterations of the sliced loop nest.
enum mlir::affine::FusionResult::ResultEnum value
SmallVector< Operation *, 4 > loadOpInsts
void collect(Operation *opToWalk)
SmallVector< Operation *, 4 > storeOpInsts
LoopNestStats aggregates various per-loop statistics (eg.
DenseMap< unsigned, SmallVector< Edge, 2 > > outEdges
Block & block
The block for which this graph is created to perform fusion.
unsigned addNode(Operation *op)
unsigned getIncomingMemRefAccesses(unsigned id, Value memref)
void addEdge(unsigned srcId, unsigned dstId, Value value)
Node * getForOpNode(AffineForOp forOp)
DenseMap< unsigned, Node > nodes
bool hasDependencePath(unsigned srcId, unsigned dstId)
void clearNodeLoadAndStores(unsigned id)
void updateEdges(unsigned srcId, unsigned dstId, const DenseSet< Value > &privateMemRefs, bool removeSrcId)
DenseMap< unsigned, SmallVector< Edge, 2 > > inEdges
void addToNode(unsigned id, const SmallVectorImpl< Operation * > &loads, const SmallVectorImpl< Operation * > &stores)
void forEachMemRefInputEdge(unsigned id, const std::function< void(Edge)> &callback)
unsigned getOutEdgeCount(unsigned id, Value memref=nullptr)
void removeNode(unsigned id)
Operation * getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId)
void forEachMemRefOutputEdge(unsigned id, const std::function< void(Edge)> &callback)
Node * getNode(unsigned id)
DenseMap< Value, unsigned > memrefEdgeCount
A region of a memref's data space; this is typically constructed by analyzing load/store op's on this...
FlatAffineValueConstraints * getConstraints()
std::optional< int64_t > getConstantBoundingSizeAndShape(SmallVectorImpl< int64_t > *shape=nullptr, std::vector< SmallVector< int64_t, 4 >> *lbs=nullptr, SmallVectorImpl< int64_t > *lbDivisors=nullptr) const
Returns a constant upper bound on the number of elements in this region if bounded by a known constan...
std::optional< int64_t > getRegionSize()
Returns the size of this MemRefRegion in bytes.
LogicalResult compute(Operation *op, unsigned loopDepth, const ComputationSliceState *sliceState=nullptr, bool addMemRefDimBounds=true)
Computes the memory region accessed by this memref with the region represented as constraints symboli...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.