MLIR  19.0.0git
BufferDeallocationOpInterface.cpp
Go to the documentation of this file.
1 //===- BufferDeallocationOpInterface.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/AsmState.h"
13 #include "mlir/IR/Matchers.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/IR/TypeUtilities.h"
16 #include "mlir/IR/Value.h"
17 #include "llvm/ADT/SetOperations.h"
18 
19 //===----------------------------------------------------------------------===//
20 // BufferDeallocationOpInterface
21 //===----------------------------------------------------------------------===//
22 
23 namespace mlir {
24 namespace bufferization {
25 
26 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp.inc"
27 
28 } // namespace bufferization
29 } // namespace mlir
30 
31 using namespace mlir;
32 using namespace bufferization;
33 
34 //===----------------------------------------------------------------------===//
35 // Helpers
36 //===----------------------------------------------------------------------===//
37 
38 static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
39  return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value));
40 }
41 
42 static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
43 
44 //===----------------------------------------------------------------------===//
45 // Ownership
46 //===----------------------------------------------------------------------===//
47 
49  : indicator(indicator), state(State::Unique) {}
50 
52  Ownership unknown;
53  unknown.indicator = Value();
54  unknown.state = State::Unknown;
55  return unknown;
56 }
57 Ownership Ownership::getUnique(Value indicator) { return Ownership(indicator); }
59 
61  return state == State::Uninitialized;
62 }
63 bool Ownership::isUnique() const { return state == State::Unique; }
64 bool Ownership::isUnknown() const { return state == State::Unknown; }
65 
67  assert(isUnique() && "must have unique ownership to get the indicator");
68  return indicator;
69 }
70 
72  if (other.isUninitialized())
73  return *this;
74  if (isUninitialized())
75  return other;
76 
77  if (!isUnique() || !other.isUnique())
78  return getUnknown();
79 
80  // Since we create a new constant i1 value for (almost) each use-site, we
81  // should compare the actual value rather than just the SSA Value to avoid
82  // unnecessary invalidations.
83  if (isEqualConstantIntOrValue(indicator, other.indicator))
84  return *this;
85 
86  // Return the join of the lattice if the indicator of both ownerships cannot
87  // be merged.
88  return getUnknown();
89 }
90 
91 void Ownership::combine(Ownership other) { *this = getCombined(other); }
92 
93 //===----------------------------------------------------------------------===//
94 // DeallocationState
95 //===----------------------------------------------------------------------===//
96 
98 
100  Block *block) {
101  // In most cases we care about the block where the value is defined.
102  if (block == nullptr)
103  block = memref.getParentBlock();
104 
105  // Update ownership of current memref itself.
106  ownershipMap[{memref, block}].combine(ownership);
107 }
108 
110  for (Value val : memrefs)
111  ownershipMap[{val, block}] = Ownership::getUninitialized();
112 }
113 
115  return ownershipMap.lookup({memref, block});
116 }
117 
119  memrefsToDeallocatePerBlock[block].push_back(memref);
120 }
121 
123  llvm::erase(memrefsToDeallocatePerBlock[block], memref);
124 }
125 
127  SmallVectorImpl<Value> &memrefs) {
128  SmallVector<Value> liveMemrefs(
129  llvm::make_filter_range(liveness.getLiveIn(block), isMemref));
130  llvm::sort(liveMemrefs, ValueComparator());
131  memrefs.append(liveMemrefs);
132 }
133 
134 std::pair<Value, Value>
136  Value memref, Block *block) {
137  auto iter = ownershipMap.find({memref, block});
138  assert(iter != ownershipMap.end() &&
139  "Value must already have been registered in the ownership map");
140 
141  Ownership ownership = iter->second;
142  if (ownership.isUnique())
143  return {memref, ownership.getIndicator()};
144 
145  // Instead of inserting a clone operation we could also insert a dealloc
146  // operation earlier in the block and use the updated ownerships returned by
147  // the op for the retained values. Alternatively, we could insert code to
148  // check aliasing at runtime and use this information to combine two unique
149  // ownerships more intelligently to not end up with an 'Unknown' ownership in
150  // the first place.
151  auto cloneOp =
152  builder.create<bufferization::CloneOp>(memref.getLoc(), memref);
153  Value condition = buildBoolValue(builder, memref.getLoc(), true);
154  Value newMemref = cloneOp.getResult();
155  updateOwnership(newMemref, condition);
156  memrefsToDeallocatePerBlock[newMemref.getParentBlock()].push_back(newMemref);
157  return {newMemref, condition};
158 }
159 
161  Block *fromBlock, Block *toBlock, ValueRange destOperands,
162  SmallVectorImpl<Value> &toRetain) const {
163  for (Value operand : destOperands) {
164  if (!isMemref(operand))
165  continue;
166  toRetain.push_back(operand);
167  }
168 
169  SmallPtrSet<Value, 16> liveOut;
170  for (auto val : liveness.getLiveOut(fromBlock))
171  if (isMemref(val))
172  liveOut.insert(val);
173 
174  if (toBlock)
175  llvm::set_intersect(liveOut, liveness.getLiveIn(toBlock));
176 
177  // liveOut has non-deterministic order because it was constructed by iterating
178  // over a hash-set.
179  SmallVector<Value> retainedByLiveness(liveOut.begin(), liveOut.end());
180  std::sort(retainedByLiveness.begin(), retainedByLiveness.end(),
181  ValueComparator());
182  toRetain.append(retainedByLiveness);
183 }
184 
186  OpBuilder &builder, Location loc, Block *block,
187  SmallVectorImpl<Value> &memrefs, SmallVectorImpl<Value> &conditions) const {
188 
189  for (auto [i, memref] :
190  llvm::enumerate(memrefsToDeallocatePerBlock.lookup(block))) {
191  Ownership ownership = ownershipMap.lookup({memref, block});
192  if (!ownership.isUnique())
193  return emitError(memref.getLoc(),
194  "MemRef value does not have valid ownership");
195 
196  // Simply cast unranked MemRefs to ranked memrefs with 0 dimensions such
197  // that we can call extract_strided_metadata on it.
198  if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
199  memref = builder.create<memref::ReinterpretCastOp>(
200  loc, MemRefType::get({}, unrankedMemRefTy.getElementType()), memref,
202 
203  // Use the `memref.extract_strided_metadata` operation to get the base
204  // memref. This is needed because the same MemRef that was produced by the
205  // alloc operation has to be passed to the dealloc operation. Passing
206  // subviews, etc. to a dealloc operation is not allowed.
207  memrefs.push_back(
208  builder.create<memref::ExtractStridedMetadataOp>(loc, memref)
209  .getResult(0));
210  conditions.push_back(ownership.getIndicator());
211  }
212 
213  return success();
214 }
215 
216 //===----------------------------------------------------------------------===//
217 // ValueComparator
218 //===----------------------------------------------------------------------===//
219 
220 bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const {
221  if (lhs == rhs)
222  return false;
223 
224  // Block arguments are less than results.
225  bool lhsIsBBArg = isa<BlockArgument>(lhs);
226  if (lhsIsBBArg != isa<BlockArgument>(rhs)) {
227  return lhsIsBBArg;
228  }
229 
230  Region *lhsRegion;
231  Region *rhsRegion;
232  if (lhsIsBBArg) {
233  auto lhsBBArg = llvm::cast<BlockArgument>(lhs);
234  auto rhsBBArg = llvm::cast<BlockArgument>(rhs);
235  if (lhsBBArg.getArgNumber() != rhsBBArg.getArgNumber()) {
236  return lhsBBArg.getArgNumber() < rhsBBArg.getArgNumber();
237  }
238  lhsRegion = lhsBBArg.getParentRegion();
239  rhsRegion = rhsBBArg.getParentRegion();
240  assert(lhsRegion != rhsRegion &&
241  "lhsRegion == rhsRegion implies lhs == rhs");
242  } else if (lhs.getDefiningOp() == rhs.getDefiningOp()) {
243  return llvm::cast<OpResult>(lhs).getResultNumber() <
244  llvm::cast<OpResult>(rhs).getResultNumber();
245  } else {
246  lhsRegion = lhs.getDefiningOp()->getParentRegion();
247  rhsRegion = rhs.getDefiningOp()->getParentRegion();
248  if (lhsRegion == rhsRegion) {
249  return lhs.getDefiningOp()->isBeforeInBlock(rhs.getDefiningOp());
250  }
251  }
252 
253  // lhsRegion != rhsRegion, so if we look at their ancestor chain, they
254  // - have different heights
255  // - or there's a spot where their region numbers differ
256  // - or their parent regions are the same and their parent ops are
257  // different.
258  while (lhsRegion && rhsRegion) {
259  if (lhsRegion->getRegionNumber() != rhsRegion->getRegionNumber()) {
260  return lhsRegion->getRegionNumber() < rhsRegion->getRegionNumber();
261  }
262  if (lhsRegion->getParentRegion() == rhsRegion->getParentRegion()) {
263  return lhsRegion->getParentOp()->isBeforeInBlock(
264  rhsRegion->getParentOp());
265  }
266  lhsRegion = lhsRegion->getParentRegion();
267  rhsRegion = rhsRegion->getParentRegion();
268  }
269  if (rhsRegion)
270  return true;
271  assert(lhsRegion && "this should only happen if lhs == rhs");
272  return false;
273 }
274 
275 //===----------------------------------------------------------------------===//
276 // Implementation utilities
277 //===----------------------------------------------------------------------===//
278 
280  DeallocationState &state, Operation *op, ValueRange operands,
281  SmallVectorImpl<Value> &updatedOperandOwnerships) {
282  assert(op->hasTrait<OpTrait::IsTerminator>() && "must be a terminator");
283  assert(!op->hasSuccessors() && "must not have any successors");
284  // Collect the values to deallocate and retain and use them to create the
285  // dealloc operation.
286  OpBuilder builder(op);
287  Block *block = op->getBlock();
288  SmallVector<Value> memrefs, conditions, toRetain;
289  if (failed(state.getMemrefsAndConditionsToDeallocate(
290  builder, op->getLoc(), block, memrefs, conditions)))
291  return failure();
292 
293  state.getMemrefsToRetain(block, /*toBlock=*/nullptr, operands, toRetain);
294  if (memrefs.empty() && toRetain.empty())
295  return op;
296 
297  auto deallocOp = builder.create<bufferization::DeallocOp>(
298  op->getLoc(), memrefs, conditions, toRetain);
299 
300  // We want to replace the current ownership of the retained values with the
301  // result values of the dealloc operation as they are always unique.
302  state.resetOwnerships(deallocOp.getRetained(), block);
303  for (auto [retained, ownership] :
304  llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions()))
305  state.updateOwnership(retained, ownership, block);
306 
307  unsigned numMemrefOperands = llvm::count_if(operands, isMemref);
308  auto newOperandOwnerships =
309  deallocOp.getUpdatedConditions().take_front(numMemrefOperands);
310  updatedOperandOwnerships.append(newOperandOwnerships.begin(),
311  newOperandOwnerships.end());
312 
313  return op;
314 }
static bool isMemref(Value v)
static Value buildBoolValue(OpBuilder &builder, Location loc, bool value)
Block represents an ordered list of Operations.
Definition: Block.h:30
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
const ValueSetT & getLiveOut(Block *block) const
Returns a reference to a set containing live-out values (unordered).
Definition: Liveness.cpp:236
const ValueSetT & getLiveIn(Block *block) const
Returns a reference to a set containing live-in values (unordered).
Definition: Liveness.cpp:231
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:764
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
bool hasSuccessors()
Definition: Operation.h:701
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This class collects all the state that we need to perform the buffer deallocation pass with associate...
void addMemrefToDeallocate(Value memref, Block *block)
Remember the given 'memref' to deallocate it at the end of the 'block'.
Ownership getOwnership(Value memref, Block *block) const
Returns the ownership of 'memref' for the given basic block.
void resetOwnerships(ValueRange memrefs, Block *block)
Removes ownerships associated with all values in the passed range for 'block'.
void updateOwnership(Value memref, Ownership ownership, Block *block=nullptr)
Small helper function to update the ownership map by taking the current ownership ('Uninitialized' st...
std::pair< Value, Value > getMemrefWithUniqueOwnership(OpBuilder &builder, Value memref, Block *block)
Given an SSA value of MemRef type, this function queries the ownership and if it is not already in th...
LogicalResult getMemrefsAndConditionsToDeallocate(OpBuilder &builder, Location loc, Block *block, SmallVectorImpl< Value > &memrefs, SmallVectorImpl< Value > &conditions) const
For a given block, computes the list of MemRefs that potentially need to be deallocated at the end of...
void getLiveMemrefsIn(Block *block, SmallVectorImpl< Value > &memrefs)
Return a sorted list of MemRef values which are live at the start of the given block.
void dropMemrefToDeallocate(Value memref, Block *block)
Forget about a MemRef that we originally wanted to deallocate at the end of 'block',...
void getMemrefsToRetain(Block *fromBlock, Block *toBlock, ValueRange destOperands, SmallVectorImpl< Value > &toRetain) const
Given two basic blocks and the values passed via block arguments to the destination block,...
This class is used to track the ownership of values.
static Ownership getUnique(Value indicator)
Get an ownership value in 'Unique' state with 'indicator' as parameter.
Ownership getCombined(Ownership other) const
Get the join of the two-element subset {this,other}.
void combine(Ownership other)
Modify 'this' ownership to be the join of the current 'this' and 'other'.
Ownership()=default
Constructor that creates an 'Uninitialized' ownership.
bool isUnknown() const
Check if this ownership value is in the 'Unknown' state.
bool isUnique() const
Check if this ownership value is in the 'Unique' state.
static Ownership getUnknown()
Get an ownership value in 'Unknown' state.
Value getIndicator() const
If this ownership value is in 'Unique' state, this function can be used to get the indicator paramete...
bool isUninitialized() const
Check if this ownership value is in the 'Uninitialized' state.
static Ownership getUninitialized()
Get an ownership value in 'Uninitialized' state.
FailureOr< Operation * > insertDeallocOpForReturnLike(DeallocationState &state, Operation *op, ValueRange operands, SmallVectorImpl< Value > &updatedOperandOwnerships)
Insert a bufferization.dealloc operation right before op which has to be a terminator without any suc...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Compare two SSA values in a deterministic manner.
bool operator()(const Value &lhs, const Value &rhs) const