27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/Support/Debug.h" 33 #define DEBUG_TYPE "async-to-async-runtime" 39 class AsyncToAsyncRuntimePass
40 :
public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
42 AsyncToAsyncRuntimePass() =
default;
43 void runOnOperation()
override;
58 struct CoroMachinery {
128 assert(!func.getBlocks().empty() &&
"Function must have an entry block");
131 Block *entryBlock = &func.getBlocks().
front();
132 Block *originalEntryBlock =
139 auto retToken = builder.
create<RuntimeCreateOp>(TokenType::get(ctx)).result();
142 for (
auto resType : func.getCallableResults().drop_front())
143 retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result());
148 auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
150 builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
151 builder.create<cf::BranchOp>(originalEntryBlock);
153 Block *cleanupBlock = func.addBlock();
154 Block *suspendBlock = func.addBlock();
159 builder.setInsertionPointToStart(cleanupBlock);
160 builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
163 builder.create<cf::BranchOp>(suspendBlock);
169 builder.setInsertionPointToStart(suspendBlock);
172 builder.create<CoroEndOp>(coroHdlOp.handle());
177 ret.insert(ret.end(), retValues.begin(), retValues.end());
178 builder.create<func::ReturnOp>(ret);
183 for (
Block &block : func.getBody().getBlocks()) {
184 if (&block == entryBlock || &block == cleanupBlock ||
185 &block == suspendBlock)
187 Operation *terminator = block.getTerminator();
188 if (
auto yield = dyn_cast<YieldOp>(terminator)) {
189 builder.setInsertionPointToEnd(&block);
190 builder.
create<cf::BranchOp>(cleanupBlock);
196 func->
setAttr(
"passthrough", builder.getArrayAttr(
197 StringAttr::get(ctx,
"presplitcoroutine")));
199 CoroMachinery machinery;
200 machinery.func = func;
201 machinery.asyncToken = retToken;
202 machinery.returnValues = retValues;
203 machinery.coroHandle = coroHdlOp.handle();
204 machinery.entry = entryBlock;
205 machinery.setError =
nullptr;
206 machinery.cleanup = cleanupBlock;
207 machinery.suspend = suspendBlock;
215 return coro.setError;
217 coro.setError = coro.func.addBlock();
218 coro.setError->moveBefore(coro.cleanup);
224 builder.
create<RuntimeSetErrorOp>(coro.asyncToken);
225 for (
Value retValue : coro.returnValues)
226 builder.create<RuntimeSetErrorOp>(retValue);
229 builder.create<cf::BranchOp>(coro.cleanup);
231 return coro.setError;
238 static std::pair<func::FuncOp, CoroMachinery>
240 ModuleOp module = execute->getParentOfType<ModuleOp>();
251 execute.dependencies().end());
252 functionInputs.insert(execute.operands().begin(), execute.operands().end());
256 auto typesRange = llvm::map_range(
259 auto outputTypes = execute.getResultTypes();
261 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
275 size_t numDependencies = execute.dependencies().size();
276 size_t numOperands = execute.operands().size();
279 for (
size_t i = 0; i < numDependencies; ++i)
280 builder.
create<AwaitOp>(func.getArgument(i));
284 for (
size_t i = 0; i < numOperands; ++i) {
285 Value operand = func.getArgument(numDependencies + i);
286 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
292 valueMapping.
map(functionInputs, func.getArguments());
293 valueMapping.
map(execute.body().getArguments(), unwrappedOperands);
297 for (
Operation &op : execute.body().getOps())
298 builder.clone(op, valueMapping);
308 cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
309 builder.setInsertionPointToEnd(coro.entry);
313 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
317 builder.create<RuntimeResumeOp>(coro.coroHandle);
320 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend,
321 branch.getDest(), coro.cleanup);
329 auto callOutlinedFunc = callBuilder.
create<func::CallOp>(
330 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
331 execute.replaceAllUsesWith(callOutlinedFunc.getResults());
348 matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
351 op, GroupType::get(op->getContext()), adaptor.getOperands());
367 matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
382 template <
typename AwaitType,
typename AwaitableType>
384 using AwaitAdaptor =
typename AwaitType::Adaptor;
391 outlinedFunctions(outlinedFunctions) {}
394 matchAndRewrite(AwaitType op,
typename AwaitType::Adaptor adaptor,
398 if (!op.operand().getType().template isa<AwaitableType>())
402 auto func = op->template getParentOfType<func::FuncOp>();
403 auto outlined = outlinedFunctions.find(func);
404 const bool isInCoroutine = outlined != outlinedFunctions.end();
407 Value operand = adaptor.operand();
413 if (!isInCoroutine) {
415 builder.
create<RuntimeAwaitOp>(loc, operand);
418 Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
419 Value notError = builder.create<arith::XOrIOp>(
420 isError, builder.create<arith::ConstantOp>(
421 loc, i1, builder.getIntegerAttr(i1, 1)));
423 builder.create<cf::AssertOp>(notError,
424 "Awaited async operand is in error state");
430 CoroMachinery &coro = outlined->getSecond();
431 Block *suspended = op->getBlock();
439 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
440 builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
446 builder.setInsertionPointToEnd(suspended);
447 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
454 builder.setInsertionPointToStart(resume);
455 auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
456 builder.create<cf::CondBranchOp>(isError,
468 if (
Value replaceWith = getReplacementValue(op, operand, rewriter))
476 virtual Value getReplacementValue(AwaitType op,
Value operand,
486 class AwaitTokenOpLowering :
public AwaitOpLoweringBase<AwaitOp, TokenType> {
487 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
494 class AwaitValueOpLowering :
public AwaitOpLoweringBase<AwaitOp, ValueType> {
495 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
501 getReplacementValue(AwaitOp op,
Value operand,
504 auto valueType = operand.
getType().
cast<ValueType>().getValueType();
505 return rewriter.
create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
510 class AwaitAllOpLowering :
public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
511 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
529 outlinedFunctions(outlinedFunctions) {}
535 auto func = op->template getParentOfType<func::FuncOp>();
536 auto outlined = outlinedFunctions.find(func);
537 if (outlined == outlinedFunctions.end())
539 op,
"operation is not inside the async coroutine function");
542 const CoroMachinery &coro = outlined->getSecond();
546 for (
auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
547 Value yieldValue = std::get<0>(tuple);
548 Value asyncValue = std::get<1>(tuple);
549 rewriter.
create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
550 rewriter.
create<RuntimeSetAvailableOp>(loc, asyncValue);
573 outlinedFunctions(outlinedFunctions) {}
579 auto func = op->template getParentOfType<func::FuncOp>();
580 auto outlined = outlinedFunctions.find(func);
581 if (outlined == outlinedFunctions.end())
583 op,
"operation is not inside the async coroutine function");
586 CoroMachinery &coro = outlined->getSecond();
590 rewriter.
create<cf::CondBranchOp>(loc, adaptor.getArg(),
612 auto *ctx = func->getContext();
613 auto loc = func.getLoc();
615 resultTypes.reserve(func.getCallableResults().size());
616 llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes),
617 [](
Type type) {
return ValueType::get(type); });
619 FunctionType::get(ctx, func.getFunctionType().getInputs(), resultTypes));
620 func.insertResult(0, TokenType::get(ctx), {});
621 for (
Block &block : func.getBlocks()) {
622 Operation *terminator = block.getTerminator();
623 if (
auto returnOp = dyn_cast<func::ReturnOp>(*terminator)) {
625 builder.
create<YieldOp>(returnOp.getOperands());
638 auto loc = func.getLoc();
640 auto newCall = callBuilder.
create<func::CallOp>(
641 func.getName(), func.getCallableResults(), oldCall.getArgOperands());
644 callBuilder.
create<AwaitOp>(loc, newCall.getResults().front());
646 unwrappedResults.reserve(newCall->getResults().size() - 1);
647 for (
Value result : newCall.getResults().drop_front())
648 unwrappedResults.push_back(
649 callBuilder.
create<AwaitOp>(loc, result).result());
652 oldCall.replaceAllUsesWith(unwrappedResults);
657 return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName);
674 auto addToWorklist = [&](func::FuncOp func) {
682 outlinedFunctions.find(func) == outlinedFunctions.end()) {
683 for (
Operation &op : func.getBody().getOps()) {
684 if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) {
685 funcWorklist.push_back(func);
693 for (func::FuncOp func : module.getOps<func::FuncOp>())
700 while (!funcWorklist.empty()) {
701 auto func = funcWorklist.pop_back_val();
702 auto insertion = outlinedFunctions.insert({func, CoroMachinery{}});
703 if (!insertion.second)
710 symbolUserMap.
getUsers(func).end());
718 return blockA > blockB || (blockA == blockB && !a->
isBeforeInBlock(b));
722 if (func::CallOp call = dyn_cast<func::CallOp>(*op)) {
723 func::FuncOp caller = call->getParentOfType<func::FuncOp>();
725 addToWorklist(caller);
727 op->emitError(
"Unexpected reference to func referenced by symbol");
736 void AsyncToAsyncRuntimePass::runOnOperation() {
737 ModuleOp module = getOperation();
743 module.walk([&](ExecuteOp execute) {
748 llvm::dbgs() <<
"Outlined " << outlinedFunctions.size()
749 <<
" functions built from async.execute operations\n";
753 auto isInCoroutine = [&](
Operation *op) ->
bool {
754 auto parentFunc = op->getParentOfType<func::FuncOp>();
755 return outlinedFunctions.find(parentFunc) != outlinedFunctions.end();
758 if (eliminateBlockingAwaitOps &&
776 asyncPatterns.
add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
777 asyncPatterns.
add<AwaitTokenOpLowering, AwaitValueOpLowering,
786 runtimeTarget.addLegalDialect<AsyncDialect>();
787 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
788 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
791 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](
Operation *op) {
792 auto walkResult = op->walk([&](
Operation *nested) {
793 bool isAsync = isa<async::AsyncDialect>(nested->
getDialect());
797 return !walkResult.wasInterrupted();
799 runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
800 func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
803 runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
804 [&](cf::AssertOp op) ->
bool {
805 auto func = op->getParentOfType<func::FuncOp>();
806 return outlinedFunctions.find(func) == outlinedFunctions.end();
809 if (eliminateBlockingAwaitOps)
810 runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>(
811 [&](RuntimeAwaitOp op) ->
bool {
816 std::move(asyncPatterns)))) {
823 return std::make_unique<AsyncToAsyncRuntimePass>();
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, ArrayRef< NamedAttribute > attributes, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
typename cf::AssertOp ::Adaptor OpAdaptor
Operation is a basic unit of execution within MLIR.
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Block represents an ordered list of Operations.
OpListType & getOperations()
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
typename async::YieldOp ::Adaptor OpAdaptor
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
static ImplicitLocOpBuilder atBlockBegin(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Block * getBlock()
Returns the operation block that contains this operation.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
static constexpr const bool value
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
YieldOpLowering(MLIRContext *ctx, const llvm::DenseMap< func::FuncOp, CoroMachinery > &outlinedFunctions)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void map(Block *from, Block *to)
Inserts a new mapping for 'from' to 'to'.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
This class represents a collection of SymbolTables.
LogicalResult matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
OpListType::iterator iterator
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
static void rewriteCallsiteForCoroutine(func::CallOp oldCall, func::FuncOp func)
Rewrites a call into a function that has been rewritten as a coroutine.
static LogicalResult funcsToCoroutines(ModuleOp module, llvm::DenseMap< func::FuncOp, CoroMachinery > &outlinedFunctions)
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
static WalkResult advance()
static constexpr const char kAsyncFnPrefix[]
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener...
static WalkResult interrupt()
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
static std::pair< func::FuncOp, CoroMachinery > outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute)
Outline the body region attached to the async.execute op into a standalone function.
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
static bool isAllowedToBlock(func::FuncOp func)
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
void cloneConstantsIntoTheRegion(Region ®ion)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
This class represents a map of symbols to users, and provides efficient implementations of symbol que...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
ArrayRef< Operation * > getUsers(Operation *symbol) const
Return the users of the provided symbol operation.
static CoroMachinery rewriteFuncAsCoroutine(func::FuncOp func)
Rewrite a func as a coroutine by: 1) Wrapping the results into async.value.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
AssertOpLowering(MLIRContext *ctx, llvm::DenseMap< func::FuncOp, CoroMachinery > &outlinedFunctions)
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
Type getType() const
Return the type of this value.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
MLIRContext is the top-level object for a collection of MLIR operations.
std::unique_ptr< OperationPass< ModuleOp > > createAsyncToAsyncRuntimePass()
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
This class implements a pattern rewriter for use with ConversionPatterns.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
value !async value< T > ***static CoroMachinery setupCoroMachinery(func::FuncOp func)
This class describes a specific conversion target.
static Block * setupSetErrorBlock(CoroMachinery &coro)
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.