MLIR  15.0.0git
AsyncToAsyncRuntime.cpp
Go to the documentation of this file.
1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering from high level async operations to async.coro
10 // and async.runtime operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
24 #include "mlir/IR/PatternMatch.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/Support/Debug.h"
29 
30 using namespace mlir;
31 using namespace mlir::async;
32 
33 #define DEBUG_TYPE "async-to-async-runtime"
34 // Prefix for functions outlined from `async.execute` op regions.
35 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
36 
37 namespace {
38 
39 class AsyncToAsyncRuntimePass
40  : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
41 public:
42  AsyncToAsyncRuntimePass() = default;
43  void runOnOperation() override;
44 };
45 
46 } // namespace
47 
48 //===----------------------------------------------------------------------===//
49 // async.execute op outlining to the coroutine functions.
50 //===----------------------------------------------------------------------===//
51 
52 /// Function targeted for coroutine transformation has two additional blocks at
53 /// the end: coroutine cleanup and coroutine suspension.
54 ///
55 /// async.await op lowering additionaly creates a resume block for each
56 /// operation to enable non-blocking waiting via coroutine suspension.
57 namespace {
58 struct CoroMachinery {
59  func::FuncOp func;
60 
61  // Async execute region returns a completion token, and an async value for
62  // each yielded value.
63  //
64  // %token, %result = async.execute -> !async.value<T> {
65  // %0 = arith.constant ... : T
66  // async.yield %0 : T
67  // }
68  Value asyncToken; // token representing completion of the async region
69  llvm::SmallVector<Value, 4> returnValues; // returned async values
70 
71  Value coroHandle; // coroutine handle (!async.coro.handle value)
72  Block *entry; // coroutine entry block
73  Block *setError; // switch completion token and all values to error state
74  Block *cleanup; // coroutine cleanup block
75  Block *suspend; // coroutine suspension block
76 };
77 } // namespace
78 
79 /// Utility to partially update the regular function CFG to the coroutine CFG
80 /// compatible with LLVM coroutines switched-resume lowering using
81 /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
82 /// that branches into preexisting entry block. Also inserts trailing blocks.
83 ///
84 /// The result types of the passed `func` must start with an `async.token`
85 /// and be continued with some number of `async.value`s.
86 ///
87 /// The func given to this function needs to have been preprocessed to have
88 /// either branch or yield ops as terminators. Branches to the cleanup block are
89 /// inserted after each yield.
90 ///
91 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
92 ///
93 /// - `entry` block sets up the coroutine.
94 /// - `set_error` block sets completion token and async values state to error.
95 /// - `cleanup` block cleans up the coroutine state.
96 /// - `suspend block after the @llvm.coro.end() defines what value will be
97 /// returned to the initial caller of a coroutine. Everything before the
98 /// @llvm.coro.end() will be executed at every suspension point.
99 ///
100 /// Coroutine structure (only the important bits):
101 ///
102 /// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>)
103 /// {
104 /// ^entry(<function-arguments>):
105 /// %token = <async token> : !async.token // create async runtime token
106 /// %value = <async value> : !async.value<T> // create async value
107 /// %id = async.coro.id // create a coroutine id
108 /// %hdl = async.coro.begin %id // create a coroutine handle
109 /// cf.br ^preexisting_entry_block
110 ///
111 /// /* preexisting blocks modified to branch to the cleanup block */
112 ///
113 /// ^set_error: // this block created lazily only if needed (see code below)
114 /// async.runtime.set_error %token : !async.token
115 /// async.runtime.set_error %value : !async.value<T>
116 /// cf.br ^cleanup
117 ///
118 /// ^cleanup:
119 /// async.coro.free %hdl // delete the coroutine state
120 /// cf.br ^suspend
121 ///
122 /// ^suspend:
123 /// async.coro.end %hdl // marks the end of a coroutine
124 /// return %token, %value : !async.token, !async.value<T>
125 /// }
126 ///
127 static CoroMachinery setupCoroMachinery(func::FuncOp func) {
128  assert(!func.getBlocks().empty() && "Function must have an entry block");
129 
130  MLIRContext *ctx = func.getContext();
131  Block *entryBlock = &func.getBlocks().front();
132  Block *originalEntryBlock =
133  entryBlock->splitBlock(entryBlock->getOperations().begin());
134  auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
135 
136  // ------------------------------------------------------------------------ //
137  // Allocate async token/values that we will return from a ramp function.
138  // ------------------------------------------------------------------------ //
139  auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result();
140 
141  llvm::SmallVector<Value, 4> retValues;
142  for (auto resType : func.getCallableResults().drop_front())
143  retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result());
144 
145  // ------------------------------------------------------------------------ //
146  // Initialize coroutine: get coroutine id and coroutine handle.
147  // ------------------------------------------------------------------------ //
148  auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
149  auto coroHdlOp =
150  builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
151  builder.create<cf::BranchOp>(originalEntryBlock);
152 
153  Block *cleanupBlock = func.addBlock();
154  Block *suspendBlock = func.addBlock();
155 
156  // ------------------------------------------------------------------------ //
157  // Coroutine cleanup block: deallocate coroutine frame, free the memory.
158  // ------------------------------------------------------------------------ //
159  builder.setInsertionPointToStart(cleanupBlock);
160  builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
161 
162  // Branch into the suspend block.
163  builder.create<cf::BranchOp>(suspendBlock);
164 
165  // ------------------------------------------------------------------------ //
166  // Coroutine suspend block: mark the end of a coroutine and return allocated
167  // async token.
168  // ------------------------------------------------------------------------ //
169  builder.setInsertionPointToStart(suspendBlock);
170 
171  // Mark the end of a coroutine: async.coro.end
172  builder.create<CoroEndOp>(coroHdlOp.handle());
173 
174  // Return created `async.token` and `async.values` from the suspend block.
175  // This will be the return value of a coroutine ramp function.
176  SmallVector<Value, 4> ret{retToken};
177  ret.insert(ret.end(), retValues.begin(), retValues.end());
178  builder.create<func::ReturnOp>(ret);
179 
180  // `async.await` op lowering will create resume blocks for async
181  // continuations, and will conditionally branch to cleanup or suspend blocks.
182 
183  for (Block &block : func.getBody().getBlocks()) {
184  if (&block == entryBlock || &block == cleanupBlock ||
185  &block == suspendBlock)
186  continue;
187  Operation *terminator = block.getTerminator();
188  if (auto yield = dyn_cast<YieldOp>(terminator)) {
189  builder.setInsertionPointToEnd(&block);
190  builder.create<cf::BranchOp>(cleanupBlock);
191  }
192  }
193 
194  // The switch-resumed API based coroutine should be marked with
195  // coroutine.presplit attribute to mark the function as a coroutine.
196  func->setAttr("passthrough", builder.getArrayAttr(
197  StringAttr::get(ctx, "presplitcoroutine")));
198 
199  CoroMachinery machinery;
200  machinery.func = func;
201  machinery.asyncToken = retToken;
202  machinery.returnValues = retValues;
203  machinery.coroHandle = coroHdlOp.handle();
204  machinery.entry = entryBlock;
205  machinery.setError = nullptr; // created lazily only if needed
206  machinery.cleanup = cleanupBlock;
207  machinery.suspend = suspendBlock;
208  return machinery;
209 }
210 
211 // Lazily creates `set_error` block only if it is required for lowering to the
212 // runtime operations (see for example lowering of assert operation).
213 static Block *setupSetErrorBlock(CoroMachinery &coro) {
214  if (coro.setError)
215  return coro.setError;
216 
217  coro.setError = coro.func.addBlock();
218  coro.setError->moveBefore(coro.cleanup);
219 
220  auto builder =
221  ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError);
222 
223  // Coroutine set_error block: set error on token and all returned values.
224  builder.create<RuntimeSetErrorOp>(coro.asyncToken);
225  for (Value retValue : coro.returnValues)
226  builder.create<RuntimeSetErrorOp>(retValue);
227 
228  // Branch into the cleanup block.
229  builder.create<cf::BranchOp>(coro.cleanup);
230 
231  return coro.setError;
232 }
233 
234 /// Outline the body region attached to the `async.execute` op into a standalone
235 /// function.
236 ///
237 /// Note that this is not reversible transformation.
238 static std::pair<func::FuncOp, CoroMachinery>
239 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
240  ModuleOp module = execute->getParentOfType<ModuleOp>();
241 
242  MLIRContext *ctx = module.getContext();
243  Location loc = execute.getLoc();
244 
245  // Make sure that all constants will be inside the outlined async function to
246  // reduce the number of function arguments.
247  cloneConstantsIntoTheRegion(execute.body());
248 
249  // Collect all outlined function inputs.
250  SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
251  execute.dependencies().end());
252  functionInputs.insert(execute.operands().begin(), execute.operands().end());
253  getUsedValuesDefinedAbove(execute.body(), functionInputs);
254 
255  // Collect types for the outlined function inputs and outputs.
256  auto typesRange = llvm::map_range(
257  functionInputs, [](Value value) { return value.getType(); });
258  SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
259  auto outputTypes = execute.getResultTypes();
260 
261  auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
262  auto funcAttrs = ArrayRef<NamedAttribute>();
263 
264  // TODO: Derive outlined function name from the parent FuncOp (support
265  // multiple nested async.execute operations).
266  func::FuncOp func =
267  func::FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
268  symbolTable.insert(func);
269 
271  auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock());
272 
273  // Prepare for coroutine conversion by creating the body of the function.
274  {
275  size_t numDependencies = execute.dependencies().size();
276  size_t numOperands = execute.operands().size();
277 
278  // Await on all dependencies before starting to execute the body region.
279  for (size_t i = 0; i < numDependencies; ++i)
280  builder.create<AwaitOp>(func.getArgument(i));
281 
282  // Await on all async value operands and unwrap the payload.
283  SmallVector<Value, 4> unwrappedOperands(numOperands);
284  for (size_t i = 0; i < numOperands; ++i) {
285  Value operand = func.getArgument(numDependencies + i);
286  unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
287  }
288 
289  // Map from function inputs defined above the execute op to the function
290  // arguments.
291  BlockAndValueMapping valueMapping;
292  valueMapping.map(functionInputs, func.getArguments());
293  valueMapping.map(execute.body().getArguments(), unwrappedOperands);
294 
295  // Clone all operations from the execute operation body into the outlined
296  // function body.
297  for (Operation &op : execute.body().getOps())
298  builder.clone(op, valueMapping);
299  }
300 
301  // Adding entry/cleanup/suspend blocks.
302  CoroMachinery coro = setupCoroMachinery(func);
303 
304  // Suspend async function at the end of an entry block, and resume it using
305  // Async resume operation (execution will be resumed in a thread managed by
306  // the async runtime).
307  {
308  cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
309  builder.setInsertionPointToEnd(coro.entry);
310 
311  // Save the coroutine state: async.coro.save
312  auto coroSaveOp =
313  builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
314 
315  // Pass coroutine to the runtime to be resumed on a runtime managed
316  // thread.
317  builder.create<RuntimeResumeOp>(coro.coroHandle);
318 
319  // Add async.coro.suspend as a suspended block terminator.
320  builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend,
321  branch.getDest(), coro.cleanup);
322 
323  branch.erase();
324  }
325 
326  // Replace the original `async.execute` with a call to outlined function.
327  {
328  ImplicitLocOpBuilder callBuilder(loc, execute);
329  auto callOutlinedFunc = callBuilder.create<func::CallOp>(
330  func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
331  execute.replaceAllUsesWith(callOutlinedFunc.getResults());
332  execute.erase();
333  }
334 
335  return {func, coro};
336 }
337 
338 //===----------------------------------------------------------------------===//
339 // Convert async.create_group operation to async.runtime.create_group
340 //===----------------------------------------------------------------------===//
341 
342 namespace {
343 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
344 public:
346 
348  matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
349  ConversionPatternRewriter &rewriter) const override {
350  rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
351  op, GroupType::get(op->getContext()), adaptor.getOperands());
352  return success();
353  }
354 };
355 } // namespace
356 
357 //===----------------------------------------------------------------------===//
358 // Convert async.add_to_group operation to async.runtime.add_to_group.
359 //===----------------------------------------------------------------------===//
360 
361 namespace {
362 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
363 public:
365 
367  matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
368  ConversionPatternRewriter &rewriter) const override {
369  rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
370  op, rewriter.getIndexType(), adaptor.getOperands());
371  return success();
372  }
373 };
374 } // namespace
375 
376 //===----------------------------------------------------------------------===//
377 // Convert async.await and async.await_all operations to the async.runtime.await
378 // or async.runtime.await_and_resume operations.
379 //===----------------------------------------------------------------------===//
380 
381 namespace {
382 template <typename AwaitType, typename AwaitableType>
383 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
384  using AwaitAdaptor = typename AwaitType::Adaptor;
385 
386 public:
387  AwaitOpLoweringBase(
388  MLIRContext *ctx,
391  outlinedFunctions(outlinedFunctions) {}
392 
394  matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
395  ConversionPatternRewriter &rewriter) const override {
396  // We can only await on one the `AwaitableType` (for `await` it can be
397  // a `token` or a `value`, for `await_all` it must be a `group`).
398  if (!op.operand().getType().template isa<AwaitableType>())
399  return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
400 
401  // Check if await operation is inside the outlined coroutine function.
402  auto func = op->template getParentOfType<func::FuncOp>();
403  auto outlined = outlinedFunctions.find(func);
404  const bool isInCoroutine = outlined != outlinedFunctions.end();
405 
406  Location loc = op->getLoc();
407  Value operand = adaptor.operand();
408 
409  Type i1 = rewriter.getI1Type();
410 
411  // Inside regular functions we use the blocking wait operation to wait for
412  // the async object (token, value or group) to become available.
413  if (!isInCoroutine) {
414  ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
415  builder.create<RuntimeAwaitOp>(loc, operand);
416 
417  // Assert that the awaited operands is not in the error state.
418  Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
419  Value notError = builder.create<arith::XOrIOp>(
420  isError, builder.create<arith::ConstantOp>(
421  loc, i1, builder.getIntegerAttr(i1, 1)));
422 
423  builder.create<cf::AssertOp>(notError,
424  "Awaited async operand is in error state");
425  }
426 
427  // Inside the coroutine we convert await operation into coroutine suspension
428  // point, and resume execution asynchronously.
429  if (isInCoroutine) {
430  CoroMachinery &coro = outlined->getSecond();
431  Block *suspended = op->getBlock();
432 
433  ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
434  MLIRContext *ctx = op->getContext();
435 
436  // Save the coroutine state and resume on a runtime managed thread when
437  // the operand becomes available.
438  auto coroSaveOp =
439  builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
440  builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
441 
442  // Split the entry block before the await operation.
443  Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
444 
445  // Add async.coro.suspend as a suspended block terminator.
446  builder.setInsertionPointToEnd(suspended);
447  builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
448  coro.cleanup);
449 
450  // Split the resume block into error checking and continuation.
451  Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
452 
453  // Check if the awaited value is in the error state.
454  builder.setInsertionPointToStart(resume);
455  auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
456  builder.create<cf::CondBranchOp>(isError,
457  /*trueDest=*/setupSetErrorBlock(coro),
458  /*trueArgs=*/ArrayRef<Value>(),
459  /*falseDest=*/continuation,
460  /*falseArgs=*/ArrayRef<Value>());
461 
462  // Make sure that replacement value will be constructed in the
463  // continuation block.
464  rewriter.setInsertionPointToStart(continuation);
465  }
466 
467  // Erase or replace the await operation with the new value.
468  if (Value replaceWith = getReplacementValue(op, operand, rewriter))
469  rewriter.replaceOp(op, replaceWith);
470  else
471  rewriter.eraseOp(op);
472 
473  return success();
474  }
475 
476  virtual Value getReplacementValue(AwaitType op, Value operand,
477  ConversionPatternRewriter &rewriter) const {
478  return Value();
479  }
480 
481 private:
483 };
484 
485 /// Lowering for `async.await` with a token operand.
486 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
487  using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
488 
489 public:
490  using Base::Base;
491 };
492 
493 /// Lowering for `async.await` with a value operand.
494 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
495  using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
496 
497 public:
498  using Base::Base;
499 
500  Value
501  getReplacementValue(AwaitOp op, Value operand,
502  ConversionPatternRewriter &rewriter) const override {
503  // Load from the async value storage.
504  auto valueType = operand.getType().cast<ValueType>().getValueType();
505  return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
506  }
507 };
508 
509 /// Lowering for `async.await_all` operation.
510 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
511  using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
512 
513 public:
514  using Base::Base;
515 };
516 
517 } // namespace
518 
519 //===----------------------------------------------------------------------===//
520 // Convert async.yield operation to async.runtime operations.
521 //===----------------------------------------------------------------------===//
522 
523 class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
524 public:
526  MLIRContext *ctx,
527  const llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions)
528  : OpConversionPattern<async::YieldOp>(ctx),
529  outlinedFunctions(outlinedFunctions) {}
530 
532  matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
533  ConversionPatternRewriter &rewriter) const override {
534  // Check if yield operation is inside the async coroutine function.
535  auto func = op->template getParentOfType<func::FuncOp>();
536  auto outlined = outlinedFunctions.find(func);
537  if (outlined == outlinedFunctions.end())
538  return rewriter.notifyMatchFailure(
539  op, "operation is not inside the async coroutine function");
540 
541  Location loc = op->getLoc();
542  const CoroMachinery &coro = outlined->getSecond();
543 
544  // Store yielded values into the async values storage and switch async
545  // values state to available.
546  for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
547  Value yieldValue = std::get<0>(tuple);
548  Value asyncValue = std::get<1>(tuple);
549  rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
550  rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
551  }
552 
553  // Switch the coroutine completion token to available state.
554  rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken);
555 
556  return success();
557  }
558 
559 private:
560  const llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions;
561 };
562 
563 //===----------------------------------------------------------------------===//
564 // Convert cf.assert operation to cf.cond_br into `set_error` block.
565 //===----------------------------------------------------------------------===//
566 
567 class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
568 public:
570  MLIRContext *ctx,
572  : OpConversionPattern<cf::AssertOp>(ctx),
573  outlinedFunctions(outlinedFunctions) {}
574 
576  matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
577  ConversionPatternRewriter &rewriter) const override {
578  // Check if assert operation is inside the async coroutine function.
579  auto func = op->template getParentOfType<func::FuncOp>();
580  auto outlined = outlinedFunctions.find(func);
581  if (outlined == outlinedFunctions.end())
582  return rewriter.notifyMatchFailure(
583  op, "operation is not inside the async coroutine function");
584 
585  Location loc = op->getLoc();
586  CoroMachinery &coro = outlined->getSecond();
587 
588  Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
589  rewriter.setInsertionPointToEnd(cont->getPrevNode());
590  rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(),
591  /*trueDest=*/cont,
592  /*trueArgs=*/ArrayRef<Value>(),
593  /*falseDest=*/setupSetErrorBlock(coro),
594  /*falseArgs=*/ArrayRef<Value>());
595  rewriter.eraseOp(op);
596 
597  return success();
598  }
599 
600 private:
602 };
603 
604 //===----------------------------------------------------------------------===//
605 
606 /// Rewrite a func as a coroutine by:
607 /// 1) Wrapping the results into `async.value`.
608 /// 2) Prepending the results with `async.token`.
609 /// 3) Setting up coroutine blocks.
610 /// 4) Rewriting return ops as yield op and branch op into the suspend block.
611 static CoroMachinery rewriteFuncAsCoroutine(func::FuncOp func) {
612  auto *ctx = func->getContext();
613  auto loc = func.getLoc();
614  SmallVector<Type> resultTypes;
615  resultTypes.reserve(func.getCallableResults().size());
616  llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes),
617  [](Type type) { return ValueType::get(type); });
618  func.setType(
619  FunctionType::get(ctx, func.getFunctionType().getInputs(), resultTypes));
620  func.insertResult(0, TokenType::get(ctx), {});
621  for (Block &block : func.getBlocks()) {
622  Operation *terminator = block.getTerminator();
623  if (auto returnOp = dyn_cast<func::ReturnOp>(*terminator)) {
624  ImplicitLocOpBuilder builder(loc, returnOp);
625  builder.create<YieldOp>(returnOp.getOperands());
626  returnOp.erase();
627  }
628  }
629  return setupCoroMachinery(func);
630 }
631 
632 /// Rewrites a call into a function that has been rewritten as a coroutine.
633 ///
634 /// The invocation of this function is safe only when call ops are traversed in
635 /// reverse order of how they appear in a single block. See `funcsToCoroutines`.
636 static void rewriteCallsiteForCoroutine(func::CallOp oldCall,
637  func::FuncOp func) {
638  auto loc = func.getLoc();
639  ImplicitLocOpBuilder callBuilder(loc, oldCall);
640  auto newCall = callBuilder.create<func::CallOp>(
641  func.getName(), func.getCallableResults(), oldCall.getArgOperands());
642 
643  // Await on the async token and all the value results and unwrap the latter.
644  callBuilder.create<AwaitOp>(loc, newCall.getResults().front());
645  SmallVector<Value> unwrappedResults;
646  unwrappedResults.reserve(newCall->getResults().size() - 1);
647  for (Value result : newCall.getResults().drop_front())
648  unwrappedResults.push_back(
649  callBuilder.create<AwaitOp>(loc, result).result());
650  // Careful, when result of a call is piped into another call this could lead
651  // to a dangling pointer.
652  oldCall.replaceAllUsesWith(unwrappedResults);
653  oldCall.erase();
654 }
655 
656 static bool isAllowedToBlock(func::FuncOp func) {
657  return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName);
658 }
659 
661  ModuleOp module,
662  llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions) {
663  // The following code supports the general case when 2 functions mutually
664  // recurse into each other. Because of this and that we are relying on
665  // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase
666  // a FuncOp while inserting an equivalent coroutine, because that could lead
667  // to dangling pointers.
668 
669  SmallVector<func::FuncOp> funcWorklist;
670 
671  // Careful, it's okay to add a func to the worklist multiple times if and only
672  // if the loop processing the worklist will skip the functions that have
673  // already been converted to coroutines.
674  auto addToWorklist = [&](func::FuncOp func) {
675  if (isAllowedToBlock(func))
676  return;
677  // N.B. To refactor this code into a separate pass the lookup in
678  // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary
679  // func and recognizing if it has a coroutine structure is messy. Passing
680  // this dict between the passes is ugly.
681  if (isAllowedToBlock(func) ||
682  outlinedFunctions.find(func) == outlinedFunctions.end()) {
683  for (Operation &op : func.getBody().getOps()) {
684  if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) {
685  funcWorklist.push_back(func);
686  break;
687  }
688  }
689  }
690  };
691 
692  // Traverse in post-order collecting for each func op the await ops it has.
693  for (func::FuncOp func : module.getOps<func::FuncOp>())
694  addToWorklist(func);
695 
696  SymbolTableCollection symbolTable;
697  SymbolUserMap symbolUserMap(symbolTable, module);
698 
699  // Rewrite funcs, while updating call sites and adding them to the worklist.
700  while (!funcWorklist.empty()) {
701  auto func = funcWorklist.pop_back_val();
702  auto insertion = outlinedFunctions.insert({func, CoroMachinery{}});
703  if (!insertion.second)
704  // This function has already been processed because this is either
705  // the corecursive case, or a caller with multiple calls to a newly
706  // created corouting. Either way, skip updating the call sites.
707  continue;
708  insertion.first->second = rewriteFuncAsCoroutine(func);
709  SmallVector<Operation *> users(symbolUserMap.getUsers(func).begin(),
710  symbolUserMap.getUsers(func).end());
711  // If there are multiple calls from the same block they need to be traversed
712  // in reverse order so that symbolUserMap references are not invalidated
713  // when updating the users of the call op which is earlier in the block.
714  llvm::sort(users, [](Operation *a, Operation *b) {
715  Block *blockA = a->getBlock();
716  Block *blockB = b->getBlock();
717  // Impose arbitrary order on blocks so that there is a well-defined order.
718  return blockA > blockB || (blockA == blockB && !a->isBeforeInBlock(b));
719  });
720  // Rewrite the callsites to await on results of the newly created coroutine.
721  for (Operation *op : users) {
722  if (func::CallOp call = dyn_cast<func::CallOp>(*op)) {
723  func::FuncOp caller = call->getParentOfType<func::FuncOp>();
724  rewriteCallsiteForCoroutine(call, func); // Careful, erases the call op.
725  addToWorklist(caller);
726  } else {
727  op->emitError("Unexpected reference to func referenced by symbol");
728  return failure();
729  }
730  }
731  }
732  return success();
733 }
734 
735 //===----------------------------------------------------------------------===//
736 void AsyncToAsyncRuntimePass::runOnOperation() {
737  ModuleOp module = getOperation();
738  SymbolTable symbolTable(module);
739 
740  // Outline all `async.execute` body regions into async functions (coroutines).
742 
743  module.walk([&](ExecuteOp execute) {
744  outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
745  });
746 
747  LLVM_DEBUG({
748  llvm::dbgs() << "Outlined " << outlinedFunctions.size()
749  << " functions built from async.execute operations\n";
750  });
751 
752  // Returns true if operation is inside the coroutine.
753  auto isInCoroutine = [&](Operation *op) -> bool {
754  auto parentFunc = op->getParentOfType<func::FuncOp>();
755  return outlinedFunctions.find(parentFunc) != outlinedFunctions.end();
756  };
757 
758  if (eliminateBlockingAwaitOps &&
759  failed(funcsToCoroutines(module, outlinedFunctions))) {
760  signalPassFailure();
761  return;
762  }
763 
764  // Lower async operations to async.runtime operations.
765  MLIRContext *ctx = module->getContext();
766  RewritePatternSet asyncPatterns(ctx);
767 
768  // Conversion to async runtime augments original CFG with the coroutine CFG,
769  // and we have to make sure that structured control flow operations with async
770  // operations in nested regions will be converted to branch-based control flow
771  // before we add the coroutine basic blocks.
773 
774  // Async lowering does not use type converter because it must preserve all
775  // types for async.runtime operations.
776  asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
777  asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering,
778  AwaitAllOpLowering, YieldOpLowering>(ctx,
779  outlinedFunctions);
780 
781  // Lower assertions to conditional branches into error blocks.
782  asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions);
783 
784  // All high level async operations must be lowered to the runtime operations.
785  ConversionTarget runtimeTarget(*ctx);
786  runtimeTarget.addLegalDialect<AsyncDialect>();
787  runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
788  runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
789 
790  // Decide if structured control flow has to be lowered to branch-based CFG.
791  runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
792  auto walkResult = op->walk([&](Operation *nested) {
793  bool isAsync = isa<async::AsyncDialect>(nested->getDialect());
794  return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
796  });
797  return !walkResult.wasInterrupted();
798  });
799  runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
800  func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
801 
802  // Assertions must be converted to runtime errors inside async functions.
803  runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
804  [&](cf::AssertOp op) -> bool {
805  auto func = op->getParentOfType<func::FuncOp>();
806  return outlinedFunctions.find(func) == outlinedFunctions.end();
807  });
808 
809  if (eliminateBlockingAwaitOps)
810  runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>(
811  [&](RuntimeAwaitOp op) -> bool {
812  return isAllowedToBlock(op->getParentOfType<func::FuncOp>());
813  });
814 
815  if (failed(applyPartialConversion(module, runtimeTarget,
816  std::move(asyncPatterns)))) {
817  signalPassFailure();
818  return;
819  }
820 }
821 
822 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
823  return std::make_unique<AsyncToAsyncRuntimePass>();
824 }
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, ArrayRef< NamedAttribute > attributes, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:27
typename cf::AssertOp ::Adaptor OpAdaptor
Definition: Pattern.h:134
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Block represents an ordered list of Operations.
Definition: Block.h:29
OpListType & getOperations()
Definition: Block.h:128
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
typename async::YieldOp ::Adaptor OpAdaptor
bool isBeforeInBlock(Operation *other)
Given an operation &#39;other&#39; that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:271
Operation & front()
Definition: Block.h:144
static ImplicitLocOpBuilder atBlockBegin(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:151
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
static constexpr const bool value
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
YieldOpLowering(MLIRContext *ctx, const llvm::DenseMap< func::FuncOp, CoroMachinery > &outlinedFunctions)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
void map(Block *from, Block *to)
Inserts a new mapping for &#39;from&#39; to &#39;to&#39;.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:242
LogicalResult matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
OpListType::iterator iterator
Definition: Block.h:131
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:59
static void rewriteCallsiteForCoroutine(func::CallOp oldCall, func::FuncOp func)
Rewrites a call into a function that has been rewritten as a coroutine.
static LogicalResult funcsToCoroutines(ModuleOp module, llvm::DenseMap< func::FuncOp, CoroMachinery > &outlinedFunctions)
IntegerType getI1Type()
Definition: Builders.cpp:50
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
static WalkResult advance()
Definition: Visitors.h:51
static constexpr const char kAsyncFnPrefix[]
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn&#39;t have a listener...
Definition: Builders.h:258
static WalkResult interrupt()
Definition: Visitors.h:50
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
static std::pair< func::FuncOp, CoroMachinery > outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute)
Outline the body region attached to the async.execute op into a standalone function.
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
static bool isAllowedToBlock(func::FuncOp func)
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
void cloneConstantsIntoTheRegion(Region &region)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
Definition: PassDetail.cpp:15
This class represents a map of symbols to users, and provides efficient implementations of symbol que...
Definition: SymbolTable.h:290
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
ArrayRef< Operation * > getUsers(Operation *symbol) const
Return the users of the provided symbol operation.
Definition: SymbolTable.h:299
static CoroMachinery rewriteFuncAsCoroutine(func::FuncOp func)
Rewrite a func as a coroutine by: 1) Wrapping the results into async.value.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
AssertOpLowering(MLIRContext *ctx, llvm::DenseMap< func::FuncOp, CoroMachinery > &outlinedFunctions)
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:369
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:402
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
Type getType() const
Return the type of this value.
Definition: Value.h:118
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
IndexType getIndexType()
Definition: Builders.cpp:48
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
ImplicitLocOpBuilder maintains a &#39;current location&#39;, allowing use of the create<> method without spec...
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:158
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
std::unique_ptr< OperationPass< ModuleOp > > createAsyncToAsyncRuntimePass()
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
This class implements a pattern rewriter for use with ConversionPatterns.
This class allows for representing and managing the symbol table used by operations with the &#39;SymbolT...
Definition: SymbolTable.h:23
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:374
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
value !async value< T > ***static CoroMachinery setupCoroMachinery(func::FuncOp func)
This class describes a specific conversion target.
static Block * setupSetErrorBlock(CoroMachinery &coro)
U cast() const
Definition: Types.h:262
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:289