40 #include "llvm/ADT/STLExtras.h"
41 #include "llvm/Support/Error.h"
42 #include "llvm/Support/FormatVariadic.h"
44 #define DEBUG_TYPE "gpu-to-llvm"
47 #define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
48 #include "mlir/Conversion/Passes.h.inc"
54 class GpuToLLVMConversionPass
55 :
public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
59 Base::getDependentDialects(registry);
63 void runOnOperation()
override;
66 template <
typename OpTy>
69 explicit ConvertOpToGpuRuntimeCallPattern(
77 return type.hasStaticShape()
79 rewriter, loc, indexType, type.getNumElements())
82 : rewriter.create<LLVM::MulOp>(loc,
83 desc.stride(rewriter, loc, 0),
84 desc.size(rewriter, loc, 0));
87 MLIRContext *context = &this->getTypeConverter()->getContext();
97 context, this->getTypeConverter()->getPointerBitwidth(0));
100 "mgpuStreamCreate", llvmPointerType , {}};
102 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType }};
104 "mgpuStreamSynchronize",
108 "mgpuStreamWaitEvent",
110 {llvmPointerType , llvmPointerType }};
112 "mgpuEventCreate", llvmPointerType , {}};
114 "mgpuEventDestroy", llvmVoidType, {llvmPointerType }};
116 "mgpuEventSynchronize",
122 {llvmPointerType , llvmPointerType }};
124 "mgpuMemHostRegisterMemRef",
130 "mgpuMemHostUnregisterMemRef",
144 {llvmPointerType , llvmPointerType }};
148 {llvmPointerType , llvmPointerType ,
161 {llvmPointerType , llvmInt32Type ,
165 "mgpuSetDefaultDevice",
171 {llvmIntPtrType, llvmPointerType, llvmInt32Type,
176 {llvmPointerType, llvmPointerType }};
180 {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
185 {llvmPointerType, llvmPointerType }};
189 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
190 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
195 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
196 llvmPointerType, llvmInt32Type, llvmInt32Type,
201 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
202 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
203 llvmInt32Type, llvmPointerType }};
207 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
208 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
209 llvmInt32Type, llvmPointerType }};
213 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
214 llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
215 llvmInt32Type, llvmInt32Type, llvmInt32Type,
220 {llvmPointerType, llvmPointerType }};
222 "mgpuSpMVBufferSize",
224 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
225 llvmInt32Type, llvmPointerType }};
229 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
230 llvmInt32Type, llvmPointerType, llvmPointerType }};
232 "mgpuSpMMBufferSize",
234 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
235 llvmPointerType, llvmInt32Type, llvmPointerType }};
239 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
240 llvmPointerType, llvmInt32Type, llvmPointerType,
243 "mgpuSDDMMBufferSize",
245 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
246 llvmPointerType, llvmInt32Type, llvmPointerType }};
250 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
251 llvmPointerType, llvmInt32Type, llvmPointerType,
254 "mgpuCreateCuSparseLtDnMat",
256 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
257 llvmInt32Type, llvmPointerType }};
259 "mgpuDestroyCuSparseLtSpMat",
261 {llvmPointerType, llvmPointerType }};
263 "mgpuDestroyCuSparseLtDnMat",
265 {llvmPointerType, llvmPointerType }};
267 "mgpuCusparseLtCreate2To4SpMat",
269 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
270 llvmInt32Type, llvmPointerType }};
272 "mgpuCuSparseLtSpMMBufferSize",
274 {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
275 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
278 "mgpuCuSparseLtSpMM",
280 {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
281 llvmPointerType, llvmPointerType, llvmPointerType }};
283 "mgpuSpGEMMCreateDescr",
287 "mgpuSpGEMMDestroyDescr",
289 {llvmPointerType , llvmPointerType }};
291 "mgpuSpGEMMWorkEstimation",
293 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
294 llvmPointerType , llvmPointerType , llvmPointerType ,
295 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
300 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
301 llvmPointerType , llvmPointerType , llvmPointerType ,
302 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
307 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
308 llvmPointerType , llvmPointerType , llvmPointerType ,
309 llvmInt32Type , llvmPointerType }};
313 {llvmPointerType , llvmPointerType , llvmPointerType ,
314 llvmPointerType , llvmPointerType }};
316 "mgpuSetCsrPointers",
318 {llvmPointerType , llvmPointerType ,
319 llvmPointerType , llvmPointerType ,
325 class ConvertHostRegisterOpToGpuRuntimeCallPattern
326 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
328 ConvertHostRegisterOpToGpuRuntimeCallPattern(
330 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
334 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
338 class ConvertHostUnregisterOpToGpuRuntimeCallPattern
339 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
341 ConvertHostUnregisterOpToGpuRuntimeCallPattern(
343 : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
348 matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
354 class ConvertAllocOpToGpuRuntimeCallPattern
355 :
public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
358 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
362 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
368 class ConvertDeallocOpToGpuRuntimeCallPattern
369 :
public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
371 ConvertDeallocOpToGpuRuntimeCallPattern(
373 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
377 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
381 class ConvertAsyncYieldToGpuRuntimeCallPattern
382 :
public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
384 ConvertAsyncYieldToGpuRuntimeCallPattern(
386 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
390 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
396 class ConvertWaitOpToGpuRuntimeCallPattern
397 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
400 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
404 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
410 class ConvertWaitAsyncOpToGpuRuntimeCallPattern
411 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
413 ConvertWaitAsyncOpToGpuRuntimeCallPattern(
415 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
419 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
424 class LegalizeLaunchFuncOpPattern
425 :
public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
428 bool kernelBarePtrCallConv)
429 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
430 kernelBarePtrCallConv(kernelBarePtrCallConv) {}
434 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
437 bool kernelBarePtrCallConv;
442 class ConvertMemcpyOpToGpuRuntimeCallPattern
443 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
446 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
450 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
456 class ConvertMemsetOpToGpuRuntimeCallPattern
457 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
460 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
464 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
470 class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
471 :
public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
473 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
475 : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
479 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
485 #define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
486 class Convert##op_name##ToGpuRuntimeCallPattern \
487 : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
489 Convert##op_name##ToGpuRuntimeCallPattern( \
490 const LLVMTypeConverter &typeConverter) \
491 : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
495 matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
496 ConversionPatternRewriter &rewriter) const override; \
523 void GpuToLLVMConversionPass::runOnOperation() {
526 options.useBarePtrCallConv = hostBarePtrCallConv;
535 auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
538 iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
543 target.
addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
546 [&](gpu::LaunchFuncOp op) ->
bool {
return converter.isLegal(op); });
554 kernelBarePtrCallConv);
564 auto function = [&] {
565 if (
auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(
functionName))
570 return builder.
create<LLVM::CallOp>(loc,
function, arguments);
587 llvm_unreachable(
"unsupported type");
593 if (llvm::isa<ComplexType>(type)) {
595 auto elementType = cast<ComplexType>(type).getElementType();
596 if (elementType.isBF16())
598 if (elementType.isF16())
600 if (elementType.isF32())
602 if (elementType.isF64())
604 if (elementType.isInteger(8))
606 if (elementType.isInteger(16))
608 if (elementType.isInteger(32))
626 llvm_unreachable(
"unsupported element type");
630 return spMat.
getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
655 llvm_unreachable(
"cannot find spmat def");
660 auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
673 if (!llvm::all_of(operands, [](
Value value) {
677 op,
"Cannot convert if operands aren't of LLVM type.");
683 gpu::AsyncOpInterface op) {
684 if (op.getAsyncDependencies().size() != 1)
686 op,
"Can only convert with exactly one async dependency.");
688 if (!op.getAsyncToken())
694 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
695 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
697 auto *op = hostRegisterOp.getOperation();
703 auto memRefType = hostRegisterOp.getValue().getType();
704 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
707 auto arguments = getTypeConverter()->promoteOperands(
708 loc, op->getOperands(), adaptor.getOperands(), rewriter);
709 arguments.push_back(elementSize);
710 hostRegisterCallBuilder.create(loc, rewriter, arguments);
716 LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
717 gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
719 Operation *op = hostUnregisterOp.getOperation();
725 auto memRefType = hostUnregisterOp.getValue().getType();
726 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
729 auto arguments = getTypeConverter()->promoteOperands(
730 loc, op->
getOperands(), adaptor.getOperands(), rewriter);
731 arguments.push_back(elementSize);
732 hostUnregisterCallBuilder.create(loc, rewriter, arguments);
738 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
739 gpu::AllocOp allocOp, OpAdaptor adaptor,
742 MemRefType memRefType = allocOp.getType();
744 if (failed(
areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
745 !isConvertibleAndHasIdentityMaps(memRefType))
748 auto loc = allocOp.getLoc();
750 bool isShared = allocOp.getHostShared();
752 if (isShared && allocOp.getAsyncToken())
754 allocOp,
"Host Shared allocation cannot be done async");
763 getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
764 shape, strides, sizeBytes);
768 auto nullPtr = rewriter.
create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
769 Value stream = adaptor.getAsyncDependencies().empty()
771 : adaptor.getAsyncDependencies().front();
773 auto isHostShared = rewriter.
create<mlir::LLVM::ConstantOp>(
777 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
781 Value alignedPtr = allocatedPtr;
784 auto memRefDescriptor = this->createMemRefDescriptor(
785 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
787 if (allocOp.getAsyncToken()) {
789 rewriter.
replaceOp(allocOp, {memRefDescriptor, stream});
791 rewriter.
replaceOp(allocOp, {memRefDescriptor});
797 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
798 gpu::DeallocOp deallocOp, OpAdaptor adaptor,
800 if (failed(
areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
808 Value stream = adaptor.getAsyncDependencies().front();
809 deallocCallBuilder.create(loc, rewriter, {pointer, stream});
816 return isa<gpu::AsyncTokenType>(value.
getType());
823 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
824 async::YieldOp yieldOp, OpAdaptor adaptor,
831 llvm::SmallDenseSet<Value> streams;
832 for (
auto &operand : yieldOp->getOpOperands()) {
835 auto idx = operand.getOperandNumber();
836 auto stream = adaptor.getOperands()[idx];
837 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
838 eventRecordCallBuilder.create(loc, rewriter, {event, stream});
839 newOperands[idx] = event;
840 streams.insert(stream);
842 for (
auto stream : streams)
843 streamDestroyCallBuilder.create(loc, rewriter, {stream});
845 rewriter.
modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
851 assert(isa<LLVM::LLVMPointerType>(value.
getType()));
853 return *defOp.getCallee() == functionName;
861 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
862 gpu::WaitOp waitOp, OpAdaptor adaptor,
864 if (waitOp.getAsyncToken())
869 for (
auto operand : adaptor.getOperands()) {
872 streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
873 streamDestroyCallBuilder.create(loc, rewriter, {operand});
877 eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
878 eventDestroyCallBuilder.create(loc, rewriter, {operand});
891 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
892 gpu::WaitOp waitOp, OpAdaptor adaptor,
894 if (!waitOp.getAsyncToken())
902 llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
903 auto operand = std::get<1>(pair);
907 auto *defOp = std::get<0>(pair).getDefiningOp();
909 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
910 eventRecordCallBuilder.create(loc, rewriter, {event, operand});
911 events.push_back(event);
915 events.push_back(operand);
919 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
920 for (
auto event : events)
921 streamWaitEventCallBuilder.create(loc, rewriter, {stream,
event});
922 for (
auto event : events)
923 eventDestroyCallBuilder.create(loc, rewriter, {
event});
930 LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
931 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
933 if (failed(
areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
936 if (launchOp.getAsyncDependencies().size() > 1)
938 launchOp,
"Cannot convert with more than one async dependency.");
943 if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
945 launchOp,
"Cannot convert non-async op with async dependencies.");
950 if (!adaptor.getAsyncDependencies().empty())
951 stream = adaptor.getAsyncDependencies().front();
954 else if (launchOp.getAsyncToken())
955 stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
960 loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(), rewriter,
961 kernelBarePtrCallConv);
963 std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
964 if (launchOp.hasClusterSize()) {
967 adaptor.getClusterSizeZ()};
969 rewriter.
create<gpu::LaunchFuncOp>(
970 launchOp.getLoc(), launchOp.getKernelAttr(),
972 adaptor.getGridSizeZ()},
974 adaptor.getBlockSizeZ()},
975 adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
976 if (launchOp.getAsyncToken())
985 LLVM::LLVMPointerType destinationType,
988 auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.
getType());
989 if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
990 sourcePtr = rewriter.
create<LLVM::AddrSpaceCastOp>(
993 destinationType.getAddressSpace()),
998 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
999 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1001 auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1003 if (failed(
areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
1004 !isConvertibleAndHasIdentityMaps(memRefType) ||
1008 auto loc = memcpyOp.getLoc();
1014 Value nullPtr = rewriter.
create<LLVM::ZeroOp>(loc, elementPtrType);
1016 loc, elementPtrType,
1017 typeConverter->convertType(memRefType.getElementType()), nullPtr,
1020 rewriter.
create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
1023 srcDesc.alignedPtr(rewriter, loc),
1024 *getTypeConverter());
1026 loc, rewriter, llvmPointerType,
1028 *getTypeConverter());
1030 auto stream = adaptor.getAsyncDependencies().front();
1031 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1038 LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1039 gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1041 auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1043 if (failed(
areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
1044 !isConvertibleAndHasIdentityMaps(memRefType) ||
1048 auto loc = memsetOp.getLoc();
1050 Type valueType = adaptor.getValue().getType();
1053 if (!valueType.
isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1055 memsetOp,
"value must be a 16 or 32 bit int or float");
1059 Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1065 rewriter.
create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
1067 dstDesc.alignedPtr(rewriter, loc),
1068 *getTypeConverter());
1070 auto stream = adaptor.getAsyncDependencies().front();
1072 valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1073 builder.
create(loc, rewriter, {dst, value, numElements, stream});
1079 LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1080 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1083 auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1084 {adaptor.getDevIndex()});
1089 template <
typename T>
1092 return builder.
create<LLVM::ConstantOp>(loc, llvmInt32Type,
1093 static_cast<int32_t
>(tValue));
1096 template <
typename T>
1099 return builder.
create<LLVM::ConstantOp>(
1100 loc, llvmFloat32Type,
1104 LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1105 gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1111 auto stream = adaptor.getAsyncDependencies().front();
1114 Type dType = op.getMemref().getType().getElementType();
1118 for (
Value dim : adaptor.getDims()) {
1119 dims.push_back(dim);
1129 if (dims.size() == 2) {
1131 auto handleSz = rewriter.
create<LLVM::ConstantOp>(
1133 handle = rewriter.
create<LLVM::AllocaOp>(
1134 loc, llvmPointerType, llvmInt8Type, handleSz, 16);
1135 handle = rewriter.
create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1137 createLtDnMatCallBuilder
1139 {handle, dims[0], dims[1], pTensor, dtp, stream})
1143 createDnMatCallBuilder
1144 .
create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1148 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1149 handle = createDnVecCallBuilder
1150 .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1153 rewriter.
replaceOp(op, {handle, stream});
1157 LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1158 gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1164 auto stream = adaptor.getAsyncDependencies().front();
1165 auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
1167 for (
Value dim : definingOp.getDims()) {
1168 dims.push_back(dim);
1170 if (dims.size() == 2) {
1174 destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1175 {adaptor.getDnTensor(), stream});
1177 destroyDnMatCallBuilder.create(loc, rewriter,
1178 {adaptor.getDnTensor(), stream});
1181 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1182 destroyDnVecCallBuilder.create(loc, rewriter,
1183 {adaptor.getDnTensor(), stream});
1189 LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1190 gpu::CreateCooOp op, OpAdaptor adaptor,
1196 auto stream = adaptor.getAsyncDependencies().front();
1204 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1206 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1210 createCooCallBuilder
1211 .create(loc, rewriter,
1212 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1213 pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1215 rewriter.
replaceOp(op, {handle, stream});
1219 LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1220 gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1226 auto stream = adaptor.getAsyncDependencies().front();
1230 Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1232 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1236 createCooAoSCallBuilder
1237 .create(loc, rewriter,
1238 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1239 pIdxs, pValues, itp, dtp, stream})
1241 rewriter.
replaceOp(op, {handle, stream});
1245 LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1246 gpu::CreateCsrOp op, OpAdaptor adaptor,
1252 auto stream = adaptor.getAsyncDependencies().front();
1260 llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1262 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1264 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1269 createCsrCallBuilder
1270 .create(loc, rewriter,
1271 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1272 pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1274 rewriter.
replaceOp(op, {handle, stream});
1278 LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1279 gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1285 auto stream = adaptor.getAsyncDependencies().front();
1289 llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1293 auto handleSz = rewriter.
create<LLVM::ConstantOp>(
1296 loc, llvmPointerType, llvmInt8Type, handleSz, 16);
1297 handle = rewriter.
create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1299 create2To4SpMatCallBuilder
1301 {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1303 rewriter.
replaceOp(op, {handle, stream});
1307 LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1308 gpu::DestroySpMatOp op, OpAdaptor adaptor,
1314 auto stream = adaptor.getAsyncDependencies().front();
1317 destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1318 {adaptor.getSpmat(), stream});
1321 destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1327 LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1328 gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1337 auto stream = adaptor.getAsyncDependencies().front();
1338 auto bufferSize = spMVBufferSizeCallBuilder
1339 .create(loc, rewriter,
1340 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1341 adaptor.getDnY(), computeType, stream})
1343 rewriter.
replaceOp(op, {bufferSize, stream});
1347 LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1348 gpu::SpMVOp op, OpAdaptor adaptor,
1357 auto stream = adaptor.getAsyncDependencies().front();
1360 spMVCallBuilder.create(loc, rewriter,
1361 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1362 adaptor.getDnY(), computeType, pBuf, stream});
1367 LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1368 gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1376 auto stream = adaptor.getAsyncDependencies().front();
1383 auto three = rewriter.
create<LLVM::ConstantOp>(loc, getIndexType(),
1385 auto bufferSize = rewriter.
create<LLVM::AllocaOp>(
1386 loc, llvmPointerType, llvmPointerType, three, 16);
1387 createCuSparseLtSpMMBufferSizeBuilder
1389 {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1390 adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1394 auto bufferSizePtr1 = rewriter.
create<LLVM::GEPOp>(
1395 loc, llvmPointerType, llvmPointerType, bufferSize,
1398 auto bufferSizePtr2 = rewriter.
create<LLVM::GEPOp>(
1399 loc, llvmPointerType, llvmPointerType, bufferSize,
1403 rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize);
1405 rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1);
1407 rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
1409 rewriter.
replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1414 createSpMMBufferSizeCallBuilder
1415 .create(loc, rewriter,
1416 {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1417 adaptor.getDnmatC(), computeType, stream})
1419 rewriter.
replaceOp(op, {bufferSize, stream});
1424 LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1425 gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1435 auto stream = adaptor.getAsyncDependencies().front();
1437 createSDDMMBufferSizeCallBuilder
1438 .create(loc, rewriter,
1439 {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1440 adaptor.getSpmatC(), computeType, stream})
1442 rewriter.
replaceOp(op, {bufferSize, stream});
1446 LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1447 gpu::SpMMOp op, OpAdaptor adaptor,
1458 auto stream = adaptor.getAsyncDependencies().front();
1463 for (
Value buffer : adaptor.getBuffers()) {
1465 pBufs.push_back(pBuf);
1467 createCuSparseLtSpMMBuilder.create(
1469 {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1470 pBufs[0], pBufs[1], pBufs[2], stream});
1474 createSpMMCallBuilder.create(loc, rewriter,
1475 {modeA, modeB, adaptor.getSpmatA(),
1476 adaptor.getDnmatB(), adaptor.getDnmatC(),
1477 computeType, pBuf, stream});
1483 template <
typename T>
1490 LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1491 gpu::SDDMMOp op, OpAdaptor adaptor,
1501 auto stream = adaptor.getAsyncDependencies().front();
1504 createSDDMMCallBuilder.create(loc, rewriter,
1505 {modeA, modeB, adaptor.getDnmatA(),
1506 adaptor.getDnmatB(), adaptor.getSpmatC(),
1507 computeType, pBuf, stream});
1513 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1514 gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1520 auto stream = adaptor.getAsyncDependencies().front();
1521 Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1523 rewriter.
replaceOp(op, {descr, stream});
1528 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1529 gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1535 auto stream = adaptor.getAsyncDependencies().front();
1536 createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1537 {adaptor.getDesc(), stream});
1543 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1544 gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1554 auto stream = adaptor.getAsyncDependencies().front();
1558 Value bufferSizeNew;
1560 if (adaptor.getKind() ==
1561 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1563 createSpGEMMWorkEstimationBuilder
1564 .create(loc, rewriter,
1565 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1566 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1567 adaptor.getBufferSz(), pBuf, stream})
1571 createSpGEMMComputeBuilder
1572 .create(loc, rewriter,
1573 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1574 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1575 adaptor.getBufferSz(), pBuf, stream})
1578 rewriter.
replaceOp(op, {bufferSizeNew, stream});
1582 LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1583 gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1593 auto stream = adaptor.getAsyncDependencies().front();
1594 createSpGEMMCopyBuilder.create(loc, rewriter,
1595 {adaptor.getDesc(), modeA, modeB,
1596 adaptor.getSpmatA(), adaptor.getSpmatB(),
1597 adaptor.getSpmatC(), computeType, stream});
1602 LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1603 gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1609 auto stream = adaptor.getAsyncDependencies().front();
1611 auto three = rewriter.
create<LLVM::ConstantOp>(loc, getIndexType(),
1613 auto buffer = rewriter.
create<LLVM::AllocaOp>(
1614 loc, llvmPointerType, llvmInt64Type, three, 16);
1616 auto rowsPtr = rewriter.
create<LLVM::GEPOp>(
1617 loc, llvmPointerType, llvmPointerType, buffer,
1620 auto colsPtr = rewriter.
create<LLVM::GEPOp>(
1621 loc, llvmPointerType, llvmPointerType, buffer,
1624 auto nnzsPtr = rewriter.
create<LLVM::GEPOp>(
1625 loc, llvmPointerType, llvmPointerType, buffer,
1628 createSpMatGetSizeBuilder.
create(
1629 loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1630 auto rows = rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr);
1631 auto cols = rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr);
1632 auto nnzs = rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr);
1638 LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1639 gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1645 auto stream = adaptor.getAsyncDependencies().front();
1652 createSetCsrPointersBuilder.create(
1653 loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1658 LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1659 gpu::CreateCscOp op, OpAdaptor adaptor,
1665 auto stream = adaptor.getAsyncDependencies().front();
1673 llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1675 llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1677 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1682 createCscCallBuilder
1683 .create(loc, rewriter,
1684 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1685 pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1687 rewriter.
replaceOp(op, {handle, stream});
1691 LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1692 gpu::CreateBsrOp op, OpAdaptor adaptor,
1698 auto stream = adaptor.getAsyncDependencies().front();
1706 llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1708 llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1710 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1715 createBsrCallBuilder
1716 .create(loc, rewriter,
1717 {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1718 adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1719 pColIdxs, pValues, ptp, itp, dtp, stream})
1721 rewriter.
replaceOp(op, {handle, stream});
1727 bool kernelBarePtrCallConv) {
1728 addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
1729 addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
1730 addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
1731 addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter);
1733 patterns.
add<ConvertAllocOpToGpuRuntimeCallPattern,
1734 ConvertDeallocOpToGpuRuntimeCallPattern,
1735 ConvertHostRegisterOpToGpuRuntimeCallPattern,
1736 ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1737 ConvertMemcpyOpToGpuRuntimeCallPattern,
1738 ConvertMemsetOpToGpuRuntimeCallPattern,
1739 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1740 ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1741 ConvertWaitOpToGpuRuntimeCallPattern,
1742 ConvertAsyncYieldToGpuRuntimeCallPattern,
1743 ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1744 ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1745 ConvertCreateCooOpToGpuRuntimeCallPattern,
1746 ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1747 ConvertCreateCsrOpToGpuRuntimeCallPattern,
1748 ConvertCreateCscOpToGpuRuntimeCallPattern,
1749 ConvertCreateBsrOpToGpuRuntimeCallPattern,
1750 ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1751 ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1752 ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1753 ConvertSpMVOpToGpuRuntimeCallPattern,
1754 ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1755 ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1756 ConvertSpMMOpToGpuRuntimeCallPattern,
1757 ConvertSDDMMOpToGpuRuntimeCallPattern,
1758 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1759 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1760 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1761 ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1762 ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1763 ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
1764 patterns.
add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv);
1772 struct GPUModuleOpConvertToLLVMInterface
1773 :
public ConvertToLLVMOpInterface::ExternalModel<
1774 GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
1776 void getConvertToLLVMConversionAttrs(
1781 void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
1783 auto module = cast<gpu::GPUModuleOp>(op);
1784 ArrayAttr targetsAttr = module.getTargetsAttr();
1786 if (!targetsAttr || targetsAttr.size() != 1)
1788 if (
auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
1789 attrs.push_back(patternAttr);
1794 gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
IntegerAttr getIndexAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
FloatAttr getF32FloatAttr(float value)
IntegerAttr getI8IntegerAttr(int8_t value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
This class helps build Operations.
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Block * getBlock() const
Returns the current block of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
Location getLoc()
The source location the operation was defined or derived from.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
operand_range getOperands()
Returns an iterator on the underlying Value's.
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void registerConvertGpuToLLVMInterface(DialectRegistry ®istry)
Registers the ConvertToLLVMOpInterface interface on the gpu::GPUModuleOP operation.
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Include the generated interface declarations.
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void registerConvertToLLVMDependentDialectLoading(DialectRegistry ®istry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool kernelBarePtrCallConv=false)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::LLVMFunctionType functionType
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
Utility class for the GPU dialect to represent triples of Values accessible through ....