24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/Support/Debug.h"
29 #define DEBUG_TYPE "memref-transforms"
30 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
31 #define DBGSNL() (llvm::dbgs() << "\n")
35 auto copyOp = dyn_cast<memref::CopyOp>(op);
38 return copyOp.getTarget() == buffer;
54 auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner());
56 operandsToReplace.push_back(&use);
63 Type newType = memref::SubViewOp::inferRankReducedResultType(
64 subviewUse.getType().getShape(), cast<MemRefType>(val.
getType()),
65 subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
66 subviewUse.getStaticStrides());
67 Value newSubview = rewriter.
create<memref::SubViewOp>(
68 subviewUse->getLoc(), cast<MemRefType>(newType), val,
69 subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
70 subviewUse.getMixedStrides());
75 opsToDelete.push_back(use.getOwner());
80 for (
OpOperand *operand : operandsToReplace) {
98 FailureOr<memref::AllocOp>
100 unsigned multiBufferingFactor,
101 bool skipOverrideAnalysis) {
102 LLVM_DEBUG(
DBGS() <<
"Start multibuffering: " << allocOp <<
"\n");
104 LoopLikeOpInterface candidateLoop;
106 auto parentLoop = user->getParentOfType<LoopLikeOpInterface>();
108 if (isa<memref::DeallocOp>(user)) {
114 LLVM_DEBUG(
DBGS() <<
"--no parent loop -> fail\n");
115 LLVM_DEBUG(
DBGS() <<
"----due to user: " << *user <<
"\n");
118 if (!skipOverrideAnalysis) {
121 LLVM_DEBUG(
DBGS() <<
"--Skip user: found loop-carried dependence\n");
125 if (llvm::any_of(allocOp->getUsers(), [&](
Operation *otherUser) {
126 return !dom.dominates(user, otherUser);
129 DBGS() <<
"--Skip user: does not dominate all other users\n");
133 if (llvm::any_of(allocOp->getUsers(), [&](
Operation *otherUser) {
134 return !isa<memref::DeallocOp>(otherUser) &&
135 !parentLoop->isProperAncestor(otherUser);
139 <<
"--Skip user: not all other users are in the parent loop\n");
143 candidateLoop = parentLoop;
147 if (!candidateLoop) {
148 LLVM_DEBUG(
DBGS() <<
"Skip alloc: no candidate loop\n");
152 std::optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
153 std::optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
154 std::optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
155 if (!inductionVar || !lowerBound || !singleStep ||
156 !llvm::hasSingleElement(candidateLoop.getLoopRegions())) {
157 LLVM_DEBUG(
DBGS() <<
"Skip alloc: no single iv, lb, step or region\n");
161 if (!dom.
dominates(allocOp.getOperation(), candidateLoop)) {
162 LLVM_DEBUG(
DBGS() <<
"Skip alloc: does not dominate candidate loop\n");
166 LLVM_DEBUG(
DBGS() <<
"Start multibuffering loop: " << candidateLoop <<
"\n");
171 llvm::append_range(multiBufferedShape, originalShape);
172 LLVM_DEBUG(
DBGS() <<
"--original type: " << allocOp.getType() <<
"\n");
175 .setLayout(MemRefLayoutAttrInterface());
176 LLVM_DEBUG(
DBGS() <<
"--multi-buffered type: " << mbMemRefType <<
"\n");
182 auto mbAlloc = rewriter.
create<memref::AllocOp>(
183 loc, mbMemRefType,
ValueRange{}, allocOp->getAttrs());
184 LLVM_DEBUG(
DBGS() <<
"--multi-buffered alloc: " << mbAlloc <<
"\n");
189 &candidateLoop.getLoopRegions().front()->front());
190 Value ivVal = *inductionVar;
196 rewriter, loc, ((iv - lb).floorDiv(step)) % multiBufferingFactor,
197 {ivVal, lbVal, stepVal});
198 LLVM_DEBUG(
DBGS() <<
"--multi-buffered indexing: " << bufferIndex <<
"\n");
202 int64_t mbMemRefTypeRank = mbMemRefType.getRank();
209 offsets.front() = bufferIndex;
211 for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
215 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
216 originalShape, mbMemRefType, offsets, sizes, strides));
217 Value subview = rewriter.
create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
218 offsets, sizes, strides);
219 LLVM_DEBUG(
DBGS() <<
"--multi-buffered slice: " << subview <<
"\n");
223 for (
OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) {
224 auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner());
230 rewriter.
create<memref::DeallocOp>(deallocOp->getLoc(), mbAlloc);
232 LLVM_DEBUG(
DBGS() <<
"----Created dealloc: " << newDeallocOp <<
"\n");
245 FailureOr<memref::AllocOp>
247 unsigned multiBufferingFactor,
248 bool skipOverrideAnalysis) {
250 return multiBuffer(rewriter, allocOp, multiBufferingFactor,
251 skipOverrideAnalysis);
static bool overrideBuffer(Operation *op, Value buffer)
Return true if the op fully overwrite the given buffer value.
static void replaceUsesAndPropagateType(RewriterBase &rewriter, Operation *oldOp, Value val)
Replace the uses of oldOp with the given val and for subview uses propagate the type change.
Base type for affine expression.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
user_range getUsers()
Returns a range of all users.
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< memref::AllocOp > multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, unsigned multiplier, bool skipOverrideAnalysis=false)
Transformation to do multi-buffering/array expansion to remove dependencies on the temporary allocati...
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.