MLIR 22.0.0git
BufferDeallocationOpInterface.cpp
Go to the documentation of this file.
1//===- BufferDeallocationOpInterface.cpp ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/IR/AsmState.h"
13#include "mlir/IR/Operation.h"
15#include "mlir/IR/Value.h"
16#include "llvm/ADT/SetOperations.h"
17
18//===----------------------------------------------------------------------===//
19// BufferDeallocationOpInterface
20//===----------------------------------------------------------------------===//
21
22namespace mlir {
23namespace bufferization {
24
25#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp.inc"
26
27} // namespace bufferization
28} // namespace mlir
29
30using namespace mlir;
31using namespace bufferization;
32
33//===----------------------------------------------------------------------===//
34// Helpers
35//===----------------------------------------------------------------------===//
36
37static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
38 return arith::ConstantOp::create(builder, loc, builder.getBoolAttr(value));
39}
40
41static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
42
43//===----------------------------------------------------------------------===//
44// Ownership
45//===----------------------------------------------------------------------===//
46
48 : indicator(indicator), state(State::Unique) {}
49
51 Ownership unknown;
52 unknown.indicator = Value();
53 unknown.state = State::Unknown;
54 return unknown;
55}
56Ownership Ownership::getUnique(Value indicator) { return Ownership(indicator); }
58
60 return state == State::Uninitialized;
61}
62bool Ownership::isUnique() const { return state == State::Unique; }
63bool Ownership::isUnknown() const { return state == State::Unknown; }
64
66 assert(isUnique() && "must have unique ownership to get the indicator");
67 return indicator;
68}
69
71 if (other.isUninitialized())
72 return *this;
73 if (isUninitialized())
74 return other;
75
76 if (!isUnique() || !other.isUnique())
77 return getUnknown();
78
79 // Since we create a new constant i1 value for (almost) each use-site, we
80 // should compare the actual value rather than just the SSA Value to avoid
81 // unnecessary invalidations.
82 if (isEqualConstantIntOrValue(indicator, other.indicator))
83 return *this;
84
85 // Return the join of the lattice if the indicator of both ownerships cannot
86 // be merged.
87 return getUnknown();
88}
89
90void Ownership::combine(Ownership other) { *this = getCombined(other); }
91
92//===----------------------------------------------------------------------===//
93// DeallocationState
94//===----------------------------------------------------------------------===//
95
97 SymbolTableCollection &symbolTables)
98 : symbolTable(symbolTables), liveness(op) {}
99
101 Block *block) {
102 // In most cases we care about the block where the value is defined.
103 if (block == nullptr)
104 block = memref.getParentBlock();
105
106 // Update ownership of current memref itself.
107 ownershipMap[{memref, block}].combine(ownership);
108}
109
111 for (Value val : memrefs)
112 ownershipMap[{val, block}] = Ownership::getUninitialized();
113}
114
116 return ownershipMap.lookup({memref, block});
117}
118
120 memrefsToDeallocatePerBlock[block].push_back(memref);
121}
122
124 llvm::erase(memrefsToDeallocatePerBlock[block], memref);
125}
126
128 SmallVectorImpl<Value> &memrefs) {
129 SmallVector<Value> liveMemrefs(
130 llvm::make_filter_range(liveness.getLiveIn(block), isMemref));
131 llvm::sort(liveMemrefs, ValueComparator());
132 memrefs.append(liveMemrefs);
133}
134
135std::pair<Value, Value>
137 Value memref, Block *block) {
138 auto iter = ownershipMap.find({memref, block});
139 assert(iter != ownershipMap.end() &&
140 "Value must already have been registered in the ownership map");
141
142 Ownership ownership = iter->second;
143 if (ownership.isUnique())
144 return {memref, ownership.getIndicator()};
145
146 // Instead of inserting a clone operation we could also insert a dealloc
147 // operation earlier in the block and use the updated ownerships returned by
148 // the op for the retained values. Alternatively, we could insert code to
149 // check aliasing at runtime and use this information to combine two unique
150 // ownerships more intelligently to not end up with an 'Unknown' ownership in
151 // the first place.
152 auto cloneOp =
153 bufferization::CloneOp::create(builder, memref.getLoc(), memref);
154 Value condition = buildBoolValue(builder, memref.getLoc(), true);
155 Value newMemref = cloneOp.getResult();
156 updateOwnership(newMemref, condition);
157 memrefsToDeallocatePerBlock[newMemref.getParentBlock()].push_back(newMemref);
158 return {newMemref, condition};
159}
160
162 Block *fromBlock, Block *toBlock, ValueRange destOperands,
163 SmallVectorImpl<Value> &toRetain) const {
164 for (Value operand : destOperands) {
165 if (!isMemref(operand))
166 continue;
167 toRetain.push_back(operand);
168 }
169
171 for (auto val : liveness.getLiveOut(fromBlock))
172 if (isMemref(val))
173 liveOut.insert(val);
174
175 if (toBlock)
176 llvm::set_intersect(liveOut, liveness.getLiveIn(toBlock));
177
178 // liveOut has non-deterministic order because it was constructed by iterating
179 // over a hash-set.
180 SmallVector<Value> retainedByLiveness(liveOut.begin(), liveOut.end());
181 llvm::sort(retainedByLiveness, ValueComparator());
182 toRetain.append(retainedByLiveness);
183}
184
186 OpBuilder &builder, Location loc, Block *block,
187 SmallVectorImpl<Value> &memrefs, SmallVectorImpl<Value> &conditions) const {
188
189 for (auto [i, memref] :
190 llvm::enumerate(memrefsToDeallocatePerBlock.lookup(block))) {
191 Ownership ownership = ownershipMap.lookup({memref, block});
192 if (!ownership.isUnique())
193 return emitError(memref.getLoc(),
194 "MemRef value does not have valid ownership");
195
196 // Simply cast unranked MemRefs to ranked memrefs with 0 dimensions such
197 // that we can call extract_strided_metadata on it.
198 if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
199 memref = memref::ReinterpretCastOp::create(
200 builder, loc, memref,
201 /*offset=*/builder.getIndexAttr(0),
202 /*sizes=*/ArrayRef<OpFoldResult>{},
203 /*strides=*/ArrayRef<OpFoldResult>{});
204
205 // Use the `memref.extract_strided_metadata` operation to get the base
206 // memref. This is needed because the same MemRef that was produced by the
207 // alloc operation has to be passed to the dealloc operation. Passing
208 // subviews, etc. to a dealloc operation is not allowed.
209 memrefs.push_back(
210 memref::ExtractStridedMetadataOp::create(builder, loc, memref)
211 .getResult(0));
212 conditions.push_back(ownership.getIndicator());
213 }
214
215 return success();
216}
217
218//===----------------------------------------------------------------------===//
219// ValueComparator
220//===----------------------------------------------------------------------===//
221
222bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const {
223 if (lhs == rhs)
224 return false;
225
226 // Block arguments are less than results.
227 bool lhsIsBBArg = isa<BlockArgument>(lhs);
228 if (lhsIsBBArg != isa<BlockArgument>(rhs)) {
229 return lhsIsBBArg;
230 }
231
232 Region *lhsRegion;
233 Region *rhsRegion;
234 if (lhsIsBBArg) {
235 auto lhsBBArg = llvm::cast<BlockArgument>(lhs);
236 auto rhsBBArg = llvm::cast<BlockArgument>(rhs);
237 if (lhsBBArg.getArgNumber() != rhsBBArg.getArgNumber()) {
238 return lhsBBArg.getArgNumber() < rhsBBArg.getArgNumber();
239 }
240 lhsRegion = lhsBBArg.getParentRegion();
241 rhsRegion = rhsBBArg.getParentRegion();
242 assert(lhsRegion != rhsRegion &&
243 "lhsRegion == rhsRegion implies lhs == rhs");
244 } else if (lhs.getDefiningOp() == rhs.getDefiningOp()) {
245 return llvm::cast<OpResult>(lhs).getResultNumber() <
246 llvm::cast<OpResult>(rhs).getResultNumber();
247 } else {
248 lhsRegion = lhs.getDefiningOp()->getParentRegion();
249 rhsRegion = rhs.getDefiningOp()->getParentRegion();
250 if (lhsRegion == rhsRegion) {
251 return lhs.getDefiningOp()->isBeforeInBlock(rhs.getDefiningOp());
252 }
253 }
254
255 // lhsRegion != rhsRegion, so if we look at their ancestor chain, they
256 // - have different heights
257 // - or there's a spot where their region numbers differ
258 // - or their parent regions are the same and their parent ops are
259 // different.
260 while (lhsRegion && rhsRegion) {
261 if (lhsRegion->getRegionNumber() != rhsRegion->getRegionNumber()) {
262 return lhsRegion->getRegionNumber() < rhsRegion->getRegionNumber();
263 }
264 if (lhsRegion->getParentRegion() == rhsRegion->getParentRegion()) {
265 return lhsRegion->getParentOp()->isBeforeInBlock(
266 rhsRegion->getParentOp());
267 }
268 lhsRegion = lhsRegion->getParentRegion();
269 rhsRegion = rhsRegion->getParentRegion();
270 }
271 if (rhsRegion)
272 return true;
273 assert(lhsRegion && "this should only happen if lhs == rhs");
274 return false;
275}
276
277//===----------------------------------------------------------------------===//
278// Implementation utilities
279//===----------------------------------------------------------------------===//
280
282 DeallocationState &state, Operation *op, ValueRange operands,
283 SmallVectorImpl<Value> &updatedOperandOwnerships) {
284 assert(op->hasTrait<OpTrait::IsTerminator>() && "must be a terminator");
285 assert(!op->hasSuccessors() && "must not have any successors");
286 // Collect the values to deallocate and retain and use them to create the
287 // dealloc operation.
288 OpBuilder builder(op);
289 Block *block = op->getBlock();
290 SmallVector<Value> memrefs, conditions, toRetain;
292 builder, op->getLoc(), block, memrefs, conditions)))
293 return failure();
294
295 state.getMemrefsToRetain(block, /*toBlock=*/nullptr, operands, toRetain);
296 if (memrefs.empty() && toRetain.empty())
297 return op;
298
299 auto deallocOp = bufferization::DeallocOp::create(
300 builder, op->getLoc(), memrefs, conditions, toRetain);
301
302 // We want to replace the current ownership of the retained values with the
303 // result values of the dealloc operation as they are always unique.
304 state.resetOwnerships(deallocOp.getRetained(), block);
305 for (auto [retained, ownership] :
306 llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions()))
307 state.updateOwnership(retained, ownership, block);
308
309 unsigned numMemrefOperands = llvm::count_if(operands, isMemref);
310 auto newOperandOwnerships =
311 deallocOp.getUpdatedConditions().take_front(numMemrefOperands);
312 updatedOperandOwnerships.append(newOperandOwnerships.begin(),
313 newOperandOwnerships.end());
314
315 return op;
316}
return success()
static bool isMemref(Value v)
static Value buildBoolValue(OpBuilder &builder, Location loc, bool value)
lhs
Block represents an ordered list of Operations.
Definition Block.h:33
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
bool hasSuccessors()
Definition Operation.h:705
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition Region.cpp:45
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition Region.cpp:62
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
This class collects all the state that we need to perform the buffer deallocation pass with associate...
void addMemrefToDeallocate(Value memref, Block *block)
Remember the given 'memref' to deallocate it at the end of the 'block'.
Ownership getOwnership(Value memref, Block *block) const
Returns the ownership of 'memref' for the given basic block.
DeallocationState(Operation *op, SymbolTableCollection &symbolTables)
void resetOwnerships(ValueRange memrefs, Block *block)
Removes ownerships associated with all values in the passed range for 'block'.
void updateOwnership(Value memref, Ownership ownership, Block *block=nullptr)
Small helper function to update the ownership map by taking the current ownership ('Uninitialized' st...
std::pair< Value, Value > getMemrefWithUniqueOwnership(OpBuilder &builder, Value memref, Block *block)
Given an SSA value of MemRef type, this function queries the ownership and if it is not already in th...
LogicalResult getMemrefsAndConditionsToDeallocate(OpBuilder &builder, Location loc, Block *block, SmallVectorImpl< Value > &memrefs, SmallVectorImpl< Value > &conditions) const
For a given block, computes the list of MemRefs that potentially need to be deallocated at the end of...
void getLiveMemrefsIn(Block *block, SmallVectorImpl< Value > &memrefs)
Return a sorted list of MemRef values which are live at the start of the given block.
void dropMemrefToDeallocate(Value memref, Block *block)
Forget about a MemRef that we originally wanted to deallocate at the end of 'block',...
void getMemrefsToRetain(Block *fromBlock, Block *toBlock, ValueRange destOperands, SmallVectorImpl< Value > &toRetain) const
Given two basic blocks and the values passed via block arguments to the destination block,...
This class is used to track the ownership of values.
static Ownership getUnique(Value indicator)
Get an ownership value in 'Unique' state with 'indicator' as parameter.
Ownership getCombined(Ownership other) const
Get the join of the two-element subset {this,other}.
void combine(Ownership other)
Modify 'this' ownership to be the join of the current 'this' and 'other'.
Ownership()=default
Constructor that creates an 'Uninitialized' ownership.
bool isUnknown() const
Check if this ownership value is in the 'Unknown' state.
bool isUnique() const
Check if this ownership value is in the 'Unique' state.
static Ownership getUnknown()
Get an ownership value in 'Unknown' state.
Value getIndicator() const
If this ownership value is in 'Unique' state, this function can be used to get the indicator paramete...
bool isUninitialized() const
Check if this ownership value is in the 'Uninitialized' state.
static Ownership getUninitialized()
Get an ownership value in 'Uninitialized' state.
FailureOr< Operation * > insertDeallocOpForReturnLike(DeallocationState &state, Operation *op, ValueRange operands, SmallVectorImpl< Value > &updatedOperandOwnerships)
Insert a bufferization.dealloc operation right before op which has to be a terminator without any suc...
Include the generated interface declarations.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Compare two SSA values in a deterministic manner.
bool operator()(const Value &lhs, const Value &rhs) const