26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/ADT/DenseSet.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/raw_ostream.h"
38 #define GEN_PASS_DEF_AFFINELOOPFUSION
39 #include "mlir/Dialect/Affine/Passes.h.inc"
43 #define DEBUG_TYPE "affine-fusion"
54 struct LoopFusion :
public affine::impl::AffineLoopFusionBase<LoopFusion> {
55 LoopFusion() =
default;
56 LoopFusion(
unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes,
57 bool maximalFusion,
enum FusionMode affineFusionMode) {
58 this->fastMemorySpace = fastMemorySpace;
59 this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024;
60 this->maximalFusion = maximalFusion;
61 this->affineFusionMode = affineFusionMode;
64 void runOnBlock(
Block *block);
65 void runOnOperation()
override;
83 bool hasOutDepsAfterFusion =
false;
85 for (
auto &outEdge : mdg.
outEdges.lookup(srcId)) {
88 if (depNodeOp == dstNodeOp)
98 if (fusedLoopInsPoint != depNodeOp &&
100 LLVM_DEBUG(llvm::dbgs() <<
"Src loop can't be removed: dst loop doesn't "
101 "dominate dependence\n");
105 hasOutDepsAfterFusion =
true;
111 if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
112 std::optional<bool> isMaximal = fusionSlice.
isMaximal();
114 LLVM_DEBUG(llvm::dbgs() <<
"Src loop can't be removed: can't determine "
115 "if fusion is maximal\n");
120 LLVM_DEBUG(llvm::dbgs()
121 <<
"Src loop can't be removed: fusion is not maximal\n");
140 if (mdg.
inEdges.count(dstId) == 0)
144 auto *dstNode = mdg.
getNode(dstId);
147 consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
151 for (
const auto &srcEdge : mdg.
inEdges.lookup(dstId)) {
152 const auto *srcNode = mdg.
getNode(srcEdge.id);
154 if (!isa<AffineForOp>(srcNode->op))
157 if (any_of(srcNode->stores, [&](
Operation *op) {
158 auto storeOp = cast<AffineWriteOpInterface>(op);
159 return consumedMemrefs.count(storeOp.getMemRef()) > 0;
161 srcIdCandidates.push_back(srcNode->id);
164 llvm::sort(srcIdCandidates);
165 srcIdCandidates.erase(llvm::unique(srcIdCandidates), srcIdCandidates.end());
174 auto *dstNode = mdg.
getNode(dstId);
175 auto *srcNode = mdg.
getNode(srcId);
177 producerConsumerMemrefs);
194 if (
auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
199 if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(defOp, memref))
206 Operation *ancestorOp = block->getParent()->findAncestorOpInRegion(*user);
209 if (ancestorOp->getBlock() != block)
211 return !isa<AffineMapAccessInterface>(*user);
220 for (
Operation *storeOp : node->stores) {
221 auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
222 if (escapingMemRefs.count(memref))
225 escapingMemRefs.insert(memref);
235 assert(isa<AffineForOp>(node->
op));
237 node->
op = newRootForOp;
247 assert(!producerStores.empty() &&
"expected producer store");
253 Block *commonBlock =
nullptr;
255 for (
Operation *store : producerStores) {
257 !commonBlock ? &*sliceInsertionBlock->
begin() : &*commonBlock->
begin();
260 assert(commonBlock &&
261 "common block of producer stores and slice should exist");
266 for (
Operation *store : producerStores) {
268 assert(ancestor &&
"producer store should be contained in common block");
269 firstAncestor = !firstAncestor || ancestor->
isBeforeInBlock(firstAncestor)
273 return firstAncestor;
281 AffineForOp srcForOp, AffineForOp dstForOp,
unsigned depth,
283 int64_t &fusedLoopNestComputeCost) {
284 LLVM_DEBUG(llvm::dbgs() <<
"Determining additional compute fraction...\n";);
289 LLVM_DEBUG(llvm::dbgs() <<
"Failed to get source loop nest stats.\n");
296 LLVM_DEBUG(llvm::dbgs() <<
"Failed to get destination loop nest stats.\n");
301 uint64_t srcLoopNestCost =
getComputeCost(srcForOp, srcLoopNestStats);
304 uint64_t dstLoopNestCost =
getComputeCost(dstForOp, dstLoopNestStats);
309 LLVM_DEBUG(llvm::dbgs() <<
"Slice wasn't computed.\n");
314 dstLoopNestStats, slice,
315 &fusedLoopNestComputeCost)) {
316 LLVM_DEBUG(llvm::dbgs() <<
"Unable to compute fusion compute cost\n");
320 double additionalComputeFraction =
321 fusedLoopNestComputeCost /
322 (
static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
325 return additionalComputeFraction;
336 unsigned dstLoopDepth,
337 std::optional<unsigned> fastMemorySpace,
338 Block *sliceInsertionBlock,
339 uint64_t localBufSizeThreshold) {
340 assert(!storeOps.empty() &&
"no source stores supplied");
346 if (storeOps.size() > 1 &&
347 !std::equal(std::next(storeOps.begin()), storeOps.end(), storeOps.begin(),
349 MemRefAccess aM(cast<AffineWriteOpInterface>(a));
350 MemRefAccess bM(cast<AffineWriteOpInterface>(b));
353 LLVM_DEBUG(llvm::dbgs()
354 <<
"Private memref creation unsupported for multiple producer "
355 "stores with different access functions.\n");
366 auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOp).getMemRef();
367 auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
368 unsigned rank = oldMemRefType.getRank();
372 bool validRegion = succeeded(
373 region.
compute(srcStoreOp, dstLoopDepth,
nullptr,
377 assert(validRegion &&
"unexpected memref region failure");
383 std::optional<int64_t> numElements =
385 assert(numElements &&
"non-constant number of elts in local buffer");
396 offsets.reserve(rank);
401 for (
unsigned j = 0, e = lbs[0].getNumSymbols();
j < e; ++
j)
403 for (
unsigned d = 0; d < rank; ++d) {
404 assert(lbs[d].getNumResults() == 1 &&
405 "invalid private memref bound calculation");
406 offsets.push_back(lbs[d].getResult(0).replaceSymbols(replacements));
412 assert(eltSize &&
"memrefs with size elt types expected");
413 uint64_t bufSize = *eltSize * *numElements;
415 if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
416 newMemSpace = b.getI64IntegerAttr(*fastMemorySpace);
418 newMemSpace = oldMemRefType.getMemorySpace();
420 auto newMemRefType =
MemRefType::get(newShape, oldMemRefType.getElementType(),
429 Value newMemRef = top.
create<memref::AllocOp>(forOp.getLoc(), newMemRefType);
433 remapExprs.reserve(rank);
434 for (
unsigned i = 0; i < rank; i++) {
435 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
439 remapExprs.push_back(remapExpr);
443 AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
449 oldMemRef, newMemRef, {}, indexRemap,
452 assert(succeeded(res) &&
453 "replaceAllMemrefUsesWith should always succeed here");
455 LLVM_DEBUG(llvm::dbgs() <<
"Created private memref of type: " << newMemRefType
500 AffineForOp dstForOp,
502 unsigned maxLegalFusionDepth,
503 unsigned *dstLoopDepth,
504 double computeToleranceThreshold) {
507 <<
"Checking whether fusion is profitable between source nest:\n";
508 llvm::dbgs() <<
' ' << srcForOp <<
" and destination nest:\n";
509 llvm::dbgs() << dstForOp <<
"\n";
512 if (maxLegalFusionDepth == 0) {
513 LLVM_DEBUG(llvm::dbgs() <<
"Can't fuse: maxLegalFusionDepth is 0\n");
536 if (producerStores.size() > 1) {
537 LLVM_DEBUG(llvm::dbgs() <<
"Limited profitability analysis. Not "
538 "supported for multiple producer store case.\n");
540 int64_t fusedLoopNestComputeCost;
544 srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost,
545 fusedLoopNestComputeCost);
546 if (!fraction || fraction > computeToleranceThreshold) {
547 LLVM_DEBUG(llvm::dbgs() <<
"Additional computation exceeds "
548 "compute tolerance. Not fusing.\n");
551 LLVM_DEBUG(llvm::dbgs()
552 <<
"Considering fusion profitable at max legal depth.\n");
556 Operation *srcStoreOp = producerStores.front();
565 double maxStorageReduction = 0.0;
566 std::optional<uint64_t> sliceMemEstimate;
569 std::optional<unsigned> bestDstLoopDepth;
573 if (failed(srcWriteRegion.
compute(srcStoreOp, 0))) {
574 LLVM_DEBUG(llvm::dbgs()
575 <<
"Unable to compute MemRefRegion for source operation\n");
579 std::optional<int64_t> maybeSrcWriteRegionSizeBytes =
581 if (!maybeSrcWriteRegionSizeBytes.has_value())
583 int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;
586 uint64_t srcLoopNestCost =
getComputeCost(srcForOp, srcLoopNestStats);
589 uint64_t dstLoopNestCost =
getComputeCost(dstForOp, dstLoopNestStats);
593 for (
unsigned i = maxLegalFusionDepth; i >= 1; --i) {
603 int64_t fusedLoopNestComputeCost;
605 auto mayAdditionalComputeFraction =
607 sliceCost, fusedLoopNestComputeCost);
608 if (!mayAdditionalComputeFraction) {
609 LLVM_DEBUG(llvm::dbgs()
610 <<
"Can't determine additional compute fraction.\n");
613 double additionalComputeFraction = *mayAdditionalComputeFraction;
619 if (failed(sliceWriteRegion.
compute(srcStoreOp, 0, &slice))) {
620 LLVM_DEBUG(llvm::dbgs()
621 <<
"Failed to compute slice write region at loopDepth: " << i
626 std::optional<int64_t> maybeSliceWriteRegionSizeBytes =
628 if (!maybeSliceWriteRegionSizeBytes.has_value() ||
629 *maybeSliceWriteRegionSizeBytes == 0) {
630 LLVM_DEBUG(llvm::dbgs()
631 <<
"Failed to get slice write region size at loopDepth: " << i
635 int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
637 double storageReduction =
static_cast<double>(srcWriteRegionSizeBytes) /
638 static_cast<double>(sliceWriteRegionSizeBytes);
641 std::stringstream msg;
642 msg <<
" evaluating fusion profitability at depth : " << i <<
"\n"
643 << std::fixed << std::setprecision(2)
644 <<
" additional compute fraction: "
645 << 100.0 * additionalComputeFraction <<
"%\n"
646 <<
" storage reduction factor: " << storageReduction <<
"x\n"
647 <<
" fused nest cost: " << fusedLoopNestComputeCost <<
"\n"
648 <<
" src write region size: " << srcWriteRegionSizeBytes <<
"\n"
649 <<
" slice write region size: " << sliceWriteRegionSizeBytes
651 llvm::dbgs() << msg.str();
658 if ((storageReduction > maxStorageReduction) &&
659 (additionalComputeFraction <= computeToleranceThreshold)) {
660 maxStorageReduction = storageReduction;
661 bestDstLoopDepth = i;
662 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
663 sliceMemEstimate = sliceWriteRegionSizeBytes;
669 if (!bestDstLoopDepth) {
672 <<
"All fusion choices involve more than the threshold amount of "
673 "redundant computation; NOT fusing.\n");
677 if (!bestDstLoopDepth) {
678 LLVM_DEBUG(llvm::dbgs() <<
"no fusion depth could be evaluated.\n");
683 *dstLoopDepth = *bestDstLoopDepth;
686 llvm::dbgs() <<
" LoopFusion fusion stats:"
687 <<
"\n best loop depth: " << bestDstLoopDepth
688 <<
"\n src loop nest compute cost: " << srcLoopNestCost
689 <<
"\n dst loop nest compute cost: " << dstLoopNestCost
690 <<
"\n fused loop nest compute cost: "
691 << minFusedLoopNestComputeCost <<
"\n");
696 std::optional<double> storageReduction;
698 if (!dstMemSize || !srcMemSize) {
699 LLVM_DEBUG(llvm::dbgs()
700 <<
" fusion memory benefit cannot be evaluated; NOT fusing.\n");
704 auto srcMemSizeVal = *srcMemSize;
705 auto dstMemSizeVal = *dstMemSize;
707 assert(sliceMemEstimate &&
"expected value");
708 auto fusedMem = dstMemSizeVal + *sliceMemEstimate;
710 LLVM_DEBUG(llvm::dbgs() <<
" src mem: " << srcMemSizeVal <<
"\n"
711 <<
" dst mem: " << dstMemSizeVal <<
"\n"
712 <<
" fused mem: " << fusedMem <<
"\n"
713 <<
" slice mem: " << sliceMemEstimate <<
"\n");
715 if (
static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
716 LLVM_DEBUG(llvm::dbgs() <<
"Fusion is not profitable; NOT fusing.\n");
721 (1.0 - fusedMem / (
static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
723 double additionalComputeFraction =
724 100.0 * (minFusedLoopNestComputeCost /
725 (
static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
727 (void)additionalComputeFraction;
729 std::stringstream msg;
730 msg <<
" fusion is most profitable at depth " << *dstLoopDepth <<
" with "
731 << std::setprecision(2) << additionalComputeFraction
732 <<
"% redundant computation and a ";
733 msg << (storageReduction ? std::to_string(*storageReduction) :
"<unknown>");
734 msg <<
"% storage reduction.\n";
735 llvm::dbgs() << msg.str();
789 struct GreedyFusion {
796 unsigned localBufSizeThreshold;
798 std::optional<unsigned> fastMemorySpace;
804 double computeToleranceThreshold;
809 std::optional<unsigned> fastMemorySpace,
bool maximalFusion,
810 double computeToleranceThreshold)
811 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
812 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
813 computeToleranceThreshold(computeToleranceThreshold) {}
820 for (
auto &idAndNode : mdg->
nodes) {
821 const Node &node = idAndNode.second;
822 worklist.push_back(node.id);
826 void runSiblingFusionOnly() {
828 eraseUnusedMemRefAllocations();
832 void runProducerConsumerFusionOnly() {
833 fuseProducerConsumerNodes(
835 eraseUnusedMemRefAllocations();
843 void runGreedyFusion() {
845 fuseProducerConsumerNodes(1);
847 fuseProducerConsumerNodes(
849 eraseUnusedMemRefAllocations();
854 bool canCreatePrivateMemRef(
Value memref,
856 unsigned producerId,
unsigned consumerId,
857 bool removeSrcNode) {
861 const Node *consumerNode = mdg->
getNode(consumerId);
869 if (srcEscapingMemRefs.count(memref) > 0 &&
870 (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
883 any_of(mdg->
outEdges[producerId], [&](
const auto &edge) {
884 return edge.value == memref && edge.id != consumerId;
894 void performFusionsIntoDest(
unsigned dstId,
unsigned maxSrcUserCount) {
895 LLVM_DEBUG(llvm::dbgs() <<
"Evaluating dst loop " << dstId <<
"\n");
897 if (mdg->
nodes.count(dstId) == 0)
900 auto *dstNode = mdg->
getNode(dstId);
902 if (!isa<AffineForOp>(dstNode->op))
906 if (dstNode->op->getNumResults() > 0)
909 LLVM_DEBUG(llvm::dbgs() <<
"Evaluating dst loop " << dstId <<
"\n");
916 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
927 dstNodeChanged =
false;
931 for (
unsigned srcId : llvm::reverse(srcIdCandidates)) {
933 auto *srcNode = mdg->
getNode(srcId);
934 auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
936 LLVM_DEBUG(llvm::dbgs()
937 <<
"Trying to fuse producer loop nest " << srcId
938 <<
" with consumer loop nest " << dstId <<
"\n");
939 LLVM_DEBUG(llvm::dbgs() <<
"Compute tolerance threshold: "
940 << computeToleranceThreshold <<
'\n');
941 LLVM_DEBUG(llvm::dbgs()
942 <<
"Producer loop nest:\n"
943 << *srcNode->op <<
"\n and consumer loop nest:\n"
944 << *dstNode->op <<
'\n');
946 LLVM_DEBUG(llvm::dbgs() <<
"Evaluating src loop " << srcId
947 <<
" for dst loop " << dstId <<
"\n");
951 if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
956 producerConsumerMemrefs);
960 if (any_of(producerConsumerMemrefs, [&](
Value memref) {
976 if (fusedLoopInsPoint ==
nullptr)
987 unsigned numSurroundingLoops = surroundingLoops.size();
993 if (producerConsumerMemrefs.count(
994 cast<AffineReadOpInterface>(op).getMemRef()) > 0)
995 dstMemrefOps.push_back(op);
997 if (producerConsumerMemrefs.count(
998 cast<AffineWriteOpInterface>(op).getMemRef()))
999 dstMemrefOps.push_back(op);
1000 unsigned dstLoopDepthTest =
1005 unsigned maxLegalFusionDepth = 0;
1007 depthSliceUnions.resize(dstLoopDepthTest);
1009 for (
unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1012 i + numSurroundingLoops,
1013 &depthSliceUnions[i - 1], strategy);
1015 maxLegalFusionDepth = i;
1016 LLVM_DEBUG(llvm::dbgs()
1017 <<
"Found valid slice for depth: " << i <<
'\n');
1021 if (maxLegalFusionDepth == 0) {
1022 LLVM_DEBUG(llvm::dbgs()
1023 <<
"Can't fuse: fusion is not legal at any depth\n");
1027 LLVM_DEBUG(llvm::dbgs() <<
"Max legal depth for fusion: "
1028 << maxLegalFusionDepth <<
'\n');
1030 double computeToleranceThresholdToUse = computeToleranceThreshold;
1038 LLVM_DEBUG(llvm::dbgs() <<
"Source nest has a cyclic dependence.\n");
1042 if (maximalFusion) {
1043 auto srcForOp = cast<AffineForOp>(srcNode->op);
1044 auto dstForOp = cast<AffineForOp>(dstNode->op);
1046 int64_t fusedLoopNestComputeCost;
1048 srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
1049 sliceCost, fusedLoopNestComputeCost);
1050 if (!fraction || fraction > 0) {
1053 <<
"Can't perform maximal fusion with a cyclic dependence "
1054 "and non-zero additional compute.\n");
1060 LLVM_DEBUG(llvm::dbgs()
1061 <<
"Setting compute tolerance to zero since "
1062 "source has a cylic dependence.\n");
1063 computeToleranceThresholdToUse = 0;
1070 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1071 if (!maximalFusion) {
1075 if (producerConsumerMemrefs.count(
1076 cast<AffineWriteOpInterface>(op).getMemRef()))
1077 producerStores.push_back(op);
1079 assert(!producerStores.empty() &&
"Expected producer store");
1081 dstAffineForOp, depthSliceUnions,
1082 maxLegalFusionDepth, &bestDstLoopDepth,
1083 computeToleranceThresholdToUse)) {
1088 assert(bestDstLoopDepth > 0 &&
"Unexpected loop fusion depth");
1090 depthSliceUnions[bestDstLoopDepth - 1];
1091 assert(!bestSlice.
isEmpty() &&
"Missing slice union for depth");
1097 srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
1101 for (
Value memref : producerConsumerMemrefs) {
1102 if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
1105 LLVM_DEBUG(llvm::dbgs()
1106 <<
"Creating private memref for " << memref <<
'\n');
1108 privateMemrefs.insert(memref);
1113 fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
1114 dstNodeChanged =
true;
1116 LLVM_DEBUG(llvm::dbgs()
1117 <<
"Fused src loop " << srcId <<
" into dst loop " << dstId
1118 <<
" at depth " << bestDstLoopDepth <<
":\n"
1119 << dstAffineForOp <<
"\n");
1122 if (fusedLoopInsPoint != dstAffineForOp)
1123 dstAffineForOp->moveBefore(fusedLoopInsPoint);
1126 mdg->
updateEdges(srcNode->id, dstNode->id, privateMemrefs,
1130 if (!privateMemrefs.empty()) {
1137 dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
1138 Value storeMemRef = storeOp.getMemRef();
1139 if (privateMemrefs.count(storeMemRef) > 0)
1140 privateMemRefToStores[storeMemRef].push_back(storeOp);
1147 for (
auto &memrefToStoresPair : privateMemRefToStores) {
1150 dstAffineForOp, storesForMemref, bestDstLoopDepth,
1151 fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
1157 mdg->
addEdge(newMemRefNodeId, dstId, newMemRef);
1162 dstNode = mdg->
getNode(dstId);
1167 dstLoopCollector.
collect(dstAffineForOp);
1176 if (removeSrcNode) {
1177 LLVM_DEBUG(llvm::dbgs()
1178 <<
"Removing src loop " << srcId <<
" after fusion\n");
1180 srcAffineForOp.erase();
1185 }
while (dstNodeChanged);
1192 void fuseProducerConsumerNodes(
unsigned maxSrcUserCount) {
1193 LLVM_DEBUG(llvm::dbgs() <<
"--- Producer/Consumer Fusion ---\n");
1195 while (!worklist.empty()) {
1196 unsigned dstId = worklist.back();
1197 worklist.pop_back();
1198 performFusionsIntoDest(dstId, maxSrcUserCount);
1204 void fuseSiblingNodes() {
1205 LLVM_DEBUG(llvm::dbgs() <<
"--- Sibling Fusion ---\n");
1207 while (!worklist.empty()) {
1208 unsigned dstId = worklist.back();
1209 worklist.pop_back();
1212 if (mdg->
nodes.count(dstId) == 0)
1215 auto *dstNode = mdg->
getNode(dstId);
1217 if (!isa<AffineForOp>(dstNode->op))
1220 fuseWithSiblingNodes(dstNode);
1225 void fuseWithSiblingNodes(
Node *dstNode) {
1227 std::pair<unsigned, Value> idAndMemref;
1228 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
1230 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
1231 unsigned sibId = idAndMemref.first;
1232 Value memref = idAndMemref.second;
1235 auto *sibNode = mdg->
getNode(sibId);
1238 assert(sibNode->op->getBlock() == dstNode->op->getBlock());
1243 if (insertPointInst ==
nullptr)
1250 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
1252 assert(sibLoadOpInsts.size() == 1);
1253 Operation *sibLoadOpInst = sibLoadOpInsts[0];
1257 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
1267 unsigned numSurroundingLoops = surroundingLoops.size();
1270 unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
1271 auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
1275 depthSliceUnions.resize(dstLoopDepthTest);
1276 unsigned maxLegalFusionDepth = 0;
1278 for (
unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1281 i + numSurroundingLoops,
1282 &depthSliceUnions[i - 1], strategy);
1285 maxLegalFusionDepth = i;
1288 LLVM_DEBUG(llvm::dbgs() <<
"Max legal depth for fusion: "
1289 << maxLegalFusionDepth <<
'\n');
1292 if (maxLegalFusionDepth == 0)
1295 double computeToleranceThresholdToUse = computeToleranceThreshold;
1303 LLVM_DEBUG(llvm::dbgs() <<
"Source nest has a cyclic dependence.\n");
1307 if (maximalFusion) {
1308 auto dstForOp = cast<AffineForOp>(dstNode->op);
1310 int64_t fusedLoopNestComputeCost;
1312 sibAffineForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
1313 sliceCost, fusedLoopNestComputeCost);
1314 if (!fraction || fraction > 0) {
1317 <<
"Can't perform maximal fusion with a cyclic dependence "
1318 "and non-zero additional compute.\n");
1324 LLVM_DEBUG(llvm::dbgs() <<
"Setting compute tolerance to zero since "
1325 "source has a cyclic dependence.\n");
1326 computeToleranceThresholdToUse = 0.0;
1330 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1331 if (!maximalFusion) {
1337 depthSliceUnions, maxLegalFusionDepth,
1339 computeToleranceThresholdToUse))
1343 assert(bestDstLoopDepth > 0 &&
"Unexpected loop fusion depth");
1346 depthSliceUnions[bestDstLoopDepth - 1];
1347 assert(!bestSlice.
isEmpty() &&
1348 "Fusion depth has no computed slice union");
1354 if (!isMaximal.value_or(
false)) {
1355 LLVM_DEBUG(llvm::dbgs()
1356 <<
"Slice isn't maximal; not performing sibling fusion.\n");
1363 bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
1366 isInnermostInsertion);
1368 auto dstForInst = cast<AffineForOp>(dstNode->op);
1370 if (insertPointInst != dstForInst)
1371 dstForInst->moveBefore(insertPointInst);
1373 LLVM_DEBUG(llvm::dbgs()
1374 <<
"Fused sibling nest " << sibId <<
" into destination nest "
1375 << dstNode->id <<
" at depth " << bestDstLoopDepth <<
":\n"
1376 << dstAffineForOp <<
"\n");
1379 updateStateAfterSiblingFusion(sibNode, dstNode);
1393 bool findSiblingNodeToFuse(
Node *dstNode,
1395 std::pair<unsigned, Value> *idAndMemrefToFuse) {
1398 auto canFuseWithSibNode = [&](
Node *sibNode,
Value memref) {
1401 if (sibNode->getLoadOpCount(memref) != 1)
1411 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
1412 if (llvm::any_of(loadAndStoreMemrefSet, [=](
Value memref) {
1419 for (
auto *storeOpInst : sibNode->stores) {
1420 storeMemrefs.insert(
1421 cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
1423 return storeMemrefs.size() <= 1;
1427 Block *block = dstNode->op->getBlock();
1430 auto loadOp = dyn_cast<AffineReadOpInterface>(user);
1439 auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
1440 return loop->getBlock() == &mdg->
block;
1443 if (it == loops.end())
1446 assert(sibNode !=
nullptr);
1448 if (sibNode->id == dstNode->id)
1451 if (visitedSibNodeIds->count(sibNode->id) > 0)
1454 auto memref = loadOp.getMemRef();
1455 if (dstNode->getLoadOpCount(memref) == 0)
1458 if (canFuseWithSibNode(sibNode, memref)) {
1459 visitedSibNodeIds->insert(sibNode->id);
1460 idAndMemrefToFuse->first = sibNode->id;
1461 idAndMemrefToFuse->second = memref;
1473 if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
1474 mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
1475 inEdges.push_back(inEdge);
1480 for (
auto &inEdge : inEdges) {
1485 unsigned sibNodeId = outEdge.id;
1486 if (visitedSibNodeIds->count(sibNodeId) > 0)
1489 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
1491 auto *sibNode = mdg->getNode(sibNodeId);
1492 if (!isa<AffineForOp>(sibNode->op))
1495 if (canFuseWithSibNode(sibNode, outEdge.value)) {
1497 outEdges.push_back(outEdge);
1502 if (!outEdges.empty()) {
1503 visitedSibNodeIds->insert(outEdges[0].
id);
1504 idAndMemrefToFuse->first = outEdges[0].id;
1505 idAndMemrefToFuse->second = outEdges[0].value;
1514 void updateStateAfterSiblingFusion(
Node *sibNode,
Node *dstNode) {
1519 auto dstForInst = cast<AffineForOp>(dstNode->op);
1521 dstLoopCollector.
collect(dstForInst);
1530 void eraseUnusedMemRefAllocations() {
1532 if (pair.second > 0)
1534 auto memref = pair.first;
1540 if (isa_and_nonnull<memref::AllocOp>(op))
1549 void LoopFusion::runOnBlock(
Block *block) {
1552 LLVM_DEBUG(llvm::dbgs() <<
"MDG init failed\n");
1556 std::optional<unsigned> fastMemorySpaceOpt;
1557 if (fastMemorySpace.hasValue())
1558 fastMemorySpaceOpt = fastMemorySpace;
1559 unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
1560 GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
1561 maximalFusion, computeToleranceThreshold);
1564 fusion.runProducerConsumerFusionOnly();
1566 fusion.runSiblingFusionOnly();
1568 fusion.runGreedyFusion();
1571 void LoopFusion::runOnOperation() {
1574 getOperation()->walk([&](
Operation *op) {
1576 for (Block &block : region.getBlocks()) {
1577 auto affineFors = block.getOps<AffineForOp>();
1578 if (!affineFors.empty() && !llvm::hasSingleElement(affineFors))
1586 unsigned fastMemorySpace, uint64_t localBufSizeThreshold,
1587 bool maximalFusion,
enum FusionMode affineFusionMode) {
1588 return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
1589 maximalFusion, affineFusionMode);
MemRefDependenceGraph::Node Node
static Operation * getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock, ArrayRef< Operation * > producerStores)
Get the operation that should act as a dominance filter while replacing memref uses with a private me...
static void getProducerCandidates(unsigned dstId, const MemRefDependenceGraph &mdg, SmallVectorImpl< unsigned > &srcIdCandidates)
Returns in 'srcIdCandidates' the producer fusion candidates for consumer 'dstId'.
static bool isFusionProfitable(AffineForOp srcForOp, ArrayRef< Operation * > producerStores, AffineForOp dstForOp, ArrayRef< ComputationSliceState > depthSliceUnions, unsigned maxLegalFusionDepth, unsigned *dstLoopDepth, double computeToleranceThreshold)
static bool isEscapingMemref(Value memref, Block *block)
A memref escapes in the context of the fusion pass if either:
static bool canRemoveSrcNodeAfterFusion(unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, Operation *fusedLoopInsPoint, const DenseSet< Value > &escapingMemRefs, const MemRefDependenceGraph &mdg)
Returns true if node 'srcId' can be removed after fusing it with node 'dstId'.
static Value createPrivateMemRef(AffineForOp forOp, ArrayRef< Operation * > storeOps, unsigned dstLoopDepth, std::optional< unsigned > fastMemorySpace, Block *sliceInsertionBlock, uint64_t localBufSizeThreshold)
static void gatherEscapingMemrefs(unsigned id, const MemRefDependenceGraph &mdg, DenseSet< Value > &escapingMemRefs)
Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' that escape the block or ...
static std::optional< double > getAdditionalComputeFraction(AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth, ArrayRef< ComputationSliceState > depthSliceUnions, int64_t &sliceCost, int64_t &fusedLoopNestComputeCost)
Returns the amount of additional (redundant) computation that will be done as a fraction of the total...
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
void getValues(unsigned start, unsigned end, SmallVectorImpl< Value > *values) const
Returns the Values associated with variables in range [start, end).
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
Describes the fusion strategy to be used in the Affine loop fusion utilities.
unsigned getNumDimAndSymbolVars() const
bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, AffineForOp dstForOp, LoopNestStats &dstStats, const ComputationSliceState &slice, int64_t *computeCost)
Computes and returns in 'computeCost', the total compute cost of fusing the 'slice' of the loop nest ...
void gatherProducerConsumerMemrefs(ArrayRef< Operation * > srcOps, ArrayRef< Operation * > dstOps, DenseSet< Value > &producerConsumerMemrefs)
Returns in 'producerConsumerMemrefs' the memrefs involved in a producer-consumer dependence between w...
int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats)
Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
void fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, const ComputationSliceState &srcSlice, bool isInnermostSiblingInsertionFusion=false)
Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point and source slice loop bo...
void getAffineForIVs(Operation &op, SmallVectorImpl< AffineForOp > *loops)
Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered from the outermost 'affine....
std::optional< int64_t > getMemoryFootprintBytes(AffineForOp forOp, int memorySpace=-1)
Gets the memory footprint of all data touched in the specified memory space in bytes; if the memory s...
std::unique_ptr< Pass > createLoopFusionPass(unsigned fastMemorySpace=0, uint64_t localBufSizeThreshold=0, bool maximalFusion=false, enum FusionMode fusionMode=FusionMode::Greedy)
Creates a loop fusion pass which fuses affine loop nests at the top-level of the operation the pass i...
FusionMode
Fusion mode to attempt.
unsigned getInnermostCommonLoopDepth(ArrayRef< Operation * > ops, SmallVectorImpl< AffineForOp > *surroundingLoops=nullptr)
Returns the innermost common loop depth for the set of operations in 'ops'.
bool getLoopNestStats(AffineForOp forOp, LoopNestStats *stats)
Collect loop nest statistics (eg.
AffineForOp sinkSequentialLoops(AffineForOp forOp)
bool hasCyclicDependence(AffineForOp root)
Returns true if the affine nest rooted at root has a cyclic dependence among its affine memory access...
mlir::Block * findInnermostCommonBlockInScope(mlir::Operation *a, mlir::Operation *b)
Find the innermost common Block of a and b in the affine scope that a and b are part of.
FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, ComputationSliceState *srcSlice, FusionStrategy fusionStrategy=FusionStrategy::Generic)
Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the loop nest rooted at 'dst...
LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, ArrayRef< Value > extraIndices={}, AffineMap indexRemap=AffineMap(), ArrayRef< Value > extraOperands={}, ArrayRef< Value > symbolOperands={}, Operation *domOpFilter=nullptr, Operation *postDomOpFilter=nullptr, bool allowNonDereferencingOps=false, bool replaceInDeallocOp=false)
Replaces all "dereferencing" uses of oldMemRef with newMemRef while optionally remapping the old memr...
std::optional< int64_t > getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType)
Returns the memref's element type's size in bytes where the elemental type is an int or float or a ve...
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their associated operands for a ...
bool isEmpty() const
Returns true if the computation slice is empty.
std::optional< bool > isMaximal() const
Returns true if the computation slice encloses all the iterations of the sliced loop nest.
Block::iterator insertPoint
enum mlir::affine::FusionResult::ResultEnum value
SmallVector< Operation *, 4 > memrefFrees
SmallVector< Operation *, 4 > loadOpInsts
SmallVector< Operation *, 4 > memrefStores
void collect(Operation *opToWalk)
SmallVector< Operation *, 4 > memrefLoads
SmallVector< Operation *, 4 > storeOpInsts
LoopNestStats aggregates various per-loop statistics (eg.
DenseMap< unsigned, SmallVector< Edge, 2 > > outEdges
Block & block
The block for which this graph is created to perform fusion.
unsigned addNode(Operation *op)
void addEdge(unsigned srcId, unsigned dstId, Value value)
DenseMap< unsigned, Node > nodes
bool hasDependencePath(unsigned srcId, unsigned dstId) const
void clearNodeLoadAndStores(unsigned id)
const Node * getForOpNode(AffineForOp forOp) const
Operation * getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) const
void updateEdges(unsigned srcId, unsigned dstId, const DenseSet< Value > &privateMemRefs, bool removeSrcId)
DenseMap< unsigned, SmallVector< Edge, 2 > > inEdges
void forEachMemRefInputEdge(unsigned id, const std::function< void(Edge)> &callback)
unsigned getOutEdgeCount(unsigned id, Value memref=nullptr) const
const Node * getNode(unsigned id) const
void removeNode(unsigned id)
void forEachMemRefOutputEdge(unsigned id, const std::function< void(Edge)> &callback)
void addToNode(unsigned id, ArrayRef< Operation * > loads, ArrayRef< Operation * > stores, ArrayRef< Operation * > memrefLoads, ArrayRef< Operation * > memrefStores, ArrayRef< Operation * > memrefFrees)
unsigned getIncomingMemRefAccesses(unsigned id, Value memref) const
DenseMap< Value, unsigned > memrefEdgeCount
A region of a memref's data space; this is typically constructed by analyzing load/store op's on this...
std::optional< int64_t > getConstantBoundingSizeAndShape(SmallVectorImpl< int64_t > *shape=nullptr, SmallVectorImpl< AffineMap > *lbs=nullptr) const
Returns a constant upper bound on the number of elements in this region if bounded by a known constan...
FlatAffineValueConstraints * getConstraints()
LogicalResult compute(Operation *op, unsigned loopDepth, const ComputationSliceState *sliceState=nullptr, bool addMemRefDimBounds=true, bool dropLocalVars=true, bool dropOuterIVs=true)
Computes the memory region accessed by this memref with the region represented as constraints symboli...
std::optional< int64_t > getRegionSize()
Returns the size of this MemRefRegion in bytes.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.