39 #include "llvm/ADT/STLExtras.h"
40 #include "llvm/Support/Error.h"
41 #include "llvm/Support/FormatVariadic.h"
43 #define DEBUG_TYPE "gpu-to-llvm"
46 #define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
47 #include "mlir/Conversion/Passes.h.inc"
55 class GpuToLLVMConversionPass
56 :
public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
60 Base::getDependentDialects(registry);
64 void runOnOperation()
override;
67 template <
typename OpTy>
70 explicit ConvertOpToGpuRuntimeCallPattern(
78 return type.hasStaticShape()
80 rewriter, loc, indexType, type.getNumElements())
83 : rewriter.
create<LLVM::MulOp>(loc,
84 desc.
stride(rewriter, loc, 0),
85 desc.
size(rewriter, loc, 0));
88 MLIRContext *context = &this->getTypeConverter()->getContext();
98 context, this->getTypeConverter()->getPointerBitwidth(0));
103 {llvmPointerType , llvmInt64Type }};
105 "mgpuModuleUnload", llvmVoidType, {llvmPointerType }};
107 "mgpuModuleGetFunction",
131 "mgpuStreamCreate", llvmPointerType , {}};
133 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType }};
135 "mgpuStreamSynchronize",
139 "mgpuStreamWaitEvent",
141 {llvmPointerType , llvmPointerType }};
143 "mgpuEventCreate", llvmPointerType , {}};
145 "mgpuEventDestroy", llvmVoidType, {llvmPointerType }};
147 "mgpuEventSynchronize",
153 {llvmPointerType , llvmPointerType }};
155 "mgpuMemHostRegisterMemRef",
161 "mgpuMemHostUnregisterMemRef",
175 {llvmPointerType , llvmPointerType }};
179 {llvmPointerType , llvmPointerType ,
192 {llvmPointerType , llvmInt32Type ,
196 "mgpuSetDefaultDevice",
202 {llvmIntPtrType, llvmPointerType, llvmInt32Type,
207 {llvmPointerType, llvmPointerType }};
211 {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
216 {llvmPointerType, llvmPointerType }};
220 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
221 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
226 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
227 llvmPointerType, llvmInt32Type, llvmInt32Type,
232 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
233 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
234 llvmInt32Type, llvmPointerType }};
238 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
239 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
240 llvmInt32Type, llvmPointerType }};
244 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
245 llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
246 llvmInt32Type, llvmInt32Type, llvmInt32Type,
251 {llvmPointerType, llvmPointerType }};
253 "mgpuSpMVBufferSize",
255 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
256 llvmInt32Type, llvmPointerType }};
260 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
261 llvmInt32Type, llvmPointerType, llvmPointerType }};
263 "mgpuSpMMBufferSize",
265 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
266 llvmPointerType, llvmInt32Type, llvmPointerType }};
270 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
271 llvmPointerType, llvmInt32Type, llvmPointerType,
274 "mgpuSDDMMBufferSize",
276 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
277 llvmPointerType, llvmInt32Type, llvmPointerType }};
281 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
282 llvmPointerType, llvmInt32Type, llvmPointerType,
285 "mgpuCreateCuSparseLtDnMat",
287 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
288 llvmInt32Type, llvmPointerType }};
290 "mgpuDestroyCuSparseLtSpMat",
292 {llvmPointerType, llvmPointerType }};
294 "mgpuDestroyCuSparseLtDnMat",
296 {llvmPointerType, llvmPointerType }};
298 "mgpuCusparseLtCreate2To4SpMat",
300 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
301 llvmInt32Type, llvmPointerType }};
303 "mgpuCuSparseLtSpMMBufferSize",
305 {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
306 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
309 "mgpuCuSparseLtSpMM",
311 {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
312 llvmPointerType, llvmPointerType, llvmPointerType }};
314 "mgpuSpGEMMCreateDescr",
318 "mgpuSpGEMMDestroyDescr",
320 {llvmPointerType , llvmPointerType }};
322 "mgpuSpGEMMWorkEstimation",
324 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
325 llvmPointerType , llvmPointerType , llvmPointerType ,
326 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
331 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
332 llvmPointerType , llvmPointerType , llvmPointerType ,
333 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
338 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
339 llvmPointerType , llvmPointerType , llvmPointerType ,
340 llvmInt32Type , llvmPointerType }};
344 {llvmPointerType , llvmPointerType , llvmPointerType ,
345 llvmPointerType , llvmPointerType }};
347 "mgpuSetCsrPointers",
349 {llvmPointerType , llvmPointerType ,
350 llvmPointerType , llvmPointerType ,
356 class ConvertHostRegisterOpToGpuRuntimeCallPattern
357 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
359 ConvertHostRegisterOpToGpuRuntimeCallPattern(
361 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
365 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
369 class ConvertHostUnregisterOpToGpuRuntimeCallPattern
370 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
372 ConvertHostUnregisterOpToGpuRuntimeCallPattern(
374 : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
379 matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
385 class ConvertAllocOpToGpuRuntimeCallPattern
386 :
public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
389 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
393 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
399 class ConvertDeallocOpToGpuRuntimeCallPattern
400 :
public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
402 ConvertDeallocOpToGpuRuntimeCallPattern(
404 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
408 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
412 class ConvertAsyncYieldToGpuRuntimeCallPattern
413 :
public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
415 ConvertAsyncYieldToGpuRuntimeCallPattern(
417 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
421 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
427 class ConvertWaitOpToGpuRuntimeCallPattern
428 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
431 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
435 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
441 class ConvertWaitAsyncOpToGpuRuntimeCallPattern
442 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
444 ConvertWaitAsyncOpToGpuRuntimeCallPattern(
446 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
450 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
467 class ConvertLaunchFuncOpToGpuRuntimeCallPattern
468 :
public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
470 ConvertLaunchFuncOpToGpuRuntimeCallPattern(
472 bool kernelBarePtrCallConv,
SymbolTable *cachedModuleTable)
473 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
474 gpuBinaryAnnotation(gpuBinaryAnnotation),
475 kernelBarePtrCallConv(kernelBarePtrCallConv),
476 cachedModuleTable(cachedModuleTable) {}
479 Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
481 Value generateKernelNameConstant(StringRef moduleName, StringRef name,
485 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
489 bool kernelBarePtrCallConv;
507 class ConvertMemcpyOpToGpuRuntimeCallPattern
508 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
511 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
515 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
521 class ConvertMemsetOpToGpuRuntimeCallPattern
522 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
525 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
529 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
535 class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
536 :
public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
538 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
540 : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
544 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
550 #define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
551 class Convert##op_name##ToGpuRuntimeCallPattern \
552 : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
554 Convert##op_name##ToGpuRuntimeCallPattern( \
555 const LLVMTypeConverter &typeConverter) \
556 : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
560 matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
561 ConversionPatternRewriter &rewriter) const override; \
588 void GpuToLLVMConversionPass::runOnOperation() {
592 options.useBarePtrCallConv = hostBarePtrCallConv;
601 auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
604 iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
609 [](gpu::GPUModuleOp module) ->
bool {
610 return module.getTargetsAttr() !=
nullptr;
615 [&](gpu::LaunchFuncOp op) ->
bool {
617 symbolTable.
lookup<gpu::GPUModuleOp>(op.getKernelModuleName());
620 (module && module.getTargetsAttr() &&
621 !module.getTargetsAttr().empty());
630 kernelBarePtrCallConv, &symbolTable);
640 auto function = [&] {
641 if (
auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(
functionName))
646 return builder.
create<LLVM::CallOp>(loc,
function, arguments);
663 llvm_unreachable(
"unsupported type");
669 if (llvm::isa<ComplexType>(type)) {
672 if (elementType.isBF16())
674 if (elementType.isF16())
676 if (elementType.isF32())
678 if (elementType.isF64())
680 if (elementType.isInteger(8))
682 if (elementType.isInteger(16))
684 if (elementType.isInteger(32))
702 llvm_unreachable(
"unsupported element type");
706 return spMat.
getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
731 llvm_unreachable(
"cannot find spmat def");
736 auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
749 if (!llvm::all_of(operands, [](
Value value) {
753 op,
"Cannot convert if operands aren't of LLVM type.");
759 gpu::AsyncOpInterface op) {
760 if (op.getAsyncDependencies().size() != 1)
762 op,
"Can only convert with exactly one async dependency.");
764 if (!op.getAsyncToken())
770 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
771 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
773 auto *op = hostRegisterOp.getOperation();
779 auto memRefType = hostRegisterOp.getValue().getType();
780 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
783 auto arguments = getTypeConverter()->promoteOperands(
784 loc, op->
getOperands(), adaptor.getOperands(), rewriter);
785 arguments.push_back(elementSize);
786 hostRegisterCallBuilder.create(loc, rewriter, arguments);
792 LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
793 gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
795 Operation *op = hostUnregisterOp.getOperation();
801 auto memRefType = hostUnregisterOp.getValue().getType();
802 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
805 auto arguments = getTypeConverter()->promoteOperands(
806 loc, op->
getOperands(), adaptor.getOperands(), rewriter);
807 arguments.push_back(elementSize);
808 hostUnregisterCallBuilder.create(loc, rewriter, arguments);
814 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
815 gpu::AllocOp allocOp, OpAdaptor adaptor,
818 MemRefType memRefType = allocOp.getType();
821 !isConvertibleAndHasIdentityMaps(memRefType))
824 auto loc = allocOp.getLoc();
826 bool isShared = allocOp.getHostShared();
828 if (isShared && allocOp.getAsyncToken())
830 allocOp,
"Host Shared allocation cannot be done async");
839 getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
840 shape, strides, sizeBytes);
844 auto nullPtr = rewriter.
create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
845 Value stream = adaptor.getAsyncDependencies().empty()
847 : adaptor.getAsyncDependencies().front();
849 auto isHostShared = rewriter.
create<mlir::LLVM::ConstantOp>(
853 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
857 Value alignedPtr = allocatedPtr;
860 auto memRefDescriptor = this->createMemRefDescriptor(
861 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
863 if (allocOp.getAsyncToken()) {
865 rewriter.
replaceOp(allocOp, {memRefDescriptor, stream});
867 rewriter.
replaceOp(allocOp, {memRefDescriptor});
873 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
874 gpu::DeallocOp deallocOp, OpAdaptor adaptor,
884 Value stream = adaptor.getAsyncDependencies().front();
885 deallocCallBuilder.create(loc, rewriter, {pointer, stream});
892 return isa<gpu::AsyncTokenType>(value.
getType());
899 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
900 async::YieldOp yieldOp, OpAdaptor adaptor,
907 llvm::SmallDenseSet<Value> streams;
908 for (
auto &operand : yieldOp->getOpOperands()) {
911 auto idx = operand.getOperandNumber();
912 auto stream = adaptor.getOperands()[idx];
913 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
914 eventRecordCallBuilder.create(loc, rewriter, {event, stream});
915 newOperands[idx] = event;
916 streams.insert(stream);
918 for (
auto stream : streams)
919 streamDestroyCallBuilder.create(loc, rewriter, {stream});
921 rewriter.
modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
927 assert(isa<LLVM::LLVMPointerType>(value.
getType()));
929 return defOp.getCallee()->equals(functionName);
937 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
938 gpu::WaitOp waitOp, OpAdaptor adaptor,
940 if (waitOp.getAsyncToken())
945 for (
auto operand : adaptor.getOperands()) {
948 streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
949 streamDestroyCallBuilder.create(loc, rewriter, {operand});
953 eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
954 eventDestroyCallBuilder.create(loc, rewriter, {operand});
967 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
968 gpu::WaitOp waitOp, OpAdaptor adaptor,
970 if (!waitOp.getAsyncToken())
978 llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
979 auto operand = std::get<1>(pair);
983 auto *defOp = std::get<0>(pair).getDefiningOp();
985 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
986 eventRecordCallBuilder.create(loc, rewriter, {event, operand});
987 events.push_back(event);
991 events.push_back(operand);
995 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
996 for (
auto event : events)
997 streamWaitEventCallBuilder.create(loc, rewriter, {stream,
event});
998 for (
auto event : events)
999 eventDestroyCallBuilder.create(loc, rewriter, {
event});
1018 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
1019 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
OpBuilder &builder)
const {
1020 auto loc = launchOp.getLoc();
1021 auto numKernelOperands = launchOp.getNumKernelOperands();
1025 loc, launchOp.getOperands().take_back(numKernelOperands),
1026 adaptor.getOperands().take_back(numKernelOperands), builder,
1027 kernelBarePtrCallConv);
1028 auto numArguments = arguments.size();
1030 argumentTypes.reserve(numArguments);
1031 for (
auto argument : arguments)
1032 argumentTypes.push_back(argument.getType());
1035 auto one = builder.
create<LLVM::ConstantOp>(loc, llvmInt32Type, 1);
1037 builder.
create<LLVM::AllocaOp>(loc, llvmPointerType, structType, one,
1040 builder.
create<LLVM::ConstantOp>(loc, llvmInt32Type, numArguments);
1041 auto arrayPtr = builder.
create<LLVM::AllocaOp>(
1042 loc, llvmPointerType, llvmPointerType, arraySize, 0);
1044 const auto index =
static_cast<int32_t
>(en.index());
1046 builder.
create<LLVM::GEPOp>(loc, llvmPointerType, structType, structPtr,
1048 builder.
create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
1050 builder.
create<LLVM::GEPOp>(loc, llvmPointerType, llvmPointerType,
1052 builder.
create<LLVM::StoreOp>(loc, fieldPtr, elementPtr);
1067 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
1068 StringRef moduleName, StringRef name,
Location loc,
1071 std::vector<char> kernelName(name.begin(), name.end());
1072 kernelName.push_back(
'\0');
1074 std::string globalName =
1075 std::string(llvm::formatv(
"{0}_{1}_kernel_name", moduleName, name));
1077 loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
1078 LLVM::Linkage::Internal);
1098 LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
1099 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
1104 if (launchOp.getAsyncDependencies().size() > 1)
1106 launchOp,
"Cannot convert with more than one async dependency.");
1111 if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
1113 launchOp,
"Cannot convert non-async op with async dependencies.");
1119 gpu::GPUModuleOp kernelModule;
1120 if (cachedModuleTable)
1121 kernelModule = cachedModuleTable->lookup<gpu::GPUModuleOp>(
1122 launchOp.getKernelModuleName());
1124 kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
1125 launchOp, launchOp.getKernelModuleName());
1126 assert(kernelModule &&
"expected a kernel module");
1129 if (ArrayAttr targets = kernelModule.getTargetsAttr()) {
1131 if (!adaptor.getAsyncDependencies().empty())
1132 stream = adaptor.getAsyncDependencies().front();
1135 else if (launchOp.getAsyncToken())
1136 stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
1142 loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(),
1143 rewriter, kernelBarePtrCallConv);
1145 std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1146 if (launchOp.hasClusterSize()) {
1149 adaptor.getClusterSizeZ()};
1151 rewriter.
create<gpu::LaunchFuncOp>(
1152 launchOp.getLoc(), launchOp.getKernelAttr(),
1154 adaptor.getGridSizeZ()},
1156 adaptor.getBlockSizeZ()},
1157 adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
1158 if (launchOp.getAsyncToken())
1166 kernelModule->getAttrOfType<StringAttr>(gpuBinaryAnnotation);
1168 kernelModule.emitOpError()
1169 <<
"missing " << gpuBinaryAnnotation <<
" attribute";
1177 binaryAttr.getValue(), LLVM::Linkage::Internal);
1180 auto gpuBlob = binaryAttr.getValue();
1181 auto gpuBlobSize = rewriter.
create<mlir::LLVM::ConstantOp>(
1184 static_cast<int64_t
>(gpuBlob.size())));
1187 moduleLoadCallBuilder.
create(loc, rewriter, {data, gpuBlobSize});
1190 auto paramsCount = rewriter.
create<mlir::LLVM::ConstantOp>(
1194 static_cast<int64_t
>(launchOp.getNumKernelOperands())));
1198 auto kernelName = generateKernelNameConstant(
1199 launchOp.getKernelModuleName().getValue(),
1200 launchOp.getKernelName().getValue(), loc, rewriter);
1201 auto function = moduleGetFunctionCallBuilder.create(
1202 loc, rewriter, {module.getResult(), kernelName});
1203 Value zero = rewriter.
create<LLVM::ConstantOp>(loc, llvmInt32Type, 0);
1205 adaptor.getAsyncDependencies().empty()
1206 ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult()
1207 : adaptor.getAsyncDependencies().front();
1209 auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter);
1210 auto nullpointer = rewriter.
create<LLVM::ZeroOp>(loc, llvmPointerType);
1211 Value dynamicSharedMemorySize = launchOp.getDynamicSharedMemorySize()
1212 ? launchOp.getDynamicSharedMemorySize()
1214 launchKernelCallBuilder.create(
1216 {
function.getResult(), adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1217 adaptor.getGridSizeZ(), adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1218 adaptor.getBlockSizeZ(), dynamicSharedMemorySize, stream, kernelParams,
1219 nullpointer, paramsCount});
1221 if (launchOp.getAsyncToken()) {
1228 streamSynchronizeCallBuilder.create(loc, rewriter, stream);
1229 streamDestroyCallBuilder.create(loc, rewriter, stream);
1232 moduleUnloadCallBuilder.create(loc, rewriter, module.getResult());
1239 LLVM::LLVMPointerType destinationType,
1242 auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.
getType());
1243 if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1244 sourcePtr = rewriter.
create<LLVM::AddrSpaceCastOp>(
1247 destinationType.getAddressSpace()),
1252 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1253 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1255 auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1258 !isConvertibleAndHasIdentityMaps(memRefType) ||
1262 auto loc = memcpyOp.getLoc();
1268 Value nullPtr = rewriter.
create<LLVM::ZeroOp>(loc, elementPtrType);
1270 loc, elementPtrType,
1271 typeConverter->convertType(memRefType.getElementType()), nullPtr,
1277 srcDesc.alignedPtr(rewriter, loc),
1278 *getTypeConverter());
1280 loc, rewriter, llvmPointerType,
1282 *getTypeConverter());
1284 auto stream = adaptor.getAsyncDependencies().front();
1285 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1292 LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1293 gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1295 auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1298 !isConvertibleAndHasIdentityMaps(memRefType) ||
1302 auto loc = memsetOp.getLoc();
1304 Type valueType = adaptor.getValue().getType();
1307 if (!valueType.
isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1309 memsetOp,
"value must be a 16 or 32 bit int or float");
1313 Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1319 rewriter.
create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
1321 dstDesc.alignedPtr(rewriter, loc),
1322 *getTypeConverter());
1324 auto stream = adaptor.getAsyncDependencies().front();
1326 valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1327 builder.
create(loc, rewriter, {dst, value, numElements, stream});
1333 LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1334 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1337 auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1338 {adaptor.getDevIndex()});
1343 template <
typename T>
1346 return builder.
create<LLVM::ConstantOp>(loc, llvmInt32Type,
1347 static_cast<int32_t
>(tValue));
1350 template <
typename T>
1353 return builder.
create<LLVM::ConstantOp>(
1354 loc, llvmFloat32Type,
1358 LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1359 gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1365 auto stream = adaptor.getAsyncDependencies().front();
1368 Type dType = op.getMemref().getType().getElementType();
1372 for (
Value dim : adaptor.getDims()) {
1373 dims.push_back(dim);
1383 if (dims.size() == 2) {
1385 auto handleSz = rewriter.
create<LLVM::ConstantOp>(
1387 handle = rewriter.
create<LLVM::AllocaOp>(
1388 loc, llvmPointerType, llvmInt8Type, handleSz, 16);
1389 handle = rewriter.
create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1391 createLtDnMatCallBuilder
1393 {handle, dims[0], dims[1], pTensor, dtp, stream})
1397 createDnMatCallBuilder
1398 .
create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1402 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1403 handle = createDnVecCallBuilder
1404 .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1407 rewriter.
replaceOp(op, {handle, stream});
1411 LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1412 gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1418 auto stream = adaptor.getAsyncDependencies().front();
1419 auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
1421 for (
Value dim : definingOp.getDims()) {
1422 dims.push_back(dim);
1424 if (dims.size() == 2) {
1428 destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1429 {adaptor.getDnTensor(), stream});
1431 destroyDnMatCallBuilder.create(loc, rewriter,
1432 {adaptor.getDnTensor(), stream});
1435 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1436 destroyDnVecCallBuilder.create(loc, rewriter,
1437 {adaptor.getDnTensor(), stream});
1443 LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1444 gpu::CreateCooOp op, OpAdaptor adaptor,
1450 auto stream = adaptor.getAsyncDependencies().front();
1458 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1460 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1464 createCooCallBuilder
1465 .create(loc, rewriter,
1466 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1467 pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1469 rewriter.
replaceOp(op, {handle, stream});
1473 LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1474 gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1480 auto stream = adaptor.getAsyncDependencies().front();
1484 Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1486 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1490 createCooAoSCallBuilder
1491 .create(loc, rewriter,
1492 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1493 pIdxs, pValues, itp, dtp, stream})
1495 rewriter.
replaceOp(op, {handle, stream});
1499 LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1500 gpu::CreateCsrOp op, OpAdaptor adaptor,
1506 auto stream = adaptor.getAsyncDependencies().front();
1514 llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1516 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1518 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1523 createCsrCallBuilder
1524 .create(loc, rewriter,
1525 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1526 pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1528 rewriter.
replaceOp(op, {handle, stream});
1532 LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1533 gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1539 auto stream = adaptor.getAsyncDependencies().front();
1543 llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1547 auto handleSz = rewriter.
create<LLVM::ConstantOp>(
1550 loc, llvmPointerType, llvmInt8Type, handleSz, 16);
1551 handle = rewriter.
create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1553 create2To4SpMatCallBuilder
1555 {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1557 rewriter.
replaceOp(op, {handle, stream});
1561 LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1562 gpu::DestroySpMatOp op, OpAdaptor adaptor,
1568 auto stream = adaptor.getAsyncDependencies().front();
1571 destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1572 {adaptor.getSpmat(), stream});
1575 destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1581 LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1582 gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1591 auto stream = adaptor.getAsyncDependencies().front();
1592 auto bufferSize = spMVBufferSizeCallBuilder
1593 .create(loc, rewriter,
1594 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1595 adaptor.getDnY(), computeType, stream})
1597 rewriter.
replaceOp(op, {bufferSize, stream});
1601 LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1602 gpu::SpMVOp op, OpAdaptor adaptor,
1611 auto stream = adaptor.getAsyncDependencies().front();
1614 spMVCallBuilder.create(loc, rewriter,
1615 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1616 adaptor.getDnY(), computeType, pBuf, stream});
1621 LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1622 gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1630 auto stream = adaptor.getAsyncDependencies().front();
1639 auto bufferSize = rewriter.
create<LLVM::AllocaOp>(
1640 loc, llvmPointerType, llvmPointerType, three, 16);
1641 createCuSparseLtSpMMBufferSizeBuilder
1643 {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1644 adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1648 auto bufferSizePtr1 = rewriter.
create<LLVM::GEPOp>(
1649 loc, llvmPointerType, llvmPointerType, bufferSize,
1652 auto bufferSizePtr2 = rewriter.
create<LLVM::GEPOp>(
1653 loc, llvmPointerType, llvmPointerType, bufferSize,
1657 rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize);
1659 rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1);
1661 rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
1663 rewriter.
replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1668 createSpMMBufferSizeCallBuilder
1669 .create(loc, rewriter,
1670 {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1671 adaptor.getDnmatC(), computeType, stream})
1673 rewriter.
replaceOp(op, {bufferSize, stream});
1678 LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1679 gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1689 auto stream = adaptor.getAsyncDependencies().front();
1691 createSDDMMBufferSizeCallBuilder
1692 .create(loc, rewriter,
1693 {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1694 adaptor.getSpmatC(), computeType, stream})
1696 rewriter.
replaceOp(op, {bufferSize, stream});
1700 LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1701 gpu::SpMMOp op, OpAdaptor adaptor,
1712 auto stream = adaptor.getAsyncDependencies().front();
1717 for (
Value buffer : adaptor.getBuffers()) {
1719 pBufs.push_back(pBuf);
1721 createCuSparseLtSpMMBuilder.create(
1723 {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1724 pBufs[0], pBufs[1], pBufs[2], stream});
1728 createSpMMCallBuilder.create(loc, rewriter,
1729 {modeA, modeB, adaptor.getSpmatA(),
1730 adaptor.getDnmatB(), adaptor.getDnmatC(),
1731 computeType, pBuf, stream});
1737 template <
typename T>
1744 LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1745 gpu::SDDMMOp op, OpAdaptor adaptor,
1755 auto stream = adaptor.getAsyncDependencies().front();
1758 createSDDMMCallBuilder.create(loc, rewriter,
1759 {modeA, modeB, adaptor.getDnmatA(),
1760 adaptor.getDnmatB(), adaptor.getSpmatC(),
1761 computeType, pBuf, stream});
1767 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1768 gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1774 auto stream = adaptor.getAsyncDependencies().front();
1775 Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1777 rewriter.
replaceOp(op, {descr, stream});
1782 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1783 gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1789 auto stream = adaptor.getAsyncDependencies().front();
1790 createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1791 {adaptor.getDesc(), stream});
1797 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1798 gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1808 auto stream = adaptor.getAsyncDependencies().front();
1812 Value bufferSizeNew;
1814 if (adaptor.getKind() ==
1815 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1817 createSpGEMMWorkEstimationBuilder
1818 .create(loc, rewriter,
1819 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1820 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1821 adaptor.getBufferSz(), pBuf, stream})
1825 createSpGEMMComputeBuilder
1826 .create(loc, rewriter,
1827 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1828 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1829 adaptor.getBufferSz(), pBuf, stream})
1832 rewriter.
replaceOp(op, {bufferSizeNew, stream});
1836 LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1837 gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1847 auto stream = adaptor.getAsyncDependencies().front();
1848 createSpGEMMCopyBuilder.create(loc, rewriter,
1849 {adaptor.getDesc(), modeA, modeB,
1850 adaptor.getSpmatA(), adaptor.getSpmatB(),
1851 adaptor.getSpmatC(), computeType, stream});
1856 LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1857 gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1863 auto stream = adaptor.getAsyncDependencies().front();
1867 auto buffer = rewriter.
create<LLVM::AllocaOp>(
1868 loc, llvmPointerType, llvmInt64Type, three, 16);
1870 auto rowsPtr = rewriter.
create<LLVM::GEPOp>(
1871 loc, llvmPointerType, llvmPointerType, buffer,
1874 auto colsPtr = rewriter.
create<LLVM::GEPOp>(
1875 loc, llvmPointerType, llvmPointerType, buffer,
1878 auto nnzsPtr = rewriter.
create<LLVM::GEPOp>(
1879 loc, llvmPointerType, llvmPointerType, buffer,
1882 createSpMatGetSizeBuilder.
create(
1883 loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1884 auto rows = rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr);
1885 auto cols = rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr);
1886 auto nnzs = rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr);
1888 rewriter.
replaceOp(op, {rows, cols, nnzs, stream});
1892 LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1893 gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1899 auto stream = adaptor.getAsyncDependencies().front();
1906 createSetCsrPointersBuilder.create(
1907 loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1912 LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1913 gpu::CreateCscOp op, OpAdaptor adaptor,
1919 auto stream = adaptor.getAsyncDependencies().front();
1927 llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1929 llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1931 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1936 createCscCallBuilder
1937 .create(loc, rewriter,
1938 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1939 pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1941 rewriter.
replaceOp(op, {handle, stream});
1945 LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1946 gpu::CreateBsrOp op, OpAdaptor adaptor,
1952 auto stream = adaptor.getAsyncDependencies().front();
1960 llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1962 llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1964 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1969 createBsrCallBuilder
1970 .create(loc, rewriter,
1971 {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1972 adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1973 pColIdxs, pValues, ptp, itp, dtp, stream})
1975 rewriter.
replaceOp(op, {handle, stream});
1981 StringRef gpuBinaryAnnotation,
1982 bool kernelBarePtrCallConv,
1984 addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
1985 addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
1986 addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
1987 addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter);
1989 patterns.
add<ConvertAllocOpToGpuRuntimeCallPattern,
1990 ConvertDeallocOpToGpuRuntimeCallPattern,
1991 ConvertHostRegisterOpToGpuRuntimeCallPattern,
1992 ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1993 ConvertMemcpyOpToGpuRuntimeCallPattern,
1994 ConvertMemsetOpToGpuRuntimeCallPattern,
1995 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1996 ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1997 ConvertWaitOpToGpuRuntimeCallPattern,
1998 ConvertAsyncYieldToGpuRuntimeCallPattern,
1999 ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
2000 ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
2001 ConvertCreateCooOpToGpuRuntimeCallPattern,
2002 ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
2003 ConvertCreateCsrOpToGpuRuntimeCallPattern,
2004 ConvertCreateCscOpToGpuRuntimeCallPattern,
2005 ConvertCreateBsrOpToGpuRuntimeCallPattern,
2006 ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
2007 ConvertDestroySpMatOpToGpuRuntimeCallPattern,
2008 ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
2009 ConvertSpMVOpToGpuRuntimeCallPattern,
2010 ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
2011 ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
2012 ConvertSpMMOpToGpuRuntimeCallPattern,
2013 ConvertSDDMMOpToGpuRuntimeCallPattern,
2014 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
2015 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
2016 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
2017 ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
2018 ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
2019 ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
2020 patterns.
add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
2021 converter, gpuBinaryAnnotation, kernelBarePtrCallConv, cachedModuleTable);
2022 patterns.
add<EraseGpuModuleOpPattern>(&converter.
getContext());
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static constexpr const char * kGpuBinaryStorageSuffix
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static MLIRContext * getContext(OpFoldResult val)
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static spirv::ScalarType getIndexType(MLIRContext *ctx, const SPIRVConversionOptions &options)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static int64_t getNumElements(ShapedType type)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
IntegerAttr getIndexAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
FloatAttr getF32FloatAttr(float value)
IntegerAttr getI8IntegerAttr(int8_t value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
static LLVMStructType getNewIdentified(MLIRContext *context, StringRef name, ArrayRef< Type > elements, bool isPacked=false)
Gets a new identified struct with the given body.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
This class helps build Operations.
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Block * getBlock() const
Returns the current block of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
Location getLoc()
The source location the operation was defined or derived from.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Value createGlobalString(Location loc, OpBuilder &builder, StringRef name, StringRef value, Linkage linkage)
Create an LLVM global containing the string "value" at the module containing surrounding the insertio...
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void populateFinalizeMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef gpuBinaryAnnotation={}, bool kernelBarePtrCallConv=false, SymbolTable *cachedModuleTable=nullptr)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
void registerConvertToLLVMDependentDialectLoading(DialectRegistry ®istry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
LLVM::LLVMFunctionType functionType
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Utility class for the GPU dialect to represent triples of Values accessible through ....