17#include "llvm/Support/Debug.h"
20#define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTINGOPTPASS
21#include "mlir/Dialect/Async/Passes.h.inc"
24#define DEBUG_TYPE "async-ref-counting"
31class AsyncRuntimeRefCountingOptPass
33 AsyncRuntimeRefCountingOptPass> {
35 AsyncRuntimeRefCountingOptPass() =
default;
36 void runOnOperation()
override;
39 LogicalResult optimizeReferenceCounting(
40 Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
45LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
46 Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
64 struct BlockUsersInfo {
65 llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
66 llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
67 llvm::SmallVector<Operation *, 4> users;
70 llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
72 auto updateBlockUsersInfo = [&](Operation *user) {
73 BlockUsersInfo &info = blockUsers[user->getBlock()];
74 info.users.push_back(user);
76 if (
auto addRef = dyn_cast<RuntimeAddRefOp>(user))
77 info.addRefs.push_back(addRef);
78 if (
auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
79 info.dropRefs.push_back(dropRef);
82 for (Operation *user : value.
getUsers()) {
83 while (user->getParentRegion() != definingRegion) {
84 updateBlockUsersInfo(user);
85 user = user->getParentOp();
86 assert(user !=
nullptr &&
"value user lies outside of the value region");
89 updateBlockUsersInfo(user);
93 auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
94 auto isBeforeInBlock = [](Operation *a, Operation *
b) ->
bool {
97 llvm::sort(info.addRefs, isBeforeInBlock);
98 llvm::sort(info.dropRefs, isBeforeInBlock);
99 llvm::sort(info.users, [&](Operation *a, Operation *
b) ->
bool {
100 return isBeforeInBlock(a, b);
108 for (
auto &kv : blockUsers) {
109 BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
111 for (RuntimeAddRefOp addRef : info.addRefs) {
112 for (RuntimeDropRefOp dropRef : info.dropRefs) {
114 if (dropRef.getCount() != addRef.getCount() ||
115 dropRef->isBeforeInBlock(addRef.getOperation()))
135 Operation *firstFunctionCallUser =
nullptr;
136 Operation *lastNonFunctionCallUser =
nullptr;
138 for (Operation *user : info.users) {
143 if (user == dropRef || dropRef->isBeforeInBlock(user))
147 Operation *functionCall = dyn_cast<func::CallOp>(user);
149 (!firstFunctionCallUser ||
151 firstFunctionCallUser = functionCall;
157 (!lastNonFunctionCallUser ||
159 lastNonFunctionCallUser = user;
166 if (firstFunctionCallUser && lastNonFunctionCallUser &&
171 auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
172 addRef.getOperation());
174 if (!emplaced.second)
186void AsyncRuntimeRefCountingOptPass::runOnOperation() {
187 Operation *op = getOperation();
193 llvm::SmallDenseMap<Operation *, Operation *> cancellable;
196 WalkResult blockWalk = op->
walk([&](
Block *block) -> WalkResult {
199 if (
failed(optimizeReferenceCounting(arg, cancellable)))
209 WalkResult opWalk = op->
walk([&](Operation *op) -> WalkResult {
222 llvm::dbgs() <<
"Found " << cancellable.size()
223 <<
" cancellable reference counting operations\n";
227 for (
auto &kv : cancellable) {
BlockArgListType getArguments()
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_type_range getResultTypes()
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumResults()
Return the number of results held by this operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
user_range getUsers() const
Region * getParentRegion()
Return the Region in which this Value is defined.
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
bool isRefCounted(Type type)
Returns true if the type is reference counted at runtime.
Include the generated interface declarations.