25 #include "llvm/ADT/DenseMap.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/DebugLog.h"
30 #include "llvm/Support/raw_ostream.h"
37 #define GEN_PASS_DEF_AFFINELOOPFUSION
38 #include "mlir/Dialect/Affine/Passes.h.inc"
42 #define DEBUG_TYPE "affine-fusion"
53 struct LoopFusion :
public affine::impl::AffineLoopFusionBase<LoopFusion> {
54 LoopFusion() =
default;
55 LoopFusion(
unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes,
56 bool maximalFusion,
enum FusionMode affineFusionMode) {
57 this->fastMemorySpace = fastMemorySpace;
58 this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024;
59 this->maximalFusion = maximalFusion;
60 this->affineFusionMode = affineFusionMode;
63 void runOnBlock(
Block *block);
64 void runOnOperation()
override;
82 bool hasOutDepsAfterFusion =
false;
84 for (
auto &outEdge : mdg.
outEdges.lookup(srcId)) {
87 if (depNodeOp == dstNodeOp)
97 if (fusedLoopInsPoint != depNodeOp &&
99 LDBG() <<
"Src loop can't be removed: dst loop doesn't "
100 <<
"dominate dependence";
104 hasOutDepsAfterFusion =
true;
110 if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
111 std::optional<bool> isMaximal = fusionSlice.
isMaximal();
113 LDBG() <<
"Src loop can't be removed: can't determine "
114 <<
"if fusion is maximal";
119 LDBG() <<
"Src loop can't be removed: fusion is not maximal";
138 if (mdg.
inEdges.count(dstId) == 0)
142 auto *dstNode = mdg.
getNode(dstId);
145 consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
149 for (
const auto &srcEdge : mdg.
inEdges.lookup(dstId)) {
150 const auto *srcNode = mdg.
getNode(srcEdge.id);
152 if (!isa<AffineForOp>(srcNode->op))
155 if (any_of(srcNode->stores, [&](
Operation *op) {
156 auto storeOp = cast<AffineWriteOpInterface>(op);
157 return consumedMemrefs.count(storeOp.getMemRef()) > 0;
159 srcIdCandidates.push_back(srcNode->id);
162 llvm::sort(srcIdCandidates);
163 srcIdCandidates.erase(llvm::unique(srcIdCandidates), srcIdCandidates.end());
172 auto *dstNode = mdg.
getNode(dstId);
173 auto *srcNode = mdg.
getNode(srcId);
175 producerConsumerMemrefs);
192 if (
auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
193 if (memref == viewOp.getViewDest() &&
198 if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(defOp, memref))
205 Operation *ancestorOp = block->getParent()->findAncestorOpInRegion(*user);
208 if (ancestorOp->getBlock() != block)
210 return !isa<AffineMapAccessInterface>(*user);
219 for (
Operation *storeOp : node->stores) {
220 auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
221 if (escapingMemRefs.count(memref))
224 escapingMemRefs.insert(memref);
234 assert(isa<AffineForOp>(node->
op));
236 node->
op = newRootForOp;
246 assert(!producerStores.empty() &&
"expected producer store");
252 Block *commonBlock =
nullptr;
254 for (
Operation *store : producerStores) {
256 !commonBlock ? &*sliceInsertionBlock->
begin() : &*commonBlock->
begin();
259 assert(commonBlock &&
260 "common block of producer stores and slice should exist");
265 for (
Operation *store : producerStores) {
267 assert(ancestor &&
"producer store should be contained in common block");
268 firstAncestor = !firstAncestor || ancestor->
isBeforeInBlock(firstAncestor)
272 return firstAncestor;
280 AffineForOp srcForOp, AffineForOp dstForOp,
unsigned depth,
282 int64_t &fusedLoopNestComputeCost) {
283 LDBG() <<
"Determining additional compute fraction...";
288 LDBG() <<
"Failed to get source loop nest stats.";
295 LDBG() <<
"Failed to get destination loop nest stats.";
300 uint64_t srcLoopNestCost =
getComputeCost(srcForOp, srcLoopNestStats);
303 uint64_t dstLoopNestCost =
getComputeCost(dstForOp, dstLoopNestStats);
308 LDBG() <<
"Slice wasn't computed.";
313 dstLoopNestStats, slice,
314 &fusedLoopNestComputeCost)) {
315 LDBG() <<
"Unable to compute fusion compute cost";
319 double additionalComputeFraction =
320 fusedLoopNestComputeCost /
321 (
static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
324 return additionalComputeFraction;
335 unsigned dstLoopDepth,
336 std::optional<unsigned> fastMemorySpace,
337 Block *sliceInsertionBlock,
338 uint64_t localBufSizeThreshold) {
339 assert(!storeOps.empty() &&
"no source stores supplied");
345 if (storeOps.size() > 1 &&
346 !std::equal(std::next(storeOps.begin()), storeOps.end(), storeOps.begin(),
348 MemRefAccess aM(cast<AffineWriteOpInterface>(a));
349 MemRefAccess bM(cast<AffineWriteOpInterface>(b));
352 LDBG() <<
"Private memref creation unsupported for multiple producer "
353 <<
"stores with different access functions.";
364 auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOp).getMemRef();
365 auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
366 unsigned rank = oldMemRefType.getRank();
370 bool validRegion = succeeded(
371 region.
compute(srcStoreOp, dstLoopDepth,
nullptr,
375 assert(validRegion &&
"unexpected memref region failure");
381 std::optional<int64_t> numElements =
383 assert(numElements &&
"non-constant number of elts in local buffer");
394 offsets.reserve(rank);
399 for (
unsigned j = 0, e = lbs[0].getNumSymbols();
j < e; ++
j)
401 for (
unsigned d = 0; d < rank; ++d) {
402 assert(lbs[d].getNumResults() == 1 &&
403 "invalid private memref bound calculation");
404 offsets.push_back(lbs[d].getResult(0).replaceSymbols(replacements));
410 assert(eltSize &&
"memrefs with size elt types expected");
411 uint64_t bufSize = *eltSize * *numElements;
413 if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
414 newMemSpace = b.getI64IntegerAttr(*fastMemorySpace);
416 newMemSpace = oldMemRefType.getMemorySpace();
418 auto newMemRefType =
MemRefType::get(newShape, oldMemRefType.getElementType(),
427 Value newMemRef = memref::AllocOp::create(top, forOp.getLoc(), newMemRefType);
431 remapExprs.reserve(rank);
432 for (
unsigned i = 0; i < rank; i++) {
433 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
437 remapExprs.push_back(remapExpr);
441 AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
446 auto userFilterFn = [&](
Operation *user) {
447 auto domInfo = std::make_unique<DominanceInfo>(
449 return domInfo->dominates(domFilter, user);
452 oldMemRef, newMemRef, {}, indexRemap,
455 assert(succeeded(res) &&
456 "replaceAllMemrefUsesWith should always succeed here");
458 LDBG() <<
"Created private memref of type: " << newMemRefType;
502 AffineForOp dstForOp,
504 unsigned maxLegalFusionDepth,
505 unsigned *dstLoopDepth,
506 double computeToleranceThreshold) {
507 LDBG() <<
"Checking whether fusion is profitable between source nest:";
508 LDBG() <<
' ' << srcForOp <<
" and destination nest:";
511 if (maxLegalFusionDepth == 0) {
512 LDBG() <<
"Can't fuse: maxLegalFusionDepth is 0";
535 if (producerStores.size() > 1) {
536 LDBG() <<
"Limited profitability analysis. Not "
537 <<
"supported for multiple producer store case.";
539 int64_t fusedLoopNestComputeCost;
543 srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost,
544 fusedLoopNestComputeCost);
545 if (!fraction || fraction > computeToleranceThreshold) {
546 LDBG() <<
"Additional computation exceeds "
547 <<
"compute tolerance. Not fusing.";
550 LDBG() <<
"Considering fusion profitable at max legal depth.";
554 Operation *srcStoreOp = producerStores.front();
563 double maxStorageReduction = 0.0;
564 std::optional<uint64_t> sliceMemEstimate;
567 std::optional<unsigned> bestDstLoopDepth;
572 LDBG() <<
"Unable to compute MemRefRegion for source operation";
576 std::optional<int64_t> maybeSrcWriteRegionSizeBytes =
578 if (!maybeSrcWriteRegionSizeBytes.has_value())
580 int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;
583 uint64_t srcLoopNestCost =
getComputeCost(srcForOp, srcLoopNestStats);
586 uint64_t dstLoopNestCost =
getComputeCost(dstForOp, dstLoopNestStats);
590 for (
unsigned i = maxLegalFusionDepth; i >= 1; --i) {
600 int64_t fusedLoopNestComputeCost;
602 auto mayAdditionalComputeFraction =
604 sliceCost, fusedLoopNestComputeCost);
605 if (!mayAdditionalComputeFraction) {
606 LDBG() <<
"Can't determine additional compute fraction.";
609 double additionalComputeFraction = *mayAdditionalComputeFraction;
615 if (
failed(sliceWriteRegion.
compute(srcStoreOp, 0, &slice))) {
616 LDBG() <<
"Failed to compute slice write region at loopDepth: " << i;
620 std::optional<int64_t> maybeSliceWriteRegionSizeBytes =
622 if (!maybeSliceWriteRegionSizeBytes.has_value() ||
623 *maybeSliceWriteRegionSizeBytes == 0) {
624 LDBG() <<
"Failed to get slice write region size at loopDepth: " << i;
627 int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
629 double storageReduction =
static_cast<double>(srcWriteRegionSizeBytes) /
630 static_cast<double>(sliceWriteRegionSizeBytes);
633 std::stringstream msg;
634 msg <<
" evaluating fusion profitability at depth : " << i <<
"\n"
635 << std::fixed << std::setprecision(2)
636 <<
" additional compute fraction: "
637 << 100.0 * additionalComputeFraction <<
"%\n"
638 <<
" storage reduction factor: " << storageReduction <<
"x\n"
639 <<
" fused nest cost: " << fusedLoopNestComputeCost <<
"\n"
640 <<
" src write region size: " << srcWriteRegionSizeBytes <<
"\n"
641 <<
" slice write region size: " << sliceWriteRegionSizeBytes;
649 if ((storageReduction > maxStorageReduction) &&
650 (additionalComputeFraction <= computeToleranceThreshold)) {
651 maxStorageReduction = storageReduction;
652 bestDstLoopDepth = i;
653 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
654 sliceMemEstimate = sliceWriteRegionSizeBytes;
660 if (!bestDstLoopDepth) {
661 LDBG() <<
"All fusion choices involve more than the threshold amount of "
662 <<
"redundant computation; NOT fusing.";
666 if (!bestDstLoopDepth) {
667 LDBG() <<
"no fusion depth could be evaluated.";
672 *dstLoopDepth = *bestDstLoopDepth;
674 LDBG() <<
" LoopFusion fusion stats:";
675 LDBG() <<
" best loop depth: " << bestDstLoopDepth;
676 LDBG() <<
" src loop nest compute cost: " << srcLoopNestCost;
677 LDBG() <<
" dst loop nest compute cost: " << dstLoopNestCost;
678 LDBG() <<
" fused loop nest compute cost: " << minFusedLoopNestComputeCost;
683 std::optional<double> storageReduction;
685 if (!dstMemSize || !srcMemSize) {
686 LDBG() <<
" fusion memory benefit cannot be evaluated; NOT fusing.";
690 auto srcMemSizeVal = *srcMemSize;
691 auto dstMemSizeVal = *dstMemSize;
693 assert(sliceMemEstimate &&
"expected value");
694 auto fusedMem = dstMemSizeVal + *sliceMemEstimate;
696 LDBG() <<
" src mem: " << srcMemSizeVal;
697 LDBG() <<
" dst mem: " << dstMemSizeVal;
698 LDBG() <<
" fused mem: " << fusedMem;
699 LDBG() <<
" slice mem: " << sliceMemEstimate;
701 if (
static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
702 LDBG() <<
"Fusion is not profitable; NOT fusing.";
707 (1.0 - fusedMem / (
static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
709 double additionalComputeFraction =
710 100.0 * (minFusedLoopNestComputeCost /
711 (
static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
713 (void)additionalComputeFraction;
715 std::stringstream msg;
716 msg <<
" fusion is most profitable at depth " << *dstLoopDepth <<
" with "
717 << std::setprecision(2) << additionalComputeFraction
718 <<
"% redundant computation and a ";
719 msg << (storageReduction ? std::to_string(*storageReduction) :
"<unknown>");
720 msg <<
"% storage reduction.";
775 struct GreedyFusion {
782 unsigned localBufSizeThreshold;
784 std::optional<unsigned> fastMemorySpace;
790 double computeToleranceThreshold;
795 std::optional<unsigned> fastMemorySpace,
bool maximalFusion,
796 double computeToleranceThreshold)
797 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
798 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
799 computeToleranceThreshold(computeToleranceThreshold) {}
806 for (
auto &idAndNode : mdg->
nodes) {
807 const Node &node = idAndNode.second;
808 worklist.push_back(node.id);
812 void runSiblingFusionOnly() {
814 eraseUnusedMemRefAllocations();
818 void runProducerConsumerFusionOnly() {
819 fuseProducerConsumerNodes(
821 eraseUnusedMemRefAllocations();
829 void runGreedyFusion() {
831 fuseProducerConsumerNodes(1);
833 fuseProducerConsumerNodes(
835 eraseUnusedMemRefAllocations();
840 bool canCreatePrivateMemRef(
Value memref,
842 unsigned producerId,
unsigned consumerId,
843 bool removeSrcNode) {
847 const Node *consumerNode = mdg->
getNode(consumerId);
855 if (srcEscapingMemRefs.count(memref) > 0 &&
856 (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
869 any_of(mdg->
outEdges[producerId], [&](
const auto &edge) {
870 return edge.value == memref && edge.id != consumerId;
880 void performFusionsIntoDest(
unsigned dstId,
unsigned maxSrcUserCount) {
881 LDBG() <<
"Evaluating dst loop " << dstId;
883 if (mdg->
nodes.count(dstId) == 0)
886 auto *dstNode = mdg->
getNode(dstId);
888 if (!isa<AffineForOp>(dstNode->op))
892 if (dstNode->op->getNumResults() > 0)
895 LDBG() <<
"Evaluating dst loop " << dstId;
902 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
913 dstNodeChanged =
false;
917 for (
unsigned srcId : llvm::reverse(srcIdCandidates)) {
919 auto *srcNode = mdg->
getNode(srcId);
920 auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
922 LDBG() <<
"Trying to fuse producer loop nest " << srcId
923 <<
" with consumer loop nest " << dstId;
924 LDBG() <<
"Compute tolerance threshold: " << computeToleranceThreshold;
925 LDBG() <<
"Producer loop nest:";
926 LDBG() << *srcNode->op <<
" and consumer loop nest:";
927 LDBG() << *dstNode->op;
929 LDBG() <<
"Evaluating src loop " << srcId <<
" for dst loop " << dstId;
933 if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
938 producerConsumerMemrefs);
942 if (any_of(producerConsumerMemrefs, [&](
Value memref) {
958 if (fusedLoopInsPoint ==
nullptr)
969 unsigned numSurroundingLoops = surroundingLoops.size();
975 if (producerConsumerMemrefs.count(
976 cast<AffineReadOpInterface>(op).getMemRef()) > 0)
977 dstMemrefOps.push_back(op);
979 if (producerConsumerMemrefs.count(
980 cast<AffineWriteOpInterface>(op).getMemRef()))
981 dstMemrefOps.push_back(op);
982 if (dstMemrefOps.empty())
984 unsigned dstLoopDepthTest =
989 unsigned maxLegalFusionDepth = 0;
991 depthSliceUnions.resize(dstLoopDepthTest);
993 for (
unsigned i = 1; i <= dstLoopDepthTest; ++i) {
996 i + numSurroundingLoops,
997 &depthSliceUnions[i - 1], strategy);
999 maxLegalFusionDepth = i;
1000 LDBG() <<
"Found valid slice for depth: " << i;
1004 if (maxLegalFusionDepth == 0) {
1005 LDBG() <<
"Can't fuse: fusion is not legal at any depth";
1009 LDBG() <<
"Max legal depth for fusion: " << maxLegalFusionDepth;
1011 double computeToleranceThresholdToUse = computeToleranceThreshold;
1019 LDBG() <<
"Source nest has a cyclic dependence.";
1023 if (maximalFusion) {
1024 auto srcForOp = cast<AffineForOp>(srcNode->op);
1025 auto dstForOp = cast<AffineForOp>(dstNode->op);
1027 int64_t fusedLoopNestComputeCost;
1029 srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
1030 sliceCost, fusedLoopNestComputeCost);
1031 if (!fraction || fraction > 0) {
1032 LDBG() <<
"Can't perform maximal fusion with a cyclic dependence "
1033 <<
"and non-zero additional compute.";
1039 LDBG() <<
"Setting compute tolerance to zero since "
1040 <<
"source has a cylic dependence.";
1041 computeToleranceThresholdToUse = 0;
1048 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1049 if (!maximalFusion) {
1053 if (producerConsumerMemrefs.count(
1054 cast<AffineWriteOpInterface>(op).getMemRef()))
1055 producerStores.push_back(op);
1057 assert(!producerStores.empty() &&
"Expected producer store");
1059 dstAffineForOp, depthSliceUnions,
1060 maxLegalFusionDepth, &bestDstLoopDepth,
1061 computeToleranceThresholdToUse)) {
1066 assert(bestDstLoopDepth > 0 &&
"Unexpected loop fusion depth");
1068 depthSliceUnions[bestDstLoopDepth - 1];
1069 assert(!bestSlice.
isEmpty() &&
"Missing slice union for depth");
1075 srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
1079 for (
Value memref : producerConsumerMemrefs) {
1080 if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
1083 LDBG() <<
"Creating private memref for " << memref;
1085 privateMemrefs.insert(memref);
1090 fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
1091 dstNodeChanged =
true;
1093 LDBG() <<
"Fused src loop " << srcId <<
" into dst loop " << dstId
1094 <<
" at depth " << bestDstLoopDepth <<
":";
1095 LDBG() << dstAffineForOp;
1098 if (fusedLoopInsPoint != dstAffineForOp)
1099 dstAffineForOp->moveBefore(fusedLoopInsPoint);
1102 mdg->
updateEdges(srcNode->id, dstNode->id, privateMemrefs,
1106 if (!privateMemrefs.empty()) {
1113 dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
1114 Value storeMemRef = storeOp.getMemRef();
1115 if (privateMemrefs.count(storeMemRef) > 0)
1116 privateMemRefToStores[storeMemRef].push_back(storeOp);
1123 for (
auto &memrefToStoresPair : privateMemRefToStores) {
1126 dstAffineForOp, storesForMemref, bestDstLoopDepth,
1127 fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
1133 mdg->
addEdge(newMemRefNodeId, dstId, newMemRef);
1138 dstNode = mdg->
getNode(dstId);
1143 dstLoopCollector.
collect(dstAffineForOp);
1152 if (removeSrcNode) {
1153 LDBG() <<
"Removing src loop " << srcId <<
" after fusion";
1155 srcAffineForOp.erase();
1160 }
while (dstNodeChanged);
1167 void fuseProducerConsumerNodes(
unsigned maxSrcUserCount) {
1168 LDBG() <<
"--- Producer/Consumer Fusion ---";
1170 while (!worklist.empty()) {
1171 unsigned dstId = worklist.back();
1172 worklist.pop_back();
1173 performFusionsIntoDest(dstId, maxSrcUserCount);
1179 void fuseSiblingNodes() {
1180 LDBG() <<
"--- Sibling Fusion ---";
1182 while (!worklist.empty()) {
1183 unsigned dstId = worklist.back();
1184 worklist.pop_back();
1187 if (mdg->
nodes.count(dstId) == 0)
1190 auto *dstNode = mdg->
getNode(dstId);
1192 if (!isa<AffineForOp>(dstNode->op))
1195 fuseWithSiblingNodes(dstNode);
1200 void fuseWithSiblingNodes(
Node *dstNode) {
1202 std::pair<unsigned, Value> idAndMemref;
1203 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
1205 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
1206 unsigned sibId = idAndMemref.first;
1207 Value memref = idAndMemref.second;
1210 auto *sibNode = mdg->
getNode(sibId);
1213 assert(sibNode->op->getBlock() == dstNode->op->getBlock());
1218 if (insertPointInst ==
nullptr)
1225 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
1227 Operation *sibLoadOpInst = llvm::getSingleElement(sibLoadOpInsts);
1231 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
1241 unsigned numSurroundingLoops = surroundingLoops.size();
1244 unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
1245 auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
1249 depthSliceUnions.resize(dstLoopDepthTest);
1250 unsigned maxLegalFusionDepth = 0;
1252 for (
unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1255 i + numSurroundingLoops,
1256 &depthSliceUnions[i - 1], strategy);
1259 maxLegalFusionDepth = i;
1262 LDBG() <<
"Max legal depth for fusion: " << maxLegalFusionDepth;
1265 if (maxLegalFusionDepth == 0)
1268 double computeToleranceThresholdToUse = computeToleranceThreshold;
1276 LDBG() <<
"Source nest has a cyclic dependence.";
1280 if (maximalFusion) {
1281 auto dstForOp = cast<AffineForOp>(dstNode->op);
1283 int64_t fusedLoopNestComputeCost;
1285 sibAffineForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
1286 sliceCost, fusedLoopNestComputeCost);
1287 if (!fraction || fraction > 0) {
1288 LDBG() <<
"Can't perform maximal fusion with a cyclic dependence "
1289 <<
"and non-zero additional compute.";
1295 LDBG() <<
"Setting compute tolerance to zero since "
1296 <<
"source has a cyclic dependence.";
1297 computeToleranceThresholdToUse = 0.0;
1301 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1302 if (!maximalFusion) {
1308 depthSliceUnions, maxLegalFusionDepth,
1310 computeToleranceThresholdToUse))
1314 assert(bestDstLoopDepth > 0 &&
"Unexpected loop fusion depth");
1317 depthSliceUnions[bestDstLoopDepth - 1];
1318 assert(!bestSlice.
isEmpty() &&
1319 "Fusion depth has no computed slice union");
1325 if (!isMaximal.value_or(
false)) {
1326 LDBG() <<
"Slice isn't maximal; not performing sibling fusion.";
1333 bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
1336 isInnermostInsertion);
1338 auto dstForInst = cast<AffineForOp>(dstNode->op);
1340 if (insertPointInst != dstForInst)
1341 dstForInst->moveBefore(insertPointInst);
1343 LDBG() <<
"Fused sibling nest " << sibId <<
" into destination nest "
1344 << dstNode->id <<
" at depth " << bestDstLoopDepth <<
":";
1345 LDBG() << dstAffineForOp;
1348 updateStateAfterSiblingFusion(sibNode, dstNode);
1362 bool findSiblingNodeToFuse(
Node *dstNode,
1364 std::pair<unsigned, Value> *idAndMemrefToFuse) {
1367 auto canFuseWithSibNode = [&](
Node *sibNode,
Value memref) {
1370 if (sibNode->getLoadOpCount(memref) != 1)
1380 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
1381 if (llvm::any_of(loadAndStoreMemrefSet, [=](
Value memref) {
1388 for (
auto *storeOpInst : sibNode->stores) {
1389 storeMemrefs.insert(
1390 cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
1392 return storeMemrefs.size() <= 1;
1396 Block *block = dstNode->op->getBlock();
1399 auto loadOp = dyn_cast<AffineReadOpInterface>(user);
1408 auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
1409 return loop->getBlock() == &mdg->
block;
1412 if (it == loops.end())
1415 assert(sibNode !=
nullptr);
1417 if (sibNode->id == dstNode->id)
1420 if (visitedSibNodeIds->count(sibNode->id) > 0)
1423 auto memref = loadOp.getMemRef();
1424 if (dstNode->getLoadOpCount(memref) == 0)
1427 if (canFuseWithSibNode(sibNode, memref)) {
1428 visitedSibNodeIds->insert(sibNode->id);
1429 idAndMemrefToFuse->first = sibNode->id;
1430 idAndMemrefToFuse->second = memref;
1443 if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
1444 (mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0 ||
1445 inEdge.value.getDefiningOp() == mdg->getNode(inEdge.id)->op))
1446 inEdges.push_back(inEdge);
1451 for (
auto &inEdge : inEdges) {
1456 unsigned sibNodeId = outEdge.id;
1457 if (visitedSibNodeIds->count(sibNodeId) > 0)
1460 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
1462 auto *sibNode = mdg->getNode(sibNodeId);
1463 if (!isa<AffineForOp>(sibNode->op))
1466 if (canFuseWithSibNode(sibNode, outEdge.value)) {
1468 outEdges.push_back(outEdge);
1473 if (!outEdges.empty()) {
1474 visitedSibNodeIds->insert(outEdges[0].
id);
1475 idAndMemrefToFuse->first = outEdges[0].id;
1476 idAndMemrefToFuse->second = outEdges[0].value;
1485 void updateStateAfterSiblingFusion(
Node *sibNode,
Node *dstNode) {
1490 auto dstForInst = cast<AffineForOp>(dstNode->op);
1492 dstLoopCollector.
collect(dstForInst);
1501 void eraseUnusedMemRefAllocations() {
1503 if (pair.second > 0)
1505 auto memref = pair.first;
1511 if (isa_and_nonnull<memref::AllocOp>(op))
1520 void LoopFusion::runOnBlock(
Block *block) {
1523 LDBG() <<
"MDG init failed";
1527 std::optional<unsigned> fastMemorySpaceOpt;
1528 if (fastMemorySpace.hasValue())
1529 fastMemorySpaceOpt = fastMemorySpace;
1530 unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
1531 GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
1532 maximalFusion, computeToleranceThreshold);
1535 fusion.runProducerConsumerFusionOnly();
1537 fusion.runSiblingFusionOnly();
1539 fusion.runGreedyFusion();
1542 void LoopFusion::runOnOperation() {
1545 getOperation()->walk([&](
Operation *op) {
1547 for (Block &block : region.getBlocks()) {
1548 auto affineFors = block.getOps<AffineForOp>();
1549 if (!affineFors.empty() && !llvm::hasSingleElement(affineFors))
1557 unsigned fastMemorySpace, uint64_t localBufSizeThreshold,
1558 bool maximalFusion,
enum FusionMode affineFusionMode) {
1559 return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
1560 maximalFusion, affineFusionMode);
MemRefDependenceGraph::Node Node
static Operation * getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock, ArrayRef< Operation * > producerStores)
Get the operation that should act as a dominance filter while replacing memref uses with a private me...
static void getProducerCandidates(unsigned dstId, const MemRefDependenceGraph &mdg, SmallVectorImpl< unsigned > &srcIdCandidates)
Returns in 'srcIdCandidates' the producer fusion candidates for consumer 'dstId'.
static bool isFusionProfitable(AffineForOp srcForOp, ArrayRef< Operation * > producerStores, AffineForOp dstForOp, ArrayRef< ComputationSliceState > depthSliceUnions, unsigned maxLegalFusionDepth, unsigned *dstLoopDepth, double computeToleranceThreshold)
static bool isEscapingMemref(Value memref, Block *block)
A memref escapes in the context of the fusion pass if either:
static bool canRemoveSrcNodeAfterFusion(unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, Operation *fusedLoopInsPoint, const DenseSet< Value > &escapingMemRefs, const MemRefDependenceGraph &mdg)
Returns true if node 'srcId' can be removed after fusing it with node 'dstId'.
static Value createPrivateMemRef(AffineForOp forOp, ArrayRef< Operation * > storeOps, unsigned dstLoopDepth, std::optional< unsigned > fastMemorySpace, Block *sliceInsertionBlock, uint64_t localBufSizeThreshold)
static void gatherEscapingMemrefs(unsigned id, const MemRefDependenceGraph &mdg, DenseSet< Value > &escapingMemRefs)
Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' that escape the block or ...
static std::optional< double > getAdditionalComputeFraction(AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth, ArrayRef< ComputationSliceState > depthSliceUnions, int64_t &sliceCost, int64_t &fusedLoopNestComputeCost)
Returns the amount of additional (redundant) computation that will be done as a fraction of the total...
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
void getValues(unsigned start, unsigned end, SmallVectorImpl< Value > *values) const
Returns the Values associated with variables in range [start, end).
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
Describes the fusion strategy to be used in the Affine loop fusion utilities.
unsigned getNumDimAndSymbolVars() const
bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, AffineForOp dstForOp, LoopNestStats &dstStats, const ComputationSliceState &slice, int64_t *computeCost)
Computes and returns in 'computeCost', the total compute cost of fusing the 'slice' of the loop nest ...
void gatherProducerConsumerMemrefs(ArrayRef< Operation * > srcOps, ArrayRef< Operation * > dstOps, DenseSet< Value > &producerConsumerMemrefs)
Returns in 'producerConsumerMemrefs' the memrefs involved in a producer-consumer dependence between w...
int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats)
Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
void fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, const ComputationSliceState &srcSlice, bool isInnermostSiblingInsertionFusion=false)
Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point and source slice loop bo...
void getAffineForIVs(Operation &op, SmallVectorImpl< AffineForOp > *loops)
Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered from the outermost 'affine....
std::optional< int64_t > getMemoryFootprintBytes(AffineForOp forOp, int memorySpace=-1)
Gets the memory footprint of all data touched in the specified memory space in bytes; if the memory s...
std::unique_ptr< Pass > createLoopFusionPass(unsigned fastMemorySpace=0, uint64_t localBufSizeThreshold=0, bool maximalFusion=false, enum FusionMode fusionMode=FusionMode::Greedy)
Creates a loop fusion pass which fuses affine loop nests at the top-level of the operation the pass i...
FusionMode
Fusion mode to attempt.
unsigned getInnermostCommonLoopDepth(ArrayRef< Operation * > ops, SmallVectorImpl< AffineForOp > *surroundingLoops=nullptr)
Returns the innermost common loop depth for the set of operations in 'ops'.
bool getLoopNestStats(AffineForOp forOp, LoopNestStats *stats)
Collect loop nest statistics (eg.
AffineForOp sinkSequentialLoops(AffineForOp forOp)
bool hasCyclicDependence(AffineForOp root)
Returns true if the affine nest rooted at root has a cyclic dependence among its affine memory access...
mlir::Block * findInnermostCommonBlockInScope(mlir::Operation *a, mlir::Operation *b)
Find the innermost common Block of a and b in the affine scope that a and b are part of.
FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, ComputationSliceState *srcSlice, FusionStrategy fusionStrategy=FusionStrategy::Generic)
Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the loop nest rooted at 'dst...
LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, ArrayRef< Value > extraIndices={}, AffineMap indexRemap=AffineMap(), ArrayRef< Value > extraOperands={}, ArrayRef< Value > symbolOperands={}, llvm::function_ref< bool(Operation *)> userFilterFn=nullptr, bool allowNonDereferencingOps=false, bool replaceInDeallocOp=false)
Replaces all "dereferencing" uses of oldMemRef with newMemRef while optionally remapping the old memr...
std::optional< int64_t > getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType)
Returns the memref's element type's size in bytes where the elemental type is an int or float or a ve...
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their associated operands for a ...
bool isEmpty() const
Returns true if the computation slice is empty.
std::optional< bool > isMaximal() const
Returns true if the computation slice encloses all the iterations of the sliced loop nest.
Block::iterator insertPoint
enum mlir::affine::FusionResult::ResultEnum value
SmallVector< Operation *, 4 > memrefFrees
SmallVector< Operation *, 4 > loadOpInsts
SmallVector< Operation *, 4 > memrefStores
void collect(Operation *opToWalk)
SmallVector< Operation *, 4 > memrefLoads
SmallVector< Operation *, 4 > storeOpInsts
LoopNestStats aggregates various per-loop statistics (eg.
DenseMap< unsigned, SmallVector< Edge, 2 > > outEdges
Block & block
The block for which this graph is created to perform fusion.
unsigned addNode(Operation *op)
void addEdge(unsigned srcId, unsigned dstId, Value value)
DenseMap< unsigned, Node > nodes
bool hasDependencePath(unsigned srcId, unsigned dstId) const
void clearNodeLoadAndStores(unsigned id)
const Node * getForOpNode(AffineForOp forOp) const
Operation * getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) const
void updateEdges(unsigned srcId, unsigned dstId, const DenseSet< Value > &privateMemRefs, bool removeSrcId)
DenseMap< unsigned, SmallVector< Edge, 2 > > inEdges
void forEachMemRefInputEdge(unsigned id, const std::function< void(Edge)> &callback)
unsigned getOutEdgeCount(unsigned id, Value memref=nullptr) const
const Node * getNode(unsigned id) const
void removeNode(unsigned id)
void forEachMemRefOutputEdge(unsigned id, const std::function< void(Edge)> &callback)
void addToNode(unsigned id, ArrayRef< Operation * > loads, ArrayRef< Operation * > stores, ArrayRef< Operation * > memrefLoads, ArrayRef< Operation * > memrefStores, ArrayRef< Operation * > memrefFrees)
unsigned getIncomingMemRefAccesses(unsigned id, Value memref) const
DenseMap< Value, unsigned > memrefEdgeCount
A region of a memref's data space; this is typically constructed by analyzing load/store op's on this...
std::optional< int64_t > getConstantBoundingSizeAndShape(SmallVectorImpl< int64_t > *shape=nullptr, SmallVectorImpl< AffineMap > *lbs=nullptr) const
Returns a constant upper bound on the number of elements in this region if bounded by a known constan...
FlatAffineValueConstraints * getConstraints()
LogicalResult compute(Operation *op, unsigned loopDepth, const ComputationSliceState *sliceState=nullptr, bool addMemRefDimBounds=true, bool dropLocalVars=true, bool dropOuterIVs=true)
Computes the memory region accessed by this memref with the region represented as constraints symboli...
std::optional< int64_t > getRegionSize()
Returns the size of this MemRefRegion in bytes.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.