MLIR  17.0.0git
AsyncToAsyncRuntime.cpp
Go to the documentation of this file.
1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering from high level async operations to async.coro
10 // and async.runtime operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <utility>
15 
17 
18 #include "PassDetail.h"
25 #include "mlir/IR/IRMapping.h"
27 #include "mlir/IR/PatternMatch.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/Support/Debug.h"
32 #include <optional>
33 
34 namespace mlir {
35 #define GEN_PASS_DEF_ASYNCTOASYNCRUNTIME
36 #define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIME
37 #include "mlir/Dialect/Async/Passes.h.inc"
38 } // namespace mlir
39 
40 using namespace mlir;
41 using namespace mlir::async;
42 
43 #define DEBUG_TYPE "async-to-async-runtime"
44 // Prefix for functions outlined from `async.execute` op regions.
45 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
46 
47 namespace {
48 
49 class AsyncToAsyncRuntimePass
50  : public impl::AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
51 public:
52  AsyncToAsyncRuntimePass() = default;
53  void runOnOperation() override;
54 };
55 
56 } // namespace
57 
58 namespace {
59 
60 class AsyncFuncToAsyncRuntimePass
61  : public impl::AsyncFuncToAsyncRuntimeBase<AsyncFuncToAsyncRuntimePass> {
62 public:
63  AsyncFuncToAsyncRuntimePass() = default;
64  void runOnOperation() override;
65 };
66 
67 } // namespace
68 
69 /// Function targeted for coroutine transformation has two additional blocks at
70 /// the end: coroutine cleanup and coroutine suspension.
71 ///
72 /// async.await op lowering additionaly creates a resume block for each
73 /// operation to enable non-blocking waiting via coroutine suspension.
74 namespace {
75 struct CoroMachinery {
76  func::FuncOp func;
77 
78  // Async function returns an optional token, followed by some async values
79  //
80  // async.func @foo() -> !async.value<T> {
81  // %cst = arith.constant 42.0 : T
82  // return %cst: T
83  // }
84  // Async execute region returns a completion token, and an async value for
85  // each yielded value.
86  //
87  // %token, %result = async.execute -> !async.value<T> {
88  // %0 = arith.constant ... : T
89  // async.yield %0 : T
90  // }
91  std::optional<Value> asyncToken; // returned completion token
92  llvm::SmallVector<Value, 4> returnValues; // returned async values
93 
94  Value coroHandle; // coroutine handle (!async.coro.getHandle value)
95  Block *entry; // coroutine entry block
96  std::optional<Block *> setError; // set returned values to error state
97  Block *cleanup; // coroutine cleanup block
98  Block *suspend; // coroutine suspension block
99 };
100 } // namespace
101 
103  std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;
104 
105 /// Utility to partially update the regular function CFG to the coroutine CFG
106 /// compatible with LLVM coroutines switched-resume lowering using
107 /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
108 /// that branches into preexisting entry block. Also inserts trailing blocks.
109 ///
110 /// The result types of the passed `func` start with an optional `async.token`
111 /// and be continued with some number of `async.value`s.
112 ///
113 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
114 ///
115 /// - `entry` block sets up the coroutine.
116 /// - `set_error` block sets completion token and async values state to error.
117 /// - `cleanup` block cleans up the coroutine state.
118 /// - `suspend block after the @llvm.coro.end() defines what value will be
119 /// returned to the initial caller of a coroutine. Everything before the
120 /// @llvm.coro.end() will be executed at every suspension point.
121 ///
122 /// Coroutine structure (only the important bits):
123 ///
124 /// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>)
125 /// {
126 /// ^entry(<function-arguments>):
127 /// %token = <async token> : !async.token // create async runtime token
128 /// %value = <async value> : !async.value<T> // create async value
129 /// %id = async.coro.getId // create a coroutine id
130 /// %hdl = async.coro.begin %id // create a coroutine handle
131 /// cf.br ^preexisting_entry_block
132 ///
133 /// /* preexisting blocks modified to branch to the cleanup block */
134 ///
135 /// ^set_error: // this block created lazily only if needed (see code below)
136 /// async.runtime.set_error %token : !async.token
137 /// async.runtime.set_error %value : !async.value<T>
138 /// cf.br ^cleanup
139 ///
140 /// ^cleanup:
141 /// async.coro.free %hdl // delete the coroutine state
142 /// cf.br ^suspend
143 ///
144 /// ^suspend:
145 /// async.coro.end %hdl // marks the end of a coroutine
146 /// return %token, %value : !async.token, !async.value<T>
147 /// }
148 ///
149 static CoroMachinery setupCoroMachinery(func::FuncOp func) {
150  assert(!func.getBlocks().empty() && "Function must have an entry block");
151 
152  MLIRContext *ctx = func.getContext();
153  Block *entryBlock = &func.getBlocks().front();
154  Block *originalEntryBlock =
155  entryBlock->splitBlock(entryBlock->getOperations().begin());
156  auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
157 
158  // ------------------------------------------------------------------------ //
159  // Allocate async token/values that we will return from a ramp function.
160  // ------------------------------------------------------------------------ //
161 
162  // We treat TokenType as state update marker to represent side-effects of
163  // async computations
164  bool isStateful = func.getCallableResults().front().isa<TokenType>();
165 
166  std::optional<Value> retToken;
167  if (isStateful)
168  retToken.emplace(builder.create<RuntimeCreateOp>(TokenType::get(ctx)));
169 
170  llvm::SmallVector<Value, 4> retValues;
171  ArrayRef<Type> resValueTypes = isStateful
172  ? func.getCallableResults().drop_front()
173  : func.getCallableResults();
174  for (auto resType : resValueTypes)
175  retValues.emplace_back(
176  builder.create<RuntimeCreateOp>(resType).getResult());
177 
178  // ------------------------------------------------------------------------ //
179  // Initialize coroutine: get coroutine id and coroutine handle.
180  // ------------------------------------------------------------------------ //
181  auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
182  auto coroHdlOp =
183  builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.getId());
184  builder.create<cf::BranchOp>(originalEntryBlock);
185 
186  Block *cleanupBlock = func.addBlock();
187  Block *suspendBlock = func.addBlock();
188 
189  // ------------------------------------------------------------------------ //
190  // Coroutine cleanup block: deallocate coroutine frame, free the memory.
191  // ------------------------------------------------------------------------ //
192  builder.setInsertionPointToStart(cleanupBlock);
193  builder.create<CoroFreeOp>(coroIdOp.getId(), coroHdlOp.getHandle());
194 
195  // Branch into the suspend block.
196  builder.create<cf::BranchOp>(suspendBlock);
197 
198  // ------------------------------------------------------------------------ //
199  // Coroutine suspend block: mark the end of a coroutine and return allocated
200  // async token.
201  // ------------------------------------------------------------------------ //
202  builder.setInsertionPointToStart(suspendBlock);
203 
204  // Mark the end of a coroutine: async.coro.end
205  builder.create<CoroEndOp>(coroHdlOp.getHandle());
206 
207  // Return created optional `async.token` and `async.values` from the suspend
208  // block. This will be the return value of a coroutine ramp function.
210  if (retToken)
211  ret.push_back(*retToken);
212  ret.insert(ret.end(), retValues.begin(), retValues.end());
213  builder.create<func::ReturnOp>(ret);
214 
215  // `async.await` op lowering will create resume blocks for async
216  // continuations, and will conditionally branch to cleanup or suspend blocks.
217 
218  // The switch-resumed API based coroutine should be marked with
219  // coroutine.presplit attribute to mark the function as a coroutine.
220  func->setAttr("passthrough", builder.getArrayAttr(
221  StringAttr::get(ctx, "presplitcoroutine")));
222 
223  CoroMachinery machinery;
224  machinery.func = func;
225  machinery.asyncToken = retToken;
226  machinery.returnValues = retValues;
227  machinery.coroHandle = coroHdlOp.getHandle();
228  machinery.entry = entryBlock;
229  machinery.setError = std::nullopt; // created lazily only if needed
230  machinery.cleanup = cleanupBlock;
231  machinery.suspend = suspendBlock;
232  return machinery;
233 }
234 
235 // Lazily creates `set_error` block only if it is required for lowering to the
236 // runtime operations (see for example lowering of assert operation).
237 static Block *setupSetErrorBlock(CoroMachinery &coro) {
238  if (coro.setError)
239  return *coro.setError;
240 
241  coro.setError = coro.func.addBlock();
242  (*coro.setError)->moveBefore(coro.cleanup);
243 
244  auto builder =
245  ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), *coro.setError);
246 
247  // Coroutine set_error block: set error on token and all returned values.
248  if (coro.asyncToken)
249  builder.create<RuntimeSetErrorOp>(*coro.asyncToken);
250 
251  for (Value retValue : coro.returnValues)
252  builder.create<RuntimeSetErrorOp>(retValue);
253 
254  // Branch into the cleanup block.
255  builder.create<cf::BranchOp>(coro.cleanup);
256 
257  return *coro.setError;
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // async.execute op outlining to the coroutine functions.
262 //===----------------------------------------------------------------------===//
263 
264 /// Outline the body region attached to the `async.execute` op into a standalone
265 /// function.
266 ///
267 /// Note that this is not reversible transformation.
268 static std::pair<func::FuncOp, CoroMachinery>
269 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
270  ModuleOp module = execute->getParentOfType<ModuleOp>();
271 
272  MLIRContext *ctx = module.getContext();
273  Location loc = execute.getLoc();
274 
275  // Make sure that all constants will be inside the outlined async function to
276  // reduce the number of function arguments.
277  cloneConstantsIntoTheRegion(execute.getBodyRegion());
278 
279  // Collect all outlined function inputs.
280  SetVector<mlir::Value> functionInputs(execute.getDependencies().begin(),
281  execute.getDependencies().end());
282  functionInputs.insert(execute.getBodyOperands().begin(),
283  execute.getBodyOperands().end());
284  getUsedValuesDefinedAbove(execute.getBodyRegion(), functionInputs);
285 
286  // Collect types for the outlined function inputs and outputs.
287  auto typesRange = llvm::map_range(
288  functionInputs, [](Value value) { return value.getType(); });
289  SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
290  auto outputTypes = execute.getResultTypes();
291 
292  auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
293  auto funcAttrs = ArrayRef<NamedAttribute>();
294 
295  // TODO: Derive outlined function name from the parent FuncOp (support
296  // multiple nested async.execute operations).
297  func::FuncOp func =
298  func::FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
299  symbolTable.insert(func);
300 
302  auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock());
303 
304  // Prepare for coroutine conversion by creating the body of the function.
305  {
306  size_t numDependencies = execute.getDependencies().size();
307  size_t numOperands = execute.getBodyOperands().size();
308 
309  // Await on all dependencies before starting to execute the body region.
310  for (size_t i = 0; i < numDependencies; ++i)
311  builder.create<AwaitOp>(func.getArgument(i));
312 
313  // Await on all async value operands and unwrap the payload.
314  SmallVector<Value, 4> unwrappedOperands(numOperands);
315  for (size_t i = 0; i < numOperands; ++i) {
316  Value operand = func.getArgument(numDependencies + i);
317  unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).getResult();
318  }
319 
320  // Map from function inputs defined above the execute op to the function
321  // arguments.
322  IRMapping valueMapping;
323  valueMapping.map(functionInputs, func.getArguments());
324  valueMapping.map(execute.getBodyRegion().getArguments(), unwrappedOperands);
325 
326  // Clone all operations from the execute operation body into the outlined
327  // function body.
328  for (Operation &op : execute.getBodyRegion().getOps())
329  builder.clone(op, valueMapping);
330  }
331 
332  // Adding entry/cleanup/suspend blocks.
333  CoroMachinery coro = setupCoroMachinery(func);
334 
335  // Suspend async function at the end of an entry block, and resume it using
336  // Async resume operation (execution will be resumed in a thread managed by
337  // the async runtime).
338  {
339  cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
340  builder.setInsertionPointToEnd(coro.entry);
341 
342  // Save the coroutine state: async.coro.save
343  auto coroSaveOp =
344  builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
345 
346  // Pass coroutine to the runtime to be resumed on a runtime managed
347  // thread.
348  builder.create<RuntimeResumeOp>(coro.coroHandle);
349 
350  // Add async.coro.suspend as a suspended block terminator.
351  builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend,
352  branch.getDest(), coro.cleanup);
353 
354  branch.erase();
355  }
356 
357  // Replace the original `async.execute` with a call to outlined function.
358  {
359  ImplicitLocOpBuilder callBuilder(loc, execute);
360  auto callOutlinedFunc = callBuilder.create<func::CallOp>(
361  func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
362  execute.replaceAllUsesWith(callOutlinedFunc.getResults());
363  execute.erase();
364  }
365 
366  return {func, coro};
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // Convert async.create_group operation to async.runtime.create_group
371 //===----------------------------------------------------------------------===//
372 
373 namespace {
374 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
375 public:
377 
379  matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
380  ConversionPatternRewriter &rewriter) const override {
381  rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
382  op, GroupType::get(op->getContext()), adaptor.getOperands());
383  return success();
384  }
385 };
386 } // namespace
387 
388 //===----------------------------------------------------------------------===//
389 // Convert async.add_to_group operation to async.runtime.add_to_group.
390 //===----------------------------------------------------------------------===//
391 
392 namespace {
393 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
394 public:
396 
398  matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
399  ConversionPatternRewriter &rewriter) const override {
400  rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
401  op, rewriter.getIndexType(), adaptor.getOperands());
402  return success();
403  }
404 };
405 } // namespace
406 
407 //===----------------------------------------------------------------------===//
408 // Convert async.func, async.return and async.call operations to non-blocking
409 // operations based on llvm coroutine
410 //===----------------------------------------------------------------------===//
411 
412 namespace {
413 
414 //===----------------------------------------------------------------------===//
415 // Convert async.func operation to func.func
416 //===----------------------------------------------------------------------===//
417 
418 class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> {
419 public:
420  AsyncFuncOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
421  : OpConversionPattern<async::FuncOp>(ctx), coros(std::move(coros)) {}
422 
424  matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
425  ConversionPatternRewriter &rewriter) const override {
426  Location loc = op->getLoc();
427 
428  auto newFuncOp =
429  rewriter.create<func::FuncOp>(loc, op.getName(), op.getFunctionType());
430 
433  // Copy over all attributes other than the name.
434  for (const auto &namedAttr : op->getAttrs()) {
435  if (namedAttr.getName() != SymbolTable::getSymbolAttrName())
436  newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
437  }
438 
439  rewriter.inlineRegionBefore(op.getBody(), newFuncOp.getBody(),
440  newFuncOp.end());
441 
442  CoroMachinery coro = setupCoroMachinery(newFuncOp);
443  (*coros)[newFuncOp] = coro;
444  // no initial suspend, we should hot-start
445 
446  rewriter.eraseOp(op);
447  return success();
448  }
449 
450 private:
451  FuncCoroMapPtr coros;
452 };
453 
454 //===----------------------------------------------------------------------===//
455 // Convert async.call operation to func.call
456 //===----------------------------------------------------------------------===//
457 
458 class AsyncCallOpLowering : public OpConversionPattern<async::CallOp> {
459 public:
460  AsyncCallOpLowering(MLIRContext *ctx)
461  : OpConversionPattern<async::CallOp>(ctx) {}
462 
464  matchAndRewrite(async::CallOp op, OpAdaptor adaptor,
465  ConversionPatternRewriter &rewriter) const override {
466  rewriter.replaceOpWithNewOp<func::CallOp>(
467  op, op.getCallee(), op.getResultTypes(), op.getOperands());
468  return success();
469  }
470 };
471 
472 //===----------------------------------------------------------------------===//
473 // Convert async.return operation to async.runtime operations.
474 //===----------------------------------------------------------------------===//
475 
476 class AsyncReturnOpLowering : public OpConversionPattern<async::ReturnOp> {
477 public:
478  AsyncReturnOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
479  : OpConversionPattern<async::ReturnOp>(ctx), coros(std::move(coros)) {}
480 
482  matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
483  ConversionPatternRewriter &rewriter) const override {
484  auto func = op->template getParentOfType<func::FuncOp>();
485  auto funcCoro = coros->find(func);
486  if (funcCoro == coros->end())
487  return rewriter.notifyMatchFailure(
488  op, "operation is not inside the async coroutine function");
489 
490  Location loc = op->getLoc();
491  const CoroMachinery &coro = funcCoro->getSecond();
492  rewriter.setInsertionPointAfter(op);
493 
494  // Store return values into the async values storage and switch async
495  // values state to available.
496  for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
497  Value returnValue = std::get<0>(tuple);
498  Value asyncValue = std::get<1>(tuple);
499  rewriter.create<RuntimeStoreOp>(loc, returnValue, asyncValue);
500  rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
501  }
502 
503  if (coro.asyncToken)
504  // Switch the coroutine completion token to available state.
505  rewriter.create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
506 
507  rewriter.eraseOp(op);
508  rewriter.create<cf::BranchOp>(loc, coro.cleanup);
509  return success();
510  }
511 
512 private:
513  FuncCoroMapPtr coros;
514 };
515 } // namespace
516 
517 //===----------------------------------------------------------------------===//
518 // Convert async.await and async.await_all operations to the async.runtime.await
519 // or async.runtime.await_and_resume operations.
520 //===----------------------------------------------------------------------===//
521 
522 namespace {
523 template <typename AwaitType, typename AwaitableType>
524 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
525  using AwaitAdaptor = typename AwaitType::Adaptor;
526 
527 public:
528  AwaitOpLoweringBase(MLIRContext *ctx, FuncCoroMapPtr coros,
529  bool shouldLowerBlockingWait)
530  : OpConversionPattern<AwaitType>(ctx), coros(std::move(coros)),
531  shouldLowerBlockingWait(shouldLowerBlockingWait) {}
532 
534  matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
535  ConversionPatternRewriter &rewriter) const override {
536  // We can only await on one the `AwaitableType` (for `await` it can be
537  // a `token` or a `value`, for `await_all` it must be a `group`).
538  if (!op.getOperand().getType().template isa<AwaitableType>())
539  return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
540 
541  // Check if await operation is inside the coroutine function.
542  auto func = op->template getParentOfType<func::FuncOp>();
543  auto funcCoro = coros->find(func);
544  const bool isInCoroutine = funcCoro != coros->end();
545 
546  Location loc = op->getLoc();
547  Value operand = adaptor.getOperand();
548 
549  Type i1 = rewriter.getI1Type();
550 
551  // Delay lowering to block wait in case await op is inside async.execute
552  if (!isInCoroutine && !shouldLowerBlockingWait)
553  return failure();
554 
555  // Inside regular functions we use the blocking wait operation to wait for
556  // the async object (token, value or group) to become available.
557  if (!isInCoroutine) {
558  ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
559  builder.create<RuntimeAwaitOp>(loc, operand);
560 
561  // Assert that the awaited operands is not in the error state.
562  Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
563  Value notError = builder.create<arith::XOrIOp>(
564  isError, builder.create<arith::ConstantOp>(
565  loc, i1, builder.getIntegerAttr(i1, 1)));
566 
567  builder.create<cf::AssertOp>(notError,
568  "Awaited async operand is in error state");
569  }
570 
571  // Inside the coroutine we convert await operation into coroutine suspension
572  // point, and resume execution asynchronously.
573  if (isInCoroutine) {
574  CoroMachinery &coro = funcCoro->getSecond();
575  Block *suspended = op->getBlock();
576 
577  ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
578  MLIRContext *ctx = op->getContext();
579 
580  // Save the coroutine state and resume on a runtime managed thread when
581  // the operand becomes available.
582  auto coroSaveOp =
583  builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
584  builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
585 
586  // Split the entry block before the await operation.
587  Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
588 
589  // Add async.coro.suspend as a suspended block terminator.
590  builder.setInsertionPointToEnd(suspended);
591  builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend, resume,
592  coro.cleanup);
593 
594  // Split the resume block into error checking and continuation.
595  Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
596 
597  // Check if the awaited value is in the error state.
598  builder.setInsertionPointToStart(resume);
599  auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
600  builder.create<cf::CondBranchOp>(isError,
601  /*trueDest=*/setupSetErrorBlock(coro),
602  /*trueArgs=*/ArrayRef<Value>(),
603  /*falseDest=*/continuation,
604  /*falseArgs=*/ArrayRef<Value>());
605 
606  // Make sure that replacement value will be constructed in the
607  // continuation block.
608  rewriter.setInsertionPointToStart(continuation);
609  }
610 
611  // Erase or replace the await operation with the new value.
612  if (Value replaceWith = getReplacementValue(op, operand, rewriter))
613  rewriter.replaceOp(op, replaceWith);
614  else
615  rewriter.eraseOp(op);
616 
617  return success();
618  }
619 
620  virtual Value getReplacementValue(AwaitType op, Value operand,
621  ConversionPatternRewriter &rewriter) const {
622  return Value();
623  }
624 
625 private:
626  FuncCoroMapPtr coros;
627  bool shouldLowerBlockingWait;
628 };
629 
630 /// Lowering for `async.await` with a token operand.
631 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
632  using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
633 
634 public:
635  using Base::Base;
636 };
637 
638 /// Lowering for `async.await` with a value operand.
639 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
640  using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
641 
642 public:
643  using Base::Base;
644 
645  Value
646  getReplacementValue(AwaitOp op, Value operand,
647  ConversionPatternRewriter &rewriter) const override {
648  // Load from the async value storage.
649  auto valueType = operand.getType().cast<ValueType>().getValueType();
650  return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
651  }
652 };
653 
654 /// Lowering for `async.await_all` operation.
655 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
656  using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
657 
658 public:
659  using Base::Base;
660 };
661 
662 } // namespace
663 
664 //===----------------------------------------------------------------------===//
665 // Convert async.yield operation to async.runtime operations.
666 //===----------------------------------------------------------------------===//
667 
668 class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
669 public:
671  : OpConversionPattern<async::YieldOp>(ctx), coros(std::move(coros)) {}
672 
674  matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
675  ConversionPatternRewriter &rewriter) const override {
676  // Check if yield operation is inside the async coroutine function.
677  auto func = op->template getParentOfType<func::FuncOp>();
678  auto funcCoro = coros->find(func);
679  if (funcCoro == coros->end())
680  return rewriter.notifyMatchFailure(
681  op, "operation is not inside the async coroutine function");
682 
683  Location loc = op->getLoc();
684  const CoroMachinery &coro = funcCoro->getSecond();
685 
686  // Store yielded values into the async values storage and switch async
687  // values state to available.
688  for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
689  Value yieldValue = std::get<0>(tuple);
690  Value asyncValue = std::get<1>(tuple);
691  rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
692  rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
693  }
694 
695  if (coro.asyncToken)
696  // Switch the coroutine completion token to available state.
697  rewriter.create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
698 
699  rewriter.eraseOp(op);
700  rewriter.create<cf::BranchOp>(loc, coro.cleanup);
701 
702  return success();
703  }
704 
705 private:
706  FuncCoroMapPtr coros;
707 };
708 
709 //===----------------------------------------------------------------------===//
710 // Convert cf.assert operation to cf.cond_br into `set_error` block.
711 //===----------------------------------------------------------------------===//
712 
713 class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
714 public:
716  : OpConversionPattern<cf::AssertOp>(ctx), coros(std::move(coros)) {}
717 
719  matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
720  ConversionPatternRewriter &rewriter) const override {
721  // Check if assert operation is inside the async coroutine function.
722  auto func = op->template getParentOfType<func::FuncOp>();
723  auto funcCoro = coros->find(func);
724  if (funcCoro == coros->end())
725  return rewriter.notifyMatchFailure(
726  op, "operation is not inside the async coroutine function");
727 
728  Location loc = op->getLoc();
729  CoroMachinery &coro = funcCoro->getSecond();
730 
731  Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
732  rewriter.setInsertionPointToEnd(cont->getPrevNode());
733  rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(),
734  /*trueDest=*/cont,
735  /*trueArgs=*/ArrayRef<Value>(),
736  /*falseDest=*/setupSetErrorBlock(coro),
737  /*falseArgs=*/ArrayRef<Value>());
738  rewriter.eraseOp(op);
739 
740  return success();
741  }
742 
743 private:
744  FuncCoroMapPtr coros;
745 };
746 
747 //===----------------------------------------------------------------------===//
748 void AsyncToAsyncRuntimePass::runOnOperation() {
749  ModuleOp module = getOperation();
750  SymbolTable symbolTable(module);
751 
752  // Functions with coroutine CFG setups, which are results of outlining
753  // `async.execute` body regions
754  FuncCoroMapPtr coros =
755  std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
756 
757  module.walk([&](ExecuteOp execute) {
758  coros->insert(outlineExecuteOp(symbolTable, execute));
759  });
760 
761  LLVM_DEBUG({
762  llvm::dbgs() << "Outlined " << coros->size()
763  << " functions built from async.execute operations\n";
764  });
765 
766  // Returns true if operation is inside the coroutine.
767  auto isInCoroutine = [&](Operation *op) -> bool {
768  auto parentFunc = op->getParentOfType<func::FuncOp>();
769  return coros->find(parentFunc) != coros->end();
770  };
771 
772  // Lower async operations to async.runtime operations.
773  MLIRContext *ctx = module->getContext();
774  RewritePatternSet asyncPatterns(ctx);
775 
776  // Conversion to async runtime augments original CFG with the coroutine CFG,
777  // and we have to make sure that structured control flow operations with async
778  // operations in nested regions will be converted to branch-based control flow
779  // before we add the coroutine basic blocks.
781 
782  // Async lowering does not use type converter because it must preserve all
783  // types for async.runtime operations.
784  asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
785 
786  asyncPatterns
787  .add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
788  ctx, coros, /*should_lower_blocking_wait=*/true);
789 
790  // Lower assertions to conditional branches into error blocks.
791  asyncPatterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
792 
793  // All high level async operations must be lowered to the runtime operations.
794  ConversionTarget runtimeTarget(*ctx);
795  runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
796  runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
797  runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
798 
799  // Decide if structured control flow has to be lowered to branch-based CFG.
800  runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
801  auto walkResult = op->walk([&](Operation *nested) {
802  bool isAsync = isa<async::AsyncDialect>(nested->getDialect());
803  return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
804  : WalkResult::advance();
805  });
806  return !walkResult.wasInterrupted();
807  });
808  runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
809  func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
810 
811  // Assertions must be converted to runtime errors inside async functions.
812  runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
813  [&](cf::AssertOp op) -> bool {
814  auto func = op->getParentOfType<func::FuncOp>();
815  return coros->find(func) == coros->end();
816  });
817 
818  if (failed(applyPartialConversion(module, runtimeTarget,
819  std::move(asyncPatterns)))) {
820  signalPassFailure();
821  return;
822  }
823 }
824 
825 //===----------------------------------------------------------------------===//
827  RewritePatternSet &patterns, ConversionTarget &target) {
828  // Functions with coroutine CFG setups, which are results of converting
829  // async.func.
830  FuncCoroMapPtr coros =
831  std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
832  MLIRContext *ctx = patterns.getContext();
833  // Lower async.func to func.func with coroutine cfg.
834  patterns.add<AsyncCallOpLowering>(ctx);
835  patterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
836 
837  patterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
838  ctx, coros, /*should_lower_blocking_wait=*/false);
839  patterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
840 
841  target.addDynamicallyLegalOp<AwaitOp, AwaitAllOp, YieldOp, cf::AssertOp>(
842  [coros](Operation *op) {
843  auto exec = op->getParentOfType<ExecuteOp>();
844  auto func = op->getParentOfType<func::FuncOp>();
845  return exec || coros->find(func) == coros->end();
846  });
847 }
848 
849 void AsyncFuncToAsyncRuntimePass::runOnOperation() {
850  ModuleOp module = getOperation();
851 
852  // Lower async operations to async.runtime operations.
853  MLIRContext *ctx = module->getContext();
854  RewritePatternSet asyncPatterns(ctx);
855  ConversionTarget runtimeTarget(*ctx);
856 
857  // Lower async.func to func.func with coroutine cfg.
859  runtimeTarget);
860 
861  runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
862  runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();
863 
864  runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,
865  cf::BranchOp, cf::CondBranchOp>();
866 
867  if (failed(applyPartialConversion(module, runtimeTarget,
868  std::move(asyncPatterns)))) {
869  signalPassFailure();
870  return;
871  }
872 }
873 
874 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
875  return std::make_unique<AsyncToAsyncRuntimePass>();
876 }
877 
878 std::unique_ptr<OperationPass<ModuleOp>>
880  return std::make_unique<AsyncFuncToAsyncRuntimePass>();
881 }
static Block * setupSetErrorBlock(CoroMachinery &coro)
std::shared_ptr< llvm::DenseMap< func::FuncOp, CoroMachinery > > FuncCoroMapPtr
static constexpr const char kAsyncFnPrefix[]
static std::pair< func::FuncOp, CoroMachinery > outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute)
Outline the body region attached to the async.execute op into a standalone function.
static CoroMachinery setupCoroMachinery(func::FuncOp func)
Utility to partially update the regular function CFG to the coroutine CFG compatible with LLVM corout...
AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:129
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:291
OpListType & getOperations()
Definition: Block.h:126
Operation & front()
Definition: Block.h:142
IntegerType getI1Type()
Definition: Builders.cpp:58
IndexType getIndexType()
Definition: Builders.cpp:56
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
typename SourceOp::Adaptor OpAdaptor
Definition: Pattern.h:137
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
static ImplicitLocOpBuilder atBlockBegin(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:384
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:389
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:273
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:365
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:195
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:23
static Visibility getSymbolVisibility(Operation *symbol)
Returns the visibility of the given symbol operation.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:58
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:321
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult interrupt()
Definition: Visitors.h:51
void cloneConstantsIntoTheRegion(Region &region)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
Definition: PassDetail.cpp:15
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:59
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
void populateAsyncFuncToAsyncRuntimeConversionPatterns(RewritePatternSet &patterns, ConversionTarget &target)
std::unique_ptr< OperationPass< ModuleOp > > createAsyncToAsyncRuntimePass()
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
std::unique_ptr< OperationPass< ModuleOp > > createAsyncFuncToAsyncRuntimePass()
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26