17 #include "llvm/ADT/SetOperations.h"
24 namespace bufferization {
26 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp.inc"
32 using namespace bufferization;
49 : indicator(indicator), state(State::Unique) {}
53 unknown.indicator =
Value();
54 unknown.state = State::Unknown;
61 return state == State::Uninitialized;
67 assert(
isUnique() &&
"must have unique ownership to get the indicator");
102 if (block ==
nullptr)
106 ownershipMap[{memref, block}].combine(ownership);
110 for (
Value val : memrefs)
115 return ownershipMap.lookup({memref, block});
119 memrefsToDeallocatePerBlock[block].push_back(memref);
123 llvm::erase(memrefsToDeallocatePerBlock[block], memref);
131 memrefs.append(liveMemrefs);
134 std::pair<Value, Value>
137 auto iter = ownershipMap.find({memref, block});
138 assert(iter != ownershipMap.end() &&
139 "Value must already have been registered in the ownership map");
152 builder.
create<bufferization::CloneOp>(memref.
getLoc(), memref);
154 Value newMemref = cloneOp.getResult();
156 memrefsToDeallocatePerBlock[newMemref.
getParentBlock()].push_back(newMemref);
157 return {newMemref, condition};
163 for (
Value operand : destOperands) {
166 toRetain.push_back(operand);
170 for (
auto val : liveness.
getLiveOut(fromBlock))
175 llvm::set_intersect(liveOut, liveness.
getLiveIn(toBlock));
180 std::sort(retainedByLiveness.begin(), retainedByLiveness.end(),
182 toRetain.append(retainedByLiveness);
189 for (
auto [i, memref] :
191 Ownership ownership = ownershipMap.lookup({memref, block});
194 "MemRef value does not have valid ownership");
198 if (
auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
199 memref = builder.
create<memref::ReinterpretCastOp>(
210 builder.
create<memref::ExtractStridedMetadataOp>(loc, memref)
227 bool lhsIsBBArg = isa<BlockArgument>(lhs);
228 if (lhsIsBBArg != isa<BlockArgument>(rhs)) {
235 auto lhsBBArg = llvm::cast<BlockArgument>(lhs);
236 auto rhsBBArg = llvm::cast<BlockArgument>(rhs);
237 if (lhsBBArg.getArgNumber() != rhsBBArg.getArgNumber()) {
238 return lhsBBArg.getArgNumber() < rhsBBArg.getArgNumber();
242 assert(lhsRegion != rhsRegion &&
243 "lhsRegion == rhsRegion implies lhs == rhs");
245 return llvm::cast<OpResult>(lhs).getResultNumber() <
246 llvm::cast<OpResult>(rhs).getResultNumber();
250 if (lhsRegion == rhsRegion) {
260 while (lhsRegion && rhsRegion) {
273 assert(lhsRegion &&
"this should only happen if lhs == rhs");
285 assert(!op->
hasSuccessors() &&
"must not have any successors");
291 if (failed(state.getMemrefsAndConditionsToDeallocate(
292 builder, op->
getLoc(), block, memrefs, conditions)))
295 state.getMemrefsToRetain(block,
nullptr, operands, toRetain);
296 if (memrefs.empty() && toRetain.empty())
299 auto deallocOp = builder.
create<bufferization::DeallocOp>(
300 op->
getLoc(), memrefs, conditions, toRetain);
304 state.resetOwnerships(deallocOp.getRetained(), block);
305 for (
auto [retained, ownership] :
306 llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions()))
307 state.updateOwnership(retained, ownership, block);
309 unsigned numMemrefOperands = llvm::count_if(operands,
isMemref);
310 auto newOperandOwnerships =
311 deallocOp.getUpdatedConditions().take_front(numMemrefOperands);
312 updatedOperandOwnerships.append(newOperandOwnerships.begin(),
313 newOperandOwnerships.end());
static bool isMemref(Value v)
static Value buildBoolValue(OpBuilder &builder, Location loc, bool value)
Block represents an ordered list of Operations.
IntegerAttr getIndexAttr(int64_t value)
BoolAttr getBoolAttr(bool value)
const ValueSetT & getLiveOut(Block *block) const
Returns a reference to a set containing live-out values (unordered).
const ValueSetT & getLiveIn(Block *block) const
Returns a reference to a set containing live-in values (unordered).
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
Region * getParentRegion()
Returns the region to which the instruction belongs.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This class collects all the state that we need to perform the buffer deallocation pass with associate...
void addMemrefToDeallocate(Value memref, Block *block)
Remember the given 'memref' to deallocate it at the end of the 'block'.
Ownership getOwnership(Value memref, Block *block) const
Returns the ownership of 'memref' for the given basic block.
void resetOwnerships(ValueRange memrefs, Block *block)
Removes ownerships associated with all values in the passed range for 'block'.
void updateOwnership(Value memref, Ownership ownership, Block *block=nullptr)
Small helper function to update the ownership map by taking the current ownership ('Uninitialized' st...
std::pair< Value, Value > getMemrefWithUniqueOwnership(OpBuilder &builder, Value memref, Block *block)
Given an SSA value of MemRef type, this function queries the ownership and if it is not already in th...
DeallocationState(Operation *op)
LogicalResult getMemrefsAndConditionsToDeallocate(OpBuilder &builder, Location loc, Block *block, SmallVectorImpl< Value > &memrefs, SmallVectorImpl< Value > &conditions) const
For a given block, computes the list of MemRefs that potentially need to be deallocated at the end of...
void getLiveMemrefsIn(Block *block, SmallVectorImpl< Value > &memrefs)
Return a sorted list of MemRef values which are live at the start of the given block.
void dropMemrefToDeallocate(Value memref, Block *block)
Forget about a MemRef that we originally wanted to deallocate at the end of 'block',...
void getMemrefsToRetain(Block *fromBlock, Block *toBlock, ValueRange destOperands, SmallVectorImpl< Value > &toRetain) const
Given two basic blocks and the values passed via block arguments to the destination block,...
This class is used to track the ownership of values.
static Ownership getUnique(Value indicator)
Get an ownership value in 'Unique' state with 'indicator' as parameter.
Ownership getCombined(Ownership other) const
Get the join of the two-element subset {this,other}.
void combine(Ownership other)
Modify 'this' ownership to be the join of the current 'this' and 'other'.
Ownership()=default
Constructor that creates an 'Uninitialized' ownership.
bool isUnknown() const
Check if this ownership value is in the 'Unknown' state.
bool isUnique() const
Check if this ownership value is in the 'Unique' state.
static Ownership getUnknown()
Get an ownership value in 'Unknown' state.
Value getIndicator() const
If this ownership value is in 'Unique' state, this function can be used to get the indicator paramete...
bool isUninitialized() const
Check if this ownership value is in the 'Uninitialized' state.
static Ownership getUninitialized()
Get an ownership value in 'Uninitialized' state.
FailureOr< Operation * > insertDeallocOpForReturnLike(DeallocationState &state, Operation *op, ValueRange operands, SmallVectorImpl< Value > &updatedOperandOwnerships)
Insert a bufferization.dealloc operation right before op which has to be a terminator without any suc...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Compare two SSA values in a deterministic manner.
bool operator()(const Value &lhs, const Value &rhs) const