17 #include "llvm/ADT/SmallSet.h"
18 #include "llvm/Support/Debug.h"
21 #define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTINGOPT
22 #include "mlir/Dialect/Async/Passes.h.inc"
25 #define DEBUG_TYPE "async-ref-counting"
32 class AsyncRuntimeRefCountingOptPass
33 :
public impl::AsyncRuntimeRefCountingOptBase<
34 AsyncRuntimeRefCountingOptPass> {
36 AsyncRuntimeRefCountingOptPass() =
default;
37 void runOnOperation()
override;
40 LogicalResult optimizeReferenceCounting(
41 Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
46 LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
47 Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
65 struct BlockUsersInfo {
73 auto updateBlockUsersInfo = [&](
Operation *user) {
74 BlockUsersInfo &info = blockUsers[user->getBlock()];
75 info.users.push_back(user);
77 if (
auto addRef = dyn_cast<RuntimeAddRefOp>(user))
78 info.addRefs.push_back(addRef);
79 if (
auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
80 info.dropRefs.push_back(dropRef);
84 while (user->getParentRegion() != definingRegion) {
85 updateBlockUsersInfo(user);
86 user = user->getParentOp();
87 assert(user !=
nullptr &&
"value user lies outside of the value region");
90 updateBlockUsersInfo(user);
94 auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
98 llvm::sort(info.addRefs, isBeforeInBlock);
99 llvm::sort(info.dropRefs, isBeforeInBlock);
101 return isBeforeInBlock(a, b);
109 for (
auto &kv : blockUsers) {
110 BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
112 for (RuntimeAddRefOp addRef : info.addRefs) {
113 for (RuntimeDropRefOp dropRef : info.dropRefs) {
115 if (dropRef.getCount() != addRef.getCount() ||
116 dropRef->isBeforeInBlock(addRef.getOperation()))
136 Operation *firstFunctionCallUser =
nullptr;
137 Operation *lastNonFunctionCallUser =
nullptr;
144 if (user == dropRef || dropRef->isBeforeInBlock(user))
148 Operation *functionCall = dyn_cast<func::CallOp>(user);
150 (!firstFunctionCallUser ||
152 firstFunctionCallUser = functionCall;
158 (!lastNonFunctionCallUser ||
160 lastNonFunctionCallUser = user;
167 if (firstFunctionCallUser && lastNonFunctionCallUser &&
172 auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
173 addRef.getOperation());
175 if (!emplaced.second)
187 void AsyncRuntimeRefCountingOptPass::runOnOperation() {
194 llvm::SmallDenseMap<Operation *, Operation *> cancellable;
200 if (failed(optimizeReferenceCounting(arg, cancellable)))
213 if (failed(optimizeReferenceCounting(op->
getResult(i), cancellable)))
223 llvm::dbgs() <<
"Found " << cancellable.size()
224 <<
" cancellable reference counting operations\n";
228 for (
auto &kv : cancellable) {
235 return std::make_unique<AsyncRuntimeRefCountingOptPass>();
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
user_range getUsers() const
Region * getParentRegion()
Return the Region in which this Value is defined.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
bool isRefCounted(Type type)
Returns true if the type is reference counted at runtime.
Include the generated interface declarations.
std::unique_ptr< Pass > createAsyncRuntimeRefCountingOptPass()