57 MemRefType srcType = cast<MemRefType>(val.
getType());
61 bool typeInferenceFailed =
false;
64 .Case([&](memref::SubViewOp subview) ->
Value {
66 memref::SubViewOp::inferRankReducedResultType(
67 subview.getType().getShape(), srcType,
68 subview.getStaticOffsets(), subview.getStaticSizes(),
69 subview.getStaticStrides());
70 return memref::SubViewOp::create(
71 rewriter, subview->getLoc(), newType, val,
72 subview.getMixedOffsets(), subview.getMixedSizes(),
73 subview.getMixedStrides());
75 .Case([&](memref::ExpandShapeOp expand) ->
Value {
76 FailureOr<MemRefType> newType =
77 memref::ExpandShapeOp::computeExpandedType(
78 srcType, expand.getResultType().getShape(),
79 expand.getReassociationIndices());
80 if (failed(newType)) {
81 typeInferenceFailed =
true;
84 return memref::ExpandShapeOp::create(
85 rewriter, expand->getLoc(), *newType, val,
86 expand.getReassociationIndices(),
87 expand.getMixedOutputShape());
89 .Case([&](memref::CollapseShapeOp collapse) ->
Value {
90 FailureOr<MemRefType> newType =
91 memref::CollapseShapeOp::computeCollapsedType(
92 srcType, collapse.getReassociationIndices());
93 if (failed(newType)) {
94 typeInferenceFailed =
true;
97 return memref::CollapseShapeOp::create(
98 rewriter, collapse->getLoc(), *newType, val,
99 collapse.getReassociationIndices());
101 .Case([&](memref::CastOp cast) ->
Value {
102 if (!memref::CastOp::areCastCompatible(srcType, cast.getType())) {
103 typeInferenceFailed =
true;
106 return memref::CastOp::create(rewriter, cast->getLoc(),
111 if (typeInferenceFailed) {
113 "failed to compute view-like result type after multi-buffering");
122 opsToErase.push_back(user);
145 unsigned multiBufferingFactor,
146 bool skipOverrideAnalysis) {
147 LLVM_DEBUG(
DBGS() <<
"Start multibuffering: " << allocOp <<
"\n");
149 LoopLikeOpInterface candidateLoop;
151 auto parentLoop = user->getParentOfType<LoopLikeOpInterface>();
153 if (isa<memref::DeallocOp>(user)) {
159 LLVM_DEBUG(
DBGS() <<
"--no parent loop -> fail\n");
160 LLVM_DEBUG(
DBGS() <<
"----due to user: " << *user <<
"\n");
163 if (!skipOverrideAnalysis) {
166 LLVM_DEBUG(
DBGS() <<
"--Skip user: found loop-carried dependence\n");
170 if (llvm::any_of(allocOp->getUsers(), [&](
Operation *otherUser) {
171 return !dom.dominates(user, otherUser);
174 DBGS() <<
"--Skip user: does not dominate all other users\n");
178 if (llvm::any_of(allocOp->getUsers(), [&](
Operation *otherUser) {
179 return !isa<memref::DeallocOp>(otherUser) &&
180 !parentLoop->isProperAncestor(otherUser);
184 <<
"--Skip user: not all other users are in the parent loop\n");
188 candidateLoop = parentLoop;
192 if (!candidateLoop) {
193 LLVM_DEBUG(
DBGS() <<
"Skip alloc: no candidate loop\n");
197 std::optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
198 std::optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
199 std::optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
200 if (!inductionVar || !lowerBound || !singleStep ||
201 !llvm::hasSingleElement(candidateLoop.getLoopRegions())) {
202 LLVM_DEBUG(
DBGS() <<
"Skip alloc: no single iv, lb, step or region\n");
206 if (!dom.
dominates(allocOp.getOperation(), candidateLoop)) {
207 LLVM_DEBUG(
DBGS() <<
"Skip alloc: does not dominate candidate loop\n");
211 LLVM_DEBUG(
DBGS() <<
"Start multibuffering loop: " << candidateLoop <<
"\n");
216 llvm::append_range(multiBufferedShape, originalShape);
217 LLVM_DEBUG(
DBGS() <<
"--original type: " << allocOp.getType() <<
"\n");
221 LLVM_DEBUG(
DBGS() <<
"--multi-buffered type: " << mbMemRefType <<
"\n");
227 auto mbAlloc = memref::AllocOp::create(rewriter, loc, mbMemRefType,
229 LLVM_DEBUG(
DBGS() <<
"--multi-buffered alloc: " << mbAlloc <<
"\n");
234 &candidateLoop.getLoopRegions().front()->front());
235 Value ivVal = *inductionVar;
241 rewriter, loc, ((iv - lb).floorDiv(step)) % multiBufferingFactor,
242 {ivVal, lbVal, stepVal});
243 LLVM_DEBUG(
DBGS() <<
"--multi-buffered indexing: " << bufferIndex <<
"\n");
247 int64_t mbMemRefTypeRank = mbMemRefType.getRank();
254 offsets.front() = bufferIndex;
256 for (
int64_t i = 0, e = originalShape.size(); i != e; ++i)
259 MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType(
260 originalShape, mbMemRefType, offsets, sizes, strides);
261 Value subview = memref::SubViewOp::create(rewriter, loc, dstMemref, mbAlloc,
262 offsets, sizes, strides);
263 LLVM_DEBUG(
DBGS() <<
"--multi-buffered slice: " << subview <<
"\n");
267 for (
OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) {
268 auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner());
274 memref::DeallocOp::create(rewriter, deallocOp->getLoc(), mbAlloc);
276 LLVM_DEBUG(
DBGS() <<
"----Created dealloc: " << newDeallocOp <<
"\n");