MLIR  16.0.0git
MultiBuffer.cpp
Go to the documentation of this file.
1 //===----------- MultiBuffering.cpp ---------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements multi buffering transformation.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/Dominance.h"
18 
19 using namespace mlir;
20 
21 /// Return true if the op fully overwrite the given `buffer` value.
22 static bool overrideBuffer(Operation *op, Value buffer) {
23  auto copyOp = dyn_cast<memref::CopyOp>(op);
24  if (!copyOp)
25  return false;
26  return copyOp.getTarget() == buffer;
27 }
28 
29 /// Replace the uses of `oldOp` with the given `val` and for subview uses
30 /// propagate the type change. Changing the memref type may require propagating
31 /// it through subview ops so we cannot just do a replaceAllUse but need to
32 /// propagate the type change and erase old subview ops.
34  OpBuilder &builder) {
35  SmallVector<Operation *> opToDelete;
36  SmallVector<OpOperand *> operandsToReplace;
37  for (OpOperand &use : oldOp->getUses()) {
38  auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner());
39  if (!subviewUse) {
40  // Save the operand to and replace outside the loop to not invalidate the
41  // iterator.
42  operandsToReplace.push_back(&use);
43  continue;
44  }
45  builder.setInsertionPoint(subviewUse);
46  Type newType = memref::SubViewOp::inferRankReducedResultType(
47  subviewUse.getType().getShape(), val.getType().cast<MemRefType>(),
48  extractFromI64ArrayAttr(subviewUse.getStaticOffsets()),
49  extractFromI64ArrayAttr(subviewUse.getStaticSizes()),
50  extractFromI64ArrayAttr(subviewUse.getStaticStrides()));
51  Value newSubview = builder.create<memref::SubViewOp>(
52  subviewUse->getLoc(), newType.cast<MemRefType>(), val,
53  subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
54  subviewUse.getMixedStrides());
55  replaceUsesAndPropagateType(subviewUse, newSubview, builder);
56  opToDelete.push_back(use.getOwner());
57  }
58  for (OpOperand *operand : operandsToReplace)
59  operand->set(val);
60  // Clean up old subview ops.
61  for (Operation *op : opToDelete)
62  op->erase();
63 }
64 
65 /// Helper to convert get a value from an OpFoldResult or create it at the
66 /// builder insert point.
68  Location loc) {
69  Value value = res.dyn_cast<Value>();
70  if (value)
71  return value;
72  return builder.create<arith::ConstantIndexOp>(
73  loc, res.dyn_cast<Attribute>().cast<IntegerAttr>().getInt());
74 }
75 
76 // Transformation to do multi-buffering/array expansion to remove dependencies
77 // on the temporary allocation between consecutive loop iterations.
78 // Returns success if the transformation happened and failure otherwise.
79 // This is not a pattern as it requires propagating the new memref type to its
80 // uses and requires updating subview ops.
82  unsigned multiplier) {
83  DominanceInfo dom(allocOp->getParentOp());
84  LoopLikeOpInterface candidateLoop;
85  for (Operation *user : allocOp->getUsers()) {
86  auto parentLoop = user->getParentOfType<LoopLikeOpInterface>();
87  if (!parentLoop)
88  return failure();
89  /// Make sure there is no loop carried dependency on the allocation.
90  if (!overrideBuffer(user, allocOp.getResult()))
91  continue;
92  // If this user doesn't dominate all the other users keep looking.
93  if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
94  return !dom.dominates(user, otherUser);
95  }))
96  continue;
97  candidateLoop = parentLoop;
98  break;
99  }
100  if (!candidateLoop)
101  return failure();
102  llvm::Optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
103  llvm::Optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
104  llvm::Optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
105  if (!inductionVar || !lowerBound || !singleStep)
106  return failure();
107  OpBuilder builder(candidateLoop);
108  Value stepValue =
109  getOrCreateValue(*singleStep, builder, candidateLoop->getLoc());
110  Value lowerBoundValue =
111  getOrCreateValue(*lowerBound, builder, candidateLoop->getLoc());
112  SmallVector<int64_t, 4> newShape(1, multiplier);
113  ArrayRef<int64_t> oldShape = allocOp.getType().getShape();
114  newShape.append(oldShape.begin(), oldShape.end());
115  auto newMemref = MemRefType::get(newShape, allocOp.getType().getElementType(),
116  MemRefLayoutAttrInterface(),
117  allocOp.getType().getMemorySpace());
118  builder.setInsertionPoint(allocOp);
119  Location loc = allocOp->getLoc();
120  auto newAlloc = builder.create<memref::AllocOp>(loc, newMemref);
121  builder.setInsertionPoint(&candidateLoop.getLoopBody().front(),
122  candidateLoop.getLoopBody().front().begin());
123  AffineExpr induc = getAffineDimExpr(0, allocOp.getContext());
124  AffineExpr init = getAffineDimExpr(1, allocOp.getContext());
125  AffineExpr step = getAffineDimExpr(2, allocOp.getContext());
126  AffineExpr expr = ((induc - init).floorDiv(step)) % multiplier;
127  auto map = AffineMap::get(3, 0, expr);
128  std::array<Value, 3> operands = {*inductionVar, lowerBoundValue, stepValue};
129  Value bufferIndex = builder.create<AffineApplyOp>(loc, map, operands);
130  SmallVector<OpFoldResult> offsets, sizes, strides;
131  offsets.push_back(bufferIndex);
132  offsets.append(oldShape.size(), builder.getIndexAttr(0));
133  strides.assign(oldShape.size() + 1, builder.getIndexAttr(1));
134  sizes.push_back(builder.getIndexAttr(1));
135  for (int64_t size : oldShape)
136  sizes.push_back(builder.getIndexAttr(size));
137  auto dstMemref =
138  memref::SubViewOp::inferRankReducedResultType(
139  allocOp.getType().getShape(), newMemref, offsets, sizes, strides)
140  .cast<MemRefType>();
141  Value subview = builder.create<memref::SubViewOp>(loc, dstMemref, newAlloc,
142  offsets, sizes, strides);
143  replaceUsesAndPropagateType(allocOp, subview, builder);
144  allocOp.erase();
145  return newAlloc;
146 }
Include the generated interface declarations.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:623
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:348
This class represents a single result from folding an operation.
Definition: OpDefinition.h:239
static bool overrideBuffer(Operation *op, Value buffer)
Return true if the op fully overwrite the given buffer value.
Definition: MultiBuffer.cpp:22
A class for computing basic dominance information.
Definition: Dominance.h:117
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:414
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
int64_t floorDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR&#39;s floordiv operation on constants.
Definition: MathExtras.h:33
static void replaceUsesAndPropagateType(Operation *oldOp, Value val, OpBuilder &builder)
Replace the uses of oldOp with the given val and for subview uses propagate the type change...
Definition: MultiBuffer.cpp:33
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:78
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:488
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
FailureOr< memref::AllocOp > multiBuffer(memref::AllocOp allocOp, unsigned multiplier)
Transformation to do multi-buffering/array expansion to remove dependencies on the temporary allocati...
Definition: MultiBuffer.cpp:81
Type getType() const
Return the type of this value.
Definition: Value.h:118
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:80
static Value getOrCreateValue(OpFoldResult res, OpBuilder &builder, Location loc)
Helper to convert get a value from an OpFoldResult or create it at the builder insert point...
Definition: MultiBuffer.cpp:67
This class represents an operand of an operation.
Definition: Value.h:251
SmallVector< int64_t, 4 > extractFromI64ArrayAttr(Attribute attr)
Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:650
This class helps build Operations.
Definition: Builders.h:196
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:101
U cast() const
Definition: Types.h:278