26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/ADT/DenseSet.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/raw_ostream.h"
38 #define GEN_PASS_DEF_AFFINELOOPFUSION
39 #include "mlir/Dialect/Affine/Passes.h.inc"
43 #define DEBUG_TYPE "affine-fusion"
54 struct LoopFusion :
public affine::impl::AffineLoopFusionBase<LoopFusion> {
55 LoopFusion() =
default;
56 LoopFusion(
unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes,
57 bool maximalFusion,
enum FusionMode affineFusionMode) {
58 this->fastMemorySpace = fastMemorySpace;
59 this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024;
60 this->maximalFusion = maximalFusion;
61 this->affineFusionMode = affineFusionMode;
64 void runOnBlock(
Block *block);
65 void runOnOperation()
override;
83 bool hasOutDepsAfterFusion =
false;
85 for (
auto &outEdge : mdg.
outEdges.lookup(srcId)) {
88 if (depNodeOp == dstNodeOp)
98 if (fusedLoopInsPoint != depNodeOp &&
100 LLVM_DEBUG(llvm::dbgs() <<
"Src loop can't be removed: dst loop doesn't "
101 "dominate dependence\n");
105 hasOutDepsAfterFusion =
true;
111 if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
112 std::optional<bool> isMaximal = fusionSlice.
isMaximal();
114 LLVM_DEBUG(llvm::dbgs() <<
"Src loop can't be removed: can't determine "
115 "if fusion is maximal\n");
120 LLVM_DEBUG(llvm::dbgs()
121 <<
"Src loop can't be removed: fusion is not maximal\n");
140 if (mdg.
inEdges.count(dstId) == 0)
144 auto *dstNode = mdg.
getNode(dstId);
147 consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
151 for (
const auto &srcEdge : mdg.
inEdges.lookup(dstId)) {
152 const auto *srcNode = mdg.
getNode(srcEdge.id);
154 if (!isa<AffineForOp>(srcNode->op))
157 if (any_of(srcNode->stores, [&](
Operation *op) {
158 auto storeOp = cast<AffineWriteOpInterface>(op);
159 return consumedMemrefs.count(storeOp.getMemRef()) > 0;
161 srcIdCandidates.push_back(srcNode->id);
164 llvm::sort(srcIdCandidates);
165 srcIdCandidates.erase(llvm::unique(srcIdCandidates), srcIdCandidates.end());
174 auto *dstNode = mdg.
getNode(dstId);
175 auto *srcNode = mdg.
getNode(srcId);
177 producerConsumerMemrefs);
194 if (
auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
199 if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(defOp, memref))
206 Operation *ancestorOp = block->getParent()->findAncestorOpInRegion(*user);
209 if (ancestorOp->getBlock() != block)
211 return !isa<AffineMapAccessInterface>(*user);
220 for (
Operation *storeOp : node->stores) {
221 auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
222 if (escapingMemRefs.count(memref))
225 escapingMemRefs.insert(memref);
235 assert(isa<AffineForOp>(node->
op));
237 node->
op = newRootForOp;
247 assert(!producerStores.empty() &&
"expected producer store");
253 Block *commonBlock =
nullptr;
255 for (
Operation *store : producerStores) {
257 !commonBlock ? &*sliceInsertionBlock->
begin() : &*commonBlock->
begin();
260 assert(commonBlock &&
261 "common block of producer stores and slice should exist");
266 for (
Operation *store : producerStores) {
268 assert(ancestor &&
"producer store should be contained in common block");
269 firstAncestor = !firstAncestor || ancestor->
isBeforeInBlock(firstAncestor)
273 return firstAncestor;
281 AffineForOp srcForOp, AffineForOp dstForOp,
unsigned depth,
283 int64_t &fusedLoopNestComputeCost) {
284 LLVM_DEBUG(llvm::dbgs() <<
"Determining additional compute fraction...\n";);
289 LLVM_DEBUG(llvm::dbgs() <<
"Failed to get source loop nest stats.\n");
296 LLVM_DEBUG(llvm::dbgs() <<
"Failed to get destination loop nest stats.\n");
301 uint64_t srcLoopNestCost =
getComputeCost(srcForOp, srcLoopNestStats);
304 uint64_t dstLoopNestCost =
getComputeCost(dstForOp, dstLoopNestStats);
309 LLVM_DEBUG(llvm::dbgs() <<
"Slice wasn't computed.\n");
314 dstLoopNestStats, slice,
315 &fusedLoopNestComputeCost)) {
316 LLVM_DEBUG(llvm::dbgs() <<
"Unable to compute fusion compute cost\n");
320 double additionalComputeFraction =
321 fusedLoopNestComputeCost /
322 (
static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
325 return additionalComputeFraction;
336 unsigned dstLoopDepth,
337 std::optional<unsigned> fastMemorySpace,
338 Block *sliceInsertionBlock,
339 uint64_t localBufSizeThreshold) {
340 assert(!storeOps.empty() &&
"no source stores supplied");
346 if (storeOps.size() > 1 &&
347 !std::equal(std::next(storeOps.begin()), storeOps.end(), storeOps.begin(),
349 MemRefAccess aM(cast<AffineWriteOpInterface>(a));
350 MemRefAccess bM(cast<AffineWriteOpInterface>(b));
353 LLVM_DEBUG(llvm::dbgs()
354 <<
"Private memref creation unsupported for multiple producer "
355 "stores with different access functions.\n");
366 auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOp).getMemRef();
367 auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
368 unsigned rank = oldMemRefType.getRank();
372 bool validRegion = succeeded(
373 region.
compute(srcStoreOp, dstLoopDepth,
nullptr,
377 assert(validRegion &&
"unexpected memref region failure");
383 std::optional<int64_t> numElements =
385 assert(numElements &&
"non-constant number of elts in local buffer");
396 offsets.reserve(rank);
401 for (
unsigned j = 0, e = lbs[0].getNumSymbols();
j < e; ++
j)
403 for (
unsigned d = 0; d < rank; ++d) {
404 assert(lbs[d].getNumResults() == 1 &&
405 "invalid private memref bound calculation");
406 offsets.push_back(lbs[d].getResult(0).replaceSymbols(replacements));
412 assert(eltSize &&
"memrefs with size elt types expected");
413 uint64_t bufSize = *eltSize * *numElements;
415 if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
416 newMemSpace = b.getI64IntegerAttr(*fastMemorySpace);
418 newMemSpace = oldMemRefType.getMemorySpace();
420 auto newMemRefType =
MemRefType::get(newShape, oldMemRefType.getElementType(),
429 Value newMemRef = top.
create<memref::AllocOp>(forOp.getLoc(), newMemRefType);
433 remapExprs.reserve(rank);
434 for (
unsigned i = 0; i < rank; i++) {
435 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
439 remapExprs.push_back(remapExpr);
443 AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
449 oldMemRef, newMemRef, {}, indexRemap,
452 assert(succeeded(res) &&
453 "replaceAllMemrefUsesWith should always succeed here");
455 LLVM_DEBUG(llvm::dbgs() <<
"Created private memref of type: " << newMemRefType
500 AffineForOp dstForOp,
502 unsigned maxLegalFusionDepth,
503 unsigned *dstLoopDepth,
504 double computeToleranceThreshold) {
507 <<
"Checking whether fusion is profitable between source nest:\n";
508 llvm::dbgs() <<
' ' << srcForOp <<
" and destination nest:\n";
509 llvm::dbgs() << dstForOp <<
"\n";
512 if (maxLegalFusionDepth == 0) {
513 LLVM_DEBUG(llvm::dbgs() <<
"Can't fuse: maxLegalFusionDepth is 0\n");
536 if (producerStores.size() > 1) {
537 LLVM_DEBUG(llvm::dbgs() <<
"Limited profitability analysis. Not "
538 "supported for multiple producer store case.\n");
540 int64_t fusedLoopNestComputeCost;
544 srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost,
545 fusedLoopNestComputeCost);
546 if (!fraction || fraction > computeToleranceThreshold) {
547 LLVM_DEBUG(llvm::dbgs() <<
"Additional computation exceeds "
548 "compute tolerance. Not fusing.\n");
551 LLVM_DEBUG(llvm::dbgs()
552 <<
"Considering fusion profitable at max legal depth.\n");
556 Operation *srcStoreOp = producerStores.front();
565 double maxStorageReduction = 0.0;
566 std::optional<uint64_t> sliceMemEstimate;
569 std::optional<unsigned> bestDstLoopDepth;
573 if (failed(srcWriteRegion.
compute(srcStoreOp, 0))) {
574 LLVM_DEBUG(llvm::dbgs()
575 <<
"Unable to compute MemRefRegion for source operation\n");
579 std::optional<int64_t> maybeSrcWriteRegionSizeBytes =
581 if (!maybeSrcWriteRegionSizeBytes.has_value())
583 int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;
586 uint64_t srcLoopNestCost =
getComputeCost(srcForOp, srcLoopNestStats);
589 uint64_t dstLoopNestCost =
getComputeCost(dstForOp, dstLoopNestStats);
593 for (
unsigned i = maxLegalFusionDepth; i >= 1; --i) {
603 int64_t fusedLoopNestComputeCost;
605 auto mayAdditionalComputeFraction =
607 sliceCost, fusedLoopNestComputeCost);
608 if (!mayAdditionalComputeFraction) {
609 LLVM_DEBUG(llvm::dbgs()
610 <<
"Can't determine additional compute fraction.\n");
613 double additionalComputeFraction = *mayAdditionalComputeFraction;
619 if (failed(sliceWriteRegion.
compute(srcStoreOp, 0, &slice))) {
620 LLVM_DEBUG(llvm::dbgs()
621 <<
"Failed to compute slice write region at loopDepth: " << i
626 std::optional<int64_t> maybeSliceWriteRegionSizeBytes =
628 if (!maybeSliceWriteRegionSizeBytes.has_value() ||
629 *maybeSliceWriteRegionSizeBytes == 0) {
630 LLVM_DEBUG(llvm::dbgs()
631 <<
"Failed to get slice write region size at loopDepth: " << i
635 int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
637 double storageReduction =
static_cast<double>(srcWriteRegionSizeBytes) /
638 static_cast<double>(sliceWriteRegionSizeBytes);
641 std::stringstream msg;
642 msg <<
" evaluating fusion profitability at depth : " << i <<
"\n"
643 << std::fixed << std::setprecision(2)
644 <<
" additional compute fraction: "
645 << 100.0 * additionalComputeFraction <<
"%\n"
646 <<
" storage reduction factor: " << storageReduction <<
"x\n"
647 <<
" fused nest cost: " << fusedLoopNestComputeCost <<
"\n"
648 <<
" src write region size: " << srcWriteRegionSizeBytes <<
"\n"
649 <<
" slice write region size: " << sliceWriteRegionSizeBytes
651 llvm::dbgs() << msg.str();
658 if ((storageReduction > maxStorageReduction) &&
659 (additionalComputeFraction <= computeToleranceThreshold)) {
660 maxStorageReduction = storageReduction;
661 bestDstLoopDepth = i;
662 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
663 sliceMemEstimate = sliceWriteRegionSizeBytes;
669 if (!bestDstLoopDepth) {
672 <<
"All fusion choices involve more than the threshold amount of "
673 "redundant computation; NOT fusing.\n");
677 if (!bestDstLoopDepth) {
678 LLVM_DEBUG(llvm::dbgs() <<
"no fusion depth could be evaluated.\n");
683 *dstLoopDepth = *bestDstLoopDepth;
686 llvm::dbgs() <<
" LoopFusion fusion stats:"
687 <<
"\n best loop depth: " << bestDstLoopDepth
688 <<
"\n src loop nest compute cost: " << srcLoopNestCost
689 <<
"\n dst loop nest compute cost: " << dstLoopNestCost
690 <<
"\n fused loop nest compute cost: "
691 << minFusedLoopNestComputeCost <<
"\n");
696 std::optional<double> storageReduction;
698 if (!dstMemSize || !srcMemSize) {
699 LLVM_DEBUG(llvm::dbgs()
700 <<
" fusion memory benefit cannot be evaluated; NOT fusing.\n");
704 auto srcMemSizeVal = *srcMemSize;
705 auto dstMemSizeVal = *dstMemSize;
707 assert(sliceMemEstimate &&
"expected value");
708 auto fusedMem = dstMemSizeVal + *sliceMemEstimate;
710 LLVM_DEBUG(llvm::dbgs() <<
" src mem: " << srcMemSizeVal <<
"\n"
711 <<
" dst mem: " << dstMemSizeVal <<
"\n"
712 <<
" fused mem: " << fusedMem <<
"\n"
713 <<
" slice mem: " << sliceMemEstimate <<
"\n");
715 if (
static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
716 LLVM_DEBUG(llvm::dbgs() <<
"Fusion is not profitable; NOT fusing.\n");
721 (1.0 - fusedMem / (
static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
723 double additionalComputeFraction =
724 100.0 * (minFusedLoopNestComputeCost /
725 (
static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
727 (void)additionalComputeFraction;
729 std::stringstream msg;
730 msg <<
" fusion is most profitable at depth " << *dstLoopDepth <<
" with "
731 << std::setprecision(2) << additionalComputeFraction
732 <<
"% redundant computation and a ";
733 msg << (storageReduction ? std::to_string(*storageReduction) :
"<unknown>");
734 msg <<
"% storage reduction.\n";
735 llvm::dbgs() << msg.str();
789 struct GreedyFusion {
796 unsigned localBufSizeThreshold;
798 std::optional<unsigned> fastMemorySpace;
804 double computeToleranceThreshold;
809 std::optional<unsigned> fastMemorySpace,
bool maximalFusion,
810 double computeToleranceThreshold)
811 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
812 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
813 computeToleranceThreshold(computeToleranceThreshold) {}
820 for (
auto &idAndNode : mdg->
nodes) {
821 const Node &node = idAndNode.second;
822 worklist.push_back(node.id);
826 void runSiblingFusionOnly() {
828 eraseUnusedMemRefAllocations();
832 void runProducerConsumerFusionOnly() {
833 fuseProducerConsumerNodes(
835 eraseUnusedMemRefAllocations();
843 void runGreedyFusion() {
845 fuseProducerConsumerNodes(1);
847 fuseProducerConsumerNodes(
849 eraseUnusedMemRefAllocations();
854 bool canCreatePrivateMemRef(
Value memref,
856 unsigned producerId,
unsigned consumerId,
857 bool removeSrcNode) {
861 const Node *consumerNode = mdg->
getNode(consumerId);
869 if (srcEscapingMemRefs.count(memref) > 0 &&
870 (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
883 any_of(mdg->
outEdges[producerId], [&](
const auto &edge) {
884 return edge.value == memref && edge.id != consumerId;
894 void performFusionsIntoDest(
unsigned dstId,
unsigned maxSrcUserCount) {
895 LLVM_DEBUG(llvm::dbgs() <<
"Evaluating dst loop " << dstId <<
"\n");
897 if (mdg->
nodes.count(dstId) == 0)
900 auto *dstNode = mdg->
getNode(dstId);
902 if (!isa<AffineForOp>(dstNode->op))
906 if (dstNode->op->getNumResults() > 0)
909 LLVM_DEBUG(llvm::dbgs() <<
"Evaluating dst loop " << dstId <<
"\n");
916 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
927 dstNodeChanged =
false;
931 for (
unsigned srcId : llvm::reverse(srcIdCandidates)) {
933 auto *srcNode = mdg->
getNode(srcId);
934 auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
936 LLVM_DEBUG(llvm::dbgs()
937 <<
"Trying to fuse producer loop nest " << srcId
938 <<
" with consumer loop nest " << dstId <<
"\n");
939 LLVM_DEBUG(llvm::dbgs() <<
"Compute tolerance threshold: "
940 << computeToleranceThreshold <<
'\n');
941 LLVM_DEBUG(llvm::dbgs()
942 <<
"Producer loop nest:\n"
943 << *srcNode->op <<
"\n and consumer loop nest:\n"
944 << *dstNode->op <<
'\n');
946 LLVM_DEBUG(llvm::dbgs() <<
"Evaluating src loop " << srcId
947 <<
" for dst loop " << dstId <<
"\n");
951 if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
956 producerConsumerMemrefs);
960 if (any_of(producerConsumerMemrefs, [&](
Value memref) {
976 if (fusedLoopInsPoint ==
nullptr)
987 unsigned numSurroundingLoops = surroundingLoops.size();
993 if (producerConsumerMemrefs.count(
994 cast<AffineReadOpInterface>(op).getMemRef()) > 0)
995 dstMemrefOps.push_back(op);
997 if (producerConsumerMemrefs.count(
998 cast<AffineWriteOpInterface>(op).getMemRef()))
999 dstMemrefOps.push_back(op);
1000 unsigned dstLoopDepthTest =
1005 unsigned maxLegalFusionDepth = 0;
1007 depthSliceUnions.resize(dstLoopDepthTest);
1009 for (
unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1012 i + numSurroundingLoops,
1013 &depthSliceUnions[i - 1], strategy);
1015 maxLegalFusionDepth = i;
1016 LLVM_DEBUG(llvm::dbgs()
1017 <<
"Found valid slice for depth: " << i <<
'\n');
1021 if (maxLegalFusionDepth == 0) {
1022 LLVM_DEBUG(llvm::dbgs()
1023 <<
"Can't fuse: fusion is not legal at any depth\n");
1027 LLVM_DEBUG(llvm::dbgs() <<
"Max legal depth for fusion: "
1028 << maxLegalFusionDepth <<
'\n');
1030 double computeToleranceThresholdToUse = computeToleranceThreshold;
1038 LLVM_DEBUG(llvm::dbgs() <<
"Source nest has a cyclic dependence.\n");
1042 if (maximalFusion) {
1043 auto srcForOp = cast<AffineForOp>(srcNode->op);
1044 auto dstForOp = cast<AffineForOp>(dstNode->op);
1046 int64_t fusedLoopNestComputeCost;
1048 srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
1049 sliceCost, fusedLoopNestComputeCost);
1050 if (!fraction || fraction > 0) {
1053 <<
"Can't perform maximal fusion with a cyclic dependence "
1054 "and non-zero additional compute.\n");
1060 LLVM_DEBUG(llvm::dbgs()
1061 <<
"Setting compute tolerance to zero since "
1062 "source has a cylic dependence.\n");
1063 computeToleranceThresholdToUse = 0;
1070 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1071 if (!maximalFusion) {
1075 if (producerConsumerMemrefs.count(
1076 cast<AffineWriteOpInterface>(op).getMemRef()))
1077 producerStores.push_back(op);
1079 assert(!producerStores.empty() &&
"Expected producer store");
1081 dstAffineForOp, depthSliceUnions,
1082 maxLegalFusionDepth, &bestDstLoopDepth,
1083 computeToleranceThresholdToUse)) {
1088 assert(bestDstLoopDepth > 0 &&
"Unexpected loop fusion depth");
1090 depthSliceUnions[bestDstLoopDepth - 1];
1091 assert(!bestSlice.
isEmpty() &&
"Missing slice union for depth");
1097 srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
1101 for (
Value memref : producerConsumerMemrefs) {
1102 if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
1105 LLVM_DEBUG(llvm::dbgs()
1106 <<
"Creating private memref for " << memref <<
'\n');
1108 privateMemrefs.insert(memref);
1113 fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
1114 dstNodeChanged =
true;
1116 LLVM_DEBUG(llvm::dbgs()
1117 <<
"Fused src loop " << srcId <<
" into dst loop " << dstId
1118 <<
" at depth " << bestDstLoopDepth <<
":\n"
1119 << dstAffineForOp <<
"\n");
1122 if (fusedLoopInsPoint != dstAffineForOp)
1123 dstAffineForOp->moveBefore(fusedLoopInsPoint);
1126 mdg->
updateEdges(srcNode->id, dstNode->id, privateMemrefs,
1130 if (!privateMemrefs.empty()) {
1137 dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
1138 Value storeMemRef = storeOp.getMemRef();
1139 if (privateMemrefs.count(storeMemRef) > 0)
1140 privateMemRefToStores[storeMemRef].push_back(storeOp);
1147 for (
auto &memrefToStoresPair : privateMemRefToStores) {
1150 dstAffineForOp, storesForMemref, bestDstLoopDepth,
1151 fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
1157 mdg->
addEdge(newMemRefNodeId, dstId, newMemRef);
1162 dstNode = mdg->
getNode(dstId);
1167 dstLoopCollector.
collect(dstAffineForOp);
1176 if (removeSrcNode) {
1177 LLVM_DEBUG(llvm::dbgs()
1178 <<
"Removing src loop " << srcId <<
" after fusion\n");
1180 srcAffineForOp.erase();
1185 }
while (dstNodeChanged);
1192 void fuseProducerConsumerNodes(
unsigned maxSrcUserCount) {
1193 LLVM_DEBUG(llvm::dbgs() <<
"--- Producer/Consumer Fusion ---\n");
1195 while (!worklist.empty()) {
1196 unsigned dstId = worklist.back();
1197 worklist.pop_back();
1198 performFusionsIntoDest(dstId, maxSrcUserCount);
1204 void fuseSiblingNodes() {
1205 LLVM_DEBUG(llvm::dbgs() <<
"--- Sibling Fusion ---\n");
1207 while (!worklist.empty()) {
1208 unsigned dstId = worklist.back();
1209 worklist.pop_back();
1212 if (mdg->
nodes.count(dstId) == 0)
1215 auto *dstNode = mdg->
getNode(dstId);
1217 if (!isa<AffineForOp>(dstNode->op))
1220 fuseWithSiblingNodes(dstNode);
1225 void fuseWithSiblingNodes(
Node *dstNode) {
1227 std::pair<unsigned, Value> idAndMemref;
1228 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
1230 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
1231 unsigned sibId = idAndMemref.first;
1232 Value memref = idAndMemref.second;
1235 auto *sibNode = mdg->
getNode(sibId);
1238 assert(sibNode->op->getBlock() == dstNode->op->getBlock());
1243 if (insertPointInst ==
nullptr)
1250 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
1252 Operation *sibLoadOpInst = llvm::getSingleElement(sibLoadOpInsts);
1256 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
1266 unsigned numSurroundingLoops = surroundingLoops.size();
1269 unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
1270 auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
1274 depthSliceUnions.resize(dstLoopDepthTest);
1275 unsigned maxLegalFusionDepth = 0;
1277 for (
unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1280 i + numSurroundingLoops,
1281 &depthSliceUnions[i - 1], strategy);
1284 maxLegalFusionDepth = i;
1287 LLVM_DEBUG(llvm::dbgs() <<
"Max legal depth for fusion: "
1288 << maxLegalFusionDepth <<
'\n');
1291 if (maxLegalFusionDepth == 0)
1294 double computeToleranceThresholdToUse = computeToleranceThreshold;
1302 LLVM_DEBUG(llvm::dbgs() <<
"Source nest has a cyclic dependence.\n");
1306 if (maximalFusion) {
1307 auto dstForOp = cast<AffineForOp>(dstNode->op);
1309 int64_t fusedLoopNestComputeCost;
1311 sibAffineForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
1312 sliceCost, fusedLoopNestComputeCost);
1313 if (!fraction || fraction > 0) {
1316 <<
"Can't perform maximal fusion with a cyclic dependence "
1317 "and non-zero additional compute.\n");
1323 LLVM_DEBUG(llvm::dbgs() <<
"Setting compute tolerance to zero since "
1324 "source has a cyclic dependence.\n");
1325 computeToleranceThresholdToUse = 0.0;
1329 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1330 if (!maximalFusion) {
1336 depthSliceUnions, maxLegalFusionDepth,
1338 computeToleranceThresholdToUse))
1342 assert(bestDstLoopDepth > 0 &&
"Unexpected loop fusion depth");
1345 depthSliceUnions[bestDstLoopDepth - 1];
1346 assert(!bestSlice.
isEmpty() &&
1347 "Fusion depth has no computed slice union");
1353 if (!isMaximal.value_or(
false)) {
1354 LLVM_DEBUG(llvm::dbgs()
1355 <<
"Slice isn't maximal; not performing sibling fusion.\n");
1362 bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
1365 isInnermostInsertion);
1367 auto dstForInst = cast<AffineForOp>(dstNode->op);
1369 if (insertPointInst != dstForInst)
1370 dstForInst->moveBefore(insertPointInst);
1372 LLVM_DEBUG(llvm::dbgs()
1373 <<
"Fused sibling nest " << sibId <<
" into destination nest "
1374 << dstNode->id <<
" at depth " << bestDstLoopDepth <<
":\n"
1375 << dstAffineForOp <<
"\n");
1378 updateStateAfterSiblingFusion(sibNode, dstNode);
1392 bool findSiblingNodeToFuse(
Node *dstNode,
1394 std::pair<unsigned, Value> *idAndMemrefToFuse) {
1397 auto canFuseWithSibNode = [&](
Node *sibNode,
Value memref) {
1400 if (sibNode->getLoadOpCount(memref) != 1)
1410 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
1411 if (llvm::any_of(loadAndStoreMemrefSet, [=](
Value memref) {
1418 for (
auto *storeOpInst : sibNode->stores) {
1419 storeMemrefs.insert(
1420 cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
1422 return storeMemrefs.size() <= 1;
1426 Block *block = dstNode->op->getBlock();
1429 auto loadOp = dyn_cast<AffineReadOpInterface>(user);
1438 auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
1439 return loop->getBlock() == &mdg->
block;
1442 if (it == loops.end())
1445 assert(sibNode !=
nullptr);
1447 if (sibNode->id == dstNode->id)
1450 if (visitedSibNodeIds->count(sibNode->id) > 0)
1453 auto memref = loadOp.getMemRef();
1454 if (dstNode->getLoadOpCount(memref) == 0)
1457 if (canFuseWithSibNode(sibNode, memref)) {
1458 visitedSibNodeIds->insert(sibNode->id);
1459 idAndMemrefToFuse->first = sibNode->id;
1460 idAndMemrefToFuse->second = memref;
1472 if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
1473 mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
1474 inEdges.push_back(inEdge);
1479 for (
auto &inEdge : inEdges) {
1484 unsigned sibNodeId = outEdge.id;
1485 if (visitedSibNodeIds->count(sibNodeId) > 0)
1488 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
1490 auto *sibNode = mdg->getNode(sibNodeId);
1491 if (!isa<AffineForOp>(sibNode->op))
1494 if (canFuseWithSibNode(sibNode, outEdge.value)) {
1496 outEdges.push_back(outEdge);
1501 if (!outEdges.empty()) {
1502 visitedSibNodeIds->insert(outEdges[0].
id);
1503 idAndMemrefToFuse->first = outEdges[0].id;
1504 idAndMemrefToFuse->second = outEdges[0].value;
1513 void updateStateAfterSiblingFusion(
Node *sibNode,
Node *dstNode) {
1518 auto dstForInst = cast<AffineForOp>(dstNode->op);
1520 dstLoopCollector.
collect(dstForInst);
1529 void eraseUnusedMemRefAllocations() {
1531 if (pair.second > 0)
1533 auto memref = pair.first;
1539 if (isa_and_nonnull<memref::AllocOp>(op))
1548 void LoopFusion::runOnBlock(
Block *block) {
1551 LLVM_DEBUG(llvm::dbgs() <<
"MDG init failed\n");
1555 std::optional<unsigned> fastMemorySpaceOpt;
1556 if (fastMemorySpace.hasValue())
1557 fastMemorySpaceOpt = fastMemorySpace;
1558 unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
1559 GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
1560 maximalFusion, computeToleranceThreshold);
1563 fusion.runProducerConsumerFusionOnly();
1565 fusion.runSiblingFusionOnly();
1567 fusion.runGreedyFusion();
1570 void LoopFusion::runOnOperation() {
1573 getOperation()->walk([&](
Operation *op) {
1575 for (Block &block : region.getBlocks()) {
1576 auto affineFors = block.getOps<AffineForOp>();
1577 if (!affineFors.empty() && !llvm::hasSingleElement(affineFors))
1585 unsigned fastMemorySpace, uint64_t localBufSizeThreshold,
1586 bool maximalFusion,
enum FusionMode affineFusionMode) {
1587 return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
1588 maximalFusion, affineFusionMode);
MemRefDependenceGraph::Node Node
static Operation * getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock, ArrayRef< Operation * > producerStores)
Get the operation that should act as a dominance filter while replacing memref uses with a private me...
static void getProducerCandidates(unsigned dstId, const MemRefDependenceGraph &mdg, SmallVectorImpl< unsigned > &srcIdCandidates)
Returns in 'srcIdCandidates' the producer fusion candidates for consumer 'dstId'.
static bool isFusionProfitable(AffineForOp srcForOp, ArrayRef< Operation * > producerStores, AffineForOp dstForOp, ArrayRef< ComputationSliceState > depthSliceUnions, unsigned maxLegalFusionDepth, unsigned *dstLoopDepth, double computeToleranceThreshold)
static bool isEscapingMemref(Value memref, Block *block)
A memref escapes in the context of the fusion pass if either:
static bool canRemoveSrcNodeAfterFusion(unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, Operation *fusedLoopInsPoint, const DenseSet< Value > &escapingMemRefs, const MemRefDependenceGraph &mdg)
Returns true if node 'srcId' can be removed after fusing it with node 'dstId'.
static Value createPrivateMemRef(AffineForOp forOp, ArrayRef< Operation * > storeOps, unsigned dstLoopDepth, std::optional< unsigned > fastMemorySpace, Block *sliceInsertionBlock, uint64_t localBufSizeThreshold)
static void gatherEscapingMemrefs(unsigned id, const MemRefDependenceGraph &mdg, DenseSet< Value > &escapingMemRefs)
Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' that escape the block or ...
static std::optional< double > getAdditionalComputeFraction(AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth, ArrayRef< ComputationSliceState > depthSliceUnions, int64_t &sliceCost, int64_t &fusedLoopNestComputeCost)
Returns the amount of additional (redundant) computation that will be done as a fraction of the total...
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
void getValues(unsigned start, unsigned end, SmallVectorImpl< Value > *values) const
Returns the Values associated with variables in range [start, end).
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
Describes the fusion strategy to be used in the Affine loop fusion utilities.
unsigned getNumDimAndSymbolVars() const
bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, AffineForOp dstForOp, LoopNestStats &dstStats, const ComputationSliceState &slice, int64_t *computeCost)
Computes and returns in 'computeCost', the total compute cost of fusing the 'slice' of the loop nest ...
void gatherProducerConsumerMemrefs(ArrayRef< Operation * > srcOps, ArrayRef< Operation * > dstOps, DenseSet< Value > &producerConsumerMemrefs)
Returns in 'producerConsumerMemrefs' the memrefs involved in a producer-consumer dependence between w...
int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats)
Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
void fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, const ComputationSliceState &srcSlice, bool isInnermostSiblingInsertionFusion=false)
Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point and source slice loop bo...
void getAffineForIVs(Operation &op, SmallVectorImpl< AffineForOp > *loops)
Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered from the outermost 'affine....
std::optional< int64_t > getMemoryFootprintBytes(AffineForOp forOp, int memorySpace=-1)
Gets the memory footprint of all data touched in the specified memory space in bytes; if the memory s...
std::unique_ptr< Pass > createLoopFusionPass(unsigned fastMemorySpace=0, uint64_t localBufSizeThreshold=0, bool maximalFusion=false, enum FusionMode fusionMode=FusionMode::Greedy)
Creates a loop fusion pass which fuses affine loop nests at the top-level of the operation the pass i...
FusionMode
Fusion mode to attempt.
unsigned getInnermostCommonLoopDepth(ArrayRef< Operation * > ops, SmallVectorImpl< AffineForOp > *surroundingLoops=nullptr)
Returns the innermost common loop depth for the set of operations in 'ops'.
bool getLoopNestStats(AffineForOp forOp, LoopNestStats *stats)
Collect loop nest statistics (eg.
AffineForOp sinkSequentialLoops(AffineForOp forOp)
bool hasCyclicDependence(AffineForOp root)
Returns true if the affine nest rooted at root has a cyclic dependence among its affine memory access...
mlir::Block * findInnermostCommonBlockInScope(mlir::Operation *a, mlir::Operation *b)
Find the innermost common Block of a and b in the affine scope that a and b are part of.
FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, ComputationSliceState *srcSlice, FusionStrategy fusionStrategy=FusionStrategy::Generic)
Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the loop nest rooted at 'dst...
LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, ArrayRef< Value > extraIndices={}, AffineMap indexRemap=AffineMap(), ArrayRef< Value > extraOperands={}, ArrayRef< Value > symbolOperands={}, Operation *domOpFilter=nullptr, Operation *postDomOpFilter=nullptr, bool allowNonDereferencingOps=false, bool replaceInDeallocOp=false)
Replaces all "dereferencing" uses of oldMemRef with newMemRef while optionally remapping the old memr...
std::optional< int64_t > getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType)
Returns the memref's element type's size in bytes where the elemental type is an int or float or a ve...
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their associated operands for a ...
bool isEmpty() const
Returns true if the computation slice is empty.
std::optional< bool > isMaximal() const
Returns true if the computation slice encloses all the iterations of the sliced loop nest.
Block::iterator insertPoint
enum mlir::affine::FusionResult::ResultEnum value
SmallVector< Operation *, 4 > memrefFrees
SmallVector< Operation *, 4 > loadOpInsts
SmallVector< Operation *, 4 > memrefStores
void collect(Operation *opToWalk)
SmallVector< Operation *, 4 > memrefLoads
SmallVector< Operation *, 4 > storeOpInsts
LoopNestStats aggregates various per-loop statistics (eg.
DenseMap< unsigned, SmallVector< Edge, 2 > > outEdges
Block & block
The block for which this graph is created to perform fusion.
unsigned addNode(Operation *op)
void addEdge(unsigned srcId, unsigned dstId, Value value)
DenseMap< unsigned, Node > nodes
bool hasDependencePath(unsigned srcId, unsigned dstId) const
void clearNodeLoadAndStores(unsigned id)
const Node * getForOpNode(AffineForOp forOp) const
Operation * getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) const
void updateEdges(unsigned srcId, unsigned dstId, const DenseSet< Value > &privateMemRefs, bool removeSrcId)
DenseMap< unsigned, SmallVector< Edge, 2 > > inEdges
void forEachMemRefInputEdge(unsigned id, const std::function< void(Edge)> &callback)
unsigned getOutEdgeCount(unsigned id, Value memref=nullptr) const
const Node * getNode(unsigned id) const
void removeNode(unsigned id)
void forEachMemRefOutputEdge(unsigned id, const std::function< void(Edge)> &callback)
void addToNode(unsigned id, ArrayRef< Operation * > loads, ArrayRef< Operation * > stores, ArrayRef< Operation * > memrefLoads, ArrayRef< Operation * > memrefStores, ArrayRef< Operation * > memrefFrees)
unsigned getIncomingMemRefAccesses(unsigned id, Value memref) const
DenseMap< Value, unsigned > memrefEdgeCount
A region of a memref's data space; this is typically constructed by analyzing load/store op's on this...
std::optional< int64_t > getConstantBoundingSizeAndShape(SmallVectorImpl< int64_t > *shape=nullptr, SmallVectorImpl< AffineMap > *lbs=nullptr) const
Returns a constant upper bound on the number of elements in this region if bounded by a known constan...
FlatAffineValueConstraints * getConstraints()
LogicalResult compute(Operation *op, unsigned loopDepth, const ComputationSliceState *sliceState=nullptr, bool addMemRefDimBounds=true, bool dropLocalVars=true, bool dropOuterIVs=true)
Computes the memory region accessed by this memref with the region represented as constraints symboli...
std::optional< int64_t > getRegionSize()
Returns the size of this MemRefRegion in bytes.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.