MLIR  19.0.0git
AsyncRuntimeRefCountingOpt.cpp
Go to the documentation of this file.
1 //===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Optimize Async dialect reference counting operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
17 #include "llvm/ADT/SmallSet.h"
18 #include "llvm/Support/Debug.h"
19 
20 namespace mlir {
21 #define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTINGOPT
22 #include "mlir/Dialect/Async/Passes.h.inc"
23 } // namespace mlir
24 
25 #define DEBUG_TYPE "async-ref-counting"
26 
27 using namespace mlir;
28 using namespace mlir::async;
29 
30 namespace {
31 
32 class AsyncRuntimeRefCountingOptPass
33  : public impl::AsyncRuntimeRefCountingOptBase<
34  AsyncRuntimeRefCountingOptPass> {
35 public:
36  AsyncRuntimeRefCountingOptPass() = default;
37  void runOnOperation() override;
38 
39 private:
40  LogicalResult optimizeReferenceCounting(
41  Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
42 };
43 
44 } // namespace
45 
46 LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
47  Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
48  Region *definingRegion = value.getParentRegion();
49 
50  // Find all users of the `value` inside each block, including operations that
51  // do not use `value` directly, but have a direct use inside nested region(s).
52  //
53  // Example:
54  //
55  // ^bb1:
56  // %token = ...
57  // scf.if %cond {
58  // ^bb2:
59  // async.runtime.await %token : !async.token
60  // }
61  //
62  // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1
63  // (`scf.if`).
64 
65  struct BlockUsersInfo {
69  };
70 
72 
73  auto updateBlockUsersInfo = [&](Operation *user) {
74  BlockUsersInfo &info = blockUsers[user->getBlock()];
75  info.users.push_back(user);
76 
77  if (auto addRef = dyn_cast<RuntimeAddRefOp>(user))
78  info.addRefs.push_back(addRef);
79  if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
80  info.dropRefs.push_back(dropRef);
81  };
82 
83  for (Operation *user : value.getUsers()) {
84  while (user->getParentRegion() != definingRegion) {
85  updateBlockUsersInfo(user);
86  user = user->getParentOp();
87  assert(user != nullptr && "value user lies outside of the value region");
88  }
89 
90  updateBlockUsersInfo(user);
91  }
92 
93  // Sort all operations found in the block.
94  auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
95  auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
96  return a->isBeforeInBlock(b);
97  };
98  llvm::sort(info.addRefs, isBeforeInBlock);
99  llvm::sort(info.dropRefs, isBeforeInBlock);
100  llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool {
101  return isBeforeInBlock(a, b);
102  });
103 
104  return info;
105  };
106 
107  // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
108  // blocks that modify the reference count of the `value`.
109  for (auto &kv : blockUsers) {
110  BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
111 
112  for (RuntimeAddRefOp addRef : info.addRefs) {
113  for (RuntimeDropRefOp dropRef : info.dropRefs) {
114  // `drop_ref` operation after the `add_ref` with matching count.
115  if (dropRef.getCount() != addRef.getCount() ||
116  dropRef->isBeforeInBlock(addRef.getOperation()))
117  continue;
118 
119  // When reference counted value passed to a function as an argument,
120  // function takes ownership of +1 reference and it will drop it before
121  // returning.
122  //
123  // Example:
124  //
125  // %token = ... : !async.token
126  //
127  // async.runtime.add_ref %token {count = 1 : i64} : !async.token
128  // call @pass_token(%token: !async.token, ...)
129  //
130  // async.await %token : !async.token
131  // async.runtime.drop_ref %token {count = 1 : i64} : !async.token
132  //
133  // In this example if we'll cancel a pair of reference counting
134  // operations we might end up with a deallocated token when we'll
135  // reach `async.await` operation.
136  Operation *firstFunctionCallUser = nullptr;
137  Operation *lastNonFunctionCallUser = nullptr;
138 
139  for (Operation *user : info.users) {
140  // `user` operation lies after `addRef` ...
141  if (user == addRef || user->isBeforeInBlock(addRef))
142  continue;
143  // ... and before `dropRef`.
144  if (user == dropRef || dropRef->isBeforeInBlock(user))
145  break;
146 
147  // Find the first function call user of the reference counted value.
148  Operation *functionCall = dyn_cast<func::CallOp>(user);
149  if (functionCall &&
150  (!firstFunctionCallUser ||
151  functionCall->isBeforeInBlock(firstFunctionCallUser))) {
152  firstFunctionCallUser = functionCall;
153  continue;
154  }
155 
156  // Find the last regular user of the reference counted value.
157  if (!functionCall &&
158  (!lastNonFunctionCallUser ||
159  lastNonFunctionCallUser->isBeforeInBlock(user))) {
160  lastNonFunctionCallUser = user;
161  continue;
162  }
163  }
164 
165  // Non function call user after the function call user of the reference
166  // counted value.
167  if (firstFunctionCallUser && lastNonFunctionCallUser &&
168  firstFunctionCallUser->isBeforeInBlock(lastNonFunctionCallUser))
169  continue;
170 
171  // Try to cancel the pair of `add_ref` and `drop_ref` operations.
172  auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
173  addRef.getOperation());
174 
175  if (!emplaced.second) // `drop_ref` was already marked for removal
176  continue; // go to the next `drop_ref`
177 
178  if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref`
179  break; // go to the next `add_ref`
180  }
181  }
182  }
183 
184  return success();
185 }
186 
187 void AsyncRuntimeRefCountingOptPass::runOnOperation() {
188  Operation *op = getOperation();
189 
190  // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
191  //
192  // Find all cancellable pairs of operation and erase them in the end to keep
193  // all iterators valid while we are walking the function operations.
194  llvm::SmallDenseMap<Operation *, Operation *> cancellable;
195 
196  // Optimize reference counting for values defined by block arguments.
197  WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
198  for (BlockArgument arg : block->getArguments())
199  if (isRefCounted(arg.getType()))
200  if (failed(optimizeReferenceCounting(arg, cancellable)))
201  return WalkResult::interrupt();
202 
203  return WalkResult::advance();
204  });
205 
206  if (blockWalk.wasInterrupted())
207  signalPassFailure();
208 
209  // Optimize reference counting for values defined by operation results.
210  WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
211  for (unsigned i = 0; i < op->getNumResults(); ++i)
212  if (isRefCounted(op->getResultTypes()[i]))
213  if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
214  return WalkResult::interrupt();
215 
216  return WalkResult::advance();
217  });
218 
219  if (opWalk.wasInterrupted())
220  signalPassFailure();
221 
222  LLVM_DEBUG({
223  llvm::dbgs() << "Found " << cancellable.size()
224  << " cancellable reference counting operations\n";
225  });
226 
227  // Erase all cancellable `add_ref <-> drop_ref` operation pairs.
228  for (auto &kv : cancellable) {
229  kv.first->erase();
230  kv.second->erase();
231  }
232 }
233 
235  return std::make_unique<AsyncRuntimeRefCountingOptPass>();
236 }
This class represents an argument of a Block.
Definition: Value.h:315
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgListType getArguments()
Definition: Block.h:84
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
result_type_range getResultTypes()
Definition: Operation.h:423
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
user_range getUsers() const
Definition: Value.h:224
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult advance()
Definition: Visitors.h:52
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:56
static WalkResult interrupt()
Definition: Visitors.h:51
bool isRefCounted(Type type)
Returns true if the type is reference counted at runtime.
Definition: Async.h:52
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::unique_ptr< Pass > createAsyncRuntimeRefCountingOptPass()
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26