MLIR  21.0.0git
BufferDeallocationOpInterface.cpp
Go to the documentation of this file.
1 //===- BufferDeallocationOpInterface.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/AsmState.h"
13 #include "mlir/IR/Matchers.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/IR/TypeUtilities.h"
16 #include "mlir/IR/Value.h"
17 #include "llvm/ADT/SetOperations.h"
18 
19 //===----------------------------------------------------------------------===//
20 // BufferDeallocationOpInterface
21 //===----------------------------------------------------------------------===//
22 
23 namespace mlir {
24 namespace bufferization {
25 
26 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp.inc"
27 
28 } // namespace bufferization
29 } // namespace mlir
30 
31 using namespace mlir;
32 using namespace bufferization;
33 
34 //===----------------------------------------------------------------------===//
35 // Helpers
36 //===----------------------------------------------------------------------===//
37 
38 static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
39  return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value));
40 }
41 
42 static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
43 
44 //===----------------------------------------------------------------------===//
45 // Ownership
46 //===----------------------------------------------------------------------===//
47 
49  : indicator(indicator), state(State::Unique) {}
50 
52  Ownership unknown;
53  unknown.indicator = Value();
54  unknown.state = State::Unknown;
55  return unknown;
56 }
57 Ownership Ownership::getUnique(Value indicator) { return Ownership(indicator); }
59 
61  return state == State::Uninitialized;
62 }
63 bool Ownership::isUnique() const { return state == State::Unique; }
64 bool Ownership::isUnknown() const { return state == State::Unknown; }
65 
67  assert(isUnique() && "must have unique ownership to get the indicator");
68  return indicator;
69 }
70 
72  if (other.isUninitialized())
73  return *this;
74  if (isUninitialized())
75  return other;
76 
77  if (!isUnique() || !other.isUnique())
78  return getUnknown();
79 
80  // Since we create a new constant i1 value for (almost) each use-site, we
81  // should compare the actual value rather than just the SSA Value to avoid
82  // unnecessary invalidations.
83  if (isEqualConstantIntOrValue(indicator, other.indicator))
84  return *this;
85 
86  // Return the join of the lattice if the indicator of both ownerships cannot
87  // be merged.
88  return getUnknown();
89 }
90 
91 void Ownership::combine(Ownership other) { *this = getCombined(other); }
92 
93 //===----------------------------------------------------------------------===//
94 // DeallocationState
95 //===----------------------------------------------------------------------===//
96 
98  SymbolTableCollection &symbolTables)
99  : symbolTable(symbolTables), liveness(op) {}
100 
102  Block *block) {
103  // In most cases we care about the block where the value is defined.
104  if (block == nullptr)
105  block = memref.getParentBlock();
106 
107  // Update ownership of current memref itself.
108  ownershipMap[{memref, block}].combine(ownership);
109 }
110 
112  for (Value val : memrefs)
113  ownershipMap[{val, block}] = Ownership::getUninitialized();
114 }
115 
117  return ownershipMap.lookup({memref, block});
118 }
119 
121  memrefsToDeallocatePerBlock[block].push_back(memref);
122 }
123 
125  llvm::erase(memrefsToDeallocatePerBlock[block], memref);
126 }
127 
129  SmallVectorImpl<Value> &memrefs) {
130  SmallVector<Value> liveMemrefs(
131  llvm::make_filter_range(liveness.getLiveIn(block), isMemref));
132  llvm::sort(liveMemrefs, ValueComparator());
133  memrefs.append(liveMemrefs);
134 }
135 
136 std::pair<Value, Value>
138  Value memref, Block *block) {
139  auto iter = ownershipMap.find({memref, block});
140  assert(iter != ownershipMap.end() &&
141  "Value must already have been registered in the ownership map");
142 
143  Ownership ownership = iter->second;
144  if (ownership.isUnique())
145  return {memref, ownership.getIndicator()};
146 
147  // Instead of inserting a clone operation we could also insert a dealloc
148  // operation earlier in the block and use the updated ownerships returned by
149  // the op for the retained values. Alternatively, we could insert code to
150  // check aliasing at runtime and use this information to combine two unique
151  // ownerships more intelligently to not end up with an 'Unknown' ownership in
152  // the first place.
153  auto cloneOp =
154  builder.create<bufferization::CloneOp>(memref.getLoc(), memref);
155  Value condition = buildBoolValue(builder, memref.getLoc(), true);
156  Value newMemref = cloneOp.getResult();
157  updateOwnership(newMemref, condition);
158  memrefsToDeallocatePerBlock[newMemref.getParentBlock()].push_back(newMemref);
159  return {newMemref, condition};
160 }
161 
163  Block *fromBlock, Block *toBlock, ValueRange destOperands,
164  SmallVectorImpl<Value> &toRetain) const {
165  for (Value operand : destOperands) {
166  if (!isMemref(operand))
167  continue;
168  toRetain.push_back(operand);
169  }
170 
171  SmallPtrSet<Value, 16> liveOut;
172  for (auto val : liveness.getLiveOut(fromBlock))
173  if (isMemref(val))
174  liveOut.insert(val);
175 
176  if (toBlock)
177  llvm::set_intersect(liveOut, liveness.getLiveIn(toBlock));
178 
179  // liveOut has non-deterministic order because it was constructed by iterating
180  // over a hash-set.
181  SmallVector<Value> retainedByLiveness(liveOut.begin(), liveOut.end());
182  llvm::sort(retainedByLiveness, ValueComparator());
183  toRetain.append(retainedByLiveness);
184 }
185 
187  OpBuilder &builder, Location loc, Block *block,
188  SmallVectorImpl<Value> &memrefs, SmallVectorImpl<Value> &conditions) const {
189 
190  for (auto [i, memref] :
191  llvm::enumerate(memrefsToDeallocatePerBlock.lookup(block))) {
192  Ownership ownership = ownershipMap.lookup({memref, block});
193  if (!ownership.isUnique())
194  return emitError(memref.getLoc(),
195  "MemRef value does not have valid ownership");
196 
197  // Simply cast unranked MemRefs to ranked memrefs with 0 dimensions such
198  // that we can call extract_strided_metadata on it.
199  if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
200  memref = builder.create<memref::ReinterpretCastOp>(
201  loc, memref,
202  /*offset=*/builder.getIndexAttr(0),
203  /*sizes=*/ArrayRef<OpFoldResult>{},
204  /*strides=*/ArrayRef<OpFoldResult>{});
205 
206  // Use the `memref.extract_strided_metadata` operation to get the base
207  // memref. This is needed because the same MemRef that was produced by the
208  // alloc operation has to be passed to the dealloc operation. Passing
209  // subviews, etc. to a dealloc operation is not allowed.
210  memrefs.push_back(
211  builder.create<memref::ExtractStridedMetadataOp>(loc, memref)
212  .getResult(0));
213  conditions.push_back(ownership.getIndicator());
214  }
215 
216  return success();
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // ValueComparator
221 //===----------------------------------------------------------------------===//
222 
223 bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const {
224  if (lhs == rhs)
225  return false;
226 
227  // Block arguments are less than results.
228  bool lhsIsBBArg = isa<BlockArgument>(lhs);
229  if (lhsIsBBArg != isa<BlockArgument>(rhs)) {
230  return lhsIsBBArg;
231  }
232 
233  Region *lhsRegion;
234  Region *rhsRegion;
235  if (lhsIsBBArg) {
236  auto lhsBBArg = llvm::cast<BlockArgument>(lhs);
237  auto rhsBBArg = llvm::cast<BlockArgument>(rhs);
238  if (lhsBBArg.getArgNumber() != rhsBBArg.getArgNumber()) {
239  return lhsBBArg.getArgNumber() < rhsBBArg.getArgNumber();
240  }
241  lhsRegion = lhsBBArg.getParentRegion();
242  rhsRegion = rhsBBArg.getParentRegion();
243  assert(lhsRegion != rhsRegion &&
244  "lhsRegion == rhsRegion implies lhs == rhs");
245  } else if (lhs.getDefiningOp() == rhs.getDefiningOp()) {
246  return llvm::cast<OpResult>(lhs).getResultNumber() <
247  llvm::cast<OpResult>(rhs).getResultNumber();
248  } else {
249  lhsRegion = lhs.getDefiningOp()->getParentRegion();
250  rhsRegion = rhs.getDefiningOp()->getParentRegion();
251  if (lhsRegion == rhsRegion) {
252  return lhs.getDefiningOp()->isBeforeInBlock(rhs.getDefiningOp());
253  }
254  }
255 
256  // lhsRegion != rhsRegion, so if we look at their ancestor chain, they
257  // - have different heights
258  // - or there's a spot where their region numbers differ
259  // - or their parent regions are the same and their parent ops are
260  // different.
261  while (lhsRegion && rhsRegion) {
262  if (lhsRegion->getRegionNumber() != rhsRegion->getRegionNumber()) {
263  return lhsRegion->getRegionNumber() < rhsRegion->getRegionNumber();
264  }
265  if (lhsRegion->getParentRegion() == rhsRegion->getParentRegion()) {
266  return lhsRegion->getParentOp()->isBeforeInBlock(
267  rhsRegion->getParentOp());
268  }
269  lhsRegion = lhsRegion->getParentRegion();
270  rhsRegion = rhsRegion->getParentRegion();
271  }
272  if (rhsRegion)
273  return true;
274  assert(lhsRegion && "this should only happen if lhs == rhs");
275  return false;
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // Implementation utilities
280 //===----------------------------------------------------------------------===//
281 
283  DeallocationState &state, Operation *op, ValueRange operands,
284  SmallVectorImpl<Value> &updatedOperandOwnerships) {
285  assert(op->hasTrait<OpTrait::IsTerminator>() && "must be a terminator");
286  assert(!op->hasSuccessors() && "must not have any successors");
287  // Collect the values to deallocate and retain and use them to create the
288  // dealloc operation.
289  OpBuilder builder(op);
290  Block *block = op->getBlock();
291  SmallVector<Value> memrefs, conditions, toRetain;
292  if (failed(state.getMemrefsAndConditionsToDeallocate(
293  builder, op->getLoc(), block, memrefs, conditions)))
294  return failure();
295 
296  state.getMemrefsToRetain(block, /*toBlock=*/nullptr, operands, toRetain);
297  if (memrefs.empty() && toRetain.empty())
298  return op;
299 
300  auto deallocOp = builder.create<bufferization::DeallocOp>(
301  op->getLoc(), memrefs, conditions, toRetain);
302 
303  // We want to replace the current ownership of the retained values with the
304  // result values of the dealloc operation as they are always unique.
305  state.resetOwnerships(deallocOp.getRetained(), block);
306  for (auto [retained, ownership] :
307  llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions()))
308  state.updateOwnership(retained, ownership, block);
309 
310  unsigned numMemrefOperands = llvm::count_if(operands, isMemref);
311  auto newOperandOwnerships =
312  deallocOp.getUpdatedConditions().take_front(numMemrefOperands);
313  updatedOperandOwnerships.append(newOperandOwnerships.begin(),
314  newOperandOwnerships.end());
315 
316  return op;
317 }
static bool isMemref(Value v)
static Value buildBoolValue(OpBuilder &builder, Location loc, bool value)
Block represents an ordered list of Operations.
Definition: Block.h:33
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
const ValueSetT & getLiveOut(Block *block) const
Returns a reference to a set containing live-out values (unordered).
Definition: Liveness.cpp:236
const ValueSetT & getLiveIn(Block *block) const
Returns a reference to a set containing live-in values (unordered).
Definition: Liveness.cpp:231
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:204
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:772
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
bool hasSuccessors()
Definition: Operation.h:705
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This class collects all the state that we need to perform the buffer deallocation pass with associate...
void addMemrefToDeallocate(Value memref, Block *block)
Remember the given 'memref' to deallocate it at the end of the 'block'.
Ownership getOwnership(Value memref, Block *block) const
Returns the ownership of 'memref' for the given basic block.
DeallocationState(Operation *op, SymbolTableCollection &symbolTables)
void resetOwnerships(ValueRange memrefs, Block *block)
Removes ownerships associated with all values in the passed range for 'block'.
void updateOwnership(Value memref, Ownership ownership, Block *block=nullptr)
Small helper function to update the ownership map by taking the current ownership ('Uninitialized' st...
std::pair< Value, Value > getMemrefWithUniqueOwnership(OpBuilder &builder, Value memref, Block *block)
Given an SSA value of MemRef type, this function queries the ownership and if it is not already in th...
LogicalResult getMemrefsAndConditionsToDeallocate(OpBuilder &builder, Location loc, Block *block, SmallVectorImpl< Value > &memrefs, SmallVectorImpl< Value > &conditions) const
For a given block, computes the list of MemRefs that potentially need to be deallocated at the end of...
void getLiveMemrefsIn(Block *block, SmallVectorImpl< Value > &memrefs)
Return a sorted list of MemRef values which are live at the start of the given block.
void dropMemrefToDeallocate(Value memref, Block *block)
Forget about a MemRef that we originally wanted to deallocate at the end of 'block',...
void getMemrefsToRetain(Block *fromBlock, Block *toBlock, ValueRange destOperands, SmallVectorImpl< Value > &toRetain) const
Given two basic blocks and the values passed via block arguments to the destination block,...
This class is used to track the ownership of values.
static Ownership getUnique(Value indicator)
Get an ownership value in 'Unique' state with 'indicator' as parameter.
Ownership getCombined(Ownership other) const
Get the join of the two-element subset {this,other}.
void combine(Ownership other)
Modify 'this' ownership to be the join of the current 'this' and 'other'.
Ownership()=default
Constructor that creates an 'Uninitialized' ownership.
bool isUnknown() const
Check if this ownership value is in the 'Unknown' state.
bool isUnique() const
Check if this ownership value is in the 'Unique' state.
static Ownership getUnknown()
Get an ownership value in 'Unknown' state.
Value getIndicator() const
If this ownership value is in 'Unique' state, this function can be used to get the indicator paramete...
bool isUninitialized() const
Check if this ownership value is in the 'Uninitialized' state.
static Ownership getUninitialized()
Get an ownership value in 'Uninitialized' state.
FailureOr< Operation * > insertDeallocOpForReturnLike(DeallocationState &state, Operation *op, ValueRange operands, SmallVectorImpl< Value > &updatedOperandOwnerships)
Insert a bufferization.dealloc operation right before op which has to be a terminator without any suc...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Compare two SSA values in a deterministic manner.
bool operator()(const Value &lhs, const Value &rhs) const