29#include "llvm/Support/Debug.h"
33#define GEN_PASS_DEF_ASYNCTOASYNCRUNTIMEPASS
34#define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIMEPASS
35#include "mlir/Dialect/Async/Passes.h.inc"
41#define DEBUG_TYPE "async-to-async-runtime"
47class AsyncToAsyncRuntimePass
50 AsyncToAsyncRuntimePass() =
default;
58class AsyncFuncToAsyncRuntimePass
60 AsyncFuncToAsyncRuntimePass> {
62 AsyncFuncToAsyncRuntimePass() =
default;
90 std::optional<Value> asyncToken;
95 std::optional<Block *> setError;
118 Block *cleanupForDestroy;
124 std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;
171 assert(!
func.getBlocks().empty() &&
"Function must have an entry block");
174 Block *entryBlock = &
func.getBlocks().front();
175 Block *originalEntryBlock =
185 bool isStateful = isa<TokenType>(
func.getResultTypes().front());
187 std::optional<Value> retToken;
189 retToken.emplace(RuntimeCreateOp::create(builder, TokenType::get(ctx)));
193 isStateful ?
func.getResultTypes().drop_front() :
func.getResultTypes();
194 for (
auto resType : resValueTypes)
195 retValues.emplace_back(
196 RuntimeCreateOp::create(builder, resType).getResult());
201 auto coroIdOp = CoroIdOp::create(builder, CoroIdType::get(ctx));
203 CoroBeginOp::create(builder, CoroHandleType::get(ctx), coroIdOp.getId());
204 cf::BranchOp::create(builder, originalEntryBlock);
207 Block *cleanupBlockForDestroy =
func.addBlock();
213 auto buildCleanupBlock = [&](
Block *cb) {
214 builder.setInsertionPointToStart(cb);
215 CoroFreeOp::create(builder, coroIdOp.getId(), coroHdlOp.getHandle());
218 cf::BranchOp::create(builder, suspendBlock);
220 buildCleanupBlock(cleanupBlock);
221 buildCleanupBlock(cleanupBlockForDestroy);
227 builder.setInsertionPointToStart(suspendBlock);
230 CoroEndOp::create(builder, coroHdlOp.getHandle());
236 ret.push_back(*retToken);
237 llvm::append_range(ret, retValues);
238 func::ReturnOp::create(builder, ret);
245 func->setAttr(
"passthrough", builder.getArrayAttr(
246 StringAttr::get(ctx,
"presplitcoroutine")));
248 CoroMachinery machinery;
249 machinery.func =
func;
250 machinery.asyncToken = retToken;
251 machinery.returnValues = retValues;
252 machinery.coroHandle = coroHdlOp.getHandle();
253 machinery.entry = entryBlock;
254 machinery.setError = std::nullopt;
255 machinery.cleanup = cleanupBlock;
256 machinery.cleanupForDestroy = cleanupBlockForDestroy;
257 machinery.suspend = suspendBlock;
265 return *coro.setError;
267 coro.setError = coro.func.addBlock();
268 (*coro.setError)->moveBefore(coro.cleanup);
275 RuntimeSetErrorOp::create(builder, *coro.asyncToken);
277 for (
Value retValue : coro.returnValues)
278 RuntimeSetErrorOp::create(builder, retValue);
281 cf::BranchOp::create(builder, coro.cleanup);
283 return *coro.setError;
294static std::pair<func::FuncOp, CoroMachinery>
296 ModuleOp module = execute->getParentOfType<ModuleOp>();
307 execute.getDependencies());
308 functionInputs.insert_range(execute.getBodyOperands());
312 auto typesRange = llvm::map_range(
313 functionInputs, [](
Value value) {
return value.
getType(); });
315 auto outputTypes = execute.getResultTypes();
317 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
331 size_t numDependencies = execute.getDependencies().size();
332 size_t numOperands = execute.getBodyOperands().size();
335 for (
size_t i = 0; i < numDependencies; ++i)
336 AwaitOp::create(builder,
func.getArgument(i));
340 for (
size_t i = 0; i < numOperands; ++i) {
341 Value operand =
func.getArgument(numDependencies + i);
342 unwrappedOperands[i] = AwaitOp::create(builder, loc, operand).getResult();
348 valueMapping.
map(functionInputs,
func.getArguments());
349 valueMapping.
map(execute.getBodyRegion().getArguments(), unwrappedOperands);
353 for (
Operation &op : execute.getBodyRegion().getOps())
354 builder.clone(op, valueMapping);
364 cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->
getTerminator());
365 builder.setInsertionPointToEnd(coro.entry);
369 CoroSaveOp::create(builder, CoroStateType::get(ctx), coro.coroHandle);
373 RuntimeResumeOp::create(builder, coro.coroHandle);
376 CoroSuspendOp::create(builder, coroSaveOp.getState(), coro.suspend,
377 branch.getDest(), coro.cleanupForDestroy);
385 auto callOutlinedFunc = func::CallOp::create(callBuilder,
func.getName(),
386 execute.getResultTypes(),
387 functionInputs.getArrayRef());
388 execute.replaceAllUsesWith(callOutlinedFunc.getResults());
400class CreateGroupOpLowering :
public OpConversionPattern<CreateGroupOp> {
402 using OpConversionPattern::OpConversionPattern;
405 matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
406 ConversionPatternRewriter &rewriter)
const override {
407 rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
408 op, GroupType::get(op->getContext()), adaptor.getOperands());
419class AddToGroupOpLowering :
public OpConversionPattern<AddToGroupOp> {
421 using OpConversionPattern::OpConversionPattern;
424 matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
425 ConversionPatternRewriter &rewriter)
const override {
426 rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
427 op, rewriter.getIndexType(), adaptor.getOperands());
444class AsyncFuncOpLowering :
public OpConversionPattern<async::FuncOp> {
447 : OpConversionPattern<
async::FuncOp>(ctx), coros(std::move(coros)) {}
450 matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
451 ConversionPatternRewriter &rewriter)
const override {
455 func::FuncOp::create(rewriter, loc, op.getName(), op.getFunctionType());
460 for (
const auto &namedAttr : op->getAttrs()) {
462 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
465 rewriter.inlineRegionBefore(op.getBody(), newFuncOp.getBody(),
469 (*coros)[newFuncOp] = coro;
472 rewriter.eraseOp(op);
484class AsyncCallOpLowering :
public OpConversionPattern<async::CallOp> {
487 : OpConversionPattern<
async::CallOp>(ctx) {}
490 matchAndRewrite(async::CallOp op, OpAdaptor adaptor,
491 ConversionPatternRewriter &rewriter)
const override {
492 rewriter.replaceOpWithNewOp<func::CallOp>(
493 op, op.getCallee(), op.getResultTypes(), op.getOperands());
502class AsyncReturnOpLowering :
public OpConversionPattern<async::ReturnOp> {
505 : OpConversionPattern<async::ReturnOp>(ctx), coros(std::move(coros)) {}
508 matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
509 ConversionPatternRewriter &rewriter)
const override {
510 auto func = op->template getParentOfType<func::FuncOp>();
511 auto funcCoro = coros->find(func);
512 if (funcCoro == coros->end())
513 return rewriter.notifyMatchFailure(
514 op,
"operation is not inside the async coroutine function");
516 Location loc = op->getLoc();
517 const CoroMachinery &coro = funcCoro->getSecond();
518 rewriter.setInsertionPointAfter(op);
522 for (
auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
523 Value returnValue = std::get<0>(tuple);
524 Value asyncValue = std::get<1>(tuple);
525 RuntimeStoreOp::create(rewriter, loc, returnValue, asyncValue);
526 RuntimeSetAvailableOp::create(rewriter, loc, asyncValue);
531 RuntimeSetAvailableOp::create(rewriter, loc, *coro.asyncToken);
533 rewriter.eraseOp(op);
534 cf::BranchOp::create(rewriter, loc, coro.cleanup);
549template <
typename AwaitType,
typename AwaitableType>
550class AwaitOpLoweringBase :
public OpConversionPattern<AwaitType> {
551 using AwaitAdaptor =
typename AwaitType::Adaptor;
555 bool shouldLowerBlockingWait)
556 : OpConversionPattern<AwaitType>(ctx), coros(std::move(coros)),
557 shouldLowerBlockingWait(shouldLowerBlockingWait) {}
560 matchAndRewrite(AwaitType op,
typename AwaitType::Adaptor adaptor,
561 ConversionPatternRewriter &rewriter)
const override {
564 if (!isa<AwaitableType>(op.getOperand().getType()))
565 return rewriter.notifyMatchFailure(op,
"unsupported awaitable type");
568 auto func = op->template getParentOfType<func::FuncOp>();
569 auto funcCoro = coros->find(func);
570 const bool isInCoroutine = funcCoro != coros->end();
572 Location loc = op->getLoc();
573 Value operand = adaptor.getOperand();
575 Type i1 = rewriter.getI1Type();
578 if (!isInCoroutine && !shouldLowerBlockingWait)
583 if (!isInCoroutine) {
584 ImplicitLocOpBuilder builder(loc, rewriter);
585 RuntimeAwaitOp::create(builder, loc, operand);
588 Value isError = RuntimeIsErrorOp::create(builder, i1, operand);
589 Value notError = arith::XOrIOp::create(
591 arith::ConstantOp::create(builder, loc, i1,
592 builder.getIntegerAttr(i1, 1)));
594 cf::AssertOp::create(builder, notError,
595 "Awaited async operand is in error state");
601 CoroMachinery &coro = funcCoro->getSecond();
602 Block *suspended = op->getBlock();
604 ImplicitLocOpBuilder builder(loc, rewriter);
605 MLIRContext *ctx = op->getContext();
610 CoroSaveOp::create(builder, CoroStateType::get(ctx), coro.coroHandle);
611 RuntimeAwaitAndResumeOp::create(builder, operand, coro.coroHandle);
617 builder.setInsertionPointToEnd(suspended);
618 CoroSuspendOp::create(builder, coroSaveOp.getState(), coro.suspend,
619 resume, coro.cleanupForDestroy);
625 builder.setInsertionPointToStart(resume);
626 auto isError = RuntimeIsErrorOp::create(builder, loc, i1, operand);
627 cf::CondBranchOp::create(builder, isError,
635 rewriter.setInsertionPointToStart(continuation);
639 if (Value replaceWith = getReplacementValue(op, operand, rewriter))
640 rewriter.replaceOp(op, replaceWith);
642 rewriter.eraseOp(op);
647 virtual Value getReplacementValue(AwaitType op, Value operand,
648 ConversionPatternRewriter &rewriter)
const {
654 bool shouldLowerBlockingWait;
658class AwaitTokenOpLowering :
public AwaitOpLoweringBase<AwaitOp, TokenType> {
659 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
666class AwaitValueOpLowering :
public AwaitOpLoweringBase<AwaitOp, ValueType> {
667 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
673 getReplacementValue(AwaitOp op, Value operand,
674 ConversionPatternRewriter &rewriter)
const override {
676 auto valueType = cast<ValueType>(operand.
getType()).getValueType();
677 return RuntimeLoadOp::create(rewriter, op->getLoc(), valueType, operand);
682class AwaitAllOpLowering :
public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
683 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
698 : OpConversionPattern<
async::YieldOp>(ctx), coros(std::move(coros)) {}
702 ConversionPatternRewriter &rewriter)
const override {
704 auto func = op->template getParentOfType<func::FuncOp>();
705 auto funcCoro = coros->find(
func);
706 if (funcCoro == coros->end())
707 return rewriter.notifyMatchFailure(
708 op,
"operation is not inside the async coroutine function");
711 const CoroMachinery &coro = funcCoro->getSecond();
715 for (
auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
716 Value yieldValue = std::get<0>(tuple);
717 Value asyncValue = std::get<1>(tuple);
718 RuntimeStoreOp::create(rewriter, loc, yieldValue, asyncValue);
719 RuntimeSetAvailableOp::create(rewriter, loc, asyncValue);
724 RuntimeSetAvailableOp::create(rewriter, loc, *coro.asyncToken);
726 cf::BranchOp::create(rewriter, loc, coro.cleanup);
727 rewriter.eraseOp(op);
743 : OpConversionPattern<
cf::AssertOp>(ctx), coros(std::move(coros)) {}
747 ConversionPatternRewriter &rewriter)
const override {
749 auto func = op->template getParentOfType<func::FuncOp>();
750 auto funcCoro = coros->find(
func);
751 if (funcCoro == coros->end())
752 return rewriter.notifyMatchFailure(
753 op,
"operation is not inside the async coroutine function");
756 CoroMachinery &coro = funcCoro->getSecond();
759 rewriter.setInsertionPointToEnd(cont->getPrevNode());
760 cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(),
765 rewriter.eraseOp(op);
775void AsyncToAsyncRuntimePass::runOnOperation() {
776 ModuleOp module = getOperation();
777 SymbolTable symbolTable(module);
782 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
784 module.walk([&](ExecuteOp execute) {
789 llvm::dbgs() <<
"Outlined " << coros->size()
790 <<
" functions built from async.execute operations\n";
794 auto isInCoroutine = [&](Operation *op) ->
bool {
795 auto parentFunc = op->getParentOfType<func::FuncOp>();
796 return coros->contains(parentFunc);
800 MLIRContext *ctx =
module->getContext();
801 RewritePatternSet asyncPatterns(ctx);
811 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
814 .add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
818 asyncPatterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
821 ConversionTarget runtimeTarget(*ctx);
822 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
823 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
824 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
827 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
828 auto walkResult = op->walk([&](Operation *nested) {
829 bool isAsync = isa<async::AsyncDialect>(nested->
getDialect());
831 : WalkResult::advance();
833 return !walkResult.wasInterrupted();
835 runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
836 func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
839 runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
840 [&](cf::AssertOp op) ->
bool {
841 auto func = op->getParentOfType<func::FuncOp>();
842 return !coros->contains(func);
845 if (
failed(applyPartialConversion(module, runtimeTarget,
846 std::move(asyncPatterns)))) {
858 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
861 patterns.add<AsyncCallOpLowering>(ctx);
862 patterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
864 patterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
868 target.addDynamicallyLegalOp<AwaitOp, AwaitAllOp, YieldOp, cf::AssertOp>(
870 auto exec = op->getParentOfType<ExecuteOp>();
871 auto func = op->getParentOfType<func::FuncOp>();
872 return exec || !coros->contains(
func);
876void AsyncFuncToAsyncRuntimePass::runOnOperation() {
877 ModuleOp module = getOperation();
888 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
889 runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();
891 runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,
892 cf::BranchOp, cf::CondBranchOp>();
894 if (failed(applyPartialConversion(module, runtimeTarget,
895 std::move(asyncPatterns)))) {
static constexpr const char kAsyncFnPrefix[]
std::shared_ptr< llvm::DenseMap< func::FuncOp, CoroMachinery > > FuncCoroMapPtr
static Block * setupSetErrorBlock(CoroMachinery &coro)
static CoroMachinery setupCoroMachinery(func::FuncOp func)
Utility to partially update the regular function CFG to the coroutine CFG compatible with LLVM corout...
static std::pair< func::FuncOp, CoroMachinery > outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute)
Outline the body region attached to the async.execute op into a standalone function.
AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
Block represents an ordered list of Operations.
OpListType::iterator iterator
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
typename cf::AssertOp::Adaptor OpAdaptor
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
static ImplicitLocOpBuilder atBlockBegin(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult interrupt()
void cloneConstantsIntoTheRegion(Region ®ion)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
void populateAsyncFuncToAsyncRuntimeConversionPatterns(RewritePatternSet &patterns, ConversionTarget &target)