MLIR  19.0.0git
AsyncToAsyncRuntime.cpp
Go to the documentation of this file.
1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering from high level async operations to async.coro
10 // and async.runtime operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <utility>
15 
17 
18 #include "PassDetail.h"
25 #include "mlir/IR/IRMapping.h"
27 #include "mlir/IR/PatternMatch.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/Support/Debug.h"
32 #include <optional>
33 
34 namespace mlir {
35 #define GEN_PASS_DEF_ASYNCTOASYNCRUNTIME
36 #define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIME
37 #include "mlir/Dialect/Async/Passes.h.inc"
38 } // namespace mlir
39 
40 using namespace mlir;
41 using namespace mlir::async;
42 
43 #define DEBUG_TYPE "async-to-async-runtime"
44 // Prefix for functions outlined from `async.execute` op regions.
45 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
46 
47 namespace {
48 
49 class AsyncToAsyncRuntimePass
50  : public impl::AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
51 public:
52  AsyncToAsyncRuntimePass() = default;
53  void runOnOperation() override;
54 };
55 
56 } // namespace
57 
58 namespace {
59 
60 class AsyncFuncToAsyncRuntimePass
61  : public impl::AsyncFuncToAsyncRuntimeBase<AsyncFuncToAsyncRuntimePass> {
62 public:
63  AsyncFuncToAsyncRuntimePass() = default;
64  void runOnOperation() override;
65 };
66 
67 } // namespace
68 
69 /// Function targeted for coroutine transformation has two additional blocks at
70 /// the end: coroutine cleanup and coroutine suspension.
71 ///
72 /// async.await op lowering additionaly creates a resume block for each
73 /// operation to enable non-blocking waiting via coroutine suspension.
74 namespace {
75 struct CoroMachinery {
76  func::FuncOp func;
77 
78  // Async function returns an optional token, followed by some async values
79  //
80  // async.func @foo() -> !async.value<T> {
81  // %cst = arith.constant 42.0 : T
82  // return %cst: T
83  // }
84  // Async execute region returns a completion token, and an async value for
85  // each yielded value.
86  //
87  // %token, %result = async.execute -> !async.value<T> {
88  // %0 = arith.constant ... : T
89  // async.yield %0 : T
90  // }
91  std::optional<Value> asyncToken; // returned completion token
92  llvm::SmallVector<Value, 4> returnValues; // returned async values
93 
94  Value coroHandle; // coroutine handle (!async.coro.getHandle value)
95  Block *entry; // coroutine entry block
96  std::optional<Block *> setError; // set returned values to error state
97  Block *cleanup; // coroutine cleanup block
98 
99  // Coroutine cleanup block for destroy after the coroutine is resumed,
100  // e.g. async.coro.suspend state, [suspend], [resume], [destroy]
101  //
102  // This cleanup block is a duplicate of the cleanup block followed by the
103  // resume block. The purpose of having a duplicate cleanup block for destroy
104  // is to make the CFG clear so that the control flow analysis won't confuse.
105  //
106  // The overall structure of the lowered CFG can be the following,
107  //
108  // Entry (calling async.coro.suspend)
109  // | \
110  // Resume Destroy (duplicate of Cleanup)
111  // | |
112  // Cleanup |
113  // | /
114  // End (ends the corontine)
115  //
116  // If there is resume-specific cleanup logic, it can go into the Cleanup
117  // block but not the destroy block. Otherwise, it can fail block dominance
118  // check.
119  Block *cleanupForDestroy;
120  Block *suspend; // coroutine suspension block
121 };
122 } // namespace
123 
125  std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;
126 
127 /// Utility to partially update the regular function CFG to the coroutine CFG
128 /// compatible with LLVM coroutines switched-resume lowering using
129 /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
130 /// that branches into preexisting entry block. Also inserts trailing blocks.
131 ///
132 /// The result types of the passed `func` start with an optional `async.token`
133 /// and be continued with some number of `async.value`s.
134 ///
135 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
136 ///
137 /// - `entry` block sets up the coroutine.
138 /// - `set_error` block sets completion token and async values state to error.
139 /// - `cleanup` block cleans up the coroutine state.
140 /// - `suspend block after the @llvm.coro.end() defines what value will be
141 /// returned to the initial caller of a coroutine. Everything before the
142 /// @llvm.coro.end() will be executed at every suspension point.
143 ///
144 /// Coroutine structure (only the important bits):
145 ///
146 /// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>)
147 /// {
148 /// ^entry(<function-arguments>):
149 /// %token = <async token> : !async.token // create async runtime token
150 /// %value = <async value> : !async.value<T> // create async value
151 /// %id = async.coro.getId // create a coroutine id
152 /// %hdl = async.coro.begin %id // create a coroutine handle
153 /// cf.br ^preexisting_entry_block
154 ///
155 /// /* preexisting blocks modified to branch to the cleanup block */
156 ///
157 /// ^set_error: // this block created lazily only if needed (see code below)
158 /// async.runtime.set_error %token : !async.token
159 /// async.runtime.set_error %value : !async.value<T>
160 /// cf.br ^cleanup
161 ///
162 /// ^cleanup:
163 /// async.coro.free %hdl // delete the coroutine state
164 /// cf.br ^suspend
165 ///
166 /// ^suspend:
167 /// async.coro.end %hdl // marks the end of a coroutine
168 /// return %token, %value : !async.token, !async.value<T>
169 /// }
170 ///
171 static CoroMachinery setupCoroMachinery(func::FuncOp func) {
172  assert(!func.getBlocks().empty() && "Function must have an entry block");
173 
174  MLIRContext *ctx = func.getContext();
175  Block *entryBlock = &func.getBlocks().front();
176  Block *originalEntryBlock =
177  entryBlock->splitBlock(entryBlock->getOperations().begin());
178  auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
179 
180  // ------------------------------------------------------------------------ //
181  // Allocate async token/values that we will return from a ramp function.
182  // ------------------------------------------------------------------------ //
183 
184  // We treat TokenType as state update marker to represent side-effects of
185  // async computations
186  bool isStateful = isa<TokenType>(func.getResultTypes().front());
187 
188  std::optional<Value> retToken;
189  if (isStateful)
190  retToken.emplace(builder.create<RuntimeCreateOp>(TokenType::get(ctx)));
191 
192  llvm::SmallVector<Value, 4> retValues;
193  ArrayRef<Type> resValueTypes =
194  isStateful ? func.getResultTypes().drop_front() : func.getResultTypes();
195  for (auto resType : resValueTypes)
196  retValues.emplace_back(
197  builder.create<RuntimeCreateOp>(resType).getResult());
198 
199  // ------------------------------------------------------------------------ //
200  // Initialize coroutine: get coroutine id and coroutine handle.
201  // ------------------------------------------------------------------------ //
202  auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
203  auto coroHdlOp =
204  builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.getId());
205  builder.create<cf::BranchOp>(originalEntryBlock);
206 
207  Block *cleanupBlock = func.addBlock();
208  Block *cleanupBlockForDestroy = func.addBlock();
209  Block *suspendBlock = func.addBlock();
210 
211  // ------------------------------------------------------------------------ //
212  // Coroutine cleanup blocks: deallocate coroutine frame, free the memory.
213  // ------------------------------------------------------------------------ //
214  auto buildCleanupBlock = [&](Block *cb) {
215  builder.setInsertionPointToStart(cb);
216  builder.create<CoroFreeOp>(coroIdOp.getId(), coroHdlOp.getHandle());
217 
218  // Branch into the suspend block.
219  builder.create<cf::BranchOp>(suspendBlock);
220  };
221  buildCleanupBlock(cleanupBlock);
222  buildCleanupBlock(cleanupBlockForDestroy);
223 
224  // ------------------------------------------------------------------------ //
225  // Coroutine suspend block: mark the end of a coroutine and return allocated
226  // async token.
227  // ------------------------------------------------------------------------ //
228  builder.setInsertionPointToStart(suspendBlock);
229 
230  // Mark the end of a coroutine: async.coro.end
231  builder.create<CoroEndOp>(coroHdlOp.getHandle());
232 
233  // Return created optional `async.token` and `async.values` from the suspend
234  // block. This will be the return value of a coroutine ramp function.
236  if (retToken)
237  ret.push_back(*retToken);
238  ret.insert(ret.end(), retValues.begin(), retValues.end());
239  builder.create<func::ReturnOp>(ret);
240 
241  // `async.await` op lowering will create resume blocks for async
242  // continuations, and will conditionally branch to cleanup or suspend blocks.
243 
244  // The switch-resumed API based coroutine should be marked with
245  // presplitcoroutine attribute to mark the function as a coroutine.
246  func->setAttr("passthrough", builder.getArrayAttr(
247  StringAttr::get(ctx, "presplitcoroutine")));
248 
249  CoroMachinery machinery;
250  machinery.func = func;
251  machinery.asyncToken = retToken;
252  machinery.returnValues = retValues;
253  machinery.coroHandle = coroHdlOp.getHandle();
254  machinery.entry = entryBlock;
255  machinery.setError = std::nullopt; // created lazily only if needed
256  machinery.cleanup = cleanupBlock;
257  machinery.cleanupForDestroy = cleanupBlockForDestroy;
258  machinery.suspend = suspendBlock;
259  return machinery;
260 }
261 
262 // Lazily creates `set_error` block only if it is required for lowering to the
263 // runtime operations (see for example lowering of assert operation).
264 static Block *setupSetErrorBlock(CoroMachinery &coro) {
265  if (coro.setError)
266  return *coro.setError;
267 
268  coro.setError = coro.func.addBlock();
269  (*coro.setError)->moveBefore(coro.cleanup);
270 
271  auto builder =
272  ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), *coro.setError);
273 
274  // Coroutine set_error block: set error on token and all returned values.
275  if (coro.asyncToken)
276  builder.create<RuntimeSetErrorOp>(*coro.asyncToken);
277 
278  for (Value retValue : coro.returnValues)
279  builder.create<RuntimeSetErrorOp>(retValue);
280 
281  // Branch into the cleanup block.
282  builder.create<cf::BranchOp>(coro.cleanup);
283 
284  return *coro.setError;
285 }
286 
287 //===----------------------------------------------------------------------===//
288 // async.execute op outlining to the coroutine functions.
289 //===----------------------------------------------------------------------===//
290 
291 /// Outline the body region attached to the `async.execute` op into a standalone
292 /// function.
293 ///
294 /// Note that this is not reversible transformation.
295 static std::pair<func::FuncOp, CoroMachinery>
296 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
297  ModuleOp module = execute->getParentOfType<ModuleOp>();
298 
299  MLIRContext *ctx = module.getContext();
300  Location loc = execute.getLoc();
301 
302  // Make sure that all constants will be inside the outlined async function to
303  // reduce the number of function arguments.
304  cloneConstantsIntoTheRegion(execute.getBodyRegion());
305 
306  // Collect all outlined function inputs.
307  SetVector<mlir::Value> functionInputs(execute.getDependencies().begin(),
308  execute.getDependencies().end());
309  functionInputs.insert(execute.getBodyOperands().begin(),
310  execute.getBodyOperands().end());
311  getUsedValuesDefinedAbove(execute.getBodyRegion(), functionInputs);
312 
313  // Collect types for the outlined function inputs and outputs.
314  auto typesRange = llvm::map_range(
315  functionInputs, [](Value value) { return value.getType(); });
316  SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
317  auto outputTypes = execute.getResultTypes();
318 
319  auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
320  auto funcAttrs = ArrayRef<NamedAttribute>();
321 
322  // TODO: Derive outlined function name from the parent FuncOp (support
323  // multiple nested async.execute operations).
324  func::FuncOp func =
325  func::FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
326  symbolTable.insert(func);
327 
329  auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock());
330 
331  // Prepare for coroutine conversion by creating the body of the function.
332  {
333  size_t numDependencies = execute.getDependencies().size();
334  size_t numOperands = execute.getBodyOperands().size();
335 
336  // Await on all dependencies before starting to execute the body region.
337  for (size_t i = 0; i < numDependencies; ++i)
338  builder.create<AwaitOp>(func.getArgument(i));
339 
340  // Await on all async value operands and unwrap the payload.
341  SmallVector<Value, 4> unwrappedOperands(numOperands);
342  for (size_t i = 0; i < numOperands; ++i) {
343  Value operand = func.getArgument(numDependencies + i);
344  unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).getResult();
345  }
346 
347  // Map from function inputs defined above the execute op to the function
348  // arguments.
349  IRMapping valueMapping;
350  valueMapping.map(functionInputs, func.getArguments());
351  valueMapping.map(execute.getBodyRegion().getArguments(), unwrappedOperands);
352 
353  // Clone all operations from the execute operation body into the outlined
354  // function body.
355  for (Operation &op : execute.getBodyRegion().getOps())
356  builder.clone(op, valueMapping);
357  }
358 
359  // Adding entry/cleanup/suspend blocks.
360  CoroMachinery coro = setupCoroMachinery(func);
361 
362  // Suspend async function at the end of an entry block, and resume it using
363  // Async resume operation (execution will be resumed in a thread managed by
364  // the async runtime).
365  {
366  cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
367  builder.setInsertionPointToEnd(coro.entry);
368 
369  // Save the coroutine state: async.coro.save
370  auto coroSaveOp =
371  builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
372 
373  // Pass coroutine to the runtime to be resumed on a runtime managed
374  // thread.
375  builder.create<RuntimeResumeOp>(coro.coroHandle);
376 
377  // Add async.coro.suspend as a suspended block terminator.
378  builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend,
379  branch.getDest(), coro.cleanupForDestroy);
380 
381  branch.erase();
382  }
383 
384  // Replace the original `async.execute` with a call to outlined function.
385  {
386  ImplicitLocOpBuilder callBuilder(loc, execute);
387  auto callOutlinedFunc = callBuilder.create<func::CallOp>(
388  func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
389  execute.replaceAllUsesWith(callOutlinedFunc.getResults());
390  execute.erase();
391  }
392 
393  return {func, coro};
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // Convert async.create_group operation to async.runtime.create_group
398 //===----------------------------------------------------------------------===//
399 
400 namespace {
401 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
402 public:
404 
406  matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
407  ConversionPatternRewriter &rewriter) const override {
408  rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
409  op, GroupType::get(op->getContext()), adaptor.getOperands());
410  return success();
411  }
412 };
413 } // namespace
414 
415 //===----------------------------------------------------------------------===//
416 // Convert async.add_to_group operation to async.runtime.add_to_group.
417 //===----------------------------------------------------------------------===//
418 
419 namespace {
420 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
421 public:
423 
425  matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
426  ConversionPatternRewriter &rewriter) const override {
427  rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
428  op, rewriter.getIndexType(), adaptor.getOperands());
429  return success();
430  }
431 };
432 } // namespace
433 
434 //===----------------------------------------------------------------------===//
435 // Convert async.func, async.return and async.call operations to non-blocking
436 // operations based on llvm coroutine
437 //===----------------------------------------------------------------------===//
438 
439 namespace {
440 
441 //===----------------------------------------------------------------------===//
442 // Convert async.func operation to func.func
443 //===----------------------------------------------------------------------===//
444 
445 class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> {
446 public:
447  AsyncFuncOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
448  : OpConversionPattern<async::FuncOp>(ctx), coros(std::move(coros)) {}
449 
451  matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
452  ConversionPatternRewriter &rewriter) const override {
453  Location loc = op->getLoc();
454 
455  auto newFuncOp =
456  rewriter.create<func::FuncOp>(loc, op.getName(), op.getFunctionType());
457 
460  // Copy over all attributes other than the name.
461  for (const auto &namedAttr : op->getAttrs()) {
462  if (namedAttr.getName() != SymbolTable::getSymbolAttrName())
463  newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
464  }
465 
466  rewriter.inlineRegionBefore(op.getBody(), newFuncOp.getBody(),
467  newFuncOp.end());
468 
469  CoroMachinery coro = setupCoroMachinery(newFuncOp);
470  (*coros)[newFuncOp] = coro;
471  // no initial suspend, we should hot-start
472 
473  rewriter.eraseOp(op);
474  return success();
475  }
476 
477 private:
478  FuncCoroMapPtr coros;
479 };
480 
481 //===----------------------------------------------------------------------===//
482 // Convert async.call operation to func.call
483 //===----------------------------------------------------------------------===//
484 
485 class AsyncCallOpLowering : public OpConversionPattern<async::CallOp> {
486 public:
487  AsyncCallOpLowering(MLIRContext *ctx)
488  : OpConversionPattern<async::CallOp>(ctx) {}
489 
491  matchAndRewrite(async::CallOp op, OpAdaptor adaptor,
492  ConversionPatternRewriter &rewriter) const override {
493  rewriter.replaceOpWithNewOp<func::CallOp>(
494  op, op.getCallee(), op.getResultTypes(), op.getOperands());
495  return success();
496  }
497 };
498 
499 //===----------------------------------------------------------------------===//
500 // Convert async.return operation to async.runtime operations.
501 //===----------------------------------------------------------------------===//
502 
503 class AsyncReturnOpLowering : public OpConversionPattern<async::ReturnOp> {
504 public:
505  AsyncReturnOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
506  : OpConversionPattern<async::ReturnOp>(ctx), coros(std::move(coros)) {}
507 
509  matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
510  ConversionPatternRewriter &rewriter) const override {
511  auto func = op->template getParentOfType<func::FuncOp>();
512  auto funcCoro = coros->find(func);
513  if (funcCoro == coros->end())
514  return rewriter.notifyMatchFailure(
515  op, "operation is not inside the async coroutine function");
516 
517  Location loc = op->getLoc();
518  const CoroMachinery &coro = funcCoro->getSecond();
519  rewriter.setInsertionPointAfter(op);
520 
521  // Store return values into the async values storage and switch async
522  // values state to available.
523  for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
524  Value returnValue = std::get<0>(tuple);
525  Value asyncValue = std::get<1>(tuple);
526  rewriter.create<RuntimeStoreOp>(loc, returnValue, asyncValue);
527  rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
528  }
529 
530  if (coro.asyncToken)
531  // Switch the coroutine completion token to available state.
532  rewriter.create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
533 
534  rewriter.eraseOp(op);
535  rewriter.create<cf::BranchOp>(loc, coro.cleanup);
536  return success();
537  }
538 
539 private:
540  FuncCoroMapPtr coros;
541 };
542 } // namespace
543 
544 //===----------------------------------------------------------------------===//
545 // Convert async.await and async.await_all operations to the async.runtime.await
546 // or async.runtime.await_and_resume operations.
547 //===----------------------------------------------------------------------===//
548 
549 namespace {
550 template <typename AwaitType, typename AwaitableType>
551 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
552  using AwaitAdaptor = typename AwaitType::Adaptor;
553 
554 public:
555  AwaitOpLoweringBase(MLIRContext *ctx, FuncCoroMapPtr coros,
556  bool shouldLowerBlockingWait)
557  : OpConversionPattern<AwaitType>(ctx), coros(std::move(coros)),
558  shouldLowerBlockingWait(shouldLowerBlockingWait) {}
559 
561  matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
562  ConversionPatternRewriter &rewriter) const override {
563  // We can only await on one the `AwaitableType` (for `await` it can be
564  // a `token` or a `value`, for `await_all` it must be a `group`).
565  if (!isa<AwaitableType>(op.getOperand().getType()))
566  return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
567 
568  // Check if await operation is inside the coroutine function.
569  auto func = op->template getParentOfType<func::FuncOp>();
570  auto funcCoro = coros->find(func);
571  const bool isInCoroutine = funcCoro != coros->end();
572 
573  Location loc = op->getLoc();
574  Value operand = adaptor.getOperand();
575 
576  Type i1 = rewriter.getI1Type();
577 
578  // Delay lowering to block wait in case await op is inside async.execute
579  if (!isInCoroutine && !shouldLowerBlockingWait)
580  return failure();
581 
582  // Inside regular functions we use the blocking wait operation to wait for
583  // the async object (token, value or group) to become available.
584  if (!isInCoroutine) {
585  ImplicitLocOpBuilder builder(loc, rewriter);
586  builder.create<RuntimeAwaitOp>(loc, operand);
587 
588  // Assert that the awaited operands is not in the error state.
589  Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
590  Value notError = builder.create<arith::XOrIOp>(
591  isError, builder.create<arith::ConstantOp>(
592  loc, i1, builder.getIntegerAttr(i1, 1)));
593 
594  builder.create<cf::AssertOp>(notError,
595  "Awaited async operand is in error state");
596  }
597 
598  // Inside the coroutine we convert await operation into coroutine suspension
599  // point, and resume execution asynchronously.
600  if (isInCoroutine) {
601  CoroMachinery &coro = funcCoro->getSecond();
602  Block *suspended = op->getBlock();
603 
604  ImplicitLocOpBuilder builder(loc, rewriter);
605  MLIRContext *ctx = op->getContext();
606 
607  // Save the coroutine state and resume on a runtime managed thread when
608  // the operand becomes available.
609  auto coroSaveOp =
610  builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
611  builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
612 
613  // Split the entry block before the await operation.
614  Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
615 
616  // Add async.coro.suspend as a suspended block terminator.
617  builder.setInsertionPointToEnd(suspended);
618  builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend, resume,
619  coro.cleanupForDestroy);
620 
621  // Split the resume block into error checking and continuation.
622  Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
623 
624  // Check if the awaited value is in the error state.
625  builder.setInsertionPointToStart(resume);
626  auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
627  builder.create<cf::CondBranchOp>(isError,
628  /*trueDest=*/setupSetErrorBlock(coro),
629  /*trueArgs=*/ArrayRef<Value>(),
630  /*falseDest=*/continuation,
631  /*falseArgs=*/ArrayRef<Value>());
632 
633  // Make sure that replacement value will be constructed in the
634  // continuation block.
635  rewriter.setInsertionPointToStart(continuation);
636  }
637 
638  // Erase or replace the await operation with the new value.
639  if (Value replaceWith = getReplacementValue(op, operand, rewriter))
640  rewriter.replaceOp(op, replaceWith);
641  else
642  rewriter.eraseOp(op);
643 
644  return success();
645  }
646 
647  virtual Value getReplacementValue(AwaitType op, Value operand,
648  ConversionPatternRewriter &rewriter) const {
649  return Value();
650  }
651 
652 private:
653  FuncCoroMapPtr coros;
654  bool shouldLowerBlockingWait;
655 };
656 
657 /// Lowering for `async.await` with a token operand.
658 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
659  using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
660 
661 public:
662  using Base::Base;
663 };
664 
665 /// Lowering for `async.await` with a value operand.
666 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
667  using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
668 
669 public:
670  using Base::Base;
671 
672  Value
673  getReplacementValue(AwaitOp op, Value operand,
674  ConversionPatternRewriter &rewriter) const override {
675  // Load from the async value storage.
676  auto valueType = cast<ValueType>(operand.getType()).getValueType();
677  return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
678  }
679 };
680 
681 /// Lowering for `async.await_all` operation.
682 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
683  using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
684 
685 public:
686  using Base::Base;
687 };
688 
689 } // namespace
690 
691 //===----------------------------------------------------------------------===//
692 // Convert async.yield operation to async.runtime operations.
693 //===----------------------------------------------------------------------===//
694 
695 class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
696 public:
698  : OpConversionPattern<async::YieldOp>(ctx), coros(std::move(coros)) {}
699 
701  matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
702  ConversionPatternRewriter &rewriter) const override {
703  // Check if yield operation is inside the async coroutine function.
704  auto func = op->template getParentOfType<func::FuncOp>();
705  auto funcCoro = coros->find(func);
706  if (funcCoro == coros->end())
707  return rewriter.notifyMatchFailure(
708  op, "operation is not inside the async coroutine function");
709 
710  Location loc = op->getLoc();
711  const CoroMachinery &coro = funcCoro->getSecond();
712 
713  // Store yielded values into the async values storage and switch async
714  // values state to available.
715  for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
716  Value yieldValue = std::get<0>(tuple);
717  Value asyncValue = std::get<1>(tuple);
718  rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
719  rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
720  }
721 
722  if (coro.asyncToken)
723  // Switch the coroutine completion token to available state.
724  rewriter.create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
725 
726  rewriter.eraseOp(op);
727  rewriter.create<cf::BranchOp>(loc, coro.cleanup);
728 
729  return success();
730  }
731 
732 private:
733  FuncCoroMapPtr coros;
734 };
735 
736 //===----------------------------------------------------------------------===//
737 // Convert cf.assert operation to cf.cond_br into `set_error` block.
738 //===----------------------------------------------------------------------===//
739 
740 class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
741 public:
743  : OpConversionPattern<cf::AssertOp>(ctx), coros(std::move(coros)) {}
744 
746  matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
747  ConversionPatternRewriter &rewriter) const override {
748  // Check if assert operation is inside the async coroutine function.
749  auto func = op->template getParentOfType<func::FuncOp>();
750  auto funcCoro = coros->find(func);
751  if (funcCoro == coros->end())
752  return rewriter.notifyMatchFailure(
753  op, "operation is not inside the async coroutine function");
754 
755  Location loc = op->getLoc();
756  CoroMachinery &coro = funcCoro->getSecond();
757 
758  Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
759  rewriter.setInsertionPointToEnd(cont->getPrevNode());
760  rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(),
761  /*trueDest=*/cont,
762  /*trueArgs=*/ArrayRef<Value>(),
763  /*falseDest=*/setupSetErrorBlock(coro),
764  /*falseArgs=*/ArrayRef<Value>());
765  rewriter.eraseOp(op);
766 
767  return success();
768  }
769 
770 private:
771  FuncCoroMapPtr coros;
772 };
773 
774 //===----------------------------------------------------------------------===//
775 void AsyncToAsyncRuntimePass::runOnOperation() {
776  ModuleOp module = getOperation();
777  SymbolTable symbolTable(module);
778 
779  // Functions with coroutine CFG setups, which are results of outlining
780  // `async.execute` body regions
781  FuncCoroMapPtr coros =
782  std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
783 
784  module.walk([&](ExecuteOp execute) {
785  coros->insert(outlineExecuteOp(symbolTable, execute));
786  });
787 
788  LLVM_DEBUG({
789  llvm::dbgs() << "Outlined " << coros->size()
790  << " functions built from async.execute operations\n";
791  });
792 
793  // Returns true if operation is inside the coroutine.
794  auto isInCoroutine = [&](Operation *op) -> bool {
795  auto parentFunc = op->getParentOfType<func::FuncOp>();
796  return coros->find(parentFunc) != coros->end();
797  };
798 
799  // Lower async operations to async.runtime operations.
800  MLIRContext *ctx = module->getContext();
801  RewritePatternSet asyncPatterns(ctx);
802 
803  // Conversion to async runtime augments original CFG with the coroutine CFG,
804  // and we have to make sure that structured control flow operations with async
805  // operations in nested regions will be converted to branch-based control flow
806  // before we add the coroutine basic blocks.
808 
809  // Async lowering does not use type converter because it must preserve all
810  // types for async.runtime operations.
811  asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
812 
813  asyncPatterns
814  .add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
815  ctx, coros, /*should_lower_blocking_wait=*/true);
816 
817  // Lower assertions to conditional branches into error blocks.
818  asyncPatterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
819 
820  // All high level async operations must be lowered to the runtime operations.
821  ConversionTarget runtimeTarget(*ctx);
822  runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
823  runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
824  runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
825 
826  // Decide if structured control flow has to be lowered to branch-based CFG.
827  runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
828  auto walkResult = op->walk([&](Operation *nested) {
829  bool isAsync = isa<async::AsyncDialect>(nested->getDialect());
830  return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
831  : WalkResult::advance();
832  });
833  return !walkResult.wasInterrupted();
834  });
835  runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
836  func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
837 
838  // Assertions must be converted to runtime errors inside async functions.
839  runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
840  [&](cf::AssertOp op) -> bool {
841  auto func = op->getParentOfType<func::FuncOp>();
842  return !coros->contains(func);
843  });
844 
845  if (failed(applyPartialConversion(module, runtimeTarget,
846  std::move(asyncPatterns)))) {
847  signalPassFailure();
848  return;
849  }
850 }
851 
852 //===----------------------------------------------------------------------===//
854  RewritePatternSet &patterns, ConversionTarget &target) {
855  // Functions with coroutine CFG setups, which are results of converting
856  // async.func.
857  FuncCoroMapPtr coros =
858  std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
859  MLIRContext *ctx = patterns.getContext();
860  // Lower async.func to func.func with coroutine cfg.
861  patterns.add<AsyncCallOpLowering>(ctx);
862  patterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
863 
864  patterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
865  ctx, coros, /*should_lower_blocking_wait=*/false);
866  patterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
867 
868  target.addDynamicallyLegalOp<AwaitOp, AwaitAllOp, YieldOp, cf::AssertOp>(
869  [coros](Operation *op) {
870  auto exec = op->getParentOfType<ExecuteOp>();
871  auto func = op->getParentOfType<func::FuncOp>();
872  return exec || !coros->contains(func);
873  });
874 }
875 
876 void AsyncFuncToAsyncRuntimePass::runOnOperation() {
877  ModuleOp module = getOperation();
878 
879  // Lower async operations to async.runtime operations.
880  MLIRContext *ctx = module->getContext();
881  RewritePatternSet asyncPatterns(ctx);
882  ConversionTarget runtimeTarget(*ctx);
883 
884  // Lower async.func to func.func with coroutine cfg.
886  runtimeTarget);
887 
888  runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
889  runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();
890 
891  runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,
892  cf::BranchOp, cf::CondBranchOp>();
893 
894  if (failed(applyPartialConversion(module, runtimeTarget,
895  std::move(asyncPatterns)))) {
896  signalPassFailure();
897  return;
898  }
899 }
900 
901 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
902  return std::make_unique<AsyncToAsyncRuntimePass>();
903 }
904 
905 std::unique_ptr<OperationPass<ModuleOp>>
907  return std::make_unique<AsyncFuncToAsyncRuntimePass>();
908 }
static Block * setupSetErrorBlock(CoroMachinery &coro)
std::shared_ptr< llvm::DenseMap< func::FuncOp, CoroMachinery > > FuncCoroMapPtr
static constexpr const char kAsyncFnPrefix[]
static std::pair< func::FuncOp, CoroMachinery > outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute)
Outline the body region attached to the async.execute op into a standalone function.
static CoroMachinery setupCoroMachinery(func::FuncOp func)
Utility to partially update the regular function CFG to the coroutine CFG compatible with LLVM corout...
AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:137
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:307
OpListType & getOperations()
Definition: Block.h:134
Operation & front()
Definition: Block.h:150
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
typename SourceOp::Adaptor OpAdaptor
Definition: Pattern.h:145
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
static ImplicitLocOpBuilder atBlockBegin(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult interrupt()
Definition: Visitors.h:51
void cloneConstantsIntoTheRegion(Region &region)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
Definition: PassDetail.cpp:15
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:63
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
void populateAsyncFuncToAsyncRuntimeConversionPatterns(RewritePatternSet &patterns, ConversionTarget &target)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< OperationPass< ModuleOp > > createAsyncToAsyncRuntimePass()
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
std::unique_ptr< OperationPass< ModuleOp > > createAsyncFuncToAsyncRuntimePass()
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26