MLIR  20.0.0git
AsyncToLLVM.cpp
Go to the documentation of this file.
1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/Pass/Pass.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTASYNCTOLLVMPASS
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 #define DEBUG_TYPE "convert-async-to-llvm"
34 
35 using namespace mlir;
36 using namespace mlir::async;
37 
38 //===----------------------------------------------------------------------===//
39 // Async Runtime C API declaration.
40 //===----------------------------------------------------------------------===//
41 
42 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
43 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
44 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
45 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
46 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
47 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
48 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
49 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
50 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
51 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
52 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
53 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
54 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
55 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
56 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
57 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
58 static constexpr const char *kGetValueStorage =
59  "mlirAsyncRuntimeGetValueStorage";
60 static constexpr const char *kAddTokenToGroup =
61  "mlirAsyncRuntimeAddTokenToGroup";
62 static constexpr const char *kAwaitTokenAndExecute =
63  "mlirAsyncRuntimeAwaitTokenAndExecute";
64 static constexpr const char *kAwaitValueAndExecute =
65  "mlirAsyncRuntimeAwaitValueAndExecute";
66 static constexpr const char *kAwaitAllAndExecute =
67  "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
68 static constexpr const char *kGetNumWorkerThreads =
69  "mlirAsyncRuntimGetNumWorkerThreads";
70 
71 namespace {
72 /// Async Runtime API function types.
73 ///
74 /// Because we can't create API function signature for type parametrized
75 /// async.getValue type, we use opaque pointers (!llvm.ptr) instead. After
76 /// lowering all async data types become opaque pointers at runtime.
77 struct AsyncAPI {
78  // All async types are lowered to opaque LLVM pointers at runtime.
79  static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
80  return LLVM::LLVMPointerType::get(ctx);
81  }
82 
83  static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
84  return LLVM::LLVMTokenType::get(ctx);
85  }
86 
87  static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
88  auto ref = opaquePointerType(ctx);
89  auto count = IntegerType::get(ctx, 64);
90  return FunctionType::get(ctx, {ref, count}, {});
91  }
92 
93  static FunctionType createTokenFunctionType(MLIRContext *ctx) {
94  return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
95  }
96 
97  static FunctionType createValueFunctionType(MLIRContext *ctx) {
98  auto i64 = IntegerType::get(ctx, 64);
99  auto value = opaquePointerType(ctx);
100  return FunctionType::get(ctx, {i64}, {value});
101  }
102 
103  static FunctionType createGroupFunctionType(MLIRContext *ctx) {
104  auto i64 = IntegerType::get(ctx, 64);
105  return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
106  }
107 
108  static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
109  auto ptrType = opaquePointerType(ctx);
110  return FunctionType::get(ctx, {ptrType}, {ptrType});
111  }
112 
113  static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
114  return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
115  }
116 
117  static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
118  auto value = opaquePointerType(ctx);
119  return FunctionType::get(ctx, {value}, {});
120  }
121 
122  static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
123  return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
124  }
125 
126  static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
127  auto value = opaquePointerType(ctx);
128  return FunctionType::get(ctx, {value}, {});
129  }
130 
131  static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
132  auto i1 = IntegerType::get(ctx, 1);
133  return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
134  }
135 
136  static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
137  auto value = opaquePointerType(ctx);
138  auto i1 = IntegerType::get(ctx, 1);
139  return FunctionType::get(ctx, {value}, {i1});
140  }
141 
142  static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
143  auto i1 = IntegerType::get(ctx, 1);
144  return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
145  }
146 
147  static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
148  return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
149  }
150 
151  static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
152  auto value = opaquePointerType(ctx);
153  return FunctionType::get(ctx, {value}, {});
154  }
155 
156  static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
157  return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
158  }
159 
160  static FunctionType executeFunctionType(MLIRContext *ctx) {
161  auto ptrType = opaquePointerType(ctx);
162  return FunctionType::get(ctx, {ptrType, ptrType}, {});
163  }
164 
165  static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
166  auto i64 = IntegerType::get(ctx, 64);
167  return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
168  {i64});
169  }
170 
171  static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
172  auto ptrType = opaquePointerType(ctx);
173  return FunctionType::get(ctx, {TokenType::get(ctx), ptrType, ptrType}, {});
174  }
175 
176  static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
177  auto ptrType = opaquePointerType(ctx);
178  return FunctionType::get(ctx, {ptrType, ptrType, ptrType}, {});
179  }
180 
181  static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
182  auto ptrType = opaquePointerType(ctx);
183  return FunctionType::get(ctx, {GroupType::get(ctx), ptrType, ptrType}, {});
184  }
185 
186  static FunctionType getNumWorkerThreads(MLIRContext *ctx) {
187  return FunctionType::get(ctx, {}, {IndexType::get(ctx)});
188  }
189 
190  // Auxiliary coroutine resume intrinsic wrapper.
191  static Type resumeFunctionType(MLIRContext *ctx) {
192  auto voidTy = LLVM::LLVMVoidType::get(ctx);
193  auto ptrType = opaquePointerType(ctx);
194  return LLVM::LLVMFunctionType::get(voidTy, {ptrType}, false);
195  }
196 };
197 } // namespace
198 
199 /// Adds Async Runtime C API declarations to the module.
200 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
201  auto builder =
202  ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
203 
204  auto addFuncDecl = [&](StringRef name, FunctionType type) {
205  if (module.lookupSymbol(name))
206  return;
207  builder.create<func::FuncOp>(name, type).setPrivate();
208  };
209 
210  MLIRContext *ctx = module.getContext();
211  addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
212  addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
213  addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
214  addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
215  addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
216  addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
217  addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
218  addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
219  addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
220  addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
221  addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
222  addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
223  addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
224  addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
225  addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
226  addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
227  addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
228  addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
229  addFuncDecl(kAwaitTokenAndExecute,
230  AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
231  addFuncDecl(kAwaitValueAndExecute,
232  AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
233  addFuncDecl(kAwaitAllAndExecute,
234  AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
235  addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx));
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // Coroutine resume function wrapper.
240 //===----------------------------------------------------------------------===//
241 
242 static constexpr const char *kResume = "__resume";
243 
244 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
245 /// intrinsics. We need this function to be able to pass it to the async
246 /// runtime execute API.
247 static void addResumeFunction(ModuleOp module) {
248  if (module.lookupSymbol(kResume))
249  return;
250 
251  MLIRContext *ctx = module.getContext();
252  auto loc = module.getLoc();
253  auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
254 
255  auto voidTy = LLVM::LLVMVoidType::get(ctx);
256  Type ptrType = AsyncAPI::opaquePointerType(ctx);
257 
258  auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
259  kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
260  resumeOp.setPrivate();
261 
262  auto *block = resumeOp.addEntryBlock(moduleBuilder);
263  auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
264 
265  blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
266  blockBuilder.create<LLVM::ReturnOp>(ValueRange());
267 }
268 
269 //===----------------------------------------------------------------------===//
270 // Convert Async dialect types to LLVM types.
271 //===----------------------------------------------------------------------===//
272 
273 namespace {
274 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
275 /// their runtime type (opaque pointers) and does not convert any other types.
276 class AsyncRuntimeTypeConverter : public TypeConverter {
277 public:
278  AsyncRuntimeTypeConverter(const LowerToLLVMOptions &options) {
279  addConversion([](Type type) { return type; });
280  addConversion([](Type type) { return convertAsyncTypes(type); });
281 
282  // Use UnrealizedConversionCast as the bridge so that we don't need to pull
283  // in patterns for other dialects.
284  auto addUnrealizedCast = [](OpBuilder &builder, Type type,
285  ValueRange inputs, Location loc) -> Value {
286  auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
287  return cast.getResult(0);
288  };
289 
290  addSourceMaterialization(addUnrealizedCast);
291  addTargetMaterialization(addUnrealizedCast);
292  }
293 
294  static std::optional<Type> convertAsyncTypes(Type type) {
295  if (isa<TokenType, GroupType, ValueType>(type))
296  return AsyncAPI::opaquePointerType(type.getContext());
297 
298  if (isa<CoroIdType, CoroStateType>(type))
299  return AsyncAPI::tokenType(type.getContext());
300  if (isa<CoroHandleType>(type))
301  return AsyncAPI::opaquePointerType(type.getContext());
302 
303  return std::nullopt;
304  }
305 };
306 
307 /// Base class for conversion patterns requiring AsyncRuntimeTypeConverter
308 /// as type converter. Allows access to it via the 'getTypeConverter'
309 /// convenience method.
310 template <typename SourceOp>
311 class AsyncOpConversionPattern : public OpConversionPattern<SourceOp> {
312 
313  using Base = OpConversionPattern<SourceOp>;
314 
315 public:
316  AsyncOpConversionPattern(const AsyncRuntimeTypeConverter &typeConverter,
317  MLIRContext *context)
318  : Base(typeConverter, context) {}
319 
320  /// Returns the 'AsyncRuntimeTypeConverter' of the pattern.
321  const AsyncRuntimeTypeConverter *getTypeConverter() const {
322  return static_cast<const AsyncRuntimeTypeConverter *>(
323  Base::getTypeConverter());
324  }
325 };
326 
327 } // namespace
328 
329 //===----------------------------------------------------------------------===//
330 // Convert async.coro.id to @llvm.coro.id intrinsic.
331 //===----------------------------------------------------------------------===//
332 
333 namespace {
334 class CoroIdOpConversion : public AsyncOpConversionPattern<CoroIdOp> {
335 public:
336  using AsyncOpConversionPattern::AsyncOpConversionPattern;
337 
338  LogicalResult
339  matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
340  ConversionPatternRewriter &rewriter) const override {
341  auto token = AsyncAPI::tokenType(op->getContext());
342  auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
343  auto loc = op->getLoc();
344 
345  // Constants for initializing coroutine frame.
346  auto constZero =
347  rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
348  auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
349 
350  // Get coroutine id: @llvm.coro.id.
351  rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
352  op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
353 
354  return success();
355  }
356 };
357 } // namespace
358 
359 //===----------------------------------------------------------------------===//
360 // Convert async.coro.begin to @llvm.coro.begin intrinsic.
361 //===----------------------------------------------------------------------===//
362 
363 namespace {
364 class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
365 public:
366  using AsyncOpConversionPattern::AsyncOpConversionPattern;
367 
368  LogicalResult
369  matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
370  ConversionPatternRewriter &rewriter) const override {
371  auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
372  auto loc = op->getLoc();
373 
374  // Get coroutine frame size: @llvm.coro.size.i64.
375  Value coroSize =
376  rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
377  // Get coroutine frame alignment: @llvm.coro.align.i64.
378  Value coroAlign =
379  rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type());
380 
381  // Round up the size to be multiple of the alignment. Since aligned_alloc
382  // requires the size parameter be an integral multiple of the alignment
383  // parameter.
384  auto makeConstant = [&](uint64_t c) {
385  return rewriter.create<LLVM::ConstantOp>(op->getLoc(),
386  rewriter.getI64Type(), c);
387  };
388  coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign);
389  coroSize =
390  rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1));
391  Value negCoroAlign =
392  rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign);
393  coroSize =
394  rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign);
395 
396  // Allocate memory for the coroutine frame.
397  auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
398  op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
399  auto coroAlloc = rewriter.create<LLVM::CallOp>(
400  loc, allocFuncOp, ValueRange{coroAlign, coroSize});
401 
402  // Begin a coroutine: @llvm.coro.begin.
403  auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
404  rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
405  op, ptrType, ValueRange({coroId, coroAlloc.getResult()}));
406 
407  return success();
408  }
409 };
410 } // namespace
411 
412 //===----------------------------------------------------------------------===//
413 // Convert async.coro.free to @llvm.coro.free intrinsic.
414 //===----------------------------------------------------------------------===//
415 
416 namespace {
417 class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
418 public:
419  using AsyncOpConversionPattern::AsyncOpConversionPattern;
420 
421  LogicalResult
422  matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
423  ConversionPatternRewriter &rewriter) const override {
424  auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
425  auto loc = op->getLoc();
426 
427  // Get a pointer to the coroutine frame memory: @llvm.coro.free.
428  auto coroMem =
429  rewriter.create<LLVM::CoroFreeOp>(loc, ptrType, adaptor.getOperands());
430 
431  // Free the memory.
432  auto freeFuncOp =
433  LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
434  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp,
435  ValueRange(coroMem.getResult()));
436 
437  return success();
438  }
439 };
440 } // namespace
441 
442 //===----------------------------------------------------------------------===//
443 // Convert async.coro.end to @llvm.coro.end intrinsic.
444 //===----------------------------------------------------------------------===//
445 
446 namespace {
447 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
448 public:
450 
451  LogicalResult
452  matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
453  ConversionPatternRewriter &rewriter) const override {
454  // We are not in the block that is part of the unwind sequence.
455  auto constFalse = rewriter.create<LLVM::ConstantOp>(
456  op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
457  auto noneToken = rewriter.create<LLVM::NoneTokenOp>(op->getLoc());
458 
459  // Mark the end of a coroutine: @llvm.coro.end.
460  auto coroHdl = adaptor.getHandle();
461  rewriter.create<LLVM::CoroEndOp>(
462  op->getLoc(), rewriter.getI1Type(),
463  ValueRange({coroHdl, constFalse, noneToken}));
464  rewriter.eraseOp(op);
465 
466  return success();
467  }
468 };
469 } // namespace
470 
471 //===----------------------------------------------------------------------===//
472 // Convert async.coro.save to @llvm.coro.save intrinsic.
473 //===----------------------------------------------------------------------===//
474 
475 namespace {
476 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
477 public:
479 
480  LogicalResult
481  matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor,
482  ConversionPatternRewriter &rewriter) const override {
483  // Save the coroutine state: @llvm.coro.save
484  rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
485  op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands());
486 
487  return success();
488  }
489 };
490 } // namespace
491 
492 //===----------------------------------------------------------------------===//
493 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
494 //===----------------------------------------------------------------------===//
495 
496 namespace {
497 
498 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
499 /// branch to the appropriate block based on the return code.
500 ///
501 /// Before:
502 ///
503 /// ^suspended:
504 /// "opBefore"(...)
505 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup
506 /// ^resume:
507 /// "op"(...)
508 /// ^cleanup: ...
509 /// ^suspend: ...
510 ///
511 /// After:
512 ///
513 /// ^suspended:
514 /// "opBefore"(...)
515 /// %suspend = llmv.intr.coro.suspend ...
516 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
517 /// ^resume:
518 /// "op"(...)
519 /// ^cleanup: ...
520 /// ^suspend: ...
521 ///
522 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
523 public:
525 
526  LogicalResult
527  matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor,
528  ConversionPatternRewriter &rewriter) const override {
529  auto i8 = rewriter.getIntegerType(8);
530  auto i32 = rewriter.getI32Type();
531  auto loc = op->getLoc();
532 
533  // This is not a final suspension point.
534  auto constFalse = rewriter.create<LLVM::ConstantOp>(
535  loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
536 
537  // Suspend a coroutine: @llvm.coro.suspend
538  auto coroState = adaptor.getState();
539  auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
540  loc, i8, ValueRange({coroState, constFalse}));
541 
542  // Cast return code to i32.
543 
544  // After a suspension point decide if we should branch into resume, cleanup
545  // or suspend block of the coroutine (see @llvm.coro.suspend return code
546  // documentation).
547  llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
548  llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(),
549  op.getCleanupDest()};
550  rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
551  op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
552  /*defaultDestination=*/op.getSuspendDest(),
553  /*defaultOperands=*/ValueRange(),
554  /*caseValues=*/caseValues,
555  /*caseDestinations=*/caseDest,
556  /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}),
557  /*branchWeights=*/ArrayRef<int32_t>());
558 
559  return success();
560  }
561 };
562 } // namespace
563 
564 //===----------------------------------------------------------------------===//
565 // Convert async.runtime.create to the corresponding runtime API call.
566 //
567 // To allocate storage for the async values we use getelementptr trick:
568 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
569 //===----------------------------------------------------------------------===//
570 
571 namespace {
572 class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> {
573 public:
575 
576  LogicalResult
577  matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
578  ConversionPatternRewriter &rewriter) const override {
579  const TypeConverter *converter = getTypeConverter();
580  Type resultType = op->getResultTypes()[0];
581 
582  // Tokens creation maps to a simple function call.
583  if (isa<TokenType>(resultType)) {
584  rewriter.replaceOpWithNewOp<func::CallOp>(
585  op, kCreateToken, converter->convertType(resultType));
586  return success();
587  }
588 
589  // To create a value we need to compute the storage requirement.
590  if (auto value = dyn_cast<ValueType>(resultType)) {
591  // Returns the size requirements for the async value storage.
592  auto sizeOf = [&](ValueType valueType) -> Value {
593  auto loc = op->getLoc();
594  auto i64 = rewriter.getI64Type();
595 
596  auto storedType = converter->convertType(valueType.getValueType());
597  auto storagePtrType =
598  AsyncAPI::opaquePointerType(rewriter.getContext());
599 
600  // %Size = getelementptr %T* null, int 1
601  // %SizeI = ptrtoint %T* %Size to i64
602  auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, storagePtrType);
603  auto gep =
604  rewriter.create<LLVM::GEPOp>(loc, storagePtrType, storedType,
605  nullPtr, ArrayRef<LLVM::GEPArg>{1});
606  return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
607  };
608 
609  rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType,
610  sizeOf(value));
611 
612  return success();
613  }
614 
615  return rewriter.notifyMatchFailure(op, "unsupported async type");
616  }
617 };
618 } // namespace
619 
620 //===----------------------------------------------------------------------===//
621 // Convert async.runtime.create_group to the corresponding runtime API call.
622 //===----------------------------------------------------------------------===//
623 
624 namespace {
625 class RuntimeCreateGroupOpLowering
626  : public ConvertOpToLLVMPattern<RuntimeCreateGroupOp> {
627 public:
629 
630  LogicalResult
631  matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
632  ConversionPatternRewriter &rewriter) const override {
633  const TypeConverter *converter = getTypeConverter();
634  Type resultType = op.getResult().getType();
635 
636  rewriter.replaceOpWithNewOp<func::CallOp>(
637  op, kCreateGroup, converter->convertType(resultType),
638  adaptor.getOperands());
639  return success();
640  }
641 };
642 } // namespace
643 
644 //===----------------------------------------------------------------------===//
645 // Convert async.runtime.set_available to the corresponding runtime API call.
646 //===----------------------------------------------------------------------===//
647 
648 namespace {
649 class RuntimeSetAvailableOpLowering
650  : public OpConversionPattern<RuntimeSetAvailableOp> {
651 public:
653 
654  LogicalResult
655  matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
656  ConversionPatternRewriter &rewriter) const override {
657  StringRef apiFuncName =
658  TypeSwitch<Type, StringRef>(op.getOperand().getType())
659  .Case<TokenType>([](Type) { return kEmplaceToken; })
660  .Case<ValueType>([](Type) { return kEmplaceValue; });
661 
662  rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
663  adaptor.getOperands());
664 
665  return success();
666  }
667 };
668 } // namespace
669 
670 //===----------------------------------------------------------------------===//
671 // Convert async.runtime.set_error to the corresponding runtime API call.
672 //===----------------------------------------------------------------------===//
673 
674 namespace {
675 class RuntimeSetErrorOpLowering
676  : public OpConversionPattern<RuntimeSetErrorOp> {
677 public:
679 
680  LogicalResult
681  matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
682  ConversionPatternRewriter &rewriter) const override {
683  StringRef apiFuncName =
684  TypeSwitch<Type, StringRef>(op.getOperand().getType())
685  .Case<TokenType>([](Type) { return kSetTokenError; })
686  .Case<ValueType>([](Type) { return kSetValueError; });
687 
688  rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
689  adaptor.getOperands());
690 
691  return success();
692  }
693 };
694 } // namespace
695 
696 //===----------------------------------------------------------------------===//
697 // Convert async.runtime.is_error to the corresponding runtime API call.
698 //===----------------------------------------------------------------------===//
699 
700 namespace {
701 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
702 public:
704 
705  LogicalResult
706  matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
707  ConversionPatternRewriter &rewriter) const override {
708  StringRef apiFuncName =
709  TypeSwitch<Type, StringRef>(op.getOperand().getType())
710  .Case<TokenType>([](Type) { return kIsTokenError; })
711  .Case<GroupType>([](Type) { return kIsGroupError; })
712  .Case<ValueType>([](Type) { return kIsValueError; });
713 
714  rewriter.replaceOpWithNewOp<func::CallOp>(
715  op, apiFuncName, rewriter.getI1Type(), adaptor.getOperands());
716  return success();
717  }
718 };
719 } // namespace
720 
721 //===----------------------------------------------------------------------===//
722 // Convert async.runtime.await to the corresponding runtime API call.
723 //===----------------------------------------------------------------------===//
724 
725 namespace {
726 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
727 public:
729 
730  LogicalResult
731  matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
732  ConversionPatternRewriter &rewriter) const override {
733  StringRef apiFuncName =
734  TypeSwitch<Type, StringRef>(op.getOperand().getType())
735  .Case<TokenType>([](Type) { return kAwaitToken; })
736  .Case<ValueType>([](Type) { return kAwaitValue; })
737  .Case<GroupType>([](Type) { return kAwaitGroup; });
738 
739  rewriter.create<func::CallOp>(op->getLoc(), apiFuncName, TypeRange(),
740  adaptor.getOperands());
741  rewriter.eraseOp(op);
742 
743  return success();
744  }
745 };
746 } // namespace
747 
748 //===----------------------------------------------------------------------===//
749 // Convert async.runtime.await_and_resume to the corresponding runtime API call.
750 //===----------------------------------------------------------------------===//
751 
752 namespace {
753 class RuntimeAwaitAndResumeOpLowering
754  : public AsyncOpConversionPattern<RuntimeAwaitAndResumeOp> {
755 public:
756  using AsyncOpConversionPattern::AsyncOpConversionPattern;
757 
758  LogicalResult
759  matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
760  ConversionPatternRewriter &rewriter) const override {
761  StringRef apiFuncName =
762  TypeSwitch<Type, StringRef>(op.getOperand().getType())
763  .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
764  .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
765  .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
766 
767  Value operand = adaptor.getOperand();
768  Value handle = adaptor.getHandle();
769 
770  // A pointer to coroutine resume intrinsic wrapper.
771  addResumeFunction(op->getParentOfType<ModuleOp>());
772  auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
773  op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
774  kResume);
775 
776  rewriter.create<func::CallOp>(
777  op->getLoc(), apiFuncName, TypeRange(),
778  ValueRange({operand, handle, resumePtr.getRes()}));
779  rewriter.eraseOp(op);
780 
781  return success();
782  }
783 };
784 } // namespace
785 
786 //===----------------------------------------------------------------------===//
787 // Convert async.runtime.resume to the corresponding runtime API call.
788 //===----------------------------------------------------------------------===//
789 
790 namespace {
791 class RuntimeResumeOpLowering
792  : public AsyncOpConversionPattern<RuntimeResumeOp> {
793 public:
794  using AsyncOpConversionPattern::AsyncOpConversionPattern;
795 
796  LogicalResult
797  matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
798  ConversionPatternRewriter &rewriter) const override {
799  // A pointer to coroutine resume intrinsic wrapper.
800  addResumeFunction(op->getParentOfType<ModuleOp>());
801  auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
802  op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
803  kResume);
804 
805  // Call async runtime API to execute a coroutine in the managed thread.
806  auto coroHdl = adaptor.getHandle();
807  rewriter.replaceOpWithNewOp<func::CallOp>(
808  op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()}));
809 
810  return success();
811  }
812 };
813 } // namespace
814 
815 //===----------------------------------------------------------------------===//
816 // Convert async.runtime.store to the corresponding runtime API call.
817 //===----------------------------------------------------------------------===//
818 
819 namespace {
820 class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> {
821 public:
823 
824  LogicalResult
825  matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
826  ConversionPatternRewriter &rewriter) const override {
827  Location loc = op->getLoc();
828 
829  // Get a pointer to the async value storage from the runtime.
830  auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
831  auto storage = adaptor.getStorage();
832  auto storagePtr = rewriter.create<func::CallOp>(
833  loc, kGetValueStorage, TypeRange(ptrType), storage);
834 
835  // Cast from i8* to the LLVM pointer type.
836  auto valueType = op.getValue().getType();
837  auto llvmValueType = getTypeConverter()->convertType(valueType);
838  if (!llvmValueType)
839  return rewriter.notifyMatchFailure(
840  op, "failed to convert stored value type to LLVM type");
841 
842  Value castedStoragePtr = storagePtr.getResult(0);
843  // Store the yielded value into the async value storage.
844  auto value = adaptor.getValue();
845  rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr);
846 
847  // Erase the original runtime store operation.
848  rewriter.eraseOp(op);
849 
850  return success();
851  }
852 };
853 } // namespace
854 
855 //===----------------------------------------------------------------------===//
856 // Convert async.runtime.load to the corresponding runtime API call.
857 //===----------------------------------------------------------------------===//
858 
859 namespace {
860 class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern<RuntimeLoadOp> {
861 public:
863 
864  LogicalResult
865  matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
866  ConversionPatternRewriter &rewriter) const override {
867  Location loc = op->getLoc();
868 
869  // Get a pointer to the async value storage from the runtime.
870  auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
871  auto storage = adaptor.getStorage();
872  auto storagePtr = rewriter.create<func::CallOp>(
873  loc, kGetValueStorage, TypeRange(ptrType), storage);
874 
875  // Cast from i8* to the LLVM pointer type.
876  auto valueType = op.getResult().getType();
877  auto llvmValueType = getTypeConverter()->convertType(valueType);
878  if (!llvmValueType)
879  return rewriter.notifyMatchFailure(
880  op, "failed to convert loaded value type to LLVM type");
881 
882  Value castedStoragePtr = storagePtr.getResult(0);
883 
884  // Load from the casted pointer.
885  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmValueType,
886  castedStoragePtr);
887 
888  return success();
889  }
890 };
891 } // namespace
892 
893 //===----------------------------------------------------------------------===//
894 // Convert async.runtime.add_to_group to the corresponding runtime API call.
895 //===----------------------------------------------------------------------===//
896 
897 namespace {
898 class RuntimeAddToGroupOpLowering
899  : public OpConversionPattern<RuntimeAddToGroupOp> {
900 public:
902 
903  LogicalResult
904  matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
905  ConversionPatternRewriter &rewriter) const override {
906  // Currently we can only add tokens to the group.
907  if (!isa<TokenType>(op.getOperand().getType()))
908  return rewriter.notifyMatchFailure(op, "only token type is supported");
909 
910  // Replace with a runtime API function call.
911  rewriter.replaceOpWithNewOp<func::CallOp>(
912  op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands());
913 
914  return success();
915  }
916 };
917 } // namespace
918 
919 //===----------------------------------------------------------------------===//
920 // Convert async.runtime.num_worker_threads to the corresponding runtime API
921 // call.
922 //===----------------------------------------------------------------------===//
923 
924 namespace {
925 class RuntimeNumWorkerThreadsOpLowering
926  : public OpConversionPattern<RuntimeNumWorkerThreadsOp> {
927 public:
929 
930  LogicalResult
931  matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor,
932  ConversionPatternRewriter &rewriter) const override {
933 
934  // Replace with a runtime API function call.
935  rewriter.replaceOpWithNewOp<func::CallOp>(op, kGetNumWorkerThreads,
936  rewriter.getIndexType());
937 
938  return success();
939  }
940 };
941 } // namespace
942 
943 //===----------------------------------------------------------------------===//
944 // Async reference counting ops lowering (`async.runtime.add_ref` and
945 // `async.runtime.drop_ref` to the corresponding API calls).
946 //===----------------------------------------------------------------------===//
947 
948 namespace {
949 template <typename RefCountingOp>
950 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
951 public:
952  explicit RefCountingOpLowering(const TypeConverter &converter,
953  MLIRContext *ctx, StringRef apiFunctionName)
954  : OpConversionPattern<RefCountingOp>(converter, ctx),
955  apiFunctionName(apiFunctionName) {}
956 
957  LogicalResult
958  matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
959  ConversionPatternRewriter &rewriter) const override {
960  auto count = rewriter.create<arith::ConstantOp>(
961  op->getLoc(), rewriter.getI64Type(),
962  rewriter.getI64IntegerAttr(op.getCount()));
963 
964  auto operand = adaptor.getOperand();
965  rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName,
966  ValueRange({operand, count}));
967 
968  return success();
969  }
970 
971 private:
972  StringRef apiFunctionName;
973 };
974 
975 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
976 public:
977  explicit RuntimeAddRefOpLowering(const TypeConverter &converter,
978  MLIRContext *ctx)
979  : RefCountingOpLowering(converter, ctx, kAddRef) {}
980 };
981 
982 class RuntimeDropRefOpLowering
983  : public RefCountingOpLowering<RuntimeDropRefOp> {
984 public:
985  explicit RuntimeDropRefOpLowering(const TypeConverter &converter,
986  MLIRContext *ctx)
987  : RefCountingOpLowering(converter, ctx, kDropRef) {}
988 };
989 } // namespace
990 
991 //===----------------------------------------------------------------------===//
992 // Convert return operations that return async values from async regions.
993 //===----------------------------------------------------------------------===//
994 
995 namespace {
996 class ReturnOpOpConversion : public OpConversionPattern<func::ReturnOp> {
997 public:
999 
1000  LogicalResult
1001  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
1002  ConversionPatternRewriter &rewriter) const override {
1003  rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
1004  return success();
1005  }
1006 };
1007 } // namespace
1008 
1009 //===----------------------------------------------------------------------===//
1010 
1011 namespace {
1012 struct ConvertAsyncToLLVMPass
1013  : public impl::ConvertAsyncToLLVMPassBase<ConvertAsyncToLLVMPass> {
1014  using Base::Base;
1015 
1016  void runOnOperation() override;
1017 };
1018 } // namespace
1019 
1020 void ConvertAsyncToLLVMPass::runOnOperation() {
1021  ModuleOp module = getOperation();
1022  MLIRContext *ctx = module->getContext();
1023 
1025 
1026  // Add declarations for most functions required by the coroutines lowering.
1027  // We delay adding the resume function until it's needed because it currently
1028  // fails to compile unless '-O0' is specified.
1030 
1031  // Lower async.runtime and async.coro operations to Async Runtime API and
1032  // LLVM coroutine intrinsics.
1033 
1034  // Convert async dialect types and operations to LLVM dialect.
1035  AsyncRuntimeTypeConverter converter(options);
1036  RewritePatternSet patterns(ctx);
1037 
1038  // We use conversion to LLVM type to lower async.runtime load and store
1039  // operations.
1040  LLVMTypeConverter llvmConverter(ctx, options);
1041  llvmConverter.addConversion([&](Type type) {
1042  return AsyncRuntimeTypeConverter::convertAsyncTypes(type);
1043  });
1044 
1045  // Convert async types in function signatures and function calls.
1046  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1047  converter);
1048  populateCallOpTypeConversionPattern(patterns, converter);
1049 
1050  // Convert return operations inside async.execute regions.
1051  patterns.add<ReturnOpOpConversion>(converter, ctx);
1052 
1053  // Lower async.runtime operations to the async runtime API calls.
1054  patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
1055  RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
1056  RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
1057  RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering,
1058  RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter,
1059  ctx);
1060 
1061  // Lower async.runtime operations that rely on LLVM type converter to convert
1062  // from async value payload type to the LLVM type.
1063  patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
1064  RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter);
1065 
1066  // Lower async coroutine operations to LLVM coroutine intrinsics.
1067  patterns
1068  .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
1069  CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
1070  converter, ctx);
1071 
1072  ConversionTarget target(*ctx);
1073  target.addLegalOp<arith::ConstantOp, func::ConstantOp,
1074  UnrealizedConversionCastOp>();
1075  target.addLegalDialect<LLVM::LLVMDialect>();
1076 
1077  // All operations from Async dialect must be lowered to the runtime API and
1078  // LLVM intrinsics calls.
1079  target.addIllegalDialect<AsyncDialect>();
1080 
1081  // Add dynamic legality constraints to apply conversions defined above.
1082  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1083  return converter.isSignatureLegal(op.getFunctionType());
1084  });
1085  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
1086  return converter.isLegal(op.getOperandTypes());
1087  });
1088  target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
1089  return converter.isSignatureLegal(op.getCalleeType());
1090  });
1091 
1092  if (failed(applyPartialConversion(module, target, std::move(patterns))))
1093  signalPassFailure();
1094 }
1095 
1096 //===----------------------------------------------------------------------===//
1097 // Patterns for structural type conversions for the Async dialect operations.
1098 //===----------------------------------------------------------------------===//
1099 
1100 namespace {
1101 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1102 public:
1104  LogicalResult
1105  matchAndRewrite(ExecuteOp op, OpAdaptor adaptor,
1106  ConversionPatternRewriter &rewriter) const override {
1107  ExecuteOp newOp =
1108  cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1109  rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1110  newOp.getRegion().end());
1111 
1112  // Set operands and update block argument and result types.
1113  newOp->setOperands(adaptor.getOperands());
1114  if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1115  return failure();
1116  for (auto result : newOp.getResults())
1117  result.setType(typeConverter->convertType(result.getType()));
1118 
1119  rewriter.replaceOp(op, newOp.getResults());
1120  return success();
1121  }
1122 };
1123 
1124 // Dummy pattern to trigger the appropriate type conversion / materialization.
1125 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1126 public:
1128  LogicalResult
1129  matchAndRewrite(AwaitOp op, OpAdaptor adaptor,
1130  ConversionPatternRewriter &rewriter) const override {
1131  rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front());
1132  return success();
1133  }
1134 };
1135 
1136 // Dummy pattern to trigger the appropriate type conversion / materialization.
1137 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1138 public:
1140  LogicalResult
1141  matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
1142  ConversionPatternRewriter &rewriter) const override {
1143  rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands());
1144  return success();
1145  }
1146 };
1147 } // namespace
1148 
1150  TypeConverter &typeConverter, RewritePatternSet &patterns,
1151  ConversionTarget &target) {
1152  typeConverter.addConversion([&](TokenType type) { return type; });
1153  typeConverter.addConversion([&](ValueType type) {
1154  Type converted = typeConverter.convertType(type.getValueType());
1155  return converted ? ValueType::get(converted) : converted;
1156  });
1157 
1158  patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1159  typeConverter, patterns.getContext());
1160 
1161  target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1162  [&](Operation *op) { return typeConverter.isLegal(op); });
1163 }
static constexpr const char * kAwaitValueAndExecute
Definition: AsyncToLLVM.cpp:64
static constexpr const char * kCreateValue
Definition: AsyncToLLVM.cpp:45
static constexpr const char * kCreateGroup
Definition: AsyncToLLVM.cpp:46
static constexpr const char * kCreateToken
Definition: AsyncToLLVM.cpp:44
static constexpr const char * kEmplaceValue
Definition: AsyncToLLVM.cpp:48
static void addResumeFunction(ModuleOp module)
A function that takes a coroutine handle and calls a llvm.coro.resume intrinsics.
static constexpr const char * kEmplaceToken
Definition: AsyncToLLVM.cpp:47
static void addAsyncRuntimeApiDeclarations(ModuleOp module)
Adds Async Runtime C API declarations to the module.
static constexpr const char * kResume
static constexpr const char * kAddRef
Definition: AsyncToLLVM.cpp:42
static constexpr const char * kAwaitTokenAndExecute
Definition: AsyncToLLVM.cpp:62
static constexpr const char * kAwaitValue
Definition: AsyncToLLVM.cpp:55
static constexpr const char * kSetTokenError
Definition: AsyncToLLVM.cpp:49
static constexpr const char * kExecute
Definition: AsyncToLLVM.cpp:57
static constexpr const char * kAddTokenToGroup
Definition: AsyncToLLVM.cpp:60
static constexpr const char * kIsGroupError
Definition: AsyncToLLVM.cpp:53
static constexpr const char * kSetValueError
Definition: AsyncToLLVM.cpp:50
static constexpr const char * kIsTokenError
Definition: AsyncToLLVM.cpp:51
static constexpr const char * kAwaitGroup
Definition: AsyncToLLVM.cpp:56
static constexpr const char * kAwaitAllAndExecute
Definition: AsyncToLLVM.cpp:66
static constexpr const char * kGetNumWorkerThreads
Definition: AsyncToLLVM.cpp:68
static constexpr const char * kDropRef
Definition: AsyncToLLVM.cpp:43
static constexpr const char * kIsValueError
Definition: AsyncToLLVM.cpp:52
static constexpr const char * kAwaitToken
Definition: AsyncToLLVM.cpp:54
static constexpr const char * kGetValueStorage
Definition: AsyncToLLVM.cpp:58
static llvm::ManagedStatic< PassManagerOptions > options
IntegerType getI64Type()
Definition: Builders.cpp:109
IntegerType getI32Type()
Definition: Builders.cpp:107
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:140
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:97
IndexType getIndexType()
Definition: Builders.cpp:95
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:147
static ImplicitLocOpBuilder atBlockEnd(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:215
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:592
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp)
LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(Operation *moduleOp, Type indexType)
Include the generated interface declarations.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.