39 #include "llvm/ADT/STLExtras.h"
40 #include "llvm/Support/Error.h"
41 #include "llvm/Support/FormatVariadic.h"
43 #define DEBUG_TYPE "gpu-to-llvm"
46 #define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
47 #include "mlir/Conversion/Passes.h.inc"
55 class GpuToLLVMConversionPass
56 :
public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
60 Base::getDependentDialects(registry);
64 void runOnOperation()
override;
67 template <
typename OpTy>
70 explicit ConvertOpToGpuRuntimeCallPattern(
78 return type.hasStaticShape()
80 rewriter, loc, indexType, type.getNumElements())
83 : rewriter.
create<LLVM::MulOp>(loc,
84 desc.
stride(rewriter, loc, 0),
85 desc.
size(rewriter, loc, 0));
88 MLIRContext *context = &this->getTypeConverter()->getContext();
98 context, this->getTypeConverter()->getPointerBitwidth(0));
103 {llvmPointerType , llvmInt64Type }};
105 "mgpuModuleUnload", llvmVoidType, {llvmPointerType }};
107 "mgpuModuleGetFunction",
131 "mgpuStreamCreate", llvmPointerType , {}};
133 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType }};
135 "mgpuStreamSynchronize",
139 "mgpuStreamWaitEvent",
141 {llvmPointerType , llvmPointerType }};
143 "mgpuEventCreate", llvmPointerType , {}};
145 "mgpuEventDestroy", llvmVoidType, {llvmPointerType }};
147 "mgpuEventSynchronize",
153 {llvmPointerType , llvmPointerType }};
155 "mgpuMemHostRegisterMemRef",
161 "mgpuMemHostUnregisterMemRef",
175 {llvmPointerType , llvmPointerType }};
179 {llvmPointerType , llvmPointerType ,
192 {llvmPointerType , llvmInt32Type ,
196 "mgpuSetDefaultDevice",
202 {llvmIntPtrType, llvmPointerType, llvmInt32Type,
207 {llvmPointerType, llvmPointerType }};
211 {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
216 {llvmPointerType, llvmPointerType }};
220 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
221 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
226 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
227 llvmPointerType, llvmInt32Type, llvmInt32Type,
232 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
233 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
234 llvmInt32Type, llvmPointerType }};
238 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
239 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
240 llvmInt32Type, llvmPointerType }};
244 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
245 llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
246 llvmInt32Type, llvmInt32Type, llvmInt32Type,
251 {llvmPointerType, llvmPointerType }};
253 "mgpuSpMVBufferSize",
255 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
256 llvmInt32Type, llvmPointerType }};
260 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
261 llvmInt32Type, llvmPointerType, llvmPointerType }};
263 "mgpuSpMMBufferSize",
265 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
266 llvmPointerType, llvmInt32Type, llvmPointerType }};
270 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
271 llvmPointerType, llvmInt32Type, llvmPointerType,
274 "mgpuSDDMMBufferSize",
276 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
277 llvmPointerType, llvmInt32Type, llvmPointerType }};
281 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
282 llvmPointerType, llvmInt32Type, llvmPointerType,
285 "mgpuCreateCuSparseLtDnMat",
287 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
288 llvmInt32Type, llvmPointerType }};
290 "mgpuDestroyCuSparseLtSpMat",
292 {llvmPointerType, llvmPointerType }};
294 "mgpuDestroyCuSparseLtDnMat",
296 {llvmPointerType, llvmPointerType }};
298 "mgpuCusparseLtCreate2To4SpMat",
300 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
301 llvmInt32Type, llvmPointerType }};
303 "mgpuCuSparseLtSpMMBufferSize",
305 {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
306 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
309 "mgpuCuSparseLtSpMM",
311 {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
312 llvmPointerType, llvmPointerType, llvmPointerType }};
314 "mgpuSpGEMMCreateDescr",
318 "mgpuSpGEMMDestroyDescr",
320 {llvmPointerType , llvmPointerType }};
322 "mgpuSpGEMMWorkEstimation",
324 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
325 llvmPointerType , llvmPointerType , llvmPointerType ,
326 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
331 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
332 llvmPointerType , llvmPointerType , llvmPointerType ,
333 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
338 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
339 llvmPointerType , llvmPointerType , llvmPointerType ,
340 llvmInt32Type , llvmPointerType }};
344 {llvmPointerType , llvmPointerType , llvmPointerType ,
345 llvmPointerType , llvmPointerType }};
347 "mgpuSetCsrPointers",
349 {llvmPointerType , llvmPointerType ,
350 llvmPointerType , llvmPointerType ,
356 class ConvertHostRegisterOpToGpuRuntimeCallPattern
357 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
359 ConvertHostRegisterOpToGpuRuntimeCallPattern(
361 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
365 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
369 class ConvertHostUnregisterOpToGpuRuntimeCallPattern
370 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
372 ConvertHostUnregisterOpToGpuRuntimeCallPattern(
374 : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
379 matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
385 class ConvertAllocOpToGpuRuntimeCallPattern
386 :
public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
389 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
393 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
399 class ConvertDeallocOpToGpuRuntimeCallPattern
400 :
public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
402 ConvertDeallocOpToGpuRuntimeCallPattern(
404 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
408 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
412 class ConvertAsyncYieldToGpuRuntimeCallPattern
413 :
public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
415 ConvertAsyncYieldToGpuRuntimeCallPattern(
417 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
421 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
427 class ConvertWaitOpToGpuRuntimeCallPattern
428 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
431 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
435 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
441 class ConvertWaitAsyncOpToGpuRuntimeCallPattern
442 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
444 ConvertWaitAsyncOpToGpuRuntimeCallPattern(
446 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
450 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
467 class ConvertLaunchFuncOpToGpuRuntimeCallPattern
468 :
public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
470 ConvertLaunchFuncOpToGpuRuntimeCallPattern(
472 bool kernelBarePtrCallConv,
SymbolTable *cachedModuleTable)
473 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
474 gpuBinaryAnnotation(gpuBinaryAnnotation),
475 kernelBarePtrCallConv(kernelBarePtrCallConv),
476 cachedModuleTable(cachedModuleTable) {}
479 Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
481 Value generateKernelNameConstant(StringRef moduleName, StringRef name,
485 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
489 bool kernelBarePtrCallConv;
507 class ConvertMemcpyOpToGpuRuntimeCallPattern
508 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
511 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
515 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
521 class ConvertMemsetOpToGpuRuntimeCallPattern
522 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
525 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
529 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
535 class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
536 :
public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
538 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
540 : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
544 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
550 #define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
551 class Convert##op_name##ToGpuRuntimeCallPattern \
552 : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
554 Convert##op_name##ToGpuRuntimeCallPattern( \
555 const LLVMTypeConverter &typeConverter) \
556 : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
560 matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
561 ConversionPatternRewriter &rewriter) const override; \
588 void GpuToLLVMConversionPass::runOnOperation() {
592 options.useBarePtrCallConv = hostBarePtrCallConv;
601 auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
604 iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
609 [](gpu::GPUModuleOp module) ->
bool {
610 return module.getTargetsAttr() !=
nullptr;
615 [&](gpu::LaunchFuncOp op) ->
bool {
617 symbolTable.
lookup<gpu::GPUModuleOp>(op.getKernelModuleName());
620 (module && module.getTargetsAttr() &&
621 !module.getTargetsAttr().empty());
630 kernelBarePtrCallConv, &symbolTable);
640 auto function = [&] {
641 if (
auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(
functionName))
646 return builder.
create<LLVM::CallOp>(loc,
function, arguments);
663 llvm_unreachable(
"unsupported type");
669 if (llvm::isa<ComplexType>(type)) {
672 if (elementType.isBF16())
674 if (elementType.isF16())
676 if (elementType.isF32())
678 if (elementType.isF64())
680 if (elementType.isInteger(8))
682 if (elementType.isInteger(16))
684 if (elementType.isInteger(32))
702 llvm_unreachable(
"unsupported element type");
706 return spMat.
getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
731 llvm_unreachable(
"cannot find spmat def");
736 auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
749 if (!llvm::all_of(operands, [](
Value value) {
753 op,
"Cannot convert if operands aren't of LLVM type.");
759 gpu::AsyncOpInterface op) {
760 if (op.getAsyncDependencies().size() != 1)
762 op,
"Can only convert with exactly one async dependency.");
764 if (!op.getAsyncToken())
770 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
771 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
773 auto *op = hostRegisterOp.getOperation();
779 auto memRefType = hostRegisterOp.getValue().getType();
780 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
783 auto arguments = getTypeConverter()->promoteOperands(
784 loc, op->
getOperands(), adaptor.getOperands(), rewriter);
785 arguments.push_back(elementSize);
786 hostRegisterCallBuilder.create(loc, rewriter, arguments);
792 LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
793 gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
795 Operation *op = hostUnregisterOp.getOperation();
801 auto memRefType = hostUnregisterOp.getValue().getType();
802 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
805 auto arguments = getTypeConverter()->promoteOperands(
806 loc, op->
getOperands(), adaptor.getOperands(), rewriter);
807 arguments.push_back(elementSize);
808 hostUnregisterCallBuilder.create(loc, rewriter, arguments);
814 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
815 gpu::AllocOp allocOp, OpAdaptor adaptor,
818 MemRefType memRefType = allocOp.getType();
821 !isConvertibleAndHasIdentityMaps(memRefType))
824 auto loc = allocOp.getLoc();
826 bool isShared = allocOp.getHostShared();
828 if (isShared && allocOp.getAsyncToken())
830 allocOp,
"Host Shared allocation cannot be done async");
839 getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
840 shape, strides, sizeBytes);
844 auto nullPtr = rewriter.
create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
845 Value stream = adaptor.getAsyncDependencies().empty()
847 : adaptor.getAsyncDependencies().front();
849 auto isHostShared = rewriter.
create<mlir::LLVM::ConstantOp>(
853 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
857 Value alignedPtr = allocatedPtr;
860 auto memRefDescriptor = this->createMemRefDescriptor(
861 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
863 if (allocOp.getAsyncToken()) {
865 rewriter.
replaceOp(allocOp, {memRefDescriptor, stream});
867 rewriter.
replaceOp(allocOp, {memRefDescriptor});
873 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
874 gpu::DeallocOp deallocOp, OpAdaptor adaptor,
884 Value stream = adaptor.getAsyncDependencies().front();
885 deallocCallBuilder.create(loc, rewriter, {pointer, stream});
892 return isa<gpu::AsyncTokenType>(value.
getType());
899 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
900 async::YieldOp yieldOp, OpAdaptor adaptor,
907 llvm::SmallDenseSet<Value> streams;
908 for (
auto &operand : yieldOp->getOpOperands()) {
911 auto idx = operand.getOperandNumber();
912 auto stream = adaptor.getOperands()[idx];
913 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
914 eventRecordCallBuilder.create(loc, rewriter, {event, stream});
915 newOperands[idx] = event;
916 streams.insert(stream);
918 for (
auto stream : streams)
919 streamDestroyCallBuilder.create(loc, rewriter, {stream});
922 [&] { yieldOp->setOperands(newOperands); });
928 assert(isa<LLVM::LLVMPointerType>(value.
getType()));
930 return defOp.getCallee()->equals(functionName);
938 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
939 gpu::WaitOp waitOp, OpAdaptor adaptor,
941 if (waitOp.getAsyncToken())
946 for (
auto operand : adaptor.getOperands()) {
949 streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
950 streamDestroyCallBuilder.create(loc, rewriter, {operand});
954 eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
955 eventDestroyCallBuilder.create(loc, rewriter, {operand});
968 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
969 gpu::WaitOp waitOp, OpAdaptor adaptor,
971 if (!waitOp.getAsyncToken())
979 llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
980 auto operand = std::get<1>(pair);
984 auto *defOp = std::get<0>(pair).getDefiningOp();
986 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
987 eventRecordCallBuilder.create(loc, rewriter, {event, operand});
988 events.push_back(event);
992 events.push_back(operand);
996 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
997 for (
auto event : events)
998 streamWaitEventCallBuilder.create(loc, rewriter, {stream,
event});
999 for (
auto event : events)
1000 eventDestroyCallBuilder.create(loc, rewriter, {
event});
1019 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
1020 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
OpBuilder &builder)
const {
1021 auto loc = launchOp.getLoc();
1022 auto numKernelOperands = launchOp.getNumKernelOperands();
1026 loc, launchOp.getOperands().take_back(numKernelOperands),
1027 adaptor.getOperands().take_back(numKernelOperands), builder,
1028 kernelBarePtrCallConv);
1029 auto numArguments = arguments.size();
1031 argumentTypes.reserve(numArguments);
1032 for (
auto argument : arguments)
1033 argumentTypes.push_back(argument.getType());
1036 auto one = builder.
create<LLVM::ConstantOp>(loc, llvmInt32Type, 1);
1038 builder.
create<LLVM::AllocaOp>(loc, llvmPointerType, structType, one,
1041 builder.
create<LLVM::ConstantOp>(loc, llvmInt32Type, numArguments);
1042 auto arrayPtr = builder.
create<LLVM::AllocaOp>(
1043 loc, llvmPointerType, llvmPointerType, arraySize, 0);
1046 builder.
create<LLVM::GEPOp>(loc, llvmPointerType, structType, structPtr,
1048 builder.
create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
1049 auto elementPtr = builder.
create<LLVM::GEPOp>(
1050 loc, llvmPointerType, llvmPointerType, arrayPtr,
1052 builder.
create<LLVM::StoreOp>(loc, fieldPtr, elementPtr);
1067 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
1068 StringRef moduleName, StringRef name,
Location loc,
1071 std::vector<char> kernelName(name.begin(), name.end());
1072 kernelName.push_back(
'\0');
1074 std::string globalName =
1075 std::string(llvm::formatv(
"{0}_{1}_kernel_name", moduleName, name));
1077 loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
1078 LLVM::Linkage::Internal);
1098 LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
1099 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
1104 if (launchOp.getAsyncDependencies().size() > 1)
1106 launchOp,
"Cannot convert with more than one async dependency.");
1111 if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
1113 launchOp,
"Cannot convert non-async op with async dependencies.");
1119 gpu::GPUModuleOp kernelModule;
1120 if (cachedModuleTable)
1121 kernelModule = cachedModuleTable->lookup<gpu::GPUModuleOp>(
1122 launchOp.getKernelModuleName());
1124 kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
1125 launchOp, launchOp.getKernelModuleName());
1126 assert(kernelModule &&
"expected a kernel module");
1129 if (ArrayAttr targets = kernelModule.getTargetsAttr()) {
1131 if (!adaptor.getAsyncDependencies().empty())
1132 stream = adaptor.getAsyncDependencies().front();
1135 else if (launchOp.getAsyncToken())
1136 stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
1142 loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(),
1143 rewriter, kernelBarePtrCallConv);
1145 std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1146 if (launchOp.hasClusterSize()) {
1149 adaptor.getClusterSizeZ()};
1151 rewriter.
create<gpu::LaunchFuncOp>(
1152 launchOp.getLoc(), launchOp.getKernelAttr(),
1154 adaptor.getGridSizeZ()},
1156 adaptor.getBlockSizeZ()},
1157 adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
1158 if (launchOp.getAsyncToken())
1166 kernelModule->getAttrOfType<StringAttr>(gpuBinaryAnnotation);
1168 kernelModule.emitOpError()
1169 <<
"missing " << gpuBinaryAnnotation <<
" attribute";
1177 binaryAttr.getValue(), LLVM::Linkage::Internal);
1180 auto gpuBlob = binaryAttr.getValue();
1181 auto gpuBlobSize = rewriter.
create<mlir::LLVM::ConstantOp>(
1184 static_cast<int64_t
>(gpuBlob.size())));
1187 moduleLoadCallBuilder.
create(loc, rewriter, {data, gpuBlobSize});
1190 auto paramsCount = rewriter.
create<mlir::LLVM::ConstantOp>(
1194 static_cast<int64_t
>(launchOp.getNumKernelOperands())));
1198 auto kernelName = generateKernelNameConstant(
1199 launchOp.getKernelModuleName().getValue(),
1200 launchOp.getKernelName().getValue(), loc, rewriter);
1201 auto function = moduleGetFunctionCallBuilder.create(
1202 loc, rewriter, {module.getResult(), kernelName});
1203 Value zero = rewriter.
create<LLVM::ConstantOp>(loc, llvmInt32Type, 0);
1205 adaptor.getAsyncDependencies().empty()
1206 ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult()
1207 : adaptor.getAsyncDependencies().front();
1209 auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter);
1210 auto nullpointer = rewriter.
create<LLVM::ZeroOp>(loc, llvmPointerType);
1211 Value dynamicSharedMemorySize = launchOp.getDynamicSharedMemorySize()
1212 ? launchOp.getDynamicSharedMemorySize()
1214 launchKernelCallBuilder.create(
1216 {
function.getResult(), adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1217 adaptor.getGridSizeZ(), adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1218 adaptor.getBlockSizeZ(), dynamicSharedMemorySize, stream, kernelParams,
1219 nullpointer, paramsCount});
1221 if (launchOp.getAsyncToken()) {
1228 streamSynchronizeCallBuilder.create(loc, rewriter, stream);
1229 streamDestroyCallBuilder.create(loc, rewriter, stream);
1232 moduleUnloadCallBuilder.create(loc, rewriter, module.getResult());
1239 LLVM::LLVMPointerType destinationType,
1242 auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.
getType());
1243 if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1244 sourcePtr = rewriter.
create<LLVM::AddrSpaceCastOp>(
1247 destinationType.getAddressSpace()),
1252 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1253 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1255 auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1258 !isConvertibleAndHasIdentityMaps(memRefType) ||
1262 auto loc = memcpyOp.getLoc();
1268 Value nullPtr = rewriter.
create<LLVM::ZeroOp>(loc, elementPtrType);
1270 loc, elementPtrType,
1271 typeConverter->convertType(memRefType.getElementType()), nullPtr,
1277 srcDesc.alignedPtr(rewriter, loc),
1278 *getTypeConverter());
1280 loc, rewriter, llvmPointerType,
1282 *getTypeConverter());
1284 auto stream = adaptor.getAsyncDependencies().front();
1285 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1292 LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1293 gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1295 auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1298 !isConvertibleAndHasIdentityMaps(memRefType) ||
1302 auto loc = memsetOp.getLoc();
1304 Type valueType = adaptor.getValue().getType();
1307 if (!valueType.
isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1309 memsetOp,
"value must be a 16 or 32 bit int or float");
1313 Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1319 rewriter.
create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
1321 dstDesc.alignedPtr(rewriter, loc),
1322 *getTypeConverter());
1324 auto stream = adaptor.getAsyncDependencies().front();
1326 valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1327 builder.
create(loc, rewriter, {dst, value, numElements, stream});
1333 LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1334 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1337 setDefaultDeviceCallBuilder.create(loc, rewriter, {adaptor.getDevIndex()});
1342 template <
typename T>
1345 return builder.
create<LLVM::ConstantOp>(loc, llvmInt32Type,
1346 static_cast<int32_t
>(tValue));
1349 template <
typename T>
1352 return builder.
create<LLVM::ConstantOp>(
1353 loc, llvmFloat32Type,
1357 LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1358 gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1364 auto stream = adaptor.getAsyncDependencies().front();
1367 Type dType = op.getMemref().getType().getElementType();
1371 for (
Value dim : adaptor.getDims()) {
1372 dims.push_back(dim);
1382 if (dims.size() == 2) {
1384 auto handleSz = rewriter.
create<LLVM::ConstantOp>(
1386 handle = rewriter.
create<LLVM::AllocaOp>(
1387 loc, llvmPointerType, llvmInt8Type, handleSz, 16);
1388 handle = rewriter.
create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1390 createLtDnMatCallBuilder
1392 {handle, dims[0], dims[1], pTensor, dtp, stream})
1396 createDnMatCallBuilder
1397 .
create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1401 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1402 handle = createDnVecCallBuilder
1403 .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1406 rewriter.
replaceOp(op, {handle, stream});
1410 LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1411 gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1417 auto stream = adaptor.getAsyncDependencies().front();
1418 auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
1420 for (
Value dim : definingOp.getDims()) {
1421 dims.push_back(dim);
1423 if (dims.size() == 2) {
1427 destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1428 {adaptor.getDnTensor(), stream});
1430 destroyDnMatCallBuilder.create(loc, rewriter,
1431 {adaptor.getDnTensor(), stream});
1434 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1435 destroyDnVecCallBuilder.create(loc, rewriter,
1436 {adaptor.getDnTensor(), stream});
1442 LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1443 gpu::CreateCooOp op, OpAdaptor adaptor,
1449 auto stream = adaptor.getAsyncDependencies().front();
1457 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1459 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1463 createCooCallBuilder
1464 .create(loc, rewriter,
1465 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1466 pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1468 rewriter.
replaceOp(op, {handle, stream});
1472 LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1473 gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1479 auto stream = adaptor.getAsyncDependencies().front();
1483 Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1485 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1489 createCooAoSCallBuilder
1490 .create(loc, rewriter,
1491 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1492 pIdxs, pValues, itp, dtp, stream})
1494 rewriter.
replaceOp(op, {handle, stream});
1498 LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1499 gpu::CreateCsrOp op, OpAdaptor adaptor,
1505 auto stream = adaptor.getAsyncDependencies().front();
1513 llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1515 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1517 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1522 createCsrCallBuilder
1523 .create(loc, rewriter,
1524 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1525 pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1527 rewriter.
replaceOp(op, {handle, stream});
1531 LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1532 gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1538 auto stream = adaptor.getAsyncDependencies().front();
1542 llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1546 auto handleSz = rewriter.
create<LLVM::ConstantOp>(
1549 loc, llvmPointerType, llvmInt8Type, handleSz, 16);
1550 handle = rewriter.
create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1552 create2To4SpMatCallBuilder
1554 {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1556 rewriter.
replaceOp(op, {handle, stream});
1560 LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1561 gpu::DestroySpMatOp op, OpAdaptor adaptor,
1567 auto stream = adaptor.getAsyncDependencies().front();
1570 destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1571 {adaptor.getSpmat(), stream});
1574 destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1580 LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1581 gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1590 auto stream = adaptor.getAsyncDependencies().front();
1591 auto bufferSize = spMVBufferSizeCallBuilder
1592 .create(loc, rewriter,
1593 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1594 adaptor.getDnY(), computeType, stream})
1596 rewriter.
replaceOp(op, {bufferSize, stream});
1600 LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1601 gpu::SpMVOp op, OpAdaptor adaptor,
1610 auto stream = adaptor.getAsyncDependencies().front();
1613 spMVCallBuilder.create(loc, rewriter,
1614 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1615 adaptor.getDnY(), computeType, pBuf, stream});
1620 LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1621 gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1629 auto stream = adaptor.getAsyncDependencies().front();
1638 auto bufferSize = rewriter.
create<LLVM::AllocaOp>(
1639 loc, llvmPointerType, llvmPointerType, three, 16);
1640 createCuSparseLtSpMMBufferSizeBuilder
1642 {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1643 adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1647 auto bufferSizePtr1 = rewriter.
create<LLVM::GEPOp>(
1648 loc, llvmPointerType, llvmPointerType, bufferSize,
1651 auto bufferSizePtr2 = rewriter.
create<LLVM::GEPOp>(
1652 loc, llvmPointerType, llvmPointerType, bufferSize,
1656 rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize);
1658 rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1);
1660 rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
1662 rewriter.
replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1667 createSpMMBufferSizeCallBuilder
1668 .create(loc, rewriter,
1669 {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1670 adaptor.getDnmatC(), computeType, stream})
1672 rewriter.
replaceOp(op, {bufferSize, stream});
1677 LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1678 gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1688 auto stream = adaptor.getAsyncDependencies().front();
1690 createSDDMMBufferSizeCallBuilder
1691 .create(loc, rewriter,
1692 {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1693 adaptor.getSpmatC(), computeType, stream})
1695 rewriter.
replaceOp(op, {bufferSize, stream});
1699 LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1700 gpu::SpMMOp op, OpAdaptor adaptor,
1711 auto stream = adaptor.getAsyncDependencies().front();
1716 for (
Value buffer : adaptor.getBuffers()) {
1718 pBufs.push_back(pBuf);
1720 createCuSparseLtSpMMBuilder.create(
1722 {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1723 pBufs[0], pBufs[1], pBufs[2], stream});
1727 createSpMMCallBuilder.create(loc, rewriter,
1728 {modeA, modeB, adaptor.getSpmatA(),
1729 adaptor.getDnmatB(), adaptor.getDnmatC(),
1730 computeType, pBuf, stream});
1736 template <
typename T>
1743 LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1744 gpu::SDDMMOp op, OpAdaptor adaptor,
1754 auto stream = adaptor.getAsyncDependencies().front();
1757 createSDDMMCallBuilder.create(loc, rewriter,
1758 {modeA, modeB, adaptor.getDnmatA(),
1759 adaptor.getDnmatB(), adaptor.getSpmatC(),
1760 computeType, pBuf, stream});
1766 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1767 gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1773 auto stream = adaptor.getAsyncDependencies().front();
1774 Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1776 rewriter.
replaceOp(op, {descr, stream});
1781 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1782 gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1788 auto stream = adaptor.getAsyncDependencies().front();
1789 createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1790 {adaptor.getDesc(), stream});
1796 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1797 gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1807 auto stream = adaptor.getAsyncDependencies().front();
1811 Value bufferSizeNew;
1813 if (adaptor.getKind() ==
1814 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1816 createSpGEMMWorkEstimationBuilder
1817 .create(loc, rewriter,
1818 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1819 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1820 adaptor.getBufferSz(), pBuf, stream})
1824 createSpGEMMComputeBuilder
1825 .create(loc, rewriter,
1826 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1827 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1828 adaptor.getBufferSz(), pBuf, stream})
1831 rewriter.
replaceOp(op, {bufferSizeNew, stream});
1835 LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1836 gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1846 auto stream = adaptor.getAsyncDependencies().front();
1847 createSpGEMMCopyBuilder.create(loc, rewriter,
1848 {adaptor.getDesc(), modeA, modeB,
1849 adaptor.getSpmatA(), adaptor.getSpmatB(),
1850 adaptor.getSpmatC(), computeType, stream});
1855 LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1856 gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1862 auto stream = adaptor.getAsyncDependencies().front();
1866 auto buffer = rewriter.
create<LLVM::AllocaOp>(
1867 loc, llvmPointerType, llvmInt64Type, three, 16);
1869 auto rowsPtr = rewriter.
create<LLVM::GEPOp>(
1870 loc, llvmPointerType, llvmPointerType, buffer,
1873 auto colsPtr = rewriter.
create<LLVM::GEPOp>(
1874 loc, llvmPointerType, llvmPointerType, buffer,
1877 auto nnzsPtr = rewriter.
create<LLVM::GEPOp>(
1878 loc, llvmPointerType, llvmPointerType, buffer,
1881 createSpMatGetSizeBuilder.
create(
1882 loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1883 auto rows = rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr);
1884 auto cols = rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr);
1885 auto nnzs = rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr);
1887 rewriter.
replaceOp(op, {rows, cols, nnzs, stream});
1891 LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1892 gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1898 auto stream = adaptor.getAsyncDependencies().front();
1905 createSetCsrPointersBuilder.create(
1906 loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1911 LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1912 gpu::CreateCscOp op, OpAdaptor adaptor,
1918 auto stream = adaptor.getAsyncDependencies().front();
1926 llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1928 llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1930 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1935 createCscCallBuilder
1936 .create(loc, rewriter,
1937 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1938 pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1940 rewriter.
replaceOp(op, {handle, stream});
1944 LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1945 gpu::CreateBsrOp op, OpAdaptor adaptor,
1951 auto stream = adaptor.getAsyncDependencies().front();
1959 llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1961 llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1963 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1968 createBsrCallBuilder
1969 .create(loc, rewriter,
1970 {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1971 adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1972 pColIdxs, pValues, ptp, itp, dtp, stream})
1974 rewriter.
replaceOp(op, {handle, stream});
1980 StringRef gpuBinaryAnnotation,
1981 bool kernelBarePtrCallConv,
1983 addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
1984 addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
1985 addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
1986 addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter);
1988 patterns.
add<ConvertAllocOpToGpuRuntimeCallPattern,
1989 ConvertDeallocOpToGpuRuntimeCallPattern,
1990 ConvertHostRegisterOpToGpuRuntimeCallPattern,
1991 ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1992 ConvertMemcpyOpToGpuRuntimeCallPattern,
1993 ConvertMemsetOpToGpuRuntimeCallPattern,
1994 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1995 ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1996 ConvertWaitOpToGpuRuntimeCallPattern,
1997 ConvertAsyncYieldToGpuRuntimeCallPattern,
1998 ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1999 ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
2000 ConvertCreateCooOpToGpuRuntimeCallPattern,
2001 ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
2002 ConvertCreateCsrOpToGpuRuntimeCallPattern,
2003 ConvertCreateCscOpToGpuRuntimeCallPattern,
2004 ConvertCreateBsrOpToGpuRuntimeCallPattern,
2005 ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
2006 ConvertDestroySpMatOpToGpuRuntimeCallPattern,
2007 ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
2008 ConvertSpMVOpToGpuRuntimeCallPattern,
2009 ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
2010 ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
2011 ConvertSpMMOpToGpuRuntimeCallPattern,
2012 ConvertSDDMMOpToGpuRuntimeCallPattern,
2013 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
2014 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
2015 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
2016 ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
2017 ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
2018 ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
2019 patterns.
add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
2020 converter, gpuBinaryAnnotation, kernelBarePtrCallConv, cachedModuleTable);
2021 patterns.
add<EraseGpuModuleOpPattern>(&converter.
getContext());
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static constexpr const char * kGpuBinaryStorageSuffix
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static MLIRContext * getContext(OpFoldResult val)
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static spirv::ScalarType getIndexType(MLIRContext *ctx, const SPIRVConversionOptions &options)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static int64_t getNumElements(ShapedType type)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
IntegerAttr getIndexAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
FloatAttr getF32FloatAttr(float value)
IntegerAttr getI8IntegerAttr(int8_t value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
static LLVMStructType getNewIdentified(MLIRContext *context, StringRef name, ArrayRef< Type > elements, bool isPacked=false)
Gets a new identified struct with the given body.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
This class helps build Operations.
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Block * getBlock() const
Returns the current block of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
Location getLoc()
The source location the operation was defined or derived from.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Value createGlobalString(Location loc, OpBuilder &builder, StringRef name, StringRef value, Linkage linkage)
Create an LLVM global containing the string "value" at the module containing surrounding the insertio...
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
void populateFinalizeMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef gpuBinaryAnnotation={}, bool kernelBarePtrCallConv=false, SymbolTable *cachedModuleTable=nullptr)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
void registerConvertToLLVMDependentDialectLoading(DialectRegistry ®istry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
LLVM::LLVMFunctionType functionType
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Utility class for the GPU dialect to represent triples of Values accessible through ....