42 #include "llvm/ADT/STLExtras.h"
43 #include "llvm/Support/Error.h"
44 #include "llvm/Support/FormatVariadic.h"
46 #define DEBUG_TYPE "gpu-to-llvm"
49 #define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
50 #include "mlir/Conversion/Passes.h.inc"
56 class GpuToLLVMConversionPass
57 :
public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
61 Base::getDependentDialects(registry);
65 void runOnOperation()
override;
68 template <
typename OpTy>
71 explicit ConvertOpToGpuRuntimeCallPattern(
79 if (type.hasStaticShape())
81 rewriter, loc, indexType, type.getNumElements());
83 uint64_t rank = type.getRank();
84 Value numElements = desc.
size(rewriter, loc, 0);
85 for (
unsigned i = 1; i < rank; i++)
86 numElements = rewriter.
create<LLVM::MulOp>(
87 loc, numElements, desc.
size(rewriter, loc, i));
91 MLIRContext *context = &this->getTypeConverter()->getContext();
101 context, this->getTypeConverter()->getPointerBitwidth(0));
104 "mgpuStreamCreate", llvmPointerType , {}};
106 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType }};
108 "mgpuStreamSynchronize",
112 "mgpuStreamWaitEvent",
114 {llvmPointerType , llvmPointerType }};
116 "mgpuEventCreate", llvmPointerType , {}};
118 "mgpuEventDestroy", llvmVoidType, {llvmPointerType }};
120 "mgpuEventSynchronize",
126 {llvmPointerType , llvmPointerType }};
128 "mgpuMemHostRegisterMemRef",
134 "mgpuMemHostUnregisterMemRef",
148 {llvmPointerType , llvmPointerType }};
152 {llvmPointerType , llvmPointerType ,
165 {llvmPointerType , llvmInt32Type ,
169 "mgpuSetDefaultDevice",
175 {llvmIntPtrType, llvmPointerType, llvmInt32Type,
180 {llvmPointerType, llvmPointerType }};
184 {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
189 {llvmPointerType, llvmPointerType }};
193 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
194 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
199 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
200 llvmPointerType, llvmInt32Type, llvmInt32Type,
205 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
206 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
207 llvmInt32Type, llvmPointerType }};
211 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
212 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
213 llvmInt32Type, llvmPointerType }};
217 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
218 llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
219 llvmInt32Type, llvmInt32Type, llvmInt32Type,
224 {llvmPointerType, llvmPointerType }};
226 "mgpuSpMVBufferSize",
228 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
229 llvmInt32Type, llvmPointerType }};
233 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
234 llvmInt32Type, llvmPointerType, llvmPointerType }};
236 "mgpuSpMMBufferSize",
238 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
239 llvmPointerType, llvmInt32Type, llvmPointerType }};
243 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
244 llvmPointerType, llvmInt32Type, llvmPointerType,
247 "mgpuSDDMMBufferSize",
249 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
250 llvmPointerType, llvmInt32Type, llvmPointerType }};
254 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
255 llvmPointerType, llvmInt32Type, llvmPointerType,
258 "mgpuCreateCuSparseLtDnMat",
260 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
261 llvmInt32Type, llvmPointerType }};
263 "mgpuDestroyCuSparseLtSpMat",
265 {llvmPointerType, llvmPointerType }};
267 "mgpuDestroyCuSparseLtDnMat",
269 {llvmPointerType, llvmPointerType }};
271 "mgpuCusparseLtCreate2To4SpMat",
273 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
274 llvmInt32Type, llvmPointerType }};
276 "mgpuCuSparseLtSpMMBufferSize",
278 {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
279 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
282 "mgpuCuSparseLtSpMM",
284 {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
285 llvmPointerType, llvmPointerType, llvmPointerType }};
287 "mgpuSpGEMMCreateDescr",
291 "mgpuSpGEMMDestroyDescr",
293 {llvmPointerType , llvmPointerType }};
295 "mgpuSpGEMMWorkEstimation",
297 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
298 llvmPointerType , llvmPointerType , llvmPointerType ,
299 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
304 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
305 llvmPointerType , llvmPointerType , llvmPointerType ,
306 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
311 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
312 llvmPointerType , llvmPointerType , llvmPointerType ,
313 llvmInt32Type , llvmPointerType }};
317 {llvmPointerType , llvmPointerType , llvmPointerType ,
318 llvmPointerType , llvmPointerType }};
320 "mgpuSetCsrPointers",
322 {llvmPointerType , llvmPointerType ,
323 llvmPointerType , llvmPointerType ,
329 class ConvertHostRegisterOpToGpuRuntimeCallPattern
330 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
332 ConvertHostRegisterOpToGpuRuntimeCallPattern(
334 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
338 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
342 class ConvertHostUnregisterOpToGpuRuntimeCallPattern
343 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
345 ConvertHostUnregisterOpToGpuRuntimeCallPattern(
347 : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
352 matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
358 class ConvertAllocOpToGpuRuntimeCallPattern
359 :
public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
362 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
366 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
372 class ConvertDeallocOpToGpuRuntimeCallPattern
373 :
public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
375 ConvertDeallocOpToGpuRuntimeCallPattern(
377 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
381 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
385 class ConvertAsyncYieldToGpuRuntimeCallPattern
386 :
public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
388 ConvertAsyncYieldToGpuRuntimeCallPattern(
390 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
394 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
400 class ConvertWaitOpToGpuRuntimeCallPattern
401 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
404 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
408 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
414 class ConvertWaitAsyncOpToGpuRuntimeCallPattern
415 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
417 ConvertWaitAsyncOpToGpuRuntimeCallPattern(
419 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
423 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
428 class LegalizeLaunchFuncOpPattern
429 :
public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
432 bool kernelBarePtrCallConv,
433 bool kernelIntersperseSizeCallConv)
434 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
435 kernelBarePtrCallConv(kernelBarePtrCallConv),
436 kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
440 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
443 bool kernelBarePtrCallConv;
444 bool kernelIntersperseSizeCallConv;
449 class ConvertMemcpyOpToGpuRuntimeCallPattern
450 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
453 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
457 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
463 class ConvertMemsetOpToGpuRuntimeCallPattern
464 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
467 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
471 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
477 class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
478 :
public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
480 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
482 : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
486 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
492 #define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
493 class Convert##op_name##ToGpuRuntimeCallPattern \
494 : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
496 Convert##op_name##ToGpuRuntimeCallPattern( \
497 const LLVMTypeConverter &typeConverter) \
498 : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
502 matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
503 ConversionPatternRewriter &rewriter) const override; \
530 void GpuToLLVMConversionPass::runOnOperation() {
540 return signalPassFailure();
544 options.useBarePtrCallConv = hostBarePtrCallConv;
547 target.addLegalDialect<LLVM::LLVMDialect>();
553 auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
556 iface->populateConvertToLLVMConversionPatterns(target, converter,
patterns);
561 target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
563 target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
564 [&](gpu::LaunchFuncOp op) ->
bool {
return converter.isLegal(op); });
572 kernelBarePtrCallConv,
573 kernelIntersperseSizeCallConv);
583 auto function = [&] {
584 if (
auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(
functionName))
589 return builder.
create<LLVM::CallOp>(loc,
function, arguments);
606 llvm_unreachable(
"unsupported type");
612 if (llvm::isa<ComplexType>(type)) {
614 auto elementType = cast<ComplexType>(type).getElementType();
615 if (elementType.isBF16())
617 if (elementType.isF16())
619 if (elementType.isF32())
621 if (elementType.isF64())
623 if (elementType.isInteger(8))
625 if (elementType.isInteger(16))
627 if (elementType.isInteger(32))
645 llvm_unreachable(
"unsupported element type");
649 return spMat.
getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
674 llvm_unreachable(
"cannot find spmat def");
679 auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
692 if (!llvm::all_of(operands, [](
Value value) {
696 op,
"Cannot convert if operands aren't of LLVM type.");
702 gpu::AsyncOpInterface op) {
703 if (op.getAsyncDependencies().size() != 1)
705 op,
"Can only convert with exactly one async dependency.");
707 if (!op.getAsyncToken())
713 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
714 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
716 auto *op = hostRegisterOp.getOperation();
722 auto memRefType = hostRegisterOp.getValue().getType();
723 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
726 auto arguments = getTypeConverter()->promoteOperands(
727 loc, op->getOperands(), adaptor.getOperands(), rewriter);
728 arguments.push_back(elementSize);
729 hostRegisterCallBuilder.create(loc, rewriter, arguments);
735 LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
736 gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
738 Operation *op = hostUnregisterOp.getOperation();
744 auto memRefType = hostUnregisterOp.getValue().getType();
745 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
748 auto arguments = getTypeConverter()->promoteOperands(
749 loc, op->
getOperands(), adaptor.getOperands(), rewriter);
750 arguments.push_back(elementSize);
751 hostUnregisterCallBuilder.create(loc, rewriter, arguments);
757 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
758 gpu::AllocOp allocOp, OpAdaptor adaptor,
761 MemRefType memRefType = allocOp.getType();
763 if (failed(
areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
764 !isConvertibleAndHasIdentityMaps(memRefType))
767 auto loc = allocOp.getLoc();
769 bool isShared = allocOp.getHostShared();
771 if (isShared && allocOp.getAsyncToken())
773 allocOp,
"Host Shared allocation cannot be done async");
782 getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
783 shape, strides, sizeBytes);
787 auto nullPtr = rewriter.
create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
788 Value stream = adaptor.getAsyncDependencies().empty()
790 : adaptor.getAsyncDependencies().front();
792 auto isHostShared = rewriter.
create<mlir::LLVM::ConstantOp>(
796 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
800 Value alignedPtr = allocatedPtr;
803 auto memRefDescriptor = this->createMemRefDescriptor(
804 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
806 if (allocOp.getAsyncToken()) {
808 rewriter.
replaceOp(allocOp, {memRefDescriptor, stream});
810 rewriter.
replaceOp(allocOp, {memRefDescriptor});
816 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
817 gpu::DeallocOp deallocOp, OpAdaptor adaptor,
819 if (failed(
areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
827 Value stream = adaptor.getAsyncDependencies().front();
828 deallocCallBuilder.create(loc, rewriter, {pointer, stream});
835 return isa<gpu::AsyncTokenType>(value.
getType());
842 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
843 async::YieldOp yieldOp, OpAdaptor adaptor,
850 llvm::SmallDenseSet<Value> streams;
851 for (
auto &operand : yieldOp->getOpOperands()) {
854 auto idx = operand.getOperandNumber();
855 auto stream = adaptor.getOperands()[idx];
856 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
857 eventRecordCallBuilder.create(loc, rewriter, {event, stream});
858 newOperands[idx] = event;
859 streams.insert(stream);
861 for (
auto stream : streams)
862 streamDestroyCallBuilder.create(loc, rewriter, {stream});
864 rewriter.
modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
870 assert(isa<LLVM::LLVMPointerType>(value.
getType()));
872 return *defOp.getCallee() == functionName;
880 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
881 gpu::WaitOp waitOp, OpAdaptor adaptor,
883 if (waitOp.getAsyncToken())
888 for (
auto operand : adaptor.getOperands()) {
891 streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
892 streamDestroyCallBuilder.create(loc, rewriter, {operand});
896 eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
897 eventDestroyCallBuilder.create(loc, rewriter, {operand});
910 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
911 gpu::WaitOp waitOp, OpAdaptor adaptor,
913 if (!waitOp.getAsyncToken())
921 llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
922 auto operand = std::get<1>(pair);
926 auto *defOp = std::get<0>(pair).getDefiningOp();
928 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
929 eventRecordCallBuilder.create(loc, rewriter, {event, operand});
930 events.push_back(event);
934 events.push_back(operand);
938 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
939 for (
auto event : events)
940 streamWaitEventCallBuilder.create(loc, rewriter, {stream,
event});
941 for (
auto event : events)
942 eventDestroyCallBuilder.create(loc, rewriter, {
event});
949 LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
950 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
952 if (failed(
areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
955 if (launchOp.getAsyncDependencies().size() > 1)
957 launchOp,
"Cannot convert with more than one async dependency.");
962 if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
964 launchOp,
"Cannot convert non-async op with async dependencies.");
969 if (!adaptor.getAsyncDependencies().empty())
970 stream = adaptor.getAsyncDependencies().front();
973 else if (launchOp.getAsyncToken())
974 stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
979 OperandRange origArguments = launchOp.getKernelOperands();
981 loc, origArguments, adaptor.getKernelOperands(), rewriter,
982 kernelBarePtrCallConv);
986 if (kernelIntersperseSizeCallConv) {
987 if (origArguments.size() != llvmArguments.size()) {
991 "Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
994 llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
995 for (
auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
996 auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
999 launchOp,
"Operand to launch op is not a memref.");
1002 if (!memrefTy.hasStaticShape() ||
1003 !memrefTy.getElementType().isIntOrFloat()) {
1005 launchOp,
"Operand to launch op is not a memref with a static "
1006 "shape and an integer or float element type.");
1009 unsigned bitwidth = memrefTy.getElementTypeBitWidth();
1010 if (bitwidth % 8 != 0) {
1012 launchOp,
"Operand to launch op is not a memref with a "
1013 "byte-aligned element type.");
1016 uint64_t staticSize =
static_cast<uint64_t
>(bitwidth / 8) *
1017 static_cast<uint64_t
>(memrefTy.getNumElements());
1019 Value sizeArg = rewriter.
create<LLVM::ConstantOp>(
1020 loc, getIndexType(), rewriter.
getIndexAttr(staticSize));
1021 llvmArgumentsWithSizes.push_back(llvmArg);
1022 llvmArgumentsWithSizes.push_back(sizeArg);
1026 std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1027 if (launchOp.hasClusterSize()) {
1030 adaptor.getClusterSizeZ()};
1032 rewriter.
create<gpu::LaunchFuncOp>(
1033 launchOp.getLoc(), launchOp.getKernelAttr(),
1035 adaptor.getGridSizeZ()},
1037 adaptor.getBlockSizeZ()},
1038 adaptor.getDynamicSharedMemorySize(),
1039 llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
1040 stream, clusterSize);
1041 if (launchOp.getAsyncToken())
1050 LLVM::LLVMPointerType destinationType,
1053 auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.
getType());
1054 if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1055 sourcePtr = rewriter.
create<LLVM::AddrSpaceCastOp>(
1058 destinationType.getAddressSpace()),
1063 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1064 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1066 auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1068 if (failed(
areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
1069 !isConvertibleAndHasIdentityMaps(memRefType) ||
1073 auto loc = memcpyOp.getLoc();
1079 Value nullPtr = rewriter.
create<LLVM::ZeroOp>(loc, elementPtrType);
1081 loc, elementPtrType,
1082 typeConverter->convertType(memRefType.getElementType()), nullPtr,
1085 rewriter.
create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
1088 srcDesc.alignedPtr(rewriter, loc),
1089 *getTypeConverter());
1091 loc, rewriter, llvmPointerType,
1093 *getTypeConverter());
1095 auto stream = adaptor.getAsyncDependencies().front();
1096 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1103 LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1104 gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1106 auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1108 if (failed(
areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
1109 !isConvertibleAndHasIdentityMaps(memRefType) ||
1113 auto loc = memsetOp.getLoc();
1115 Type valueType = adaptor.getValue().getType();
1118 if (!valueType.
isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1120 memsetOp,
"value must be a 16 or 32 bit int or float");
1124 Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1130 rewriter.
create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
1132 dstDesc.alignedPtr(rewriter, loc),
1133 *getTypeConverter());
1135 auto stream = adaptor.getAsyncDependencies().front();
1137 valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1138 builder.
create(loc, rewriter, {dst, value, numElements, stream});
1144 LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1145 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1148 auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1149 {adaptor.getDevIndex()});
1154 template <
typename T>
1157 return builder.
create<LLVM::ConstantOp>(loc, llvmInt32Type,
1158 static_cast<int32_t
>(tValue));
1161 template <
typename T>
1164 return builder.
create<LLVM::ConstantOp>(
1165 loc, llvmFloat32Type,
1169 LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1170 gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1176 auto stream = adaptor.getAsyncDependencies().front();
1179 Type dType = op.getMemref().getType().getElementType();
1183 for (
Value dim : adaptor.getDims()) {
1184 dims.push_back(dim);
1194 if (dims.size() == 2) {
1196 auto handleSz = rewriter.
create<LLVM::ConstantOp>(
1198 handle = rewriter.
create<LLVM::AllocaOp>(
1199 loc, llvmPointerType, llvmInt8Type, handleSz, 16);
1200 handle = rewriter.
create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1202 createLtDnMatCallBuilder
1204 {handle, dims[0], dims[1], pTensor, dtp, stream})
1208 createDnMatCallBuilder
1209 .
create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1213 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1214 handle = createDnVecCallBuilder
1215 .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1218 rewriter.
replaceOp(op, {handle, stream});
1222 LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1223 gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1229 auto stream = adaptor.getAsyncDependencies().front();
1230 auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
1232 for (
Value dim : definingOp.getDims()) {
1233 dims.push_back(dim);
1235 if (dims.size() == 2) {
1239 destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1240 {adaptor.getDnTensor(), stream});
1242 destroyDnMatCallBuilder.create(loc, rewriter,
1243 {adaptor.getDnTensor(), stream});
1246 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1247 destroyDnVecCallBuilder.create(loc, rewriter,
1248 {adaptor.getDnTensor(), stream});
1254 LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1255 gpu::CreateCooOp op, OpAdaptor adaptor,
1261 auto stream = adaptor.getAsyncDependencies().front();
1269 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1271 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1275 createCooCallBuilder
1276 .create(loc, rewriter,
1277 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1278 pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1280 rewriter.
replaceOp(op, {handle, stream});
1284 LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1285 gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1291 auto stream = adaptor.getAsyncDependencies().front();
1295 Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1297 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1301 createCooAoSCallBuilder
1302 .create(loc, rewriter,
1303 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1304 pIdxs, pValues, itp, dtp, stream})
1306 rewriter.
replaceOp(op, {handle, stream});
1310 LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1311 gpu::CreateCsrOp op, OpAdaptor adaptor,
1317 auto stream = adaptor.getAsyncDependencies().front();
1325 llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1327 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1329 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1334 createCsrCallBuilder
1335 .create(loc, rewriter,
1336 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1337 pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1339 rewriter.
replaceOp(op, {handle, stream});
1343 LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1344 gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1350 auto stream = adaptor.getAsyncDependencies().front();
1354 llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1358 auto handleSz = rewriter.
create<LLVM::ConstantOp>(
1361 loc, llvmPointerType, llvmInt8Type, handleSz, 16);
1362 handle = rewriter.
create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1364 create2To4SpMatCallBuilder
1366 {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1368 rewriter.
replaceOp(op, {handle, stream});
1372 LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1373 gpu::DestroySpMatOp op, OpAdaptor adaptor,
1379 auto stream = adaptor.getAsyncDependencies().front();
1382 destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1383 {adaptor.getSpmat(), stream});
1386 destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1392 LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1393 gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1402 auto stream = adaptor.getAsyncDependencies().front();
1403 auto bufferSize = spMVBufferSizeCallBuilder
1404 .create(loc, rewriter,
1405 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1406 adaptor.getDnY(), computeType, stream})
1408 rewriter.
replaceOp(op, {bufferSize, stream});
1412 LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1413 gpu::SpMVOp op, OpAdaptor adaptor,
1422 auto stream = adaptor.getAsyncDependencies().front();
1425 spMVCallBuilder.create(loc, rewriter,
1426 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1427 adaptor.getDnY(), computeType, pBuf, stream});
1432 LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1433 gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1441 auto stream = adaptor.getAsyncDependencies().front();
1448 auto three = rewriter.
create<LLVM::ConstantOp>(loc, getIndexType(),
1450 auto bufferSize = rewriter.
create<LLVM::AllocaOp>(
1451 loc, llvmPointerType, llvmPointerType, three, 16);
1452 createCuSparseLtSpMMBufferSizeBuilder
1454 {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1455 adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1459 auto bufferSizePtr1 = rewriter.
create<LLVM::GEPOp>(
1460 loc, llvmPointerType, llvmPointerType, bufferSize,
1463 auto bufferSizePtr2 = rewriter.
create<LLVM::GEPOp>(
1464 loc, llvmPointerType, llvmPointerType, bufferSize,
1468 rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize);
1470 rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1);
1472 rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
1474 rewriter.
replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1479 createSpMMBufferSizeCallBuilder
1480 .create(loc, rewriter,
1481 {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1482 adaptor.getDnmatC(), computeType, stream})
1484 rewriter.
replaceOp(op, {bufferSize, stream});
1489 LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1490 gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1500 auto stream = adaptor.getAsyncDependencies().front();
1502 createSDDMMBufferSizeCallBuilder
1503 .create(loc, rewriter,
1504 {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1505 adaptor.getSpmatC(), computeType, stream})
1507 rewriter.
replaceOp(op, {bufferSize, stream});
1511 LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1512 gpu::SpMMOp op, OpAdaptor adaptor,
1523 auto stream = adaptor.getAsyncDependencies().front();
1528 for (
Value buffer : adaptor.getBuffers()) {
1530 pBufs.push_back(pBuf);
1532 createCuSparseLtSpMMBuilder.create(
1534 {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1535 pBufs[0], pBufs[1], pBufs[2], stream});
1539 createSpMMCallBuilder.create(loc, rewriter,
1540 {modeA, modeB, adaptor.getSpmatA(),
1541 adaptor.getDnmatB(), adaptor.getDnmatC(),
1542 computeType, pBuf, stream});
1548 template <
typename T>
1555 LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1556 gpu::SDDMMOp op, OpAdaptor adaptor,
1566 auto stream = adaptor.getAsyncDependencies().front();
1569 createSDDMMCallBuilder.create(loc, rewriter,
1570 {modeA, modeB, adaptor.getDnmatA(),
1571 adaptor.getDnmatB(), adaptor.getSpmatC(),
1572 computeType, pBuf, stream});
1578 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1579 gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1585 auto stream = adaptor.getAsyncDependencies().front();
1586 Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1588 rewriter.
replaceOp(op, {descr, stream});
1593 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1594 gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1600 auto stream = adaptor.getAsyncDependencies().front();
1601 createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1602 {adaptor.getDesc(), stream});
1608 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1609 gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1619 auto stream = adaptor.getAsyncDependencies().front();
1623 Value bufferSizeNew;
1625 if (adaptor.getKind() ==
1626 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1628 createSpGEMMWorkEstimationBuilder
1629 .create(loc, rewriter,
1630 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1631 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1632 adaptor.getBufferSz(), pBuf, stream})
1636 createSpGEMMComputeBuilder
1637 .create(loc, rewriter,
1638 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1639 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1640 adaptor.getBufferSz(), pBuf, stream})
1643 rewriter.
replaceOp(op, {bufferSizeNew, stream});
1647 LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1648 gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1658 auto stream = adaptor.getAsyncDependencies().front();
1659 createSpGEMMCopyBuilder.create(loc, rewriter,
1660 {adaptor.getDesc(), modeA, modeB,
1661 adaptor.getSpmatA(), adaptor.getSpmatB(),
1662 adaptor.getSpmatC(), computeType, stream});
1667 LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1668 gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1674 auto stream = adaptor.getAsyncDependencies().front();
1676 auto three = rewriter.
create<LLVM::ConstantOp>(loc, getIndexType(),
1678 auto buffer = rewriter.
create<LLVM::AllocaOp>(
1679 loc, llvmPointerType, llvmInt64Type, three, 16);
1681 auto rowsPtr = rewriter.
create<LLVM::GEPOp>(
1682 loc, llvmPointerType, llvmPointerType, buffer,
1685 auto colsPtr = rewriter.
create<LLVM::GEPOp>(
1686 loc, llvmPointerType, llvmPointerType, buffer,
1689 auto nnzsPtr = rewriter.
create<LLVM::GEPOp>(
1690 loc, llvmPointerType, llvmPointerType, buffer,
1693 createSpMatGetSizeBuilder.
create(
1694 loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1695 auto rows = rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr);
1696 auto cols = rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr);
1697 auto nnzs = rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr);
1703 LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1704 gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1710 auto stream = adaptor.getAsyncDependencies().front();
1717 createSetCsrPointersBuilder.create(
1718 loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1723 LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1724 gpu::CreateCscOp op, OpAdaptor adaptor,
1730 auto stream = adaptor.getAsyncDependencies().front();
1738 llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1740 llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1742 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1747 createCscCallBuilder
1748 .create(loc, rewriter,
1749 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1750 pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1752 rewriter.
replaceOp(op, {handle, stream});
1756 LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1757 gpu::CreateBsrOp op, OpAdaptor adaptor,
1763 auto stream = adaptor.getAsyncDependencies().front();
1771 llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1773 llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1775 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1780 createBsrCallBuilder
1781 .create(loc, rewriter,
1782 {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1783 adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1784 pColIdxs, pValues, ptp, itp, dtp, stream})
1786 rewriter.
replaceOp(op, {handle, stream});
1792 bool kernelBarePtrCallConv,
bool kernelIntersperseSizeCallConv) {
1793 addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
1794 addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
1795 addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
1796 addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter);
1798 patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
1799 ConvertDeallocOpToGpuRuntimeCallPattern,
1800 ConvertHostRegisterOpToGpuRuntimeCallPattern,
1801 ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1802 ConvertMemcpyOpToGpuRuntimeCallPattern,
1803 ConvertMemsetOpToGpuRuntimeCallPattern,
1804 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1805 ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1806 ConvertWaitOpToGpuRuntimeCallPattern,
1807 ConvertAsyncYieldToGpuRuntimeCallPattern,
1808 ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1809 ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1810 ConvertCreateCooOpToGpuRuntimeCallPattern,
1811 ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1812 ConvertCreateCsrOpToGpuRuntimeCallPattern,
1813 ConvertCreateCscOpToGpuRuntimeCallPattern,
1814 ConvertCreateBsrOpToGpuRuntimeCallPattern,
1815 ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1816 ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1817 ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1818 ConvertSpMVOpToGpuRuntimeCallPattern,
1819 ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1820 ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1821 ConvertSpMMOpToGpuRuntimeCallPattern,
1822 ConvertSDDMMOpToGpuRuntimeCallPattern,
1823 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1824 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1825 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1826 ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1827 ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1828 ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
1829 patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv,
1830 kernelIntersperseSizeCallConv);
1838 struct GPUModuleOpConvertToLLVMInterface
1839 :
public ConvertToLLVMOpInterface::ExternalModel<
1840 GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
1842 void getConvertToLLVMConversionAttrs(
1847 void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
1849 auto module = cast<gpu::GPUModuleOp>(op);
1850 ArrayAttr targetsAttr = module.getTargetsAttr();
1852 if (!targetsAttr || targetsAttr.size() != 1)
1854 if (
auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
1855 attrs.push_back(patternAttr);
1860 gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
IntegerAttr getIndexAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
FloatAttr getF32FloatAttr(float value)
IntegerAttr getI8IntegerAttr(int8_t value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
This class helps build Operations.
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Block * getBlock() const
Returns the current block of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
Location getLoc()
The source location the operation was defined or derived from.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
operand_range getOperands()
Returns an iterator on the underlying Value's.
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void registerConvertGpuToLLVMInterface(DialectRegistry ®istry)
Registers the ConvertToLLVMOpInterface interface on the gpu::GPUModuleOP operation.
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool kernelBarePtrCallConv=false, bool kernelIntersperseSizeCallConv=false)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry ®istry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::LLVMFunctionType functionType
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
Utility class for the GPU dialect to represent triples of Values accessible through ....