MLIR 23.0.0git
AsyncToAsyncRuntime.cpp
Go to the documentation of this file.
1//===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering from high level async operations to async.coro
10// and async.runtime operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include <utility>
15
17
18#include "PassDetail.h"
25#include "mlir/IR/IRMapping.h"
29#include "llvm/Support/Debug.h"
30#include <optional>
31
32namespace mlir {
33#define GEN_PASS_DEF_ASYNCTOASYNCRUNTIMEPASS
34#define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIMEPASS
35#include "mlir/Dialect/Async/Passes.h.inc"
36} // namespace mlir
37
38using namespace mlir;
39using namespace mlir::async;
40
41#define DEBUG_TYPE "async-to-async-runtime"
42// Prefix for functions outlined from `async.execute` op regions.
43static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
44
45namespace {
46
47class AsyncToAsyncRuntimePass
48 : public impl::AsyncToAsyncRuntimePassBase<AsyncToAsyncRuntimePass> {
49public:
50 AsyncToAsyncRuntimePass() = default;
51 void runOnOperation() override;
52};
53
54} // namespace
55
56namespace {
57
58class AsyncFuncToAsyncRuntimePass
59 : public impl::AsyncFuncToAsyncRuntimePassBase<
60 AsyncFuncToAsyncRuntimePass> {
61public:
62 AsyncFuncToAsyncRuntimePass() = default;
63 void runOnOperation() override;
64};
65
66} // namespace
67
68/// Function targeted for coroutine transformation has two additional blocks at
69/// the end: coroutine cleanup and coroutine suspension.
70///
71/// async.await op lowering additionaly creates a resume block for each
72/// operation to enable non-blocking waiting via coroutine suspension.
73namespace {
74struct CoroMachinery {
75 func::FuncOp func;
76
77 // Async function returns an optional token, followed by some async values
78 //
79 // async.func @foo() -> !async.value<T> {
80 // %cst = arith.constant 42.0 : T
81 // return %cst: T
82 // }
83 // Async execute region returns a completion token, and an async value for
84 // each yielded value.
85 //
86 // %token, %result = async.execute -> !async.value<T> {
87 // %0 = arith.constant ... : T
88 // async.yield %0 : T
89 // }
90 std::optional<Value> asyncToken; // returned completion token
91 llvm::SmallVector<Value, 4> returnValues; // returned async values
92
93 Value coroHandle; // coroutine handle (!async.coro.getHandle value)
94 Block *entry; // coroutine entry block
95 std::optional<Block *> setError; // set returned values to error state
96 Block *cleanup; // coroutine cleanup block
97
98 // Coroutine cleanup block for destroy after the coroutine is resumed,
99 // e.g. async.coro.suspend state, [suspend], [resume], [destroy]
100 //
101 // This cleanup block is a duplicate of the cleanup block followed by the
102 // resume block. The purpose of having a duplicate cleanup block for destroy
103 // is to make the CFG clear so that the control flow analysis won't confuse.
104 //
105 // The overall structure of the lowered CFG can be the following,
106 //
107 // Entry (calling async.coro.suspend)
108 // | \
109 // Resume Destroy (duplicate of Cleanup)
110 // | |
111 // Cleanup |
112 // | /
113 // End (ends the corontine)
114 //
115 // If there is resume-specific cleanup logic, it can go into the Cleanup
116 // block but not the destroy block. Otherwise, it can fail block dominance
117 // check.
118 Block *cleanupForDestroy;
119 Block *suspend; // coroutine suspension block
120};
121} // namespace
122
124 std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;
125
126/// Utility to partially update the regular function CFG to the coroutine CFG
127/// compatible with LLVM coroutines switched-resume lowering using
128/// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
129/// that branches into preexisting entry block. Also inserts trailing blocks.
130///
131/// The result types of the passed `func` start with an optional `async.token`
132/// and be continued with some number of `async.value`s.
133///
134/// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
135///
136/// - `entry` block sets up the coroutine.
137/// - `set_error` block sets completion token and async values state to error.
138/// - `cleanup` block cleans up the coroutine state.
139/// - `suspend block after the @llvm.coro.end() defines what value will be
140/// returned to the initial caller of a coroutine. Everything before the
141/// @llvm.coro.end() will be executed at every suspension point.
142///
143/// Coroutine structure (only the important bits):
144///
145/// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>)
146/// {
147/// ^entry(<function-arguments>):
148/// %token = <async token> : !async.token // create async runtime token
149/// %value = <async value> : !async.value<T> // create async value
150/// %id = async.coro.getId // create a coroutine id
151/// %hdl = async.coro.begin %id // create a coroutine handle
152/// cf.br ^preexisting_entry_block
153///
154/// /* preexisting blocks modified to branch to the cleanup block */
155///
156/// ^set_error: // this block created lazily only if needed (see code below)
157/// async.runtime.set_error %token : !async.token
158/// async.runtime.set_error %value : !async.value<T>
159/// cf.br ^cleanup
160///
161/// ^cleanup:
162/// async.coro.free %hdl // delete the coroutine state
163/// cf.br ^suspend
164///
165/// ^suspend:
166/// async.coro.end %hdl // marks the end of a coroutine
167/// return %token, %value : !async.token, !async.value<T>
168/// }
169///
170static CoroMachinery setupCoroMachinery(func::FuncOp func) {
171 assert(!func.getBlocks().empty() && "Function must have an entry block");
172
173 MLIRContext *ctx = func.getContext();
174 Block *entryBlock = &func.getBlocks().front();
175 Block *originalEntryBlock =
176 entryBlock->splitBlock(entryBlock->getOperations().begin());
177 auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
178
179 // ------------------------------------------------------------------------ //
180 // Allocate async token/values that we will return from a ramp function.
181 // ------------------------------------------------------------------------ //
182
183 // We treat TokenType as state update marker to represent side-effects of
184 // async computations
185 bool isStateful = isa<TokenType>(func.getResultTypes().front());
186
187 std::optional<Value> retToken;
188 if (isStateful)
189 retToken.emplace(RuntimeCreateOp::create(builder, TokenType::get(ctx)));
190
192 ArrayRef<Type> resValueTypes =
193 isStateful ? func.getResultTypes().drop_front() : func.getResultTypes();
194 for (auto resType : resValueTypes)
195 retValues.emplace_back(
196 RuntimeCreateOp::create(builder, resType).getResult());
197
198 // ------------------------------------------------------------------------ //
199 // Initialize coroutine: get coroutine id and coroutine handle.
200 // ------------------------------------------------------------------------ //
201 auto coroIdOp = CoroIdOp::create(builder, CoroIdType::get(ctx));
202 auto coroHdlOp =
203 CoroBeginOp::create(builder, CoroHandleType::get(ctx), coroIdOp.getId());
204 cf::BranchOp::create(builder, originalEntryBlock);
205
206 Block *cleanupBlock = func.addBlock();
207 Block *cleanupBlockForDestroy = func.addBlock();
208 Block *suspendBlock = func.addBlock();
209
210 // ------------------------------------------------------------------------ //
211 // Coroutine cleanup blocks: deallocate coroutine frame, free the memory.
212 // ------------------------------------------------------------------------ //
213 auto buildCleanupBlock = [&](Block *cb) {
214 builder.setInsertionPointToStart(cb);
215 CoroFreeOp::create(builder, coroIdOp.getId(), coroHdlOp.getHandle());
216
217 // Branch into the suspend block.
218 cf::BranchOp::create(builder, suspendBlock);
219 };
220 buildCleanupBlock(cleanupBlock);
221 buildCleanupBlock(cleanupBlockForDestroy);
222
223 // ------------------------------------------------------------------------ //
224 // Coroutine suspend block: mark the end of a coroutine and return allocated
225 // async token.
226 // ------------------------------------------------------------------------ //
227 builder.setInsertionPointToStart(suspendBlock);
228
229 // Mark the end of a coroutine: async.coro.end
230 CoroEndOp::create(builder, coroHdlOp.getHandle());
231
232 // Return created optional `async.token` and `async.values` from the suspend
233 // block. This will be the return value of a coroutine ramp function.
235 if (retToken)
236 ret.push_back(*retToken);
237 llvm::append_range(ret, retValues);
238 func::ReturnOp::create(builder, ret);
239
240 // `async.await` op lowering will create resume blocks for async
241 // continuations, and will conditionally branch to cleanup or suspend blocks.
242
243 // The switch-resumed API based coroutine should be marked with
244 // presplitcoroutine attribute to mark the function as a coroutine.
245 func->setAttr("passthrough", builder.getArrayAttr(
246 StringAttr::get(ctx, "presplitcoroutine")));
247
248 CoroMachinery machinery;
249 machinery.func = func;
250 machinery.asyncToken = retToken;
251 machinery.returnValues = retValues;
252 machinery.coroHandle = coroHdlOp.getHandle();
253 machinery.entry = entryBlock;
254 machinery.setError = std::nullopt; // created lazily only if needed
255 machinery.cleanup = cleanupBlock;
256 machinery.cleanupForDestroy = cleanupBlockForDestroy;
257 machinery.suspend = suspendBlock;
258 return machinery;
259}
260
261// Lazily creates `set_error` block only if it is required for lowering to the
262// runtime operations (see for example lowering of assert operation).
263static Block *setupSetErrorBlock(CoroMachinery &coro) {
264 if (coro.setError)
265 return *coro.setError;
266
267 coro.setError = coro.func.addBlock();
268 (*coro.setError)->moveBefore(coro.cleanup);
269
270 auto builder =
271 ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), *coro.setError);
272
273 // Coroutine set_error block: set error on token and all returned values.
274 if (coro.asyncToken)
275 RuntimeSetErrorOp::create(builder, *coro.asyncToken);
276
277 for (Value retValue : coro.returnValues)
278 RuntimeSetErrorOp::create(builder, retValue);
279
280 // Branch into the cleanup block.
281 cf::BranchOp::create(builder, coro.cleanup);
282
283 return *coro.setError;
284}
285
286//===----------------------------------------------------------------------===//
287// async.execute op outlining to the coroutine functions.
288//===----------------------------------------------------------------------===//
289
290/// Outline the body region attached to the `async.execute` op into a standalone
291/// function.
292///
293/// Note that this is not reversible transformation.
294static std::pair<func::FuncOp, CoroMachinery>
295outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
296 ModuleOp module = execute->getParentOfType<ModuleOp>();
297
298 MLIRContext *ctx = module.getContext();
299 Location loc = execute.getLoc();
300
301 // Make sure that all constants will be inside the outlined async function to
302 // reduce the number of function arguments.
303 cloneConstantsIntoTheRegion(execute.getBodyRegion());
304
305 // Collect all outlined function inputs.
306 SetVector<mlir::Value> functionInputs(llvm::from_range,
307 execute.getDependencies());
308 functionInputs.insert_range(execute.getBodyOperands());
309 getUsedValuesDefinedAbove(execute.getBodyRegion(), functionInputs);
310
311 // Collect types for the outlined function inputs and outputs.
312 auto typesRange = llvm::map_range(
313 functionInputs, [](Value value) { return value.getType(); });
314 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
315 auto outputTypes = execute.getResultTypes();
316
317 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
318 auto funcAttrs = ArrayRef<NamedAttribute>();
319
320 // TODO: Derive outlined function name from the parent FuncOp (support
321 // multiple nested async.execute operations).
322 func::FuncOp func =
323 func::FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
324 symbolTable.insert(func);
325
327 auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock());
328
329 // Prepare for coroutine conversion by creating the body of the function.
330 {
331 size_t numDependencies = execute.getDependencies().size();
332 size_t numOperands = execute.getBodyOperands().size();
333
334 // Await on all dependencies before starting to execute the body region.
335 for (size_t i = 0; i < numDependencies; ++i)
336 AwaitOp::create(builder, func.getArgument(i));
337
338 // Await on all async value operands and unwrap the payload.
339 SmallVector<Value, 4> unwrappedOperands(numOperands);
340 for (size_t i = 0; i < numOperands; ++i) {
341 Value operand = func.getArgument(numDependencies + i);
342 unwrappedOperands[i] = AwaitOp::create(builder, loc, operand).getResult();
343 }
344
345 // Map from function inputs defined above the execute op to the function
346 // arguments.
347 IRMapping valueMapping;
348 valueMapping.map(functionInputs, func.getArguments());
349 valueMapping.map(execute.getBodyRegion().getArguments(), unwrappedOperands);
350
351 // Clone all operations from the execute operation body into the outlined
352 // function body.
353 for (Operation &op : execute.getBodyRegion().getOps())
354 builder.clone(op, valueMapping);
355 }
356
357 // Adding entry/cleanup/suspend blocks.
358 CoroMachinery coro = setupCoroMachinery(func);
359
360 // Suspend async function at the end of an entry block, and resume it using
361 // Async resume operation (execution will be resumed in a thread managed by
362 // the async runtime).
363 {
364 cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
365 builder.setInsertionPointToEnd(coro.entry);
366
367 // Save the coroutine state: async.coro.save
368 auto coroSaveOp =
369 CoroSaveOp::create(builder, CoroStateType::get(ctx), coro.coroHandle);
370
371 // Pass coroutine to the runtime to be resumed on a runtime managed
372 // thread.
373 RuntimeResumeOp::create(builder, coro.coroHandle);
374
375 // Add async.coro.suspend as a suspended block terminator.
376 CoroSuspendOp::create(builder, coroSaveOp.getState(), coro.suspend,
377 branch.getDest(), coro.cleanupForDestroy);
378
379 branch.erase();
380 }
381
382 // Replace the original `async.execute` with a call to outlined function.
383 {
384 ImplicitLocOpBuilder callBuilder(loc, execute);
385 auto callOutlinedFunc = func::CallOp::create(callBuilder, func.getName(),
386 execute.getResultTypes(),
387 functionInputs.getArrayRef());
388 execute.replaceAllUsesWith(callOutlinedFunc.getResults());
389 execute.erase();
390 }
391
392 return {func, coro};
393}
394
395//===----------------------------------------------------------------------===//
396// Convert async.create_group operation to async.runtime.create_group
397//===----------------------------------------------------------------------===//
398
399namespace {
400class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
401public:
402 using OpConversionPattern::OpConversionPattern;
403
404 LogicalResult
405 matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
406 ConversionPatternRewriter &rewriter) const override {
407 rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
408 op, GroupType::get(op->getContext()), adaptor.getOperands());
409 return success();
410 }
411};
412} // namespace
413
414//===----------------------------------------------------------------------===//
415// Convert async.add_to_group operation to async.runtime.add_to_group.
416//===----------------------------------------------------------------------===//
417
418namespace {
419class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
420public:
421 using OpConversionPattern::OpConversionPattern;
422
423 LogicalResult
424 matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
425 ConversionPatternRewriter &rewriter) const override {
426 rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
427 op, rewriter.getIndexType(), adaptor.getOperands());
428 return success();
429 }
430};
431} // namespace
432
433//===----------------------------------------------------------------------===//
434// Convert async.func, async.return and async.call operations to non-blocking
435// operations based on llvm coroutine
436//===----------------------------------------------------------------------===//
437
438namespace {
439
440//===----------------------------------------------------------------------===//
441// Convert async.func operation to func.func
442//===----------------------------------------------------------------------===//
443
444class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> {
445public:
446 AsyncFuncOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
447 : OpConversionPattern<async::FuncOp>(ctx), coros(std::move(coros)) {}
448
449 LogicalResult
450 matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
451 ConversionPatternRewriter &rewriter) const override {
452 Location loc = op->getLoc();
453
454 auto newFuncOp =
455 func::FuncOp::create(rewriter, loc, op.getName(), op.getFunctionType());
456
459 // Copy over all attributes other than the name.
460 for (const auto &namedAttr : op->getAttrs()) {
461 if (namedAttr.getName() != SymbolTable::getSymbolAttrName())
462 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
463 }
464
465 rewriter.inlineRegionBefore(op.getBody(), newFuncOp.getBody(),
466 newFuncOp.end());
467
468 CoroMachinery coro = setupCoroMachinery(newFuncOp);
469 (*coros)[newFuncOp] = coro;
470 // no initial suspend, we should hot-start
471
472 rewriter.eraseOp(op);
473 return success();
474 }
475
476private:
477 FuncCoroMapPtr coros;
478};
479
480//===----------------------------------------------------------------------===//
481// Convert async.call operation to func.call
482//===----------------------------------------------------------------------===//
483
484class AsyncCallOpLowering : public OpConversionPattern<async::CallOp> {
485public:
486 AsyncCallOpLowering(MLIRContext *ctx)
487 : OpConversionPattern<async::CallOp>(ctx) {}
488
489 LogicalResult
490 matchAndRewrite(async::CallOp op, OpAdaptor adaptor,
491 ConversionPatternRewriter &rewriter) const override {
492 rewriter.replaceOpWithNewOp<func::CallOp>(
493 op, op.getCallee(), op.getResultTypes(), op.getOperands());
494 return success();
495 }
496};
497
498//===----------------------------------------------------------------------===//
499// Convert async.return operation to async.runtime operations.
500//===----------------------------------------------------------------------===//
501
502class AsyncReturnOpLowering : public OpConversionPattern<async::ReturnOp> {
503public:
504 AsyncReturnOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
505 : OpConversionPattern<async::ReturnOp>(ctx), coros(std::move(coros)) {}
506
507 LogicalResult
508 matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
509 ConversionPatternRewriter &rewriter) const override {
510 auto func = op->template getParentOfType<func::FuncOp>();
511 auto funcCoro = coros->find(func);
512 if (funcCoro == coros->end())
513 return rewriter.notifyMatchFailure(
514 op, "operation is not inside the async coroutine function");
515
516 Location loc = op->getLoc();
517 const CoroMachinery &coro = funcCoro->getSecond();
518 rewriter.setInsertionPointAfter(op);
519
520 // Store return values into the async values storage and switch async
521 // values state to available.
522 for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
523 Value returnValue = std::get<0>(tuple);
524 Value asyncValue = std::get<1>(tuple);
525 RuntimeStoreOp::create(rewriter, loc, returnValue, asyncValue);
526 RuntimeSetAvailableOp::create(rewriter, loc, asyncValue);
527 }
528
529 if (coro.asyncToken)
530 // Switch the coroutine completion token to available state.
531 RuntimeSetAvailableOp::create(rewriter, loc, *coro.asyncToken);
532
533 rewriter.eraseOp(op);
534 cf::BranchOp::create(rewriter, loc, coro.cleanup);
535 return success();
536 }
537
538private:
539 FuncCoroMapPtr coros;
540};
541} // namespace
542
543//===----------------------------------------------------------------------===//
544// Convert async.await and async.await_all operations to the async.runtime.await
545// or async.runtime.await_and_resume operations.
546//===----------------------------------------------------------------------===//
547
548namespace {
549template <typename AwaitType, typename AwaitableType>
550class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
551 using AwaitAdaptor = typename AwaitType::Adaptor;
552
553public:
554 AwaitOpLoweringBase(MLIRContext *ctx, FuncCoroMapPtr coros,
555 bool shouldLowerBlockingWait)
556 : OpConversionPattern<AwaitType>(ctx), coros(std::move(coros)),
557 shouldLowerBlockingWait(shouldLowerBlockingWait) {}
558
559 LogicalResult
560 matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
561 ConversionPatternRewriter &rewriter) const override {
562 // We can only await on one the `AwaitableType` (for `await` it can be
563 // a `token` or a `value`, for `await_all` it must be a `group`).
564 if (!isa<AwaitableType>(op.getOperand().getType()))
565 return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
566
567 // Check if await operation is inside the coroutine function.
568 auto func = op->template getParentOfType<func::FuncOp>();
569 auto funcCoro = coros->find(func);
570 const bool isInCoroutine = funcCoro != coros->end();
571
572 Location loc = op->getLoc();
573 Value operand = adaptor.getOperand();
574
575 Type i1 = rewriter.getI1Type();
576
577 // Delay lowering to block wait in case await op is inside async.execute
578 if (!isInCoroutine && !shouldLowerBlockingWait)
579 return failure();
580
581 // Inside regular functions we use the blocking wait operation to wait for
582 // the async object (token, value or group) to become available.
583 if (!isInCoroutine) {
584 ImplicitLocOpBuilder builder(loc, rewriter);
585 RuntimeAwaitOp::create(builder, loc, operand);
586
587 // Assert that the awaited operands is not in the error state.
588 Value isError = RuntimeIsErrorOp::create(builder, i1, operand);
589 Value notError = arith::XOrIOp::create(
590 builder, isError,
591 arith::ConstantOp::create(builder, loc, i1,
592 builder.getIntegerAttr(i1, 1)));
593
594 cf::AssertOp::create(builder, notError,
595 "Awaited async operand is in error state");
596 }
597
598 // Inside the coroutine we convert await operation into coroutine suspension
599 // point, and resume execution asynchronously.
600 if (isInCoroutine) {
601 CoroMachinery &coro = funcCoro->getSecond();
602 Block *suspended = op->getBlock();
603
604 ImplicitLocOpBuilder builder(loc, rewriter);
605 MLIRContext *ctx = op->getContext();
606
607 // Save the coroutine state and resume on a runtime managed thread when
608 // the operand becomes available.
609 auto coroSaveOp =
610 CoroSaveOp::create(builder, CoroStateType::get(ctx), coro.coroHandle);
611 RuntimeAwaitAndResumeOp::create(builder, operand, coro.coroHandle);
612
613 // Split the entry block before the await operation.
614 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
615
616 // Add async.coro.suspend as a suspended block terminator.
617 builder.setInsertionPointToEnd(suspended);
618 CoroSuspendOp::create(builder, coroSaveOp.getState(), coro.suspend,
619 resume, coro.cleanupForDestroy);
620
621 // Split the resume block into error checking and continuation.
622 Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
623
624 // Check if the awaited value is in the error state.
625 builder.setInsertionPointToStart(resume);
626 auto isError = RuntimeIsErrorOp::create(builder, loc, i1, operand);
627 cf::CondBranchOp::create(builder, isError,
628 /*trueDest=*/setupSetErrorBlock(coro),
629 /*trueArgs=*/ArrayRef<Value>(),
630 /*falseDest=*/continuation,
631 /*falseArgs=*/ArrayRef<Value>());
632
633 // Make sure that replacement value will be constructed in the
634 // continuation block.
635 rewriter.setInsertionPointToStart(continuation);
636 }
637
638 // Erase or replace the await operation with the new value.
639 if (Value replaceWith = getReplacementValue(op, operand, rewriter))
640 rewriter.replaceOp(op, replaceWith);
641 else
642 rewriter.eraseOp(op);
643
644 return success();
645 }
646
647 virtual Value getReplacementValue(AwaitType op, Value operand,
648 ConversionPatternRewriter &rewriter) const {
649 return Value();
650 }
651
652private:
653 FuncCoroMapPtr coros;
654 bool shouldLowerBlockingWait;
655};
656
657/// Lowering for `async.await` with a token operand.
658class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
659 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
660
661public:
662 using Base::Base;
663};
664
665/// Lowering for `async.await` with a value operand.
666class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
667 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
668
669public:
670 using Base::Base;
671
672 Value
673 getReplacementValue(AwaitOp op, Value operand,
674 ConversionPatternRewriter &rewriter) const override {
675 // Load from the async value storage.
676 auto valueType = cast<ValueType>(operand.getType()).getValueType();
677 return RuntimeLoadOp::create(rewriter, op->getLoc(), valueType, operand);
678 }
679};
680
681/// Lowering for `async.await_all` operation.
682class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
683 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
684
685public:
686 using Base::Base;
687};
688
689} // namespace
690
691//===----------------------------------------------------------------------===//
692// Convert async.yield operation to async.runtime operations.
693//===----------------------------------------------------------------------===//
694
695class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
696public:
698 : OpConversionPattern<async::YieldOp>(ctx), coros(std::move(coros)) {}
699
700 LogicalResult
701 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
702 ConversionPatternRewriter &rewriter) const override {
703 // Check if yield operation is inside the async coroutine function.
704 auto func = op->template getParentOfType<func::FuncOp>();
705 auto funcCoro = coros->find(func);
706 if (funcCoro == coros->end())
707 return rewriter.notifyMatchFailure(
708 op, "operation is not inside the async coroutine function");
709
710 Location loc = op->getLoc();
711 const CoroMachinery &coro = funcCoro->getSecond();
712
713 // Store yielded values into the async values storage and switch async
714 // values state to available.
715 for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
716 Value yieldValue = std::get<0>(tuple);
717 Value asyncValue = std::get<1>(tuple);
718 RuntimeStoreOp::create(rewriter, loc, yieldValue, asyncValue);
719 RuntimeSetAvailableOp::create(rewriter, loc, asyncValue);
720 }
721
722 if (coro.asyncToken)
723 // Switch the coroutine completion token to available state.
724 RuntimeSetAvailableOp::create(rewriter, loc, *coro.asyncToken);
725
726 cf::BranchOp::create(rewriter, loc, coro.cleanup);
727 rewriter.eraseOp(op);
728
729 return success();
730 }
731
732private:
733 FuncCoroMapPtr coros;
734};
735
736//===----------------------------------------------------------------------===//
737// Convert cf.assert operation to cf.cond_br into `set_error` block.
738//===----------------------------------------------------------------------===//
739
740class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
741public:
743 : OpConversionPattern<cf::AssertOp>(ctx), coros(std::move(coros)) {}
744
745 LogicalResult
746 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
747 ConversionPatternRewriter &rewriter) const override {
748 // Check if assert operation is inside the async coroutine function.
749 auto func = op->template getParentOfType<func::FuncOp>();
750 auto funcCoro = coros->find(func);
751 if (funcCoro == coros->end())
752 return rewriter.notifyMatchFailure(
753 op, "operation is not inside the async coroutine function");
754
755 Location loc = op->getLoc();
756 CoroMachinery &coro = funcCoro->getSecond();
757
758 Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
759 rewriter.setInsertionPointToEnd(cont->getPrevNode());
760 cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(),
761 /*trueDest=*/cont,
762 /*trueArgs=*/ArrayRef<Value>(),
763 /*falseDest=*/setupSetErrorBlock(coro),
764 /*falseArgs=*/ArrayRef<Value>());
765 rewriter.eraseOp(op);
766
767 return success();
768 }
769
770private:
771 FuncCoroMapPtr coros;
772};
773
774//===----------------------------------------------------------------------===//
775void AsyncToAsyncRuntimePass::runOnOperation() {
776 ModuleOp module = getOperation();
777 SymbolTable symbolTable(module);
778
779 // Functions with coroutine CFG setups, which are results of outlining
780 // `async.execute` body regions
781 FuncCoroMapPtr coros =
782 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
783
784 module.walk([&](ExecuteOp execute) {
785 coros->insert(outlineExecuteOp(symbolTable, execute));
786 });
787
788 LLVM_DEBUG({
789 llvm::dbgs() << "Outlined " << coros->size()
790 << " functions built from async.execute operations\n";
791 });
792
793 // Returns true if operation is inside the coroutine.
794 auto isInCoroutine = [&](Operation *op) -> bool {
795 auto parentFunc = op->getParentOfType<func::FuncOp>();
796 return coros->contains(parentFunc);
797 };
798
799 // Lower async operations to async.runtime operations.
800 MLIRContext *ctx = module->getContext();
801 RewritePatternSet asyncPatterns(ctx);
802
803 // Conversion to async runtime augments original CFG with the coroutine CFG,
804 // and we have to make sure that structured control flow operations with async
805 // operations in nested regions will be converted to branch-based control flow
806 // before we add the coroutine basic blocks.
808
809 // Async lowering does not use type converter because it must preserve all
810 // types for async.runtime operations.
811 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
812
813 asyncPatterns
814 .add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
815 ctx, coros, /*should_lower_blocking_wait=*/true);
816
817 // Lower assertions to conditional branches into error blocks.
818 asyncPatterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
819
820 // All high level async operations must be lowered to the runtime operations.
821 ConversionTarget runtimeTarget(*ctx);
822 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
823 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
824 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
825
826 // Decide if structured control flow has to be lowered to branch-based CFG.
827 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
828 auto walkResult = op->walk([&](Operation *nested) {
829 bool isAsync = isa<async::AsyncDialect>(nested->getDialect());
830 return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
831 : WalkResult::advance();
832 });
833 return !walkResult.wasInterrupted();
834 });
835 runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
836 func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
837
838 // Assertions must be converted to runtime errors inside async functions.
839 runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
840 [&](cf::AssertOp op) -> bool {
841 auto func = op->getParentOfType<func::FuncOp>();
842 return !coros->contains(func);
843 });
844
845 if (failed(applyPartialConversion(module, runtimeTarget,
846 std::move(asyncPatterns)))) {
847 signalPassFailure();
848 return;
849 }
850}
851
852//===----------------------------------------------------------------------===//
855 // Functions with coroutine CFG setups, which are results of converting
856 // async.func.
857 FuncCoroMapPtr coros =
858 std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
859 MLIRContext *ctx = patterns.getContext();
860 // Lower async.func to func.func with coroutine cfg.
861 patterns.add<AsyncCallOpLowering>(ctx);
862 patterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
863
864 patterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
865 ctx, coros, /*should_lower_blocking_wait=*/false);
866 patterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
867
868 target.addDynamicallyLegalOp<AwaitOp, AwaitAllOp, YieldOp, cf::AssertOp>(
869 [coros](Operation *op) {
870 auto exec = op->getParentOfType<ExecuteOp>();
871 auto func = op->getParentOfType<func::FuncOp>();
872 return exec || !coros->contains(func);
873 });
874}
875
876void AsyncFuncToAsyncRuntimePass::runOnOperation() {
877 ModuleOp module = getOperation();
878
879 // Lower async operations to async.runtime operations.
880 MLIRContext *ctx = module->getContext();
881 RewritePatternSet asyncPatterns(ctx);
882 ConversionTarget runtimeTarget(*ctx);
883
884 // Lower async.func to func.func with coroutine cfg.
886 runtimeTarget);
887
888 runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
889 runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();
890
891 runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,
892 cf::BranchOp, cf::CondBranchOp>();
893
894 if (failed(applyPartialConversion(module, runtimeTarget,
895 std::move(asyncPatterns)))) {
896 signalPassFailure();
897 return;
898 }
899}
return success()
static constexpr const char kAsyncFnPrefix[]
std::shared_ptr< llvm::DenseMap< func::FuncOp, CoroMachinery > > FuncCoroMapPtr
static Block * setupSetErrorBlock(CoroMachinery &coro)
static CoroMachinery setupCoroMachinery(func::FuncOp func)
Utility to partially update the regular function CFG to the coroutine CFG compatible with LLVM corout...
static std::pair< func::FuncOp, CoroMachinery > outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute)
Outline the body region attached to the async.execute op into a standalone function.
AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:150
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition Block.cpp:323
OpListType & getOperations()
Definition Block.h:147
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
typename cf::AssertOp::Adaptor OpAdaptor
Definition Pattern.h:229
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
static ImplicitLocOpBuilder atBlockBegin(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition Builders.h:642
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:241
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
Definition SymbolTable.h:97
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult interrupt()
Definition WalkResult.h:46
void cloneConstantsIntoTheRegion(Region &region)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
void populateAsyncFuncToAsyncRuntimeConversionPatterns(RewritePatternSet &patterns, ConversionTarget &target)