24#include "llvm/ADT/TypeSwitch.h"
27#define GEN_PASS_DEF_CONVERTASYNCTOLLVMPASS
28#include "mlir/Conversion/Passes.h.inc"
31#define DEBUG_TYPE "convert-async-to-llvm"
40static constexpr const char *
kAddRef =
"mlirAsyncRuntimeAddRef";
41static constexpr const char *
kDropRef =
"mlirAsyncRuntimeDropRef";
42static constexpr const char *
kCreateToken =
"mlirAsyncRuntimeCreateToken";
43static constexpr const char *
kCreateValue =
"mlirAsyncRuntimeCreateValue";
44static constexpr const char *
kCreateGroup =
"mlirAsyncRuntimeCreateGroup";
45static constexpr const char *
kEmplaceToken =
"mlirAsyncRuntimeEmplaceToken";
46static constexpr const char *
kEmplaceValue =
"mlirAsyncRuntimeEmplaceValue";
47static constexpr const char *
kSetTokenError =
"mlirAsyncRuntimeSetTokenError";
48static constexpr const char *
kSetValueError =
"mlirAsyncRuntimeSetValueError";
49static constexpr const char *
kIsTokenError =
"mlirAsyncRuntimeIsTokenError";
50static constexpr const char *
kIsValueError =
"mlirAsyncRuntimeIsValueError";
51static constexpr const char *
kIsGroupError =
"mlirAsyncRuntimeIsGroupError";
52static constexpr const char *
kAwaitToken =
"mlirAsyncRuntimeAwaitToken";
53static constexpr const char *
kAwaitValue =
"mlirAsyncRuntimeAwaitValue";
54static constexpr const char *
kAwaitGroup =
"mlirAsyncRuntimeAwaitAllInGroup";
55static constexpr const char *
kExecute =
"mlirAsyncRuntimeExecute";
57 "mlirAsyncRuntimeGetValueStorage";
59 "mlirAsyncRuntimeAddTokenToGroup";
61 "mlirAsyncRuntimeAwaitTokenAndExecute";
63 "mlirAsyncRuntimeAwaitValueAndExecute";
65 "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
67 "mlirAsyncRuntimGetNumWorkerThreads";
77 static LLVM::LLVMPointerType opaquePointerType(
MLIRContext *ctx) {
78 return LLVM::LLVMPointerType::get(ctx);
81 static mlir::TokenType tokenType(
MLIRContext *ctx) {
82 return mlir::TokenType::get(ctx);
85 static FunctionType addOrDropRefFunctionType(
MLIRContext *ctx) {
86 auto ref = opaquePointerType(ctx);
87 auto count = IntegerType::get(ctx, 64);
88 return FunctionType::get(ctx, {ref, count}, {});
91 static FunctionType createTokenFunctionType(
MLIRContext *ctx) {
92 return FunctionType::get(ctx, {}, {async::TokenType::get(ctx)});
95 static FunctionType createValueFunctionType(
MLIRContext *ctx) {
96 auto i64 = IntegerType::get(ctx, 64);
97 auto value = opaquePointerType(ctx);
98 return FunctionType::get(ctx, {i64}, {value});
101 static FunctionType createGroupFunctionType(
MLIRContext *ctx) {
102 auto i64 = IntegerType::get(ctx, 64);
103 return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
106 static FunctionType getValueStorageFunctionType(
MLIRContext *ctx) {
107 auto ptrType = opaquePointerType(ctx);
108 return FunctionType::get(ctx, {ptrType}, {ptrType});
111 static FunctionType emplaceTokenFunctionType(
MLIRContext *ctx) {
112 return FunctionType::get(ctx, {async::TokenType::get(ctx)}, {});
115 static FunctionType emplaceValueFunctionType(
MLIRContext *ctx) {
116 auto value = opaquePointerType(ctx);
117 return FunctionType::get(ctx, {value}, {});
120 static FunctionType setTokenErrorFunctionType(
MLIRContext *ctx) {
121 return FunctionType::get(ctx, {async::TokenType::get(ctx)}, {});
124 static FunctionType setValueErrorFunctionType(
MLIRContext *ctx) {
125 auto value = opaquePointerType(ctx);
126 return FunctionType::get(ctx, {value}, {});
129 static FunctionType isTokenErrorFunctionType(
MLIRContext *ctx) {
130 auto i1 = IntegerType::get(ctx, 1);
131 return FunctionType::get(ctx, {async::TokenType::get(ctx)}, {i1});
134 static FunctionType isValueErrorFunctionType(
MLIRContext *ctx) {
135 auto value = opaquePointerType(ctx);
136 auto i1 = IntegerType::get(ctx, 1);
137 return FunctionType::get(ctx, {value}, {i1});
140 static FunctionType isGroupErrorFunctionType(
MLIRContext *ctx) {
141 auto i1 = IntegerType::get(ctx, 1);
142 return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
145 static FunctionType awaitTokenFunctionType(
MLIRContext *ctx) {
146 return FunctionType::get(ctx, {async::TokenType::get(ctx)}, {});
149 static FunctionType awaitValueFunctionType(
MLIRContext *ctx) {
150 auto value = opaquePointerType(ctx);
151 return FunctionType::get(ctx, {value}, {});
154 static FunctionType awaitGroupFunctionType(
MLIRContext *ctx) {
155 return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
158 static FunctionType executeFunctionType(
MLIRContext *ctx) {
159 auto ptrType = opaquePointerType(ctx);
160 return FunctionType::get(ctx, {ptrType, ptrType}, {});
163 static FunctionType addTokenToGroupFunctionType(
MLIRContext *ctx) {
164 auto i64 = IntegerType::get(ctx, 64);
165 return FunctionType::get(
166 ctx, {async::TokenType::get(ctx), GroupType::get(ctx)}, {i64});
169 static FunctionType awaitTokenAndExecuteFunctionType(
MLIRContext *ctx) {
170 auto ptrType = opaquePointerType(ctx);
171 return FunctionType::get(
172 ctx, {async::TokenType::get(ctx), ptrType, ptrType}, {});
175 static FunctionType awaitValueAndExecuteFunctionType(
MLIRContext *ctx) {
176 auto ptrType = opaquePointerType(ctx);
177 return FunctionType::get(ctx, {ptrType, ptrType, ptrType}, {});
180 static FunctionType awaitAllAndExecuteFunctionType(
MLIRContext *ctx) {
181 auto ptrType = opaquePointerType(ctx);
182 return FunctionType::get(ctx, {GroupType::get(ctx), ptrType, ptrType}, {});
185 static FunctionType getNumWorkerThreads(
MLIRContext *ctx) {
186 return FunctionType::get(ctx, {}, {IndexType::get(ctx)});
191 auto voidTy = LLVM::LLVMVoidType::get(ctx);
192 auto ptrType = opaquePointerType(ctx);
193 return LLVM::LLVMFunctionType::get(voidTy, {ptrType},
false);
203 auto addFuncDecl = [&](StringRef name, FunctionType type) {
204 if (module.lookupSymbol(name))
206 func::FuncOp::create(builder, name, type).setPrivate();
210 addFuncDecl(
kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
211 addFuncDecl(
kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
212 addFuncDecl(
kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
213 addFuncDecl(
kCreateValue, AsyncAPI::createValueFunctionType(ctx));
214 addFuncDecl(
kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
215 addFuncDecl(
kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
216 addFuncDecl(
kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
217 addFuncDecl(
kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
218 addFuncDecl(
kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
219 addFuncDecl(
kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
220 addFuncDecl(
kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
221 addFuncDecl(
kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
222 addFuncDecl(
kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
223 addFuncDecl(
kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
224 addFuncDecl(
kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
225 addFuncDecl(
kExecute, AsyncAPI::executeFunctionType(ctx));
229 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
231 AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
233 AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
241static constexpr const char *
kResume =
"__resume";
247 if (module.lookupSymbol(
kResume))
251 auto loc =
module.getLoc();
254 auto voidTy = LLVM::LLVMVoidType::get(ctx);
255 Type ptrType = AsyncAPI::opaquePointerType(ctx);
257 auto resumeOp = LLVM::LLVMFuncOp::create(
258 moduleBuilder,
kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
259 resumeOp.setPrivate();
261 auto *block = resumeOp.addEntryBlock(moduleBuilder);
264 LLVM::CoroResumeOp::create(blockBuilder, resumeOp.getArgument(0));
265 LLVM::ReturnOp::create(blockBuilder,
ValueRange());
277 AsyncRuntimeTypeConverter(
const LowerToLLVMOptions &
options) {
278 addConversion([](Type type) {
return type; });
279 addConversion([](Type type) {
return convertAsyncTypes(type); });
283 auto addUnrealizedCast = [](OpBuilder &builder, Type type,
286 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
287 return cast.getResult(0);
290 addSourceMaterialization(addUnrealizedCast);
291 addTargetMaterialization(addUnrealizedCast);
294 static std::optional<Type> convertAsyncTypes(Type type) {
295 if (isa<async::TokenType, GroupType, ValueType>(type))
296 return AsyncAPI::opaquePointerType(type.
getContext());
298 if (isa<CoroIdType, CoroStateType>(type))
299 return AsyncAPI::tokenType(type.
getContext());
300 if (isa<CoroHandleType>(type))
301 return AsyncAPI::opaquePointerType(type.
getContext());
310template <
typename SourceOp>
311class AsyncOpConversionPattern :
public OpConversionPattern<SourceOp> {
313 using Base = OpConversionPattern<SourceOp>;
316 AsyncOpConversionPattern(
const AsyncRuntimeTypeConverter &typeConverter,
317 MLIRContext *context)
318 : Base(typeConverter, context) {}
321 const AsyncRuntimeTypeConverter *getTypeConverter()
const {
322 return static_cast<const AsyncRuntimeTypeConverter *
>(
323 Base::getTypeConverter());
334class CoroIdOpConversion :
public AsyncOpConversionPattern<CoroIdOp> {
336 using AsyncOpConversionPattern::AsyncOpConversionPattern;
339 matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
340 ConversionPatternRewriter &rewriter)
const override {
341 auto token = AsyncAPI::tokenType(op->getContext());
342 auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
343 auto loc = op->getLoc();
347 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 0);
348 auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, ptrType);
351 rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
352 op, token,
ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
364class CoroBeginOpConversion :
public AsyncOpConversionPattern<CoroBeginOp> {
366 using AsyncOpConversionPattern::AsyncOpConversionPattern;
369 matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
370 ConversionPatternRewriter &rewriter)
const override {
371 auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
372 auto loc = op->getLoc();
376 LLVM::CoroSizeOp::create(rewriter, loc, rewriter.getI64Type());
379 LLVM::CoroAlignOp::create(rewriter, loc, rewriter.getI64Type());
384 auto makeConstant = [&](uint64_t c) {
385 return LLVM::ConstantOp::create(rewriter, op->getLoc(),
386 rewriter.getI64Type(), c);
388 coroSize = LLVM::AddOp::create(rewriter, op->getLoc(), coroSize, coroAlign);
390 LLVM::SubOp::create(rewriter, op->getLoc(), coroSize, makeConstant(1));
392 LLVM::SubOp::create(rewriter, op->getLoc(), makeConstant(0), coroAlign);
394 LLVM::AndOp::create(rewriter, op->getLoc(), coroSize, negCoroAlign);
398 rewriter, op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
401 auto coroAlloc = LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(),
405 auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
406 rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
407 op, ptrType,
ValueRange({coroId, coroAlloc.getResult()}));
419class CoroFreeOpConversion :
public AsyncOpConversionPattern<CoroFreeOp> {
421 using AsyncOpConversionPattern::AsyncOpConversionPattern;
424 matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
425 ConversionPatternRewriter &rewriter)
const override {
426 auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
427 auto loc = op->getLoc();
431 LLVM::CoroFreeOp::create(rewriter, loc, ptrType, adaptor.getOperands());
438 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp.value(),
451class CoroEndOpConversion :
public OpConversionPattern<CoroEndOp> {
453 using OpConversionPattern::OpConversionPattern;
456 matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
457 ConversionPatternRewriter &rewriter)
const override {
460 LLVM::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI1Type(),
461 rewriter.getBoolAttr(
false));
462 auto noneToken = LLVM::NoneTokenOp::create(rewriter, op->getLoc());
465 auto coroHdl = adaptor.getHandle();
466 LLVM::CoroEndOp::create(rewriter, op->getLoc(), rewriter.getI1Type(),
467 ValueRange({coroHdl, constFalse, noneToken}));
468 rewriter.eraseOp(op);
480class CoroSaveOpConversion :
public OpConversionPattern<CoroSaveOp> {
482 using OpConversionPattern::OpConversionPattern;
485 matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor,
486 ConversionPatternRewriter &rewriter)
const override {
488 rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
489 op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands());
526class CoroSuspendOpConversion :
public OpConversionPattern<CoroSuspendOp> {
528 using OpConversionPattern::OpConversionPattern;
531 matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor,
532 ConversionPatternRewriter &rewriter)
const override {
533 auto i8 = rewriter.getIntegerType(8);
534 auto i32 = rewriter.getI32Type();
535 auto loc = op->getLoc();
538 auto constFalse = LLVM::ConstantOp::create(
539 rewriter, loc, rewriter.getI1Type(), rewriter.getBoolAttr(
false));
542 auto coroState = adaptor.getState();
543 auto coroSuspend = LLVM::CoroSuspendOp::create(
544 rewriter, loc, i8,
ValueRange({coroState, constFalse}));
551 llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
552 llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(),
553 op.getCleanupDest()};
554 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
555 op, LLVM::SExtOp::create(rewriter, loc, i32, coroSuspend.getResult()),
561 ArrayRef<int32_t>());
581 matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
582 ConversionPatternRewriter &rewriter)
const override {
583 const TypeConverter *converter = getTypeConverter();
584 Type resultType = op->getResultTypes()[0];
587 if (isa<async::TokenType>(resultType)) {
588 rewriter.replaceOpWithNewOp<func::CallOp>(
594 if (
auto value = dyn_cast<ValueType>(resultType)) {
596 auto sizeOf = [&](ValueType valueType) -> Value {
597 auto loc = op->getLoc();
598 auto i64 = rewriter.getI64Type();
600 auto storedType = converter->convertType(valueType.getValueType());
601 auto storagePtrType =
602 AsyncAPI::opaquePointerType(rewriter.getContext());
606 auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, storagePtrType);
608 LLVM::GEPOp::create(rewriter, loc, storagePtrType, storedType,
609 nullPtr, ArrayRef<LLVM::GEPArg>{1});
610 return LLVM::PtrToIntOp::create(rewriter, loc, i64, gep);
613 rewriter.replaceOpWithNewOp<func::CallOp>(op,
kCreateValue, resultType,
619 return rewriter.notifyMatchFailure(op,
"unsupported async type");
629class RuntimeCreateGroupOpLowering
635 matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
636 ConversionPatternRewriter &rewriter)
const override {
637 const TypeConverter *converter = getTypeConverter();
638 Type resultType = op.getResult().getType();
640 rewriter.replaceOpWithNewOp<func::CallOp>(
642 adaptor.getOperands());
653class RuntimeSetAvailableOpLowering
654 :
public OpConversionPattern<RuntimeSetAvailableOp> {
656 using OpConversionPattern::OpConversionPattern;
659 matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
660 ConversionPatternRewriter &rewriter)
const override {
661 StringRef apiFuncName =
666 rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName,
TypeRange(),
667 adaptor.getOperands());
679class RuntimeSetErrorOpLowering
680 :
public OpConversionPattern<RuntimeSetErrorOp> {
682 using OpConversionPattern::OpConversionPattern;
685 matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
686 ConversionPatternRewriter &rewriter)
const override {
687 StringRef apiFuncName =
692 rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName,
TypeRange(),
693 adaptor.getOperands());
705class RuntimeIsErrorOpLowering :
public OpConversionPattern<RuntimeIsErrorOp> {
707 using OpConversionPattern::OpConversionPattern;
710 matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
711 ConversionPatternRewriter &rewriter)
const override {
712 StringRef apiFuncName =
718 rewriter.replaceOpWithNewOp<func::CallOp>(
719 op, apiFuncName, rewriter.getI1Type(), adaptor.getOperands());
730class RuntimeAwaitOpLowering :
public OpConversionPattern<RuntimeAwaitOp> {
732 using OpConversionPattern::OpConversionPattern;
735 matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
736 ConversionPatternRewriter &rewriter)
const override {
737 StringRef apiFuncName =
739 .Case<async::TokenType>([](Type) {
return kAwaitToken; })
741 .Case<GroupType>([](Type) {
return kAwaitGroup; });
743 func::CallOp::create(rewriter, op->getLoc(), apiFuncName,
TypeRange(),
744 adaptor.getOperands());
745 rewriter.eraseOp(op);
757class RuntimeAwaitAndResumeOpLowering
758 :
public AsyncOpConversionPattern<RuntimeAwaitAndResumeOp> {
760 using AsyncOpConversionPattern::AsyncOpConversionPattern;
763 matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
764 ConversionPatternRewriter &rewriter)
const override {
765 StringRef apiFuncName =
771 Value operand = adaptor.getOperand();
772 Value handle = adaptor.getHandle();
776 auto resumePtr = LLVM::AddressOfOp::create(
777 rewriter, op->getLoc(),
778 AsyncAPI::opaquePointerType(rewriter.getContext()),
kResume);
780 func::CallOp::create(rewriter, op->getLoc(), apiFuncName,
TypeRange(),
781 ValueRange({operand, handle, resumePtr.getRes()}));
782 rewriter.eraseOp(op);
794class RuntimeResumeOpLowering
795 :
public AsyncOpConversionPattern<RuntimeResumeOp> {
797 using AsyncOpConversionPattern::AsyncOpConversionPattern;
800 matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
801 ConversionPatternRewriter &rewriter)
const override {
804 auto resumePtr = LLVM::AddressOfOp::create(
805 rewriter, op->getLoc(),
806 AsyncAPI::opaquePointerType(rewriter.getContext()),
kResume);
809 auto coroHdl = adaptor.getHandle();
810 rewriter.replaceOpWithNewOp<func::CallOp>(
828 matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
829 ConversionPatternRewriter &rewriter)
const override {
830 Location loc = op->getLoc();
833 auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
834 auto storage = adaptor.getStorage();
839 auto valueType = op.getValue().getType();
840 auto llvmValueType = getTypeConverter()->convertType(valueType);
842 return rewriter.notifyMatchFailure(
843 op,
"failed to convert stored value type to LLVM type");
845 Value castedStoragePtr = storagePtr.getResult(0);
847 auto value = adaptor.getValue();
848 LLVM::StoreOp::create(rewriter, loc, value, castedStoragePtr);
851 rewriter.eraseOp(op);
868 matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
869 ConversionPatternRewriter &rewriter)
const override {
870 Location loc = op->getLoc();
873 auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
874 auto storage = adaptor.getStorage();
879 auto valueType = op.getResult().getType();
880 auto llvmValueType = getTypeConverter()->convertType(valueType);
882 return rewriter.notifyMatchFailure(
883 op,
"failed to convert loaded value type to LLVM type");
885 Value castedStoragePtr = storagePtr.getResult(0);
888 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmValueType,
901class RuntimeAddToGroupOpLowering
902 :
public OpConversionPattern<RuntimeAddToGroupOp> {
904 using OpConversionPattern::OpConversionPattern;
907 matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
908 ConversionPatternRewriter &rewriter)
const override {
910 if (!isa<async::TokenType>(op.getOperand().getType()))
911 return rewriter.notifyMatchFailure(op,
"only token type is supported");
914 rewriter.replaceOpWithNewOp<func::CallOp>(
928class RuntimeNumWorkerThreadsOpLowering
929 :
public OpConversionPattern<RuntimeNumWorkerThreadsOp> {
931 using OpConversionPattern::OpConversionPattern;
934 matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor,
935 ConversionPatternRewriter &rewriter)
const override {
939 rewriter.getIndexType());
952template <
typename RefCountingOp>
953class RefCountingOpLowering :
public OpConversionPattern<RefCountingOp> {
955 explicit RefCountingOpLowering(
const TypeConverter &converter,
956 MLIRContext *ctx, StringRef apiFunctionName)
957 : OpConversionPattern<RefCountingOp>(converter, ctx),
958 apiFunctionName(apiFunctionName) {}
961 matchAndRewrite(RefCountingOp op,
typename RefCountingOp::Adaptor adaptor,
962 ConversionPatternRewriter &rewriter)
const override {
964 arith::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI64Type(),
965 rewriter.getI64IntegerAttr(op.getCount()));
967 auto operand = adaptor.getOperand();
968 rewriter.replaceOpWithNewOp<func::CallOp>(op,
TypeRange(), apiFunctionName,
975 StringRef apiFunctionName;
978class RuntimeAddRefOpLowering :
public RefCountingOpLowering<RuntimeAddRefOp> {
980 explicit RuntimeAddRefOpLowering(
const TypeConverter &converter,
982 : RefCountingOpLowering(converter, ctx,
kAddRef) {}
985class RuntimeDropRefOpLowering
986 :
public RefCountingOpLowering<RuntimeDropRefOp> {
988 explicit RuntimeDropRefOpLowering(
const TypeConverter &converter,
990 : RefCountingOpLowering(converter, ctx,
kDropRef) {}
999class ReturnOpOpConversion :
public OpConversionPattern<func::ReturnOp> {
1001 using OpConversionPattern::OpConversionPattern;
1004 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
1005 ConversionPatternRewriter &rewriter)
const override {
1006 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
1015struct ConvertAsyncToLLVMPass
1019 void runOnOperation()
override;
1023void ConvertAsyncToLLVMPass::runOnOperation() {
1024 ModuleOp module = getOperation();
1025 MLIRContext *ctx =
module->getContext();
1027 LowerToLLVMOptions
options(ctx);
1038 AsyncRuntimeTypeConverter converter(
options);
1039 RewritePatternSet patterns(ctx);
1043 LLVMTypeConverter llvmConverter(ctx,
options);
1044 llvmConverter.addConversion([&](Type type) {
1045 return AsyncRuntimeTypeConverter::convertAsyncTypes(type);
1049 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1054 patterns.add<ReturnOpOpConversion>(converter, ctx);
1057 patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
1058 RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
1059 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
1060 RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering,
1061 RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter,
1066 patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
1067 RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter);
1071 .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
1072 CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
1076 target.addLegalOp<arith::ConstantOp, func::ConstantOp,
1077 UnrealizedConversionCastOp>();
1078 target.addLegalDialect<LLVM::LLVMDialect>();
1082 target.addIllegalDialect<AsyncDialect>();
1085 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1086 return converter.isSignatureLegal(op.getFunctionType());
1088 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
1089 return converter.isLegal(op.getOperandTypes());
1091 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
1092 return converter.isSignatureLegal(op.getCalleeType());
1095 if (failed(applyPartialConversion(module,
target, std::move(patterns))))
1096 signalPassFailure();
1104class ConvertExecuteOpTypes :
public OpConversionPattern<ExecuteOp> {
1106 using OpConversionPattern::OpConversionPattern;
1108 matchAndRewrite(ExecuteOp op, OpAdaptor adaptor,
1109 ConversionPatternRewriter &rewriter)
const override {
1111 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1112 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1113 newOp.getRegion().end());
1116 newOp->setOperands(adaptor.getOperands());
1117 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1119 for (
auto result : newOp.getResults())
1120 result.setType(typeConverter->convertType(
result.getType()));
1122 rewriter.replaceOp(op, newOp.getResults());
1128class ConvertAwaitOpTypes :
public OpConversionPattern<AwaitOp> {
1130 using OpConversionPattern::OpConversionPattern;
1132 matchAndRewrite(AwaitOp op, OpAdaptor adaptor,
1133 ConversionPatternRewriter &rewriter)
const override {
1134 rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front());
1140class ConvertYieldOpTypes :
public OpConversionPattern<async::YieldOp> {
1142 using OpConversionPattern::OpConversionPattern;
1144 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
1145 ConversionPatternRewriter &rewriter)
const override {
1146 rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands());
1155 typeConverter.addConversion([&](async::TokenType type) {
return type; });
1156 typeConverter.addConversion([&](ValueType type) {
1157 Type converted = typeConverter.convertType(type.getValueType());
1158 return converted ? ValueType::get(converted) : converted;
1161 patterns.
add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1164 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1165 [&](
Operation *op) {
return typeConverter.isLegal(op); });
static constexpr const char * kAwaitValueAndExecute
static constexpr const char * kCreateValue
static constexpr const char * kCreateGroup
static constexpr const char * kCreateToken
static constexpr const char * kEmplaceValue
static void addResumeFunction(ModuleOp module)
A function that takes a coroutine handle and calls a llvm.coro.resume intrinsics.
static constexpr const char * kEmplaceToken
static void addAsyncRuntimeApiDeclarations(ModuleOp module)
Adds Async Runtime C API declarations to the module.
static constexpr const char * kResume
static constexpr const char * kAddRef
static constexpr const char * kAwaitTokenAndExecute
static constexpr const char * kAwaitValue
static constexpr const char * kSetTokenError
static constexpr const char * kExecute
static constexpr const char * kAddTokenToGroup
static constexpr const char * kIsGroupError
static constexpr const char * kSetValueError
static constexpr const char * kIsTokenError
static constexpr const char * kAwaitGroup
static constexpr const char * kAwaitAllAndExecute
static constexpr const char * kGetNumWorkerThreads
static constexpr const char * kDropRef
static constexpr const char * kIsValueError
static constexpr const char * kAwaitToken
static constexpr const char * kGetValueStorage
static llvm::ManagedStatic< PassManagerOptions > options
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
static ImplicitLocOpBuilder atBlockEnd(Location loc, Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
Include the generated interface declarations.
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
llvm::TypeSwitch< T, ResultT > TypeSwitch