17 #include "llvm/Support/Debug.h"
20 #define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTINGOPTPASS
21 #include "mlir/Dialect/Async/Passes.h.inc"
24 #define DEBUG_TYPE "async-ref-counting"
31 class AsyncRuntimeRefCountingOptPass
32 :
public impl::AsyncRuntimeRefCountingOptPassBase<
33 AsyncRuntimeRefCountingOptPass> {
35 AsyncRuntimeRefCountingOptPass() =
default;
36 void runOnOperation()
override;
39 LogicalResult optimizeReferenceCounting(
40 Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
45 LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
46 Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
64 struct BlockUsersInfo {
72 auto updateBlockUsersInfo = [&](
Operation *user) {
73 BlockUsersInfo &info = blockUsers[user->getBlock()];
74 info.users.push_back(user);
76 if (
auto addRef = dyn_cast<RuntimeAddRefOp>(user))
77 info.addRefs.push_back(addRef);
78 if (
auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
79 info.dropRefs.push_back(dropRef);
83 while (user->getParentRegion() != definingRegion) {
84 updateBlockUsersInfo(user);
85 user = user->getParentOp();
86 assert(user !=
nullptr &&
"value user lies outside of the value region");
89 updateBlockUsersInfo(user);
93 auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
100 return isBeforeInBlock(a, b);
108 for (
auto &kv : blockUsers) {
109 BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
111 for (RuntimeAddRefOp addRef : info.addRefs) {
112 for (RuntimeDropRefOp dropRef : info.dropRefs) {
114 if (dropRef.getCount() != addRef.getCount() ||
115 dropRef->isBeforeInBlock(addRef.getOperation()))
135 Operation *firstFunctionCallUser =
nullptr;
136 Operation *lastNonFunctionCallUser =
nullptr;
143 if (user == dropRef || dropRef->isBeforeInBlock(user))
147 Operation *functionCall = dyn_cast<func::CallOp>(user);
149 (!firstFunctionCallUser ||
151 firstFunctionCallUser = functionCall;
157 (!lastNonFunctionCallUser ||
159 lastNonFunctionCallUser = user;
166 if (firstFunctionCallUser && lastNonFunctionCallUser &&
171 auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
172 addRef.getOperation());
174 if (!emplaced.second)
186 void AsyncRuntimeRefCountingOptPass::runOnOperation() {
193 llvm::SmallDenseMap<Operation *, Operation *> cancellable;
199 if (failed(optimizeReferenceCounting(arg, cancellable)))
212 if (failed(optimizeReferenceCounting(op->
getResult(i), cancellable)))
222 llvm::dbgs() <<
"Found " << cancellable.size()
223 <<
" cancellable reference counting operations\n";
227 for (
auto &kv : cancellable) {
static bool isBeforeInBlock(Block *block, Block::iterator a, Block::iterator b)
Given two iterators into the same block, return "true" if a is before `b.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
user_range getUsers() const
Region * getParentRegion()
Return the Region in which this Value is defined.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
bool isRefCounted(Type type)
Returns true if the type is reference counted at runtime.
Include the generated interface declarations.