MLIR  20.0.0git
BufferDeallocationOpInterface.cpp
Go to the documentation of this file.
1 //===- BufferDeallocationOpInterface.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/AsmState.h"
13 #include "mlir/IR/Matchers.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/IR/TypeUtilities.h"
16 #include "mlir/IR/Value.h"
17 #include "llvm/ADT/SetOperations.h"
18 
19 //===----------------------------------------------------------------------===//
20 // BufferDeallocationOpInterface
21 //===----------------------------------------------------------------------===//
22 
23 namespace mlir {
24 namespace bufferization {
25 
26 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp.inc"
27 
28 } // namespace bufferization
29 } // namespace mlir
30 
31 using namespace mlir;
32 using namespace bufferization;
33 
34 //===----------------------------------------------------------------------===//
35 // Helpers
36 //===----------------------------------------------------------------------===//
37 
38 static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
39  return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value));
40 }
41 
42 static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
43 
44 //===----------------------------------------------------------------------===//
45 // Ownership
46 //===----------------------------------------------------------------------===//
47 
49  : indicator(indicator), state(State::Unique) {}
50 
52  Ownership unknown;
53  unknown.indicator = Value();
54  unknown.state = State::Unknown;
55  return unknown;
56 }
57 Ownership Ownership::getUnique(Value indicator) { return Ownership(indicator); }
59 
61  return state == State::Uninitialized;
62 }
63 bool Ownership::isUnique() const { return state == State::Unique; }
64 bool Ownership::isUnknown() const { return state == State::Unknown; }
65 
67  assert(isUnique() && "must have unique ownership to get the indicator");
68  return indicator;
69 }
70 
72  if (other.isUninitialized())
73  return *this;
74  if (isUninitialized())
75  return other;
76 
77  if (!isUnique() || !other.isUnique())
78  return getUnknown();
79 
80  // Since we create a new constant i1 value for (almost) each use-site, we
81  // should compare the actual value rather than just the SSA Value to avoid
82  // unnecessary invalidations.
83  if (isEqualConstantIntOrValue(indicator, other.indicator))
84  return *this;
85 
86  // Return the join of the lattice if the indicator of both ownerships cannot
87  // be merged.
88  return getUnknown();
89 }
90 
91 void Ownership::combine(Ownership other) { *this = getCombined(other); }
92 
93 //===----------------------------------------------------------------------===//
94 // DeallocationState
95 //===----------------------------------------------------------------------===//
96 
98 
100  Block *block) {
101  // In most cases we care about the block where the value is defined.
102  if (block == nullptr)
103  block = memref.getParentBlock();
104 
105  // Update ownership of current memref itself.
106  ownershipMap[{memref, block}].combine(ownership);
107 }
108 
110  for (Value val : memrefs)
111  ownershipMap[{val, block}] = Ownership::getUninitialized();
112 }
113 
115  return ownershipMap.lookup({memref, block});
116 }
117 
119  memrefsToDeallocatePerBlock[block].push_back(memref);
120 }
121 
123  llvm::erase(memrefsToDeallocatePerBlock[block], memref);
124 }
125 
127  SmallVectorImpl<Value> &memrefs) {
128  SmallVector<Value> liveMemrefs(
129  llvm::make_filter_range(liveness.getLiveIn(block), isMemref));
130  llvm::sort(liveMemrefs, ValueComparator());
131  memrefs.append(liveMemrefs);
132 }
133 
134 std::pair<Value, Value>
136  Value memref, Block *block) {
137  auto iter = ownershipMap.find({memref, block});
138  assert(iter != ownershipMap.end() &&
139  "Value must already have been registered in the ownership map");
140 
141  Ownership ownership = iter->second;
142  if (ownership.isUnique())
143  return {memref, ownership.getIndicator()};
144 
145  // Instead of inserting a clone operation we could also insert a dealloc
146  // operation earlier in the block and use the updated ownerships returned by
147  // the op for the retained values. Alternatively, we could insert code to
148  // check aliasing at runtime and use this information to combine two unique
149  // ownerships more intelligently to not end up with an 'Unknown' ownership in
150  // the first place.
151  auto cloneOp =
152  builder.create<bufferization::CloneOp>(memref.getLoc(), memref);
153  Value condition = buildBoolValue(builder, memref.getLoc(), true);
154  Value newMemref = cloneOp.getResult();
155  updateOwnership(newMemref, condition);
156  memrefsToDeallocatePerBlock[newMemref.getParentBlock()].push_back(newMemref);
157  return {newMemref, condition};
158 }
159 
161  Block *fromBlock, Block *toBlock, ValueRange destOperands,
162  SmallVectorImpl<Value> &toRetain) const {
163  for (Value operand : destOperands) {
164  if (!isMemref(operand))
165  continue;
166  toRetain.push_back(operand);
167  }
168 
169  SmallPtrSet<Value, 16> liveOut;
170  for (auto val : liveness.getLiveOut(fromBlock))
171  if (isMemref(val))
172  liveOut.insert(val);
173 
174  if (toBlock)
175  llvm::set_intersect(liveOut, liveness.getLiveIn(toBlock));
176 
177  // liveOut has non-deterministic order because it was constructed by iterating
178  // over a hash-set.
179  SmallVector<Value> retainedByLiveness(liveOut.begin(), liveOut.end());
180  std::sort(retainedByLiveness.begin(), retainedByLiveness.end(),
181  ValueComparator());
182  toRetain.append(retainedByLiveness);
183 }
184 
186  OpBuilder &builder, Location loc, Block *block,
187  SmallVectorImpl<Value> &memrefs, SmallVectorImpl<Value> &conditions) const {
188 
189  for (auto [i, memref] :
190  llvm::enumerate(memrefsToDeallocatePerBlock.lookup(block))) {
191  Ownership ownership = ownershipMap.lookup({memref, block});
192  if (!ownership.isUnique())
193  return emitError(memref.getLoc(),
194  "MemRef value does not have valid ownership");
195 
196  // Simply cast unranked MemRefs to ranked memrefs with 0 dimensions such
197  // that we can call extract_strided_metadata on it.
198  if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
199  memref = builder.create<memref::ReinterpretCastOp>(
200  loc, memref,
201  /*offset=*/builder.getIndexAttr(0),
202  /*sizes=*/ArrayRef<OpFoldResult>{},
203  /*strides=*/ArrayRef<OpFoldResult>{});
204 
205  // Use the `memref.extract_strided_metadata` operation to get the base
206  // memref. This is needed because the same MemRef that was produced by the
207  // alloc operation has to be passed to the dealloc operation. Passing
208  // subviews, etc. to a dealloc operation is not allowed.
209  memrefs.push_back(
210  builder.create<memref::ExtractStridedMetadataOp>(loc, memref)
211  .getResult(0));
212  conditions.push_back(ownership.getIndicator());
213  }
214 
215  return success();
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // ValueComparator
220 //===----------------------------------------------------------------------===//
221 
222 bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const {
223  if (lhs == rhs)
224  return false;
225 
226  // Block arguments are less than results.
227  bool lhsIsBBArg = isa<BlockArgument>(lhs);
228  if (lhsIsBBArg != isa<BlockArgument>(rhs)) {
229  return lhsIsBBArg;
230  }
231 
232  Region *lhsRegion;
233  Region *rhsRegion;
234  if (lhsIsBBArg) {
235  auto lhsBBArg = llvm::cast<BlockArgument>(lhs);
236  auto rhsBBArg = llvm::cast<BlockArgument>(rhs);
237  if (lhsBBArg.getArgNumber() != rhsBBArg.getArgNumber()) {
238  return lhsBBArg.getArgNumber() < rhsBBArg.getArgNumber();
239  }
240  lhsRegion = lhsBBArg.getParentRegion();
241  rhsRegion = rhsBBArg.getParentRegion();
242  assert(lhsRegion != rhsRegion &&
243  "lhsRegion == rhsRegion implies lhs == rhs");
244  } else if (lhs.getDefiningOp() == rhs.getDefiningOp()) {
245  return llvm::cast<OpResult>(lhs).getResultNumber() <
246  llvm::cast<OpResult>(rhs).getResultNumber();
247  } else {
248  lhsRegion = lhs.getDefiningOp()->getParentRegion();
249  rhsRegion = rhs.getDefiningOp()->getParentRegion();
250  if (lhsRegion == rhsRegion) {
251  return lhs.getDefiningOp()->isBeforeInBlock(rhs.getDefiningOp());
252  }
253  }
254 
255  // lhsRegion != rhsRegion, so if we look at their ancestor chain, they
256  // - have different heights
257  // - or there's a spot where their region numbers differ
258  // - or their parent regions are the same and their parent ops are
259  // different.
260  while (lhsRegion && rhsRegion) {
261  if (lhsRegion->getRegionNumber() != rhsRegion->getRegionNumber()) {
262  return lhsRegion->getRegionNumber() < rhsRegion->getRegionNumber();
263  }
264  if (lhsRegion->getParentRegion() == rhsRegion->getParentRegion()) {
265  return lhsRegion->getParentOp()->isBeforeInBlock(
266  rhsRegion->getParentOp());
267  }
268  lhsRegion = lhsRegion->getParentRegion();
269  rhsRegion = rhsRegion->getParentRegion();
270  }
271  if (rhsRegion)
272  return true;
273  assert(lhsRegion && "this should only happen if lhs == rhs");
274  return false;
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // Implementation utilities
279 //===----------------------------------------------------------------------===//
280 
282  DeallocationState &state, Operation *op, ValueRange operands,
283  SmallVectorImpl<Value> &updatedOperandOwnerships) {
284  assert(op->hasTrait<OpTrait::IsTerminator>() && "must be a terminator");
285  assert(!op->hasSuccessors() && "must not have any successors");
286  // Collect the values to deallocate and retain and use them to create the
287  // dealloc operation.
288  OpBuilder builder(op);
289  Block *block = op->getBlock();
290  SmallVector<Value> memrefs, conditions, toRetain;
291  if (failed(state.getMemrefsAndConditionsToDeallocate(
292  builder, op->getLoc(), block, memrefs, conditions)))
293  return failure();
294 
295  state.getMemrefsToRetain(block, /*toBlock=*/nullptr, operands, toRetain);
296  if (memrefs.empty() && toRetain.empty())
297  return op;
298 
299  auto deallocOp = builder.create<bufferization::DeallocOp>(
300  op->getLoc(), memrefs, conditions, toRetain);
301 
302  // We want to replace the current ownership of the retained values with the
303  // result values of the dealloc operation as they are always unique.
304  state.resetOwnerships(deallocOp.getRetained(), block);
305  for (auto [retained, ownership] :
306  llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions()))
307  state.updateOwnership(retained, ownership, block);
308 
309  unsigned numMemrefOperands = llvm::count_if(operands, isMemref);
310  auto newOperandOwnerships =
311  deallocOp.getUpdatedConditions().take_front(numMemrefOperands);
312  updatedOperandOwnerships.append(newOperandOwnerships.begin(),
313  newOperandOwnerships.end());
314 
315  return op;
316 }
static bool isMemref(Value v)
static Value buildBoolValue(OpBuilder &builder, Location loc, bool value)
Block represents an ordered list of Operations.
Definition: Block.h:31
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:140
const ValueSetT & getLiveOut(Block *block) const
Returns a reference to a set containing live-out values (unordered).
Definition: Liveness.cpp:240
const ValueSetT & getLiveIn(Block *block) const
Returns a reference to a set containing live-in values (unordered).
Definition: Liveness.cpp:235
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:215
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:764
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
bool hasSuccessors()
Definition: Operation.h:701
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This class collects all the state that we need to perform the buffer deallocation pass with associate...
void addMemrefToDeallocate(Value memref, Block *block)
Remember the given 'memref' to deallocate it at the end of the 'block'.
Ownership getOwnership(Value memref, Block *block) const
Returns the ownership of 'memref' for the given basic block.
void resetOwnerships(ValueRange memrefs, Block *block)
Removes ownerships associated with all values in the passed range for 'block'.
void updateOwnership(Value memref, Ownership ownership, Block *block=nullptr)
Small helper function to update the ownership map by taking the current ownership ('Uninitialized' st...
std::pair< Value, Value > getMemrefWithUniqueOwnership(OpBuilder &builder, Value memref, Block *block)
Given an SSA value of MemRef type, this function queries the ownership and if it is not already in th...
LogicalResult getMemrefsAndConditionsToDeallocate(OpBuilder &builder, Location loc, Block *block, SmallVectorImpl< Value > &memrefs, SmallVectorImpl< Value > &conditions) const
For a given block, computes the list of MemRefs that potentially need to be deallocated at the end of...
void getLiveMemrefsIn(Block *block, SmallVectorImpl< Value > &memrefs)
Return a sorted list of MemRef values which are live at the start of the given block.
void dropMemrefToDeallocate(Value memref, Block *block)
Forget about a MemRef that we originally wanted to deallocate at the end of 'block',...
void getMemrefsToRetain(Block *fromBlock, Block *toBlock, ValueRange destOperands, SmallVectorImpl< Value > &toRetain) const
Given two basic blocks and the values passed via block arguments to the destination block,...
This class is used to track the ownership of values.
static Ownership getUnique(Value indicator)
Get an ownership value in 'Unique' state with 'indicator' as parameter.
Ownership getCombined(Ownership other) const
Get the join of the two-element subset {this,other}.
void combine(Ownership other)
Modify 'this' ownership to be the join of the current 'this' and 'other'.
Ownership()=default
Constructor that creates an 'Uninitialized' ownership.
bool isUnknown() const
Check if this ownership value is in the 'Unknown' state.
bool isUnique() const
Check if this ownership value is in the 'Unique' state.
static Ownership getUnknown()
Get an ownership value in 'Unknown' state.
Value getIndicator() const
If this ownership value is in 'Unique' state, this function can be used to get the indicator paramete...
bool isUninitialized() const
Check if this ownership value is in the 'Uninitialized' state.
static Ownership getUninitialized()
Get an ownership value in 'Uninitialized' state.
FailureOr< Operation * > insertDeallocOpForReturnLike(DeallocationState &state, Operation *op, ValueRange operands, SmallVectorImpl< Value > &updatedOperandOwnerships)
Insert a bufferization.dealloc operation right before op which has to be a terminator without any suc...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Compare two SSA values in a deterministic manner.
bool operator()(const Value &lhs, const Value &rhs) const