30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/Support/Debug.h"
35 #define GEN_PASS_DEF_ASYNCTOASYNCRUNTIME
36 #define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIME
37 #include "mlir/Dialect/Async/Passes.h.inc"
43 #define DEBUG_TYPE "async-to-async-runtime"
49 class AsyncToAsyncRuntimePass
50 :
public impl::AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
52 AsyncToAsyncRuntimePass() =
default;
53 void runOnOperation()
override;
60 class AsyncFuncToAsyncRuntimePass
61 :
public impl::AsyncFuncToAsyncRuntimeBase<AsyncFuncToAsyncRuntimePass> {
63 AsyncFuncToAsyncRuntimePass() =
default;
64 void runOnOperation()
override;
75 struct CoroMachinery {
91 std::optional<Value> asyncToken;
96 std::optional<Block *> setError;
119 Block *cleanupForDestroy;
125 std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;
172 assert(!func.getBlocks().empty() &&
"Function must have an entry block");
175 Block *entryBlock = &func.getBlocks().
front();
176 Block *originalEntryBlock =
186 bool isStateful = isa<TokenType>(func.getResultTypes().front());
188 std::optional<Value> retToken;
190 retToken.emplace(builder.create<RuntimeCreateOp>(
TokenType::get(ctx)));
194 isStateful ? func.getResultTypes().drop_front() : func.getResultTypes();
195 for (
auto resType : resValueTypes)
196 retValues.emplace_back(
197 builder.create<RuntimeCreateOp>(resType).getResult());
205 builder.create<cf::BranchOp>(originalEntryBlock);
207 Block *cleanupBlock = func.addBlock();
208 Block *cleanupBlockForDestroy = func.addBlock();
209 Block *suspendBlock = func.addBlock();
214 auto buildCleanupBlock = [&](
Block *cb) {
215 builder.setInsertionPointToStart(cb);
216 builder.create<CoroFreeOp>(coroIdOp.getId(), coroHdlOp.getHandle());
219 builder.create<cf::BranchOp>(suspendBlock);
221 buildCleanupBlock(cleanupBlock);
222 buildCleanupBlock(cleanupBlockForDestroy);
228 builder.setInsertionPointToStart(suspendBlock);
231 builder.create<CoroEndOp>(coroHdlOp.getHandle());
237 ret.push_back(*retToken);
238 ret.insert(ret.end(), retValues.begin(), retValues.end());
239 builder.create<func::ReturnOp>(ret);
246 func->setAttr(
"passthrough", builder.getArrayAttr(
249 CoroMachinery machinery;
250 machinery.func = func;
251 machinery.asyncToken = retToken;
252 machinery.returnValues = retValues;
253 machinery.coroHandle = coroHdlOp.getHandle();
254 machinery.entry = entryBlock;
255 machinery.setError = std::nullopt;
256 machinery.cleanup = cleanupBlock;
257 machinery.cleanupForDestroy = cleanupBlockForDestroy;
258 machinery.suspend = suspendBlock;
266 return *coro.setError;
268 coro.setError = coro.func.addBlock();
269 (*coro.setError)->moveBefore(coro.cleanup);
276 builder.create<RuntimeSetErrorOp>(*coro.asyncToken);
278 for (
Value retValue : coro.returnValues)
279 builder.create<RuntimeSetErrorOp>(retValue);
282 builder.create<cf::BranchOp>(coro.cleanup);
284 return *coro.setError;
295 static std::pair<func::FuncOp, CoroMachinery>
297 ModuleOp module = execute->getParentOfType<ModuleOp>();
308 execute.getDependencies().end());
309 functionInputs.insert_range(execute.getBodyOperands());
313 auto typesRange = llvm::map_range(
314 functionInputs, [](
Value value) {
return value.
getType(); });
316 auto outputTypes = execute.getResultTypes();
332 size_t numDependencies = execute.getDependencies().size();
333 size_t numOperands = execute.getBodyOperands().size();
336 for (
size_t i = 0; i < numDependencies; ++i)
337 builder.create<AwaitOp>(func.getArgument(i));
341 for (
size_t i = 0; i < numOperands; ++i) {
342 Value operand = func.getArgument(numDependencies + i);
343 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).getResult();
349 valueMapping.
map(functionInputs, func.getArguments());
350 valueMapping.
map(execute.getBodyRegion().getArguments(), unwrappedOperands);
354 for (
Operation &op : execute.getBodyRegion().getOps())
355 builder.clone(op, valueMapping);
365 cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
366 builder.setInsertionPointToEnd(coro.entry);
374 builder.create<RuntimeResumeOp>(coro.coroHandle);
377 builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend,
378 branch.getDest(), coro.cleanupForDestroy);
386 auto callOutlinedFunc = callBuilder.
create<func::CallOp>(
387 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
388 execute.replaceAllUsesWith(callOutlinedFunc.getResults());
405 matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
424 matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
450 matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
455 rewriter.
create<func::FuncOp>(loc, op.getName(), op.getFunctionType());
460 for (
const auto &namedAttr : op->getAttrs()) {
462 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
469 (*coros)[newFuncOp] = coro;
490 matchAndRewrite(async::CallOp op, OpAdaptor adaptor,
493 op, op.getCallee(), op.getResultTypes(), op.getOperands());
508 matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
510 auto func = op->template getParentOfType<func::FuncOp>();
511 auto funcCoro = coros->find(func);
512 if (funcCoro == coros->end())
514 op,
"operation is not inside the async coroutine function");
517 const CoroMachinery &coro = funcCoro->getSecond();
522 for (
auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
523 Value returnValue = std::get<0>(tuple);
524 Value asyncValue = std::get<1>(tuple);
525 rewriter.
create<RuntimeStoreOp>(loc, returnValue, asyncValue);
526 rewriter.
create<RuntimeSetAvailableOp>(loc, asyncValue);
531 rewriter.
create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
534 rewriter.
create<cf::BranchOp>(loc, coro.cleanup);
549 template <
typename AwaitType,
typename AwaitableType>
551 using AwaitAdaptor =
typename AwaitType::Adaptor;
555 bool shouldLowerBlockingWait)
557 shouldLowerBlockingWait(shouldLowerBlockingWait) {}
560 matchAndRewrite(AwaitType op,
typename AwaitType::Adaptor adaptor,
564 if (!isa<AwaitableType>(op.getOperand().getType()))
568 auto func = op->template getParentOfType<func::FuncOp>();
569 auto funcCoro = coros->find(func);
570 const bool isInCoroutine = funcCoro != coros->end();
573 Value operand = adaptor.getOperand();
578 if (!isInCoroutine && !shouldLowerBlockingWait)
583 if (!isInCoroutine) {
585 builder.create<RuntimeAwaitOp>(loc, operand);
588 Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
589 Value notError = builder.create<arith::XOrIOp>(
590 isError, builder.create<arith::ConstantOp>(
591 loc, i1, builder.getIntegerAttr(i1, 1)));
593 builder.create<cf::AssertOp>(notError,
594 "Awaited async operand is in error state");
600 CoroMachinery &coro = funcCoro->getSecond();
601 Block *suspended = op->getBlock();
610 builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
616 builder.setInsertionPointToEnd(suspended);
617 builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend, resume,
618 coro.cleanupForDestroy);
624 builder.setInsertionPointToStart(resume);
625 auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
626 builder.create<cf::CondBranchOp>(isError,
638 if (
Value replaceWith = getReplacementValue(op, operand, rewriter))
646 virtual Value getReplacementValue(AwaitType op,
Value operand,
653 bool shouldLowerBlockingWait;
657 class AwaitTokenOpLowering :
public AwaitOpLoweringBase<AwaitOp, TokenType> {
658 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
665 class AwaitValueOpLowering :
public AwaitOpLoweringBase<AwaitOp, ValueType> {
666 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
672 getReplacementValue(AwaitOp op,
Value operand,
675 auto valueType = cast<ValueType>(operand.
getType()).getValueType();
676 return rewriter.
create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
681 class AwaitAllOpLowering :
public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
682 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
703 auto func = op->template getParentOfType<func::FuncOp>();
704 auto funcCoro = coros->find(func);
705 if (funcCoro == coros->end())
707 op,
"operation is not inside the async coroutine function");
710 const CoroMachinery &coro = funcCoro->getSecond();
714 for (
auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
715 Value yieldValue = std::get<0>(tuple);
716 Value asyncValue = std::get<1>(tuple);
717 rewriter.
create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
718 rewriter.
create<RuntimeSetAvailableOp>(loc, asyncValue);
723 rewriter.
create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
726 rewriter.
create<cf::BranchOp>(loc, coro.cleanup);
748 auto func = op->template getParentOfType<func::FuncOp>();
749 auto funcCoro = coros->find(func);
750 if (funcCoro == coros->end())
752 op,
"operation is not inside the async coroutine function");
755 CoroMachinery &coro = funcCoro->getSecond();
759 rewriter.
create<cf::CondBranchOp>(loc, adaptor.getArg(),
774 void AsyncToAsyncRuntimePass::runOnOperation() {
775 ModuleOp module = getOperation();
781 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
783 module.walk([&](ExecuteOp execute) {
788 llvm::dbgs() <<
"Outlined " << coros->size()
789 <<
" functions built from async.execute operations\n";
793 auto isInCoroutine = [&](
Operation *op) ->
bool {
794 auto parentFunc = op->getParentOfType<func::FuncOp>();
795 return coros->find(parentFunc) != coros->end();
810 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
813 .add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
821 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
822 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
823 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
826 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](
Operation *op) {
827 auto walkResult = op->walk([&](
Operation *nested) {
828 bool isAsync = isa<async::AsyncDialect>(nested->
getDialect());
832 return !walkResult.wasInterrupted();
834 runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
835 func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
838 runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
839 [&](cf::AssertOp op) ->
bool {
840 auto func = op->getParentOfType<func::FuncOp>();
841 return !coros->contains(func);
845 std::move(asyncPatterns)))) {
857 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
860 patterns.add<AsyncCallOpLowering>(ctx);
861 patterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
863 patterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
869 auto exec = op->getParentOfType<ExecuteOp>();
870 auto func = op->getParentOfType<func::FuncOp>();
871 return exec || !coros->contains(func);
875 void AsyncFuncToAsyncRuntimePass::runOnOperation() {
876 ModuleOp module = getOperation();
887 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
888 runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();
890 runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,
891 cf::BranchOp, cf::CondBranchOp>();
894 std::move(asyncPatterns)))) {
901 return std::make_unique<AsyncToAsyncRuntimePass>();
904 std::unique_ptr<OperationPass<ModuleOp>>
906 return std::make_unique<AsyncFuncToAsyncRuntimePass>();
static Block * setupSetErrorBlock(CoroMachinery &coro)
std::shared_ptr< llvm::DenseMap< func::FuncOp, CoroMachinery > > FuncCoroMapPtr
static constexpr const char kAsyncFnPrefix[]
static std::pair< func::FuncOp, CoroMachinery > outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute)
Outline the body region attached to the async.execute op into a standalone function.
static CoroMachinery setupCoroMachinery(func::FuncOp func)
Utility to partially update the regular function CFG to the coroutine CFG compatible with LLVM corout...
AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
LogicalResult matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
Block represents an ordered list of Operations.
OpListType::iterator iterator
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
OpListType & getOperations()
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
static ImplicitLocOpBuilder atBlockBegin(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult interrupt()
void cloneConstantsIntoTheRegion(Region ®ion)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
void populateAsyncFuncToAsyncRuntimeConversionPatterns(RewritePatternSet &patterns, ConversionTarget &target)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< OperationPass< ModuleOp > > createAsyncToAsyncRuntimePass()
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
std::unique_ptr< OperationPass< ModuleOp > > createAsyncFuncToAsyncRuntimePass()