MLIR 22.0.0git
BufferDeallocationOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/Dialect.h"
15#include "mlir/IR/Operation.h"
16
17using namespace mlir;
18using namespace mlir::bufferization;
19
20static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
21
22namespace {
23/// While CondBranchOp also implement the BranchOpInterface, we add a
24/// special-case implementation here because the BranchOpInterface does not
25/// offer all of the functionallity we need to insert dealloc oeprations in an
26/// efficient way. More precisely, there is no way to extract the branch
27/// condition without casting to CondBranchOp specifically. It is still
28/// possible to implement deallocation for cases where we don't know to which
29/// successor the terminator branches before the actual branch happens by
30/// inserting auxiliary blocks and putting the dealloc op there, however, this
31/// can lead to less efficient code.
32/// This function inserts two dealloc operations (one for each successor) and
33/// adjusts the dealloc conditions according to the branch condition, then the
34/// ownerships of the retained MemRefs are updated by combining the result
35/// values of the two dealloc operations.
36///
37/// Example:
38/// ```
39/// ^bb1:
40/// <more ops...>
41/// cf.cond_br cond, ^bb2(<forward-to-bb2>), ^bb3(<forward-to-bb2>)
42/// ```
43/// becomes
44/// ```
45/// // let (m, c) = getMemrefsAndConditionsToDeallocate(bb1)
46/// // let r0 = getMemrefsToRetain(bb1, bb2, <forward-to-bb2>)
47/// // let r1 = getMemrefsToRetain(bb1, bb3, <forward-to-bb3>)
48/// ^bb1:
49/// <more ops...>
50/// let thenCond = map(c, (c) -> arith.andi cond, c)
51/// let elseCond = map(c, (c) -> arith.andi (arith.xori cond, true), c)
52/// o0 = bufferization.dealloc m if thenCond retain r0
53/// o1 = bufferization.dealloc m if elseCond retain r1
54/// // replace ownership(r0) with o0 element-wise
55/// // replace ownership(r1) with o1 element-wise
56/// // let ownership0 := (r) -> o in o0 corresponding to r
57/// // let ownership1 := (r) -> o in o1 corresponding to r
58/// // let cmn := intersection(r0, r1)
59/// foreach (a, b) in zip(map(cmn, ownership0), map(cmn, ownership1)):
60/// forall r in r0: replace ownership0(r) with arith.select cond, a, b)
61/// forall r in r1: replace ownership1(r) with arith.select cond, a, b)
62/// cf.cond_br cond, ^bb2(<forward-to-bb2>, o0), ^bb3(<forward-to-bb3>, o1)
63/// ```
64struct CondBranchOpInterface
65 : public BufferDeallocationOpInterface::ExternalModel<CondBranchOpInterface,
66 cf::CondBranchOp> {
67 FailureOr<Operation *> process(Operation *op, DeallocationState &state,
68 const DeallocationOptions &options) const {
69 OpBuilder builder(op);
70 auto condBr = cast<cf::CondBranchOp>(op);
71
72 // The list of memrefs to deallocate in this block is independent of which
73 // branch is taken.
74 SmallVector<Value> memrefs, conditions;
76 builder, condBr.getLoc(), condBr->getBlock(), memrefs, conditions)))
77 return failure();
78
79 // Helper lambda to factor out common logic for inserting the dealloc
80 // operations for each successor.
81 auto insertDeallocForBranch =
82 [&](Block *target, MutableOperandRange destOperands,
83 const std::function<Value(Value)> &conditionModifier,
84 DenseMap<Value, Value> &mapping) -> DeallocOp {
85 SmallVector<Value> toRetain;
86 state.getMemrefsToRetain(condBr->getBlock(), target,
87 destOperands.getAsOperandRange(), toRetain);
88 SmallVector<Value> adaptedConditions(
89 llvm::map_range(conditions, conditionModifier));
90 auto deallocOp = bufferization::DeallocOp::create(
91 builder, condBr.getLoc(), memrefs, adaptedConditions, toRetain);
92 state.resetOwnerships(deallocOp.getRetained(), condBr->getBlock());
93 for (auto [retained, ownership] : llvm::zip(
94 deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
95 state.updateOwnership(retained, ownership, condBr->getBlock());
96 mapping[retained] = ownership;
97 }
98 SmallVector<Value> replacements, ownerships;
99 for (OpOperand &operand : destOperands) {
100 replacements.push_back(operand.get());
101 if (isMemref(operand.get())) {
102 assert(mapping.contains(operand.get()) &&
103 "Should be contained at this point");
104 ownerships.push_back(mapping[operand.get()]);
105 }
106 }
107 replacements.append(ownerships);
108 destOperands.assign(replacements);
109 return deallocOp;
110 };
111
112 // Call the helper lambda and make sure the dealloc conditions are properly
113 // modified to reflect the branch condition as well.
114 DenseMap<Value, Value> thenMapping, elseMapping;
115 DeallocOp thenTakenDeallocOp = insertDeallocForBranch(
116 condBr.getTrueDest(), condBr.getTrueDestOperandsMutable(),
117 [&](Value cond) {
118 return arith::AndIOp::create(builder, condBr.getLoc(), cond,
119 condBr.getCondition());
120 },
121 thenMapping);
122 DeallocOp elseTakenDeallocOp = insertDeallocForBranch(
123 condBr.getFalseDest(), condBr.getFalseDestOperandsMutable(),
124 [&](Value cond) {
125 Value trueVal = arith::ConstantOp::create(builder, condBr.getLoc(),
126 builder.getBoolAttr(true));
127 Value negation = arith::XOrIOp::create(
128 builder, condBr.getLoc(), trueVal, condBr.getCondition());
129 return arith::AndIOp::create(builder, condBr.getLoc(), cond,
130 negation);
131 },
132 elseMapping);
133
134 // We specifically need to update the ownerships of values that are retained
135 // in both dealloc operations again to get a combined 'Unique' ownership
136 // instead of an 'Unknown' ownership.
137 SmallPtrSet<Value, 16> thenValues(llvm::from_range,
138 thenTakenDeallocOp.getRetained());
139 SetVector<Value> commonValues;
140 for (Value val : elseTakenDeallocOp.getRetained()) {
141 if (thenValues.contains(val))
142 commonValues.insert(val);
143 }
144
145 for (Value retained : commonValues) {
146 state.resetOwnerships(retained, condBr->getBlock());
147 Value combinedOwnership = arith::SelectOp::create(
148 builder, condBr.getLoc(), condBr.getCondition(),
149 thenMapping[retained], elseMapping[retained]);
150 state.updateOwnership(retained, combinedOwnership, condBr->getBlock());
151 }
152
153 return condBr.getOperation();
154 }
155};
156
157} // namespace
158
160 DialectRegistry &registry) {
161 registry.addExtension(+[](MLIRContext *ctx, ControlFlowDialect *dialect) {
162 CondBranchOp::attachInterface<CondBranchOpInterface>(*ctx);
163 });
164}
static bool isMemref(Value v)
static llvm::ManagedStatic< PassManagerOptions > options
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void resetOwnerships(ValueRange memrefs, Block *block)
Removes ownerships associated with all values in the passed range for 'block'.
void updateOwnership(Value memref, Ownership ownership, Block *block=nullptr)
Small helper function to update the ownership map by taking the current ownership ('Uninitialized' st...
LogicalResult getMemrefsAndConditionsToDeallocate(OpBuilder &builder, Location loc, Block *block, SmallVectorImpl< Value > &memrefs, SmallVectorImpl< Value > &conditions) const
For a given block, computes the list of MemRefs that potentially need to be deallocated at the end of...
void getMemrefsToRetain(Block *fromBlock, Block *toBlock, ValueRange destOperands, SmallVectorImpl< Value > &toRetain) const
Given two basic blocks and the values passed via block arguments to the destination block,...
void registerBufferDeallocationOpInterfaceExternalModels(DialectRegistry &registry)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131