30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/Support/Debug.h"
35 #define GEN_PASS_DEF_ASYNCTOASYNCRUNTIME
36 #define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIME
37 #include "mlir/Dialect/Async/Passes.h.inc"
43 #define DEBUG_TYPE "async-to-async-runtime"
49 class AsyncToAsyncRuntimePass
50 :
public impl::AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
52 AsyncToAsyncRuntimePass() =
default;
53 void runOnOperation()
override;
60 class AsyncFuncToAsyncRuntimePass
61 :
public impl::AsyncFuncToAsyncRuntimeBase<AsyncFuncToAsyncRuntimePass> {
63 AsyncFuncToAsyncRuntimePass() =
default;
64 void runOnOperation()
override;
75 struct CoroMachinery {
91 std::optional<Value> asyncToken;
96 std::optional<Block *> setError;
103 std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;
150 assert(!func.getBlocks().empty() &&
"Function must have an entry block");
153 Block *entryBlock = &func.getBlocks().
front();
154 Block *originalEntryBlock =
164 bool isStateful = func.getCallableResults().front().isa<TokenType>();
166 std::optional<Value> retToken;
168 retToken.emplace(builder.create<RuntimeCreateOp>(TokenType::get(ctx)));
172 ? func.getCallableResults().drop_front()
173 : func.getCallableResults();
174 for (
auto resType : resValueTypes)
175 retValues.emplace_back(
176 builder.create<RuntimeCreateOp>(resType).getResult());
181 auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
183 builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.getId());
184 builder.create<cf::BranchOp>(originalEntryBlock);
186 Block *cleanupBlock = func.addBlock();
187 Block *suspendBlock = func.addBlock();
192 builder.setInsertionPointToStart(cleanupBlock);
193 builder.create<CoroFreeOp>(coroIdOp.getId(), coroHdlOp.getHandle());
196 builder.create<cf::BranchOp>(suspendBlock);
202 builder.setInsertionPointToStart(suspendBlock);
205 builder.create<CoroEndOp>(coroHdlOp.getHandle());
211 ret.push_back(*retToken);
212 ret.insert(ret.end(), retValues.begin(), retValues.end());
213 builder.create<func::ReturnOp>(ret);
220 func->setAttr(
"passthrough", builder.getArrayAttr(
221 StringAttr::get(ctx,
"presplitcoroutine")));
223 CoroMachinery machinery;
224 machinery.func = func;
225 machinery.asyncToken = retToken;
226 machinery.returnValues = retValues;
227 machinery.coroHandle = coroHdlOp.getHandle();
228 machinery.entry = entryBlock;
229 machinery.setError = std::nullopt;
230 machinery.cleanup = cleanupBlock;
231 machinery.suspend = suspendBlock;
239 return *coro.setError;
241 coro.setError = coro.func.addBlock();
242 (*coro.setError)->moveBefore(coro.cleanup);
249 builder.create<RuntimeSetErrorOp>(*coro.asyncToken);
251 for (
Value retValue : coro.returnValues)
252 builder.create<RuntimeSetErrorOp>(retValue);
255 builder.create<cf::BranchOp>(coro.cleanup);
257 return *coro.setError;
268 static std::pair<func::FuncOp, CoroMachinery>
270 ModuleOp module = execute->getParentOfType<ModuleOp>();
281 execute.getDependencies().end());
282 functionInputs.insert(execute.getBodyOperands().begin(),
283 execute.getBodyOperands().end());
287 auto typesRange = llvm::map_range(
288 functionInputs, [](
Value value) {
return value.
getType(); });
290 auto outputTypes = execute.getResultTypes();
292 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
306 size_t numDependencies = execute.getDependencies().size();
307 size_t numOperands = execute.getBodyOperands().size();
310 for (
size_t i = 0; i < numDependencies; ++i)
311 builder.create<AwaitOp>(func.getArgument(i));
315 for (
size_t i = 0; i < numOperands; ++i) {
316 Value operand = func.getArgument(numDependencies + i);
317 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).getResult();
323 valueMapping.
map(functionInputs, func.getArguments());
324 valueMapping.
map(execute.getBodyRegion().getArguments(), unwrappedOperands);
328 for (
Operation &op : execute.getBodyRegion().getOps())
329 builder.clone(op, valueMapping);
339 cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
340 builder.setInsertionPointToEnd(coro.entry);
344 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
348 builder.create<RuntimeResumeOp>(coro.coroHandle);
351 builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend,
352 branch.getDest(), coro.cleanup);
360 auto callOutlinedFunc = callBuilder.
create<func::CallOp>(
361 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
362 execute.replaceAllUsesWith(callOutlinedFunc.getResults());
379 matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
382 op, GroupType::get(op->getContext()), adaptor.getOperands());
398 matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
424 matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
429 rewriter.
create<func::FuncOp>(loc, op.getName(), op.getFunctionType());
434 for (
const auto &namedAttr : op->getAttrs()) {
436 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
443 (*coros)[newFuncOp] = coro;
464 matchAndRewrite(async::CallOp op, OpAdaptor adaptor,
467 op, op.getCallee(), op.getResultTypes(), op.getOperands());
482 matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
484 auto func = op->template getParentOfType<func::FuncOp>();
485 auto funcCoro = coros->find(func);
486 if (funcCoro == coros->end())
488 op,
"operation is not inside the async coroutine function");
491 const CoroMachinery &coro = funcCoro->getSecond();
496 for (
auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
497 Value returnValue = std::get<0>(tuple);
498 Value asyncValue = std::get<1>(tuple);
499 rewriter.
create<RuntimeStoreOp>(loc, returnValue, asyncValue);
500 rewriter.
create<RuntimeSetAvailableOp>(loc, asyncValue);
505 rewriter.
create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
508 rewriter.
create<cf::BranchOp>(loc, coro.cleanup);
523 template <
typename AwaitType,
typename AwaitableType>
525 using AwaitAdaptor =
typename AwaitType::Adaptor;
529 bool shouldLowerBlockingWait)
531 shouldLowerBlockingWait(shouldLowerBlockingWait) {}
534 matchAndRewrite(AwaitType op,
typename AwaitType::Adaptor adaptor,
538 if (!op.getOperand().getType().template isa<AwaitableType>())
542 auto func = op->template getParentOfType<func::FuncOp>();
543 auto funcCoro = coros->find(func);
544 const bool isInCoroutine = funcCoro != coros->end();
547 Value operand = adaptor.getOperand();
552 if (!isInCoroutine && !shouldLowerBlockingWait)
557 if (!isInCoroutine) {
559 builder.create<RuntimeAwaitOp>(loc, operand);
562 Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
563 Value notError = builder.create<arith::XOrIOp>(
564 isError, builder.create<arith::ConstantOp>(
565 loc, i1, builder.getIntegerAttr(i1, 1)));
567 builder.create<cf::AssertOp>(notError,
568 "Awaited async operand is in error state");
574 CoroMachinery &coro = funcCoro->getSecond();
575 Block *suspended = op->getBlock();
583 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
584 builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
590 builder.setInsertionPointToEnd(suspended);
591 builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend, resume,
598 builder.setInsertionPointToStart(resume);
599 auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
600 builder.create<cf::CondBranchOp>(isError,
612 if (
Value replaceWith = getReplacementValue(op, operand, rewriter))
620 virtual Value getReplacementValue(AwaitType op,
Value operand,
627 bool shouldLowerBlockingWait;
631 class AwaitTokenOpLowering :
public AwaitOpLoweringBase<AwaitOp, TokenType> {
632 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
639 class AwaitValueOpLowering :
public AwaitOpLoweringBase<AwaitOp, ValueType> {
640 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
646 getReplacementValue(AwaitOp op,
Value operand,
649 auto valueType = operand.
getType().
cast<ValueType>().getValueType();
650 return rewriter.
create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
655 class AwaitAllOpLowering :
public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
656 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
677 auto func = op->template getParentOfType<func::FuncOp>();
678 auto funcCoro = coros->find(func);
679 if (funcCoro == coros->end())
681 op,
"operation is not inside the async coroutine function");
684 const CoroMachinery &coro = funcCoro->getSecond();
688 for (
auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
689 Value yieldValue = std::get<0>(tuple);
690 Value asyncValue = std::get<1>(tuple);
691 rewriter.
create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
692 rewriter.
create<RuntimeSetAvailableOp>(loc, asyncValue);
697 rewriter.
create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
700 rewriter.
create<cf::BranchOp>(loc, coro.cleanup);
722 auto func = op->template getParentOfType<func::FuncOp>();
723 auto funcCoro = coros->find(func);
724 if (funcCoro == coros->end())
726 op,
"operation is not inside the async coroutine function");
729 CoroMachinery &coro = funcCoro->getSecond();
733 rewriter.
create<cf::CondBranchOp>(loc, adaptor.getArg(),
748 void AsyncToAsyncRuntimePass::runOnOperation() {
749 ModuleOp module = getOperation();
755 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
757 module.walk([&](ExecuteOp execute) {
762 llvm::dbgs() <<
"Outlined " << coros->size()
763 <<
" functions built from async.execute operations\n";
767 auto isInCoroutine = [&](
Operation *op) ->
bool {
768 auto parentFunc = op->getParentOfType<func::FuncOp>();
769 return coros->find(parentFunc) != coros->end();
784 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
787 .add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
795 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
796 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
797 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
800 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](
Operation *op) {
801 auto walkResult = op->walk([&](
Operation *nested) {
802 bool isAsync = isa<async::AsyncDialect>(nested->
getDialect());
806 return !walkResult.wasInterrupted();
808 runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
809 func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
812 runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
813 [&](cf::AssertOp op) ->
bool {
814 auto func = op->getParentOfType<func::FuncOp>();
815 return coros->find(func) == coros->end();
819 std::move(asyncPatterns)))) {
831 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
834 patterns.
add<AsyncCallOpLowering>(ctx);
835 patterns.
add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
837 patterns.
add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
843 auto exec = op->getParentOfType<ExecuteOp>();
844 auto func = op->getParentOfType<func::FuncOp>();
845 return exec || coros->find(func) == coros->end();
849 void AsyncFuncToAsyncRuntimePass::runOnOperation() {
850 ModuleOp module = getOperation();
861 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
862 runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();
864 runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,
865 cf::BranchOp, cf::CondBranchOp>();
868 std::move(asyncPatterns)))) {
875 return std::make_unique<AsyncToAsyncRuntimePass>();
878 std::unique_ptr<OperationPass<ModuleOp>>
880 return std::make_unique<AsyncFuncToAsyncRuntimePass>();
static Block * setupSetErrorBlock(CoroMachinery &coro)
std::shared_ptr< llvm::DenseMap< func::FuncOp, CoroMachinery > > FuncCoroMapPtr
static constexpr const char kAsyncFnPrefix[]
static std::pair< func::FuncOp, CoroMachinery > outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute)
Outline the body region attached to the async.execute op into a standalone function.
static CoroMachinery setupCoroMachinery(func::FuncOp func)
Utility to partially update the regular function CFG to the coroutine CFG compatible with LLVM corout...
AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
Block represents an ordered list of Operations.
OpListType::iterator iterator
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
OpListType & getOperations()
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
typename SourceOp::Adaptor OpAdaptor
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
static ImplicitLocOpBuilder atBlockBegin(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult interrupt()
void cloneConstantsIntoTheRegion(Region ®ion)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
void populateAsyncFuncToAsyncRuntimeConversionPatterns(RewritePatternSet &patterns, ConversionTarget &target)
std::unique_ptr< OperationPass< ModuleOp > > createAsyncToAsyncRuntimePass()
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
std::unique_ptr< OperationPass< ModuleOp > > createAsyncFuncToAsyncRuntimePass()
This class represents an efficient way to signal success or failure.