MLIR 23.0.0git
AsyncToAsyncRuntime.cpp
Go to the documentation of this file.
1//===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering from high level async operations to async.coro
10// and async.runtime operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include <utility>
15
17
18#include "PassDetail.h"
25#include "mlir/IR/IRMapping.h"
29#include "llvm/Support/Debug.h"
30#include <optional>
31
32namespace mlir {
33#define GEN_PASS_DEF_ASYNCTOASYNCRUNTIMEPASS
34#define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIMEPASS
35#include "mlir/Dialect/Async/Passes.h.inc"
36} // namespace mlir
38using namespace mlir;
39using namespace mlir::async;
40
41#define DEBUG_TYPE "async-to-async-runtime"
42// Prefix for functions outlined from `async.execute` op regions.
43static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
44
45namespace {
46
47class AsyncToAsyncRuntimePass
48 : public impl::AsyncToAsyncRuntimePassBase<AsyncToAsyncRuntimePass> {
49public:
50 AsyncToAsyncRuntimePass() = default;
51 void runOnOperation() override;
52};
54} // namespace
55
56namespace {
57
58class AsyncFuncToAsyncRuntimePass
60 AsyncFuncToAsyncRuntimePass> {
61public:
62 AsyncFuncToAsyncRuntimePass() = default;
63 void runOnOperation() override;
64};
65
66} // namespace
67
68/// Function targeted for coroutine transformation has two additional blocks at
69/// the end: coroutine cleanup and coroutine suspension.
70///
71/// async.await op lowering additionaly creates a resume block for each
72/// operation to enable non-blocking waiting via coroutine suspension.
73namespace {
74struct CoroMachinery {
75 func::FuncOp func;
76
77 // Async function returns an optional token, followed by some async values
78 //
79 // async.func @foo() -> !async.value<T> {
80 // %cst = arith.constant 42.0 : T
81 // return %cst: T
82 // }
83 // Async execute region returns a completion token, and an async value for
84 // each yielded value.
85 //
86 // %token, %result = async.execute -> !async.value<T> {
87 // %0 = arith.constant ... : T
88 // async.yield %0 : T
89 // }
90 std::optional<Value> asyncToken; // returned completion token
91 llvm::SmallVector<Value, 4> returnValues; // returned async values
92
93 Value coroId; // coroutine id (!async.coro.id value)
94 Value coroHandle; // coroutine handle (!async.coro.getHandle value)
95 Block *entry; // coroutine entry block
96 std::optional<Block *> setError; // set returned values to error state
97 Block *cleanup; // coroutine cleanup block
98
99 // Coroutine cleanup block for destroy after the coroutine is resumed,
100 // e.g. async.coro.suspend state, [suspend], [resume], [destroy]
101 //
102 // This cleanup block is a duplicate of the cleanup block followed by the
103 // resume block. The purpose of having a duplicate cleanup block for destroy
104 // is to make the CFG clear so that the control flow analysis won't confuse.
105 //
106 // The overall structure of the lowered CFG can be the following,
107 //
108 // Entry (calling async.coro.suspend)
109 // | \
110 // Resume Destroy (duplicate of Cleanup)
111 // | |
112 // Cleanup |
113 // | /
114 // End (ends the corontine)
115 //
116 // If there is resume-specific cleanup logic, it can go into the Cleanup
117 // block but not the destroy block. Otherwise, it can fail block dominance
118 // check.
119 //
120 // This block is created lazily by `setupCleanupForDestroyBlock` only when a
121 // suspension point needs a destroy successor, so that functions without any
122 // coroutine suspends (e.g. an `async.func` body with no `await`) don't end
123 // up with dead code.
124 std::optional<Block *> cleanupForDestroy;
125 Block *suspend; // coroutine suspension block
126};
127} // namespace
128
130 std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;
131
132/// Utility to partially update the regular function CFG to the coroutine CFG
133/// compatible with LLVM coroutines switched-resume lowering using
134/// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
135/// that branches into preexisting entry block. Also inserts trailing blocks.
136///
137/// The result types of the passed `func` start with an optional `async.token`
138/// and be continued with some number of `async.value`s.
139///
140/// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
141///
142/// - `entry` block sets up the coroutine.
143/// - `set_error` block sets completion token and async values state to error.
144/// - `cleanup` block cleans up the coroutine state.
145/// - `suspend block after the @llvm.coro.end() defines what value will be
146/// returned to the initial caller of a coroutine. Everything before the
147/// @llvm.coro.end() will be executed at every suspension point.
148///
149/// Coroutine structure (only the important bits):
150///
151/// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>)
152/// {
153/// ^entry(<function-arguments>):
154/// %token = <async token> : !async.token // create async runtime token
155/// %value = <async value> : !async.value<T> // create async value
156/// %id = async.coro.getId // create a coroutine id
157/// %hdl = async.coro.begin %id // create a coroutine handle
158/// cf.br ^preexisting_entry_block
159///
160/// /* preexisting blocks modified to branch to the cleanup block */
161///
162/// ^set_error: // this block created lazily only if needed (see code below)
163/// async.runtime.set_error %token : !async.token
164/// async.runtime.set_error %value : !async.value<T>
165/// cf.br ^cleanup
166///
167/// ^cleanup:
168/// async.coro.free %hdl // delete the coroutine state
169/// cf.br ^suspend
170///
171/// ^suspend:
172/// async.coro.end %hdl // marks the end of a coroutine
173/// return %token, %value : !async.token, !async.value<T>
174/// }
175///
176static CoroMachinery setupCoroMachinery(func::FuncOp func) {
177 assert(!func.getBlocks().empty() && "Function must have an entry block");
178
179 MLIRContext *ctx = func.getContext();
180 Block *entryBlock = &func.getBlocks().front();
181 Block *originalEntryBlock =
182 entryBlock->splitBlock(entryBlock->getOperations().begin());
183 auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
184
185 // ------------------------------------------------------------------------ //
186 // Allocate async token/values that we will return from a ramp function.
187 // ------------------------------------------------------------------------ //
188
189 // We treat TokenType as state update marker to represent side-effects of
190 // async computations
191 bool isStateful = isa<async::TokenType>(func.getResultTypes().front());
192
193 std::optional<Value> retToken;
194 if (isStateful)
195 retToken.emplace(
196 RuntimeCreateOp::create(builder, async::TokenType::get(ctx)));
197
199 ArrayRef<Type> resValueTypes =
200 isStateful ? func.getResultTypes().drop_front() : func.getResultTypes();
201 for (auto resType : resValueTypes)
202 retValues.emplace_back(
203 RuntimeCreateOp::create(builder, resType).getResult());
204
205 // ------------------------------------------------------------------------ //
206 // Initialize coroutine: get coroutine id and coroutine handle.
207 // ------------------------------------------------------------------------ //
208 auto coroIdOp = CoroIdOp::create(builder, CoroIdType::get(ctx));
209 auto coroHdlOp =
210 CoroBeginOp::create(builder, CoroHandleType::get(ctx), coroIdOp.getId());
211 cf::BranchOp::create(builder, originalEntryBlock);
212
213 Block *cleanupBlock = func.addBlock();
214 Block *suspendBlock = func.addBlock();
215
216 // ------------------------------------------------------------------------ //
217 // Coroutine cleanup block: deallocate coroutine frame, free the memory.
218 // ------------------------------------------------------------------------ //
219 // The matching "destroy" cleanup block is materialized lazily by
220 // `setupCleanupForDestroyBlock` only when a suspend point needs it.
221 builder.setInsertionPointToStart(cleanupBlock);
222 CoroFreeOp::create(builder, coroIdOp.getId(), coroHdlOp.getHandle());
223 cf::BranchOp::create(builder, suspendBlock);
224
225 // ------------------------------------------------------------------------ //
226 // Coroutine suspend block: mark the end of a coroutine and return allocated
227 // async token.
228 // ------------------------------------------------------------------------ //
229 builder.setInsertionPointToStart(suspendBlock);
230
231 // Mark the end of a coroutine: async.coro.end
232 CoroEndOp::create(builder, coroHdlOp.getHandle());
233
234 // Return created optional `async.token` and `async.values` from the suspend
235 // block. This will be the return value of a coroutine ramp function.
237 if (retToken)
238 ret.push_back(*retToken);
239 llvm::append_range(ret, retValues);
240 func::ReturnOp::create(builder, ret);
241
242 // `async.await` op lowering will create resume blocks for async
243 // continuations, and will conditionally branch to cleanup or suspend blocks.
244
245 // The switch-resumed API based coroutine should be marked with
246 // presplitcoroutine attribute to mark the function as a coroutine.
247 func->setAttr("llvm.passthrough", builder.getArrayAttr(StringAttr::get(
248 ctx, "presplitcoroutine")));
249
250 CoroMachinery machinery;
251 machinery.func = func;
252 machinery.asyncToken = retToken;
253 machinery.returnValues = retValues;
254 machinery.coroId = coroIdOp.getId();
255 machinery.coroHandle = coroHdlOp.getHandle();
256 machinery.entry = entryBlock;
257 machinery.setError = std::nullopt; // created lazily only if needed
258 machinery.cleanup = cleanupBlock;
259 machinery.cleanupForDestroy = std::nullopt; // created lazily only if needed
260 machinery.suspend = suspendBlock;
261 return machinery;
262}
263
264// Lazily creates `set_error` block only if it is required for lowering to the
265// runtime operations (see for example lowering of assert operation).
266static Block *setupSetErrorBlock(CoroMachinery &coro) {
267 if (coro.setError)
268 return *coro.setError;
269
270 coro.setError = coro.func.addBlock();
271 (*coro.setError)->moveBefore(coro.cleanup);
272
273 auto builder =
274 ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), *coro.setError);
275
276 // Coroutine set_error block: set error on token and all returned values.
277 if (coro.asyncToken)
278 RuntimeSetErrorOp::create(builder, *coro.asyncToken);
279
280 for (Value retValue : coro.returnValues)
281 RuntimeSetErrorOp::create(builder, retValue);
282
283 // Branch into the cleanup block.
284 cf::BranchOp::create(builder, coro.cleanup);
285
286 return *coro.setError;
287}
288
289// Lazily creates the `cleanupForDestroy` block only if a suspension point
290// actually needs a destroy successor. This avoids leaving an unreachable
291// cleanup block behind in coroutines that never suspend.
293 CoroMachinery &coro) {
294 if (coro.cleanupForDestroy)
295 return *coro.cleanupForDestroy;
296 OpBuilder::InsertionGuard guard(builder);
297 coro.cleanupForDestroy = builder.createBlock(coro.suspend);
298 CoroFreeOp::create(builder, coro.coroId, coro.coroHandle);
299 cf::BranchOp::create(builder, coro.suspend);
300 return *coro.cleanupForDestroy;
301}
302
303//===----------------------------------------------------------------------===//
304// async.execute op outlining to the coroutine functions.
305//===----------------------------------------------------------------------===//
306
307/// Outline the body region attached to the `async.execute` op into a standalone
308/// function.
309///
310/// Note that this is not reversible transformation.
311static std::pair<func::FuncOp, CoroMachinery>
312outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
313 ModuleOp module = execute->getParentOfType<ModuleOp>();
314
315 MLIRContext *ctx = module.getContext();
316 Location loc = execute.getLoc();
317
318 // Make sure that all constants will be inside the outlined async function to
319 // reduce the number of function arguments.
320 cloneConstantsIntoTheRegion(execute.getBodyRegion());
321
322 // Collect all outlined function inputs.
323 SetVector<mlir::Value> functionInputs(llvm::from_range,
324 execute.getDependencies());
325 functionInputs.insert_range(execute.getBodyOperands());
326 getUsedValuesDefinedAbove(execute.getBodyRegion(), functionInputs);
327
328 // Collect types for the outlined function inputs and outputs.
329 auto typesRange = llvm::map_range(
330 functionInputs, [](Value value) { return value.getType(); });
331 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
332 auto outputTypes = execute.getResultTypes();
333
334 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
335 auto funcAttrs = ArrayRef<NamedAttribute>();
336
337 // TODO: Derive outlined function name from the parent FuncOp (support
338 // multiple nested async.execute operations).
339 func::FuncOp func =
340 func::FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
341 symbolTable.insert(func);
342
344 auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock());
345
346 // Prepare for coroutine conversion by creating the body of the function.
347 {
348 size_t numDependencies = execute.getDependencies().size();
349 size_t numOperands = execute.getBodyOperands().size();
350
351 // Await on all dependencies before starting to execute the body region.
352 for (size_t i = 0; i < numDependencies; ++i)
353 AwaitOp::create(builder, func.getArgument(i));
354
355 // Await on all async value operands and unwrap the payload.
356 SmallVector<Value, 4> unwrappedOperands(numOperands);
357 for (size_t i = 0; i < numOperands; ++i) {
358 Value operand = func.getArgument(numDependencies + i);
359 unwrappedOperands[i] = AwaitOp::create(builder, loc, operand).getResult();
360 }
361
362 // Map from function inputs defined above the execute op to the function
363 // arguments.
364 IRMapping valueMapping;
365 valueMapping.map(functionInputs, func.getArguments());
366 valueMapping.map(execute.getBodyRegion().getArguments(), unwrappedOperands);
367
368 // Clone all operations from the execute operation body into the outlined
369 // function body.
370 for (Operation &op : execute.getBodyRegion().getOps())
371 builder.clone(op, valueMapping);
372 }
373
374 // Adding entry/cleanup/suspend blocks.
375 CoroMachinery coro = setupCoroMachinery(func);
376
377 // Suspend async function at the end of an entry block, and resume it using
378 // Async resume operation (execution will be resumed in a thread managed by
379 // the async runtime).
380 {
381 cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
382 builder.setInsertionPointToEnd(coro.entry);
383
384 // Save the coroutine state: async.coro.save
385 auto coroSaveOp =
386 CoroSaveOp::create(builder, CoroStateType::get(ctx), coro.coroHandle);
387
388 // Pass coroutine to the runtime to be resumed on a runtime managed
389 // thread.
390 RuntimeResumeOp::create(builder, coro.coroHandle);
391
392 // Add async.coro.suspend as a suspended block terminator.
393 Block *destroy = setupCleanupForDestroyBlock(builder, coro);
394 CoroSuspendOp::create(builder, coroSaveOp.getState(), coro.suspend,
395 branch.getDest(), destroy);
396
397 branch.erase();
398 }
399
400 // Replace the original `async.execute` with a call to outlined function.
401 {
402 ImplicitLocOpBuilder callBuilder(loc, execute);
403 auto callOutlinedFunc = func::CallOp::create(callBuilder, func.getName(),
404 execute.getResultTypes(),
405 functionInputs.getArrayRef());
406 execute.replaceAllUsesWith(callOutlinedFunc.getResults());
407 execute.erase();
408 }
409
410 return {func, coro};
411}
412
413//===----------------------------------------------------------------------===//
414// Convert async.create_group operation to async.runtime.create_group
415//===----------------------------------------------------------------------===//
416
417namespace {
418class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
419public:
420 using OpConversionPattern::OpConversionPattern;
421
422 LogicalResult
423 matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
424 ConversionPatternRewriter &rewriter) const override {
425 rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
426 op, GroupType::get(op->getContext()), adaptor.getOperands());
427 return success();
428 }
429};
430} // namespace
431
432//===----------------------------------------------------------------------===//
433// Convert async.add_to_group operation to async.runtime.add_to_group.
434//===----------------------------------------------------------------------===//
436namespace {
437class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
438public:
439 using OpConversionPattern::OpConversionPattern;
441 LogicalResult
442 matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
443 ConversionPatternRewriter &rewriter) const override {
444 rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
445 op, rewriter.getIndexType(), adaptor.getOperands());
446 return success();
448};
449} // namespace
451//===----------------------------------------------------------------------===//
452// Convert async.func, async.return and async.call operations to non-blocking
453// operations based on llvm coroutine
454//===----------------------------------------------------------------------===//
456namespace {
457
458//===----------------------------------------------------------------------===//
459// Convert async.func operation to func.func
460//===----------------------------------------------------------------------===//
462class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> {
463public:
464 AsyncFuncOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
465 : OpConversionPattern<async::FuncOp>(ctx), coros(std::move(coros)) {}
467 LogicalResult
468 matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
469 ConversionPatternRewriter &rewriter) const override {
470 Location loc = op->getLoc();
472 auto newFuncOp =
473 func::FuncOp::create(rewriter, loc, op.getName(), op.getFunctionType());
474
477 // Copy over all attributes other than the name.
478 for (const auto &namedAttr : op->getAttrs()) {
479 if (namedAttr.getName() != SymbolTable::getSymbolAttrName())
480 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
481 }
482
483 rewriter.inlineRegionBefore(op.getBody(), newFuncOp.getBody(),
484 newFuncOp.end());
486 CoroMachinery coro = setupCoroMachinery(newFuncOp);
487 (*coros)[newFuncOp] = coro;
488 // no initial suspend, we should hot-start
489
490 rewriter.eraseOp(op);
491 return success();
492 }
493
494private:
495 FuncCoroMapPtr coros;
496};
497
498//===----------------------------------------------------------------------===//
499// Convert async.call operation to func.call
500//===----------------------------------------------------------------------===//
501
502class AsyncCallOpLowering : public OpConversionPattern<async::CallOp> {
503public:
504 AsyncCallOpLowering(MLIRContext *ctx)
505 : OpConversionPattern<async::CallOp>(ctx) {}
506
507 LogicalResult
508 matchAndRewrite(async::CallOp op, OpAdaptor adaptor,
509 ConversionPatternRewriter &rewriter) const override {
510 rewriter.replaceOpWithNewOp<func::CallOp>(
511 op, op.getCallee(), op.getResultTypes(), op.getOperands());
512 return success();
513 }
514};
515
516//===----------------------------------------------------------------------===//
517// Convert async.return operation to async.runtime operations.
518//===----------------------------------------------------------------------===//
519
520class AsyncReturnOpLowering : public OpConversionPattern<async::ReturnOp> {
521public:
522 AsyncReturnOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
523 : OpConversionPattern<async::ReturnOp>(ctx), coros(std::move(coros)) {}
524
525 LogicalResult
526 matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
527 ConversionPatternRewriter &rewriter) const override {
528 auto func = op->template getParentOfType<func::FuncOp>();
529 auto funcCoro = coros->find(func);
530 if (funcCoro == coros->end())
531 return rewriter.notifyMatchFailure(
532 op, "operation is not inside the async coroutine function");
533
534 Location loc = op->getLoc();
535 const CoroMachinery &coro = funcCoro->getSecond();
536 rewriter.setInsertionPointAfter(op);
537
538 // Store return values into the async values storage and switch async
539 // values state to available.
540 for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
541 Value returnValue = std::get<0>(tuple);
542 Value asyncValue = std::get<1>(tuple);
543 RuntimeStoreOp::create(rewriter, loc, returnValue, asyncValue);
544 RuntimeSetAvailableOp::create(rewriter, loc, asyncValue);
545 }
546
547 if (coro.asyncToken)
548 // Switch the coroutine completion token to available state.
549 RuntimeSetAvailableOp::create(rewriter, loc, *coro.asyncToken);
550
551 rewriter.eraseOp(op);
552 cf::BranchOp::create(rewriter, loc, coro.cleanup);
553 return success();
554 }
555
556private:
557 FuncCoroMapPtr coros;
558};
559} // namespace
560
561//===----------------------------------------------------------------------===//
562// Convert async.await and async.await_all operations to the async.runtime.await
563// or async.runtime.await_and_resume operations.
564//===----------------------------------------------------------------------===//
565
566namespace {
567template <typename AwaitType, typename AwaitableType>
568class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
569 using AwaitAdaptor = typename AwaitType::Adaptor;
570
571public:
572 AwaitOpLoweringBase(MLIRContext *ctx, FuncCoroMapPtr coros,
573 bool shouldLowerBlockingWait)
574 : OpConversionPattern<AwaitType>(ctx), coros(std::move(coros)),
575 shouldLowerBlockingWait(shouldLowerBlockingWait) {}
576
577 LogicalResult
578 matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
579 ConversionPatternRewriter &rewriter) const override {
580 // We can only await on one the `AwaitableType` (for `await` it can be
581 // a `token` or a `value`, for `await_all` it must be a `group`).
582 if (!isa<AwaitableType>(op.getOperand().getType()))
583 return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
584
585 // Check if await operation is inside the coroutine function.
586 auto func = op->template getParentOfType<func::FuncOp>();
587 auto funcCoro = coros->find(func);
588 const bool isInCoroutine = funcCoro != coros->end();
589
590 Location loc = op->getLoc();
591 Value operand = adaptor.getOperand();
592
593 Type i1 = rewriter.getI1Type();
594
595 // Delay lowering to block wait in case await op is inside async.execute
596 if (!isInCoroutine && !shouldLowerBlockingWait)
597 return failure();
598
599 // Inside regular functions we use the blocking wait operation to wait for
600 // the async object (token, value or group) to become available.
601 if (!isInCoroutine) {
602 ImplicitLocOpBuilder builder(loc, rewriter);
603 RuntimeAwaitOp::create(builder, loc, operand);
604
605 // Assert that the awaited operands is not in the error state.
606 Value isError = RuntimeIsErrorOp::create(builder, i1, operand);
607 Value notError = arith::XOrIOp::create(
608 builder, isError,
609 arith::ConstantOp::create(builder, loc, i1,
610 builder.getIntegerAttr(i1, 1)));
611
612 cf::AssertOp::create(builder, notError,
613 "Awaited async operand is in error state");
614 }
615
616 // Inside the coroutine we convert await operation into coroutine suspension
617 // point, and resume execution asynchronously.
618 if (isInCoroutine) {
619 CoroMachinery &coro = funcCoro->getSecond();
620 Block *suspended = op->getBlock();
621
622 ImplicitLocOpBuilder builder(loc, rewriter);
623 MLIRContext *ctx = op->getContext();
624
625 // Save the coroutine state and resume on a runtime managed thread when
626 // the operand becomes available.
627 auto coroSaveOp =
628 CoroSaveOp::create(builder, CoroStateType::get(ctx), coro.coroHandle);
629 RuntimeAwaitAndResumeOp::create(builder, operand, coro.coroHandle);
630
631 // Split the entry block before the await operation.
632 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
633
634 // Add async.coro.suspend as a suspended block terminator.
635 Block *destroy = setupCleanupForDestroyBlock(builder, coro);
636 builder.setInsertionPointToEnd(suspended);
637 CoroSuspendOp::create(builder, coroSaveOp.getState(), coro.suspend,
638 resume, destroy);
639
640 // Split the resume block into error checking and continuation.
641 Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
642
643 // Check if the awaited value is in the error state.
644 builder.setInsertionPointToStart(resume);
645 auto isError = RuntimeIsErrorOp::create(builder, loc, i1, operand);
646 cf::CondBranchOp::create(builder, isError,
647 /*trueDest=*/setupSetErrorBlock(coro),
648 /*trueArgs=*/ArrayRef<Value>(),
649 /*falseDest=*/continuation,
650 /*falseArgs=*/ArrayRef<Value>());
651
652 // Make sure that replacement value will be constructed in the
653 // continuation block.
654 rewriter.setInsertionPointToStart(continuation);
655 }
656
657 // Erase or replace the await operation with the new value.
658 if (Value replaceWith = getReplacementValue(op, operand, rewriter))
659 rewriter.replaceOp(op, replaceWith);
660 else
661 rewriter.eraseOp(op);
662
663 return success();
664 }
665
666 virtual Value getReplacementValue(AwaitType op, Value operand,
667 ConversionPatternRewriter &rewriter) const {
668 return Value();
669 }
670
671private:
672 FuncCoroMapPtr coros;
673 bool shouldLowerBlockingWait;
674};
675
676/// Lowering for `async.await` with a token operand.
677class AwaitTokenOpLowering
678 : public AwaitOpLoweringBase<AwaitOp, async::TokenType> {
679 using Base = AwaitOpLoweringBase<AwaitOp, async::TokenType>;
680
681public:
682 using Base::Base;
683};
684
685/// Lowering for `async.await` with a value operand.
686class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
687 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
688
689public:
690 using Base::Base;
691
692 Value
693 getReplacementValue(AwaitOp op, Value operand,
694 ConversionPatternRewriter &rewriter) const override {
695 // Load from the async value storage.
696 auto valueType = cast<ValueType>(operand.getType()).getValueType();
697 return RuntimeLoadOp::create(rewriter, op->getLoc(), valueType, operand);
698 }
699};
700
701/// Lowering for `async.await_all` operation.
702class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
703 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
704
705public:
706 using Base::Base;
707};
708
709} // namespace
710
711//===----------------------------------------------------------------------===//
712// Convert async.yield operation to async.runtime operations.
713//===----------------------------------------------------------------------===//
714
715class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
716public:
718 : OpConversionPattern<async::YieldOp>(ctx), coros(std::move(coros)) {}
719
720 LogicalResult
721 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
722 ConversionPatternRewriter &rewriter) const override {
723 // Check if yield operation is inside the async coroutine function.
724 auto func = op->template getParentOfType<func::FuncOp>();
725 auto funcCoro = coros->find(func);
726 if (funcCoro == coros->end())
727 return rewriter.notifyMatchFailure(
728 op, "operation is not inside the async coroutine function");
729
730 Location loc = op->getLoc();
731 const CoroMachinery &coro = funcCoro->getSecond();
732
733 // Store yielded values into the async values storage and switch async
734 // values state to available.
735 for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
736 Value yieldValue = std::get<0>(tuple);
737 Value asyncValue = std::get<1>(tuple);
738 RuntimeStoreOp::create(rewriter, loc, yieldValue, asyncValue);
739 RuntimeSetAvailableOp::create(rewriter, loc, asyncValue);
740 }
741
742 if (coro.asyncToken)
743 // Switch the coroutine completion token to available state.
744 RuntimeSetAvailableOp::create(rewriter, loc, *coro.asyncToken);
745
746 cf::BranchOp::create(rewriter, loc, coro.cleanup);
747 rewriter.eraseOp(op);
748
749 return success();
750 }
751
752private:
753 FuncCoroMapPtr coros;
754};
755
756//===----------------------------------------------------------------------===//
757// Convert cf.assert operation to cf.cond_br into `set_error` block.
758//===----------------------------------------------------------------------===//
759
760class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
761public:
763 : OpConversionPattern<cf::AssertOp>(ctx), coros(std::move(coros)) {}
764
765 LogicalResult
766 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
767 ConversionPatternRewriter &rewriter) const override {
768 // Check if assert operation is inside the async coroutine function.
769 auto func = op->template getParentOfType<func::FuncOp>();
770 auto funcCoro = coros->find(func);
771 if (funcCoro == coros->end())
772 return rewriter.notifyMatchFailure(
773 op, "operation is not inside the async coroutine function");
774
775 Location loc = op->getLoc();
776 CoroMachinery &coro = funcCoro->getSecond();
777
778 Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
779 rewriter.setInsertionPointToEnd(cont->getPrevNode());
780 cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(),
781 /*trueDest=*/cont,
782 /*trueArgs=*/ArrayRef<Value>(),
783 /*falseDest=*/setupSetErrorBlock(coro),
784 /*falseArgs=*/ArrayRef<Value>());
785 rewriter.eraseOp(op);
786
787 return success();
788 }
789
790private:
791 FuncCoroMapPtr coros;
792};
793
794//===----------------------------------------------------------------------===//
795void AsyncToAsyncRuntimePass::runOnOperation() {
796 ModuleOp module = getOperation();
797 SymbolTable symbolTable(module);
798
799 // Functions with coroutine CFG setups, which are results of outlining
800 // `async.execute` body regions
801 FuncCoroMapPtr coros =
802 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
803
804 module.walk([&](ExecuteOp execute) {
805 coros->insert(outlineExecuteOp(symbolTable, execute));
806 });
807
808 LLVM_DEBUG({
809 llvm::dbgs() << "Outlined " << coros->size()
810 << " functions built from async.execute operations\n";
811 });
812
813 // Returns true if operation is inside the coroutine.
814 auto isInCoroutine = [&](Operation *op) -> bool {
815 auto parentFunc = op->getParentOfType<func::FuncOp>();
816 return coros->contains(parentFunc);
817 };
818
819 // Lower async operations to async.runtime operations.
820 MLIRContext *ctx = module->getContext();
821 RewritePatternSet asyncPatterns(ctx);
822
823 // Conversion to async runtime augments original CFG with the coroutine CFG,
824 // and we have to make sure that structured control flow operations with async
825 // operations in nested regions will be converted to branch-based control flow
826 // before we add the coroutine basic blocks.
828
829 // Async lowering does not use type converter because it must preserve all
830 // types for async.runtime operations.
831 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
832
833 asyncPatterns
834 .add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
835 ctx, coros, /*should_lower_blocking_wait=*/true);
836
837 // Lower assertions to conditional branches into error blocks.
838 asyncPatterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
839
840 // All high level async operations must be lowered to the runtime operations.
841 ConversionTarget runtimeTarget(*ctx);
842 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
843 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
844 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
845
846 // Decide if structured control flow has to be lowered to branch-based CFG.
847 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
848 auto walkResult = op->walk([&](Operation *nested) {
849 bool isAsync = isa<async::AsyncDialect>(nested->getDialect());
850 return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
851 : WalkResult::advance();
852 });
853 return !walkResult.wasInterrupted();
854 });
855 runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
856 func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
857
858 // Assertions must be converted to runtime errors inside async functions.
859 runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
860 [&](cf::AssertOp op) -> bool {
861 auto func = op->getParentOfType<func::FuncOp>();
862 return !coros->contains(func);
863 });
864
865 if (failed(applyPartialConversion(module, runtimeTarget,
866 std::move(asyncPatterns)))) {
867 signalPassFailure();
868 return;
869 }
870}
871
872//===----------------------------------------------------------------------===//
875 // Functions with coroutine CFG setups, which are results of converting
876 // async.func.
877 FuncCoroMapPtr coros =
878 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
879 MLIRContext *ctx = patterns.getContext();
880 // Lower async.func to func.func with coroutine cfg.
881 patterns.add<AsyncCallOpLowering>(ctx);
882 patterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
883
884 patterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
885 ctx, coros, /*should_lower_blocking_wait=*/false);
886 patterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
887
888 target.addDynamicallyLegalOp<AwaitOp, AwaitAllOp, YieldOp, cf::AssertOp>(
889 [coros](Operation *op) {
890 auto exec = op->getParentOfType<ExecuteOp>();
891 auto func = op->getParentOfType<func::FuncOp>();
892 return exec || !coros->contains(func);
893 });
894}
895
896void AsyncFuncToAsyncRuntimePass::runOnOperation() {
897 ModuleOp module = getOperation();
898
899 // Lower async operations to async.runtime operations.
900 MLIRContext *ctx = module->getContext();
901 RewritePatternSet asyncPatterns(ctx);
902 ConversionTarget runtimeTarget(*ctx);
903
904 // Lower async.func to func.func with coroutine cfg.
906 runtimeTarget);
907
908 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
909 runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();
910
911 runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,
912 cf::BranchOp, cf::CondBranchOp>();
913
914 if (failed(applyPartialConversion(module, runtimeTarget,
915 std::move(asyncPatterns)))) {
916 signalPassFailure();
917 return;
918 }
919}
return success()
static Block * setupCleanupForDestroyBlock(ImplicitLocOpBuilder &builder, CoroMachinery &coro)
static constexpr const char kAsyncFnPrefix[]
std::shared_ptr< llvm::DenseMap< func::FuncOp, CoroMachinery > > FuncCoroMapPtr
static Block * setupSetErrorBlock(CoroMachinery &coro)
static CoroMachinery setupCoroMachinery(func::FuncOp func)
Utility to partially update the regular function CFG to the coroutine CFG compatible with LLVM corout...
static std::pair< func::FuncOp, CoroMachinery > outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute)
Outline the body region attached to the async.execute op into a standalone function.
AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:150
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition Block.cpp:323
OpListType & getOperations()
Definition Block.h:147
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
typename cf::AssertOp::Adaptor OpAdaptor
Definition Pattern.h:229
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
static ImplicitLocOpBuilder atBlockBegin(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition Builders.h:642
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:435
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:237
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Definition SymbolTable.h:97
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult interrupt()
Definition WalkResult.h:46
void cloneConstantsIntoTheRegion(Region &region)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
void populateAsyncFuncToAsyncRuntimeConversionPatterns(RewritePatternSet &patterns, ConversionTarget &target)