MLIR 22.0.0git
AsyncToAsyncRuntime.cpp
Go to the documentation of this file.
1//===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering from high level async operations to async.coro
10// and async.runtime operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include <utility>
15
17
18#include "PassDetail.h"
25#include "mlir/IR/IRMapping.h"
29#include "llvm/Support/Debug.h"
30#include <optional>
31
32namespace mlir {
33#define GEN_PASS_DEF_ASYNCTOASYNCRUNTIMEPASS
34#define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIMEPASS
35#include "mlir/Dialect/Async/Passes.h.inc"
36} // namespace mlir
38using namespace mlir;
39using namespace mlir::async;
40
41#define DEBUG_TYPE "async-to-async-runtime"
42// Prefix for functions outlined from `async.execute` op regions.
43static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
44
45namespace {
46
47class AsyncToAsyncRuntimePass
48 : public impl::AsyncToAsyncRuntimePassBase<AsyncToAsyncRuntimePass> {
49public:
50 AsyncToAsyncRuntimePass() = default;
51 void runOnOperation() override;
52};
54} // namespace
55
56namespace {
57
58class AsyncFuncToAsyncRuntimePass
60 AsyncFuncToAsyncRuntimePass> {
61public:
62 AsyncFuncToAsyncRuntimePass() = default;
63 void runOnOperation() override;
64};
65
66} // namespace
67
68/// Function targeted for coroutine transformation has two additional blocks at
69/// the end: coroutine cleanup and coroutine suspension.
70///
71/// async.await op lowering additionaly creates a resume block for each
72/// operation to enable non-blocking waiting via coroutine suspension.
73namespace {
74struct CoroMachinery {
75 func::FuncOp func;
76
77 // Async function returns an optional token, followed by some async values
78 //
79 // async.func @foo() -> !async.value<T> {
80 // %cst = arith.constant 42.0 : T
81 // return %cst: T
82 // }
83 // Async execute region returns a completion token, and an async value for
84 // each yielded value.
85 //
86 // %token, %result = async.execute -> !async.value<T> {
87 // %0 = arith.constant ... : T
88 // async.yield %0 : T
89 // }
90 std::optional<Value> asyncToken; // returned completion token
91 llvm::SmallVector<Value, 4> returnValues; // returned async values
92
93 Value coroHandle; // coroutine handle (!async.coro.getHandle value)
94 Block *entry; // coroutine entry block
95 std::optional<Block *> setError; // set returned values to error state
96 Block *cleanup; // coroutine cleanup block
97
98 // Coroutine cleanup block for destroy after the coroutine is resumed,
99 // e.g. async.coro.suspend state, [suspend], [resume], [destroy]
100 //
101 // This cleanup block is a duplicate of the cleanup block followed by the
102 // resume block. The purpose of having a duplicate cleanup block for destroy
103 // is to make the CFG clear so that the control flow analysis won't confuse.
104 //
105 // The overall structure of the lowered CFG can be the following,
106 //
107 // Entry (calling async.coro.suspend)
108 // | \
109 // Resume Destroy (duplicate of Cleanup)
110 // | |
111 // Cleanup |
112 // | /
113 // End (ends the corontine)
114 //
115 // If there is resume-specific cleanup logic, it can go into the Cleanup
116 // block but not the destroy block. Otherwise, it can fail block dominance
117 // check.
118 Block *cleanupForDestroy;
119 Block *suspend; // coroutine suspension block
120};
121} // namespace
122
124 std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;
125
126/// Utility to partially update the regular function CFG to the coroutine CFG
127/// compatible with LLVM coroutines switched-resume lowering using
128/// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
129/// that branches into preexisting entry block. Also inserts trailing blocks.
130///
131/// The result types of the passed `func` start with an optional `async.token`
132/// and be continued with some number of `async.value`s.
133///
134/// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
135///
136/// - `entry` block sets up the coroutine.
137/// - `set_error` block sets completion token and async values state to error.
138/// - `cleanup` block cleans up the coroutine state.
139/// - `suspend block after the @llvm.coro.end() defines what value will be
140/// returned to the initial caller of a coroutine. Everything before the
141/// @llvm.coro.end() will be executed at every suspension point.
142///
143/// Coroutine structure (only the important bits):
144///
145/// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>)
146/// {
147/// ^entry(<function-arguments>):
148/// %token = <async token> : !async.token // create async runtime token
149/// %value = <async value> : !async.value<T> // create async value
150/// %id = async.coro.getId // create a coroutine id
151/// %hdl = async.coro.begin %id // create a coroutine handle
152/// cf.br ^preexisting_entry_block
153///
154/// /* preexisting blocks modified to branch to the cleanup block */
155///
156/// ^set_error: // this block created lazily only if needed (see code below)
157/// async.runtime.set_error %token : !async.token
158/// async.runtime.set_error %value : !async.value<T>
159/// cf.br ^cleanup
160///
161/// ^cleanup:
162/// async.coro.free %hdl // delete the coroutine state
163/// cf.br ^suspend
164///
165/// ^suspend:
166/// async.coro.end %hdl // marks the end of a coroutine
167/// return %token, %value : !async.token, !async.value<T>
168/// }
169///
170static CoroMachinery setupCoroMachinery(func::FuncOp func) {
171 assert(!func.getBlocks().empty() && "Function must have an entry block");
172
173 MLIRContext *ctx = func.getContext();
174 Block *entryBlock = &func.getBlocks().front();
175 Block *originalEntryBlock =
176 entryBlock->splitBlock(entryBlock->getOperations().begin());
177 auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
178
179 // ------------------------------------------------------------------------ //
180 // Allocate async token/values that we will return from a ramp function.
181 // ------------------------------------------------------------------------ //
182
183 // We treat TokenType as state update marker to represent side-effects of
184 // async computations
185 bool isStateful = isa<TokenType>(func.getResultTypes().front());
186
187 std::optional<Value> retToken;
188 if (isStateful)
189 retToken.emplace(RuntimeCreateOp::create(builder, TokenType::get(ctx)));
190
192 ArrayRef<Type> resValueTypes =
193 isStateful ? func.getResultTypes().drop_front() : func.getResultTypes();
194 for (auto resType : resValueTypes)
195 retValues.emplace_back(
196 RuntimeCreateOp::create(builder, resType).getResult());
197
198 // ------------------------------------------------------------------------ //
199 // Initialize coroutine: get coroutine id and coroutine handle.
200 // ------------------------------------------------------------------------ //
201 auto coroIdOp = CoroIdOp::create(builder, CoroIdType::get(ctx));
202 auto coroHdlOp =
203 CoroBeginOp::create(builder, CoroHandleType::get(ctx), coroIdOp.getId());
204 cf::BranchOp::create(builder, originalEntryBlock);
205
206 Block *cleanupBlock = func.addBlock();
207 Block *cleanupBlockForDestroy = func.addBlock();
208 Block *suspendBlock = func.addBlock();
209
210 // ------------------------------------------------------------------------ //
211 // Coroutine cleanup blocks: deallocate coroutine frame, free the memory.
212 // ------------------------------------------------------------------------ //
213 auto buildCleanupBlock = [&](Block *cb) {
214 builder.setInsertionPointToStart(cb);
215 CoroFreeOp::create(builder, coroIdOp.getId(), coroHdlOp.getHandle());
216
217 // Branch into the suspend block.
218 cf::BranchOp::create(builder, suspendBlock);
219 };
220 buildCleanupBlock(cleanupBlock);
221 buildCleanupBlock(cleanupBlockForDestroy);
222
223 // ------------------------------------------------------------------------ //
224 // Coroutine suspend block: mark the end of a coroutine and return allocated
225 // async token.
226 // ------------------------------------------------------------------------ //
227 builder.setInsertionPointToStart(suspendBlock);
228
229 // Mark the end of a coroutine: async.coro.end
230 CoroEndOp::create(builder, coroHdlOp.getHandle());
231
232 // Return created optional `async.token` and `async.values` from the suspend
233 // block. This will be the return value of a coroutine ramp function.
235 if (retToken)
236 ret.push_back(*retToken);
237 llvm::append_range(ret, retValues);
238 func::ReturnOp::create(builder, ret);
239
240 // `async.await` op lowering will create resume blocks for async
241 // continuations, and will conditionally branch to cleanup or suspend blocks.
242
243 // The switch-resumed API based coroutine should be marked with
244 // presplitcoroutine attribute to mark the function as a coroutine.
245 func->setAttr("passthrough", builder.getArrayAttr(
246 StringAttr::get(ctx, "presplitcoroutine")));
247
248 CoroMachinery machinery;
249 machinery.func = func;
250 machinery.asyncToken = retToken;
251 machinery.returnValues = retValues;
252 machinery.coroHandle = coroHdlOp.getHandle();
253 machinery.entry = entryBlock;
254 machinery.setError = std::nullopt; // created lazily only if needed
255 machinery.cleanup = cleanupBlock;
256 machinery.cleanupForDestroy = cleanupBlockForDestroy;
257 machinery.suspend = suspendBlock;
258 return machinery;
259}
260
261// Lazily creates `set_error` block only if it is required for lowering to the
262// runtime operations (see for example lowering of assert operation).
263static Block *setupSetErrorBlock(CoroMachinery &coro) {
264 if (coro.setError)
265 return *coro.setError;
266
267 coro.setError = coro.func.addBlock();
268 (*coro.setError)->moveBefore(coro.cleanup);
269
270 auto builder =
271 ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), *coro.setError);
272
273 // Coroutine set_error block: set error on token and all returned values.
274 if (coro.asyncToken)
275 RuntimeSetErrorOp::create(builder, *coro.asyncToken);
276
277 for (Value retValue : coro.returnValues)
278 RuntimeSetErrorOp::create(builder, retValue);
279
280 // Branch into the cleanup block.
281 cf::BranchOp::create(builder, coro.cleanup);
282
283 return *coro.setError;
284}
285
286//===----------------------------------------------------------------------===//
287// async.execute op outlining to the coroutine functions.
288//===----------------------------------------------------------------------===//
289
290/// Outline the body region attached to the `async.execute` op into a standalone
291/// function.
292///
293/// Note that this is not reversible transformation.
294static std::pair<func::FuncOp, CoroMachinery>
295outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
296 ModuleOp module = execute->getParentOfType<ModuleOp>();
297
298 MLIRContext *ctx = module.getContext();
299 Location loc = execute.getLoc();
300
301 // Make sure that all constants will be inside the outlined async function to
302 // reduce the number of function arguments.
303 cloneConstantsIntoTheRegion(execute.getBodyRegion());
304
305 // Collect all outlined function inputs.
306 SetVector<mlir::Value> functionInputs(llvm::from_range,
307 execute.getDependencies());
308 functionInputs.insert_range(execute.getBodyOperands());
309 getUsedValuesDefinedAbove(execute.getBodyRegion(), functionInputs);
310
311 // Collect types for the outlined function inputs and outputs.
312 auto typesRange = llvm::map_range(
313 functionInputs, [](Value value) { return value.getType(); });
314 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
315 auto outputTypes = execute.getResultTypes();
316
317 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
318 auto funcAttrs = ArrayRef<NamedAttribute>();
319
320 // TODO: Derive outlined function name from the parent FuncOp (support
321 // multiple nested async.execute operations).
322 func::FuncOp func =
323 func::FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
324 symbolTable.insert(func);
325
327 auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock());
328
329 // Prepare for coroutine conversion by creating the body of the function.
330 {
331 size_t numDependencies = execute.getDependencies().size();
332 size_t numOperands = execute.getBodyOperands().size();
333
334 // Await on all dependencies before starting to execute the body region.
335 for (size_t i = 0; i < numDependencies; ++i)
336 AwaitOp::create(builder, func.getArgument(i));
337
338 // Await on all async value operands and unwrap the payload.
339 SmallVector<Value, 4> unwrappedOperands(numOperands);
340 for (size_t i = 0; i < numOperands; ++i) {
341 Value operand = func.getArgument(numDependencies + i);
342 unwrappedOperands[i] = AwaitOp::create(builder, loc, operand).getResult();
343 }
344
345 // Map from function inputs defined above the execute op to the function
346 // arguments.
347 IRMapping valueMapping;
348 valueMapping.map(functionInputs, func.getArguments());
349 valueMapping.map(execute.getBodyRegion().getArguments(), unwrappedOperands);
350
351 // Clone all operations from the execute operation body into the outlined
352 // function body.
353 for (Operation &op : execute.getBodyRegion().getOps())
354 builder.clone(op, valueMapping);
355 }
356
357 // Adding entry/cleanup/suspend blocks.
358 CoroMachinery coro = setupCoroMachinery(func);
359
360 // Suspend async function at the end of an entry block, and resume it using
361 // Async resume operation (execution will be resumed in a thread managed by
362 // the async runtime).
363 {
364 cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
365 builder.setInsertionPointToEnd(coro.entry);
366
367 // Save the coroutine state: async.coro.save
368 auto coroSaveOp =
369 CoroSaveOp::create(builder, CoroStateType::get(ctx), coro.coroHandle);
370
371 // Pass coroutine to the runtime to be resumed on a runtime managed
372 // thread.
373 RuntimeResumeOp::create(builder, coro.coroHandle);
374
375 // Add async.coro.suspend as a suspended block terminator.
376 CoroSuspendOp::create(builder, coroSaveOp.getState(), coro.suspend,
377 branch.getDest(), coro.cleanupForDestroy);
378
379 branch.erase();
380 }
381
382 // Replace the original `async.execute` with a call to outlined function.
383 {
384 ImplicitLocOpBuilder callBuilder(loc, execute);
385 auto callOutlinedFunc = func::CallOp::create(callBuilder, func.getName(),
386 execute.getResultTypes(),
387 functionInputs.getArrayRef());
388 execute.replaceAllUsesWith(callOutlinedFunc.getResults());
389 execute.erase();
390 }
391
392 return {func, coro};
393}
394
395//===----------------------------------------------------------------------===//
396// Convert async.create_group operation to async.runtime.create_group
397//===----------------------------------------------------------------------===//
398
399namespace {
400class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
401public:
402 using OpConversionPattern::OpConversionPattern;
403
404 LogicalResult
405 matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
406 ConversionPatternRewriter &rewriter) const override {
407 rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
408 op, GroupType::get(op->getContext()), adaptor.getOperands());
409 return success();
410 }
411};
412} // namespace
413
414//===----------------------------------------------------------------------===//
415// Convert async.add_to_group operation to async.runtime.add_to_group.
416//===----------------------------------------------------------------------===//
417
418namespace {
419class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
420public:
421 using OpConversionPattern::OpConversionPattern;
422
423 LogicalResult
424 matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
425 ConversionPatternRewriter &rewriter) const override {
426 rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
427 op, rewriter.getIndexType(), adaptor.getOperands());
428 return success();
429 }
430};
431} // namespace
432
433//===----------------------------------------------------------------------===//
434// Convert async.func, async.return and async.call operations to non-blocking
435// operations based on llvm coroutine
436//===----------------------------------------------------------------------===//
438namespace {
440//===----------------------------------------------------------------------===//
441// Convert async.func operation to func.func
442//===----------------------------------------------------------------------===//
444class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> {
445public:
446 AsyncFuncOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
447 : OpConversionPattern<async::FuncOp>(ctx), coros(std::move(coros)) {}
448
449 LogicalResult
450 matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
451 ConversionPatternRewriter &rewriter) const override {
452 Location loc = op->getLoc();
453
454 auto newFuncOp =
455 func::FuncOp::create(rewriter, loc, op.getName(), op.getFunctionType());
456
459 // Copy over all attributes other than the name.
460 for (const auto &namedAttr : op->getAttrs()) {
461 if (namedAttr.getName() != SymbolTable::getSymbolAttrName())
462 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
463 }
464
465 rewriter.inlineRegionBefore(op.getBody(), newFuncOp.getBody(),
466 newFuncOp.end());
467
468 CoroMachinery coro = setupCoroMachinery(newFuncOp);
469 (*coros)[newFuncOp] = coro;
470 // no initial suspend, we should hot-start
472 rewriter.eraseOp(op);
473 return success();
474 }
475
476private:
477 FuncCoroMapPtr coros;
478};
479
480//===----------------------------------------------------------------------===//
481// Convert async.call operation to func.call
482//===----------------------------------------------------------------------===//
483
484class AsyncCallOpLowering : public OpConversionPattern<async::CallOp> {
485public:
486 AsyncCallOpLowering(MLIRContext *ctx)
487 : OpConversionPattern<async::CallOp>(ctx) {}
488
489 LogicalResult
490 matchAndRewrite(async::CallOp op, OpAdaptor adaptor,
491 ConversionPatternRewriter &rewriter) const override {
492 rewriter.replaceOpWithNewOp<func::CallOp>(
493 op, op.getCallee(), op.getResultTypes(), op.getOperands());
494 return success();
495 }
496};
497
498//===----------------------------------------------------------------------===//
499// Convert async.return operation to async.runtime operations.
500//===----------------------------------------------------------------------===//
501
502class AsyncReturnOpLowering : public OpConversionPattern<async::ReturnOp> {
503public:
504 AsyncReturnOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
505 : OpConversionPattern<async::ReturnOp>(ctx), coros(std::move(coros)) {}
506
507 LogicalResult
508 matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
509 ConversionPatternRewriter &rewriter) const override {
510 auto func = op->template getParentOfType<func::FuncOp>();
511 auto funcCoro = coros->find(func);
512 if (funcCoro == coros->end())
513 return rewriter.notifyMatchFailure(
514 op, "operation is not inside the async coroutine function");
515
516 Location loc = op->getLoc();
517 const CoroMachinery &coro = funcCoro->getSecond();
518 rewriter.setInsertionPointAfter(op);
519
520 // Store return values into the async values storage and switch async
521 // values state to available.
522 for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
523 Value returnValue = std::get<0>(tuple);
524 Value asyncValue = std::get<1>(tuple);
525 RuntimeStoreOp::create(rewriter, loc, returnValue, asyncValue);
526 RuntimeSetAvailableOp::create(rewriter, loc, asyncValue);
527 }
528
529 if (coro.asyncToken)
530 // Switch the coroutine completion token to available state.
531 RuntimeSetAvailableOp::create(rewriter, loc, *coro.asyncToken);
532
533 rewriter.eraseOp(op);
534 cf::BranchOp::create(rewriter, loc, coro.cleanup);
535 return success();
536 }
537
538private:
539 FuncCoroMapPtr coros;
540};
541} // namespace
542
543//===----------------------------------------------------------------------===//
544// Convert async.await and async.await_all operations to the async.runtime.await
545// or async.runtime.await_and_resume operations.
546//===----------------------------------------------------------------------===//
547
548namespace {
549template <typename AwaitType, typename AwaitableType>
550class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
551 using AwaitAdaptor = typename AwaitType::Adaptor;
552
553public:
554 AwaitOpLoweringBase(MLIRContext *ctx, FuncCoroMapPtr coros,
555 bool shouldLowerBlockingWait)
556 : OpConversionPattern<AwaitType>(ctx), coros(std::move(coros)),
557 shouldLowerBlockingWait(shouldLowerBlockingWait) {}
558
559 LogicalResult
560 matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
561 ConversionPatternRewriter &rewriter) const override {
562 // We can only await on one the `AwaitableType` (for `await` it can be
563 // a `token` or a `value`, for `await_all` it must be a `group`).
564 if (!isa<AwaitableType>(op.getOperand().getType()))
565 return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
566
567 // Check if await operation is inside the coroutine function.
568 auto func = op->template getParentOfType<func::FuncOp>();
569 auto funcCoro = coros->find(func);
570 const bool isInCoroutine = funcCoro != coros->end();
571
572 Location loc = op->getLoc();
573 Value operand = adaptor.getOperand();
574
575 Type i1 = rewriter.getI1Type();
576
577 // Delay lowering to block wait in case await op is inside async.execute
578 if (!isInCoroutine && !shouldLowerBlockingWait)
579 return failure();
580
581 // Inside regular functions we use the blocking wait operation to wait for
582 // the async object (token, value or group) to become available.
583 if (!isInCoroutine) {
584 ImplicitLocOpBuilder builder(loc, rewriter);
585 RuntimeAwaitOp::create(builder, loc, operand);
586
587 // Assert that the awaited operands is not in the error state.
588 Value isError = RuntimeIsErrorOp::create(builder, i1, operand);
589 Value notError = arith::XOrIOp::create(
590 builder, isError,
591 arith::ConstantOp::create(builder, loc, i1,
592 builder.getIntegerAttr(i1, 1)));
593
594 cf::AssertOp::create(builder, notError,
595 "Awaited async operand is in error state");
596 }
597
598 // Inside the coroutine we convert await operation into coroutine suspension
599 // point, and resume execution asynchronously.
600 if (isInCoroutine) {
601 CoroMachinery &coro = funcCoro->getSecond();
602 Block *suspended = op->getBlock();
603
604 ImplicitLocOpBuilder builder(loc, rewriter);
605 MLIRContext *ctx = op->getContext();
606
607 // Save the coroutine state and resume on a runtime managed thread when
608 // the operand becomes available.
609 auto coroSaveOp =
610 CoroSaveOp::create(builder, CoroStateType::get(ctx), coro.coroHandle);
611 RuntimeAwaitAndResumeOp::create(builder, operand, coro.coroHandle);
612
613 // Split the entry block before the await operation.
614 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
615
616 // Add async.coro.suspend as a suspended block terminator.
617 builder.setInsertionPointToEnd(suspended);
618 CoroSuspendOp::create(builder, coroSaveOp.getState(), coro.suspend,
619 resume, coro.cleanupForDestroy);
620
621 // Split the resume block into error checking and continuation.
622 Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
623
624 // Check if the awaited value is in the error state.
625 builder.setInsertionPointToStart(resume);
626 auto isError = RuntimeIsErrorOp::create(builder, loc, i1, operand);
627 cf::CondBranchOp::create(builder, isError,
628 /*trueDest=*/setupSetErrorBlock(coro),
629 /*trueArgs=*/ArrayRef<Value>(),
630 /*falseDest=*/continuation,
631 /*falseArgs=*/ArrayRef<Value>());
632
633 // Make sure that replacement value will be constructed in the
634 // continuation block.
635 rewriter.setInsertionPointToStart(continuation);
636 }
637
638 // Erase or replace the await operation with the new value.
639 if (Value replaceWith = getReplacementValue(op, operand, rewriter))
640 rewriter.replaceOp(op, replaceWith);
641 else
642 rewriter.eraseOp(op);
643
644 return success();
645 }
646
647 virtual Value getReplacementValue(AwaitType op, Value operand,
648 ConversionPatternRewriter &rewriter) const {
649 return Value();
650 }
651
652private:
653 FuncCoroMapPtr coros;
654 bool shouldLowerBlockingWait;
655};
656
657/// Lowering for `async.await` with a token operand.
658class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
659 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
660
661public:
662 using Base::Base;
663};
664
665/// Lowering for `async.await` with a value operand.
666class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
667 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
668
669public:
670 using Base::Base;
671
672 Value
673 getReplacementValue(AwaitOp op, Value operand,
674 ConversionPatternRewriter &rewriter) const override {
675 // Load from the async value storage.
676 auto valueType = cast<ValueType>(operand.getType()).getValueType();
677 return RuntimeLoadOp::create(rewriter, op->getLoc(), valueType, operand);
678 }
679};
680
681/// Lowering for `async.await_all` operation.
682class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
683 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
684
685public:
686 using Base::Base;
687};
688
689} // namespace
690
691//===----------------------------------------------------------------------===//
692// Convert async.yield operation to async.runtime operations.
693//===----------------------------------------------------------------------===//
694
695class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
696public:
698 : OpConversionPattern<async::YieldOp>(ctx), coros(std::move(coros)) {}
699
700 LogicalResult
701 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
702 ConversionPatternRewriter &rewriter) const override {
703 // Check if yield operation is inside the async coroutine function.
704 auto func = op->template getParentOfType<func::FuncOp>();
705 auto funcCoro = coros->find(func);
706 if (funcCoro == coros->end())
707 return rewriter.notifyMatchFailure(
708 op, "operation is not inside the async coroutine function");
709
710 Location loc = op->getLoc();
711 const CoroMachinery &coro = funcCoro->getSecond();
712
713 // Store yielded values into the async values storage and switch async
714 // values state to available.
715 for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
716 Value yieldValue = std::get<0>(tuple);
717 Value asyncValue = std::get<1>(tuple);
718 RuntimeStoreOp::create(rewriter, loc, yieldValue, asyncValue);
719 RuntimeSetAvailableOp::create(rewriter, loc, asyncValue);
720 }
721
722 if (coro.asyncToken)
723 // Switch the coroutine completion token to available state.
724 RuntimeSetAvailableOp::create(rewriter, loc, *coro.asyncToken);
725
726 cf::BranchOp::create(rewriter, loc, coro.cleanup);
727 rewriter.eraseOp(op);
728
729 return success();
730 }
731
732private:
733 FuncCoroMapPtr coros;
734};
735
736//===----------------------------------------------------------------------===//
737// Convert cf.assert operation to cf.cond_br into `set_error` block.
738//===----------------------------------------------------------------------===//
739
740class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
741public:
743 : OpConversionPattern<cf::AssertOp>(ctx), coros(std::move(coros)) {}
744
745 LogicalResult
746 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
747 ConversionPatternRewriter &rewriter) const override {
748 // Check if assert operation is inside the async coroutine function.
749 auto func = op->template getParentOfType<func::FuncOp>();
750 auto funcCoro = coros->find(func);
751 if (funcCoro == coros->end())
752 return rewriter.notifyMatchFailure(
753 op, "operation is not inside the async coroutine function");
754
755 Location loc = op->getLoc();
756 CoroMachinery &coro = funcCoro->getSecond();
757
758 Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
759 rewriter.setInsertionPointToEnd(cont->getPrevNode());
760 cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(),
761 /*trueDest=*/cont,
762 /*trueArgs=*/ArrayRef<Value>(),
763 /*falseDest=*/setupSetErrorBlock(coro),
764 /*falseArgs=*/ArrayRef<Value>());
765 rewriter.eraseOp(op);
766
767 return success();
768 }
769
770private:
771 FuncCoroMapPtr coros;
772};
773
774//===----------------------------------------------------------------------===//
775void AsyncToAsyncRuntimePass::runOnOperation() {
776 ModuleOp module = getOperation();
777 SymbolTable symbolTable(module);
778
779 // Functions with coroutine CFG setups, which are results of outlining
780 // `async.execute` body regions
781 FuncCoroMapPtr coros =
782 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
783
784 module.walk([&](ExecuteOp execute) {
785 coros->insert(outlineExecuteOp(symbolTable, execute));
786 });
787
788 LLVM_DEBUG({
789 llvm::dbgs() << "Outlined " << coros->size()
790 << " functions built from async.execute operations\n";
791 });
792
793 // Returns true if operation is inside the coroutine.
794 auto isInCoroutine = [&](Operation *op) -> bool {
795 auto parentFunc = op->getParentOfType<func::FuncOp>();
796 return coros->contains(parentFunc);
797 };
798
799 // Lower async operations to async.runtime operations.
800 MLIRContext *ctx = module->getContext();
801 RewritePatternSet asyncPatterns(ctx);
802
803 // Conversion to async runtime augments original CFG with the coroutine CFG,
804 // and we have to make sure that structured control flow operations with async
805 // operations in nested regions will be converted to branch-based control flow
806 // before we add the coroutine basic blocks.
808
809 // Async lowering does not use type converter because it must preserve all
810 // types for async.runtime operations.
811 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
812
813 asyncPatterns
814 .add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
815 ctx, coros, /*should_lower_blocking_wait=*/true);
816
817 // Lower assertions to conditional branches into error blocks.
818 asyncPatterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
819
820 // All high level async operations must be lowered to the runtime operations.
821 ConversionTarget runtimeTarget(*ctx);
822 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
823 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
824 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
825
826 // Decide if structured control flow has to be lowered to branch-based CFG.
827 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
828 auto walkResult = op->walk([&](Operation *nested) {
829 bool isAsync = isa<async::AsyncDialect>(nested->getDialect());
830 return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
831 : WalkResult::advance();
832 });
833 return !walkResult.wasInterrupted();
834 });
835 runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
836 func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
837
838 // Assertions must be converted to runtime errors inside async functions.
839 runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
840 [&](cf::AssertOp op) -> bool {
841 auto func = op->getParentOfType<func::FuncOp>();
842 return !coros->contains(func);
843 });
844
845 if (failed(applyPartialConversion(module, runtimeTarget,
846 std::move(asyncPatterns)))) {
847 signalPassFailure();
848 return;
849 }
850}
851
852//===----------------------------------------------------------------------===//
855 // Functions with coroutine CFG setups, which are results of converting
856 // async.func.
857 FuncCoroMapPtr coros =
858 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
859 MLIRContext *ctx = patterns.getContext();
860 // Lower async.func to func.func with coroutine cfg.
861 patterns.add<AsyncCallOpLowering>(ctx);
862 patterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
863
864 patterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
865 ctx, coros, /*should_lower_blocking_wait=*/false);
866 patterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
867
868 target.addDynamicallyLegalOp<AwaitOp, AwaitAllOp, YieldOp, cf::AssertOp>(
869 [coros](Operation *op) {
870 auto exec = op->getParentOfType<ExecuteOp>();
871 auto func = op->getParentOfType<func::FuncOp>();
872 return exec || !coros->contains(func);
873 });
874}
875
876void AsyncFuncToAsyncRuntimePass::runOnOperation() {
877 ModuleOp module = getOperation();
878
879 // Lower async operations to async.runtime operations.
880 MLIRContext *ctx = module->getContext();
881 RewritePatternSet asyncPatterns(ctx);
882 ConversionTarget runtimeTarget(*ctx);
883
884 // Lower async.func to func.func with coroutine cfg.
886 runtimeTarget);
887
888 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
889 runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();
890
891 runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,
892 cf::BranchOp, cf::CondBranchOp>();
893
894 if (failed(applyPartialConversion(module, runtimeTarget,
895 std::move(asyncPatterns)))) {
896 signalPassFailure();
897 return;
898 }
899}
return success()
static constexpr const char kAsyncFnPrefix[]
std::shared_ptr< llvm::DenseMap< func::FuncOp, CoroMachinery > > FuncCoroMapPtr
static Block * setupSetErrorBlock(CoroMachinery &coro)
static CoroMachinery setupCoroMachinery(func::FuncOp func)
Utility to partially update the regular function CFG to the coroutine CFG compatible with LLVM corout...
static std::pair< func::FuncOp, CoroMachinery > outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute)
Outline the body region attached to the async.execute op into a standalone function.
AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:140
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition Block.cpp:318
OpListType & getOperations()
Definition Block.h:137
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
typename cf::AssertOp::Adaptor OpAdaptor
Definition Pattern.h:211
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
static ImplicitLocOpBuilder atBlockBegin(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition Builders.h:640
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Definition SymbolTable.h:97
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult interrupt()
Definition WalkResult.h:46
void cloneConstantsIntoTheRegion(Region &region)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
void populateAsyncFuncToAsyncRuntimeConversionPatterns(RewritePatternSet &patterns, ConversionTarget &target)