23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/Debug.h"
28 #define DEBUG_TYPE "memref-transforms"
29 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
30 #define DBGSNL() (llvm::dbgs() << "\n")
34 auto copyOp = dyn_cast<memref::CopyOp>(op);
37 return copyOp.getTarget() == buffer;
49 if (
auto subviewUse = dyn_cast<memref::SubViewOp>(user)) {
53 MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
54 subviewUse.getType().getShape(), cast<MemRefType>(val.
getType()),
55 subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
56 subviewUse.getStaticStrides());
57 Value newSubview = memref::SubViewOp::create(
58 rewriter, subviewUse->getLoc(), newType, val,
59 subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
60 subviewUse.getMixedStrides());
81 FailureOr<memref::AllocOp>
83 unsigned multiBufferingFactor,
84 bool skipOverrideAnalysis) {
85 LLVM_DEBUG(
DBGS() <<
"Start multibuffering: " << allocOp <<
"\n");
87 LoopLikeOpInterface candidateLoop;
89 auto parentLoop = user->getParentOfType<LoopLikeOpInterface>();
91 if (isa<memref::DeallocOp>(user)) {
97 LLVM_DEBUG(
DBGS() <<
"--no parent loop -> fail\n");
98 LLVM_DEBUG(
DBGS() <<
"----due to user: " << *user <<
"\n");
101 if (!skipOverrideAnalysis) {
104 LLVM_DEBUG(
DBGS() <<
"--Skip user: found loop-carried dependence\n");
108 if (llvm::any_of(allocOp->getUsers(), [&](
Operation *otherUser) {
109 return !dom.dominates(user, otherUser);
112 DBGS() <<
"--Skip user: does not dominate all other users\n");
116 if (llvm::any_of(allocOp->getUsers(), [&](
Operation *otherUser) {
117 return !isa<memref::DeallocOp>(otherUser) &&
118 !parentLoop->isProperAncestor(otherUser);
122 <<
"--Skip user: not all other users are in the parent loop\n");
126 candidateLoop = parentLoop;
130 if (!candidateLoop) {
131 LLVM_DEBUG(
DBGS() <<
"Skip alloc: no candidate loop\n");
135 std::optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
136 std::optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
137 std::optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
138 if (!inductionVar || !lowerBound || !singleStep ||
139 !llvm::hasSingleElement(candidateLoop.getLoopRegions())) {
140 LLVM_DEBUG(
DBGS() <<
"Skip alloc: no single iv, lb, step or region\n");
144 if (!dom.
dominates(allocOp.getOperation(), candidateLoop)) {
145 LLVM_DEBUG(
DBGS() <<
"Skip alloc: does not dominate candidate loop\n");
149 LLVM_DEBUG(
DBGS() <<
"Start multibuffering loop: " << candidateLoop <<
"\n");
154 llvm::append_range(multiBufferedShape, originalShape);
155 LLVM_DEBUG(
DBGS() <<
"--original type: " << allocOp.getType() <<
"\n");
158 .setLayout(MemRefLayoutAttrInterface());
159 LLVM_DEBUG(
DBGS() <<
"--multi-buffered type: " << mbMemRefType <<
"\n");
165 auto mbAlloc = memref::AllocOp::create(rewriter, loc, mbMemRefType,
167 LLVM_DEBUG(
DBGS() <<
"--multi-buffered alloc: " << mbAlloc <<
"\n");
172 &candidateLoop.getLoopRegions().front()->front());
173 Value ivVal = *inductionVar;
179 rewriter, loc, ((iv - lb).floorDiv(step)) % multiBufferingFactor,
180 {ivVal, lbVal, stepVal});
181 LLVM_DEBUG(
DBGS() <<
"--multi-buffered indexing: " << bufferIndex <<
"\n");
185 int64_t mbMemRefTypeRank = mbMemRefType.getRank();
192 offsets.front() = bufferIndex;
194 for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
197 MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType(
198 originalShape, mbMemRefType, offsets, sizes, strides);
199 Value subview = memref::SubViewOp::create(rewriter, loc, dstMemref, mbAlloc,
200 offsets, sizes, strides);
201 LLVM_DEBUG(
DBGS() <<
"--multi-buffered slice: " << subview <<
"\n");
205 for (
OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) {
206 auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner());
212 memref::DeallocOp::create(rewriter, deallocOp->getLoc(), mbAlloc);
214 LLVM_DEBUG(
DBGS() <<
"----Created dealloc: " << newDeallocOp <<
"\n");
227 FailureOr<memref::AllocOp>
229 unsigned multiBufferingFactor,
230 bool skipOverrideAnalysis) {
232 return multiBuffer(rewriter, allocOp, multiBufferingFactor,
233 skipOverrideAnalysis);
static bool overrideBuffer(Operation *op, Value buffer)
Return true if the op fully overwrite the given buffer value.
static void replaceUsesAndPropagateType(RewriterBase &rewriter, Operation *oldOp, Value val)
Replace the uses of oldOp with the given val and for subview uses propagate the type change.
Base type for affine expression.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
user_range getUsers()
Returns a range of all users.
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< memref::AllocOp > multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, unsigned multiplier, bool skipOverrideAnalysis=false)
Transformation to do multi-buffering/array expansion to remove dependencies on the temporary allocati...
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.