23 #include "llvm/ADT/SmallSet.h"
26 #define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTING
27 #define GEN_PASS_DEF_ASYNCRUNTIMEPOLICYBASEDREFCOUNTING
28 #include "mlir/Dialect/Async/Passes.h.inc"
31 #define DEBUG_TYPE "async-runtime-ref-counting"
66 if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
70 <<
"async operations must be lowered to async runtime operations";
80 if (failed(addRefCounting(arg)))
93 if (failed(addRefCounting(op->
getResult(i))))
111 class AsyncRuntimeRefCountingPass
112 :
public impl::AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> {
114 AsyncRuntimeRefCountingPass() =
default;
115 void runOnOperation()
override;
151 LogicalResult addAutomaticRefCounting(
Value value);
164 LogicalResult addDropRefAfterLastUse(
Value value);
169 LogicalResult addAddRefBeforeFunctionCall(
Value value);
227 LogicalResult addDropRefInDivergentLivenessSuccessor(
Value value);
232 LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(
Value value) {
237 auto &liveness = getAnalysis<Liveness>();
253 llvm::SmallSet<Operation *, 4> lastUsers;
263 Block *userBlock = user->getBlock();
266 assert(ancestor &&
"ancestor block must be not null");
267 assert(usersInTheBlocks[ancestor] &&
"ancestor op must be not null");
279 for (
auto &blockAndUser : usersInTheBlocks) {
280 Block *block = blockAndUser.getFirst();
281 Operation *userInTheBlock = blockAndUser.getSecond();
286 assert(blockLiveness->
isLiveIn(value) ||
296 assert(lastUsers.count(lastUser) == 0 &&
"last users must be unique");
297 lastUsers.insert(lastUser);
309 return lastUser->
emitError() <<
"async reference counting can't handle "
310 "terminators that are not ReturnLike";
313 builder.setInsertionPointAfter(lastUser);
314 builder.create<RuntimeDropRefOp>(loc, value, builder.getI64IntegerAttr(1));
321 AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(
Value value) {
326 if (!isa<func::CallOp>(user))
331 builder.setInsertionPoint(user);
332 builder.create<RuntimeAddRefOp>(loc, value, builder.getI64IntegerAttr(1));
339 AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor(
347 llvm::SmallDenseMap<Block *, BlockSet> divergentLivenessBlocks;
350 auto &liveness = getAnalysis<Liveness>();
361 if (!blockLiveness || !blockLiveness->
isLiveOut(value))
364 BlockSet liveInSuccessors;
365 BlockSet noLiveInSuccessors;
370 if (succLiveness && succLiveness->
isLiveIn(value))
371 liveInSuccessors.insert(successor);
373 noLiveInSuccessors.insert(successor);
377 if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty())
378 divergentLivenessBlocks.try_emplace(&block, noLiveInSuccessors);
383 for (
auto kv : divergentLivenessBlocks) {
384 Block *block = kv.getFirst();
385 BlockSet &successors = kv.getSecond();
390 if (isa<CoroSuspendOp>(terminator))
395 if (llvm::any_of(successors, hasArgs))
397 <<
"successor have different `liveIn` property of the reference "
402 for (
Block *successor : successors) {
408 Block *refCountingBlock =
nullptr;
410 if (successor->getUniquePredecessor() == block) {
411 refCountingBlock = successor;
416 builder.
create<cf::BranchOp>(value.
getLoc(), successor);
424 if (successor == refCountingBlock)
429 if (pair.value() == successor)
430 terminator->
setSuccessor(refCountingBlock, pair.index());
438 AsyncRuntimeRefCountingPass::addAutomaticRefCounting(
Value value) {
444 if (failed(addDropRefAfterLastUse(value)))
448 if (failed(addAddRefBeforeFunctionCall(value)))
452 if (failed(addDropRefInDivergentLivenessSuccessor(value)))
458 void AsyncRuntimeRefCountingPass::runOnOperation() {
459 auto functor = [&](
Value value) {
return addAutomaticRefCounting(value); };
470 class AsyncRuntimePolicyBasedRefCountingPass
471 :
public impl::AsyncRuntimePolicyBasedRefCountingBase<
472 AsyncRuntimePolicyBasedRefCountingPass> {
474 AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); }
476 void runOnOperation()
override;
481 LogicalResult addRefCounting(
Value value);
483 void initializeDefaultPolicy();
491 AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(
Value value) {
500 Location loc = operand.getOwner()->getLoc();
502 for (
auto &func : policy) {
503 FailureOr<int> refCount = func(operand);
504 if (failed(refCount))
511 b.setInsertionPoint(operand.getOwner());
512 b.create<RuntimeAddRefOp>(loc, value, b.getI64IntegerAttr(cnt));
517 b.setInsertionPointAfter(operand.getOwner());
518 b.create<RuntimeDropRefOp>(loc, value, b.getI64IntegerAttr(-cnt));
526 void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() {
527 policy.push_back([](
OpOperand &operand) -> FailureOr<int> {
531 bool isToken = isa<TokenType>(type);
532 bool isGroup = isa<GroupType>(type);
533 bool isValue = isa<ValueType>(type);
536 if (dyn_cast<RuntimeIsErrorOp>(op))
537 return (isToken || isGroup) ? -1 : 0;
540 if (dyn_cast<RuntimeLoadOp>(op))
541 return isValue ? -1 : 0;
544 if (dyn_cast<RuntimeAddToGroupOp>(op))
545 return isToken ? -1 : 0;
551 void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() {
552 auto functor = [&](
Value value) {
return addRefCounting(value); };
560 return std::make_unique<AsyncRuntimeRefCountingPass>();
564 return std::make_unique<AsyncRuntimePolicyBasedRefCountingPass>();
static LogicalResult walkReferenceCountedValues(Operation *op, llvm::function_ref< LogicalResult(Value)> addRefCounting)
static LogicalResult dropRefIfNoUses(Value value, unsigned count=1)
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
SuccessorRange getSuccessors()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
void moveBefore(Block *block)
Unlink this block from its current region and insert it right before the specific block.
IntegerAttr getI64IntegerAttr(int64_t value)
IRValueT get() const
Return the current value being used by this operand.
This class represents liveness information on block level.
Block * getBlock() const
Returns the underlying block.
bool isLiveIn(Value value) const
Returns true if the given value is in the live-in set.
bool isLiveOut(Value value) const
Returns true if the given value is in the live-out set.
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_type_range getResultTypes()
void setSuccessor(Block *block, unsigned index)
SuccessorRange getSuccessors()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
Block * findAncestorBlockInRegion(Block &block)
Returns 'block' if 'block' lies in this region, or otherwise finds the ancestor of 'block' that lies ...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Operation * getOwner() const
Return the owner of this operand.
bool isRefCounted(Type type)
Returns true if the type is reference counted at runtime.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::unique_ptr< Pass > createAsyncRuntimePolicyBasedRefCountingPass()
std::unique_ptr< Pass > createAsyncRuntimeRefCountingPass()
This trait indicates that a terminator operation is "return-like".