MLIR  22.0.0git
AsyncToLLVM.cpp
Go to the documentation of this file.
1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/Pass/Pass.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 namespace mlir {
27 #define GEN_PASS_DEF_CONVERTASYNCTOLLVMPASS
28 #include "mlir/Conversion/Passes.h.inc"
29 } // namespace mlir
30 
31 #define DEBUG_TYPE "convert-async-to-llvm"
32 
33 using namespace mlir;
34 using namespace mlir::async;
35 
36 //===----------------------------------------------------------------------===//
37 // Async Runtime C API declaration.
38 //===----------------------------------------------------------------------===//
39 
40 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
41 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
42 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
43 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
44 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
45 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
46 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
47 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
48 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
49 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
50 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
51 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
52 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
53 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
54 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
55 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
56 static constexpr const char *kGetValueStorage =
57  "mlirAsyncRuntimeGetValueStorage";
58 static constexpr const char *kAddTokenToGroup =
59  "mlirAsyncRuntimeAddTokenToGroup";
60 static constexpr const char *kAwaitTokenAndExecute =
61  "mlirAsyncRuntimeAwaitTokenAndExecute";
62 static constexpr const char *kAwaitValueAndExecute =
63  "mlirAsyncRuntimeAwaitValueAndExecute";
64 static constexpr const char *kAwaitAllAndExecute =
65  "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
66 static constexpr const char *kGetNumWorkerThreads =
67  "mlirAsyncRuntimGetNumWorkerThreads";
68 
69 namespace {
70 /// Async Runtime API function types.
71 ///
72 /// Because we can't create API function signature for type parametrized
73 /// async.getValue type, we use opaque pointers (!llvm.ptr) instead. After
74 /// lowering all async data types become opaque pointers at runtime.
75 struct AsyncAPI {
76  // All async types are lowered to opaque LLVM pointers at runtime.
77  static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
78  return LLVM::LLVMPointerType::get(ctx);
79  }
80 
81  static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
82  return LLVM::LLVMTokenType::get(ctx);
83  }
84 
85  static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
86  auto ref = opaquePointerType(ctx);
87  auto count = IntegerType::get(ctx, 64);
88  return FunctionType::get(ctx, {ref, count}, {});
89  }
90 
91  static FunctionType createTokenFunctionType(MLIRContext *ctx) {
92  return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
93  }
94 
95  static FunctionType createValueFunctionType(MLIRContext *ctx) {
96  auto i64 = IntegerType::get(ctx, 64);
97  auto value = opaquePointerType(ctx);
98  return FunctionType::get(ctx, {i64}, {value});
99  }
100 
101  static FunctionType createGroupFunctionType(MLIRContext *ctx) {
102  auto i64 = IntegerType::get(ctx, 64);
103  return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
104  }
105 
106  static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
107  auto ptrType = opaquePointerType(ctx);
108  return FunctionType::get(ctx, {ptrType}, {ptrType});
109  }
110 
111  static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
112  return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
113  }
114 
115  static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
116  auto value = opaquePointerType(ctx);
117  return FunctionType::get(ctx, {value}, {});
118  }
119 
120  static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
121  return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
122  }
123 
124  static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
125  auto value = opaquePointerType(ctx);
126  return FunctionType::get(ctx, {value}, {});
127  }
128 
129  static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
130  auto i1 = IntegerType::get(ctx, 1);
131  return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
132  }
133 
134  static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
135  auto value = opaquePointerType(ctx);
136  auto i1 = IntegerType::get(ctx, 1);
137  return FunctionType::get(ctx, {value}, {i1});
138  }
139 
140  static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
141  auto i1 = IntegerType::get(ctx, 1);
142  return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
143  }
144 
145  static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
146  return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
147  }
148 
149  static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
150  auto value = opaquePointerType(ctx);
151  return FunctionType::get(ctx, {value}, {});
152  }
153 
154  static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
155  return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
156  }
157 
158  static FunctionType executeFunctionType(MLIRContext *ctx) {
159  auto ptrType = opaquePointerType(ctx);
160  return FunctionType::get(ctx, {ptrType, ptrType}, {});
161  }
162 
163  static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
164  auto i64 = IntegerType::get(ctx, 64);
165  return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
166  {i64});
167  }
168 
169  static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
170  auto ptrType = opaquePointerType(ctx);
171  return FunctionType::get(ctx, {TokenType::get(ctx), ptrType, ptrType}, {});
172  }
173 
174  static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
175  auto ptrType = opaquePointerType(ctx);
176  return FunctionType::get(ctx, {ptrType, ptrType, ptrType}, {});
177  }
178 
179  static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
180  auto ptrType = opaquePointerType(ctx);
181  return FunctionType::get(ctx, {GroupType::get(ctx), ptrType, ptrType}, {});
182  }
183 
184  static FunctionType getNumWorkerThreads(MLIRContext *ctx) {
185  return FunctionType::get(ctx, {}, {IndexType::get(ctx)});
186  }
187 
188  // Auxiliary coroutine resume intrinsic wrapper.
189  static Type resumeFunctionType(MLIRContext *ctx) {
190  auto voidTy = LLVM::LLVMVoidType::get(ctx);
191  auto ptrType = opaquePointerType(ctx);
192  return LLVM::LLVMFunctionType::get(voidTy, {ptrType}, false);
193  }
194 };
195 } // namespace
196 
197 /// Adds Async Runtime C API declarations to the module.
198 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
199  auto builder =
200  ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
201 
202  auto addFuncDecl = [&](StringRef name, FunctionType type) {
203  if (module.lookupSymbol(name))
204  return;
205  func::FuncOp::create(builder, name, type).setPrivate();
206  };
207 
208  MLIRContext *ctx = module.getContext();
209  addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
210  addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
211  addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
212  addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
213  addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
214  addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
215  addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
216  addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
217  addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
218  addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
219  addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
220  addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
221  addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
222  addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
223  addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
224  addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
225  addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
226  addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
227  addFuncDecl(kAwaitTokenAndExecute,
228  AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
229  addFuncDecl(kAwaitValueAndExecute,
230  AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
231  addFuncDecl(kAwaitAllAndExecute,
232  AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
233  addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx));
234 }
235 
236 //===----------------------------------------------------------------------===//
237 // Coroutine resume function wrapper.
238 //===----------------------------------------------------------------------===//
239 
240 static constexpr const char *kResume = "__resume";
241 
242 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
243 /// intrinsics. We need this function to be able to pass it to the async
244 /// runtime execute API.
245 static void addResumeFunction(ModuleOp module) {
246  if (module.lookupSymbol(kResume))
247  return;
248 
249  MLIRContext *ctx = module.getContext();
250  auto loc = module.getLoc();
251  auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
252 
253  auto voidTy = LLVM::LLVMVoidType::get(ctx);
254  Type ptrType = AsyncAPI::opaquePointerType(ctx);
255 
256  auto resumeOp = LLVM::LLVMFuncOp::create(
257  moduleBuilder, kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
258  resumeOp.setPrivate();
259 
260  auto *block = resumeOp.addEntryBlock(moduleBuilder);
261  auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
262 
263  LLVM::CoroResumeOp::create(blockBuilder, resumeOp.getArgument(0));
264  LLVM::ReturnOp::create(blockBuilder, ValueRange());
265 }
266 
267 //===----------------------------------------------------------------------===//
268 // Convert Async dialect types to LLVM types.
269 //===----------------------------------------------------------------------===//
270 
271 namespace {
272 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
273 /// their runtime type (opaque pointers) and does not convert any other types.
274 class AsyncRuntimeTypeConverter : public TypeConverter {
275 public:
276  AsyncRuntimeTypeConverter(const LowerToLLVMOptions &options) {
277  addConversion([](Type type) { return type; });
278  addConversion([](Type type) { return convertAsyncTypes(type); });
279 
280  // Use UnrealizedConversionCast as the bridge so that we don't need to pull
281  // in patterns for other dialects.
282  auto addUnrealizedCast = [](OpBuilder &builder, Type type,
283  ValueRange inputs, Location loc) -> Value {
284  auto cast =
285  UnrealizedConversionCastOp::create(builder, loc, type, inputs);
286  return cast.getResult(0);
287  };
288 
289  addSourceMaterialization(addUnrealizedCast);
290  addTargetMaterialization(addUnrealizedCast);
291  }
292 
293  static std::optional<Type> convertAsyncTypes(Type type) {
294  if (isa<TokenType, GroupType, ValueType>(type))
295  return AsyncAPI::opaquePointerType(type.getContext());
296 
297  if (isa<CoroIdType, CoroStateType>(type))
298  return AsyncAPI::tokenType(type.getContext());
299  if (isa<CoroHandleType>(type))
300  return AsyncAPI::opaquePointerType(type.getContext());
301 
302  return std::nullopt;
303  }
304 };
305 
306 /// Base class for conversion patterns requiring AsyncRuntimeTypeConverter
307 /// as type converter. Allows access to it via the 'getTypeConverter'
308 /// convenience method.
309 template <typename SourceOp>
310 class AsyncOpConversionPattern : public OpConversionPattern<SourceOp> {
311 
312  using Base = OpConversionPattern<SourceOp>;
313 
314 public:
315  AsyncOpConversionPattern(const AsyncRuntimeTypeConverter &typeConverter,
316  MLIRContext *context)
317  : Base(typeConverter, context) {}
318 
319  /// Returns the 'AsyncRuntimeTypeConverter' of the pattern.
320  const AsyncRuntimeTypeConverter *getTypeConverter() const {
321  return static_cast<const AsyncRuntimeTypeConverter *>(
322  Base::getTypeConverter());
323  }
324 };
325 
326 } // namespace
327 
328 //===----------------------------------------------------------------------===//
329 // Convert async.coro.id to @llvm.coro.id intrinsic.
330 //===----------------------------------------------------------------------===//
331 
332 namespace {
333 class CoroIdOpConversion : public AsyncOpConversionPattern<CoroIdOp> {
334 public:
335  using AsyncOpConversionPattern::AsyncOpConversionPattern;
336 
337  LogicalResult
338  matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
339  ConversionPatternRewriter &rewriter) const override {
340  auto token = AsyncAPI::tokenType(op->getContext());
341  auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
342  auto loc = op->getLoc();
343 
344  // Constants for initializing coroutine frame.
345  auto constZero =
346  LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 0);
347  auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, ptrType);
348 
349  // Get coroutine id: @llvm.coro.id.
350  rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
351  op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
352 
353  return success();
354  }
355 };
356 } // namespace
357 
358 //===----------------------------------------------------------------------===//
359 // Convert async.coro.begin to @llvm.coro.begin intrinsic.
360 //===----------------------------------------------------------------------===//
361 
362 namespace {
363 class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
364 public:
365  using AsyncOpConversionPattern::AsyncOpConversionPattern;
366 
367  LogicalResult
368  matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
369  ConversionPatternRewriter &rewriter) const override {
370  auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
371  auto loc = op->getLoc();
372 
373  // Get coroutine frame size: @llvm.coro.size.i64.
374  Value coroSize =
375  LLVM::CoroSizeOp::create(rewriter, loc, rewriter.getI64Type());
376  // Get coroutine frame alignment: @llvm.coro.align.i64.
377  Value coroAlign =
378  LLVM::CoroAlignOp::create(rewriter, loc, rewriter.getI64Type());
379 
380  // Round up the size to be multiple of the alignment. Since aligned_alloc
381  // requires the size parameter be an integral multiple of the alignment
382  // parameter.
383  auto makeConstant = [&](uint64_t c) {
384  return LLVM::ConstantOp::create(rewriter, op->getLoc(),
385  rewriter.getI64Type(), c);
386  };
387  coroSize = LLVM::AddOp::create(rewriter, op->getLoc(), coroSize, coroAlign);
388  coroSize =
389  LLVM::SubOp::create(rewriter, op->getLoc(), coroSize, makeConstant(1));
390  Value negCoroAlign =
391  LLVM::SubOp::create(rewriter, op->getLoc(), makeConstant(0), coroAlign);
392  coroSize =
393  LLVM::AndOp::create(rewriter, op->getLoc(), coroSize, negCoroAlign);
394 
395  // Allocate memory for the coroutine frame.
396  auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
397  rewriter, op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
398  if (failed(allocFuncOp))
399  return failure();
400  auto coroAlloc = LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(),
401  ValueRange{coroAlign, coroSize});
402 
403  // Begin a coroutine: @llvm.coro.begin.
404  auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
405  rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
406  op, ptrType, ValueRange({coroId, coroAlloc.getResult()}));
407 
408  return success();
409  }
410 };
411 } // namespace
412 
413 //===----------------------------------------------------------------------===//
414 // Convert async.coro.free to @llvm.coro.free intrinsic.
415 //===----------------------------------------------------------------------===//
416 
417 namespace {
418 class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
419 public:
420  using AsyncOpConversionPattern::AsyncOpConversionPattern;
421 
422  LogicalResult
423  matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
424  ConversionPatternRewriter &rewriter) const override {
425  auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
426  auto loc = op->getLoc();
427 
428  // Get a pointer to the coroutine frame memory: @llvm.coro.free.
429  auto coroMem =
430  LLVM::CoroFreeOp::create(rewriter, loc, ptrType, adaptor.getOperands());
431 
432  // Free the memory.
433  auto freeFuncOp =
434  LLVM::lookupOrCreateFreeFn(rewriter, op->getParentOfType<ModuleOp>());
435  if (failed(freeFuncOp))
436  return failure();
437  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp.value(),
438  ValueRange(coroMem.getResult()));
439 
440  return success();
441  }
442 };
443 } // namespace
444 
445 //===----------------------------------------------------------------------===//
446 // Convert async.coro.end to @llvm.coro.end intrinsic.
447 //===----------------------------------------------------------------------===//
448 
449 namespace {
450 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
451 public:
453 
454  LogicalResult
455  matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
456  ConversionPatternRewriter &rewriter) const override {
457  // We are not in the block that is part of the unwind sequence.
458  auto constFalse =
459  LLVM::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI1Type(),
460  rewriter.getBoolAttr(false));
461  auto noneToken = LLVM::NoneTokenOp::create(rewriter, op->getLoc());
462 
463  // Mark the end of a coroutine: @llvm.coro.end.
464  auto coroHdl = adaptor.getHandle();
465  LLVM::CoroEndOp::create(rewriter, op->getLoc(), rewriter.getI1Type(),
466  ValueRange({coroHdl, constFalse, noneToken}));
467  rewriter.eraseOp(op);
468 
469  return success();
470  }
471 };
472 } // namespace
473 
474 //===----------------------------------------------------------------------===//
475 // Convert async.coro.save to @llvm.coro.save intrinsic.
476 //===----------------------------------------------------------------------===//
477 
478 namespace {
479 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
480 public:
482 
483  LogicalResult
484  matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor,
485  ConversionPatternRewriter &rewriter) const override {
486  // Save the coroutine state: @llvm.coro.save
487  rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
488  op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands());
489 
490  return success();
491  }
492 };
493 } // namespace
494 
495 //===----------------------------------------------------------------------===//
496 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
497 //===----------------------------------------------------------------------===//
498 
499 namespace {
500 
501 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
502 /// branch to the appropriate block based on the return code.
503 ///
504 /// Before:
505 ///
506 /// ^suspended:
507 /// "opBefore"(...)
508 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup
509 /// ^resume:
510 /// "op"(...)
511 /// ^cleanup: ...
512 /// ^suspend: ...
513 ///
514 /// After:
515 ///
516 /// ^suspended:
517 /// "opBefore"(...)
518 /// %suspend = llmv.intr.coro.suspend ...
519 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
520 /// ^resume:
521 /// "op"(...)
522 /// ^cleanup: ...
523 /// ^suspend: ...
524 ///
525 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
526 public:
528 
529  LogicalResult
530  matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor,
531  ConversionPatternRewriter &rewriter) const override {
532  auto i8 = rewriter.getIntegerType(8);
533  auto i32 = rewriter.getI32Type();
534  auto loc = op->getLoc();
535 
536  // This is not a final suspension point.
537  auto constFalse = LLVM::ConstantOp::create(
538  rewriter, loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
539 
540  // Suspend a coroutine: @llvm.coro.suspend
541  auto coroState = adaptor.getState();
542  auto coroSuspend = LLVM::CoroSuspendOp::create(
543  rewriter, loc, i8, ValueRange({coroState, constFalse}));
544 
545  // Cast return code to i32.
546 
547  // After a suspension point decide if we should branch into resume, cleanup
548  // or suspend block of the coroutine (see @llvm.coro.suspend return code
549  // documentation).
550  llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
551  llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(),
552  op.getCleanupDest()};
553  rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
554  op, LLVM::SExtOp::create(rewriter, loc, i32, coroSuspend.getResult()),
555  /*defaultDestination=*/op.getSuspendDest(),
556  /*defaultOperands=*/ValueRange(),
557  /*caseValues=*/caseValues,
558  /*caseDestinations=*/caseDest,
559  /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}),
560  /*branchWeights=*/ArrayRef<int32_t>());
561 
562  return success();
563  }
564 };
565 } // namespace
566 
567 //===----------------------------------------------------------------------===//
568 // Convert async.runtime.create to the corresponding runtime API call.
569 //
570 // To allocate storage for the async values we use getelementptr trick:
571 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
572 //===----------------------------------------------------------------------===//
573 
574 namespace {
575 class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> {
576 public:
578 
579  LogicalResult
580  matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
581  ConversionPatternRewriter &rewriter) const override {
582  const TypeConverter *converter = getTypeConverter();
583  Type resultType = op->getResultTypes()[0];
584 
585  // Tokens creation maps to a simple function call.
586  if (isa<TokenType>(resultType)) {
587  rewriter.replaceOpWithNewOp<func::CallOp>(
588  op, kCreateToken, converter->convertType(resultType));
589  return success();
590  }
591 
592  // To create a value we need to compute the storage requirement.
593  if (auto value = dyn_cast<ValueType>(resultType)) {
594  // Returns the size requirements for the async value storage.
595  auto sizeOf = [&](ValueType valueType) -> Value {
596  auto loc = op->getLoc();
597  auto i64 = rewriter.getI64Type();
598 
599  auto storedType = converter->convertType(valueType.getValueType());
600  auto storagePtrType =
601  AsyncAPI::opaquePointerType(rewriter.getContext());
602 
603  // %Size = getelementptr %T* null, int 1
604  // %SizeI = ptrtoint %T* %Size to i64
605  auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, storagePtrType);
606  auto gep =
607  LLVM::GEPOp::create(rewriter, loc, storagePtrType, storedType,
608  nullPtr, ArrayRef<LLVM::GEPArg>{1});
609  return LLVM::PtrToIntOp::create(rewriter, loc, i64, gep);
610  };
611 
612  rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType,
613  sizeOf(value));
614 
615  return success();
616  }
617 
618  return rewriter.notifyMatchFailure(op, "unsupported async type");
619  }
620 };
621 } // namespace
622 
623 //===----------------------------------------------------------------------===//
624 // Convert async.runtime.create_group to the corresponding runtime API call.
625 //===----------------------------------------------------------------------===//
626 
627 namespace {
628 class RuntimeCreateGroupOpLowering
629  : public ConvertOpToLLVMPattern<RuntimeCreateGroupOp> {
630 public:
632 
633  LogicalResult
634  matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
635  ConversionPatternRewriter &rewriter) const override {
636  const TypeConverter *converter = getTypeConverter();
637  Type resultType = op.getResult().getType();
638 
639  rewriter.replaceOpWithNewOp<func::CallOp>(
640  op, kCreateGroup, converter->convertType(resultType),
641  adaptor.getOperands());
642  return success();
643  }
644 };
645 } // namespace
646 
647 //===----------------------------------------------------------------------===//
648 // Convert async.runtime.set_available to the corresponding runtime API call.
649 //===----------------------------------------------------------------------===//
650 
651 namespace {
652 class RuntimeSetAvailableOpLowering
653  : public OpConversionPattern<RuntimeSetAvailableOp> {
654 public:
656 
657  LogicalResult
658  matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
659  ConversionPatternRewriter &rewriter) const override {
660  StringRef apiFuncName =
661  TypeSwitch<Type, StringRef>(op.getOperand().getType())
662  .Case<TokenType>([](Type) { return kEmplaceToken; })
663  .Case<ValueType>([](Type) { return kEmplaceValue; });
664 
665  rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
666  adaptor.getOperands());
667 
668  return success();
669  }
670 };
671 } // namespace
672 
673 //===----------------------------------------------------------------------===//
674 // Convert async.runtime.set_error to the corresponding runtime API call.
675 //===----------------------------------------------------------------------===//
676 
677 namespace {
678 class RuntimeSetErrorOpLowering
679  : public OpConversionPattern<RuntimeSetErrorOp> {
680 public:
682 
683  LogicalResult
684  matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
685  ConversionPatternRewriter &rewriter) const override {
686  StringRef apiFuncName =
687  TypeSwitch<Type, StringRef>(op.getOperand().getType())
688  .Case<TokenType>([](Type) { return kSetTokenError; })
689  .Case<ValueType>([](Type) { return kSetValueError; });
690 
691  rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
692  adaptor.getOperands());
693 
694  return success();
695  }
696 };
697 } // namespace
698 
699 //===----------------------------------------------------------------------===//
700 // Convert async.runtime.is_error to the corresponding runtime API call.
701 //===----------------------------------------------------------------------===//
702 
703 namespace {
704 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
705 public:
707 
708  LogicalResult
709  matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
710  ConversionPatternRewriter &rewriter) const override {
711  StringRef apiFuncName =
712  TypeSwitch<Type, StringRef>(op.getOperand().getType())
713  .Case<TokenType>([](Type) { return kIsTokenError; })
714  .Case<GroupType>([](Type) { return kIsGroupError; })
715  .Case<ValueType>([](Type) { return kIsValueError; });
716 
717  rewriter.replaceOpWithNewOp<func::CallOp>(
718  op, apiFuncName, rewriter.getI1Type(), adaptor.getOperands());
719  return success();
720  }
721 };
722 } // namespace
723 
724 //===----------------------------------------------------------------------===//
725 // Convert async.runtime.await to the corresponding runtime API call.
726 //===----------------------------------------------------------------------===//
727 
728 namespace {
729 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
730 public:
732 
733  LogicalResult
734  matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
735  ConversionPatternRewriter &rewriter) const override {
736  StringRef apiFuncName =
737  TypeSwitch<Type, StringRef>(op.getOperand().getType())
738  .Case<TokenType>([](Type) { return kAwaitToken; })
739  .Case<ValueType>([](Type) { return kAwaitValue; })
740  .Case<GroupType>([](Type) { return kAwaitGroup; });
741 
742  func::CallOp::create(rewriter, op->getLoc(), apiFuncName, TypeRange(),
743  adaptor.getOperands());
744  rewriter.eraseOp(op);
745 
746  return success();
747  }
748 };
749 } // namespace
750 
751 //===----------------------------------------------------------------------===//
752 // Convert async.runtime.await_and_resume to the corresponding runtime API call.
753 //===----------------------------------------------------------------------===//
754 
755 namespace {
756 class RuntimeAwaitAndResumeOpLowering
757  : public AsyncOpConversionPattern<RuntimeAwaitAndResumeOp> {
758 public:
759  using AsyncOpConversionPattern::AsyncOpConversionPattern;
760 
761  LogicalResult
762  matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
763  ConversionPatternRewriter &rewriter) const override {
764  StringRef apiFuncName =
765  TypeSwitch<Type, StringRef>(op.getOperand().getType())
766  .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
767  .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
768  .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
769 
770  Value operand = adaptor.getOperand();
771  Value handle = adaptor.getHandle();
772 
773  // A pointer to coroutine resume intrinsic wrapper.
774  addResumeFunction(op->getParentOfType<ModuleOp>());
775  auto resumePtr = LLVM::AddressOfOp::create(
776  rewriter, op->getLoc(),
777  AsyncAPI::opaquePointerType(rewriter.getContext()), kResume);
778 
779  func::CallOp::create(rewriter, op->getLoc(), apiFuncName, TypeRange(),
780  ValueRange({operand, handle, resumePtr.getRes()}));
781  rewriter.eraseOp(op);
782 
783  return success();
784  }
785 };
786 } // namespace
787 
788 //===----------------------------------------------------------------------===//
789 // Convert async.runtime.resume to the corresponding runtime API call.
790 //===----------------------------------------------------------------------===//
791 
792 namespace {
793 class RuntimeResumeOpLowering
794  : public AsyncOpConversionPattern<RuntimeResumeOp> {
795 public:
796  using AsyncOpConversionPattern::AsyncOpConversionPattern;
797 
798  LogicalResult
799  matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
800  ConversionPatternRewriter &rewriter) const override {
801  // A pointer to coroutine resume intrinsic wrapper.
802  addResumeFunction(op->getParentOfType<ModuleOp>());
803  auto resumePtr = LLVM::AddressOfOp::create(
804  rewriter, op->getLoc(),
805  AsyncAPI::opaquePointerType(rewriter.getContext()), kResume);
806 
807  // Call async runtime API to execute a coroutine in the managed thread.
808  auto coroHdl = adaptor.getHandle();
809  rewriter.replaceOpWithNewOp<func::CallOp>(
810  op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()}));
811 
812  return success();
813  }
814 };
815 } // namespace
816 
817 //===----------------------------------------------------------------------===//
818 // Convert async.runtime.store to the corresponding runtime API call.
819 //===----------------------------------------------------------------------===//
820 
821 namespace {
822 class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> {
823 public:
825 
826  LogicalResult
827  matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
828  ConversionPatternRewriter &rewriter) const override {
829  Location loc = op->getLoc();
830 
831  // Get a pointer to the async value storage from the runtime.
832  auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
833  auto storage = adaptor.getStorage();
834  auto storagePtr = func::CallOp::create(rewriter, loc, kGetValueStorage,
835  TypeRange(ptrType), storage);
836 
837  // Cast from i8* to the LLVM pointer type.
838  auto valueType = op.getValue().getType();
839  auto llvmValueType = getTypeConverter()->convertType(valueType);
840  if (!llvmValueType)
841  return rewriter.notifyMatchFailure(
842  op, "failed to convert stored value type to LLVM type");
843 
844  Value castedStoragePtr = storagePtr.getResult(0);
845  // Store the yielded value into the async value storage.
846  auto value = adaptor.getValue();
847  LLVM::StoreOp::create(rewriter, loc, value, castedStoragePtr);
848 
849  // Erase the original runtime store operation.
850  rewriter.eraseOp(op);
851 
852  return success();
853  }
854 };
855 } // namespace
856 
857 //===----------------------------------------------------------------------===//
858 // Convert async.runtime.load to the corresponding runtime API call.
859 //===----------------------------------------------------------------------===//
860 
861 namespace {
862 class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern<RuntimeLoadOp> {
863 public:
865 
866  LogicalResult
867  matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
868  ConversionPatternRewriter &rewriter) const override {
869  Location loc = op->getLoc();
870 
871  // Get a pointer to the async value storage from the runtime.
872  auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
873  auto storage = adaptor.getStorage();
874  auto storagePtr = func::CallOp::create(rewriter, loc, kGetValueStorage,
875  TypeRange(ptrType), storage);
876 
877  // Cast from i8* to the LLVM pointer type.
878  auto valueType = op.getResult().getType();
879  auto llvmValueType = getTypeConverter()->convertType(valueType);
880  if (!llvmValueType)
881  return rewriter.notifyMatchFailure(
882  op, "failed to convert loaded value type to LLVM type");
883 
884  Value castedStoragePtr = storagePtr.getResult(0);
885 
886  // Load from the casted pointer.
887  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmValueType,
888  castedStoragePtr);
889 
890  return success();
891  }
892 };
893 } // namespace
894 
895 //===----------------------------------------------------------------------===//
896 // Convert async.runtime.add_to_group to the corresponding runtime API call.
897 //===----------------------------------------------------------------------===//
898 
899 namespace {
900 class RuntimeAddToGroupOpLowering
901  : public OpConversionPattern<RuntimeAddToGroupOp> {
902 public:
904 
905  LogicalResult
906  matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
907  ConversionPatternRewriter &rewriter) const override {
908  // Currently we can only add tokens to the group.
909  if (!isa<TokenType>(op.getOperand().getType()))
910  return rewriter.notifyMatchFailure(op, "only token type is supported");
911 
912  // Replace with a runtime API function call.
913  rewriter.replaceOpWithNewOp<func::CallOp>(
914  op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands());
915 
916  return success();
917  }
918 };
919 } // namespace
920 
921 //===----------------------------------------------------------------------===//
922 // Convert async.runtime.num_worker_threads to the corresponding runtime API
923 // call.
924 //===----------------------------------------------------------------------===//
925 
926 namespace {
927 class RuntimeNumWorkerThreadsOpLowering
928  : public OpConversionPattern<RuntimeNumWorkerThreadsOp> {
929 public:
931 
932  LogicalResult
933  matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor,
934  ConversionPatternRewriter &rewriter) const override {
935 
936  // Replace with a runtime API function call.
937  rewriter.replaceOpWithNewOp<func::CallOp>(op, kGetNumWorkerThreads,
938  rewriter.getIndexType());
939 
940  return success();
941  }
942 };
943 } // namespace
944 
945 //===----------------------------------------------------------------------===//
946 // Async reference counting ops lowering (`async.runtime.add_ref` and
947 // `async.runtime.drop_ref` to the corresponding API calls).
948 //===----------------------------------------------------------------------===//
949 
950 namespace {
951 template <typename RefCountingOp>
952 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
953 public:
954  explicit RefCountingOpLowering(const TypeConverter &converter,
955  MLIRContext *ctx, StringRef apiFunctionName)
956  : OpConversionPattern<RefCountingOp>(converter, ctx),
957  apiFunctionName(apiFunctionName) {}
958 
959  LogicalResult
960  matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
961  ConversionPatternRewriter &rewriter) const override {
962  auto count =
963  arith::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI64Type(),
964  rewriter.getI64IntegerAttr(op.getCount()));
965 
966  auto operand = adaptor.getOperand();
967  rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName,
968  ValueRange({operand, count}));
969 
970  return success();
971  }
972 
973 private:
974  StringRef apiFunctionName;
975 };
976 
977 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
978 public:
979  explicit RuntimeAddRefOpLowering(const TypeConverter &converter,
980  MLIRContext *ctx)
981  : RefCountingOpLowering(converter, ctx, kAddRef) {}
982 };
983 
984 class RuntimeDropRefOpLowering
985  : public RefCountingOpLowering<RuntimeDropRefOp> {
986 public:
987  explicit RuntimeDropRefOpLowering(const TypeConverter &converter,
988  MLIRContext *ctx)
989  : RefCountingOpLowering(converter, ctx, kDropRef) {}
990 };
991 } // namespace
992 
993 //===----------------------------------------------------------------------===//
994 // Convert return operations that return async values from async regions.
995 //===----------------------------------------------------------------------===//
996 
997 namespace {
998 class ReturnOpOpConversion : public OpConversionPattern<func::ReturnOp> {
999 public:
1001 
1002  LogicalResult
1003  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
1004  ConversionPatternRewriter &rewriter) const override {
1005  rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
1006  return success();
1007  }
1008 };
1009 } // namespace
1010 
1011 //===----------------------------------------------------------------------===//
1012 
1013 namespace {
1014 struct ConvertAsyncToLLVMPass
1015  : public impl::ConvertAsyncToLLVMPassBase<ConvertAsyncToLLVMPass> {
1016  using Base::Base;
1017 
1018  void runOnOperation() override;
1019 };
1020 } // namespace
1021 
1022 void ConvertAsyncToLLVMPass::runOnOperation() {
1023  ModuleOp module = getOperation();
1024  MLIRContext *ctx = module->getContext();
1025 
1027 
1028  // Add declarations for most functions required by the coroutines lowering.
1029  // We delay adding the resume function until it's needed because it currently
1030  // fails to compile unless '-O0' is specified.
1032 
1033  // Lower async.runtime and async.coro operations to Async Runtime API and
1034  // LLVM coroutine intrinsics.
1035 
1036  // Convert async dialect types and operations to LLVM dialect.
1037  AsyncRuntimeTypeConverter converter(options);
1039 
1040  // We use conversion to LLVM type to lower async.runtime load and store
1041  // operations.
1042  LLVMTypeConverter llvmConverter(ctx, options);
1043  llvmConverter.addConversion([&](Type type) {
1044  return AsyncRuntimeTypeConverter::convertAsyncTypes(type);
1045  });
1046 
1047  // Convert async types in function signatures and function calls.
1048  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1049  converter);
1051 
1052  // Convert return operations inside async.execute regions.
1053  patterns.add<ReturnOpOpConversion>(converter, ctx);
1054 
1055  // Lower async.runtime operations to the async runtime API calls.
1056  patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
1057  RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
1058  RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
1059  RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering,
1060  RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter,
1061  ctx);
1062 
1063  // Lower async.runtime operations that rely on LLVM type converter to convert
1064  // from async value payload type to the LLVM type.
1065  patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
1066  RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter);
1067 
1068  // Lower async coroutine operations to LLVM coroutine intrinsics.
1069  patterns
1070  .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
1071  CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
1072  converter, ctx);
1073 
1074  ConversionTarget target(*ctx);
1075  target.addLegalOp<arith::ConstantOp, func::ConstantOp,
1076  UnrealizedConversionCastOp>();
1077  target.addLegalDialect<LLVM::LLVMDialect>();
1078 
1079  // All operations from Async dialect must be lowered to the runtime API and
1080  // LLVM intrinsics calls.
1081  target.addIllegalDialect<AsyncDialect>();
1082 
1083  // Add dynamic legality constraints to apply conversions defined above.
1084  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1085  return converter.isSignatureLegal(op.getFunctionType());
1086  });
1087  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
1088  return converter.isLegal(op.getOperandTypes());
1089  });
1090  target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
1091  return converter.isSignatureLegal(op.getCalleeType());
1092  });
1093 
1094  if (failed(applyPartialConversion(module, target, std::move(patterns))))
1095  signalPassFailure();
1096 }
1097 
1098 //===----------------------------------------------------------------------===//
1099 // Patterns for structural type conversions for the Async dialect operations.
1100 //===----------------------------------------------------------------------===//
1101 
1102 namespace {
1103 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1104 public:
1106  LogicalResult
1107  matchAndRewrite(ExecuteOp op, OpAdaptor adaptor,
1108  ConversionPatternRewriter &rewriter) const override {
1109  ExecuteOp newOp =
1110  cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1111  rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1112  newOp.getRegion().end());
1113 
1114  // Set operands and update block argument and result types.
1115  newOp->setOperands(adaptor.getOperands());
1116  if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1117  return failure();
1118  for (auto result : newOp.getResults())
1119  result.setType(typeConverter->convertType(result.getType()));
1120 
1121  rewriter.replaceOp(op, newOp.getResults());
1122  return success();
1123  }
1124 };
1125 
1126 // Dummy pattern to trigger the appropriate type conversion / materialization.
1127 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1128 public:
1130  LogicalResult
1131  matchAndRewrite(AwaitOp op, OpAdaptor adaptor,
1132  ConversionPatternRewriter &rewriter) const override {
1133  rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front());
1134  return success();
1135  }
1136 };
1137 
1138 // Dummy pattern to trigger the appropriate type conversion / materialization.
1139 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1140 public:
1142  LogicalResult
1143  matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
1144  ConversionPatternRewriter &rewriter) const override {
1145  rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands());
1146  return success();
1147  }
1148 };
1149 } // namespace
1150 
1152  TypeConverter &typeConverter, RewritePatternSet &patterns,
1153  ConversionTarget &target) {
1154  typeConverter.addConversion([&](TokenType type) { return type; });
1155  typeConverter.addConversion([&](ValueType type) {
1156  Type converted = typeConverter.convertType(type.getValueType());
1157  return converted ? ValueType::get(converted) : converted;
1158  });
1159 
1160  patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1161  typeConverter, patterns.getContext());
1162 
1163  target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1164  [&](Operation *op) { return typeConverter.isLegal(op); });
1165 }
static constexpr const char * kAwaitValueAndExecute
Definition: AsyncToLLVM.cpp:62
static constexpr const char * kCreateValue
Definition: AsyncToLLVM.cpp:43
static constexpr const char * kCreateGroup
Definition: AsyncToLLVM.cpp:44
static constexpr const char * kCreateToken
Definition: AsyncToLLVM.cpp:42
static constexpr const char * kEmplaceValue
Definition: AsyncToLLVM.cpp:46
static void addResumeFunction(ModuleOp module)
A function that takes a coroutine handle and calls a llvm.coro.resume intrinsics.
static constexpr const char * kEmplaceToken
Definition: AsyncToLLVM.cpp:45
static void addAsyncRuntimeApiDeclarations(ModuleOp module)
Adds Async Runtime C API declarations to the module.
static constexpr const char * kResume
static constexpr const char * kAddRef
Definition: AsyncToLLVM.cpp:40
static constexpr const char * kAwaitTokenAndExecute
Definition: AsyncToLLVM.cpp:60
static constexpr const char * kAwaitValue
Definition: AsyncToLLVM.cpp:53
static constexpr const char * kSetTokenError
Definition: AsyncToLLVM.cpp:47
static constexpr const char * kExecute
Definition: AsyncToLLVM.cpp:55
static constexpr const char * kAddTokenToGroup
Definition: AsyncToLLVM.cpp:58
static constexpr const char * kIsGroupError
Definition: AsyncToLLVM.cpp:51
static constexpr const char * kSetValueError
Definition: AsyncToLLVM.cpp:48
static constexpr const char * kIsTokenError
Definition: AsyncToLLVM.cpp:49
static constexpr const char * kAwaitGroup
Definition: AsyncToLLVM.cpp:54
static constexpr const char * kAwaitAllAndExecute
Definition: AsyncToLLVM.cpp:64
static constexpr const char * kGetNumWorkerThreads
Definition: AsyncToLLVM.cpp:66
static constexpr const char * kDropRef
Definition: AsyncToLLVM.cpp:41
static constexpr const char * kIsValueError
Definition: AsyncToLLVM.cpp:50
static constexpr const char * kAwaitToken
Definition: AsyncToLLVM.cpp:52
static constexpr const char * kGetValueStorage
Definition: AsyncToLLVM.cpp:56
static llvm::ManagedStatic< PassManagerOptions > options
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:95
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:199
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:205
static ImplicitLocOpBuilder atBlockEnd(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition: Builders.h:638
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:585
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.