MLIR  16.0.0git
AsyncToLLVM.cpp
Go to the documentation of this file.
1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/Pass/Pass.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 namespace mlir {
27 #define GEN_PASS_DEF_CONVERTASYNCTOLLVM
28 #include "mlir/Conversion/Passes.h.inc"
29 } // namespace mlir
30 
31 #define DEBUG_TYPE "convert-async-to-llvm"
32 
33 using namespace mlir;
34 using namespace mlir::async;
35 
36 //===----------------------------------------------------------------------===//
37 // Async Runtime C API declaration.
38 //===----------------------------------------------------------------------===//
39 
40 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
41 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
42 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
43 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
44 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
45 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
46 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
47 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
48 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
49 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
50 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
51 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
52 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
53 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
54 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
55 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
56 static constexpr const char *kGetValueStorage =
57  "mlirAsyncRuntimeGetValueStorage";
58 static constexpr const char *kAddTokenToGroup =
59  "mlirAsyncRuntimeAddTokenToGroup";
60 static constexpr const char *kAwaitTokenAndExecute =
61  "mlirAsyncRuntimeAwaitTokenAndExecute";
62 static constexpr const char *kAwaitValueAndExecute =
63  "mlirAsyncRuntimeAwaitValueAndExecute";
64 static constexpr const char *kAwaitAllAndExecute =
65  "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
66 static constexpr const char *kGetNumWorkerThreads =
67  "mlirAsyncRuntimGetNumWorkerThreads";
68 
69 namespace {
70 /// Async Runtime API function types.
71 ///
72 /// Because we can't create API function signature for type parametrized
73 /// async.getValue type, we use opaque pointers (!llvm.ptr<i8>) instead. After
74 /// lowering all async data types become opaque pointers at runtime.
75 struct AsyncAPI {
76  // All async types are lowered to opaque i8* LLVM pointers at runtime.
77  static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
78  return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
79  }
80 
81  static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
82  return LLVM::LLVMTokenType::get(ctx);
83  }
84 
85  static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
86  auto ref = opaquePointerType(ctx);
87  auto count = IntegerType::get(ctx, 64);
88  return FunctionType::get(ctx, {ref, count}, {});
89  }
90 
91  static FunctionType createTokenFunctionType(MLIRContext *ctx) {
92  return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
93  }
94 
95  static FunctionType createValueFunctionType(MLIRContext *ctx) {
96  auto i64 = IntegerType::get(ctx, 64);
97  auto value = opaquePointerType(ctx);
98  return FunctionType::get(ctx, {i64}, {value});
99  }
100 
101  static FunctionType createGroupFunctionType(MLIRContext *ctx) {
102  auto i64 = IntegerType::get(ctx, 64);
103  return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
104  }
105 
106  static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
107  auto value = opaquePointerType(ctx);
108  auto storage = opaquePointerType(ctx);
109  return FunctionType::get(ctx, {value}, {storage});
110  }
111 
112  static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
113  return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
114  }
115 
116  static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
117  auto value = opaquePointerType(ctx);
118  return FunctionType::get(ctx, {value}, {});
119  }
120 
121  static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
122  return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
123  }
124 
125  static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
126  auto value = opaquePointerType(ctx);
127  return FunctionType::get(ctx, {value}, {});
128  }
129 
130  static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
131  auto i1 = IntegerType::get(ctx, 1);
132  return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
133  }
134 
135  static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
136  auto value = opaquePointerType(ctx);
137  auto i1 = IntegerType::get(ctx, 1);
138  return FunctionType::get(ctx, {value}, {i1});
139  }
140 
141  static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
142  auto i1 = IntegerType::get(ctx, 1);
143  return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
144  }
145 
146  static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
147  return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
148  }
149 
150  static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
151  auto value = opaquePointerType(ctx);
152  return FunctionType::get(ctx, {value}, {});
153  }
154 
155  static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
156  return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
157  }
158 
159  static FunctionType executeFunctionType(MLIRContext *ctx) {
160  auto hdl = opaquePointerType(ctx);
161  auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
162  return FunctionType::get(ctx, {hdl, resume}, {});
163  }
164 
165  static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
166  auto i64 = IntegerType::get(ctx, 64);
167  return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
168  {i64});
169  }
170 
171  static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
172  auto hdl = opaquePointerType(ctx);
173  auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
174  return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
175  }
176 
177  static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
178  auto value = opaquePointerType(ctx);
179  auto hdl = opaquePointerType(ctx);
180  auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
181  return FunctionType::get(ctx, {value, hdl, resume}, {});
182  }
183 
184  static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
185  auto hdl = opaquePointerType(ctx);
186  auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
187  return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
188  }
189 
190  static FunctionType getNumWorkerThreads(MLIRContext *ctx) {
191  return FunctionType::get(ctx, {}, {IndexType::get(ctx)});
192  }
193 
194  // Auxiliary coroutine resume intrinsic wrapper.
195  static Type resumeFunctionType(MLIRContext *ctx) {
196  auto voidTy = LLVM::LLVMVoidType::get(ctx);
197  auto i8Ptr = opaquePointerType(ctx);
198  return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
199  }
200 };
201 } // namespace
202 
203 /// Adds Async Runtime C API declarations to the module.
204 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
205  auto builder =
206  ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
207 
208  auto addFuncDecl = [&](StringRef name, FunctionType type) {
209  if (module.lookupSymbol(name))
210  return;
211  builder.create<func::FuncOp>(name, type).setPrivate();
212  };
213 
214  MLIRContext *ctx = module.getContext();
215  addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
216  addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
217  addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
218  addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
219  addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
220  addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
221  addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
222  addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
223  addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
224  addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
225  addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
226  addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
227  addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
228  addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
229  addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
230  addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
231  addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
232  addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
233  addFuncDecl(kAwaitTokenAndExecute,
234  AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
235  addFuncDecl(kAwaitValueAndExecute,
236  AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
237  addFuncDecl(kAwaitAllAndExecute,
238  AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
239  addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx));
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // Coroutine resume function wrapper.
244 //===----------------------------------------------------------------------===//
245 
246 static constexpr const char *kResume = "__resume";
247 
248 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
249 /// intrinsics. We need this function to be able to pass it to the async
250 /// runtime execute API.
251 static void addResumeFunction(ModuleOp module) {
252  if (module.lookupSymbol(kResume))
253  return;
254 
255  MLIRContext *ctx = module.getContext();
256  auto loc = module.getLoc();
257  auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
258 
259  auto voidTy = LLVM::LLVMVoidType::get(ctx);
260  auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
261 
262  auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
263  kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
264  resumeOp.setPrivate();
265 
266  auto *block = resumeOp.addEntryBlock();
267  auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
268 
269  blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
270  blockBuilder.create<LLVM::ReturnOp>(ValueRange());
271 }
272 
273 //===----------------------------------------------------------------------===//
274 // Convert Async dialect types to LLVM types.
275 //===----------------------------------------------------------------------===//
276 
277 namespace {
278 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
279 /// their runtime type (opaque pointers) and does not convert any other types.
280 class AsyncRuntimeTypeConverter : public TypeConverter {
281 public:
282  AsyncRuntimeTypeConverter() {
283  addConversion([](Type type) { return type; });
284  addConversion(convertAsyncTypes);
285 
286  // Use UnrealizedConversionCast as the bridge so that we don't need to pull
287  // in patterns for other dialects.
288  auto addUnrealizedCast = [](OpBuilder &builder, Type type,
289  ValueRange inputs, Location loc) {
290  auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
291  return Optional<Value>(cast.getResult(0));
292  };
293 
294  addSourceMaterialization(addUnrealizedCast);
295  addTargetMaterialization(addUnrealizedCast);
296  }
297 
298  static Optional<Type> convertAsyncTypes(Type type) {
299  if (type.isa<TokenType, GroupType, ValueType>())
300  return AsyncAPI::opaquePointerType(type.getContext());
301 
302  if (type.isa<CoroIdType, CoroStateType>())
303  return AsyncAPI::tokenType(type.getContext());
304  if (type.isa<CoroHandleType>())
305  return AsyncAPI::opaquePointerType(type.getContext());
306 
307  return llvm::None;
308  }
309 };
310 } // namespace
311 
312 //===----------------------------------------------------------------------===//
313 // Convert async.coro.id to @llvm.coro.id intrinsic.
314 //===----------------------------------------------------------------------===//
315 
316 namespace {
317 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> {
318 public:
320 
322  matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
323  ConversionPatternRewriter &rewriter) const override {
324  auto token = AsyncAPI::tokenType(op->getContext());
325  auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
326  auto loc = op->getLoc();
327 
328  // Constants for initializing coroutine frame.
329  auto constZero =
330  rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
331  auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr);
332 
333  // Get coroutine id: @llvm.coro.id.
334  rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
335  op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
336 
337  return success();
338  }
339 };
340 } // namespace
341 
342 //===----------------------------------------------------------------------===//
343 // Convert async.coro.begin to @llvm.coro.begin intrinsic.
344 //===----------------------------------------------------------------------===//
345 
346 namespace {
347 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
348 public:
350 
352  matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
353  ConversionPatternRewriter &rewriter) const override {
354  auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
355  auto loc = op->getLoc();
356 
357  // Get coroutine frame size: @llvm.coro.size.i64.
358  Value coroSize =
359  rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
360  // Get coroutine frame alignment: @llvm.coro.align.i64.
361  Value coroAlign =
362  rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type());
363 
364  // Round up the size to be multiple of the alignment. Since aligned_alloc
365  // requires the size parameter be an integral multiple of the alignment
366  // parameter.
367  auto makeConstant = [&](uint64_t c) {
368  return rewriter.create<LLVM::ConstantOp>(op->getLoc(),
369  rewriter.getI64Type(), c);
370  };
371  coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign);
372  coroSize =
373  rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1));
374  Value negCoroAlign =
375  rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign);
376  coroSize =
377  rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign);
378 
379  // Allocate memory for the coroutine frame.
380  auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
381  op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
382  auto coroAlloc = rewriter.create<LLVM::CallOp>(
383  loc, allocFuncOp, ValueRange{coroAlign, coroSize});
384 
385  // Begin a coroutine: @llvm.coro.begin.
386  auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
387  rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
388  op, i8Ptr, ValueRange({coroId, coroAlloc.getResult()}));
389 
390  return success();
391  }
392 };
393 } // namespace
394 
395 //===----------------------------------------------------------------------===//
396 // Convert async.coro.free to @llvm.coro.free intrinsic.
397 //===----------------------------------------------------------------------===//
398 
399 namespace {
400 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
401 public:
403 
405  matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
406  ConversionPatternRewriter &rewriter) const override {
407  auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
408  auto loc = op->getLoc();
409 
410  // Get a pointer to the coroutine frame memory: @llvm.coro.free.
411  auto coroMem =
412  rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, adaptor.getOperands());
413 
414  // Free the memory.
415  auto freeFuncOp =
416  LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
417  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp,
418  ValueRange(coroMem.getResult()));
419 
420  return success();
421  }
422 };
423 } // namespace
424 
425 //===----------------------------------------------------------------------===//
426 // Convert async.coro.end to @llvm.coro.end intrinsic.
427 //===----------------------------------------------------------------------===//
428 
429 namespace {
430 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
431 public:
433 
435  matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
436  ConversionPatternRewriter &rewriter) const override {
437  // We are not in the block that is part of the unwind sequence.
438  auto constFalse = rewriter.create<LLVM::ConstantOp>(
439  op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
440 
441  // Mark the end of a coroutine: @llvm.coro.end.
442  auto coroHdl = adaptor.getHandle();
443  rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(),
444  ValueRange({coroHdl, constFalse}));
445  rewriter.eraseOp(op);
446 
447  return success();
448  }
449 };
450 } // namespace
451 
452 //===----------------------------------------------------------------------===//
453 // Convert async.coro.save to @llvm.coro.save intrinsic.
454 //===----------------------------------------------------------------------===//
455 
456 namespace {
457 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
458 public:
460 
462  matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor,
463  ConversionPatternRewriter &rewriter) const override {
464  // Save the coroutine state: @llvm.coro.save
465  rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
466  op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands());
467 
468  return success();
469  }
470 };
471 } // namespace
472 
473 //===----------------------------------------------------------------------===//
474 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
475 //===----------------------------------------------------------------------===//
476 
477 namespace {
478 
479 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
480 /// branch to the appropriate block based on the return code.
481 ///
482 /// Before:
483 ///
484 /// ^suspended:
485 /// "opBefore"(...)
486 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup
487 /// ^resume:
488 /// "op"(...)
489 /// ^cleanup: ...
490 /// ^suspend: ...
491 ///
492 /// After:
493 ///
494 /// ^suspended:
495 /// "opBefore"(...)
496 /// %suspend = llmv.intr.coro.suspend ...
497 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
498 /// ^resume:
499 /// "op"(...)
500 /// ^cleanup: ...
501 /// ^suspend: ...
502 ///
503 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
504 public:
506 
508  matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor,
509  ConversionPatternRewriter &rewriter) const override {
510  auto i8 = rewriter.getIntegerType(8);
511  auto i32 = rewriter.getI32Type();
512  auto loc = op->getLoc();
513 
514  // This is not a final suspension point.
515  auto constFalse = rewriter.create<LLVM::ConstantOp>(
516  loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
517 
518  // Suspend a coroutine: @llvm.coro.suspend
519  auto coroState = adaptor.getState();
520  auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
521  loc, i8, ValueRange({coroState, constFalse}));
522 
523  // Cast return code to i32.
524 
525  // After a suspension point decide if we should branch into resume, cleanup
526  // or suspend block of the coroutine (see @llvm.coro.suspend return code
527  // documentation).
528  llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
529  llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(),
530  op.getCleanupDest()};
531  rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
532  op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
533  /*defaultDestination=*/op.getSuspendDest(),
534  /*defaultOperands=*/ValueRange(),
535  /*caseValues=*/caseValues,
536  /*caseDestinations=*/caseDest,
537  /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}),
538  /*branchWeights=*/ArrayRef<int32_t>());
539 
540  return success();
541  }
542 };
543 } // namespace
544 
545 //===----------------------------------------------------------------------===//
546 // Convert async.runtime.create to the corresponding runtime API call.
547 //
548 // To allocate storage for the async values we use getelementptr trick:
549 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
550 //===----------------------------------------------------------------------===//
551 
552 namespace {
553 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
554 public:
556 
558  matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
559  ConversionPatternRewriter &rewriter) const override {
560  TypeConverter *converter = getTypeConverter();
561  Type resultType = op->getResultTypes()[0];
562 
563  // Tokens creation maps to a simple function call.
564  if (resultType.isa<TokenType>()) {
565  rewriter.replaceOpWithNewOp<func::CallOp>(
566  op, kCreateToken, converter->convertType(resultType));
567  return success();
568  }
569 
570  // To create a value we need to compute the storage requirement.
571  if (auto value = resultType.dyn_cast<ValueType>()) {
572  // Returns the size requirements for the async value storage.
573  auto sizeOf = [&](ValueType valueType) -> Value {
574  auto loc = op->getLoc();
575  auto i64 = rewriter.getI64Type();
576 
577  auto storedType = converter->convertType(valueType.getValueType());
578  auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
579 
580  // %Size = getelementptr %T* null, int 1
581  // %SizeI = ptrtoint %T* %Size to i64
582  auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType);
583  auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
585  return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
586  };
587 
588  rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType,
589  sizeOf(value));
590 
591  return success();
592  }
593 
594  return rewriter.notifyMatchFailure(op, "unsupported async type");
595  }
596 };
597 } // namespace
598 
599 //===----------------------------------------------------------------------===//
600 // Convert async.runtime.create_group to the corresponding runtime API call.
601 //===----------------------------------------------------------------------===//
602 
603 namespace {
604 class RuntimeCreateGroupOpLowering
605  : public OpConversionPattern<RuntimeCreateGroupOp> {
606 public:
608 
610  matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
611  ConversionPatternRewriter &rewriter) const override {
612  TypeConverter *converter = getTypeConverter();
613  Type resultType = op.getResult().getType();
614 
615  rewriter.replaceOpWithNewOp<func::CallOp>(
616  op, kCreateGroup, converter->convertType(resultType),
617  adaptor.getOperands());
618  return success();
619  }
620 };
621 } // namespace
622 
623 //===----------------------------------------------------------------------===//
624 // Convert async.runtime.set_available to the corresponding runtime API call.
625 //===----------------------------------------------------------------------===//
626 
627 namespace {
628 class RuntimeSetAvailableOpLowering
629  : public OpConversionPattern<RuntimeSetAvailableOp> {
630 public:
632 
634  matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
635  ConversionPatternRewriter &rewriter) const override {
636  StringRef apiFuncName =
637  TypeSwitch<Type, StringRef>(op.getOperand().getType())
638  .Case<TokenType>([](Type) { return kEmplaceToken; })
639  .Case<ValueType>([](Type) { return kEmplaceValue; });
640 
641  rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
642  adaptor.getOperands());
643 
644  return success();
645  }
646 };
647 } // namespace
648 
649 //===----------------------------------------------------------------------===//
650 // Convert async.runtime.set_error to the corresponding runtime API call.
651 //===----------------------------------------------------------------------===//
652 
653 namespace {
654 class RuntimeSetErrorOpLowering
655  : public OpConversionPattern<RuntimeSetErrorOp> {
656 public:
658 
660  matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
661  ConversionPatternRewriter &rewriter) const override {
662  StringRef apiFuncName =
663  TypeSwitch<Type, StringRef>(op.getOperand().getType())
664  .Case<TokenType>([](Type) { return kSetTokenError; })
665  .Case<ValueType>([](Type) { return kSetValueError; });
666 
667  rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
668  adaptor.getOperands());
669 
670  return success();
671  }
672 };
673 } // namespace
674 
675 //===----------------------------------------------------------------------===//
676 // Convert async.runtime.is_error to the corresponding runtime API call.
677 //===----------------------------------------------------------------------===//
678 
679 namespace {
680 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
681 public:
683 
685  matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
686  ConversionPatternRewriter &rewriter) const override {
687  StringRef apiFuncName =
688  TypeSwitch<Type, StringRef>(op.getOperand().getType())
689  .Case<TokenType>([](Type) { return kIsTokenError; })
690  .Case<GroupType>([](Type) { return kIsGroupError; })
691  .Case<ValueType>([](Type) { return kIsValueError; });
692 
693  rewriter.replaceOpWithNewOp<func::CallOp>(
694  op, apiFuncName, rewriter.getI1Type(), adaptor.getOperands());
695  return success();
696  }
697 };
698 } // namespace
699 
700 //===----------------------------------------------------------------------===//
701 // Convert async.runtime.await to the corresponding runtime API call.
702 //===----------------------------------------------------------------------===//
703 
704 namespace {
705 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
706 public:
708 
710  matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
711  ConversionPatternRewriter &rewriter) const override {
712  StringRef apiFuncName =
713  TypeSwitch<Type, StringRef>(op.getOperand().getType())
714  .Case<TokenType>([](Type) { return kAwaitToken; })
715  .Case<ValueType>([](Type) { return kAwaitValue; })
716  .Case<GroupType>([](Type) { return kAwaitGroup; });
717 
718  rewriter.create<func::CallOp>(op->getLoc(), apiFuncName, TypeRange(),
719  adaptor.getOperands());
720  rewriter.eraseOp(op);
721 
722  return success();
723  }
724 };
725 } // namespace
726 
727 //===----------------------------------------------------------------------===//
728 // Convert async.runtime.await_and_resume to the corresponding runtime API call.
729 //===----------------------------------------------------------------------===//
730 
731 namespace {
732 class RuntimeAwaitAndResumeOpLowering
733  : public OpConversionPattern<RuntimeAwaitAndResumeOp> {
734 public:
736 
738  matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
739  ConversionPatternRewriter &rewriter) const override {
740  StringRef apiFuncName =
741  TypeSwitch<Type, StringRef>(op.getOperand().getType())
742  .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
743  .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
744  .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
745 
746  Value operand = adaptor.getOperand();
747  Value handle = adaptor.getHandle();
748 
749  // A pointer to coroutine resume intrinsic wrapper.
750  addResumeFunction(op->getParentOfType<ModuleOp>());
751  auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
752  auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
753  op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
754 
755  rewriter.create<func::CallOp>(
756  op->getLoc(), apiFuncName, TypeRange(),
757  ValueRange({operand, handle, resumePtr.getRes()}));
758  rewriter.eraseOp(op);
759 
760  return success();
761  }
762 };
763 } // namespace
764 
765 //===----------------------------------------------------------------------===//
766 // Convert async.runtime.resume to the corresponding runtime API call.
767 //===----------------------------------------------------------------------===//
768 
769 namespace {
770 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> {
771 public:
773 
775  matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
776  ConversionPatternRewriter &rewriter) const override {
777  // A pointer to coroutine resume intrinsic wrapper.
778  addResumeFunction(op->getParentOfType<ModuleOp>());
779  auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
780  auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
781  op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
782 
783  // Call async runtime API to execute a coroutine in the managed thread.
784  auto coroHdl = adaptor.getHandle();
785  rewriter.replaceOpWithNewOp<func::CallOp>(
786  op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()}));
787 
788  return success();
789  }
790 };
791 } // namespace
792 
793 //===----------------------------------------------------------------------===//
794 // Convert async.runtime.store to the corresponding runtime API call.
795 //===----------------------------------------------------------------------===//
796 
797 namespace {
798 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> {
799 public:
801 
803  matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
804  ConversionPatternRewriter &rewriter) const override {
805  Location loc = op->getLoc();
806 
807  // Get a pointer to the async value storage from the runtime.
808  auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
809  auto storage = adaptor.getStorage();
810  auto storagePtr = rewriter.create<func::CallOp>(loc, kGetValueStorage,
811  TypeRange(i8Ptr), storage);
812 
813  // Cast from i8* to the LLVM pointer type.
814  auto valueType = op.getValue().getType();
815  auto llvmValueType = getTypeConverter()->convertType(valueType);
816  if (!llvmValueType)
817  return rewriter.notifyMatchFailure(
818  op, "failed to convert stored value type to LLVM type");
819 
820  auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
821  loc, LLVM::LLVMPointerType::get(llvmValueType),
822  storagePtr.getResult(0));
823 
824  // Store the yielded value into the async value storage.
825  auto value = adaptor.getValue();
826  rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult());
827 
828  // Erase the original runtime store operation.
829  rewriter.eraseOp(op);
830 
831  return success();
832  }
833 };
834 } // namespace
835 
836 //===----------------------------------------------------------------------===//
837 // Convert async.runtime.load to the corresponding runtime API call.
838 //===----------------------------------------------------------------------===//
839 
840 namespace {
841 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> {
842 public:
844 
846  matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
847  ConversionPatternRewriter &rewriter) const override {
848  Location loc = op->getLoc();
849 
850  // Get a pointer to the async value storage from the runtime.
851  auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
852  auto storage = adaptor.getStorage();
853  auto storagePtr = rewriter.create<func::CallOp>(loc, kGetValueStorage,
854  TypeRange(i8Ptr), storage);
855 
856  // Cast from i8* to the LLVM pointer type.
857  auto valueType = op.getResult().getType();
858  auto llvmValueType = getTypeConverter()->convertType(valueType);
859  if (!llvmValueType)
860  return rewriter.notifyMatchFailure(
861  op, "failed to convert loaded value type to LLVM type");
862 
863  auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
864  loc, LLVM::LLVMPointerType::get(llvmValueType),
865  storagePtr.getResult(0));
866 
867  // Load from the casted pointer.
868  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult());
869 
870  return success();
871  }
872 };
873 } // namespace
874 
875 //===----------------------------------------------------------------------===//
876 // Convert async.runtime.add_to_group to the corresponding runtime API call.
877 //===----------------------------------------------------------------------===//
878 
879 namespace {
880 class RuntimeAddToGroupOpLowering
881  : public OpConversionPattern<RuntimeAddToGroupOp> {
882 public:
884 
886  matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
887  ConversionPatternRewriter &rewriter) const override {
888  // Currently we can only add tokens to the group.
889  if (!op.getOperand().getType().isa<TokenType>())
890  return rewriter.notifyMatchFailure(op, "only token type is supported");
891 
892  // Replace with a runtime API function call.
893  rewriter.replaceOpWithNewOp<func::CallOp>(
894  op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands());
895 
896  return success();
897  }
898 };
899 } // namespace
900 
901 //===----------------------------------------------------------------------===//
902 // Convert async.runtime.num_worker_threads to the corresponding runtime API
903 // call.
904 //===----------------------------------------------------------------------===//
905 
906 namespace {
907 class RuntimeNumWorkerThreadsOpLowering
908  : public OpConversionPattern<RuntimeNumWorkerThreadsOp> {
909 public:
911 
913  matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor,
914  ConversionPatternRewriter &rewriter) const override {
915 
916  // Replace with a runtime API function call.
917  rewriter.replaceOpWithNewOp<func::CallOp>(op, kGetNumWorkerThreads,
918  rewriter.getIndexType());
919 
920  return success();
921  }
922 };
923 } // namespace
924 
925 //===----------------------------------------------------------------------===//
926 // Async reference counting ops lowering (`async.runtime.add_ref` and
927 // `async.runtime.drop_ref` to the corresponding API calls).
928 //===----------------------------------------------------------------------===//
929 
930 namespace {
931 template <typename RefCountingOp>
932 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
933 public:
934  explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
935  StringRef apiFunctionName)
936  : OpConversionPattern<RefCountingOp>(converter, ctx),
937  apiFunctionName(apiFunctionName) {}
938 
940  matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
941  ConversionPatternRewriter &rewriter) const override {
942  auto count = rewriter.create<arith::ConstantOp>(
943  op->getLoc(), rewriter.getI64Type(),
944  rewriter.getI64IntegerAttr(op.getCount()));
945 
946  auto operand = adaptor.getOperand();
947  rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName,
948  ValueRange({operand, count}));
949 
950  return success();
951  }
952 
953 private:
954  StringRef apiFunctionName;
955 };
956 
957 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
958 public:
959  explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
960  : RefCountingOpLowering(converter, ctx, kAddRef) {}
961 };
962 
963 class RuntimeDropRefOpLowering
964  : public RefCountingOpLowering<RuntimeDropRefOp> {
965 public:
966  explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
967  : RefCountingOpLowering(converter, ctx, kDropRef) {}
968 };
969 } // namespace
970 
971 //===----------------------------------------------------------------------===//
972 // Convert return operations that return async values from async regions.
973 //===----------------------------------------------------------------------===//
974 
975 namespace {
976 class ReturnOpOpConversion : public OpConversionPattern<func::ReturnOp> {
977 public:
979 
981  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
982  ConversionPatternRewriter &rewriter) const override {
983  rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
984  return success();
985  }
986 };
987 } // namespace
988 
989 //===----------------------------------------------------------------------===//
990 
991 namespace {
992 struct ConvertAsyncToLLVMPass
993  : public impl::ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
994  void runOnOperation() override;
995 };
996 } // namespace
997 
998 void ConvertAsyncToLLVMPass::runOnOperation() {
999  ModuleOp module = getOperation();
1000  MLIRContext *ctx = module->getContext();
1001 
1002  // Add declarations for most functions required by the coroutines lowering.
1003  // We delay adding the resume function until it's needed because it currently
1004  // fails to compile unless '-O0' is specified.
1006 
1007  // Lower async.runtime and async.coro operations to Async Runtime API and
1008  // LLVM coroutine intrinsics.
1009 
1010  // Convert async dialect types and operations to LLVM dialect.
1011  AsyncRuntimeTypeConverter converter;
1012  RewritePatternSet patterns(ctx);
1013 
1014  // We use conversion to LLVM type to lower async.runtime load and store
1015  // operations.
1016  LLVMTypeConverter llvmConverter(ctx);
1017  llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
1018 
1019  // Convert async types in function signatures and function calls.
1020  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1021  converter);
1022  populateCallOpTypeConversionPattern(patterns, converter);
1023 
1024  // Convert return operations inside async.execute regions.
1025  patterns.add<ReturnOpOpConversion>(converter, ctx);
1026 
1027  // Lower async.runtime operations to the async runtime API calls.
1028  patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
1029  RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
1030  RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
1031  RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering,
1032  RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter,
1033  ctx);
1034 
1035  // Lower async.runtime operations that rely on LLVM type converter to convert
1036  // from async value payload type to the LLVM type.
1037  patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
1038  RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter,
1039  ctx);
1040 
1041  // Lower async coroutine operations to LLVM coroutine intrinsics.
1042  patterns
1043  .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
1044  CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
1045  converter, ctx);
1046 
1047  ConversionTarget target(*ctx);
1048  target.addLegalOp<arith::ConstantOp, func::ConstantOp,
1049  UnrealizedConversionCastOp>();
1050  target.addLegalDialect<LLVM::LLVMDialect>();
1051 
1052  // All operations from Async dialect must be lowered to the runtime API and
1053  // LLVM intrinsics calls.
1054  target.addIllegalDialect<AsyncDialect>();
1055 
1056  // Add dynamic legality constraints to apply conversions defined above.
1057  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1058  return converter.isSignatureLegal(op.getFunctionType());
1059  });
1060  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
1061  return converter.isLegal(op.getOperandTypes());
1062  });
1063  target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
1064  return converter.isSignatureLegal(op.getCalleeType());
1065  });
1066 
1067  if (failed(applyPartialConversion(module, target, std::move(patterns))))
1068  signalPassFailure();
1069 }
1070 
1071 //===----------------------------------------------------------------------===//
1072 // Patterns for structural type conversions for the Async dialect operations.
1073 //===----------------------------------------------------------------------===//
1074 
1075 namespace {
1076 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1077 public:
1080  matchAndRewrite(ExecuteOp op, OpAdaptor adaptor,
1081  ConversionPatternRewriter &rewriter) const override {
1082  ExecuteOp newOp =
1083  cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1084  rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1085  newOp.getRegion().end());
1086 
1087  // Set operands and update block argument and result types.
1088  newOp->setOperands(adaptor.getOperands());
1089  if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1090  return failure();
1091  for (auto result : newOp.getResults())
1092  result.setType(typeConverter->convertType(result.getType()));
1093 
1094  rewriter.replaceOp(op, newOp.getResults());
1095  return success();
1096  }
1097 };
1098 
1099 // Dummy pattern to trigger the appropriate type conversion / materialization.
1100 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1101 public:
1104  matchAndRewrite(AwaitOp op, OpAdaptor adaptor,
1105  ConversionPatternRewriter &rewriter) const override {
1106  rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front());
1107  return success();
1108  }
1109 };
1110 
1111 // Dummy pattern to trigger the appropriate type conversion / materialization.
1112 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1113 public:
1116  matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
1117  ConversionPatternRewriter &rewriter) const override {
1118  rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands());
1119  return success();
1120  }
1121 };
1122 } // namespace
1123 
1124 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
1125  return std::make_unique<ConvertAsyncToLLVMPass>();
1126 }
1127 
1129  TypeConverter &typeConverter, RewritePatternSet &patterns,
1130  ConversionTarget &target) {
1131  typeConverter.addConversion([&](TokenType type) { return type; });
1132  typeConverter.addConversion([&](ValueType type) {
1133  Type converted = typeConverter.convertType(type.getValueType());
1134  return converted ? ValueType::get(converted) : converted;
1135  });
1136 
1137  patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1138  typeConverter, patterns.getContext());
1139 
1140  target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1141  [&](Operation *op) { return typeConverter.isLegal(op); });
1142 }
Include the generated interface declarations.
static constexpr const char * kAddTokenToGroup
Definition: AsyncToLLVM.cpp:58
MLIRContext * getContext() const
Definition: Builders.h:54
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
static constexpr const char * kIsGroupError
Definition: AsyncToLLVM.cpp:51
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
static ImplicitLocOpBuilder atBlockEnd(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
static constexpr const char * kGetNumWorkerThreads
Definition: AsyncToLLVM.cpp:66
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
static constexpr const char * kEmplaceToken
Definition: AsyncToLLVM.cpp:45
Operation * cloneWithoutRegions(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:525
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static constexpr const char * kCreateToken
Definition: AsyncToLLVM.cpp:42
static LLVMFunctionType get(Type result, ArrayRef< Type > arguments, bool isVarArg=false)
Gets or creates an instance of LLVM dialect function in the same context as the result type...
Definition: LLVMTypes.cpp:112
static constexpr const char * kAwaitToken
Definition: AsyncToLLVM.cpp:52
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
static constexpr const char * kSetTokenError
Definition: AsyncToLLVM.cpp:47
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:418
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static constexpr const char * kAwaitValue
Definition: AsyncToLLVM.cpp:53
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
static constexpr const char * kAwaitAllAndExecute
Definition: AsyncToLLVM.cpp:64
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
static constexpr const char * kIsTokenError
Definition: AsyncToLLVM.cpp:49
U dyn_cast() const
Definition: Types.h:269
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:109
static constexpr const char * kSetValueError
Definition: AsyncToLLVM.cpp:48
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:68
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:324
static LLVMPointerType get(MLIRContext *context, unsigned addressSpace=0)
Gets or creates an instance of LLVM dialect pointer type pointing to an object of pointee type in the...
Definition: LLVMTypes.cpp:198
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
IntegerType getI1Type()
Definition: Builders.cpp:54
static constexpr const char * kAddRef
Definition: AsyncToLLVM.cpp:40
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
static void addAsyncRuntimeApiDeclarations(ModuleOp module)
Adds Async Runtime C API declarations to the module.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
static constexpr const char * kCreateValue
Definition: AsyncToLLVM.cpp:43
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp, Type indexType)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static constexpr const char * kAwaitTokenAndExecute
Definition: AsyncToLLVM.cpp:60
static constexpr const char * kAwaitGroup
Definition: AsyncToLLVM.cpp:54
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:19
static constexpr const char * kEmplaceValue
Definition: AsyncToLLVM.cpp:46
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
IntegerType getI64Type()
Definition: Builders.cpp:66
static void addResumeFunction(ModuleOp module)
A function that takes a coroutine handle and calls a llvm.coro.resume intrinsics. ...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
void addConversion(FnT &&callback)
Register a conversion function.
static constexpr const char * kExecute
Definition: AsyncToLLVM.cpp:55
static constexpr const char * kAwaitValueAndExecute
Definition: AsyncToLLVM.cpp:62
Type getType() const
Return the type of this value.
Definition: Value.h:118
IndexType getIndexType()
Definition: Builders.cpp:52
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
static constexpr const char * kIsValueError
Definition: AsyncToLLVM.cpp:50
Type conversion class.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:97
static constexpr const char * kGetValueStorage
Definition: AsyncToLLVM.cpp:56
static constexpr const char * kDropRef
Definition: AsyncToLLVM.cpp:41
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
std::unique_ptr< OperationPass< ModuleOp > > createConvertAsyncToLLVMPass()
Create a pass to convert Async operations to the LLVM dialect.
This class implements a pattern rewriter for use with ConversionPatterns.
LLVM dialect pointer type.
Definition: LLVMTypes.h:194
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
bool isa() const
Definition: Types.h:259
This class helps build Operations.
Definition: Builders.h:197
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
static constexpr const char * kCreateGroup
Definition: AsyncToLLVM.cpp:44
static constexpr const char * kResume
MLIRContext * getContext() const
IntegerType getI32Type()
Definition: Builders.cpp:64
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp)