MLIR 23.0.0git
MultiBuffer.cpp
Go to the documentation of this file.
1//===----------- MultiBuffering.cpp ---------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements multi buffering transformation.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/IR/AffineExpr.h"
19#include "mlir/IR/Dominance.h"
21#include "mlir/IR/ValueRange.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/Debug.h"
26
27using namespace mlir;
28
29#define DEBUG_TYPE "memref-transforms"
30#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
31#define DBGSNL() (llvm::dbgs() << "\n")
32
33/// Return true if the op fully overwrite the given `buffer` value.
34static bool overrideBuffer(Operation *op, Value buffer) {
35 auto copyOp = dyn_cast<memref::CopyOp>(op);
36 if (!copyOp)
37 return false;
38 return copyOp.getTarget() == buffer;
39}
40
41/// Replace the uses of `oldOp` with the given `val` and for view-like uses
42/// propagate the type change. Changing the memref type may require propagating
43/// it through view-like ops (subview, expand_shape, collapse_shape, cast) so
44/// we need to propagate the type change and erase old view ops.
45///
46/// Only view-like ops whose result type can be recomputed from the new source
47/// type and existing op attributes are handled here. Other ops fall back to
48/// operand replacement without type propagation.
49static LogicalResult replaceUsesAndPropagateType(RewriterBase &rewriter,
50 Operation *oldOp, Value val) {
51 SmallVector<Operation *> opsToErase;
52 // Iterate with early_inc to erase current user inside the loop.
53 for (OpOperand &use : llvm::make_early_inc_range(oldOp->getUses())) {
54 Operation *user = use.getOwner();
55 OpBuilder::InsertionGuard g(rewriter);
56 rewriter.setInsertionPoint(user);
57 MemRefType srcType = cast<MemRefType>(val.getType());
58
59 // Try to create a new view-like op with updated result type.
60 // Each view-like op has its own method to compute the result type.
61 bool typeInferenceFailed = false;
64 .Case([&](memref::SubViewOp subview) -> Value {
65 MemRefType newType =
66 memref::SubViewOp::inferRankReducedResultType(
67 subview.getType().getShape(), srcType,
68 subview.getStaticOffsets(), subview.getStaticSizes(),
69 subview.getStaticStrides());
70 return memref::SubViewOp::create(
71 rewriter, subview->getLoc(), newType, val,
72 subview.getMixedOffsets(), subview.getMixedSizes(),
73 subview.getMixedStrides());
74 })
75 .Case([&](memref::ExpandShapeOp expand) -> Value {
76 FailureOr<MemRefType> newType =
77 memref::ExpandShapeOp::computeExpandedType(
78 srcType, expand.getResultType().getShape(),
79 expand.getReassociationIndices());
80 if (failed(newType)) {
81 typeInferenceFailed = true;
82 return Value();
83 }
84 return memref::ExpandShapeOp::create(
85 rewriter, expand->getLoc(), *newType, val,
86 expand.getReassociationIndices(),
87 expand.getMixedOutputShape());
88 })
89 .Case([&](memref::CollapseShapeOp collapse) -> Value {
90 FailureOr<MemRefType> newType =
91 memref::CollapseShapeOp::computeCollapsedType(
92 srcType, collapse.getReassociationIndices());
93 if (failed(newType)) {
94 typeInferenceFailed = true;
95 return Value();
96 }
97 return memref::CollapseShapeOp::create(
98 rewriter, collapse->getLoc(), *newType, val,
99 collapse.getReassociationIndices());
100 })
101 .Case([&](memref::CastOp cast) -> Value {
102 if (!memref::CastOp::areCastCompatible(srcType, cast.getType())) {
103 typeInferenceFailed = true;
104 return Value();
105 }
106 return memref::CastOp::create(rewriter, cast->getLoc(),
107 cast.getType(), val);
108 })
109 .Default([&](Operation *) -> Value { return Value(); });
110
111 if (typeInferenceFailed) {
112 user->emitOpError(
113 "failed to compute view-like result type after multi-buffering");
114 return failure();
115 }
116
117 if (replacement) {
118 // Recursively propagate through view-like ops and mark old op for
119 // erasure.
120 if (failed(replaceUsesAndPropagateType(rewriter, user, replacement)))
121 return failure();
122 opsToErase.push_back(user);
123 } else {
124 // Not a view-like op: just replace operand.
125 rewriter.startOpModification(user);
126 use.set(val);
127 rewriter.finalizeOpModification(user);
128 }
129 }
130
131 for (Operation *op : opsToErase) {
132 rewriter.eraseOp(op);
133 }
134
135 return success();
136}
137
138// Transformation to do multi-buffering/array expansion to remove dependencies
139// on the temporary allocation between consecutive loop iterations.
140// Returns success if the transformation happened and failure otherwise.
141// This is not a pattern as it requires propagating the new memref type to its
142// uses and requires updating subview ops.
143FailureOr<memref::AllocOp>
144mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
145 unsigned multiBufferingFactor,
146 bool skipOverrideAnalysis) {
147 LLVM_DEBUG(DBGS() << "Start multibuffering: " << allocOp << "\n");
148 DominanceInfo dom(allocOp->getParentOp());
149 LoopLikeOpInterface candidateLoop;
150 for (Operation *user : allocOp->getUsers()) {
151 auto parentLoop = user->getParentOfType<LoopLikeOpInterface>();
152 if (!parentLoop) {
153 if (isa<memref::DeallocOp>(user)) {
154 // Allow dealloc outside of any loop.
155 // TODO: The whole precondition function here is very brittle and will
156 // need to rethought an isolated into a cleaner analysis.
157 continue;
158 }
159 LLVM_DEBUG(DBGS() << "--no parent loop -> fail\n");
160 LLVM_DEBUG(DBGS() << "----due to user: " << *user << "\n");
161 return failure();
162 }
163 if (!skipOverrideAnalysis) {
164 /// Make sure there is no loop-carried dependency on the allocation.
165 if (!overrideBuffer(user, allocOp.getResult())) {
166 LLVM_DEBUG(DBGS() << "--Skip user: found loop-carried dependence\n");
167 continue;
168 }
169 // If this user doesn't dominate all the other users keep looking.
170 if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
171 return !dom.dominates(user, otherUser);
172 })) {
173 LLVM_DEBUG(
174 DBGS() << "--Skip user: does not dominate all other users\n");
175 continue;
176 }
177 } else {
178 if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
179 return !isa<memref::DeallocOp>(otherUser) &&
180 !parentLoop->isProperAncestor(otherUser);
181 })) {
182 LLVM_DEBUG(
183 DBGS()
184 << "--Skip user: not all other users are in the parent loop\n");
185 continue;
186 }
187 }
188 candidateLoop = parentLoop;
189 break;
190 }
191
192 if (!candidateLoop) {
193 LLVM_DEBUG(DBGS() << "Skip alloc: no candidate loop\n");
194 return failure();
195 }
196
197 std::optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
198 std::optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
199 std::optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
200 if (!inductionVar || !lowerBound || !singleStep ||
201 !llvm::hasSingleElement(candidateLoop.getLoopRegions())) {
202 LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb, step or region\n");
203 return failure();
204 }
205
206 if (!dom.dominates(allocOp.getOperation(), candidateLoop)) {
207 LLVM_DEBUG(DBGS() << "Skip alloc: does not dominate candidate loop\n");
208 return failure();
209 }
210
211 LLVM_DEBUG(DBGS() << "Start multibuffering loop: " << candidateLoop << "\n");
212
213 // 1. Construct the multi-buffered memref type.
214 ArrayRef<int64_t> originalShape = allocOp.getType().getShape();
215 SmallVector<int64_t, 4> multiBufferedShape{multiBufferingFactor};
216 llvm::append_range(multiBufferedShape, originalShape);
217 LLVM_DEBUG(DBGS() << "--original type: " << allocOp.getType() << "\n");
218 MemRefType mbMemRefType = MemRefType::Builder(allocOp.getType())
219 .setShape(multiBufferedShape)
220 .setLayout(MemRefLayoutAttrInterface());
221 LLVM_DEBUG(DBGS() << "--multi-buffered type: " << mbMemRefType << "\n");
222
223 // 2. Create the multi-buffered alloc.
224 Location loc = allocOp->getLoc();
225 OpBuilder::InsertionGuard g(rewriter);
226 rewriter.setInsertionPoint(allocOp);
227 auto mbAlloc = memref::AllocOp::create(rewriter, loc, mbMemRefType,
228 ValueRange{}, allocOp->getAttrs());
229 LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n");
230
231 // 3. Within the loop, build the modular leading index (i.e. each loop
232 // iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor).
234 &candidateLoop.getLoopRegions().front()->front());
235 Value ivVal = *inductionVar;
236 Value lbVal = getValueOrCreateConstantIndexOp(rewriter, loc, *lowerBound);
237 Value stepVal = getValueOrCreateConstantIndexOp(rewriter, loc, *singleStep);
238 AffineExpr iv, lb, step;
239 bindDims(rewriter.getContext(), iv, lb, step);
241 rewriter, loc, ((iv - lb).floorDiv(step)) % multiBufferingFactor,
242 {ivVal, lbVal, stepVal});
243 LLVM_DEBUG(DBGS() << "--multi-buffered indexing: " << bufferIndex << "\n");
244
245 // 4. Build the subview accessing the particular slice, taking modular
246 // rotation into account.
247 int64_t mbMemRefTypeRank = mbMemRefType.getRank();
248 IntegerAttr zero = rewriter.getIndexAttr(0);
249 IntegerAttr one = rewriter.getIndexAttr(1);
250 SmallVector<OpFoldResult> offsets(mbMemRefTypeRank, zero);
251 SmallVector<OpFoldResult> sizes(mbMemRefTypeRank, one);
252 SmallVector<OpFoldResult> strides(mbMemRefTypeRank, one);
253 // Offset is [bufferIndex, 0 ... 0 ].
254 offsets.front() = bufferIndex;
255 // Sizes is [1, original_size_0 ... original_size_n ].
256 for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
257 sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]);
258 // Strides is [1, 1 ... 1 ].
259 MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType(
260 originalShape, mbMemRefType, offsets, sizes, strides);
261 Value subview = memref::SubViewOp::create(rewriter, loc, dstMemref, mbAlloc,
262 offsets, sizes, strides);
263 LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
264
265 // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need
266 // to handle dealloc uses separately..
267 for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) {
268 auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner());
269 if (!deallocOp)
270 continue;
271 OpBuilder::InsertionGuard g(rewriter);
272 rewriter.setInsertionPoint(deallocOp);
273 auto newDeallocOp =
274 memref::DeallocOp::create(rewriter, deallocOp->getLoc(), mbAlloc);
275 (void)newDeallocOp;
276 LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n");
277 rewriter.eraseOp(deallocOp);
278 }
279
280 // 6. RAUW with the particular slice, taking modular rotation into account.
281 if (failed(replaceUsesAndPropagateType(rewriter, allocOp, subview)))
282 return failure();
283
284 // 7. Finally, erase the old allocOp.
285 rewriter.eraseOp(allocOp);
286
287 return mbAlloc;
288}
289
290FailureOr<memref::AllocOp>
291mlir::memref::multiBuffer(memref::AllocOp allocOp,
292 unsigned multiBufferingFactor,
293 bool skipOverrideAnalysis) {
294 IRRewriter rewriter(allocOp->getContext());
295 return multiBuffer(rewriter, allocOp, multiBufferingFactor,
296 skipOverrideAnalysis);
297}
return success()
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static LogicalResult replaceUsesAndPropagateType(RewriterBase &rewriter, Operation *oldOp, Value val)
Replace the uses of oldOp with the given val and for view-like uses propagate the type change.
static bool overrideBuffer(Operation *op, Value buffer)
Return true if the op fully overwrite the given buffer value.
#define DBGS()
Base type for affine expression.
Definition AffineExpr.h:68
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
MLIRContext * getContext() const
Definition Builders.h:56
A class for computing basic dominance information.
Definition Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition Operation.h:846
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< memref::AllocOp > multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, unsigned multiplier, bool skipOverrideAnalysis=false)
Transformation to do multi-buffering/array expansion to remove dependencies on the temporary allocati...
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112