30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/Support/Debug.h"
35 #define GEN_PASS_DEF_ASYNCTOASYNCRUNTIME
36 #define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIME
37 #include "mlir/Dialect/Async/Passes.h.inc"
43 #define DEBUG_TYPE "async-to-async-runtime"
49 class AsyncToAsyncRuntimePass
50 :
public impl::AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
52 AsyncToAsyncRuntimePass() =
default;
53 void runOnOperation()
override;
60 class AsyncFuncToAsyncRuntimePass
61 :
public impl::AsyncFuncToAsyncRuntimeBase<AsyncFuncToAsyncRuntimePass> {
63 AsyncFuncToAsyncRuntimePass() =
default;
64 void runOnOperation()
override;
75 struct CoroMachinery {
91 std::optional<Value> asyncToken;
96 std::optional<Block *> setError;
119 Block *cleanupForDestroy;
125 std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;
172 assert(!func.getBlocks().empty() &&
"Function must have an entry block");
175 Block *entryBlock = &func.getBlocks().
front();
176 Block *originalEntryBlock =
186 bool isStateful = isa<TokenType>(func.getResultTypes().front());
188 std::optional<Value> retToken;
190 retToken.emplace(builder.create<RuntimeCreateOp>(
TokenType::get(ctx)));
194 isStateful ? func.getResultTypes().drop_front() : func.getResultTypes();
195 for (
auto resType : resValueTypes)
196 retValues.emplace_back(
197 builder.create<RuntimeCreateOp>(resType).getResult());
205 builder.create<cf::BranchOp>(originalEntryBlock);
207 Block *cleanupBlock = func.addBlock();
208 Block *cleanupBlockForDestroy = func.addBlock();
209 Block *suspendBlock = func.addBlock();
214 auto buildCleanupBlock = [&](
Block *cb) {
215 builder.setInsertionPointToStart(cb);
216 builder.create<CoroFreeOp>(coroIdOp.getId(), coroHdlOp.getHandle());
219 builder.create<cf::BranchOp>(suspendBlock);
221 buildCleanupBlock(cleanupBlock);
222 buildCleanupBlock(cleanupBlockForDestroy);
228 builder.setInsertionPointToStart(suspendBlock);
231 builder.create<CoroEndOp>(coroHdlOp.getHandle());
237 ret.push_back(*retToken);
238 ret.insert(ret.end(), retValues.begin(), retValues.end());
239 builder.create<func::ReturnOp>(ret);
246 func->setAttr(
"passthrough", builder.getArrayAttr(
249 CoroMachinery machinery;
250 machinery.func = func;
251 machinery.asyncToken = retToken;
252 machinery.returnValues = retValues;
253 machinery.coroHandle = coroHdlOp.getHandle();
254 machinery.entry = entryBlock;
255 machinery.setError = std::nullopt;
256 machinery.cleanup = cleanupBlock;
257 machinery.cleanupForDestroy = cleanupBlockForDestroy;
258 machinery.suspend = suspendBlock;
266 return *coro.setError;
268 coro.setError = coro.func.addBlock();
269 (*coro.setError)->moveBefore(coro.cleanup);
276 builder.create<RuntimeSetErrorOp>(*coro.asyncToken);
278 for (
Value retValue : coro.returnValues)
279 builder.create<RuntimeSetErrorOp>(retValue);
282 builder.create<cf::BranchOp>(coro.cleanup);
284 return *coro.setError;
295 static std::pair<func::FuncOp, CoroMachinery>
297 ModuleOp module = execute->getParentOfType<ModuleOp>();
308 execute.getDependencies().end());
309 functionInputs.insert(execute.getBodyOperands().begin(),
310 execute.getBodyOperands().end());
314 auto typesRange = llvm::map_range(
315 functionInputs, [](
Value value) {
return value.
getType(); });
317 auto outputTypes = execute.getResultTypes();
333 size_t numDependencies = execute.getDependencies().size();
334 size_t numOperands = execute.getBodyOperands().size();
337 for (
size_t i = 0; i < numDependencies; ++i)
338 builder.create<AwaitOp>(func.getArgument(i));
342 for (
size_t i = 0; i < numOperands; ++i) {
343 Value operand = func.getArgument(numDependencies + i);
344 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).getResult();
350 valueMapping.
map(functionInputs, func.getArguments());
351 valueMapping.
map(execute.getBodyRegion().getArguments(), unwrappedOperands);
355 for (
Operation &op : execute.getBodyRegion().getOps())
356 builder.clone(op, valueMapping);
366 cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
367 builder.setInsertionPointToEnd(coro.entry);
375 builder.create<RuntimeResumeOp>(coro.coroHandle);
378 builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend,
379 branch.getDest(), coro.cleanupForDestroy);
387 auto callOutlinedFunc = callBuilder.
create<func::CallOp>(
388 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
389 execute.replaceAllUsesWith(callOutlinedFunc.getResults());
406 matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
425 matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
451 matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
456 rewriter.
create<func::FuncOp>(loc, op.getName(), op.getFunctionType());
461 for (
const auto &namedAttr : op->getAttrs()) {
463 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
470 (*coros)[newFuncOp] = coro;
491 matchAndRewrite(async::CallOp op, OpAdaptor adaptor,
494 op, op.getCallee(), op.getResultTypes(), op.getOperands());
509 matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
511 auto func = op->template getParentOfType<func::FuncOp>();
512 auto funcCoro = coros->find(func);
513 if (funcCoro == coros->end())
515 op,
"operation is not inside the async coroutine function");
518 const CoroMachinery &coro = funcCoro->getSecond();
523 for (
auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
524 Value returnValue = std::get<0>(tuple);
525 Value asyncValue = std::get<1>(tuple);
526 rewriter.
create<RuntimeStoreOp>(loc, returnValue, asyncValue);
527 rewriter.
create<RuntimeSetAvailableOp>(loc, asyncValue);
532 rewriter.
create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
535 rewriter.
create<cf::BranchOp>(loc, coro.cleanup);
550 template <
typename AwaitType,
typename AwaitableType>
552 using AwaitAdaptor =
typename AwaitType::Adaptor;
556 bool shouldLowerBlockingWait)
558 shouldLowerBlockingWait(shouldLowerBlockingWait) {}
561 matchAndRewrite(AwaitType op,
typename AwaitType::Adaptor adaptor,
565 if (!isa<AwaitableType>(op.getOperand().getType()))
569 auto func = op->template getParentOfType<func::FuncOp>();
570 auto funcCoro = coros->find(func);
571 const bool isInCoroutine = funcCoro != coros->end();
574 Value operand = adaptor.getOperand();
579 if (!isInCoroutine && !shouldLowerBlockingWait)
584 if (!isInCoroutine) {
586 builder.create<RuntimeAwaitOp>(loc, operand);
589 Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
590 Value notError = builder.create<arith::XOrIOp>(
591 isError, builder.create<arith::ConstantOp>(
592 loc, i1, builder.getIntegerAttr(i1, 1)));
594 builder.create<cf::AssertOp>(notError,
595 "Awaited async operand is in error state");
601 CoroMachinery &coro = funcCoro->getSecond();
602 Block *suspended = op->getBlock();
611 builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
617 builder.setInsertionPointToEnd(suspended);
618 builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend, resume,
619 coro.cleanupForDestroy);
625 builder.setInsertionPointToStart(resume);
626 auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
627 builder.create<cf::CondBranchOp>(isError,
639 if (
Value replaceWith = getReplacementValue(op, operand, rewriter))
647 virtual Value getReplacementValue(AwaitType op,
Value operand,
654 bool shouldLowerBlockingWait;
658 class AwaitTokenOpLowering :
public AwaitOpLoweringBase<AwaitOp, TokenType> {
659 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
666 class AwaitValueOpLowering :
public AwaitOpLoweringBase<AwaitOp, ValueType> {
667 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
673 getReplacementValue(AwaitOp op,
Value operand,
676 auto valueType = cast<ValueType>(operand.
getType()).getValueType();
677 return rewriter.
create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
682 class AwaitAllOpLowering :
public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
683 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
704 auto func = op->template getParentOfType<func::FuncOp>();
705 auto funcCoro = coros->find(func);
706 if (funcCoro == coros->end())
708 op,
"operation is not inside the async coroutine function");
711 const CoroMachinery &coro = funcCoro->getSecond();
715 for (
auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
716 Value yieldValue = std::get<0>(tuple);
717 Value asyncValue = std::get<1>(tuple);
718 rewriter.
create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
719 rewriter.
create<RuntimeSetAvailableOp>(loc, asyncValue);
724 rewriter.
create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
727 rewriter.
create<cf::BranchOp>(loc, coro.cleanup);
749 auto func = op->template getParentOfType<func::FuncOp>();
750 auto funcCoro = coros->find(func);
751 if (funcCoro == coros->end())
753 op,
"operation is not inside the async coroutine function");
756 CoroMachinery &coro = funcCoro->getSecond();
760 rewriter.
create<cf::CondBranchOp>(loc, adaptor.getArg(),
775 void AsyncToAsyncRuntimePass::runOnOperation() {
776 ModuleOp module = getOperation();
782 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
784 module.walk([&](ExecuteOp execute) {
789 llvm::dbgs() <<
"Outlined " << coros->size()
790 <<
" functions built from async.execute operations\n";
794 auto isInCoroutine = [&](
Operation *op) ->
bool {
795 auto parentFunc = op->getParentOfType<func::FuncOp>();
796 return coros->find(parentFunc) != coros->end();
811 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
814 .add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
822 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
823 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
824 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
827 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](
Operation *op) {
828 auto walkResult = op->walk([&](
Operation *nested) {
829 bool isAsync = isa<async::AsyncDialect>(nested->
getDialect());
833 return !walkResult.wasInterrupted();
835 runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
836 func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
839 runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
840 [&](cf::AssertOp op) ->
bool {
841 auto func = op->getParentOfType<func::FuncOp>();
842 return !coros->contains(func);
846 std::move(asyncPatterns)))) {
858 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
861 patterns.add<AsyncCallOpLowering>(ctx);
862 patterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
864 patterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
870 auto exec = op->getParentOfType<ExecuteOp>();
871 auto func = op->getParentOfType<func::FuncOp>();
872 return exec || !coros->contains(func);
876 void AsyncFuncToAsyncRuntimePass::runOnOperation() {
877 ModuleOp module = getOperation();
888 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
889 runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();
891 runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,
892 cf::BranchOp, cf::CondBranchOp>();
895 std::move(asyncPatterns)))) {
902 return std::make_unique<AsyncToAsyncRuntimePass>();
905 std::unique_ptr<OperationPass<ModuleOp>>
907 return std::make_unique<AsyncFuncToAsyncRuntimePass>();
static Block * setupSetErrorBlock(CoroMachinery &coro)
std::shared_ptr< llvm::DenseMap< func::FuncOp, CoroMachinery > > FuncCoroMapPtr
static constexpr const char kAsyncFnPrefix[]
static std::pair< func::FuncOp, CoroMachinery > outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute)
Outline the body region attached to the async.execute op into a standalone function.
static CoroMachinery setupCoroMachinery(func::FuncOp func)
Utility to partially update the regular function CFG to the coroutine CFG compatible with LLVM corout...
AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
Block represents an ordered list of Operations.
OpListType::iterator iterator
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
OpListType & getOperations()
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
typename SourceOp::Adaptor OpAdaptor
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
static ImplicitLocOpBuilder atBlockBegin(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult interrupt()
void cloneConstantsIntoTheRegion(Region ®ion)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
void populateAsyncFuncToAsyncRuntimeConversionPatterns(RewritePatternSet &patterns, ConversionTarget &target)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< OperationPass< ModuleOp > > createAsyncToAsyncRuntimePass()
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
std::unique_ptr< OperationPass< ModuleOp > > createAsyncFuncToAsyncRuntimePass()