29#include "llvm/Support/Debug.h"
33#define GEN_PASS_DEF_ASYNCTOASYNCRUNTIMEPASS
34#define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIMEPASS
35#include "mlir/Dialect/Async/Passes.h.inc"
41#define DEBUG_TYPE "async-to-async-runtime"
47class AsyncToAsyncRuntimePass
50 AsyncToAsyncRuntimePass() =
default;
58class AsyncFuncToAsyncRuntimePass
60 AsyncFuncToAsyncRuntimePass> {
62 AsyncFuncToAsyncRuntimePass() =
default;
90 std::optional<Value> asyncToken;
96 std::optional<Block *> setError;
124 std::optional<Block *> cleanupForDestroy;
130 std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;
177 assert(!
func.getBlocks().empty() &&
"Function must have an entry block");
180 Block *entryBlock = &
func.getBlocks().front();
181 Block *originalEntryBlock =
191 bool isStateful = isa<async::TokenType>(
func.getResultTypes().front());
193 std::optional<Value> retToken;
196 RuntimeCreateOp::create(builder, async::TokenType::get(ctx)));
200 isStateful ?
func.getResultTypes().drop_front() :
func.getResultTypes();
201 for (
auto resType : resValueTypes)
202 retValues.emplace_back(
203 RuntimeCreateOp::create(builder, resType).getResult());
208 auto coroIdOp = CoroIdOp::create(builder, CoroIdType::get(ctx));
210 CoroBeginOp::create(builder, CoroHandleType::get(ctx), coroIdOp.getId());
211 cf::BranchOp::create(builder, originalEntryBlock);
221 builder.setInsertionPointToStart(cleanupBlock);
222 CoroFreeOp::create(builder, coroIdOp.getId(), coroHdlOp.getHandle());
223 cf::BranchOp::create(builder, suspendBlock);
229 builder.setInsertionPointToStart(suspendBlock);
232 CoroEndOp::create(builder, coroHdlOp.getHandle());
238 ret.push_back(*retToken);
239 llvm::append_range(ret, retValues);
240 func::ReturnOp::create(builder, ret);
247 func->setAttr(
"llvm.passthrough", builder.getArrayAttr(StringAttr::get(
248 ctx,
"presplitcoroutine")));
250 CoroMachinery machinery;
251 machinery.func =
func;
252 machinery.asyncToken = retToken;
253 machinery.returnValues = retValues;
254 machinery.coroId = coroIdOp.getId();
255 machinery.coroHandle = coroHdlOp.getHandle();
256 machinery.entry = entryBlock;
257 machinery.setError = std::nullopt;
258 machinery.cleanup = cleanupBlock;
259 machinery.cleanupForDestroy = std::nullopt;
260 machinery.suspend = suspendBlock;
268 return *coro.setError;
270 coro.setError = coro.func.addBlock();
271 (*coro.setError)->moveBefore(coro.cleanup);
278 RuntimeSetErrorOp::create(builder, *coro.asyncToken);
280 for (
Value retValue : coro.returnValues)
281 RuntimeSetErrorOp::create(builder, retValue);
284 cf::BranchOp::create(builder, coro.cleanup);
286 return *coro.setError;
293 CoroMachinery &coro) {
294 if (coro.cleanupForDestroy)
295 return *coro.cleanupForDestroy;
297 coro.cleanupForDestroy = builder.
createBlock(coro.suspend);
298 CoroFreeOp::create(builder, coro.coroId, coro.coroHandle);
299 cf::BranchOp::create(builder, coro.suspend);
300 return *coro.cleanupForDestroy;
311static std::pair<func::FuncOp, CoroMachinery>
313 ModuleOp module = execute->getParentOfType<ModuleOp>();
324 execute.getDependencies());
325 functionInputs.insert_range(execute.getBodyOperands());
329 auto typesRange = llvm::map_range(
330 functionInputs, [](
Value value) {
return value.
getType(); });
332 auto outputTypes = execute.getResultTypes();
334 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
348 size_t numDependencies = execute.getDependencies().size();
349 size_t numOperands = execute.getBodyOperands().size();
352 for (
size_t i = 0; i < numDependencies; ++i)
353 AwaitOp::create(builder,
func.getArgument(i));
357 for (
size_t i = 0; i < numOperands; ++i) {
358 Value operand =
func.getArgument(numDependencies + i);
359 unwrappedOperands[i] = AwaitOp::create(builder, loc, operand).getResult();
365 valueMapping.
map(functionInputs,
func.getArguments());
366 valueMapping.
map(execute.getBodyRegion().getArguments(), unwrappedOperands);
370 for (
Operation &op : execute.getBodyRegion().getOps())
371 builder.clone(op, valueMapping);
381 cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->
getTerminator());
382 builder.setInsertionPointToEnd(coro.entry);
386 CoroSaveOp::create(builder, CoroStateType::get(ctx), coro.coroHandle);
390 RuntimeResumeOp::create(builder, coro.coroHandle);
394 CoroSuspendOp::create(builder, coroSaveOp.getState(), coro.suspend,
395 branch.getDest(), destroy);
403 auto callOutlinedFunc = func::CallOp::create(callBuilder,
func.getName(),
404 execute.getResultTypes(),
405 functionInputs.getArrayRef());
406 execute.replaceAllUsesWith(callOutlinedFunc.getResults());
418class CreateGroupOpLowering :
public OpConversionPattern<CreateGroupOp> {
420 using OpConversionPattern::OpConversionPattern;
423 matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
424 ConversionPatternRewriter &rewriter)
const override {
425 rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
426 op, GroupType::get(op->getContext()), adaptor.getOperands());
437class AddToGroupOpLowering :
public OpConversionPattern<AddToGroupOp> {
439 using OpConversionPattern::OpConversionPattern;
442 matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
443 ConversionPatternRewriter &rewriter)
const override {
444 rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
445 op, rewriter.getIndexType(), adaptor.getOperands());
462class AsyncFuncOpLowering :
public OpConversionPattern<async::FuncOp> {
465 : OpConversionPattern<async::FuncOp>(ctx), coros(std::move(coros)) {}
468 matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
469 ConversionPatternRewriter &rewriter)
const override {
473 func::FuncOp::create(rewriter, loc, op.getName(), op.getFunctionType());
478 for (
const auto &namedAttr : op->getAttrs()) {
480 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
483 rewriter.inlineRegionBefore(op.getBody(), newFuncOp.getBody(),
487 (*coros)[newFuncOp] = coro;
490 rewriter.eraseOp(op);
502class AsyncCallOpLowering :
public OpConversionPattern<async::CallOp> {
504 AsyncCallOpLowering(MLIRContext *ctx)
505 : OpConversionPattern<async::CallOp>(ctx) {}
508 matchAndRewrite(async::CallOp op, OpAdaptor adaptor,
509 ConversionPatternRewriter &rewriter)
const override {
510 rewriter.replaceOpWithNewOp<func::CallOp>(
511 op, op.getCallee(), op.getResultTypes(), op.getOperands());
520class AsyncReturnOpLowering :
public OpConversionPattern<async::ReturnOp> {
523 : OpConversionPattern<async::ReturnOp>(ctx), coros(std::move(coros)) {}
526 matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
527 ConversionPatternRewriter &rewriter)
const override {
528 auto func = op->template getParentOfType<func::FuncOp>();
529 auto funcCoro = coros->find(func);
530 if (funcCoro == coros->end())
531 return rewriter.notifyMatchFailure(
532 op,
"operation is not inside the async coroutine function");
534 Location loc = op->getLoc();
535 const CoroMachinery &coro = funcCoro->getSecond();
536 rewriter.setInsertionPointAfter(op);
540 for (
auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
541 Value returnValue = std::get<0>(tuple);
542 Value asyncValue = std::get<1>(tuple);
543 RuntimeStoreOp::create(rewriter, loc, returnValue, asyncValue);
544 RuntimeSetAvailableOp::create(rewriter, loc, asyncValue);
549 RuntimeSetAvailableOp::create(rewriter, loc, *coro.asyncToken);
551 rewriter.eraseOp(op);
552 cf::BranchOp::create(rewriter, loc, coro.cleanup);
567template <
typename AwaitType,
typename AwaitableType>
568class AwaitOpLoweringBase :
public OpConversionPattern<AwaitType> {
569 using AwaitAdaptor =
typename AwaitType::Adaptor;
573 bool shouldLowerBlockingWait)
574 : OpConversionPattern<AwaitType>(ctx), coros(std::move(coros)),
575 shouldLowerBlockingWait(shouldLowerBlockingWait) {}
578 matchAndRewrite(AwaitType op,
typename AwaitType::Adaptor adaptor,
579 ConversionPatternRewriter &rewriter)
const override {
582 if (!isa<AwaitableType>(op.getOperand().getType()))
583 return rewriter.notifyMatchFailure(op,
"unsupported awaitable type");
586 auto func = op->template getParentOfType<func::FuncOp>();
587 auto funcCoro = coros->find(func);
588 const bool isInCoroutine = funcCoro != coros->end();
590 Location loc = op->getLoc();
591 Value operand = adaptor.getOperand();
593 Type i1 = rewriter.getI1Type();
596 if (!isInCoroutine && !shouldLowerBlockingWait)
601 if (!isInCoroutine) {
602 ImplicitLocOpBuilder builder(loc, rewriter);
603 RuntimeAwaitOp::create(builder, loc, operand);
606 Value isError = RuntimeIsErrorOp::create(builder, i1, operand);
607 Value notError = arith::XOrIOp::create(
609 arith::ConstantOp::create(builder, loc, i1,
610 builder.getIntegerAttr(i1, 1)));
612 cf::AssertOp::create(builder, notError,
613 "Awaited async operand is in error state");
619 CoroMachinery &coro = funcCoro->getSecond();
620 Block *suspended = op->getBlock();
622 ImplicitLocOpBuilder builder(loc, rewriter);
623 MLIRContext *ctx = op->getContext();
628 CoroSaveOp::create(builder, CoroStateType::get(ctx), coro.coroHandle);
629 RuntimeAwaitAndResumeOp::create(builder, operand, coro.coroHandle);
636 builder.setInsertionPointToEnd(suspended);
637 CoroSuspendOp::create(builder, coroSaveOp.getState(), coro.suspend,
644 builder.setInsertionPointToStart(resume);
645 auto isError = RuntimeIsErrorOp::create(builder, loc, i1, operand);
646 cf::CondBranchOp::create(builder, isError,
654 rewriter.setInsertionPointToStart(continuation);
658 if (Value replaceWith = getReplacementValue(op, operand, rewriter))
659 rewriter.replaceOp(op, replaceWith);
661 rewriter.eraseOp(op);
666 virtual Value getReplacementValue(AwaitType op, Value operand,
667 ConversionPatternRewriter &rewriter)
const {
673 bool shouldLowerBlockingWait;
677class AwaitTokenOpLowering
678 :
public AwaitOpLoweringBase<AwaitOp, async::TokenType> {
679 using Base = AwaitOpLoweringBase<AwaitOp, async::TokenType>;
686class AwaitValueOpLowering :
public AwaitOpLoweringBase<AwaitOp, ValueType> {
687 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
693 getReplacementValue(AwaitOp op, Value operand,
694 ConversionPatternRewriter &rewriter)
const override {
696 auto valueType = cast<ValueType>(operand.
getType()).getValueType();
697 return RuntimeLoadOp::create(rewriter, op->getLoc(), valueType, operand);
702class AwaitAllOpLowering :
public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
703 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
718 : OpConversionPattern<
async::YieldOp>(ctx), coros(std::move(coros)) {}
722 ConversionPatternRewriter &rewriter)
const override {
724 auto func = op->template getParentOfType<func::FuncOp>();
725 auto funcCoro = coros->find(
func);
726 if (funcCoro == coros->end())
727 return rewriter.notifyMatchFailure(
728 op,
"operation is not inside the async coroutine function");
731 const CoroMachinery &coro = funcCoro->getSecond();
735 for (
auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
736 Value yieldValue = std::get<0>(tuple);
737 Value asyncValue = std::get<1>(tuple);
738 RuntimeStoreOp::create(rewriter, loc, yieldValue, asyncValue);
739 RuntimeSetAvailableOp::create(rewriter, loc, asyncValue);
744 RuntimeSetAvailableOp::create(rewriter, loc, *coro.asyncToken);
746 cf::BranchOp::create(rewriter, loc, coro.cleanup);
747 rewriter.eraseOp(op);
763 : OpConversionPattern<
cf::AssertOp>(ctx), coros(std::move(coros)) {}
767 ConversionPatternRewriter &rewriter)
const override {
769 auto func = op->template getParentOfType<func::FuncOp>();
770 auto funcCoro = coros->find(
func);
771 if (funcCoro == coros->end())
772 return rewriter.notifyMatchFailure(
773 op,
"operation is not inside the async coroutine function");
776 CoroMachinery &coro = funcCoro->getSecond();
779 rewriter.setInsertionPointToEnd(cont->getPrevNode());
780 cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(),
785 rewriter.eraseOp(op);
795void AsyncToAsyncRuntimePass::runOnOperation() {
796 ModuleOp module = getOperation();
797 SymbolTable symbolTable(module);
802 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
804 module.walk([&](ExecuteOp execute) {
809 llvm::dbgs() <<
"Outlined " << coros->size()
810 <<
" functions built from async.execute operations\n";
814 auto isInCoroutine = [&](Operation *op) ->
bool {
815 auto parentFunc = op->getParentOfType<func::FuncOp>();
816 return coros->contains(parentFunc);
820 MLIRContext *ctx =
module->getContext();
821 RewritePatternSet asyncPatterns(ctx);
831 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
834 .add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
838 asyncPatterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
841 ConversionTarget runtimeTarget(*ctx);
842 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
843 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
844 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
847 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
848 auto walkResult = op->walk([&](Operation *nested) {
849 bool isAsync = isa<async::AsyncDialect>(nested->
getDialect());
851 : WalkResult::advance();
853 return !walkResult.wasInterrupted();
855 runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
856 func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
859 runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
860 [&](cf::AssertOp op) ->
bool {
861 auto func = op->getParentOfType<func::FuncOp>();
862 return !coros->contains(func);
865 if (
failed(applyPartialConversion(module, runtimeTarget,
866 std::move(asyncPatterns)))) {
878 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
881 patterns.
add<AsyncCallOpLowering>(ctx);
882 patterns.
add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
884 patterns.
add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
888 target.addDynamicallyLegalOp<AwaitOp, AwaitAllOp, YieldOp, cf::AssertOp>(
890 auto exec = op->getParentOfType<ExecuteOp>();
891 auto func = op->getParentOfType<func::FuncOp>();
892 return exec || !coros->contains(
func);
896void AsyncFuncToAsyncRuntimePass::runOnOperation() {
897 ModuleOp module = getOperation();
908 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
909 runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();
911 runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,
912 cf::BranchOp, cf::CondBranchOp>();
914 if (failed(applyPartialConversion(module, runtimeTarget,
915 std::move(asyncPatterns)))) {
static Block * setupCleanupForDestroyBlock(ImplicitLocOpBuilder &builder, CoroMachinery &coro)
static constexpr const char kAsyncFnPrefix[]
std::shared_ptr< llvm::DenseMap< func::FuncOp, CoroMachinery > > FuncCoroMapPtr
static Block * setupSetErrorBlock(CoroMachinery &coro)
static CoroMachinery setupCoroMachinery(func::FuncOp func)
Utility to partially update the regular function CFG to the coroutine CFG compatible with LLVM corout...
static std::pair< func::FuncOp, CoroMachinery > outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute)
Outline the body region attached to the async.execute op into a standalone function.
AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
Block represents an ordered list of Operations.
OpListType::iterator iterator
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
typename cf::AssertOp::Adaptor OpAdaptor
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
static ImplicitLocOpBuilder atBlockBegin(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult interrupt()
void cloneConstantsIntoTheRegion(Region ®ion)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
void populateAsyncFuncToAsyncRuntimeConversionPatterns(RewritePatternSet &patterns, ConversionTarget &target)