20 #include "llvm/ADT/SmallSet.h"
23 #define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTINGPASS
24 #define GEN_PASS_DEF_ASYNCRUNTIMEPOLICYBASEDREFCOUNTINGPASS
25 #include "mlir/Dialect/Async/Passes.h.inc"
28 #define DEBUG_TYPE "async-runtime-ref-counting"
63 if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
67 <<
"async operations must be lowered to async runtime operations";
77 if (failed(addRefCounting(arg)))
90 if (failed(addRefCounting(op->
getResult(i))))
108 class AsyncRuntimeRefCountingPass
109 :
public impl::AsyncRuntimeRefCountingPassBase<
110 AsyncRuntimeRefCountingPass> {
112 AsyncRuntimeRefCountingPass() =
default;
113 void runOnOperation()
override;
149 LogicalResult addAutomaticRefCounting(
Value value);
162 LogicalResult addDropRefAfterLastUse(
Value value);
167 LogicalResult addAddRefBeforeFunctionCall(
Value value);
225 LogicalResult addDropRefInDivergentLivenessSuccessor(
Value value);
230 LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(
Value value) {
235 auto &liveness = getAnalysis<Liveness>();
251 llvm::SmallSet<Operation *, 4> lastUsers;
261 Block *userBlock = user->getBlock();
264 assert(ancestor &&
"ancestor block must be not null");
265 assert(usersInTheBlocks[ancestor] &&
"ancestor op must be not null");
277 for (
auto &blockAndUser : usersInTheBlocks) {
278 Block *block = blockAndUser.getFirst();
279 Operation *userInTheBlock = blockAndUser.getSecond();
284 assert(blockLiveness->
isLiveIn(value) ||
294 assert(lastUsers.count(lastUser) == 0 &&
"last users must be unique");
295 lastUsers.insert(lastUser);
307 return lastUser->
emitError() <<
"async reference counting can't handle "
308 "terminators that are not ReturnLike";
311 builder.setInsertionPointAfter(lastUser);
312 RuntimeDropRefOp::create(builder, loc, value, builder.getI64IntegerAttr(1));
319 AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(
Value value) {
324 if (!isa<func::CallOp>(user))
329 builder.setInsertionPoint(user);
330 RuntimeAddRefOp::create(builder, loc, value, builder.getI64IntegerAttr(1));
337 AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor(
345 llvm::SmallDenseMap<Block *, BlockSet> divergentLivenessBlocks;
348 auto &liveness = getAnalysis<Liveness>();
359 if (!blockLiveness || !blockLiveness->
isLiveOut(value))
362 BlockSet liveInSuccessors;
363 BlockSet noLiveInSuccessors;
368 if (succLiveness && succLiveness->
isLiveIn(value))
369 liveInSuccessors.insert(successor);
371 noLiveInSuccessors.insert(successor);
375 if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty())
376 divergentLivenessBlocks.try_emplace(&block, noLiveInSuccessors);
381 for (
auto kv : divergentLivenessBlocks) {
382 Block *block = kv.getFirst();
383 BlockSet &successors = kv.getSecond();
388 if (isa<CoroSuspendOp>(terminator))
393 if (llvm::any_of(successors, hasArgs))
395 <<
"successor have different `liveIn` property of the reference "
400 for (
Block *successor : successors) {
406 Block *refCountingBlock =
nullptr;
408 if (successor->getUniquePredecessor() == block) {
409 refCountingBlock = successor;
414 cf::BranchOp::create(builder, value.
getLoc(), successor);
418 RuntimeDropRefOp::create(builder, value.
getLoc(), value,
422 if (successor == refCountingBlock)
427 if (pair.value() == successor)
428 terminator->
setSuccessor(refCountingBlock, pair.index());
436 AsyncRuntimeRefCountingPass::addAutomaticRefCounting(
Value value) {
442 if (failed(addDropRefAfterLastUse(value)))
446 if (failed(addAddRefBeforeFunctionCall(value)))
450 if (failed(addDropRefInDivergentLivenessSuccessor(value)))
456 void AsyncRuntimeRefCountingPass::runOnOperation() {
457 auto functor = [&](
Value value) {
return addAutomaticRefCounting(value); };
468 class AsyncRuntimePolicyBasedRefCountingPass
469 :
public impl::AsyncRuntimePolicyBasedRefCountingPassBase<
470 AsyncRuntimePolicyBasedRefCountingPass> {
472 AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); }
474 void runOnOperation()
override;
479 LogicalResult addRefCounting(
Value value);
481 void initializeDefaultPolicy();
489 AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(
Value value) {
498 Location loc = operand.getOwner()->getLoc();
500 for (
auto &func : policy) {
501 FailureOr<int> refCount = func(operand);
502 if (failed(refCount))
509 b.setInsertionPoint(operand.getOwner());
510 RuntimeAddRefOp::create(b, loc, value, b.getI64IntegerAttr(cnt));
515 b.setInsertionPointAfter(operand.getOwner());
516 RuntimeDropRefOp::create(b, loc, value, b.getI64IntegerAttr(-cnt));
524 void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() {
525 policy.push_back([](
OpOperand &operand) -> FailureOr<int> {
529 bool isToken = isa<TokenType>(type);
530 bool isGroup = isa<GroupType>(type);
531 bool isValue = isa<ValueType>(type);
534 if (isa<RuntimeIsErrorOp>(op))
535 return (isToken || isGroup) ? -1 : 0;
538 if (isa<RuntimeLoadOp>(op))
539 return isValue ? -1 : 0;
542 if (isa<RuntimeAddToGroupOp>(op))
543 return isToken ? -1 : 0;
549 void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() {
550 auto functor = [&](
Value value) {
return addRefCounting(value); };
static LogicalResult walkReferenceCountedValues(Operation *op, llvm::function_ref< LogicalResult(Value)> addRefCounting)
static LogicalResult dropRefIfNoUses(Value value, unsigned count=1)
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
SuccessorRange getSuccessors()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
void moveBefore(Block *block)
Unlink this block from its current region and insert it right before the specific block.
IntegerAttr getI64IntegerAttr(int64_t value)
IRValueT get() const
Return the current value being used by this operand.
This class represents liveness information on block level.
Block * getBlock() const
Returns the underlying block.
bool isLiveIn(Value value) const
Returns true if the given value is in the live-in set.
bool isLiveOut(Value value) const
Returns true if the given value is in the live-out set.
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_type_range getResultTypes()
void setSuccessor(Block *block, unsigned index)
SuccessorRange getSuccessors()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
Block * findAncestorBlockInRegion(Block &block)
Returns 'block' if 'block' lies in this region, or otherwise finds the ancestor of 'block' that lies ...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Operation * getOwner() const
Return the owner of this operand.
bool isRefCounted(Type type)
Returns true if the type is reference counted at runtime.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
This trait indicates that a terminator operation is "return-like".