MLIR  22.0.0git
AsyncRuntimeRefCountingOpt.cpp
Go to the documentation of this file.
1 //===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Optimize Async dialect reference counting operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
17 #include "llvm/Support/Debug.h"
18 
19 namespace mlir {
20 #define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTINGOPTPASS
21 #include "mlir/Dialect/Async/Passes.h.inc"
22 } // namespace mlir
23 
24 #define DEBUG_TYPE "async-ref-counting"
25 
26 using namespace mlir;
27 using namespace mlir::async;
28 
29 namespace {
30 
31 class AsyncRuntimeRefCountingOptPass
32  : public impl::AsyncRuntimeRefCountingOptPassBase<
33  AsyncRuntimeRefCountingOptPass> {
34 public:
35  AsyncRuntimeRefCountingOptPass() = default;
36  void runOnOperation() override;
37 
38 private:
39  LogicalResult optimizeReferenceCounting(
40  Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
41 };
42 
43 } // namespace
44 
45 LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
46  Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
47  Region *definingRegion = value.getParentRegion();
48 
49  // Find all users of the `value` inside each block, including operations that
50  // do not use `value` directly, but have a direct use inside nested region(s).
51  //
52  // Example:
53  //
54  // ^bb1:
55  // %token = ...
56  // scf.if %cond {
57  // ^bb2:
58  // async.runtime.await %token : !async.token
59  // }
60  //
61  // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1
62  // (`scf.if`).
63 
64  struct BlockUsersInfo {
68  };
69 
71 
72  auto updateBlockUsersInfo = [&](Operation *user) {
73  BlockUsersInfo &info = blockUsers[user->getBlock()];
74  info.users.push_back(user);
75 
76  if (auto addRef = dyn_cast<RuntimeAddRefOp>(user))
77  info.addRefs.push_back(addRef);
78  if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
79  info.dropRefs.push_back(dropRef);
80  };
81 
82  for (Operation *user : value.getUsers()) {
83  while (user->getParentRegion() != definingRegion) {
84  updateBlockUsersInfo(user);
85  user = user->getParentOp();
86  assert(user != nullptr && "value user lies outside of the value region");
87  }
88 
89  updateBlockUsersInfo(user);
90  }
91 
92  // Sort all operations found in the block.
93  auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
94  auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
95  return a->isBeforeInBlock(b);
96  };
97  llvm::sort(info.addRefs, isBeforeInBlock);
98  llvm::sort(info.dropRefs, isBeforeInBlock);
99  llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool {
100  return isBeforeInBlock(a, b);
101  });
102 
103  return info;
104  };
105 
106  // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
107  // blocks that modify the reference count of the `value`.
108  for (auto &kv : blockUsers) {
109  BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
110 
111  for (RuntimeAddRefOp addRef : info.addRefs) {
112  for (RuntimeDropRefOp dropRef : info.dropRefs) {
113  // `drop_ref` operation after the `add_ref` with matching count.
114  if (dropRef.getCount() != addRef.getCount() ||
115  dropRef->isBeforeInBlock(addRef.getOperation()))
116  continue;
117 
118  // When reference counted value passed to a function as an argument,
119  // function takes ownership of +1 reference and it will drop it before
120  // returning.
121  //
122  // Example:
123  //
124  // %token = ... : !async.token
125  //
126  // async.runtime.add_ref %token {count = 1 : i64} : !async.token
127  // call @pass_token(%token: !async.token, ...)
128  //
129  // async.await %token : !async.token
130  // async.runtime.drop_ref %token {count = 1 : i64} : !async.token
131  //
132  // In this example if we'll cancel a pair of reference counting
133  // operations we might end up with a deallocated token when we'll
134  // reach `async.await` operation.
135  Operation *firstFunctionCallUser = nullptr;
136  Operation *lastNonFunctionCallUser = nullptr;
137 
138  for (Operation *user : info.users) {
139  // `user` operation lies after `addRef` ...
140  if (user == addRef || user->isBeforeInBlock(addRef))
141  continue;
142  // ... and before `dropRef`.
143  if (user == dropRef || dropRef->isBeforeInBlock(user))
144  break;
145 
146  // Find the first function call user of the reference counted value.
147  Operation *functionCall = dyn_cast<func::CallOp>(user);
148  if (functionCall &&
149  (!firstFunctionCallUser ||
150  functionCall->isBeforeInBlock(firstFunctionCallUser))) {
151  firstFunctionCallUser = functionCall;
152  continue;
153  }
154 
155  // Find the last regular user of the reference counted value.
156  if (!functionCall &&
157  (!lastNonFunctionCallUser ||
158  lastNonFunctionCallUser->isBeforeInBlock(user))) {
159  lastNonFunctionCallUser = user;
160  continue;
161  }
162  }
163 
164  // Non function call user after the function call user of the reference
165  // counted value.
166  if (firstFunctionCallUser && lastNonFunctionCallUser &&
167  firstFunctionCallUser->isBeforeInBlock(lastNonFunctionCallUser))
168  continue;
169 
170  // Try to cancel the pair of `add_ref` and `drop_ref` operations.
171  auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
172  addRef.getOperation());
173 
174  if (!emplaced.second) // `drop_ref` was already marked for removal
175  continue; // go to the next `drop_ref`
176 
177  if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref`
178  break; // go to the next `add_ref`
179  }
180  }
181  }
182 
183  return success();
184 }
185 
186 void AsyncRuntimeRefCountingOptPass::runOnOperation() {
187  Operation *op = getOperation();
188 
189  // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
190  //
191  // Find all cancellable pairs of operation and erase them in the end to keep
192  // all iterators valid while we are walking the function operations.
193  llvm::SmallDenseMap<Operation *, Operation *> cancellable;
194 
195  // Optimize reference counting for values defined by block arguments.
196  WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
197  for (BlockArgument arg : block->getArguments())
198  if (isRefCounted(arg.getType()))
199  if (failed(optimizeReferenceCounting(arg, cancellable)))
200  return WalkResult::interrupt();
201 
202  return WalkResult::advance();
203  });
204 
205  if (blockWalk.wasInterrupted())
206  signalPassFailure();
207 
208  // Optimize reference counting for values defined by operation results.
209  WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
210  for (unsigned i = 0; i < op->getNumResults(); ++i)
211  if (isRefCounted(op->getResultTypes()[i]))
212  if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
213  return WalkResult::interrupt();
214 
215  return WalkResult::advance();
216  });
217 
218  if (opWalk.wasInterrupted())
219  signalPassFailure();
220 
221  LLVM_DEBUG({
222  llvm::dbgs() << "Found " << cancellable.size()
223  << " cancellable reference counting operations\n";
224  });
225 
226  // Erase all cancellable `add_ref <-> drop_ref` operation pairs.
227  for (auto &kv : cancellable) {
228  kv.first->erase();
229  kv.second->erase();
230  }
231 }
static bool isBeforeInBlock(Block *block, Block::iterator a, Block::iterator b)
Given two iterators into the same block, return "true" if a is before `b.
Definition: Dominance.cpp:241
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgListType getArguments()
Definition: Block.h:87
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
result_type_range getResultTypes()
Definition: Operation.h:428
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
user_range getUsers() const
Definition: Value.h:218
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:39
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: WalkResult.h:51
static WalkResult interrupt()
Definition: WalkResult.h:46
bool isRefCounted(Type type)
Returns true if the type is reference counted at runtime.
Definition: Async.h:52
Include the generated interface declarations.