MLIR 22.0.0git
AsyncRuntimeRefCountingOpt.cpp
Go to the documentation of this file.
1//===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Optimize Async dialect reference counting operations.
10//
11//===----------------------------------------------------------------------===//
12
14
17#include "llvm/Support/Debug.h"
18
19namespace mlir {
20#define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTINGOPTPASS
21#include "mlir/Dialect/Async/Passes.h.inc"
22} // namespace mlir
23
24#define DEBUG_TYPE "async-ref-counting"
25
26using namespace mlir;
27using namespace mlir::async;
28
29namespace {
30
31class AsyncRuntimeRefCountingOptPass
33 AsyncRuntimeRefCountingOptPass> {
34public:
35 AsyncRuntimeRefCountingOptPass() = default;
36 void runOnOperation() override;
37
38private:
39 LogicalResult optimizeReferenceCounting(
40 Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
41};
42
43} // namespace
44
45LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
46 Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
47 Region *definingRegion = value.getParentRegion();
48
49 // Find all users of the `value` inside each block, including operations that
50 // do not use `value` directly, but have a direct use inside nested region(s).
51 //
52 // Example:
53 //
54 // ^bb1:
55 // %token = ...
56 // scf.if %cond {
57 // ^bb2:
58 // async.runtime.await %token : !async.token
59 // }
60 //
61 // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1
62 // (`scf.if`).
63
64 struct BlockUsersInfo {
65 llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
66 llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
67 llvm::SmallVector<Operation *, 4> users;
68 };
69
70 llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
71
72 auto updateBlockUsersInfo = [&](Operation *user) {
73 BlockUsersInfo &info = blockUsers[user->getBlock()];
74 info.users.push_back(user);
75
76 if (auto addRef = dyn_cast<RuntimeAddRefOp>(user))
77 info.addRefs.push_back(addRef);
78 if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
79 info.dropRefs.push_back(dropRef);
80 };
81
82 for (Operation *user : value.getUsers()) {
83 while (user->getParentRegion() != definingRegion) {
84 updateBlockUsersInfo(user);
85 user = user->getParentOp();
86 assert(user != nullptr && "value user lies outside of the value region");
87 }
88
89 updateBlockUsersInfo(user);
90 }
91
92 // Sort all operations found in the block.
93 auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
94 auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
95 return a->isBeforeInBlock(b);
96 };
97 llvm::sort(info.addRefs, isBeforeInBlock);
98 llvm::sort(info.dropRefs, isBeforeInBlock);
99 llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool {
100 return isBeforeInBlock(a, b);
101 });
102
103 return info;
104 };
105
106 // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
107 // blocks that modify the reference count of the `value`.
108 for (auto &kv : blockUsers) {
109 BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
110
111 for (RuntimeAddRefOp addRef : info.addRefs) {
112 for (RuntimeDropRefOp dropRef : info.dropRefs) {
113 // `drop_ref` operation after the `add_ref` with matching count.
114 if (dropRef.getCount() != addRef.getCount() ||
115 dropRef->isBeforeInBlock(addRef.getOperation()))
116 continue;
117
118 // When reference counted value passed to a function as an argument,
119 // function takes ownership of +1 reference and it will drop it before
120 // returning.
121 //
122 // Example:
123 //
124 // %token = ... : !async.token
125 //
126 // async.runtime.add_ref %token {count = 1 : i64} : !async.token
127 // call @pass_token(%token: !async.token, ...)
128 //
129 // async.await %token : !async.token
130 // async.runtime.drop_ref %token {count = 1 : i64} : !async.token
131 //
132 // In this example if we'll cancel a pair of reference counting
133 // operations we might end up with a deallocated token when we'll
134 // reach `async.await` operation.
135 Operation *firstFunctionCallUser = nullptr;
136 Operation *lastNonFunctionCallUser = nullptr;
137
138 for (Operation *user : info.users) {
139 // `user` operation lies after `addRef` ...
140 if (user == addRef || user->isBeforeInBlock(addRef))
141 continue;
142 // ... and before `dropRef`.
143 if (user == dropRef || dropRef->isBeforeInBlock(user))
144 break;
145
146 // Find the first function call user of the reference counted value.
147 Operation *functionCall = dyn_cast<func::CallOp>(user);
148 if (functionCall &&
149 (!firstFunctionCallUser ||
150 functionCall->isBeforeInBlock(firstFunctionCallUser))) {
151 firstFunctionCallUser = functionCall;
152 continue;
153 }
154
155 // Find the last regular user of the reference counted value.
156 if (!functionCall &&
157 (!lastNonFunctionCallUser ||
158 lastNonFunctionCallUser->isBeforeInBlock(user))) {
159 lastNonFunctionCallUser = user;
160 continue;
161 }
162 }
163
164 // Non function call user after the function call user of the reference
165 // counted value.
166 if (firstFunctionCallUser && lastNonFunctionCallUser &&
167 firstFunctionCallUser->isBeforeInBlock(lastNonFunctionCallUser))
168 continue;
169
170 // Try to cancel the pair of `add_ref` and `drop_ref` operations.
171 auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
172 addRef.getOperation());
173
174 if (!emplaced.second) // `drop_ref` was already marked for removal
175 continue; // go to the next `drop_ref`
176
177 if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref`
178 break; // go to the next `add_ref`
179 }
180 }
181 }
182
183 return success();
184}
185
186void AsyncRuntimeRefCountingOptPass::runOnOperation() {
187 Operation *op = getOperation();
188
189 // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
190 //
191 // Find all cancellable pairs of operation and erase them in the end to keep
192 // all iterators valid while we are walking the function operations.
193 llvm::SmallDenseMap<Operation *, Operation *> cancellable;
194
195 // Optimize reference counting for values defined by block arguments.
196 WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
197 for (BlockArgument arg : block->getArguments())
198 if (isRefCounted(arg.getType()))
199 if (failed(optimizeReferenceCounting(arg, cancellable)))
200 return WalkResult::interrupt();
201
202 return WalkResult::advance();
203 });
204
205 if (blockWalk.wasInterrupted())
206 signalPassFailure();
207
208 // Optimize reference counting for values defined by operation results.
209 WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
210 for (unsigned i = 0; i < op->getNumResults(); ++i)
211 if (isRefCounted(op->getResultTypes()[i]))
212 if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
213 return WalkResult::interrupt();
214
215 return WalkResult::advance();
216 });
217
218 if (opWalk.wasInterrupted())
219 signalPassFailure();
220
221 LLVM_DEBUG({
222 llvm::dbgs() << "Found " << cancellable.size()
223 << " cancellable reference counting operations\n";
224 });
225
226 // Erase all cancellable `add_ref <-> drop_ref` operation pairs.
227 for (auto &kv : cancellable) {
228 kv.first->erase();
229 kv.second->erase();
230 }
231}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
BlockArgListType getArguments()
Definition Block.h:87
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
result_type_range getResultTypes()
Definition Operation.h:428
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
user_range getUsers() const
Definition Value.h:218
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
bool isRefCounted(Type type)
Returns true if the type is reference counted at runtime.
Definition Async.h:52
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.