MLIR  19.0.0git
MultiBuffer.cpp
Go to the documentation of this file.
1 //===----------- MultiBuffering.cpp ---------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements multi buffering transformation.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/Dominance.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/ValueRange.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/Support/Debug.h"
26 
27 using namespace mlir;
28 
29 #define DEBUG_TYPE "memref-transforms"
30 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
31 #define DBGSNL() (llvm::dbgs() << "\n")
32 
33 /// Return true if the op fully overwrite the given `buffer` value.
34 static bool overrideBuffer(Operation *op, Value buffer) {
35  auto copyOp = dyn_cast<memref::CopyOp>(op);
36  if (!copyOp)
37  return false;
38  return copyOp.getTarget() == buffer;
39 }
40 
41 /// Replace the uses of `oldOp` with the given `val` and for subview uses
42 /// propagate the type change. Changing the memref type may require propagating
43 /// it through subview ops so we cannot just do a replaceAllUse but need to
44 /// propagate the type change and erase old subview ops.
46  Operation *oldOp, Value val) {
47  SmallVector<Operation *> opsToDelete;
48  SmallVector<OpOperand *> operandsToReplace;
49 
50  // Save the operand to replace / delete later (avoid iterator invalidation).
51  // TODO: can we use an early_inc iterator?
52  for (OpOperand &use : oldOp->getUses()) {
53  // Non-subview ops will be replaced by `val`.
54  auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner());
55  if (!subviewUse) {
56  operandsToReplace.push_back(&use);
57  continue;
58  }
59 
60  // `subview(old_op)` is replaced by a new `subview(val)`.
61  OpBuilder::InsertionGuard g(rewriter);
62  rewriter.setInsertionPoint(subviewUse);
63  Type newType = memref::SubViewOp::inferRankReducedResultType(
64  subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
65  subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
66  subviewUse.getStaticStrides());
67  Value newSubview = rewriter.create<memref::SubViewOp>(
68  subviewUse->getLoc(), cast<MemRefType>(newType), val,
69  subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
70  subviewUse.getMixedStrides());
71 
72  // Ouch recursion ... is this really necessary?
73  replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
74 
75  opsToDelete.push_back(use.getOwner());
76  }
77 
78  // Perform late replacement.
79  // TODO: can we use an early_inc iterator?
80  for (OpOperand *operand : operandsToReplace) {
81  Operation *op = operand->getOwner();
82  rewriter.startOpModification(op);
83  operand->set(val);
84  rewriter.finalizeOpModification(op);
85  }
86 
87  // Perform late op erasure.
88  // TODO: can we use an early_inc iterator?
89  for (Operation *op : opsToDelete)
90  rewriter.eraseOp(op);
91 }
92 
93 // Transformation to do multi-buffering/array expansion to remove dependencies
94 // on the temporary allocation between consecutive loop iterations.
95 // Returns success if the transformation happened and failure otherwise.
96 // This is not a pattern as it requires propagating the new memref type to its
97 // uses and requires updating subview ops.
99 mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
100  unsigned multiBufferingFactor,
101  bool skipOverrideAnalysis) {
102  LLVM_DEBUG(DBGS() << "Start multibuffering: " << allocOp << "\n");
103  DominanceInfo dom(allocOp->getParentOp());
104  LoopLikeOpInterface candidateLoop;
105  for (Operation *user : allocOp->getUsers()) {
106  auto parentLoop = user->getParentOfType<LoopLikeOpInterface>();
107  if (!parentLoop) {
108  if (isa<memref::DeallocOp>(user)) {
109  // Allow dealloc outside of any loop.
110  // TODO: The whole precondition function here is very brittle and will
111  // need to rethought an isolated into a cleaner analysis.
112  continue;
113  }
114  LLVM_DEBUG(DBGS() << "--no parent loop -> fail\n");
115  LLVM_DEBUG(DBGS() << "----due to user: " << *user << "\n");
116  return failure();
117  }
118  if (!skipOverrideAnalysis) {
119  /// Make sure there is no loop-carried dependency on the allocation.
120  if (!overrideBuffer(user, allocOp.getResult())) {
121  LLVM_DEBUG(DBGS() << "--Skip user: found loop-carried dependence\n");
122  continue;
123  }
124  // If this user doesn't dominate all the other users keep looking.
125  if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
126  return !dom.dominates(user, otherUser);
127  })) {
128  LLVM_DEBUG(
129  DBGS() << "--Skip user: does not dominate all other users\n");
130  continue;
131  }
132  } else {
133  if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
134  return !isa<memref::DeallocOp>(otherUser) &&
135  !parentLoop->isProperAncestor(otherUser);
136  })) {
137  LLVM_DEBUG(
138  DBGS()
139  << "--Skip user: not all other users are in the parent loop\n");
140  continue;
141  }
142  }
143  candidateLoop = parentLoop;
144  break;
145  }
146 
147  if (!candidateLoop) {
148  LLVM_DEBUG(DBGS() << "Skip alloc: no candidate loop\n");
149  return failure();
150  }
151 
152  std::optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
153  std::optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
154  std::optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
155  if (!inductionVar || !lowerBound || !singleStep ||
156  !llvm::hasSingleElement(candidateLoop.getLoopRegions())) {
157  LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb, step or region\n");
158  return failure();
159  }
160 
161  if (!dom.dominates(allocOp.getOperation(), candidateLoop)) {
162  LLVM_DEBUG(DBGS() << "Skip alloc: does not dominate candidate loop\n");
163  return failure();
164  }
165 
166  LLVM_DEBUG(DBGS() << "Start multibuffering loop: " << candidateLoop << "\n");
167 
168  // 1. Construct the multi-buffered memref type.
169  ArrayRef<int64_t> originalShape = allocOp.getType().getShape();
170  SmallVector<int64_t, 4> multiBufferedShape{multiBufferingFactor};
171  llvm::append_range(multiBufferedShape, originalShape);
172  LLVM_DEBUG(DBGS() << "--original type: " << allocOp.getType() << "\n");
173  MemRefType mbMemRefType = MemRefType::Builder(allocOp.getType())
174  .setShape(multiBufferedShape)
175  .setLayout(MemRefLayoutAttrInterface());
176  LLVM_DEBUG(DBGS() << "--multi-buffered type: " << mbMemRefType << "\n");
177 
178  // 2. Create the multi-buffered alloc.
179  Location loc = allocOp->getLoc();
180  OpBuilder::InsertionGuard g(rewriter);
181  rewriter.setInsertionPoint(allocOp);
182  auto mbAlloc = rewriter.create<memref::AllocOp>(
183  loc, mbMemRefType, ValueRange{}, allocOp->getAttrs());
184  LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n");
185 
186  // 3. Within the loop, build the modular leading index (i.e. each loop
187  // iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor).
188  rewriter.setInsertionPointToStart(
189  &candidateLoop.getLoopRegions().front()->front());
190  Value ivVal = *inductionVar;
191  Value lbVal = getValueOrCreateConstantIndexOp(rewriter, loc, *lowerBound);
192  Value stepVal = getValueOrCreateConstantIndexOp(rewriter, loc, *singleStep);
193  AffineExpr iv, lb, step;
194  bindDims(rewriter.getContext(), iv, lb, step);
196  rewriter, loc, ((iv - lb).floorDiv(step)) % multiBufferingFactor,
197  {ivVal, lbVal, stepVal});
198  LLVM_DEBUG(DBGS() << "--multi-buffered indexing: " << bufferIndex << "\n");
199 
200  // 4. Build the subview accessing the particular slice, taking modular
201  // rotation into account.
202  int64_t mbMemRefTypeRank = mbMemRefType.getRank();
203  IntegerAttr zero = rewriter.getIndexAttr(0);
204  IntegerAttr one = rewriter.getIndexAttr(1);
205  SmallVector<OpFoldResult> offsets(mbMemRefTypeRank, zero);
206  SmallVector<OpFoldResult> sizes(mbMemRefTypeRank, one);
207  SmallVector<OpFoldResult> strides(mbMemRefTypeRank, one);
208  // Offset is [bufferIndex, 0 ... 0 ].
209  offsets.front() = bufferIndex;
210  // Sizes is [1, original_size_0 ... original_size_n ].
211  for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
212  sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]);
213  // Strides is [1, 1 ... 1 ].
214  auto dstMemref =
215  cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
216  originalShape, mbMemRefType, offsets, sizes, strides));
217  Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
218  offsets, sizes, strides);
219  LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
220 
221  // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to
222  // handle dealloc uses separately..
223  for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) {
224  auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner());
225  if (!deallocOp)
226  continue;
227  OpBuilder::InsertionGuard g(rewriter);
228  rewriter.setInsertionPoint(deallocOp);
229  auto newDeallocOp =
230  rewriter.create<memref::DeallocOp>(deallocOp->getLoc(), mbAlloc);
231  (void)newDeallocOp;
232  LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n");
233  rewriter.eraseOp(deallocOp);
234  }
235 
236  // 6. RAUW with the particular slice, taking modular rotation into account.
237  replaceUsesAndPropagateType(rewriter, allocOp, subview);
238 
239  // 7. Finally, erase the old allocOp.
240  rewriter.eraseOp(allocOp);
241 
242  return mbAlloc;
243 }
244 
246 mlir::memref::multiBuffer(memref::AllocOp allocOp,
247  unsigned multiBufferingFactor,
248  bool skipOverrideAnalysis) {
249  IRRewriter rewriter(allocOp->getContext());
250  return multiBuffer(rewriter, allocOp, multiBufferingFactor,
251  skipOverrideAnalysis);
252 }
static bool overrideBuffer(Operation *op, Value buffer)
Return true if the op fully overwrite the given buffer value.
Definition: MultiBuffer.cpp:34
static void replaceUsesAndPropagateType(RewriterBase &rewriter, Operation *oldOp, Value val)
Replace the uses of oldOp with the given val and for subview uses propagate the type change.
Definition: MultiBuffer.cpp:45
#define DBGS()
Definition: MultiBuffer.cpp:30
Base type for affine expression.
Definition: AffineExpr.h:69
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
MLIRContext * getContext() const
Definition: Builders.h:55
A class for computing basic dominance information.
Definition: Dominance.h:136
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:156
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:201
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:212
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:842
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1138
FailureOr< memref::AllocOp > multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, unsigned multiplier, bool skipOverrideAnalysis=false)
Transformation to do multi-buffering/array expansion to remove dependencies on the temporary allocati...
Definition: MultiBuffer.cpp:99
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
int64_t floorDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's floordiv operation on constants.
Definition: MathExtras.h:33
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:349
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41