MLIR  19.0.0git
BufferDeallocationOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 
17 using namespace mlir;
18 using namespace mlir::bufferization;
19 
20 static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
21 
22 namespace {
23 /// While CondBranchOp also implement the BranchOpInterface, we add a
24 /// special-case implementation here because the BranchOpInterface does not
25 /// offer all of the functionallity we need to insert dealloc oeprations in an
26 /// efficient way. More precisely, there is no way to extract the branch
27 /// condition without casting to CondBranchOp specifically. It is still
28 /// possible to implement deallocation for cases where we don't know to which
29 /// successor the terminator branches before the actual branch happens by
30 /// inserting auxiliary blocks and putting the dealloc op there, however, this
31 /// can lead to less efficient code.
32 /// This function inserts two dealloc operations (one for each successor) and
33 /// adjusts the dealloc conditions according to the branch condition, then the
34 /// ownerships of the retained MemRefs are updated by combining the result
35 /// values of the two dealloc operations.
36 ///
37 /// Example:
38 /// ```
39 /// ^bb1:
40 /// <more ops...>
41 /// cf.cond_br cond, ^bb2(<forward-to-bb2>), ^bb3(<forward-to-bb2>)
42 /// ```
43 /// becomes
44 /// ```
45 /// // let (m, c) = getMemrefsAndConditionsToDeallocate(bb1)
46 /// // let r0 = getMemrefsToRetain(bb1, bb2, <forward-to-bb2>)
47 /// // let r1 = getMemrefsToRetain(bb1, bb3, <forward-to-bb3>)
48 /// ^bb1:
49 /// <more ops...>
50 /// let thenCond = map(c, (c) -> arith.andi cond, c)
51 /// let elseCond = map(c, (c) -> arith.andi (arith.xori cond, true), c)
52 /// o0 = bufferization.dealloc m if thenCond retain r0
53 /// o1 = bufferization.dealloc m if elseCond retain r1
54 /// // replace ownership(r0) with o0 element-wise
55 /// // replace ownership(r1) with o1 element-wise
56 /// // let ownership0 := (r) -> o in o0 corresponding to r
57 /// // let ownership1 := (r) -> o in o1 corresponding to r
58 /// // let cmn := intersection(r0, r1)
59 /// foreach (a, b) in zip(map(cmn, ownership0), map(cmn, ownership1)):
60 /// forall r in r0: replace ownership0(r) with arith.select cond, a, b)
61 /// forall r in r1: replace ownership1(r) with arith.select cond, a, b)
62 /// cf.cond_br cond, ^bb2(<forward-to-bb2>, o0), ^bb3(<forward-to-bb3>, o1)
63 /// ```
64 struct CondBranchOpInterface
65  : public BufferDeallocationOpInterface::ExternalModel<CondBranchOpInterface,
66  cf::CondBranchOp> {
68  const DeallocationOptions &options) const {
69  OpBuilder builder(op);
70  auto condBr = cast<cf::CondBranchOp>(op);
71 
72  // The list of memrefs to deallocate in this block is independent of which
73  // branch is taken.
74  SmallVector<Value> memrefs, conditions;
75  if (failed(state.getMemrefsAndConditionsToDeallocate(
76  builder, condBr.getLoc(), condBr->getBlock(), memrefs, conditions)))
77  return failure();
78 
79  // Helper lambda to factor out common logic for inserting the dealloc
80  // operations for each successor.
81  auto insertDeallocForBranch =
82  [&](Block *target, MutableOperandRange destOperands,
83  const std::function<Value(Value)> &conditionModifier,
84  DenseMap<Value, Value> &mapping) -> DeallocOp {
85  SmallVector<Value> toRetain;
86  state.getMemrefsToRetain(condBr->getBlock(), target,
87  destOperands.getAsOperandRange(), toRetain);
88  SmallVector<Value> adaptedConditions(
89  llvm::map_range(conditions, conditionModifier));
90  auto deallocOp = builder.create<bufferization::DeallocOp>(
91  condBr.getLoc(), memrefs, adaptedConditions, toRetain);
92  state.resetOwnerships(deallocOp.getRetained(), condBr->getBlock());
93  for (auto [retained, ownership] : llvm::zip(
94  deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
95  state.updateOwnership(retained, ownership, condBr->getBlock());
96  mapping[retained] = ownership;
97  }
98  SmallVector<Value> replacements, ownerships;
99  for (OpOperand &operand : destOperands) {
100  replacements.push_back(operand.get());
101  if (isMemref(operand.get())) {
102  assert(mapping.contains(operand.get()) &&
103  "Should be contained at this point");
104  ownerships.push_back(mapping[operand.get()]);
105  }
106  }
107  replacements.append(ownerships);
108  destOperands.assign(replacements);
109  return deallocOp;
110  };
111 
112  // Call the helper lambda and make sure the dealloc conditions are properly
113  // modified to reflect the branch condition as well.
114  DenseMap<Value, Value> thenMapping, elseMapping;
115  DeallocOp thenTakenDeallocOp = insertDeallocForBranch(
116  condBr.getTrueDest(), condBr.getTrueDestOperandsMutable(),
117  [&](Value cond) {
118  return builder.create<arith::AndIOp>(condBr.getLoc(), cond,
119  condBr.getCondition());
120  },
121  thenMapping);
122  DeallocOp elseTakenDeallocOp = insertDeallocForBranch(
123  condBr.getFalseDest(), condBr.getFalseDestOperandsMutable(),
124  [&](Value cond) {
125  Value trueVal = builder.create<arith::ConstantOp>(
126  condBr.getLoc(), builder.getBoolAttr(true));
127  Value negation = builder.create<arith::XOrIOp>(
128  condBr.getLoc(), trueVal, condBr.getCondition());
129  return builder.create<arith::AndIOp>(condBr.getLoc(), cond, negation);
130  },
131  elseMapping);
132 
133  // We specifically need to update the ownerships of values that are retained
134  // in both dealloc operations again to get a combined 'Unique' ownership
135  // instead of an 'Unknown' ownership.
136  SmallPtrSet<Value, 16> thenValues(thenTakenDeallocOp.getRetained().begin(),
137  thenTakenDeallocOp.getRetained().end());
138  SetVector<Value> commonValues;
139  for (Value val : elseTakenDeallocOp.getRetained()) {
140  if (thenValues.contains(val))
141  commonValues.insert(val);
142  }
143 
144  for (Value retained : commonValues) {
145  state.resetOwnerships(retained, condBr->getBlock());
146  Value combinedOwnership = builder.create<arith::SelectOp>(
147  condBr.getLoc(), condBr.getCondition(), thenMapping[retained],
148  elseMapping[retained]);
149  state.updateOwnership(retained, combinedOwnership, condBr->getBlock());
150  }
151 
152  return condBr.getOperation();
153  }
154 };
155 
156 } // namespace
157 
159  DialectRegistry &registry) {
160  registry.addExtension(+[](MLIRContext *ctx, ControlFlowDialect *dialect) {
161  CondBranchOp::attachInterface<CondBranchOpInterface>(*ctx);
162  });
163 }
static bool isMemref(Value v)
static llvm::ManagedStatic< PassManagerOptions > options
Block represents an ordered list of Operations.
Definition: Block.h:31
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
This class helps build Operations.
Definition: Builders.h:209
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
This class collects all the state that we need to perform the buffer deallocation pass with associate...
void registerBufferDeallocationOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Options for BufferDeallocationOpInterface-based buffer deallocation.