MLIR 23.0.0git
BufferDeallocationOpInterface.cpp
Go to the documentation of this file.
1//===- BufferDeallocationOpInterface.cpp ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/IR/AsmState.h"
13#include "mlir/IR/Operation.h"
15#include "mlir/IR/Value.h"
16#include "llvm/ADT/SetOperations.h"
17
18//===----------------------------------------------------------------------===//
19// BufferDeallocationOpInterface
20//===----------------------------------------------------------------------===//
21
22namespace mlir {
23namespace bufferization {
24
25#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp.inc"
26
27} // namespace bufferization
28} // namespace mlir
29
30using namespace mlir;
31using namespace bufferization;
32
33//===----------------------------------------------------------------------===//
34// Helpers
35//===----------------------------------------------------------------------===//
36
37static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
38 return arith::ConstantOp::create(builder, loc, builder.getBoolAttr(value));
39}
40
41static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
42
43//===----------------------------------------------------------------------===//
44// Ownership
45//===----------------------------------------------------------------------===//
46
48 : indicator(indicator), state(State::Unique) {}
49
51 Ownership unknown;
52 unknown.indicator = Value();
53 unknown.state = State::Unknown;
54 return unknown;
55}
56Ownership Ownership::getUnique(Value indicator) { return Ownership(indicator); }
58
60 return state == State::Uninitialized;
61}
62bool Ownership::isUnique() const { return state == State::Unique; }
63bool Ownership::isUnknown() const { return state == State::Unknown; }
64
66 assert(isUnique() && "must have unique ownership to get the indicator");
67 return indicator;
68}
69
71 if (other.isUninitialized())
72 return *this;
73 if (isUninitialized())
74 return other;
75
76 if (!isUnique() || !other.isUnique())
77 return getUnknown();
78
79 // Since we create a new constant i1 value for (almost) each use-site, we
80 // should compare the actual value rather than just the SSA Value to avoid
81 // unnecessary invalidations.
82 if (isEqualConstantIntOrValue(indicator, other.indicator))
83 return *this;
84
85 // Return the join of the lattice if the indicator of both ownerships cannot
86 // be merged.
87 return getUnknown();
88}
89
90void Ownership::combine(Ownership other) { *this = getCombined(other); }
91
92//===----------------------------------------------------------------------===//
93// DeallocationState
94//===----------------------------------------------------------------------===//
95
97 SymbolTableCollection &symbolTables)
98 : symbolTable(symbolTables), liveness(op) {}
99
101 Block *block) {
102 // In most cases we care about the block where the value is defined.
103 if (block == nullptr)
104 block = memref.getParentBlock();
105
106 // Update ownership of current memref itself.
107 ownershipMap[{memref, block}].combine(ownership);
108}
109
111 for (Value val : memrefs)
112 ownershipMap[{val, block}] = Ownership::getUninitialized();
113}
114
116 return ownershipMap.lookup({memref, block});
117}
118
120 memrefsToDeallocatePerBlock[block].push_back(memref);
121}
122
124 llvm::erase(memrefsToDeallocatePerBlock[block], memref);
125}
126
127void DeallocationState::mapValue(Value oldValue, Value newValue) {
128 valueMapping[oldValue] = newValue;
129}
130
132 SmallVectorImpl<Value> &memrefs) {
133 SmallVector<Value> liveMemrefs;
134 for (Value val : liveness.getLiveIn(block)) {
135 // Translate any value that was replaced (e.g., by appendOpResults) to its
136 // current equivalent before checking whether it is a MemRef.
137 if (Value mapped = valueMapping.lookup(val))
138 val = mapped;
139 if (isMemref(val))
140 liveMemrefs.push_back(val);
141 }
142 llvm::sort(liveMemrefs, ValueComparator());
143 memrefs.append(liveMemrefs);
144}
145
146std::pair<Value, Value>
148 Value memref, Block *block) {
149 auto iter = ownershipMap.find({memref, block});
150 assert(iter != ownershipMap.end() &&
151 "Value must already have been registered in the ownership map");
152
153 Ownership ownership = iter->second;
154 if (ownership.isUnique())
155 return {memref, ownership.getIndicator()};
156
157 // Instead of inserting a clone operation we could also insert a dealloc
158 // operation earlier in the block and use the updated ownerships returned by
159 // the op for the retained values. Alternatively, we could insert code to
160 // check aliasing at runtime and use this information to combine two unique
161 // ownerships more intelligently to not end up with an 'Unknown' ownership in
162 // the first place.
163 auto cloneOp =
164 bufferization::CloneOp::create(builder, memref.getLoc(), memref);
165 Value condition = buildBoolValue(builder, memref.getLoc(), true);
166 Value newMemref = cloneOp.getResult();
167 updateOwnership(newMemref, condition);
168 memrefsToDeallocatePerBlock[newMemref.getParentBlock()].push_back(newMemref);
169 return {newMemref, condition};
170}
171
173 Block *fromBlock, Block *toBlock, ValueRange destOperands,
174 SmallVectorImpl<Value> &toRetain) const {
175 for (Value operand : destOperands) {
176 if (!isMemref(operand))
177 continue;
178 toRetain.push_back(operand);
179 }
180
181 // Translate any value replaced during the transformation (e.g., when an op
182 // was cloned with extra results via appendOpResults) before checking whether
183 // it is a MemRef. The liveness analysis is computed once and may contain
184 // stale values after IR modifications.
185 auto translateValue = [&](Value val) -> Value {
186 if (Value mapped = valueMapping.lookup(val))
187 return mapped;
188 return val;
189 };
190
192 for (auto val : liveness.getLiveOut(fromBlock)) {
193 val = translateValue(val);
194 if (isMemref(val))
195 liveOut.insert(val);
196 }
197
198 if (toBlock) {
200 for (auto val : liveness.getLiveIn(toBlock)) {
201 val = translateValue(val);
202 if (isMemref(val))
203 liveIn.insert(val);
204 }
205 llvm::set_intersect(liveOut, liveIn);
206 }
207
208 // liveOut has non-deterministic order because it was constructed by iterating
209 // over a hash-set.
210 SmallVector<Value> retainedByLiveness(liveOut.begin(), liveOut.end());
211 llvm::sort(retainedByLiveness, ValueComparator());
212 toRetain.append(retainedByLiveness);
213}
214
216 OpBuilder &builder, Location loc, Block *block,
217 SmallVectorImpl<Value> &memrefs, SmallVectorImpl<Value> &conditions) const {
218
219 for (auto [i, memref] :
220 llvm::enumerate(memrefsToDeallocatePerBlock.lookup(block))) {
221 Ownership ownership = ownershipMap.lookup({memref, block});
222 if (!ownership.isUnique())
223 return emitError(memref.getLoc(),
224 "MemRef value does not have valid ownership");
225
226 // Simply cast unranked MemRefs to ranked memrefs with 0 dimensions such
227 // that we can call extract_strided_metadata on it.
228 if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
229 memref = memref::ReinterpretCastOp::create(
230 builder, loc, memref,
231 /*offset=*/builder.getIndexAttr(0),
232 /*sizes=*/ArrayRef<OpFoldResult>{},
233 /*strides=*/ArrayRef<OpFoldResult>{});
234
235 // Use the `memref.extract_strided_metadata` operation to get the base
236 // memref. This is needed because the same MemRef that was produced by the
237 // alloc operation has to be passed to the dealloc operation. Passing
238 // subviews, etc. to a dealloc operation is not allowed.
239 memrefs.push_back(
240 memref::ExtractStridedMetadataOp::create(builder, loc, memref)
241 .getResult(0));
242 conditions.push_back(ownership.getIndicator());
243 }
244
245 return success();
246}
247
248//===----------------------------------------------------------------------===//
249// ValueComparator
250//===----------------------------------------------------------------------===//
251
252bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const {
253 if (lhs == rhs)
254 return false;
255
256 // Block arguments are less than results.
257 bool lhsIsBBArg = isa<BlockArgument>(lhs);
258 if (lhsIsBBArg != isa<BlockArgument>(rhs)) {
259 return lhsIsBBArg;
260 }
261
262 Region *lhsRegion;
263 Region *rhsRegion;
264 if (lhsIsBBArg) {
265 auto lhsBBArg = llvm::cast<BlockArgument>(lhs);
266 auto rhsBBArg = llvm::cast<BlockArgument>(rhs);
267 if (lhsBBArg.getArgNumber() != rhsBBArg.getArgNumber()) {
268 return lhsBBArg.getArgNumber() < rhsBBArg.getArgNumber();
269 }
270 lhsRegion = lhsBBArg.getParentRegion();
271 rhsRegion = rhsBBArg.getParentRegion();
272 assert(lhsRegion != rhsRegion &&
273 "lhsRegion == rhsRegion implies lhs == rhs");
274 } else if (lhs.getDefiningOp() == rhs.getDefiningOp()) {
275 return llvm::cast<OpResult>(lhs).getResultNumber() <
276 llvm::cast<OpResult>(rhs).getResultNumber();
277 } else {
278 lhsRegion = lhs.getDefiningOp()->getParentRegion();
279 rhsRegion = rhs.getDefiningOp()->getParentRegion();
280 if (lhsRegion == rhsRegion) {
281 Block *lhsBlock = lhs.getDefiningOp()->getBlock();
282 Block *rhsBlock = rhs.getDefiningOp()->getBlock();
283 if (lhsBlock == rhsBlock) {
284 return lhs.getDefiningOp()->isBeforeInBlock(rhs.getDefiningOp());
285 }
286 return lhsBlock->computeBlockNumber() < rhsBlock->computeBlockNumber();
287 }
288 }
289
290 // lhsRegion != rhsRegion, so if we look at their ancestor chain, they
291 // - have different heights
292 // - or there's a spot where their region numbers differ
293 // - or their parent regions are the same and their parent ops are
294 // different.
295 while (lhsRegion && rhsRegion) {
296 if (lhsRegion->getRegionNumber() != rhsRegion->getRegionNumber()) {
297 return lhsRegion->getRegionNumber() < rhsRegion->getRegionNumber();
298 }
299 if (lhsRegion->getParentRegion() == rhsRegion->getParentRegion()) {
300 Block *lhsParentOpBlock = lhsRegion->getParentOp()->getBlock();
301 Block *rhsParentOpBlock = rhsRegion->getParentOp()->getBlock();
302 if (lhsParentOpBlock == rhsParentOpBlock) {
303 return lhsRegion->getParentOp()->isBeforeInBlock(
304 rhsRegion->getParentOp());
305 }
306 return lhsParentOpBlock->computeBlockNumber() <
307 rhsParentOpBlock->computeBlockNumber();
308 }
309 lhsRegion = lhsRegion->getParentRegion();
310 rhsRegion = rhsRegion->getParentRegion();
311 }
312 if (rhsRegion)
313 return true;
314 assert(lhsRegion && "this should only happen if lhs == rhs");
315 return false;
316}
317
318//===----------------------------------------------------------------------===//
319// Implementation utilities
320//===----------------------------------------------------------------------===//
321
323 DeallocationState &state, Operation *op, ValueRange operands,
324 SmallVectorImpl<Value> &updatedOperandOwnerships) {
325 assert(op->hasTrait<OpTrait::IsTerminator>() && "must be a terminator");
326 assert(!op->hasSuccessors() && "must not have any successors");
327 // Collect the values to deallocate and retain and use them to create the
328 // dealloc operation.
329 OpBuilder builder(op);
330 Block *block = op->getBlock();
331 SmallVector<Value> memrefs, conditions, toRetain;
333 builder, op->getLoc(), block, memrefs, conditions)))
334 return failure();
335
336 state.getMemrefsToRetain(block, /*toBlock=*/nullptr, operands, toRetain);
337 if (memrefs.empty() && toRetain.empty())
338 return op;
339
340 auto deallocOp = bufferization::DeallocOp::create(
341 builder, op->getLoc(), memrefs, conditions, toRetain);
342
343 // We want to replace the current ownership of the retained values with the
344 // result values of the dealloc operation as they are always unique.
345 state.resetOwnerships(deallocOp.getRetained(), block);
346 for (auto [retained, ownership] :
347 llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions()))
348 state.updateOwnership(retained, ownership, block);
349
350 unsigned numMemrefOperands = llvm::count_if(operands, isMemref);
351 auto newOperandOwnerships =
352 deallocOp.getUpdatedConditions().take_front(numMemrefOperands);
353 updatedOperandOwnerships.append(newOperandOwnerships.begin(),
354 newOperandOwnerships.end());
355
356 return op;
357}
return success()
static bool isMemref(Value v)
static Value buildBoolValue(OpBuilder &builder, Location loc, bool value)
lhs
Block represents an ordered list of Operations.
Definition Block.h:33
unsigned computeBlockNumber()
Compute the position of this block within its parent region using an O(N) linear scan.
Definition Block.cpp:144
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:778
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:234
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
bool hasSuccessors()
Definition Operation.h:734
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition Region.cpp:45
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition Region.cpp:62
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
This class collects all the state that we need to perform the buffer deallocation pass with associate...
void addMemrefToDeallocate(Value memref, Block *block)
Remember the given 'memref' to deallocate it at the end of the 'block'.
Ownership getOwnership(Value memref, Block *block) const
Returns the ownership of 'memref' for the given basic block.
DeallocationState(Operation *op, SymbolTableCollection &symbolTables)
void resetOwnerships(ValueRange memrefs, Block *block)
Removes ownerships associated with all values in the passed range for 'block'.
void updateOwnership(Value memref, Ownership ownership, Block *block=nullptr)
Small helper function to update the ownership map by taking the current ownership ('Uninitialized' st...
std::pair< Value, Value > getMemrefWithUniqueOwnership(OpBuilder &builder, Value memref, Block *block)
Given an SSA value of MemRef type, this function queries the ownership and if it is not already in th...
void mapValue(Value oldValue, Value newValue)
Register that 'oldValue' has been replaced by 'newValue'.
LogicalResult getMemrefsAndConditionsToDeallocate(OpBuilder &builder, Location loc, Block *block, SmallVectorImpl< Value > &memrefs, SmallVectorImpl< Value > &conditions) const
For a given block, computes the list of MemRefs that potentially need to be deallocated at the end of...
void getLiveMemrefsIn(Block *block, SmallVectorImpl< Value > &memrefs)
Return a sorted list of MemRef values which are live at the start of the given block.
void dropMemrefToDeallocate(Value memref, Block *block)
Forget about a MemRef that we originally wanted to deallocate at the end of 'block',...
void getMemrefsToRetain(Block *fromBlock, Block *toBlock, ValueRange destOperands, SmallVectorImpl< Value > &toRetain) const
Given two basic blocks and the values passed via block arguments to the destination block,...
This class is used to track the ownership of values.
static Ownership getUnique(Value indicator)
Get an ownership value in 'Unique' state with 'indicator' as parameter.
Ownership getCombined(Ownership other) const
Get the join of the two-element subset {this,other}.
void combine(Ownership other)
Modify 'this' ownership to be the join of the current 'this' and 'other'.
Ownership()=default
Constructor that creates an 'Uninitialized' ownership.
bool isUnknown() const
Check if this ownership value is in the 'Unknown' state.
bool isUnique() const
Check if this ownership value is in the 'Unique' state.
static Ownership getUnknown()
Get an ownership value in 'Unknown' state.
Value getIndicator() const
If this ownership value is in 'Unique' state, this function can be used to get the indicator paramete...
bool isUninitialized() const
Check if this ownership value is in the 'Uninitialized' state.
static Ownership getUninitialized()
Get an ownership value in 'Uninitialized' state.
FailureOr< Operation * > insertDeallocOpForReturnLike(DeallocationState &state, Operation *op, ValueRange operands, SmallVectorImpl< Value > &updatedOperandOwnerships)
Insert a bufferization.dealloc operation right before op which has to be a terminator without any suc...
Include the generated interface declarations.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Compare two SSA values in a deterministic manner.
bool operator()(const Value &lhs, const Value &rhs) const