29 #include "llvm/Support/Debug.h"
33 #define GEN_PASS_DEF_ASYNCTOASYNCRUNTIMEPASS
34 #define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIMEPASS
35 #include "mlir/Dialect/Async/Passes.h.inc"
41 #define DEBUG_TYPE "async-to-async-runtime"
47 class AsyncToAsyncRuntimePass
48 :
public impl::AsyncToAsyncRuntimePassBase<AsyncToAsyncRuntimePass> {
50 AsyncToAsyncRuntimePass() =
default;
51 void runOnOperation()
override;
58 class AsyncFuncToAsyncRuntimePass
59 :
public impl::AsyncFuncToAsyncRuntimePassBase<
60 AsyncFuncToAsyncRuntimePass> {
62 AsyncFuncToAsyncRuntimePass() =
default;
63 void runOnOperation()
override;
74 struct CoroMachinery {
90 std::optional<Value> asyncToken;
95 std::optional<Block *> setError;
118 Block *cleanupForDestroy;
124 std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;
171 assert(!func.getBlocks().empty() &&
"Function must have an entry block");
174 Block *entryBlock = &func.getBlocks().
front();
175 Block *originalEntryBlock =
185 bool isStateful = isa<TokenType>(func.getResultTypes().front());
187 std::optional<Value> retToken;
189 retToken.emplace(RuntimeCreateOp::create(builder,
TokenType::get(ctx)));
193 isStateful ? func.getResultTypes().drop_front() : func.getResultTypes();
194 for (
auto resType : resValueTypes)
195 retValues.emplace_back(
196 RuntimeCreateOp::create(builder, resType).getResult());
204 cf::BranchOp::create(builder, originalEntryBlock);
206 Block *cleanupBlock = func.addBlock();
207 Block *cleanupBlockForDestroy = func.addBlock();
208 Block *suspendBlock = func.addBlock();
213 auto buildCleanupBlock = [&](
Block *cb) {
214 builder.setInsertionPointToStart(cb);
215 CoroFreeOp::create(builder, coroIdOp.getId(), coroHdlOp.getHandle());
218 cf::BranchOp::create(builder, suspendBlock);
220 buildCleanupBlock(cleanupBlock);
221 buildCleanupBlock(cleanupBlockForDestroy);
227 builder.setInsertionPointToStart(suspendBlock);
230 CoroEndOp::create(builder, coroHdlOp.getHandle());
236 ret.push_back(*retToken);
237 llvm::append_range(ret, retValues);
238 func::ReturnOp::create(builder, ret);
245 func->setAttr(
"passthrough", builder.getArrayAttr(
248 CoroMachinery machinery;
249 machinery.func = func;
250 machinery.asyncToken = retToken;
251 machinery.returnValues = retValues;
252 machinery.coroHandle = coroHdlOp.getHandle();
253 machinery.entry = entryBlock;
254 machinery.setError = std::nullopt;
255 machinery.cleanup = cleanupBlock;
256 machinery.cleanupForDestroy = cleanupBlockForDestroy;
257 machinery.suspend = suspendBlock;
265 return *coro.setError;
267 coro.setError = coro.func.addBlock();
268 (*coro.setError)->moveBefore(coro.cleanup);
275 RuntimeSetErrorOp::create(builder, *coro.asyncToken);
277 for (
Value retValue : coro.returnValues)
278 RuntimeSetErrorOp::create(builder, retValue);
281 cf::BranchOp::create(builder, coro.cleanup);
283 return *coro.setError;
294 static std::pair<func::FuncOp, CoroMachinery>
296 ModuleOp module = execute->getParentOfType<ModuleOp>();
307 execute.getDependencies());
308 functionInputs.insert_range(execute.getBodyOperands());
312 auto typesRange = llvm::map_range(
313 functionInputs, [](
Value value) {
return value.
getType(); });
315 auto outputTypes = execute.getResultTypes();
331 size_t numDependencies = execute.getDependencies().size();
332 size_t numOperands = execute.getBodyOperands().size();
335 for (
size_t i = 0; i < numDependencies; ++i)
336 AwaitOp::create(builder, func.getArgument(i));
340 for (
size_t i = 0; i < numOperands; ++i) {
341 Value operand = func.getArgument(numDependencies + i);
342 unwrappedOperands[i] = AwaitOp::create(builder, loc, operand).getResult();
348 valueMapping.
map(functionInputs, func.getArguments());
349 valueMapping.
map(execute.getBodyRegion().getArguments(), unwrappedOperands);
353 for (
Operation &op : execute.getBodyRegion().getOps())
354 builder.clone(op, valueMapping);
364 cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
365 builder.setInsertionPointToEnd(coro.entry);
373 RuntimeResumeOp::create(builder, coro.coroHandle);
376 CoroSuspendOp::create(builder, coroSaveOp.getState(), coro.suspend,
377 branch.getDest(), coro.cleanupForDestroy);
385 auto callOutlinedFunc = func::CallOp::create(callBuilder, func.getName(),
386 execute.getResultTypes(),
387 functionInputs.getArrayRef());
388 execute.replaceAllUsesWith(callOutlinedFunc.getResults());
405 matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
424 matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
450 matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
455 func::FuncOp::create(rewriter, loc, op.getName(), op.getFunctionType());
460 for (
const auto &namedAttr : op->getAttrs()) {
462 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
469 (*coros)[newFuncOp] = coro;
490 matchAndRewrite(async::CallOp op, OpAdaptor adaptor,
493 op, op.getCallee(), op.getResultTypes(), op.getOperands());
508 matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
510 auto func = op->template getParentOfType<func::FuncOp>();
511 auto funcCoro = coros->find(func);
512 if (funcCoro == coros->end())
514 op,
"operation is not inside the async coroutine function");
517 const CoroMachinery &coro = funcCoro->getSecond();
522 for (
auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
523 Value returnValue = std::get<0>(tuple);
524 Value asyncValue = std::get<1>(tuple);
525 RuntimeStoreOp::create(rewriter, loc, returnValue, asyncValue);
526 RuntimeSetAvailableOp::create(rewriter, loc, asyncValue);
531 RuntimeSetAvailableOp::create(rewriter, loc, *coro.asyncToken);
534 cf::BranchOp::create(rewriter, loc, coro.cleanup);
549 template <
typename AwaitType,
typename AwaitableType>
551 using AwaitAdaptor =
typename AwaitType::Adaptor;
555 bool shouldLowerBlockingWait)
557 shouldLowerBlockingWait(shouldLowerBlockingWait) {}
560 matchAndRewrite(AwaitType op,
typename AwaitType::Adaptor adaptor,
564 if (!isa<AwaitableType>(op.getOperand().getType()))
568 auto func = op->template getParentOfType<func::FuncOp>();
569 auto funcCoro = coros->find(func);
570 const bool isInCoroutine = funcCoro != coros->end();
573 Value operand = adaptor.getOperand();
578 if (!isInCoroutine && !shouldLowerBlockingWait)
583 if (!isInCoroutine) {
585 RuntimeAwaitOp::create(builder, loc, operand);
588 Value isError = RuntimeIsErrorOp::create(builder, i1, operand);
589 Value notError = arith::XOrIOp::create(
591 arith::ConstantOp::create(builder, loc, i1,
592 builder.getIntegerAttr(i1, 1)));
594 cf::AssertOp::create(builder, notError,
595 "Awaited async operand is in error state");
601 CoroMachinery &coro = funcCoro->getSecond();
602 Block *suspended = op->getBlock();
611 RuntimeAwaitAndResumeOp::create(builder, operand, coro.coroHandle);
617 builder.setInsertionPointToEnd(suspended);
618 CoroSuspendOp::create(builder, coroSaveOp.getState(), coro.suspend,
619 resume, coro.cleanupForDestroy);
625 builder.setInsertionPointToStart(resume);
626 auto isError = RuntimeIsErrorOp::create(builder, loc, i1, operand);
627 cf::CondBranchOp::create(builder, isError,
639 if (
Value replaceWith = getReplacementValue(op, operand, rewriter))
647 virtual Value getReplacementValue(AwaitType op,
Value operand,
654 bool shouldLowerBlockingWait;
658 class AwaitTokenOpLowering :
public AwaitOpLoweringBase<AwaitOp, TokenType> {
659 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
666 class AwaitValueOpLowering :
public AwaitOpLoweringBase<AwaitOp, ValueType> {
667 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
673 getReplacementValue(AwaitOp op,
Value operand,
676 auto valueType = cast<ValueType>(operand.
getType()).getValueType();
677 return RuntimeLoadOp::create(rewriter, op->getLoc(), valueType, operand);
682 class AwaitAllOpLowering :
public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
683 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
704 auto func = op->template getParentOfType<func::FuncOp>();
705 auto funcCoro = coros->find(func);
706 if (funcCoro == coros->end())
708 op,
"operation is not inside the async coroutine function");
711 const CoroMachinery &coro = funcCoro->getSecond();
715 for (
auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
716 Value yieldValue = std::get<0>(tuple);
717 Value asyncValue = std::get<1>(tuple);
718 RuntimeStoreOp::create(rewriter, loc, yieldValue, asyncValue);
719 RuntimeSetAvailableOp::create(rewriter, loc, asyncValue);
724 RuntimeSetAvailableOp::create(rewriter, loc, *coro.asyncToken);
726 cf::BranchOp::create(rewriter, loc, coro.cleanup);
749 auto func = op->template getParentOfType<func::FuncOp>();
750 auto funcCoro = coros->find(func);
751 if (funcCoro == coros->end())
753 op,
"operation is not inside the async coroutine function");
756 CoroMachinery &coro = funcCoro->getSecond();
760 cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(),
775 void AsyncToAsyncRuntimePass::runOnOperation() {
776 ModuleOp module = getOperation();
782 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
784 module.walk([&](ExecuteOp execute) {
789 llvm::dbgs() <<
"Outlined " << coros->size()
790 <<
" functions built from async.execute operations\n";
794 auto isInCoroutine = [&](
Operation *op) ->
bool {
795 auto parentFunc = op->getParentOfType<func::FuncOp>();
796 return coros->contains(parentFunc);
811 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
814 .add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
822 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
823 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
824 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
827 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](
Operation *op) {
828 auto walkResult = op->walk([&](
Operation *nested) {
829 bool isAsync = isa<async::AsyncDialect>(nested->
getDialect());
833 return !walkResult.wasInterrupted();
835 runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
836 func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
839 runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
840 [&](cf::AssertOp op) ->
bool {
841 auto func = op->getParentOfType<func::FuncOp>();
842 return !coros->contains(func);
846 std::move(asyncPatterns)))) {
858 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
861 patterns.add<AsyncCallOpLowering>(ctx);
862 patterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
864 patterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
870 auto exec = op->getParentOfType<ExecuteOp>();
871 auto func = op->getParentOfType<func::FuncOp>();
872 return exec || !coros->contains(func);
876 void AsyncFuncToAsyncRuntimePass::runOnOperation() {
877 ModuleOp module = getOperation();
888 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
889 runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();
891 runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,
892 cf::BranchOp, cf::CondBranchOp>();
895 std::move(asyncPatterns)))) {
static Block * setupSetErrorBlock(CoroMachinery &coro)
std::shared_ptr< llvm::DenseMap< func::FuncOp, CoroMachinery > > FuncCoroMapPtr
static constexpr const char kAsyncFnPrefix[]
static std::pair< func::FuncOp, CoroMachinery > outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute)
Outline the body region attached to the async.execute op into a standalone function.
static CoroMachinery setupCoroMachinery(func::FuncOp func)
Utility to partially update the regular function CFG to the coroutine CFG compatible with LLVM corout...
AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
LogicalResult matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
Block represents an ordered list of Operations.
OpListType::iterator iterator
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
OpListType & getOperations()
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
static ImplicitLocOpBuilder atBlockBegin(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult interrupt()
void cloneConstantsIntoTheRegion(Region ®ion)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
void populateAsyncFuncToAsyncRuntimeConversionPatterns(RewritePatternSet &patterns, ConversionTarget &target)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.