23 #include "llvm/ADT/SmallSet.h"
26 #define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTINGPASS
27 #define GEN_PASS_DEF_ASYNCRUNTIMEPOLICYBASEDREFCOUNTINGPASS
28 #include "mlir/Dialect/Async/Passes.h.inc"
31 #define DEBUG_TYPE "async-runtime-ref-counting"
66 if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
70 <<
"async operations must be lowered to async runtime operations";
80 if (failed(addRefCounting(arg)))
93 if (failed(addRefCounting(op->
getResult(i))))
111 class AsyncRuntimeRefCountingPass
112 :
public impl::AsyncRuntimeRefCountingPassBase<
113 AsyncRuntimeRefCountingPass> {
115 AsyncRuntimeRefCountingPass() =
default;
116 void runOnOperation()
override;
152 LogicalResult addAutomaticRefCounting(
Value value);
165 LogicalResult addDropRefAfterLastUse(
Value value);
170 LogicalResult addAddRefBeforeFunctionCall(
Value value);
228 LogicalResult addDropRefInDivergentLivenessSuccessor(
Value value);
233 LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(
Value value) {
238 auto &liveness = getAnalysis<Liveness>();
254 llvm::SmallSet<Operation *, 4> lastUsers;
264 Block *userBlock = user->getBlock();
267 assert(ancestor &&
"ancestor block must be not null");
268 assert(usersInTheBlocks[ancestor] &&
"ancestor op must be not null");
280 for (
auto &blockAndUser : usersInTheBlocks) {
281 Block *block = blockAndUser.getFirst();
282 Operation *userInTheBlock = blockAndUser.getSecond();
287 assert(blockLiveness->
isLiveIn(value) ||
297 assert(lastUsers.count(lastUser) == 0 &&
"last users must be unique");
298 lastUsers.insert(lastUser);
310 return lastUser->
emitError() <<
"async reference counting can't handle "
311 "terminators that are not ReturnLike";
314 builder.setInsertionPointAfter(lastUser);
315 builder.create<RuntimeDropRefOp>(loc, value, builder.getI64IntegerAttr(1));
322 AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(
Value value) {
327 if (!isa<func::CallOp>(user))
332 builder.setInsertionPoint(user);
333 builder.create<RuntimeAddRefOp>(loc, value, builder.getI64IntegerAttr(1));
340 AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor(
348 llvm::SmallDenseMap<Block *, BlockSet> divergentLivenessBlocks;
351 auto &liveness = getAnalysis<Liveness>();
362 if (!blockLiveness || !blockLiveness->
isLiveOut(value))
365 BlockSet liveInSuccessors;
366 BlockSet noLiveInSuccessors;
371 if (succLiveness && succLiveness->
isLiveIn(value))
372 liveInSuccessors.insert(successor);
374 noLiveInSuccessors.insert(successor);
378 if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty())
379 divergentLivenessBlocks.try_emplace(&block, noLiveInSuccessors);
384 for (
auto kv : divergentLivenessBlocks) {
385 Block *block = kv.getFirst();
386 BlockSet &successors = kv.getSecond();
391 if (isa<CoroSuspendOp>(terminator))
396 if (llvm::any_of(successors, hasArgs))
398 <<
"successor have different `liveIn` property of the reference "
403 for (
Block *successor : successors) {
409 Block *refCountingBlock =
nullptr;
411 if (successor->getUniquePredecessor() == block) {
412 refCountingBlock = successor;
417 builder.
create<cf::BranchOp>(value.
getLoc(), successor);
425 if (successor == refCountingBlock)
430 if (pair.value() == successor)
431 terminator->
setSuccessor(refCountingBlock, pair.index());
439 AsyncRuntimeRefCountingPass::addAutomaticRefCounting(
Value value) {
445 if (failed(addDropRefAfterLastUse(value)))
449 if (failed(addAddRefBeforeFunctionCall(value)))
453 if (failed(addDropRefInDivergentLivenessSuccessor(value)))
459 void AsyncRuntimeRefCountingPass::runOnOperation() {
460 auto functor = [&](
Value value) {
return addAutomaticRefCounting(value); };
471 class AsyncRuntimePolicyBasedRefCountingPass
472 :
public impl::AsyncRuntimePolicyBasedRefCountingPassBase<
473 AsyncRuntimePolicyBasedRefCountingPass> {
475 AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); }
477 void runOnOperation()
override;
482 LogicalResult addRefCounting(
Value value);
484 void initializeDefaultPolicy();
492 AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(
Value value) {
501 Location loc = operand.getOwner()->getLoc();
503 for (
auto &func : policy) {
504 FailureOr<int> refCount = func(operand);
505 if (failed(refCount))
512 b.setInsertionPoint(operand.getOwner());
513 b.create<RuntimeAddRefOp>(loc, value, b.getI64IntegerAttr(cnt));
518 b.setInsertionPointAfter(operand.getOwner());
519 b.create<RuntimeDropRefOp>(loc, value, b.getI64IntegerAttr(-cnt));
527 void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() {
528 policy.push_back([](
OpOperand &operand) -> FailureOr<int> {
532 bool isToken = isa<TokenType>(type);
533 bool isGroup = isa<GroupType>(type);
534 bool isValue = isa<ValueType>(type);
537 if (dyn_cast<RuntimeIsErrorOp>(op))
538 return (isToken || isGroup) ? -1 : 0;
541 if (dyn_cast<RuntimeLoadOp>(op))
542 return isValue ? -1 : 0;
545 if (dyn_cast<RuntimeAddToGroupOp>(op))
546 return isToken ? -1 : 0;
552 void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() {
553 auto functor = [&](
Value value) {
return addRefCounting(value); };
static LogicalResult walkReferenceCountedValues(Operation *op, llvm::function_ref< LogicalResult(Value)> addRefCounting)
static LogicalResult dropRefIfNoUses(Value value, unsigned count=1)
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
SuccessorRange getSuccessors()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
void moveBefore(Block *block)
Unlink this block from its current region and insert it right before the specific block.
IntegerAttr getI64IntegerAttr(int64_t value)
IRValueT get() const
Return the current value being used by this operand.
This class represents liveness information on block level.
Block * getBlock() const
Returns the underlying block.
bool isLiveIn(Value value) const
Returns true if the given value is in the live-in set.
bool isLiveOut(Value value) const
Returns true if the given value is in the live-out set.
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_type_range getResultTypes()
void setSuccessor(Block *block, unsigned index)
SuccessorRange getSuccessors()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
Block * findAncestorBlockInRegion(Block &block)
Returns 'block' if 'block' lies in this region, or otherwise finds the ancestor of 'block' that lies ...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Operation * getOwner() const
Return the owner of this operand.
bool isRefCounted(Type type)
Returns true if the type is reference counted at runtime.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
This trait indicates that a terminator operation is "return-like".