24#include "llvm/ADT/TypeSwitch.h"
27#define GEN_PASS_DEF_CONVERTASYNCTOLLVMPASS
28#include "mlir/Conversion/Passes.h.inc"
31#define DEBUG_TYPE "convert-async-to-llvm"
40static constexpr const char *
kAddRef =
"mlirAsyncRuntimeAddRef";
41static constexpr const char *
kDropRef =
"mlirAsyncRuntimeDropRef";
42static constexpr const char *
kCreateToken =
"mlirAsyncRuntimeCreateToken";
43static constexpr const char *
kCreateValue =
"mlirAsyncRuntimeCreateValue";
44static constexpr const char *
kCreateGroup =
"mlirAsyncRuntimeCreateGroup";
45static constexpr const char *
kEmplaceToken =
"mlirAsyncRuntimeEmplaceToken";
46static constexpr const char *
kEmplaceValue =
"mlirAsyncRuntimeEmplaceValue";
47static constexpr const char *
kSetTokenError =
"mlirAsyncRuntimeSetTokenError";
48static constexpr const char *
kSetValueError =
"mlirAsyncRuntimeSetValueError";
49static constexpr const char *
kIsTokenError =
"mlirAsyncRuntimeIsTokenError";
50static constexpr const char *
kIsValueError =
"mlirAsyncRuntimeIsValueError";
51static constexpr const char *
kIsGroupError =
"mlirAsyncRuntimeIsGroupError";
52static constexpr const char *
kAwaitToken =
"mlirAsyncRuntimeAwaitToken";
53static constexpr const char *
kAwaitValue =
"mlirAsyncRuntimeAwaitValue";
54static constexpr const char *
kAwaitGroup =
"mlirAsyncRuntimeAwaitAllInGroup";
55static constexpr const char *
kExecute =
"mlirAsyncRuntimeExecute";
57 "mlirAsyncRuntimeGetValueStorage";
59 "mlirAsyncRuntimeAddTokenToGroup";
61 "mlirAsyncRuntimeAwaitTokenAndExecute";
63 "mlirAsyncRuntimeAwaitValueAndExecute";
65 "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
67 "mlirAsyncRuntimGetNumWorkerThreads";
77 static LLVM::LLVMPointerType opaquePointerType(
MLIRContext *ctx) {
78 return LLVM::LLVMPointerType::get(ctx);
81 static LLVM::LLVMTokenType tokenType(
MLIRContext *ctx) {
82 return LLVM::LLVMTokenType::get(ctx);
85 static FunctionType addOrDropRefFunctionType(
MLIRContext *ctx) {
86 auto ref = opaquePointerType(ctx);
87 auto count = IntegerType::get(ctx, 64);
88 return FunctionType::get(ctx, {ref, count}, {});
91 static FunctionType createTokenFunctionType(
MLIRContext *ctx) {
92 return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
95 static FunctionType createValueFunctionType(
MLIRContext *ctx) {
96 auto i64 = IntegerType::get(ctx, 64);
97 auto value = opaquePointerType(ctx);
98 return FunctionType::get(ctx, {i64}, {value});
101 static FunctionType createGroupFunctionType(
MLIRContext *ctx) {
102 auto i64 = IntegerType::get(ctx, 64);
103 return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
106 static FunctionType getValueStorageFunctionType(
MLIRContext *ctx) {
107 auto ptrType = opaquePointerType(ctx);
108 return FunctionType::get(ctx, {ptrType}, {ptrType});
111 static FunctionType emplaceTokenFunctionType(
MLIRContext *ctx) {
112 return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
115 static FunctionType emplaceValueFunctionType(
MLIRContext *ctx) {
116 auto value = opaquePointerType(ctx);
117 return FunctionType::get(ctx, {value}, {});
120 static FunctionType setTokenErrorFunctionType(
MLIRContext *ctx) {
121 return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
124 static FunctionType setValueErrorFunctionType(
MLIRContext *ctx) {
125 auto value = opaquePointerType(ctx);
126 return FunctionType::get(ctx, {value}, {});
129 static FunctionType isTokenErrorFunctionType(
MLIRContext *ctx) {
130 auto i1 = IntegerType::get(ctx, 1);
131 return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
134 static FunctionType isValueErrorFunctionType(
MLIRContext *ctx) {
135 auto value = opaquePointerType(ctx);
136 auto i1 = IntegerType::get(ctx, 1);
137 return FunctionType::get(ctx, {value}, {i1});
140 static FunctionType isGroupErrorFunctionType(
MLIRContext *ctx) {
141 auto i1 = IntegerType::get(ctx, 1);
142 return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
145 static FunctionType awaitTokenFunctionType(
MLIRContext *ctx) {
146 return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
149 static FunctionType awaitValueFunctionType(
MLIRContext *ctx) {
150 auto value = opaquePointerType(ctx);
151 return FunctionType::get(ctx, {value}, {});
154 static FunctionType awaitGroupFunctionType(
MLIRContext *ctx) {
155 return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
158 static FunctionType executeFunctionType(
MLIRContext *ctx) {
159 auto ptrType = opaquePointerType(ctx);
160 return FunctionType::get(ctx, {ptrType, ptrType}, {});
163 static FunctionType addTokenToGroupFunctionType(
MLIRContext *ctx) {
164 auto i64 = IntegerType::get(ctx, 64);
165 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
169 static FunctionType awaitTokenAndExecuteFunctionType(
MLIRContext *ctx) {
170 auto ptrType = opaquePointerType(ctx);
171 return FunctionType::get(ctx, {TokenType::get(ctx), ptrType, ptrType}, {});
174 static FunctionType awaitValueAndExecuteFunctionType(
MLIRContext *ctx) {
175 auto ptrType = opaquePointerType(ctx);
176 return FunctionType::get(ctx, {ptrType, ptrType, ptrType}, {});
179 static FunctionType awaitAllAndExecuteFunctionType(
MLIRContext *ctx) {
180 auto ptrType = opaquePointerType(ctx);
181 return FunctionType::get(ctx, {GroupType::get(ctx), ptrType, ptrType}, {});
184 static FunctionType getNumWorkerThreads(
MLIRContext *ctx) {
185 return FunctionType::get(ctx, {}, {IndexType::get(ctx)});
190 auto voidTy = LLVM::LLVMVoidType::get(ctx);
191 auto ptrType = opaquePointerType(ctx);
192 return LLVM::LLVMFunctionType::get(voidTy, {ptrType},
false);
202 auto addFuncDecl = [&](StringRef name, FunctionType type) {
203 if (module.lookupSymbol(name))
205 func::FuncOp::create(builder, name, type).setPrivate();
209 addFuncDecl(
kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
210 addFuncDecl(
kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
211 addFuncDecl(
kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
212 addFuncDecl(
kCreateValue, AsyncAPI::createValueFunctionType(ctx));
213 addFuncDecl(
kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
214 addFuncDecl(
kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
215 addFuncDecl(
kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
216 addFuncDecl(
kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
217 addFuncDecl(
kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
218 addFuncDecl(
kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
219 addFuncDecl(
kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
220 addFuncDecl(
kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
221 addFuncDecl(
kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
222 addFuncDecl(
kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
223 addFuncDecl(
kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
224 addFuncDecl(
kExecute, AsyncAPI::executeFunctionType(ctx));
228 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
230 AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
232 AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
240static constexpr const char *
kResume =
"__resume";
246 if (module.lookupSymbol(
kResume))
250 auto loc =
module.getLoc();
253 auto voidTy = LLVM::LLVMVoidType::get(ctx);
254 Type ptrType = AsyncAPI::opaquePointerType(ctx);
256 auto resumeOp = LLVM::LLVMFuncOp::create(
257 moduleBuilder,
kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
258 resumeOp.setPrivate();
260 auto *block = resumeOp.addEntryBlock(moduleBuilder);
263 LLVM::CoroResumeOp::create(blockBuilder, resumeOp.getArgument(0));
264 LLVM::ReturnOp::create(blockBuilder,
ValueRange());
276 AsyncRuntimeTypeConverter(
const LowerToLLVMOptions &
options) {
277 addConversion([](Type type) {
return type; });
278 addConversion([](Type type) {
return convertAsyncTypes(type); });
282 auto addUnrealizedCast = [](OpBuilder &builder, Type type,
285 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
286 return cast.getResult(0);
289 addSourceMaterialization(addUnrealizedCast);
290 addTargetMaterialization(addUnrealizedCast);
293 static std::optional<Type> convertAsyncTypes(Type type) {
294 if (isa<TokenType, GroupType, ValueType>(type))
295 return AsyncAPI::opaquePointerType(type.
getContext());
297 if (isa<CoroIdType, CoroStateType>(type))
298 return AsyncAPI::tokenType(type.
getContext());
299 if (isa<CoroHandleType>(type))
300 return AsyncAPI::opaquePointerType(type.
getContext());
309template <
typename SourceOp>
310class AsyncOpConversionPattern :
public OpConversionPattern<SourceOp> {
312 using Base = OpConversionPattern<SourceOp>;
315 AsyncOpConversionPattern(
const AsyncRuntimeTypeConverter &typeConverter,
316 MLIRContext *context)
317 : Base(typeConverter, context) {}
320 const AsyncRuntimeTypeConverter *getTypeConverter()
const {
321 return static_cast<const AsyncRuntimeTypeConverter *
>(
322 Base::getTypeConverter());
333class CoroIdOpConversion :
public AsyncOpConversionPattern<CoroIdOp> {
335 using AsyncOpConversionPattern::AsyncOpConversionPattern;
338 matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
339 ConversionPatternRewriter &rewriter)
const override {
340 auto token = AsyncAPI::tokenType(op->getContext());
341 auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
342 auto loc = op->getLoc();
346 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 0);
347 auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, ptrType);
350 rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
351 op, token,
ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
363class CoroBeginOpConversion :
public AsyncOpConversionPattern<CoroBeginOp> {
365 using AsyncOpConversionPattern::AsyncOpConversionPattern;
368 matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
369 ConversionPatternRewriter &rewriter)
const override {
370 auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
371 auto loc = op->getLoc();
375 LLVM::CoroSizeOp::create(rewriter, loc, rewriter.getI64Type());
378 LLVM::CoroAlignOp::create(rewriter, loc, rewriter.getI64Type());
383 auto makeConstant = [&](uint64_t c) {
384 return LLVM::ConstantOp::create(rewriter, op->getLoc(),
385 rewriter.getI64Type(), c);
387 coroSize = LLVM::AddOp::create(rewriter, op->getLoc(), coroSize, coroAlign);
389 LLVM::SubOp::create(rewriter, op->getLoc(), coroSize, makeConstant(1));
391 LLVM::SubOp::create(rewriter, op->getLoc(), makeConstant(0), coroAlign);
393 LLVM::AndOp::create(rewriter, op->getLoc(), coroSize, negCoroAlign);
396 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
397 rewriter, op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
400 auto coroAlloc = LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(),
404 auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
405 rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
406 op, ptrType,
ValueRange({coroId, coroAlloc.getResult()}));
418class CoroFreeOpConversion :
public AsyncOpConversionPattern<CoroFreeOp> {
420 using AsyncOpConversionPattern::AsyncOpConversionPattern;
423 matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
424 ConversionPatternRewriter &rewriter)
const override {
425 auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
426 auto loc = op->getLoc();
430 LLVM::CoroFreeOp::create(rewriter, loc, ptrType, adaptor.getOperands());
434 LLVM::lookupOrCreateFreeFn(rewriter, op->getParentOfType<ModuleOp>());
437 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp.value(),
450class CoroEndOpConversion :
public OpConversionPattern<CoroEndOp> {
452 using OpConversionPattern::OpConversionPattern;
455 matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
456 ConversionPatternRewriter &rewriter)
const override {
459 LLVM::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI1Type(),
460 rewriter.getBoolAttr(
false));
461 auto noneToken = LLVM::NoneTokenOp::create(rewriter, op->getLoc());
464 auto coroHdl = adaptor.getHandle();
465 LLVM::CoroEndOp::create(rewriter, op->getLoc(), rewriter.getI1Type(),
466 ValueRange({coroHdl, constFalse, noneToken}));
467 rewriter.eraseOp(op);
479class CoroSaveOpConversion :
public OpConversionPattern<CoroSaveOp> {
481 using OpConversionPattern::OpConversionPattern;
484 matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor,
485 ConversionPatternRewriter &rewriter)
const override {
487 rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
488 op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands());
525class CoroSuspendOpConversion :
public OpConversionPattern<CoroSuspendOp> {
527 using OpConversionPattern::OpConversionPattern;
530 matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor,
531 ConversionPatternRewriter &rewriter)
const override {
532 auto i8 = rewriter.getIntegerType(8);
533 auto i32 = rewriter.getI32Type();
534 auto loc = op->getLoc();
537 auto constFalse = LLVM::ConstantOp::create(
538 rewriter, loc, rewriter.getI1Type(), rewriter.getBoolAttr(
false));
541 auto coroState = adaptor.getState();
542 auto coroSuspend = LLVM::CoroSuspendOp::create(
543 rewriter, loc, i8,
ValueRange({coroState, constFalse}));
550 llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
551 llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(),
552 op.getCleanupDest()};
553 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
554 op, LLVM::SExtOp::create(rewriter, loc, i32, coroSuspend.getResult()),
560 ArrayRef<int32_t>());
580 matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
581 ConversionPatternRewriter &rewriter)
const override {
582 const TypeConverter *converter = getTypeConverter();
583 Type resultType = op->getResultTypes()[0];
586 if (isa<TokenType>(resultType)) {
587 rewriter.replaceOpWithNewOp<func::CallOp>(
593 if (
auto value = dyn_cast<ValueType>(resultType)) {
595 auto sizeOf = [&](ValueType valueType) -> Value {
596 auto loc = op->getLoc();
597 auto i64 = rewriter.getI64Type();
599 auto storedType = converter->convertType(valueType.getValueType());
600 auto storagePtrType =
601 AsyncAPI::opaquePointerType(rewriter.getContext());
605 auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, storagePtrType);
607 LLVM::GEPOp::create(rewriter, loc, storagePtrType, storedType,
608 nullPtr, ArrayRef<LLVM::GEPArg>{1});
609 return LLVM::PtrToIntOp::create(rewriter, loc, i64, gep);
612 rewriter.replaceOpWithNewOp<func::CallOp>(op,
kCreateValue, resultType,
618 return rewriter.notifyMatchFailure(op,
"unsupported async type");
628class RuntimeCreateGroupOpLowering
634 matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
635 ConversionPatternRewriter &rewriter)
const override {
636 const TypeConverter *converter = getTypeConverter();
637 Type resultType = op.getResult().getType();
639 rewriter.replaceOpWithNewOp<func::CallOp>(
641 adaptor.getOperands());
652class RuntimeSetAvailableOpLowering
653 :
public OpConversionPattern<RuntimeSetAvailableOp> {
655 using OpConversionPattern::OpConversionPattern;
658 matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
659 ConversionPatternRewriter &rewriter)
const override {
660 StringRef apiFuncName =
665 rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName,
TypeRange(),
666 adaptor.getOperands());
678class RuntimeSetErrorOpLowering
679 :
public OpConversionPattern<RuntimeSetErrorOp> {
681 using OpConversionPattern::OpConversionPattern;
684 matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
685 ConversionPatternRewriter &rewriter)
const override {
686 StringRef apiFuncName =
691 rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName,
TypeRange(),
692 adaptor.getOperands());
704class RuntimeIsErrorOpLowering :
public OpConversionPattern<RuntimeIsErrorOp> {
706 using OpConversionPattern::OpConversionPattern;
709 matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
710 ConversionPatternRewriter &rewriter)
const override {
711 StringRef apiFuncName =
717 rewriter.replaceOpWithNewOp<func::CallOp>(
718 op, apiFuncName, rewriter.getI1Type(), adaptor.getOperands());
729class RuntimeAwaitOpLowering :
public OpConversionPattern<RuntimeAwaitOp> {
731 using OpConversionPattern::OpConversionPattern;
734 matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
735 ConversionPatternRewriter &rewriter)
const override {
736 StringRef apiFuncName =
740 .Case<GroupType>([](Type) {
return kAwaitGroup; });
742 func::CallOp::create(rewriter, op->getLoc(), apiFuncName,
TypeRange(),
743 adaptor.getOperands());
744 rewriter.eraseOp(op);
756class RuntimeAwaitAndResumeOpLowering
757 :
public AsyncOpConversionPattern<RuntimeAwaitAndResumeOp> {
759 using AsyncOpConversionPattern::AsyncOpConversionPattern;
762 matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
763 ConversionPatternRewriter &rewriter)
const override {
764 StringRef apiFuncName =
770 Value operand = adaptor.getOperand();
771 Value handle = adaptor.getHandle();
775 auto resumePtr = LLVM::AddressOfOp::create(
776 rewriter, op->getLoc(),
777 AsyncAPI::opaquePointerType(rewriter.getContext()),
kResume);
779 func::CallOp::create(rewriter, op->getLoc(), apiFuncName,
TypeRange(),
780 ValueRange({operand, handle, resumePtr.getRes()}));
781 rewriter.eraseOp(op);
793class RuntimeResumeOpLowering
794 :
public AsyncOpConversionPattern<RuntimeResumeOp> {
796 using AsyncOpConversionPattern::AsyncOpConversionPattern;
799 matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
800 ConversionPatternRewriter &rewriter)
const override {
803 auto resumePtr = LLVM::AddressOfOp::create(
804 rewriter, op->getLoc(),
805 AsyncAPI::opaquePointerType(rewriter.getContext()),
kResume);
808 auto coroHdl = adaptor.getHandle();
809 rewriter.replaceOpWithNewOp<func::CallOp>(
827 matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
828 ConversionPatternRewriter &rewriter)
const override {
829 Location loc = op->getLoc();
832 auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
833 auto storage = adaptor.getStorage();
838 auto valueType = op.getValue().getType();
839 auto llvmValueType = getTypeConverter()->convertType(valueType);
841 return rewriter.notifyMatchFailure(
842 op,
"failed to convert stored value type to LLVM type");
844 Value castedStoragePtr = storagePtr.getResult(0);
846 auto value = adaptor.getValue();
847 LLVM::StoreOp::create(rewriter, loc, value, castedStoragePtr);
850 rewriter.eraseOp(op);
867 matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
868 ConversionPatternRewriter &rewriter)
const override {
869 Location loc = op->getLoc();
872 auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
873 auto storage = adaptor.getStorage();
878 auto valueType = op.getResult().getType();
879 auto llvmValueType = getTypeConverter()->convertType(valueType);
881 return rewriter.notifyMatchFailure(
882 op,
"failed to convert loaded value type to LLVM type");
884 Value castedStoragePtr = storagePtr.getResult(0);
887 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmValueType,
900class RuntimeAddToGroupOpLowering
901 :
public OpConversionPattern<RuntimeAddToGroupOp> {
903 using OpConversionPattern::OpConversionPattern;
906 matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
907 ConversionPatternRewriter &rewriter)
const override {
909 if (!isa<TokenType>(op.getOperand().getType()))
910 return rewriter.notifyMatchFailure(op,
"only token type is supported");
913 rewriter.replaceOpWithNewOp<func::CallOp>(
927class RuntimeNumWorkerThreadsOpLowering
928 :
public OpConversionPattern<RuntimeNumWorkerThreadsOp> {
930 using OpConversionPattern::OpConversionPattern;
933 matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor,
934 ConversionPatternRewriter &rewriter)
const override {
938 rewriter.getIndexType());
951template <
typename RefCountingOp>
952class RefCountingOpLowering :
public OpConversionPattern<RefCountingOp> {
954 explicit RefCountingOpLowering(
const TypeConverter &converter,
955 MLIRContext *ctx, StringRef apiFunctionName)
956 : OpConversionPattern<RefCountingOp>(converter, ctx),
957 apiFunctionName(apiFunctionName) {}
960 matchAndRewrite(RefCountingOp op,
typename RefCountingOp::Adaptor adaptor,
961 ConversionPatternRewriter &rewriter)
const override {
963 arith::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI64Type(),
964 rewriter.getI64IntegerAttr(op.getCount()));
966 auto operand = adaptor.getOperand();
967 rewriter.replaceOpWithNewOp<func::CallOp>(op,
TypeRange(), apiFunctionName,
974 StringRef apiFunctionName;
977class RuntimeAddRefOpLowering :
public RefCountingOpLowering<RuntimeAddRefOp> {
979 explicit RuntimeAddRefOpLowering(
const TypeConverter &converter,
981 : RefCountingOpLowering(converter, ctx,
kAddRef) {}
984class RuntimeDropRefOpLowering
985 :
public RefCountingOpLowering<RuntimeDropRefOp> {
987 explicit RuntimeDropRefOpLowering(
const TypeConverter &converter,
989 : RefCountingOpLowering(converter, ctx,
kDropRef) {}
998class ReturnOpOpConversion :
public OpConversionPattern<func::ReturnOp> {
1000 using OpConversionPattern::OpConversionPattern;
1003 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
1004 ConversionPatternRewriter &rewriter)
const override {
1005 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
1014struct ConvertAsyncToLLVMPass
1015 :
public impl::ConvertAsyncToLLVMPassBase<ConvertAsyncToLLVMPass> {
1018 void runOnOperation()
override;
1022void ConvertAsyncToLLVMPass::runOnOperation() {
1023 ModuleOp module = getOperation();
1024 MLIRContext *ctx =
module->getContext();
1026 LowerToLLVMOptions
options(ctx);
1037 AsyncRuntimeTypeConverter converter(
options);
1042 LLVMTypeConverter llvmConverter(ctx,
options);
1043 llvmConverter.addConversion([&](Type type) {
1044 return AsyncRuntimeTypeConverter::convertAsyncTypes(type);
1048 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns,
1053 patterns.add<ReturnOpOpConversion>(converter, ctx);
1056 patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
1057 RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
1058 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
1059 RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering,
1060 RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter,
1065 patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
1066 RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter);
1070 .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
1071 CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
1074 ConversionTarget
target(*ctx);
1075 target.addLegalOp<arith::ConstantOp, func::ConstantOp,
1076 UnrealizedConversionCastOp>();
1077 target.addLegalDialect<LLVM::LLVMDialect>();
1081 target.addIllegalDialect<AsyncDialect>();
1084 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1085 return converter.isSignatureLegal(op.getFunctionType());
1087 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
1088 return converter.isLegal(op.getOperandTypes());
1090 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
1091 return converter.isSignatureLegal(op.getCalleeType());
1095 signalPassFailure();
1103class ConvertExecuteOpTypes :
public OpConversionPattern<ExecuteOp> {
1105 using OpConversionPattern::OpConversionPattern;
1107 matchAndRewrite(ExecuteOp op, OpAdaptor adaptor,
1108 ConversionPatternRewriter &rewriter)
const override {
1110 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1111 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1112 newOp.getRegion().end());
1115 newOp->setOperands(adaptor.getOperands());
1116 if (
failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1118 for (
auto result : newOp.getResults())
1119 result.setType(typeConverter->convertType(
result.getType()));
1121 rewriter.replaceOp(op, newOp.getResults());
1127class ConvertAwaitOpTypes :
public OpConversionPattern<AwaitOp> {
1129 using OpConversionPattern::OpConversionPattern;
1131 matchAndRewrite(AwaitOp op, OpAdaptor adaptor,
1132 ConversionPatternRewriter &rewriter)
const override {
1133 rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front());
1139class ConvertYieldOpTypes :
public OpConversionPattern<async::YieldOp> {
1141 using OpConversionPattern::OpConversionPattern;
1143 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
1144 ConversionPatternRewriter &rewriter)
const override {
1145 rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands());
1154 typeConverter.addConversion([&](TokenType type) {
return type; });
1155 typeConverter.addConversion([&](ValueType type) {
1156 Type converted = typeConverter.convertType(type.getValueType());
1157 return converted ? ValueType::get(converted) : converted;
1160 patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1161 typeConverter,
patterns.getContext());
1163 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1164 [&](
Operation *op) {
return typeConverter.isLegal(op); });
static constexpr const char * kAwaitValueAndExecute
static constexpr const char * kCreateValue
static constexpr const char * kCreateGroup
static constexpr const char * kCreateToken
static constexpr const char * kEmplaceValue
static void addResumeFunction(ModuleOp module)
A function that takes a coroutine handle and calls a llvm.coro.resume intrinsics.
static constexpr const char * kEmplaceToken
static void addAsyncRuntimeApiDeclarations(ModuleOp module)
Adds Async Runtime C API declarations to the module.
static constexpr const char * kResume
static constexpr const char * kAddRef
static constexpr const char * kAwaitTokenAndExecute
static constexpr const char * kAwaitValue
static constexpr const char * kSetTokenError
static constexpr const char * kExecute
static constexpr const char * kAddTokenToGroup
static constexpr const char * kIsGroupError
static constexpr const char * kSetValueError
static constexpr const char * kIsTokenError
static constexpr const char * kAwaitGroup
static constexpr const char * kAwaitAllAndExecute
static constexpr const char * kGetNumWorkerThreads
static constexpr const char * kDropRef
static constexpr const char * kIsValueError
static constexpr const char * kAwaitToken
static constexpr const char * kGetValueStorage
static llvm::ManagedStatic< PassManagerOptions > options
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
static ImplicitLocOpBuilder atBlockEnd(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Include the generated interface declarations.
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
const FrozenRewritePatternSet & patterns
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
llvm::TypeSwitch< T, ResultT > TypeSwitch