MLIR  21.0.0git
AsyncToLLVM.cpp
Go to the documentation of this file.
1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/Pass/Pass.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTASYNCTOLLVMPASS
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 #define DEBUG_TYPE "convert-async-to-llvm"
34 
35 using namespace mlir;
36 using namespace mlir::async;
37 
38 //===----------------------------------------------------------------------===//
39 // Async Runtime C API declaration.
40 //===----------------------------------------------------------------------===//
41 
42 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
43 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
44 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
45 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
46 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
47 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
48 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
49 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
50 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
51 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
52 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
53 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
54 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
55 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
56 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
57 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
58 static constexpr const char *kGetValueStorage =
59  "mlirAsyncRuntimeGetValueStorage";
60 static constexpr const char *kAddTokenToGroup =
61  "mlirAsyncRuntimeAddTokenToGroup";
62 static constexpr const char *kAwaitTokenAndExecute =
63  "mlirAsyncRuntimeAwaitTokenAndExecute";
64 static constexpr const char *kAwaitValueAndExecute =
65  "mlirAsyncRuntimeAwaitValueAndExecute";
66 static constexpr const char *kAwaitAllAndExecute =
67  "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
68 static constexpr const char *kGetNumWorkerThreads =
69  "mlirAsyncRuntimGetNumWorkerThreads";
70 
71 namespace {
72 /// Async Runtime API function types.
73 ///
74 /// Because we can't create API function signature for type parametrized
75 /// async.getValue type, we use opaque pointers (!llvm.ptr) instead. After
76 /// lowering all async data types become opaque pointers at runtime.
77 struct AsyncAPI {
78  // All async types are lowered to opaque LLVM pointers at runtime.
79  static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
80  return LLVM::LLVMPointerType::get(ctx);
81  }
82 
83  static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
84  return LLVM::LLVMTokenType::get(ctx);
85  }
86 
87  static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
88  auto ref = opaquePointerType(ctx);
89  auto count = IntegerType::get(ctx, 64);
90  return FunctionType::get(ctx, {ref, count}, {});
91  }
92 
93  static FunctionType createTokenFunctionType(MLIRContext *ctx) {
94  return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
95  }
96 
97  static FunctionType createValueFunctionType(MLIRContext *ctx) {
98  auto i64 = IntegerType::get(ctx, 64);
99  auto value = opaquePointerType(ctx);
100  return FunctionType::get(ctx, {i64}, {value});
101  }
102 
103  static FunctionType createGroupFunctionType(MLIRContext *ctx) {
104  auto i64 = IntegerType::get(ctx, 64);
105  return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
106  }
107 
108  static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
109  auto ptrType = opaquePointerType(ctx);
110  return FunctionType::get(ctx, {ptrType}, {ptrType});
111  }
112 
113  static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
114  return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
115  }
116 
117  static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
118  auto value = opaquePointerType(ctx);
119  return FunctionType::get(ctx, {value}, {});
120  }
121 
122  static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
123  return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
124  }
125 
126  static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
127  auto value = opaquePointerType(ctx);
128  return FunctionType::get(ctx, {value}, {});
129  }
130 
131  static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
132  auto i1 = IntegerType::get(ctx, 1);
133  return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
134  }
135 
136  static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
137  auto value = opaquePointerType(ctx);
138  auto i1 = IntegerType::get(ctx, 1);
139  return FunctionType::get(ctx, {value}, {i1});
140  }
141 
142  static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
143  auto i1 = IntegerType::get(ctx, 1);
144  return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
145  }
146 
147  static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
148  return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
149  }
150 
151  static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
152  auto value = opaquePointerType(ctx);
153  return FunctionType::get(ctx, {value}, {});
154  }
155 
156  static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
157  return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
158  }
159 
160  static FunctionType executeFunctionType(MLIRContext *ctx) {
161  auto ptrType = opaquePointerType(ctx);
162  return FunctionType::get(ctx, {ptrType, ptrType}, {});
163  }
164 
165  static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
166  auto i64 = IntegerType::get(ctx, 64);
167  return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
168  {i64});
169  }
170 
171  static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
172  auto ptrType = opaquePointerType(ctx);
173  return FunctionType::get(ctx, {TokenType::get(ctx), ptrType, ptrType}, {});
174  }
175 
176  static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
177  auto ptrType = opaquePointerType(ctx);
178  return FunctionType::get(ctx, {ptrType, ptrType, ptrType}, {});
179  }
180 
181  static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
182  auto ptrType = opaquePointerType(ctx);
183  return FunctionType::get(ctx, {GroupType::get(ctx), ptrType, ptrType}, {});
184  }
185 
186  static FunctionType getNumWorkerThreads(MLIRContext *ctx) {
187  return FunctionType::get(ctx, {}, {IndexType::get(ctx)});
188  }
189 
190  // Auxiliary coroutine resume intrinsic wrapper.
191  static Type resumeFunctionType(MLIRContext *ctx) {
192  auto voidTy = LLVM::LLVMVoidType::get(ctx);
193  auto ptrType = opaquePointerType(ctx);
194  return LLVM::LLVMFunctionType::get(voidTy, {ptrType}, false);
195  }
196 };
197 } // namespace
198 
199 /// Adds Async Runtime C API declarations to the module.
200 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
201  auto builder =
202  ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
203 
204  auto addFuncDecl = [&](StringRef name, FunctionType type) {
205  if (module.lookupSymbol(name))
206  return;
207  builder.create<func::FuncOp>(name, type).setPrivate();
208  };
209 
210  MLIRContext *ctx = module.getContext();
211  addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
212  addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
213  addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
214  addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
215  addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
216  addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
217  addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
218  addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
219  addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
220  addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
221  addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
222  addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
223  addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
224  addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
225  addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
226  addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
227  addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
228  addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
229  addFuncDecl(kAwaitTokenAndExecute,
230  AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
231  addFuncDecl(kAwaitValueAndExecute,
232  AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
233  addFuncDecl(kAwaitAllAndExecute,
234  AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
235  addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx));
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // Coroutine resume function wrapper.
240 //===----------------------------------------------------------------------===//
241 
242 static constexpr const char *kResume = "__resume";
243 
244 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
245 /// intrinsics. We need this function to be able to pass it to the async
246 /// runtime execute API.
247 static void addResumeFunction(ModuleOp module) {
248  if (module.lookupSymbol(kResume))
249  return;
250 
251  MLIRContext *ctx = module.getContext();
252  auto loc = module.getLoc();
253  auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
254 
255  auto voidTy = LLVM::LLVMVoidType::get(ctx);
256  Type ptrType = AsyncAPI::opaquePointerType(ctx);
257 
258  auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
259  kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
260  resumeOp.setPrivate();
261 
262  auto *block = resumeOp.addEntryBlock(moduleBuilder);
263  auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
264 
265  blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
266  blockBuilder.create<LLVM::ReturnOp>(ValueRange());
267 }
268 
269 //===----------------------------------------------------------------------===//
270 // Convert Async dialect types to LLVM types.
271 //===----------------------------------------------------------------------===//
272 
273 namespace {
274 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
275 /// their runtime type (opaque pointers) and does not convert any other types.
276 class AsyncRuntimeTypeConverter : public TypeConverter {
277 public:
278  AsyncRuntimeTypeConverter(const LowerToLLVMOptions &options) {
279  addConversion([](Type type) { return type; });
280  addConversion([](Type type) { return convertAsyncTypes(type); });
281 
282  // Use UnrealizedConversionCast as the bridge so that we don't need to pull
283  // in patterns for other dialects.
284  auto addUnrealizedCast = [](OpBuilder &builder, Type type,
285  ValueRange inputs, Location loc) -> Value {
286  auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
287  return cast.getResult(0);
288  };
289 
290  addSourceMaterialization(addUnrealizedCast);
291  addTargetMaterialization(addUnrealizedCast);
292  }
293 
294  static std::optional<Type> convertAsyncTypes(Type type) {
295  if (isa<TokenType, GroupType, ValueType>(type))
296  return AsyncAPI::opaquePointerType(type.getContext());
297 
298  if (isa<CoroIdType, CoroStateType>(type))
299  return AsyncAPI::tokenType(type.getContext());
300  if (isa<CoroHandleType>(type))
301  return AsyncAPI::opaquePointerType(type.getContext());
302 
303  return std::nullopt;
304  }
305 };
306 
307 /// Base class for conversion patterns requiring AsyncRuntimeTypeConverter
308 /// as type converter. Allows access to it via the 'getTypeConverter'
309 /// convenience method.
310 template <typename SourceOp>
311 class AsyncOpConversionPattern : public OpConversionPattern<SourceOp> {
312 
313  using Base = OpConversionPattern<SourceOp>;
314 
315 public:
316  AsyncOpConversionPattern(const AsyncRuntimeTypeConverter &typeConverter,
317  MLIRContext *context)
318  : Base(typeConverter, context) {}
319 
320  /// Returns the 'AsyncRuntimeTypeConverter' of the pattern.
321  const AsyncRuntimeTypeConverter *getTypeConverter() const {
322  return static_cast<const AsyncRuntimeTypeConverter *>(
323  Base::getTypeConverter());
324  }
325 };
326 
327 } // namespace
328 
329 //===----------------------------------------------------------------------===//
330 // Convert async.coro.id to @llvm.coro.id intrinsic.
331 //===----------------------------------------------------------------------===//
332 
333 namespace {
334 class CoroIdOpConversion : public AsyncOpConversionPattern<CoroIdOp> {
335 public:
336  using AsyncOpConversionPattern::AsyncOpConversionPattern;
337 
338  LogicalResult
339  matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
340  ConversionPatternRewriter &rewriter) const override {
341  auto token = AsyncAPI::tokenType(op->getContext());
342  auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
343  auto loc = op->getLoc();
344 
345  // Constants for initializing coroutine frame.
346  auto constZero =
347  rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
348  auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
349 
350  // Get coroutine id: @llvm.coro.id.
351  rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
352  op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
353 
354  return success();
355  }
356 };
357 } // namespace
358 
359 //===----------------------------------------------------------------------===//
360 // Convert async.coro.begin to @llvm.coro.begin intrinsic.
361 //===----------------------------------------------------------------------===//
362 
363 namespace {
364 class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
365 public:
366  using AsyncOpConversionPattern::AsyncOpConversionPattern;
367 
368  LogicalResult
369  matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
370  ConversionPatternRewriter &rewriter) const override {
371  auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
372  auto loc = op->getLoc();
373 
374  // Get coroutine frame size: @llvm.coro.size.i64.
375  Value coroSize =
376  rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
377  // Get coroutine frame alignment: @llvm.coro.align.i64.
378  Value coroAlign =
379  rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type());
380 
381  // Round up the size to be multiple of the alignment. Since aligned_alloc
382  // requires the size parameter be an integral multiple of the alignment
383  // parameter.
384  auto makeConstant = [&](uint64_t c) {
385  return rewriter.create<LLVM::ConstantOp>(op->getLoc(),
386  rewriter.getI64Type(), c);
387  };
388  coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign);
389  coroSize =
390  rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1));
391  Value negCoroAlign =
392  rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign);
393  coroSize =
394  rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign);
395 
396  // Allocate memory for the coroutine frame.
397  auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
398  op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
399  if (failed(allocFuncOp))
400  return failure();
401  auto coroAlloc = rewriter.create<LLVM::CallOp>(
402  loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize});
403 
404  // Begin a coroutine: @llvm.coro.begin.
405  auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
406  rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
407  op, ptrType, ValueRange({coroId, coroAlloc.getResult()}));
408 
409  return success();
410  }
411 };
412 } // namespace
413 
414 //===----------------------------------------------------------------------===//
415 // Convert async.coro.free to @llvm.coro.free intrinsic.
416 //===----------------------------------------------------------------------===//
417 
418 namespace {
419 class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
420 public:
421  using AsyncOpConversionPattern::AsyncOpConversionPattern;
422 
423  LogicalResult
424  matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
425  ConversionPatternRewriter &rewriter) const override {
426  auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
427  auto loc = op->getLoc();
428 
429  // Get a pointer to the coroutine frame memory: @llvm.coro.free.
430  auto coroMem =
431  rewriter.create<LLVM::CoroFreeOp>(loc, ptrType, adaptor.getOperands());
432 
433  // Free the memory.
434  auto freeFuncOp =
435  LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
436  if (failed(freeFuncOp))
437  return failure();
438  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp.value(),
439  ValueRange(coroMem.getResult()));
440 
441  return success();
442  }
443 };
444 } // namespace
445 
446 //===----------------------------------------------------------------------===//
447 // Convert async.coro.end to @llvm.coro.end intrinsic.
448 //===----------------------------------------------------------------------===//
449 
450 namespace {
451 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
452 public:
454 
455  LogicalResult
456  matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
457  ConversionPatternRewriter &rewriter) const override {
458  // We are not in the block that is part of the unwind sequence.
459  auto constFalse = rewriter.create<LLVM::ConstantOp>(
460  op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
461  auto noneToken = rewriter.create<LLVM::NoneTokenOp>(op->getLoc());
462 
463  // Mark the end of a coroutine: @llvm.coro.end.
464  auto coroHdl = adaptor.getHandle();
465  rewriter.create<LLVM::CoroEndOp>(
466  op->getLoc(), rewriter.getI1Type(),
467  ValueRange({coroHdl, constFalse, noneToken}));
468  rewriter.eraseOp(op);
469 
470  return success();
471  }
472 };
473 } // namespace
474 
475 //===----------------------------------------------------------------------===//
476 // Convert async.coro.save to @llvm.coro.save intrinsic.
477 //===----------------------------------------------------------------------===//
478 
479 namespace {
480 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
481 public:
483 
484  LogicalResult
485  matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor,
486  ConversionPatternRewriter &rewriter) const override {
487  // Save the coroutine state: @llvm.coro.save
488  rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
489  op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands());
490 
491  return success();
492  }
493 };
494 } // namespace
495 
496 //===----------------------------------------------------------------------===//
497 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
498 //===----------------------------------------------------------------------===//
499 
500 namespace {
501 
502 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
503 /// branch to the appropriate block based on the return code.
504 ///
505 /// Before:
506 ///
507 /// ^suspended:
508 /// "opBefore"(...)
509 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup
510 /// ^resume:
511 /// "op"(...)
512 /// ^cleanup: ...
513 /// ^suspend: ...
514 ///
515 /// After:
516 ///
517 /// ^suspended:
518 /// "opBefore"(...)
519 /// %suspend = llmv.intr.coro.suspend ...
520 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
521 /// ^resume:
522 /// "op"(...)
523 /// ^cleanup: ...
524 /// ^suspend: ...
525 ///
526 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
527 public:
529 
530  LogicalResult
531  matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor,
532  ConversionPatternRewriter &rewriter) const override {
533  auto i8 = rewriter.getIntegerType(8);
534  auto i32 = rewriter.getI32Type();
535  auto loc = op->getLoc();
536 
537  // This is not a final suspension point.
538  auto constFalse = rewriter.create<LLVM::ConstantOp>(
539  loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
540 
541  // Suspend a coroutine: @llvm.coro.suspend
542  auto coroState = adaptor.getState();
543  auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
544  loc, i8, ValueRange({coroState, constFalse}));
545 
546  // Cast return code to i32.
547 
548  // After a suspension point decide if we should branch into resume, cleanup
549  // or suspend block of the coroutine (see @llvm.coro.suspend return code
550  // documentation).
551  llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
552  llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(),
553  op.getCleanupDest()};
554  rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
555  op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
556  /*defaultDestination=*/op.getSuspendDest(),
557  /*defaultOperands=*/ValueRange(),
558  /*caseValues=*/caseValues,
559  /*caseDestinations=*/caseDest,
560  /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}),
561  /*branchWeights=*/ArrayRef<int32_t>());
562 
563  return success();
564  }
565 };
566 } // namespace
567 
568 //===----------------------------------------------------------------------===//
569 // Convert async.runtime.create to the corresponding runtime API call.
570 //
571 // To allocate storage for the async values we use getelementptr trick:
572 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
573 //===----------------------------------------------------------------------===//
574 
575 namespace {
576 class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> {
577 public:
579 
580  LogicalResult
581  matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
582  ConversionPatternRewriter &rewriter) const override {
583  const TypeConverter *converter = getTypeConverter();
584  Type resultType = op->getResultTypes()[0];
585 
586  // Tokens creation maps to a simple function call.
587  if (isa<TokenType>(resultType)) {
588  rewriter.replaceOpWithNewOp<func::CallOp>(
589  op, kCreateToken, converter->convertType(resultType));
590  return success();
591  }
592 
593  // To create a value we need to compute the storage requirement.
594  if (auto value = dyn_cast<ValueType>(resultType)) {
595  // Returns the size requirements for the async value storage.
596  auto sizeOf = [&](ValueType valueType) -> Value {
597  auto loc = op->getLoc();
598  auto i64 = rewriter.getI64Type();
599 
600  auto storedType = converter->convertType(valueType.getValueType());
601  auto storagePtrType =
602  AsyncAPI::opaquePointerType(rewriter.getContext());
603 
604  // %Size = getelementptr %T* null, int 1
605  // %SizeI = ptrtoint %T* %Size to i64
606  auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, storagePtrType);
607  auto gep =
608  rewriter.create<LLVM::GEPOp>(loc, storagePtrType, storedType,
609  nullPtr, ArrayRef<LLVM::GEPArg>{1});
610  return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
611  };
612 
613  rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType,
614  sizeOf(value));
615 
616  return success();
617  }
618 
619  return rewriter.notifyMatchFailure(op, "unsupported async type");
620  }
621 };
622 } // namespace
623 
624 //===----------------------------------------------------------------------===//
625 // Convert async.runtime.create_group to the corresponding runtime API call.
626 //===----------------------------------------------------------------------===//
627 
628 namespace {
629 class RuntimeCreateGroupOpLowering
630  : public ConvertOpToLLVMPattern<RuntimeCreateGroupOp> {
631 public:
633 
634  LogicalResult
635  matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
636  ConversionPatternRewriter &rewriter) const override {
637  const TypeConverter *converter = getTypeConverter();
638  Type resultType = op.getResult().getType();
639 
640  rewriter.replaceOpWithNewOp<func::CallOp>(
641  op, kCreateGroup, converter->convertType(resultType),
642  adaptor.getOperands());
643  return success();
644  }
645 };
646 } // namespace
647 
648 //===----------------------------------------------------------------------===//
649 // Convert async.runtime.set_available to the corresponding runtime API call.
650 //===----------------------------------------------------------------------===//
651 
652 namespace {
653 class RuntimeSetAvailableOpLowering
654  : public OpConversionPattern<RuntimeSetAvailableOp> {
655 public:
657 
658  LogicalResult
659  matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
660  ConversionPatternRewriter &rewriter) const override {
661  StringRef apiFuncName =
662  TypeSwitch<Type, StringRef>(op.getOperand().getType())
663  .Case<TokenType>([](Type) { return kEmplaceToken; })
664  .Case<ValueType>([](Type) { return kEmplaceValue; });
665 
666  rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
667  adaptor.getOperands());
668 
669  return success();
670  }
671 };
672 } // namespace
673 
674 //===----------------------------------------------------------------------===//
675 // Convert async.runtime.set_error to the corresponding runtime API call.
676 //===----------------------------------------------------------------------===//
677 
678 namespace {
679 class RuntimeSetErrorOpLowering
680  : public OpConversionPattern<RuntimeSetErrorOp> {
681 public:
683 
684  LogicalResult
685  matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
686  ConversionPatternRewriter &rewriter) const override {
687  StringRef apiFuncName =
688  TypeSwitch<Type, StringRef>(op.getOperand().getType())
689  .Case<TokenType>([](Type) { return kSetTokenError; })
690  .Case<ValueType>([](Type) { return kSetValueError; });
691 
692  rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
693  adaptor.getOperands());
694 
695  return success();
696  }
697 };
698 } // namespace
699 
700 //===----------------------------------------------------------------------===//
701 // Convert async.runtime.is_error to the corresponding runtime API call.
702 //===----------------------------------------------------------------------===//
703 
704 namespace {
705 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
706 public:
708 
709  LogicalResult
710  matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
711  ConversionPatternRewriter &rewriter) const override {
712  StringRef apiFuncName =
713  TypeSwitch<Type, StringRef>(op.getOperand().getType())
714  .Case<TokenType>([](Type) { return kIsTokenError; })
715  .Case<GroupType>([](Type) { return kIsGroupError; })
716  .Case<ValueType>([](Type) { return kIsValueError; });
717 
718  rewriter.replaceOpWithNewOp<func::CallOp>(
719  op, apiFuncName, rewriter.getI1Type(), adaptor.getOperands());
720  return success();
721  }
722 };
723 } // namespace
724 
725 //===----------------------------------------------------------------------===//
726 // Convert async.runtime.await to the corresponding runtime API call.
727 //===----------------------------------------------------------------------===//
728 
729 namespace {
730 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
731 public:
733 
734  LogicalResult
735  matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
736  ConversionPatternRewriter &rewriter) const override {
737  StringRef apiFuncName =
738  TypeSwitch<Type, StringRef>(op.getOperand().getType())
739  .Case<TokenType>([](Type) { return kAwaitToken; })
740  .Case<ValueType>([](Type) { return kAwaitValue; })
741  .Case<GroupType>([](Type) { return kAwaitGroup; });
742 
743  rewriter.create<func::CallOp>(op->getLoc(), apiFuncName, TypeRange(),
744  adaptor.getOperands());
745  rewriter.eraseOp(op);
746 
747  return success();
748  }
749 };
750 } // namespace
751 
752 //===----------------------------------------------------------------------===//
753 // Convert async.runtime.await_and_resume to the corresponding runtime API call.
754 //===----------------------------------------------------------------------===//
755 
756 namespace {
757 class RuntimeAwaitAndResumeOpLowering
758  : public AsyncOpConversionPattern<RuntimeAwaitAndResumeOp> {
759 public:
760  using AsyncOpConversionPattern::AsyncOpConversionPattern;
761 
762  LogicalResult
763  matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
764  ConversionPatternRewriter &rewriter) const override {
765  StringRef apiFuncName =
766  TypeSwitch<Type, StringRef>(op.getOperand().getType())
767  .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
768  .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
769  .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
770 
771  Value operand = adaptor.getOperand();
772  Value handle = adaptor.getHandle();
773 
774  // A pointer to coroutine resume intrinsic wrapper.
775  addResumeFunction(op->getParentOfType<ModuleOp>());
776  auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
777  op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
778  kResume);
779 
780  rewriter.create<func::CallOp>(
781  op->getLoc(), apiFuncName, TypeRange(),
782  ValueRange({operand, handle, resumePtr.getRes()}));
783  rewriter.eraseOp(op);
784 
785  return success();
786  }
787 };
788 } // namespace
789 
790 //===----------------------------------------------------------------------===//
791 // Convert async.runtime.resume to the corresponding runtime API call.
792 //===----------------------------------------------------------------------===//
793 
794 namespace {
795 class RuntimeResumeOpLowering
796  : public AsyncOpConversionPattern<RuntimeResumeOp> {
797 public:
798  using AsyncOpConversionPattern::AsyncOpConversionPattern;
799 
800  LogicalResult
801  matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
802  ConversionPatternRewriter &rewriter) const override {
803  // A pointer to coroutine resume intrinsic wrapper.
804  addResumeFunction(op->getParentOfType<ModuleOp>());
805  auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
806  op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
807  kResume);
808 
809  // Call async runtime API to execute a coroutine in the managed thread.
810  auto coroHdl = adaptor.getHandle();
811  rewriter.replaceOpWithNewOp<func::CallOp>(
812  op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()}));
813 
814  return success();
815  }
816 };
817 } // namespace
818 
819 //===----------------------------------------------------------------------===//
820 // Convert async.runtime.store to the corresponding runtime API call.
821 //===----------------------------------------------------------------------===//
822 
823 namespace {
824 class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> {
825 public:
827 
828  LogicalResult
829  matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
830  ConversionPatternRewriter &rewriter) const override {
831  Location loc = op->getLoc();
832 
833  // Get a pointer to the async value storage from the runtime.
834  auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
835  auto storage = adaptor.getStorage();
836  auto storagePtr = rewriter.create<func::CallOp>(
837  loc, kGetValueStorage, TypeRange(ptrType), storage);
838 
839  // Cast from i8* to the LLVM pointer type.
840  auto valueType = op.getValue().getType();
841  auto llvmValueType = getTypeConverter()->convertType(valueType);
842  if (!llvmValueType)
843  return rewriter.notifyMatchFailure(
844  op, "failed to convert stored value type to LLVM type");
845 
846  Value castedStoragePtr = storagePtr.getResult(0);
847  // Store the yielded value into the async value storage.
848  auto value = adaptor.getValue();
849  rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr);
850 
851  // Erase the original runtime store operation.
852  rewriter.eraseOp(op);
853 
854  return success();
855  }
856 };
857 } // namespace
858 
859 //===----------------------------------------------------------------------===//
860 // Convert async.runtime.load to the corresponding runtime API call.
861 //===----------------------------------------------------------------------===//
862 
863 namespace {
864 class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern<RuntimeLoadOp> {
865 public:
867 
868  LogicalResult
869  matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
870  ConversionPatternRewriter &rewriter) const override {
871  Location loc = op->getLoc();
872 
873  // Get a pointer to the async value storage from the runtime.
874  auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
875  auto storage = adaptor.getStorage();
876  auto storagePtr = rewriter.create<func::CallOp>(
877  loc, kGetValueStorage, TypeRange(ptrType), storage);
878 
879  // Cast from i8* to the LLVM pointer type.
880  auto valueType = op.getResult().getType();
881  auto llvmValueType = getTypeConverter()->convertType(valueType);
882  if (!llvmValueType)
883  return rewriter.notifyMatchFailure(
884  op, "failed to convert loaded value type to LLVM type");
885 
886  Value castedStoragePtr = storagePtr.getResult(0);
887 
888  // Load from the casted pointer.
889  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmValueType,
890  castedStoragePtr);
891 
892  return success();
893  }
894 };
895 } // namespace
896 
897 //===----------------------------------------------------------------------===//
898 // Convert async.runtime.add_to_group to the corresponding runtime API call.
899 //===----------------------------------------------------------------------===//
900 
901 namespace {
902 class RuntimeAddToGroupOpLowering
903  : public OpConversionPattern<RuntimeAddToGroupOp> {
904 public:
906 
907  LogicalResult
908  matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
909  ConversionPatternRewriter &rewriter) const override {
910  // Currently we can only add tokens to the group.
911  if (!isa<TokenType>(op.getOperand().getType()))
912  return rewriter.notifyMatchFailure(op, "only token type is supported");
913 
914  // Replace with a runtime API function call.
915  rewriter.replaceOpWithNewOp<func::CallOp>(
916  op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands());
917 
918  return success();
919  }
920 };
921 } // namespace
922 
923 //===----------------------------------------------------------------------===//
924 // Convert async.runtime.num_worker_threads to the corresponding runtime API
925 // call.
926 //===----------------------------------------------------------------------===//
927 
928 namespace {
929 class RuntimeNumWorkerThreadsOpLowering
930  : public OpConversionPattern<RuntimeNumWorkerThreadsOp> {
931 public:
933 
934  LogicalResult
935  matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor,
936  ConversionPatternRewriter &rewriter) const override {
937 
938  // Replace with a runtime API function call.
939  rewriter.replaceOpWithNewOp<func::CallOp>(op, kGetNumWorkerThreads,
940  rewriter.getIndexType());
941 
942  return success();
943  }
944 };
945 } // namespace
946 
947 //===----------------------------------------------------------------------===//
948 // Async reference counting ops lowering (`async.runtime.add_ref` and
949 // `async.runtime.drop_ref` to the corresponding API calls).
950 //===----------------------------------------------------------------------===//
951 
952 namespace {
953 template <typename RefCountingOp>
954 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
955 public:
956  explicit RefCountingOpLowering(const TypeConverter &converter,
957  MLIRContext *ctx, StringRef apiFunctionName)
958  : OpConversionPattern<RefCountingOp>(converter, ctx),
959  apiFunctionName(apiFunctionName) {}
960 
961  LogicalResult
962  matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
963  ConversionPatternRewriter &rewriter) const override {
964  auto count = rewriter.create<arith::ConstantOp>(
965  op->getLoc(), rewriter.getI64Type(),
966  rewriter.getI64IntegerAttr(op.getCount()));
967 
968  auto operand = adaptor.getOperand();
969  rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName,
970  ValueRange({operand, count}));
971 
972  return success();
973  }
974 
975 private:
976  StringRef apiFunctionName;
977 };
978 
979 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
980 public:
981  explicit RuntimeAddRefOpLowering(const TypeConverter &converter,
982  MLIRContext *ctx)
983  : RefCountingOpLowering(converter, ctx, kAddRef) {}
984 };
985 
986 class RuntimeDropRefOpLowering
987  : public RefCountingOpLowering<RuntimeDropRefOp> {
988 public:
989  explicit RuntimeDropRefOpLowering(const TypeConverter &converter,
990  MLIRContext *ctx)
991  : RefCountingOpLowering(converter, ctx, kDropRef) {}
992 };
993 } // namespace
994 
995 //===----------------------------------------------------------------------===//
996 // Convert return operations that return async values from async regions.
997 //===----------------------------------------------------------------------===//
998 
999 namespace {
1000 class ReturnOpOpConversion : public OpConversionPattern<func::ReturnOp> {
1001 public:
1003 
1004  LogicalResult
1005  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
1006  ConversionPatternRewriter &rewriter) const override {
1007  rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
1008  return success();
1009  }
1010 };
1011 } // namespace
1012 
1013 //===----------------------------------------------------------------------===//
1014 
1015 namespace {
1016 struct ConvertAsyncToLLVMPass
1017  : public impl::ConvertAsyncToLLVMPassBase<ConvertAsyncToLLVMPass> {
1018  using Base::Base;
1019 
1020  void runOnOperation() override;
1021 };
1022 } // namespace
1023 
1024 void ConvertAsyncToLLVMPass::runOnOperation() {
1025  ModuleOp module = getOperation();
1026  MLIRContext *ctx = module->getContext();
1027 
1029 
1030  // Add declarations for most functions required by the coroutines lowering.
1031  // We delay adding the resume function until it's needed because it currently
1032  // fails to compile unless '-O0' is specified.
1034 
1035  // Lower async.runtime and async.coro operations to Async Runtime API and
1036  // LLVM coroutine intrinsics.
1037 
1038  // Convert async dialect types and operations to LLVM dialect.
1039  AsyncRuntimeTypeConverter converter(options);
1041 
1042  // We use conversion to LLVM type to lower async.runtime load and store
1043  // operations.
1044  LLVMTypeConverter llvmConverter(ctx, options);
1045  llvmConverter.addConversion([&](Type type) {
1046  return AsyncRuntimeTypeConverter::convertAsyncTypes(type);
1047  });
1048 
1049  // Convert async types in function signatures and function calls.
1050  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1051  converter);
1053 
1054  // Convert return operations inside async.execute regions.
1055  patterns.add<ReturnOpOpConversion>(converter, ctx);
1056 
1057  // Lower async.runtime operations to the async runtime API calls.
1058  patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
1059  RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
1060  RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
1061  RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering,
1062  RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter,
1063  ctx);
1064 
1065  // Lower async.runtime operations that rely on LLVM type converter to convert
1066  // from async value payload type to the LLVM type.
1067  patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
1068  RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter);
1069 
1070  // Lower async coroutine operations to LLVM coroutine intrinsics.
1071  patterns
1072  .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
1073  CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
1074  converter, ctx);
1075 
1076  ConversionTarget target(*ctx);
1077  target.addLegalOp<arith::ConstantOp, func::ConstantOp,
1078  UnrealizedConversionCastOp>();
1079  target.addLegalDialect<LLVM::LLVMDialect>();
1080 
1081  // All operations from Async dialect must be lowered to the runtime API and
1082  // LLVM intrinsics calls.
1083  target.addIllegalDialect<AsyncDialect>();
1084 
1085  // Add dynamic legality constraints to apply conversions defined above.
1086  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1087  return converter.isSignatureLegal(op.getFunctionType());
1088  });
1089  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
1090  return converter.isLegal(op.getOperandTypes());
1091  });
1092  target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
1093  return converter.isSignatureLegal(op.getCalleeType());
1094  });
1095 
1096  if (failed(applyPartialConversion(module, target, std::move(patterns))))
1097  signalPassFailure();
1098 }
1099 
1100 //===----------------------------------------------------------------------===//
1101 // Patterns for structural type conversions for the Async dialect operations.
1102 //===----------------------------------------------------------------------===//
1103 
1104 namespace {
1105 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1106 public:
1108  LogicalResult
1109  matchAndRewrite(ExecuteOp op, OpAdaptor adaptor,
1110  ConversionPatternRewriter &rewriter) const override {
1111  ExecuteOp newOp =
1112  cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1113  rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1114  newOp.getRegion().end());
1115 
1116  // Set operands and update block argument and result types.
1117  newOp->setOperands(adaptor.getOperands());
1118  if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1119  return failure();
1120  for (auto result : newOp.getResults())
1121  result.setType(typeConverter->convertType(result.getType()));
1122 
1123  rewriter.replaceOp(op, newOp.getResults());
1124  return success();
1125  }
1126 };
1127 
1128 // Dummy pattern to trigger the appropriate type conversion / materialization.
1129 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1130 public:
1132  LogicalResult
1133  matchAndRewrite(AwaitOp op, OpAdaptor adaptor,
1134  ConversionPatternRewriter &rewriter) const override {
1135  rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front());
1136  return success();
1137  }
1138 };
1139 
1140 // Dummy pattern to trigger the appropriate type conversion / materialization.
1141 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1142 public:
1144  LogicalResult
1145  matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
1146  ConversionPatternRewriter &rewriter) const override {
1147  rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands());
1148  return success();
1149  }
1150 };
1151 } // namespace
1152 
1154  TypeConverter &typeConverter, RewritePatternSet &patterns,
1155  ConversionTarget &target) {
1156  typeConverter.addConversion([&](TokenType type) { return type; });
1157  typeConverter.addConversion([&](ValueType type) {
1158  Type converted = typeConverter.convertType(type.getValueType());
1159  return converted ? ValueType::get(converted) : converted;
1160  });
1161 
1162  patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1163  typeConverter, patterns.getContext());
1164 
1165  target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1166  [&](Operation *op) { return typeConverter.isLegal(op); });
1167 }
static constexpr const char * kAwaitValueAndExecute
Definition: AsyncToLLVM.cpp:64
static constexpr const char * kCreateValue
Definition: AsyncToLLVM.cpp:45
static constexpr const char * kCreateGroup
Definition: AsyncToLLVM.cpp:46
static constexpr const char * kCreateToken
Definition: AsyncToLLVM.cpp:44
static constexpr const char * kEmplaceValue
Definition: AsyncToLLVM.cpp:48
static void addResumeFunction(ModuleOp module)
A function that takes a coroutine handle and calls a llvm.coro.resume intrinsics.
static constexpr const char * kEmplaceToken
Definition: AsyncToLLVM.cpp:47
static void addAsyncRuntimeApiDeclarations(ModuleOp module)
Adds Async Runtime C API declarations to the module.
static constexpr const char * kResume
static constexpr const char * kAddRef
Definition: AsyncToLLVM.cpp:42
static constexpr const char * kAwaitTokenAndExecute
Definition: AsyncToLLVM.cpp:62
static constexpr const char * kAwaitValue
Definition: AsyncToLLVM.cpp:55
static constexpr const char * kSetTokenError
Definition: AsyncToLLVM.cpp:49
static constexpr const char * kExecute
Definition: AsyncToLLVM.cpp:57
static constexpr const char * kAddTokenToGroup
Definition: AsyncToLLVM.cpp:60
static constexpr const char * kIsGroupError
Definition: AsyncToLLVM.cpp:53
static constexpr const char * kSetValueError
Definition: AsyncToLLVM.cpp:50
static constexpr const char * kIsTokenError
Definition: AsyncToLLVM.cpp:51
static constexpr const char * kAwaitGroup
Definition: AsyncToLLVM.cpp:56
static constexpr const char * kAwaitAllAndExecute
Definition: AsyncToLLVM.cpp:66
static constexpr const char * kGetNumWorkerThreads
Definition: AsyncToLLVM.cpp:68
static constexpr const char * kDropRef
Definition: AsyncToLLVM.cpp:43
static constexpr const char * kIsValueError
Definition: AsyncToLLVM.cpp:52
static constexpr const char * kAwaitToken
Definition: AsyncToLLVM.cpp:54
static constexpr const char * kGetValueStorage
Definition: AsyncToLLVM.cpp:58
static llvm::ManagedStatic< PassManagerOptions > options
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:149
static ImplicitLocOpBuilder atBlockEnd(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:582
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(Operation *moduleOp, Type indexType)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(Operation *moduleOp)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.