64struct CondBranchOpInterface
65 :
public BufferDeallocationOpInterface::ExternalModel<CondBranchOpInterface,
67 FailureOr<Operation *> process(Operation *op, DeallocationState &state,
68 const DeallocationOptions &
options)
const {
69 OpBuilder builder(op);
70 auto condBr = cast<cf::CondBranchOp>(op);
74 SmallVector<Value> memrefs, conditions;
76 builder, condBr.getLoc(), condBr->getBlock(), memrefs, conditions)))
81 auto insertDeallocForBranch =
82 [&](
Block *
target, MutableOperandRange destOperands,
83 const std::function<Value(Value)> &conditionModifier,
84 DenseMap<Value, Value> &mapping) -> DeallocOp {
85 SmallVector<Value> toRetain;
87 destOperands.getAsOperandRange(), toRetain);
88 SmallVector<Value> adaptedConditions(
89 llvm::map_range(conditions, conditionModifier));
90 auto deallocOp = bufferization::DeallocOp::create(
91 builder, condBr.getLoc(), memrefs, adaptedConditions, toRetain);
93 for (
auto [retained, ownership] : llvm::zip(
94 deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
96 mapping[retained] = ownership;
98 SmallVector<Value> replacements, ownerships;
99 for (OpOperand &operand : destOperands) {
100 replacements.push_back(operand.get());
102 assert(mapping.contains(operand.get()) &&
103 "Should be contained at this point");
104 ownerships.push_back(mapping[operand.get()]);
107 replacements.append(ownerships);
108 destOperands.assign(replacements);
114 DenseMap<Value, Value> thenMapping, elseMapping;
115 DeallocOp thenTakenDeallocOp = insertDeallocForBranch(
116 condBr.getTrueDest(), condBr.getTrueDestOperandsMutable(),
118 return arith::AndIOp::create(builder, condBr.getLoc(), cond,
119 condBr.getCondition());
122 DeallocOp elseTakenDeallocOp = insertDeallocForBranch(
123 condBr.getFalseDest(), condBr.getFalseDestOperandsMutable(),
125 Value trueVal = arith::ConstantOp::create(builder, condBr.getLoc(),
126 builder.getBoolAttr(true));
127 Value negation = arith::XOrIOp::create(
128 builder, condBr.getLoc(), trueVal, condBr.getCondition());
129 return arith::AndIOp::create(builder, condBr.getLoc(), cond,
137 SmallPtrSet<Value, 16> thenValues(llvm::from_range,
138 thenTakenDeallocOp.getRetained());
140 for (Value val : elseTakenDeallocOp.getRetained()) {
141 if (thenValues.contains(val))
142 commonValues.insert(val);
145 for (Value retained : commonValues) {
147 Value combinedOwnership = arith::SelectOp::create(
148 builder, condBr.getLoc(), condBr.getCondition(),
149 thenMapping[retained], elseMapping[retained]);
150 state.
updateOwnership(retained, combinedOwnership, condBr->getBlock());
153 return condBr.getOperation();
162 CondBranchOp::attachInterface<CondBranchOpInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void resetOwnerships(ValueRange memrefs, Block *block)
Removes ownerships associated with all values in the passed range for 'block'.
void updateOwnership(Value memref, Ownership ownership, Block *block=nullptr)
Small helper function to update the ownership map by taking the current ownership ('Uninitialized' st...
LogicalResult getMemrefsAndConditionsToDeallocate(OpBuilder &builder, Location loc, Block *block, SmallVectorImpl< Value > &memrefs, SmallVectorImpl< Value > &conditions) const
For a given block, computes the list of MemRefs that potentially need to be deallocated at the end of...
void getMemrefsToRetain(Block *fromBlock, Block *toBlock, ValueRange destOperands, SmallVectorImpl< Value > &toRetain) const
Given two basic blocks and the values passed via block arguments to the destination block,...
void registerBufferDeallocationOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector