24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/Support/Debug.h"
29 #define DEBUG_TYPE "memref-transforms"
30 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
31 #define DBGSNL() (llvm::dbgs() << "\n")
35 auto copyOp = dyn_cast<memref::CopyOp>(op);
38 return copyOp.getTarget() == buffer;
54 auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner());
56 operandsToReplace.push_back(&use);
63 MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
64 subviewUse.getType().getShape(), cast<MemRefType>(val.
getType()),
65 subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
66 subviewUse.getStaticStrides());
67 Value newSubview = rewriter.
create<memref::SubViewOp>(
68 subviewUse->getLoc(), newType, val, subviewUse.getMixedOffsets(),
69 subviewUse.getMixedSizes(), subviewUse.getMixedStrides());
74 opsToDelete.push_back(use.getOwner());
79 for (
OpOperand *operand : operandsToReplace) {
97 FailureOr<memref::AllocOp>
99 unsigned multiBufferingFactor,
100 bool skipOverrideAnalysis) {
101 LLVM_DEBUG(
DBGS() <<
"Start multibuffering: " << allocOp <<
"\n");
103 LoopLikeOpInterface candidateLoop;
105 auto parentLoop = user->getParentOfType<LoopLikeOpInterface>();
107 if (isa<memref::DeallocOp>(user)) {
113 LLVM_DEBUG(
DBGS() <<
"--no parent loop -> fail\n");
114 LLVM_DEBUG(
DBGS() <<
"----due to user: " << *user <<
"\n");
117 if (!skipOverrideAnalysis) {
120 LLVM_DEBUG(
DBGS() <<
"--Skip user: found loop-carried dependence\n");
124 if (llvm::any_of(allocOp->getUsers(), [&](
Operation *otherUser) {
125 return !dom.dominates(user, otherUser);
128 DBGS() <<
"--Skip user: does not dominate all other users\n");
132 if (llvm::any_of(allocOp->getUsers(), [&](
Operation *otherUser) {
133 return !isa<memref::DeallocOp>(otherUser) &&
134 !parentLoop->isProperAncestor(otherUser);
138 <<
"--Skip user: not all other users are in the parent loop\n");
142 candidateLoop = parentLoop;
146 if (!candidateLoop) {
147 LLVM_DEBUG(
DBGS() <<
"Skip alloc: no candidate loop\n");
151 std::optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
152 std::optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
153 std::optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
154 if (!inductionVar || !lowerBound || !singleStep ||
155 !llvm::hasSingleElement(candidateLoop.getLoopRegions())) {
156 LLVM_DEBUG(
DBGS() <<
"Skip alloc: no single iv, lb, step or region\n");
160 if (!dom.
dominates(allocOp.getOperation(), candidateLoop)) {
161 LLVM_DEBUG(
DBGS() <<
"Skip alloc: does not dominate candidate loop\n");
165 LLVM_DEBUG(
DBGS() <<
"Start multibuffering loop: " << candidateLoop <<
"\n");
170 llvm::append_range(multiBufferedShape, originalShape);
171 LLVM_DEBUG(
DBGS() <<
"--original type: " << allocOp.getType() <<
"\n");
174 .setLayout(MemRefLayoutAttrInterface());
175 LLVM_DEBUG(
DBGS() <<
"--multi-buffered type: " << mbMemRefType <<
"\n");
181 auto mbAlloc = rewriter.
create<memref::AllocOp>(
182 loc, mbMemRefType,
ValueRange{}, allocOp->getAttrs());
183 LLVM_DEBUG(
DBGS() <<
"--multi-buffered alloc: " << mbAlloc <<
"\n");
188 &candidateLoop.getLoopRegions().front()->front());
189 Value ivVal = *inductionVar;
195 rewriter, loc, ((iv - lb).floorDiv(step)) % multiBufferingFactor,
196 {ivVal, lbVal, stepVal});
197 LLVM_DEBUG(
DBGS() <<
"--multi-buffered indexing: " << bufferIndex <<
"\n");
201 int64_t mbMemRefTypeRank = mbMemRefType.getRank();
208 offsets.front() = bufferIndex;
210 for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
213 MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType(
214 originalShape, mbMemRefType, offsets, sizes, strides);
215 Value subview = rewriter.
create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
216 offsets, sizes, strides);
217 LLVM_DEBUG(
DBGS() <<
"--multi-buffered slice: " << subview <<
"\n");
221 for (
OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) {
222 auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner());
228 rewriter.
create<memref::DeallocOp>(deallocOp->getLoc(), mbAlloc);
230 LLVM_DEBUG(
DBGS() <<
"----Created dealloc: " << newDeallocOp <<
"\n");
243 FailureOr<memref::AllocOp>
245 unsigned multiBufferingFactor,
246 bool skipOverrideAnalysis) {
248 return multiBuffer(rewriter, allocOp, multiBufferingFactor,
249 skipOverrideAnalysis);
static bool overrideBuffer(Operation *op, Value buffer)
Return true if the op fully overwrite the given buffer value.
static void replaceUsesAndPropagateType(RewriterBase &rewriter, Operation *oldOp, Value val)
Replace the uses of oldOp with the given val and for subview uses propagate the type change.
Base type for affine expression.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
user_range getUsers()
Returns a range of all users.
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< memref::AllocOp > multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, unsigned multiplier, bool skipOverrideAnalysis=false)
Transformation to do multi-buffering/array expansion to remove dependencies on the temporary allocati...
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.