MLIR  21.0.0git
AsyncToLLVM.cpp
Go to the documentation of this file.
1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/Pass/Pass.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 
27 namespace mlir {
28 #define GEN_PASS_DEF_CONVERTASYNCTOLLVMPASS
29 #include "mlir/Conversion/Passes.h.inc"
30 } // namespace mlir
31 
32 #define DEBUG_TYPE "convert-async-to-llvm"
33 
34 using namespace mlir;
35 using namespace mlir::async;
36 
37 //===----------------------------------------------------------------------===//
38 // Async Runtime C API declaration.
39 //===----------------------------------------------------------------------===//
40 
41 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
42 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
43 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
44 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
45 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
46 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
47 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
48 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
49 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
50 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
51 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
52 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
53 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
54 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
55 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
56 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
57 static constexpr const char *kGetValueStorage =
58  "mlirAsyncRuntimeGetValueStorage";
59 static constexpr const char *kAddTokenToGroup =
60  "mlirAsyncRuntimeAddTokenToGroup";
61 static constexpr const char *kAwaitTokenAndExecute =
62  "mlirAsyncRuntimeAwaitTokenAndExecute";
63 static constexpr const char *kAwaitValueAndExecute =
64  "mlirAsyncRuntimeAwaitValueAndExecute";
65 static constexpr const char *kAwaitAllAndExecute =
66  "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
67 static constexpr const char *kGetNumWorkerThreads =
68  "mlirAsyncRuntimGetNumWorkerThreads";
69 
70 namespace {
71 /// Async Runtime API function types.
72 ///
73 /// Because we can't create API function signature for type parametrized
74 /// async.getValue type, we use opaque pointers (!llvm.ptr) instead. After
75 /// lowering all async data types become opaque pointers at runtime.
76 struct AsyncAPI {
77  // All async types are lowered to opaque LLVM pointers at runtime.
78  static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
79  return LLVM::LLVMPointerType::get(ctx);
80  }
81 
82  static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
83  return LLVM::LLVMTokenType::get(ctx);
84  }
85 
86  static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
87  auto ref = opaquePointerType(ctx);
88  auto count = IntegerType::get(ctx, 64);
89  return FunctionType::get(ctx, {ref, count}, {});
90  }
91 
92  static FunctionType createTokenFunctionType(MLIRContext *ctx) {
93  return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
94  }
95 
96  static FunctionType createValueFunctionType(MLIRContext *ctx) {
97  auto i64 = IntegerType::get(ctx, 64);
98  auto value = opaquePointerType(ctx);
99  return FunctionType::get(ctx, {i64}, {value});
100  }
101 
102  static FunctionType createGroupFunctionType(MLIRContext *ctx) {
103  auto i64 = IntegerType::get(ctx, 64);
104  return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
105  }
106 
107  static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
108  auto ptrType = opaquePointerType(ctx);
109  return FunctionType::get(ctx, {ptrType}, {ptrType});
110  }
111 
112  static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
113  return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
114  }
115 
116  static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
117  auto value = opaquePointerType(ctx);
118  return FunctionType::get(ctx, {value}, {});
119  }
120 
121  static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
122  return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
123  }
124 
125  static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
126  auto value = opaquePointerType(ctx);
127  return FunctionType::get(ctx, {value}, {});
128  }
129 
130  static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
131  auto i1 = IntegerType::get(ctx, 1);
132  return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
133  }
134 
135  static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
136  auto value = opaquePointerType(ctx);
137  auto i1 = IntegerType::get(ctx, 1);
138  return FunctionType::get(ctx, {value}, {i1});
139  }
140 
141  static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
142  auto i1 = IntegerType::get(ctx, 1);
143  return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
144  }
145 
146  static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
147  return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
148  }
149 
150  static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
151  auto value = opaquePointerType(ctx);
152  return FunctionType::get(ctx, {value}, {});
153  }
154 
155  static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
156  return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
157  }
158 
159  static FunctionType executeFunctionType(MLIRContext *ctx) {
160  auto ptrType = opaquePointerType(ctx);
161  return FunctionType::get(ctx, {ptrType, ptrType}, {});
162  }
163 
164  static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
165  auto i64 = IntegerType::get(ctx, 64);
166  return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
167  {i64});
168  }
169 
170  static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
171  auto ptrType = opaquePointerType(ctx);
172  return FunctionType::get(ctx, {TokenType::get(ctx), ptrType, ptrType}, {});
173  }
174 
175  static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
176  auto ptrType = opaquePointerType(ctx);
177  return FunctionType::get(ctx, {ptrType, ptrType, ptrType}, {});
178  }
179 
180  static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
181  auto ptrType = opaquePointerType(ctx);
182  return FunctionType::get(ctx, {GroupType::get(ctx), ptrType, ptrType}, {});
183  }
184 
185  static FunctionType getNumWorkerThreads(MLIRContext *ctx) {
186  return FunctionType::get(ctx, {}, {IndexType::get(ctx)});
187  }
188 
189  // Auxiliary coroutine resume intrinsic wrapper.
190  static Type resumeFunctionType(MLIRContext *ctx) {
191  auto voidTy = LLVM::LLVMVoidType::get(ctx);
192  auto ptrType = opaquePointerType(ctx);
193  return LLVM::LLVMFunctionType::get(voidTy, {ptrType}, false);
194  }
195 };
196 } // namespace
197 
198 /// Adds Async Runtime C API declarations to the module.
199 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
200  auto builder =
201  ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
202 
203  auto addFuncDecl = [&](StringRef name, FunctionType type) {
204  if (module.lookupSymbol(name))
205  return;
206  builder.create<func::FuncOp>(name, type).setPrivate();
207  };
208 
209  MLIRContext *ctx = module.getContext();
210  addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
211  addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
212  addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
213  addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
214  addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
215  addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
216  addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
217  addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
218  addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
219  addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
220  addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
221  addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
222  addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
223  addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
224  addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
225  addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
226  addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
227  addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
228  addFuncDecl(kAwaitTokenAndExecute,
229  AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
230  addFuncDecl(kAwaitValueAndExecute,
231  AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
232  addFuncDecl(kAwaitAllAndExecute,
233  AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
234  addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx));
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // Coroutine resume function wrapper.
239 //===----------------------------------------------------------------------===//
240 
241 static constexpr const char *kResume = "__resume";
242 
243 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
244 /// intrinsics. We need this function to be able to pass it to the async
245 /// runtime execute API.
246 static void addResumeFunction(ModuleOp module) {
247  if (module.lookupSymbol(kResume))
248  return;
249 
250  MLIRContext *ctx = module.getContext();
251  auto loc = module.getLoc();
252  auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
253 
254  auto voidTy = LLVM::LLVMVoidType::get(ctx);
255  Type ptrType = AsyncAPI::opaquePointerType(ctx);
256 
257  auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
258  kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
259  resumeOp.setPrivate();
260 
261  auto *block = resumeOp.addEntryBlock(moduleBuilder);
262  auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
263 
264  blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
265  blockBuilder.create<LLVM::ReturnOp>(ValueRange());
266 }
267 
268 //===----------------------------------------------------------------------===//
269 // Convert Async dialect types to LLVM types.
270 //===----------------------------------------------------------------------===//
271 
272 namespace {
273 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
274 /// their runtime type (opaque pointers) and does not convert any other types.
275 class AsyncRuntimeTypeConverter : public TypeConverter {
276 public:
277  AsyncRuntimeTypeConverter(const LowerToLLVMOptions &options) {
278  addConversion([](Type type) { return type; });
279  addConversion([](Type type) { return convertAsyncTypes(type); });
280 
281  // Use UnrealizedConversionCast as the bridge so that we don't need to pull
282  // in patterns for other dialects.
283  auto addUnrealizedCast = [](OpBuilder &builder, Type type,
284  ValueRange inputs, Location loc) -> Value {
285  auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
286  return cast.getResult(0);
287  };
288 
289  addSourceMaterialization(addUnrealizedCast);
290  addTargetMaterialization(addUnrealizedCast);
291  }
292 
293  static std::optional<Type> convertAsyncTypes(Type type) {
294  if (isa<TokenType, GroupType, ValueType>(type))
295  return AsyncAPI::opaquePointerType(type.getContext());
296 
297  if (isa<CoroIdType, CoroStateType>(type))
298  return AsyncAPI::tokenType(type.getContext());
299  if (isa<CoroHandleType>(type))
300  return AsyncAPI::opaquePointerType(type.getContext());
301 
302  return std::nullopt;
303  }
304 };
305 
306 /// Base class for conversion patterns requiring AsyncRuntimeTypeConverter
307 /// as type converter. Allows access to it via the 'getTypeConverter'
308 /// convenience method.
309 template <typename SourceOp>
310 class AsyncOpConversionPattern : public OpConversionPattern<SourceOp> {
311 
312  using Base = OpConversionPattern<SourceOp>;
313 
314 public:
315  AsyncOpConversionPattern(const AsyncRuntimeTypeConverter &typeConverter,
316  MLIRContext *context)
317  : Base(typeConverter, context) {}
318 
319  /// Returns the 'AsyncRuntimeTypeConverter' of the pattern.
320  const AsyncRuntimeTypeConverter *getTypeConverter() const {
321  return static_cast<const AsyncRuntimeTypeConverter *>(
322  Base::getTypeConverter());
323  }
324 };
325 
326 } // namespace
327 
328 //===----------------------------------------------------------------------===//
329 // Convert async.coro.id to @llvm.coro.id intrinsic.
330 //===----------------------------------------------------------------------===//
331 
332 namespace {
333 class CoroIdOpConversion : public AsyncOpConversionPattern<CoroIdOp> {
334 public:
335  using AsyncOpConversionPattern::AsyncOpConversionPattern;
336 
337  LogicalResult
338  matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
339  ConversionPatternRewriter &rewriter) const override {
340  auto token = AsyncAPI::tokenType(op->getContext());
341  auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
342  auto loc = op->getLoc();
343 
344  // Constants for initializing coroutine frame.
345  auto constZero =
346  rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
347  auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
348 
349  // Get coroutine id: @llvm.coro.id.
350  rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
351  op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
352 
353  return success();
354  }
355 };
356 } // namespace
357 
358 //===----------------------------------------------------------------------===//
359 // Convert async.coro.begin to @llvm.coro.begin intrinsic.
360 //===----------------------------------------------------------------------===//
361 
362 namespace {
363 class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
364 public:
365  using AsyncOpConversionPattern::AsyncOpConversionPattern;
366 
367  LogicalResult
368  matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
369  ConversionPatternRewriter &rewriter) const override {
370  auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
371  auto loc = op->getLoc();
372 
373  // Get coroutine frame size: @llvm.coro.size.i64.
374  Value coroSize =
375  rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
376  // Get coroutine frame alignment: @llvm.coro.align.i64.
377  Value coroAlign =
378  rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type());
379 
380  // Round up the size to be multiple of the alignment. Since aligned_alloc
381  // requires the size parameter be an integral multiple of the alignment
382  // parameter.
383  auto makeConstant = [&](uint64_t c) {
384  return rewriter.create<LLVM::ConstantOp>(op->getLoc(),
385  rewriter.getI64Type(), c);
386  };
387  coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign);
388  coroSize =
389  rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1));
390  Value negCoroAlign =
391  rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign);
392  coroSize =
393  rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign);
394 
395  // Allocate memory for the coroutine frame.
396  auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
397  rewriter, op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
398  if (failed(allocFuncOp))
399  return failure();
400  auto coroAlloc = rewriter.create<LLVM::CallOp>(
401  loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize});
402 
403  // Begin a coroutine: @llvm.coro.begin.
404  auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
405  rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
406  op, ptrType, ValueRange({coroId, coroAlloc.getResult()}));
407 
408  return success();
409  }
410 };
411 } // namespace
412 
413 //===----------------------------------------------------------------------===//
414 // Convert async.coro.free to @llvm.coro.free intrinsic.
415 //===----------------------------------------------------------------------===//
416 
417 namespace {
418 class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
419 public:
420  using AsyncOpConversionPattern::AsyncOpConversionPattern;
421 
422  LogicalResult
423  matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
424  ConversionPatternRewriter &rewriter) const override {
425  auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
426  auto loc = op->getLoc();
427 
428  // Get a pointer to the coroutine frame memory: @llvm.coro.free.
429  auto coroMem =
430  rewriter.create<LLVM::CoroFreeOp>(loc, ptrType, adaptor.getOperands());
431 
432  // Free the memory.
433  auto freeFuncOp =
434  LLVM::lookupOrCreateFreeFn(rewriter, op->getParentOfType<ModuleOp>());
435  if (failed(freeFuncOp))
436  return failure();
437  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp.value(),
438  ValueRange(coroMem.getResult()));
439 
440  return success();
441  }
442 };
443 } // namespace
444 
445 //===----------------------------------------------------------------------===//
446 // Convert async.coro.end to @llvm.coro.end intrinsic.
447 //===----------------------------------------------------------------------===//
448 
449 namespace {
450 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
451 public:
453 
454  LogicalResult
455  matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
456  ConversionPatternRewriter &rewriter) const override {
457  // We are not in the block that is part of the unwind sequence.
458  auto constFalse = rewriter.create<LLVM::ConstantOp>(
459  op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
460  auto noneToken = rewriter.create<LLVM::NoneTokenOp>(op->getLoc());
461 
462  // Mark the end of a coroutine: @llvm.coro.end.
463  auto coroHdl = adaptor.getHandle();
464  rewriter.create<LLVM::CoroEndOp>(
465  op->getLoc(), rewriter.getI1Type(),
466  ValueRange({coroHdl, constFalse, noneToken}));
467  rewriter.eraseOp(op);
468 
469  return success();
470  }
471 };
472 } // namespace
473 
474 //===----------------------------------------------------------------------===//
475 // Convert async.coro.save to @llvm.coro.save intrinsic.
476 //===----------------------------------------------------------------------===//
477 
478 namespace {
479 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
480 public:
482 
483  LogicalResult
484  matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor,
485  ConversionPatternRewriter &rewriter) const override {
486  // Save the coroutine state: @llvm.coro.save
487  rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
488  op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands());
489 
490  return success();
491  }
492 };
493 } // namespace
494 
495 //===----------------------------------------------------------------------===//
496 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
497 //===----------------------------------------------------------------------===//
498 
499 namespace {
500 
501 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
502 /// branch to the appropriate block based on the return code.
503 ///
504 /// Before:
505 ///
506 /// ^suspended:
507 /// "opBefore"(...)
508 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup
509 /// ^resume:
510 /// "op"(...)
511 /// ^cleanup: ...
512 /// ^suspend: ...
513 ///
514 /// After:
515 ///
516 /// ^suspended:
517 /// "opBefore"(...)
518 /// %suspend = llmv.intr.coro.suspend ...
519 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
520 /// ^resume:
521 /// "op"(...)
522 /// ^cleanup: ...
523 /// ^suspend: ...
524 ///
525 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
526 public:
528 
529  LogicalResult
530  matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor,
531  ConversionPatternRewriter &rewriter) const override {
532  auto i8 = rewriter.getIntegerType(8);
533  auto i32 = rewriter.getI32Type();
534  auto loc = op->getLoc();
535 
536  // This is not a final suspension point.
537  auto constFalse = rewriter.create<LLVM::ConstantOp>(
538  loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
539 
540  // Suspend a coroutine: @llvm.coro.suspend
541  auto coroState = adaptor.getState();
542  auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
543  loc, i8, ValueRange({coroState, constFalse}));
544 
545  // Cast return code to i32.
546 
547  // After a suspension point decide if we should branch into resume, cleanup
548  // or suspend block of the coroutine (see @llvm.coro.suspend return code
549  // documentation).
550  llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
551  llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(),
552  op.getCleanupDest()};
553  rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
554  op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
555  /*defaultDestination=*/op.getSuspendDest(),
556  /*defaultOperands=*/ValueRange(),
557  /*caseValues=*/caseValues,
558  /*caseDestinations=*/caseDest,
559  /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}),
560  /*branchWeights=*/ArrayRef<int32_t>());
561 
562  return success();
563  }
564 };
565 } // namespace
566 
567 //===----------------------------------------------------------------------===//
568 // Convert async.runtime.create to the corresponding runtime API call.
569 //
570 // To allocate storage for the async values we use getelementptr trick:
571 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
572 //===----------------------------------------------------------------------===//
573 
574 namespace {
575 class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> {
576 public:
578 
579  LogicalResult
580  matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
581  ConversionPatternRewriter &rewriter) const override {
582  const TypeConverter *converter = getTypeConverter();
583  Type resultType = op->getResultTypes()[0];
584 
585  // Tokens creation maps to a simple function call.
586  if (isa<TokenType>(resultType)) {
587  rewriter.replaceOpWithNewOp<func::CallOp>(
588  op, kCreateToken, converter->convertType(resultType));
589  return success();
590  }
591 
592  // To create a value we need to compute the storage requirement.
593  if (auto value = dyn_cast<ValueType>(resultType)) {
594  // Returns the size requirements for the async value storage.
595  auto sizeOf = [&](ValueType valueType) -> Value {
596  auto loc = op->getLoc();
597  auto i64 = rewriter.getI64Type();
598 
599  auto storedType = converter->convertType(valueType.getValueType());
600  auto storagePtrType =
601  AsyncAPI::opaquePointerType(rewriter.getContext());
602 
603  // %Size = getelementptr %T* null, int 1
604  // %SizeI = ptrtoint %T* %Size to i64
605  auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, storagePtrType);
606  auto gep =
607  rewriter.create<LLVM::GEPOp>(loc, storagePtrType, storedType,
608  nullPtr, ArrayRef<LLVM::GEPArg>{1});
609  return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
610  };
611 
612  rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType,
613  sizeOf(value));
614 
615  return success();
616  }
617 
618  return rewriter.notifyMatchFailure(op, "unsupported async type");
619  }
620 };
621 } // namespace
622 
623 //===----------------------------------------------------------------------===//
624 // Convert async.runtime.create_group to the corresponding runtime API call.
625 //===----------------------------------------------------------------------===//
626 
627 namespace {
628 class RuntimeCreateGroupOpLowering
629  : public ConvertOpToLLVMPattern<RuntimeCreateGroupOp> {
630 public:
632 
633  LogicalResult
634  matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
635  ConversionPatternRewriter &rewriter) const override {
636  const TypeConverter *converter = getTypeConverter();
637  Type resultType = op.getResult().getType();
638 
639  rewriter.replaceOpWithNewOp<func::CallOp>(
640  op, kCreateGroup, converter->convertType(resultType),
641  adaptor.getOperands());
642  return success();
643  }
644 };
645 } // namespace
646 
647 //===----------------------------------------------------------------------===//
648 // Convert async.runtime.set_available to the corresponding runtime API call.
649 //===----------------------------------------------------------------------===//
650 
651 namespace {
652 class RuntimeSetAvailableOpLowering
653  : public OpConversionPattern<RuntimeSetAvailableOp> {
654 public:
656 
657  LogicalResult
658  matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
659  ConversionPatternRewriter &rewriter) const override {
660  StringRef apiFuncName =
661  TypeSwitch<Type, StringRef>(op.getOperand().getType())
662  .Case<TokenType>([](Type) { return kEmplaceToken; })
663  .Case<ValueType>([](Type) { return kEmplaceValue; });
664 
665  rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
666  adaptor.getOperands());
667 
668  return success();
669  }
670 };
671 } // namespace
672 
673 //===----------------------------------------------------------------------===//
674 // Convert async.runtime.set_error to the corresponding runtime API call.
675 //===----------------------------------------------------------------------===//
676 
677 namespace {
678 class RuntimeSetErrorOpLowering
679  : public OpConversionPattern<RuntimeSetErrorOp> {
680 public:
682 
683  LogicalResult
684  matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
685  ConversionPatternRewriter &rewriter) const override {
686  StringRef apiFuncName =
687  TypeSwitch<Type, StringRef>(op.getOperand().getType())
688  .Case<TokenType>([](Type) { return kSetTokenError; })
689  .Case<ValueType>([](Type) { return kSetValueError; });
690 
691  rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
692  adaptor.getOperands());
693 
694  return success();
695  }
696 };
697 } // namespace
698 
699 //===----------------------------------------------------------------------===//
700 // Convert async.runtime.is_error to the corresponding runtime API call.
701 //===----------------------------------------------------------------------===//
702 
703 namespace {
704 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
705 public:
707 
708  LogicalResult
709  matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
710  ConversionPatternRewriter &rewriter) const override {
711  StringRef apiFuncName =
712  TypeSwitch<Type, StringRef>(op.getOperand().getType())
713  .Case<TokenType>([](Type) { return kIsTokenError; })
714  .Case<GroupType>([](Type) { return kIsGroupError; })
715  .Case<ValueType>([](Type) { return kIsValueError; });
716 
717  rewriter.replaceOpWithNewOp<func::CallOp>(
718  op, apiFuncName, rewriter.getI1Type(), adaptor.getOperands());
719  return success();
720  }
721 };
722 } // namespace
723 
724 //===----------------------------------------------------------------------===//
725 // Convert async.runtime.await to the corresponding runtime API call.
726 //===----------------------------------------------------------------------===//
727 
728 namespace {
729 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
730 public:
732 
733  LogicalResult
734  matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
735  ConversionPatternRewriter &rewriter) const override {
736  StringRef apiFuncName =
737  TypeSwitch<Type, StringRef>(op.getOperand().getType())
738  .Case<TokenType>([](Type) { return kAwaitToken; })
739  .Case<ValueType>([](Type) { return kAwaitValue; })
740  .Case<GroupType>([](Type) { return kAwaitGroup; });
741 
742  rewriter.create<func::CallOp>(op->getLoc(), apiFuncName, TypeRange(),
743  adaptor.getOperands());
744  rewriter.eraseOp(op);
745 
746  return success();
747  }
748 };
749 } // namespace
750 
751 //===----------------------------------------------------------------------===//
752 // Convert async.runtime.await_and_resume to the corresponding runtime API call.
753 //===----------------------------------------------------------------------===//
754 
755 namespace {
756 class RuntimeAwaitAndResumeOpLowering
757  : public AsyncOpConversionPattern<RuntimeAwaitAndResumeOp> {
758 public:
759  using AsyncOpConversionPattern::AsyncOpConversionPattern;
760 
761  LogicalResult
762  matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
763  ConversionPatternRewriter &rewriter) const override {
764  StringRef apiFuncName =
765  TypeSwitch<Type, StringRef>(op.getOperand().getType())
766  .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
767  .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
768  .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
769 
770  Value operand = adaptor.getOperand();
771  Value handle = adaptor.getHandle();
772 
773  // A pointer to coroutine resume intrinsic wrapper.
774  addResumeFunction(op->getParentOfType<ModuleOp>());
775  auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
776  op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
777  kResume);
778 
779  rewriter.create<func::CallOp>(
780  op->getLoc(), apiFuncName, TypeRange(),
781  ValueRange({operand, handle, resumePtr.getRes()}));
782  rewriter.eraseOp(op);
783 
784  return success();
785  }
786 };
787 } // namespace
788 
789 //===----------------------------------------------------------------------===//
790 // Convert async.runtime.resume to the corresponding runtime API call.
791 //===----------------------------------------------------------------------===//
792 
793 namespace {
794 class RuntimeResumeOpLowering
795  : public AsyncOpConversionPattern<RuntimeResumeOp> {
796 public:
797  using AsyncOpConversionPattern::AsyncOpConversionPattern;
798 
799  LogicalResult
800  matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
801  ConversionPatternRewriter &rewriter) const override {
802  // A pointer to coroutine resume intrinsic wrapper.
803  addResumeFunction(op->getParentOfType<ModuleOp>());
804  auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
805  op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
806  kResume);
807 
808  // Call async runtime API to execute a coroutine in the managed thread.
809  auto coroHdl = adaptor.getHandle();
810  rewriter.replaceOpWithNewOp<func::CallOp>(
811  op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()}));
812 
813  return success();
814  }
815 };
816 } // namespace
817 
818 //===----------------------------------------------------------------------===//
819 // Convert async.runtime.store to the corresponding runtime API call.
820 //===----------------------------------------------------------------------===//
821 
822 namespace {
823 class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> {
824 public:
826 
827  LogicalResult
828  matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
829  ConversionPatternRewriter &rewriter) const override {
830  Location loc = op->getLoc();
831 
832  // Get a pointer to the async value storage from the runtime.
833  auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
834  auto storage = adaptor.getStorage();
835  auto storagePtr = rewriter.create<func::CallOp>(
836  loc, kGetValueStorage, TypeRange(ptrType), storage);
837 
838  // Cast from i8* to the LLVM pointer type.
839  auto valueType = op.getValue().getType();
840  auto llvmValueType = getTypeConverter()->convertType(valueType);
841  if (!llvmValueType)
842  return rewriter.notifyMatchFailure(
843  op, "failed to convert stored value type to LLVM type");
844 
845  Value castedStoragePtr = storagePtr.getResult(0);
846  // Store the yielded value into the async value storage.
847  auto value = adaptor.getValue();
848  rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr);
849 
850  // Erase the original runtime store operation.
851  rewriter.eraseOp(op);
852 
853  return success();
854  }
855 };
856 } // namespace
857 
858 //===----------------------------------------------------------------------===//
859 // Convert async.runtime.load to the corresponding runtime API call.
860 //===----------------------------------------------------------------------===//
861 
862 namespace {
863 class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern<RuntimeLoadOp> {
864 public:
866 
867  LogicalResult
868  matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
869  ConversionPatternRewriter &rewriter) const override {
870  Location loc = op->getLoc();
871 
872  // Get a pointer to the async value storage from the runtime.
873  auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
874  auto storage = adaptor.getStorage();
875  auto storagePtr = rewriter.create<func::CallOp>(
876  loc, kGetValueStorage, TypeRange(ptrType), storage);
877 
878  // Cast from i8* to the LLVM pointer type.
879  auto valueType = op.getResult().getType();
880  auto llvmValueType = getTypeConverter()->convertType(valueType);
881  if (!llvmValueType)
882  return rewriter.notifyMatchFailure(
883  op, "failed to convert loaded value type to LLVM type");
884 
885  Value castedStoragePtr = storagePtr.getResult(0);
886 
887  // Load from the casted pointer.
888  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmValueType,
889  castedStoragePtr);
890 
891  return success();
892  }
893 };
894 } // namespace
895 
896 //===----------------------------------------------------------------------===//
897 // Convert async.runtime.add_to_group to the corresponding runtime API call.
898 //===----------------------------------------------------------------------===//
899 
900 namespace {
901 class RuntimeAddToGroupOpLowering
902  : public OpConversionPattern<RuntimeAddToGroupOp> {
903 public:
905 
906  LogicalResult
907  matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
908  ConversionPatternRewriter &rewriter) const override {
909  // Currently we can only add tokens to the group.
910  if (!isa<TokenType>(op.getOperand().getType()))
911  return rewriter.notifyMatchFailure(op, "only token type is supported");
912 
913  // Replace with a runtime API function call.
914  rewriter.replaceOpWithNewOp<func::CallOp>(
915  op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands());
916 
917  return success();
918  }
919 };
920 } // namespace
921 
922 //===----------------------------------------------------------------------===//
923 // Convert async.runtime.num_worker_threads to the corresponding runtime API
924 // call.
925 //===----------------------------------------------------------------------===//
926 
927 namespace {
928 class RuntimeNumWorkerThreadsOpLowering
929  : public OpConversionPattern<RuntimeNumWorkerThreadsOp> {
930 public:
932 
933  LogicalResult
934  matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor,
935  ConversionPatternRewriter &rewriter) const override {
936 
937  // Replace with a runtime API function call.
938  rewriter.replaceOpWithNewOp<func::CallOp>(op, kGetNumWorkerThreads,
939  rewriter.getIndexType());
940 
941  return success();
942  }
943 };
944 } // namespace
945 
946 //===----------------------------------------------------------------------===//
947 // Async reference counting ops lowering (`async.runtime.add_ref` and
948 // `async.runtime.drop_ref` to the corresponding API calls).
949 //===----------------------------------------------------------------------===//
950 
951 namespace {
952 template <typename RefCountingOp>
953 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
954 public:
955  explicit RefCountingOpLowering(const TypeConverter &converter,
956  MLIRContext *ctx, StringRef apiFunctionName)
957  : OpConversionPattern<RefCountingOp>(converter, ctx),
958  apiFunctionName(apiFunctionName) {}
959 
960  LogicalResult
961  matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
962  ConversionPatternRewriter &rewriter) const override {
963  auto count = rewriter.create<arith::ConstantOp>(
964  op->getLoc(), rewriter.getI64Type(),
965  rewriter.getI64IntegerAttr(op.getCount()));
966 
967  auto operand = adaptor.getOperand();
968  rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName,
969  ValueRange({operand, count}));
970 
971  return success();
972  }
973 
974 private:
975  StringRef apiFunctionName;
976 };
977 
978 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
979 public:
980  explicit RuntimeAddRefOpLowering(const TypeConverter &converter,
981  MLIRContext *ctx)
982  : RefCountingOpLowering(converter, ctx, kAddRef) {}
983 };
984 
985 class RuntimeDropRefOpLowering
986  : public RefCountingOpLowering<RuntimeDropRefOp> {
987 public:
988  explicit RuntimeDropRefOpLowering(const TypeConverter &converter,
989  MLIRContext *ctx)
990  : RefCountingOpLowering(converter, ctx, kDropRef) {}
991 };
992 } // namespace
993 
994 //===----------------------------------------------------------------------===//
995 // Convert return operations that return async values from async regions.
996 //===----------------------------------------------------------------------===//
997 
998 namespace {
999 class ReturnOpOpConversion : public OpConversionPattern<func::ReturnOp> {
1000 public:
1002 
1003  LogicalResult
1004  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
1005  ConversionPatternRewriter &rewriter) const override {
1006  rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
1007  return success();
1008  }
1009 };
1010 } // namespace
1011 
1012 //===----------------------------------------------------------------------===//
1013 
1014 namespace {
1015 struct ConvertAsyncToLLVMPass
1016  : public impl::ConvertAsyncToLLVMPassBase<ConvertAsyncToLLVMPass> {
1017  using Base::Base;
1018 
1019  void runOnOperation() override;
1020 };
1021 } // namespace
1022 
1023 void ConvertAsyncToLLVMPass::runOnOperation() {
1024  ModuleOp module = getOperation();
1025  MLIRContext *ctx = module->getContext();
1026 
1028 
1029  // Add declarations for most functions required by the coroutines lowering.
1030  // We delay adding the resume function until it's needed because it currently
1031  // fails to compile unless '-O0' is specified.
1033 
1034  // Lower async.runtime and async.coro operations to Async Runtime API and
1035  // LLVM coroutine intrinsics.
1036 
1037  // Convert async dialect types and operations to LLVM dialect.
1038  AsyncRuntimeTypeConverter converter(options);
1040 
1041  // We use conversion to LLVM type to lower async.runtime load and store
1042  // operations.
1043  LLVMTypeConverter llvmConverter(ctx, options);
1044  llvmConverter.addConversion([&](Type type) {
1045  return AsyncRuntimeTypeConverter::convertAsyncTypes(type);
1046  });
1047 
1048  // Convert async types in function signatures and function calls.
1049  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1050  converter);
1052 
1053  // Convert return operations inside async.execute regions.
1054  patterns.add<ReturnOpOpConversion>(converter, ctx);
1055 
1056  // Lower async.runtime operations to the async runtime API calls.
1057  patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
1058  RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
1059  RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
1060  RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering,
1061  RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter,
1062  ctx);
1063 
1064  // Lower async.runtime operations that rely on LLVM type converter to convert
1065  // from async value payload type to the LLVM type.
1066  patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
1067  RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter);
1068 
1069  // Lower async coroutine operations to LLVM coroutine intrinsics.
1070  patterns
1071  .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
1072  CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
1073  converter, ctx);
1074 
1075  ConversionTarget target(*ctx);
1076  target.addLegalOp<arith::ConstantOp, func::ConstantOp,
1077  UnrealizedConversionCastOp>();
1078  target.addLegalDialect<LLVM::LLVMDialect>();
1079 
1080  // All operations from Async dialect must be lowered to the runtime API and
1081  // LLVM intrinsics calls.
1082  target.addIllegalDialect<AsyncDialect>();
1083 
1084  // Add dynamic legality constraints to apply conversions defined above.
1085  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1086  return converter.isSignatureLegal(op.getFunctionType());
1087  });
1088  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
1089  return converter.isLegal(op.getOperandTypes());
1090  });
1091  target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
1092  return converter.isSignatureLegal(op.getCalleeType());
1093  });
1094 
1095  if (failed(applyPartialConversion(module, target, std::move(patterns))))
1096  signalPassFailure();
1097 }
1098 
1099 //===----------------------------------------------------------------------===//
1100 // Patterns for structural type conversions for the Async dialect operations.
1101 //===----------------------------------------------------------------------===//
1102 
1103 namespace {
1104 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1105 public:
1107  LogicalResult
1108  matchAndRewrite(ExecuteOp op, OpAdaptor adaptor,
1109  ConversionPatternRewriter &rewriter) const override {
1110  ExecuteOp newOp =
1111  cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1112  rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1113  newOp.getRegion().end());
1114 
1115  // Set operands and update block argument and result types.
1116  newOp->setOperands(adaptor.getOperands());
1117  if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1118  return failure();
1119  for (auto result : newOp.getResults())
1120  result.setType(typeConverter->convertType(result.getType()));
1121 
1122  rewriter.replaceOp(op, newOp.getResults());
1123  return success();
1124  }
1125 };
1126 
1127 // Dummy pattern to trigger the appropriate type conversion / materialization.
1128 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1129 public:
1131  LogicalResult
1132  matchAndRewrite(AwaitOp op, OpAdaptor adaptor,
1133  ConversionPatternRewriter &rewriter) const override {
1134  rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front());
1135  return success();
1136  }
1137 };
1138 
1139 // Dummy pattern to trigger the appropriate type conversion / materialization.
1140 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1141 public:
1143  LogicalResult
1144  matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
1145  ConversionPatternRewriter &rewriter) const override {
1146  rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands());
1147  return success();
1148  }
1149 };
1150 } // namespace
1151 
1153  TypeConverter &typeConverter, RewritePatternSet &patterns,
1154  ConversionTarget &target) {
1155  typeConverter.addConversion([&](TokenType type) { return type; });
1156  typeConverter.addConversion([&](ValueType type) {
1157  Type converted = typeConverter.convertType(type.getValueType());
1158  return converted ? ValueType::get(converted) : converted;
1159  });
1160 
1161  patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1162  typeConverter, patterns.getContext());
1163 
1164  target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1165  [&](Operation *op) { return typeConverter.isLegal(op); });
1166 }
static constexpr const char * kAwaitValueAndExecute
Definition: AsyncToLLVM.cpp:63
static constexpr const char * kCreateValue
Definition: AsyncToLLVM.cpp:44
static constexpr const char * kCreateGroup
Definition: AsyncToLLVM.cpp:45
static constexpr const char * kCreateToken
Definition: AsyncToLLVM.cpp:43
static constexpr const char * kEmplaceValue
Definition: AsyncToLLVM.cpp:47
static void addResumeFunction(ModuleOp module)
A function that takes a coroutine handle and calls a llvm.coro.resume intrinsics.
static constexpr const char * kEmplaceToken
Definition: AsyncToLLVM.cpp:46
static void addAsyncRuntimeApiDeclarations(ModuleOp module)
Adds Async Runtime C API declarations to the module.
static constexpr const char * kResume
static constexpr const char * kAddRef
Definition: AsyncToLLVM.cpp:41
static constexpr const char * kAwaitTokenAndExecute
Definition: AsyncToLLVM.cpp:61
static constexpr const char * kAwaitValue
Definition: AsyncToLLVM.cpp:54
static constexpr const char * kSetTokenError
Definition: AsyncToLLVM.cpp:48
static constexpr const char * kExecute
Definition: AsyncToLLVM.cpp:56
static constexpr const char * kAddTokenToGroup
Definition: AsyncToLLVM.cpp:59
static constexpr const char * kIsGroupError
Definition: AsyncToLLVM.cpp:52
static constexpr const char * kSetValueError
Definition: AsyncToLLVM.cpp:49
static constexpr const char * kIsTokenError
Definition: AsyncToLLVM.cpp:50
static constexpr const char * kAwaitGroup
Definition: AsyncToLLVM.cpp:55
static constexpr const char * kAwaitAllAndExecute
Definition: AsyncToLLVM.cpp:65
static constexpr const char * kGetNumWorkerThreads
Definition: AsyncToLLVM.cpp:67
static constexpr const char * kDropRef
Definition: AsyncToLLVM.cpp:42
static constexpr const char * kIsValueError
Definition: AsyncToLLVM.cpp:51
static constexpr const char * kAwaitToken
Definition: AsyncToLLVM.cpp:53
static constexpr const char * kGetValueStorage
Definition: AsyncToLLVM.cpp:57
static llvm::ManagedStatic< PassManagerOptions > options
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:95
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:199
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:205
static ImplicitLocOpBuilder atBlockEnd(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:585
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.