MLIR  14.0.0git
AsyncToAsyncRuntime.cpp
Go to the documentation of this file.
1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering from high level async operations to async.coro
10 // and async.runtime operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
19 #include "mlir/Dialect/SCF/SCF.h"
23 #include "mlir/IR/PatternMatch.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/Support/Debug.h"
28 
29 using namespace mlir;
30 using namespace mlir::async;
31 
32 #define DEBUG_TYPE "async-to-async-runtime"
33 // Prefix for functions outlined from `async.execute` op regions.
34 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
35 
36 namespace {
37 
38 class AsyncToAsyncRuntimePass
39  : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
40 public:
41  AsyncToAsyncRuntimePass() = default;
42  void runOnOperation() override;
43 };
44 
45 } // namespace
46 
47 //===----------------------------------------------------------------------===//
48 // async.execute op outlining to the coroutine functions.
49 //===----------------------------------------------------------------------===//
50 
51 /// Function targeted for coroutine transformation has two additional blocks at
52 /// the end: coroutine cleanup and coroutine suspension.
53 ///
54 /// async.await op lowering additionaly creates a resume block for each
55 /// operation to enable non-blocking waiting via coroutine suspension.
56 namespace {
57 struct CoroMachinery {
58  FuncOp func;
59 
60  // Async execute region returns a completion token, and an async value for
61  // each yielded value.
62  //
63  // %token, %result = async.execute -> !async.value<T> {
64  // %0 = arith.constant ... : T
65  // async.yield %0 : T
66  // }
67  Value asyncToken; // token representing completion of the async region
68  llvm::SmallVector<Value, 4> returnValues; // returned async values
69 
70  Value coroHandle; // coroutine handle (!async.coro.handle value)
71  Block *entry; // coroutine entry block
72  Block *setError; // switch completion token and all values to error state
73  Block *cleanup; // coroutine cleanup block
74  Block *suspend; // coroutine suspension block
75 };
76 } // namespace
77 
78 /// Utility to partially update the regular function CFG to the coroutine CFG
79 /// compatible with LLVM coroutines switched-resume lowering using
80 /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
81 /// that branches into preexisting entry block. Also inserts trailing blocks.
82 ///
83 /// The result types of the passed `func` must start with an `async.token`
84 /// and be continued with some number of `async.value`s.
85 ///
86 /// The func given to this function needs to have been preprocessed to have
87 /// either branch or yield ops as terminators. Branches to the cleanup block are
88 /// inserted after each yield.
89 ///
90 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
91 ///
92 /// - `entry` block sets up the coroutine.
93 /// - `set_error` block sets completion token and async values state to error.
94 /// - `cleanup` block cleans up the coroutine state.
95 /// - `suspend block after the @llvm.coro.end() defines what value will be
96 /// returned to the initial caller of a coroutine. Everything before the
97 /// @llvm.coro.end() will be executed at every suspension point.
98 ///
99 /// Coroutine structure (only the important bits):
100 ///
101 /// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>)
102 /// {
103 /// ^entry(<function-arguments>):
104 /// %token = <async token> : !async.token // create async runtime token
105 /// %value = <async value> : !async.value<T> // create async value
106 /// %id = async.coro.id // create a coroutine id
107 /// %hdl = async.coro.begin %id // create a coroutine handle
108 /// br ^preexisting_entry_block
109 ///
110 /// /* preexisting blocks modified to branch to the cleanup block */
111 ///
112 /// ^set_error: // this block created lazily only if needed (see code below)
113 /// async.runtime.set_error %token : !async.token
114 /// async.runtime.set_error %value : !async.value<T>
115 /// br ^cleanup
116 ///
117 /// ^cleanup:
118 /// async.coro.free %hdl // delete the coroutine state
119 /// br ^suspend
120 ///
121 /// ^suspend:
122 /// async.coro.end %hdl // marks the end of a coroutine
123 /// return %token, %value : !async.token, !async.value<T>
124 /// }
125 ///
126 static CoroMachinery setupCoroMachinery(FuncOp func) {
127  assert(!func.getBlocks().empty() && "Function must have an entry block");
128 
129  MLIRContext *ctx = func.getContext();
130  Block *entryBlock = &func.getBlocks().front();
131  Block *originalEntryBlock =
132  entryBlock->splitBlock(entryBlock->getOperations().begin());
133  auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
134 
135  // ------------------------------------------------------------------------ //
136  // Allocate async token/values that we will return from a ramp function.
137  // ------------------------------------------------------------------------ //
138  auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result();
139 
140  llvm::SmallVector<Value, 4> retValues;
141  for (auto resType : func.getCallableResults().drop_front())
142  retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result());
143 
144  // ------------------------------------------------------------------------ //
145  // Initialize coroutine: get coroutine id and coroutine handle.
146  // ------------------------------------------------------------------------ //
147  auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
148  auto coroHdlOp =
149  builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
150  builder.create<BranchOp>(originalEntryBlock);
151 
152  Block *cleanupBlock = func.addBlock();
153  Block *suspendBlock = func.addBlock();
154 
155  // ------------------------------------------------------------------------ //
156  // Coroutine cleanup block: deallocate coroutine frame, free the memory.
157  // ------------------------------------------------------------------------ //
158  builder.setInsertionPointToStart(cleanupBlock);
159  builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
160 
161  // Branch into the suspend block.
162  builder.create<BranchOp>(suspendBlock);
163 
164  // ------------------------------------------------------------------------ //
165  // Coroutine suspend block: mark the end of a coroutine and return allocated
166  // async token.
167  // ------------------------------------------------------------------------ //
168  builder.setInsertionPointToStart(suspendBlock);
169 
170  // Mark the end of a coroutine: async.coro.end
171  builder.create<CoroEndOp>(coroHdlOp.handle());
172 
173  // Return created `async.token` and `async.values` from the suspend block.
174  // This will be the return value of a coroutine ramp function.
175  SmallVector<Value, 4> ret{retToken};
176  ret.insert(ret.end(), retValues.begin(), retValues.end());
177  builder.create<ReturnOp>(ret);
178 
179  // `async.await` op lowering will create resume blocks for async
180  // continuations, and will conditionally branch to cleanup or suspend blocks.
181 
182  for (Block &block : func.body().getBlocks()) {
183  if (&block == entryBlock || &block == cleanupBlock ||
184  &block == suspendBlock)
185  continue;
186  Operation *terminator = block.getTerminator();
187  if (auto yield = dyn_cast<YieldOp>(terminator)) {
188  builder.setInsertionPointToEnd(&block);
189  builder.create<BranchOp>(cleanupBlock);
190  }
191  }
192 
193  // The switch-resumed API based coroutine should be marked with
194  // "coroutine.presplit" attribute with value "0" to mark the function as a
195  // coroutine.
196  func->setAttr("passthrough", builder.getArrayAttr(builder.getArrayAttr(
197  {builder.getStringAttr("coroutine.presplit"),
198  builder.getStringAttr("0")})));
199 
200  CoroMachinery machinery;
201  machinery.func = func;
202  machinery.asyncToken = retToken;
203  machinery.returnValues = retValues;
204  machinery.coroHandle = coroHdlOp.handle();
205  machinery.entry = entryBlock;
206  machinery.setError = nullptr; // created lazily only if needed
207  machinery.cleanup = cleanupBlock;
208  machinery.suspend = suspendBlock;
209  return machinery;
210 }
211 
212 // Lazily creates `set_error` block only if it is required for lowering to the
213 // runtime operations (see for example lowering of assert operation).
214 static Block *setupSetErrorBlock(CoroMachinery &coro) {
215  if (coro.setError)
216  return coro.setError;
217 
218  coro.setError = coro.func.addBlock();
219  coro.setError->moveBefore(coro.cleanup);
220 
221  auto builder =
222  ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError);
223 
224  // Coroutine set_error block: set error on token and all returned values.
225  builder.create<RuntimeSetErrorOp>(coro.asyncToken);
226  for (Value retValue : coro.returnValues)
227  builder.create<RuntimeSetErrorOp>(retValue);
228 
229  // Branch into the cleanup block.
230  builder.create<BranchOp>(coro.cleanup);
231 
232  return coro.setError;
233 }
234 
235 /// Outline the body region attached to the `async.execute` op into a standalone
236 /// function.
237 ///
238 /// Note that this is not reversible transformation.
239 static std::pair<FuncOp, CoroMachinery>
240 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
241  ModuleOp module = execute->getParentOfType<ModuleOp>();
242 
243  MLIRContext *ctx = module.getContext();
244  Location loc = execute.getLoc();
245 
246  // Make sure that all constants will be inside the outlined async function to
247  // reduce the number of function arguments.
248  cloneConstantsIntoTheRegion(execute.body());
249 
250  // Collect all outlined function inputs.
251  SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
252  execute.dependencies().end());
253  functionInputs.insert(execute.operands().begin(), execute.operands().end());
254  getUsedValuesDefinedAbove(execute.body(), functionInputs);
255 
256  // Collect types for the outlined function inputs and outputs.
257  auto typesRange = llvm::map_range(
258  functionInputs, [](Value value) { return value.getType(); });
259  SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
260  auto outputTypes = execute.getResultTypes();
261 
262  auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
263  auto funcAttrs = ArrayRef<NamedAttribute>();
264 
265  // TODO: Derive outlined function name from the parent FuncOp (support
266  // multiple nested async.execute operations).
267  FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
268  symbolTable.insert(func);
269 
271  auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock());
272 
273  // Prepare for coroutine conversion by creating the body of the function.
274  {
275  size_t numDependencies = execute.dependencies().size();
276  size_t numOperands = execute.operands().size();
277 
278  // Await on all dependencies before starting to execute the body region.
279  for (size_t i = 0; i < numDependencies; ++i)
280  builder.create<AwaitOp>(func.getArgument(i));
281 
282  // Await on all async value operands and unwrap the payload.
283  SmallVector<Value, 4> unwrappedOperands(numOperands);
284  for (size_t i = 0; i < numOperands; ++i) {
285  Value operand = func.getArgument(numDependencies + i);
286  unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
287  }
288 
289  // Map from function inputs defined above the execute op to the function
290  // arguments.
291  BlockAndValueMapping valueMapping;
292  valueMapping.map(functionInputs, func.getArguments());
293  valueMapping.map(execute.body().getArguments(), unwrappedOperands);
294 
295  // Clone all operations from the execute operation body into the outlined
296  // function body.
297  for (Operation &op : execute.body().getOps())
298  builder.clone(op, valueMapping);
299  }
300 
301  // Adding entry/cleanup/suspend blocks.
302  CoroMachinery coro = setupCoroMachinery(func);
303 
304  // Suspend async function at the end of an entry block, and resume it using
305  // Async resume operation (execution will be resumed in a thread managed by
306  // the async runtime).
307  {
308  BranchOp branch = cast<BranchOp>(coro.entry->getTerminator());
309  builder.setInsertionPointToEnd(coro.entry);
310 
311  // Save the coroutine state: async.coro.save
312  auto coroSaveOp =
313  builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
314 
315  // Pass coroutine to the runtime to be resumed on a runtime managed
316  // thread.
317  builder.create<RuntimeResumeOp>(coro.coroHandle);
318 
319  // Add async.coro.suspend as a suspended block terminator.
320  builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend,
321  branch.getDest(), coro.cleanup);
322 
323  branch.erase();
324  }
325 
326  // Replace the original `async.execute` with a call to outlined function.
327  {
328  ImplicitLocOpBuilder callBuilder(loc, execute);
329  auto callOutlinedFunc = callBuilder.create<CallOp>(
330  func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
331  execute.replaceAllUsesWith(callOutlinedFunc.getResults());
332  execute.erase();
333  }
334 
335  return {func, coro};
336 }
337 
338 //===----------------------------------------------------------------------===//
339 // Convert async.create_group operation to async.runtime.create_group
340 //===----------------------------------------------------------------------===//
341 
342 namespace {
343 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
344 public:
346 
348  matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
349  ConversionPatternRewriter &rewriter) const override {
350  rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
351  op, GroupType::get(op->getContext()), adaptor.getOperands());
352  return success();
353  }
354 };
355 } // namespace
356 
357 //===----------------------------------------------------------------------===//
358 // Convert async.add_to_group operation to async.runtime.add_to_group.
359 //===----------------------------------------------------------------------===//
360 
361 namespace {
362 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
363 public:
365 
367  matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
368  ConversionPatternRewriter &rewriter) const override {
369  rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
370  op, rewriter.getIndexType(), adaptor.getOperands());
371  return success();
372  }
373 };
374 } // namespace
375 
376 //===----------------------------------------------------------------------===//
377 // Convert async.await and async.await_all operations to the async.runtime.await
378 // or async.runtime.await_and_resume operations.
379 //===----------------------------------------------------------------------===//
380 
381 namespace {
382 template <typename AwaitType, typename AwaitableType>
383 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
384  using AwaitAdaptor = typename AwaitType::Adaptor;
385 
386 public:
387  AwaitOpLoweringBase(MLIRContext *ctx,
388  llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
390  outlinedFunctions(outlinedFunctions) {}
391 
393  matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
394  ConversionPatternRewriter &rewriter) const override {
395  // We can only await on one the `AwaitableType` (for `await` it can be
396  // a `token` or a `value`, for `await_all` it must be a `group`).
397  if (!op.operand().getType().template isa<AwaitableType>())
398  return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
399 
400  // Check if await operation is inside the outlined coroutine function.
401  auto func = op->template getParentOfType<FuncOp>();
402  auto outlined = outlinedFunctions.find(func);
403  const bool isInCoroutine = outlined != outlinedFunctions.end();
404 
405  Location loc = op->getLoc();
406  Value operand = adaptor.operand();
407 
408  Type i1 = rewriter.getI1Type();
409 
410  // Inside regular functions we use the blocking wait operation to wait for
411  // the async object (token, value or group) to become available.
412  if (!isInCoroutine) {
413  ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
414  builder.create<RuntimeAwaitOp>(loc, operand);
415 
416  // Assert that the awaited operands is not in the error state.
417  Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
418  Value notError = builder.create<arith::XOrIOp>(
419  isError, builder.create<arith::ConstantOp>(
420  loc, i1, builder.getIntegerAttr(i1, 1)));
421 
422  builder.create<AssertOp>(notError,
423  "Awaited async operand is in error state");
424  }
425 
426  // Inside the coroutine we convert await operation into coroutine suspension
427  // point, and resume execution asynchronously.
428  if (isInCoroutine) {
429  CoroMachinery &coro = outlined->getSecond();
430  Block *suspended = op->getBlock();
431 
432  ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
433  MLIRContext *ctx = op->getContext();
434 
435  // Save the coroutine state and resume on a runtime managed thread when
436  // the operand becomes available.
437  auto coroSaveOp =
438  builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
439  builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
440 
441  // Split the entry block before the await operation.
442  Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
443 
444  // Add async.coro.suspend as a suspended block terminator.
445  builder.setInsertionPointToEnd(suspended);
446  builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
447  coro.cleanup);
448 
449  // Split the resume block into error checking and continuation.
450  Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
451 
452  // Check if the awaited value is in the error state.
453  builder.setInsertionPointToStart(resume);
454  auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
455  builder.create<CondBranchOp>(isError,
456  /*trueDest=*/setupSetErrorBlock(coro),
457  /*trueArgs=*/ArrayRef<Value>(),
458  /*falseDest=*/continuation,
459  /*falseArgs=*/ArrayRef<Value>());
460 
461  // Make sure that replacement value will be constructed in the
462  // continuation block.
463  rewriter.setInsertionPointToStart(continuation);
464  }
465 
466  // Erase or replace the await operation with the new value.
467  if (Value replaceWith = getReplacementValue(op, operand, rewriter))
468  rewriter.replaceOp(op, replaceWith);
469  else
470  rewriter.eraseOp(op);
471 
472  return success();
473  }
474 
475  virtual Value getReplacementValue(AwaitType op, Value operand,
476  ConversionPatternRewriter &rewriter) const {
477  return Value();
478  }
479 
480 private:
481  llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
482 };
483 
484 /// Lowering for `async.await` with a token operand.
485 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
486  using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
487 
488 public:
489  using Base::Base;
490 };
491 
492 /// Lowering for `async.await` with a value operand.
493 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
494  using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
495 
496 public:
497  using Base::Base;
498 
499  Value
500  getReplacementValue(AwaitOp op, Value operand,
501  ConversionPatternRewriter &rewriter) const override {
502  // Load from the async value storage.
503  auto valueType = operand.getType().cast<ValueType>().getValueType();
504  return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
505  }
506 };
507 
508 /// Lowering for `async.await_all` operation.
509 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
510  using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
511 
512 public:
513  using Base::Base;
514 };
515 
516 } // namespace
517 
518 //===----------------------------------------------------------------------===//
519 // Convert async.yield operation to async.runtime operations.
520 //===----------------------------------------------------------------------===//
521 
522 class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
523 public:
525  MLIRContext *ctx,
526  const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
527  : OpConversionPattern<async::YieldOp>(ctx),
528  outlinedFunctions(outlinedFunctions) {}
529 
531  matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
532  ConversionPatternRewriter &rewriter) const override {
533  // Check if yield operation is inside the async coroutine function.
534  auto func = op->template getParentOfType<FuncOp>();
535  auto outlined = outlinedFunctions.find(func);
536  if (outlined == outlinedFunctions.end())
537  return rewriter.notifyMatchFailure(
538  op, "operation is not inside the async coroutine function");
539 
540  Location loc = op->getLoc();
541  const CoroMachinery &coro = outlined->getSecond();
542 
543  // Store yielded values into the async values storage and switch async
544  // values state to available.
545  for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
546  Value yieldValue = std::get<0>(tuple);
547  Value asyncValue = std::get<1>(tuple);
548  rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
549  rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
550  }
551 
552  // Switch the coroutine completion token to available state.
553  rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken);
554 
555  return success();
556  }
557 
558 private:
559  const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
560 };
561 
562 //===----------------------------------------------------------------------===//
563 // Convert std.assert operation to cond_br into `set_error` block.
564 //===----------------------------------------------------------------------===//
565 
566 class AssertOpLowering : public OpConversionPattern<AssertOp> {
567 public:
569  llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
570  : OpConversionPattern<AssertOp>(ctx),
571  outlinedFunctions(outlinedFunctions) {}
572 
574  matchAndRewrite(AssertOp op, OpAdaptor adaptor,
575  ConversionPatternRewriter &rewriter) const override {
576  // Check if assert operation is inside the async coroutine function.
577  auto func = op->template getParentOfType<FuncOp>();
578  auto outlined = outlinedFunctions.find(func);
579  if (outlined == outlinedFunctions.end())
580  return rewriter.notifyMatchFailure(
581  op, "operation is not inside the async coroutine function");
582 
583  Location loc = op->getLoc();
584  CoroMachinery &coro = outlined->getSecond();
585 
586  Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
587  rewriter.setInsertionPointToEnd(cont->getPrevNode());
588  rewriter.create<CondBranchOp>(loc, adaptor.getArg(),
589  /*trueDest=*/cont,
590  /*trueArgs=*/ArrayRef<Value>(),
591  /*falseDest=*/setupSetErrorBlock(coro),
592  /*falseArgs=*/ArrayRef<Value>());
593  rewriter.eraseOp(op);
594 
595  return success();
596  }
597 
598 private:
599  llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
600 };
601 
602 //===----------------------------------------------------------------------===//
603 
604 /// Rewrite a func as a coroutine by:
605 /// 1) Wrapping the results into `async.value`.
606 /// 2) Prepending the results with `async.token`.
607 /// 3) Setting up coroutine blocks.
608 /// 4) Rewriting return ops as yield op and branch op into the suspend block.
609 static CoroMachinery rewriteFuncAsCoroutine(FuncOp func) {
610  auto *ctx = func->getContext();
611  auto loc = func.getLoc();
612  SmallVector<Type> resultTypes;
613  resultTypes.reserve(func.getCallableResults().size());
614  llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes),
615  [](Type type) { return ValueType::get(type); });
616  func.setType(FunctionType::get(ctx, func.getType().getInputs(), resultTypes));
617  func.insertResult(0, TokenType::get(ctx), {});
618  for (Block &block : func.getBlocks()) {
619  Operation *terminator = block.getTerminator();
620  if (auto returnOp = dyn_cast<ReturnOp>(*terminator)) {
621  ImplicitLocOpBuilder builder(loc, returnOp);
622  builder.create<YieldOp>(returnOp.getOperands());
623  returnOp.erase();
624  }
625  }
626  return setupCoroMachinery(func);
627 }
628 
629 /// Rewrites a call into a function that has been rewritten as a coroutine.
630 ///
631 /// The invocation of this function is safe only when call ops are traversed in
632 /// reverse order of how they appear in a single block. See `funcsToCoroutines`.
633 static void rewriteCallsiteForCoroutine(CallOp oldCall, FuncOp func) {
634  auto loc = func.getLoc();
635  ImplicitLocOpBuilder callBuilder(loc, oldCall);
636  auto newCall = callBuilder.create<CallOp>(
637  func.getName(), func.getCallableResults(), oldCall.getArgOperands());
638 
639  // Await on the async token and all the value results and unwrap the latter.
640  callBuilder.create<AwaitOp>(loc, newCall.getResults().front());
641  SmallVector<Value> unwrappedResults;
642  unwrappedResults.reserve(newCall->getResults().size() - 1);
643  for (Value result : newCall.getResults().drop_front())
644  unwrappedResults.push_back(
645  callBuilder.create<AwaitOp>(loc, result).result());
646  // Careful, when result of a call is piped into another call this could lead
647  // to a dangling pointer.
648  oldCall.replaceAllUsesWith(unwrappedResults);
649  oldCall.erase();
650 }
651 
652 static bool isAllowedToBlock(FuncOp func) {
653  return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName);
654 }
655 
656 static LogicalResult
657 funcsToCoroutines(ModuleOp module,
658  llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) {
659  // The following code supports the general case when 2 functions mutually
660  // recurse into each other. Because of this and that we are relying on
661  // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase
662  // a FuncOp while inserting an equivalent coroutine, because that could lead
663  // to dangling pointers.
664 
665  SmallVector<FuncOp> funcWorklist;
666 
667  // Careful, it's okay to add a func to the worklist multiple times if and only
668  // if the loop processing the worklist will skip the functions that have
669  // already been converted to coroutines.
670  auto addToWorklist = [&](FuncOp func) {
671  if (isAllowedToBlock(func))
672  return;
673  // N.B. To refactor this code into a separate pass the lookup in
674  // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary
675  // func and recognizing if it has a coroutine structure is messy. Passing
676  // this dict between the passes is ugly.
677  if (isAllowedToBlock(func) ||
678  outlinedFunctions.find(func) == outlinedFunctions.end()) {
679  for (Operation &op : func.body().getOps()) {
680  if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) {
681  funcWorklist.push_back(func);
682  break;
683  }
684  }
685  }
686  };
687 
688  // Traverse in post-order collecting for each func op the await ops it has.
689  for (FuncOp func : module.getOps<FuncOp>())
690  addToWorklist(func);
691 
692  SymbolTableCollection symbolTable;
693  SymbolUserMap symbolUserMap(symbolTable, module);
694 
695  // Rewrite funcs, while updating call sites and adding them to the worklist.
696  while (!funcWorklist.empty()) {
697  auto func = funcWorklist.pop_back_val();
698  auto insertion = outlinedFunctions.insert({func, CoroMachinery{}});
699  if (!insertion.second)
700  // This function has already been processed because this is either
701  // the corecursive case, or a caller with multiple calls to a newly
702  // created corouting. Either way, skip updating the call sites.
703  continue;
704  insertion.first->second = rewriteFuncAsCoroutine(func);
705  SmallVector<Operation *> users(symbolUserMap.getUsers(func).begin(),
706  symbolUserMap.getUsers(func).end());
707  // If there are multiple calls from the same block they need to be traversed
708  // in reverse order so that symbolUserMap references are not invalidated
709  // when updating the users of the call op which is earlier in the block.
710  llvm::sort(users, [](Operation *a, Operation *b) {
711  Block *blockA = a->getBlock();
712  Block *blockB = b->getBlock();
713  // Impose arbitrary order on blocks so that there is a well-defined order.
714  return blockA > blockB || (blockA == blockB && !a->isBeforeInBlock(b));
715  });
716  // Rewrite the callsites to await on results of the newly created coroutine.
717  for (Operation *op : users) {
718  if (CallOp call = dyn_cast<mlir::CallOp>(*op)) {
719  FuncOp caller = call->getParentOfType<FuncOp>();
720  rewriteCallsiteForCoroutine(call, func); // Careful, erases the call op.
721  addToWorklist(caller);
722  } else {
723  op->emitError("Unexpected reference to func referenced by symbol");
724  return failure();
725  }
726  }
727  }
728  return success();
729 }
730 
731 //===----------------------------------------------------------------------===//
732 void AsyncToAsyncRuntimePass::runOnOperation() {
733  ModuleOp module = getOperation();
734  SymbolTable symbolTable(module);
735 
736  // Outline all `async.execute` body regions into async functions (coroutines).
737  llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
738 
739  module.walk([&](ExecuteOp execute) {
740  outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
741  });
742 
743  LLVM_DEBUG({
744  llvm::dbgs() << "Outlined " << outlinedFunctions.size()
745  << " functions built from async.execute operations\n";
746  });
747 
748  // Returns true if operation is inside the coroutine.
749  auto isInCoroutine = [&](Operation *op) -> bool {
750  auto parentFunc = op->getParentOfType<FuncOp>();
751  return outlinedFunctions.find(parentFunc) != outlinedFunctions.end();
752  };
753 
754  if (eliminateBlockingAwaitOps &&
755  failed(funcsToCoroutines(module, outlinedFunctions))) {
756  signalPassFailure();
757  return;
758  }
759 
760  // Lower async operations to async.runtime operations.
761  MLIRContext *ctx = module->getContext();
762  RewritePatternSet asyncPatterns(ctx);
763 
764  // Conversion to async runtime augments original CFG with the coroutine CFG,
765  // and we have to make sure that structured control flow operations with async
766  // operations in nested regions will be converted to branch-based control flow
767  // before we add the coroutine basic blocks.
769 
770  // Async lowering does not use type converter because it must preserve all
771  // types for async.runtime operations.
772  asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
773  asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering,
774  AwaitAllOpLowering, YieldOpLowering>(ctx,
775  outlinedFunctions);
776 
777  // Lower assertions to conditional branches into error blocks.
778  asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions);
779 
780  // All high level async operations must be lowered to the runtime operations.
781  ConversionTarget runtimeTarget(*ctx);
782  runtimeTarget.addLegalDialect<AsyncDialect>();
783  runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
784  runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
785 
786  // Decide if structured control flow has to be lowered to branch-based CFG.
787  runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
788  auto walkResult = op->walk([&](Operation *nested) {
789  bool isAsync = isa<async::AsyncDialect>(nested->getDialect());
790  return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
792  });
793  return !walkResult.wasInterrupted();
794  });
795  runtimeTarget.addLegalOp<AssertOp, arith::XOrIOp, arith::ConstantOp,
796  ConstantOp, BranchOp, CondBranchOp>();
797 
798  // Assertions must be converted to runtime errors inside async functions.
799  runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool {
800  auto func = op->getParentOfType<FuncOp>();
801  return outlinedFunctions.find(func) == outlinedFunctions.end();
802  });
803 
804  if (eliminateBlockingAwaitOps)
805  runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>(
806  [&](RuntimeAwaitOp op) -> bool {
807  return isAllowedToBlock(op->getParentOfType<FuncOp>());
808  });
809 
810  if (failed(applyPartialConversion(module, runtimeTarget,
811  std::move(asyncPatterns)))) {
812  signalPassFailure();
813  return;
814  }
815 }
816 
817 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
818  return std::make_unique<AsyncToAsyncRuntimePass>();
819 }
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, ArrayRef< NamedAttribute > attributes, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:27
typename AssertOp ::Adaptor OpAdaptor
Definition: Pattern.h:134
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
Block represents an ordered list of Operations.
Definition: Block.h:29
void populateLoopToStdConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to lower from scf.for, scf.if, and loop.terminator to CFG operations within...
OpListType & getOperations()
Definition: Block.h:128
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
typename async::YieldOp ::Adaptor OpAdaptor
static bool isAllowedToBlock(FuncOp func)
bool isBeforeInBlock(Operation *other)
Given an operation &#39;other&#39; that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:271
YieldOpLowering(MLIRContext *ctx, const llvm::DenseMap< FuncOp, CoroMachinery > &outlinedFunctions)
LogicalResult notifyMatchFailure(Operation *op, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
Operation & front()
Definition: Block.h:144
static std::pair< FuncOp, CoroMachinery > outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute)
Outline the body region attached to the async.execute op into a standalone function.
static ImplicitLocOpBuilder atBlockBegin(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:96
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
static constexpr const bool value
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
value !async value< T > ***static CoroMachinery setupCoroMachinery(FuncOp func)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
static CoroMachinery rewriteFuncAsCoroutine(FuncOp func)
Rewrite a func as a coroutine by: 1) Wrapping the results into async.value.
LogicalResult matchAndRewrite(AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
void map(Block *from, Block *to)
Inserts a new mapping for &#39;from&#39; to &#39;to&#39;.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:242
LogicalResult matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
OpListType::iterator iterator
Definition: Block.h:131
static LogicalResult funcsToCoroutines(ModuleOp module, llvm::DenseMap< FuncOp, CoroMachinery > &outlinedFunctions)
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:59
AssertOpLowering(MLIRContext *ctx, llvm::DenseMap< FuncOp, CoroMachinery > &outlinedFunctions)
IntegerType getI1Type()
Definition: Builders.cpp:50
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
static WalkResult advance()
Definition: Visitors.h:51
static constexpr const char kAsyncFnPrefix[]
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn&#39;t have a listener...
Definition: Builders.h:251
static WalkResult interrupt()
Definition: Visitors.h:50
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
void cloneConstantsIntoTheRegion(Region &region)
Clone ConstantLike operations that are defined above the given region and have users in the region in...
Definition: PassDetail.cpp:15
This class represents a map of symbols to users, and provides efficient implementations of symbol que...
Definition: SymbolTable.h:290
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
ArrayRef< Operation * > getUsers(Operation *symbol) const
Return the users of the provided symbol operation.
Definition: SymbolTable.h:299
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:362
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:347
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
IndexType getIndexType()
Definition: Builders.cpp:48
ImplicitLocOpBuilder maintains a &#39;current location&#39;, allowing use of the create<> method without spec...
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:103
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
std::unique_ptr< OperationPass< ModuleOp > > createAsyncToAsyncRuntimePass()
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
This class implements a pattern rewriter for use with ConversionPatterns.
This class allows for representing and managing the symbol table used by operations with the &#39;SymbolT...
Definition: SymbolTable.h:23
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:367
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
static Block * setupSetErrorBlock(CoroMachinery &coro)
static void rewriteCallsiteForCoroutine(CallOp oldCall, FuncOp func)
Rewrites a call into a function that has been rewritten as a coroutine.
U cast() const
Definition: Types.h:250
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:289