39 #include "llvm/ADT/STLExtras.h"
40 #include "llvm/Support/Error.h"
41 #include "llvm/Support/FormatVariadic.h"
43 #define DEBUG_TYPE "gpu-to-llvm"
46 #define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
47 #include "mlir/Conversion/Passes.h.inc"
53 class GpuToLLVMConversionPass
54 :
public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
58 Base::getDependentDialects(registry);
62 void runOnOperation()
override;
65 template <
typename OpTy>
68 explicit ConvertOpToGpuRuntimeCallPattern(
76 return type.hasStaticShape()
78 rewriter, loc, indexType, type.getNumElements())
81 : rewriter.create<LLVM::MulOp>(loc,
82 desc.stride(rewriter, loc, 0),
83 desc.size(rewriter, loc, 0));
86 MLIRContext *context = &this->getTypeConverter()->getContext();
96 context, this->getTypeConverter()->getPointerBitwidth(0));
99 "mgpuStreamCreate", llvmPointerType , {}};
101 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType }};
103 "mgpuStreamSynchronize",
107 "mgpuStreamWaitEvent",
109 {llvmPointerType , llvmPointerType }};
111 "mgpuEventCreate", llvmPointerType , {}};
113 "mgpuEventDestroy", llvmVoidType, {llvmPointerType }};
115 "mgpuEventSynchronize",
121 {llvmPointerType , llvmPointerType }};
123 "mgpuMemHostRegisterMemRef",
129 "mgpuMemHostUnregisterMemRef",
143 {llvmPointerType , llvmPointerType }};
147 {llvmPointerType , llvmPointerType ,
160 {llvmPointerType , llvmInt32Type ,
164 "mgpuSetDefaultDevice",
170 {llvmIntPtrType, llvmPointerType, llvmInt32Type,
175 {llvmPointerType, llvmPointerType }};
179 {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
184 {llvmPointerType, llvmPointerType }};
188 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
189 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
194 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
195 llvmPointerType, llvmInt32Type, llvmInt32Type,
200 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
201 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
202 llvmInt32Type, llvmPointerType }};
206 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
207 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
208 llvmInt32Type, llvmPointerType }};
212 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
213 llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
214 llvmInt32Type, llvmInt32Type, llvmInt32Type,
219 {llvmPointerType, llvmPointerType }};
221 "mgpuSpMVBufferSize",
223 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
224 llvmInt32Type, llvmPointerType }};
228 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
229 llvmInt32Type, llvmPointerType, llvmPointerType }};
231 "mgpuSpMMBufferSize",
233 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
234 llvmPointerType, llvmInt32Type, llvmPointerType }};
238 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
239 llvmPointerType, llvmInt32Type, llvmPointerType,
242 "mgpuSDDMMBufferSize",
244 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
245 llvmPointerType, llvmInt32Type, llvmPointerType }};
249 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
250 llvmPointerType, llvmInt32Type, llvmPointerType,
253 "mgpuCreateCuSparseLtDnMat",
255 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
256 llvmInt32Type, llvmPointerType }};
258 "mgpuDestroyCuSparseLtSpMat",
260 {llvmPointerType, llvmPointerType }};
262 "mgpuDestroyCuSparseLtDnMat",
264 {llvmPointerType, llvmPointerType }};
266 "mgpuCusparseLtCreate2To4SpMat",
268 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
269 llvmInt32Type, llvmPointerType }};
271 "mgpuCuSparseLtSpMMBufferSize",
273 {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
274 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
277 "mgpuCuSparseLtSpMM",
279 {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
280 llvmPointerType, llvmPointerType, llvmPointerType }};
282 "mgpuSpGEMMCreateDescr",
286 "mgpuSpGEMMDestroyDescr",
288 {llvmPointerType , llvmPointerType }};
290 "mgpuSpGEMMWorkEstimation",
292 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
293 llvmPointerType , llvmPointerType , llvmPointerType ,
294 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
299 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
300 llvmPointerType , llvmPointerType , llvmPointerType ,
301 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
306 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
307 llvmPointerType , llvmPointerType , llvmPointerType ,
308 llvmInt32Type , llvmPointerType }};
312 {llvmPointerType , llvmPointerType , llvmPointerType ,
313 llvmPointerType , llvmPointerType }};
315 "mgpuSetCsrPointers",
317 {llvmPointerType , llvmPointerType ,
318 llvmPointerType , llvmPointerType ,
324 class ConvertHostRegisterOpToGpuRuntimeCallPattern
325 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
327 ConvertHostRegisterOpToGpuRuntimeCallPattern(
329 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
333 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
337 class ConvertHostUnregisterOpToGpuRuntimeCallPattern
338 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
340 ConvertHostUnregisterOpToGpuRuntimeCallPattern(
342 : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
347 matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
353 class ConvertAllocOpToGpuRuntimeCallPattern
354 :
public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
357 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
361 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
367 class ConvertDeallocOpToGpuRuntimeCallPattern
368 :
public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
370 ConvertDeallocOpToGpuRuntimeCallPattern(
372 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
376 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
380 class ConvertAsyncYieldToGpuRuntimeCallPattern
381 :
public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
383 ConvertAsyncYieldToGpuRuntimeCallPattern(
385 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
389 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
395 class ConvertWaitOpToGpuRuntimeCallPattern
396 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
399 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
403 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
409 class ConvertWaitAsyncOpToGpuRuntimeCallPattern
410 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
412 ConvertWaitAsyncOpToGpuRuntimeCallPattern(
414 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
418 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
423 class LegalizeLaunchFuncOpPattern
424 :
public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
427 bool kernelBarePtrCallConv)
428 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
429 kernelBarePtrCallConv(kernelBarePtrCallConv) {}
433 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
436 bool kernelBarePtrCallConv;
441 class ConvertMemcpyOpToGpuRuntimeCallPattern
442 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
445 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
449 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
455 class ConvertMemsetOpToGpuRuntimeCallPattern
456 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
459 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
463 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
469 class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
470 :
public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
472 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
474 : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
478 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
484 #define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
485 class Convert##op_name##ToGpuRuntimeCallPattern \
486 : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
488 Convert##op_name##ToGpuRuntimeCallPattern( \
489 const LLVMTypeConverter &typeConverter) \
490 : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
494 matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
495 ConversionPatternRewriter &rewriter) const override; \
522 void GpuToLLVMConversionPass::runOnOperation() {
525 options.useBarePtrCallConv = hostBarePtrCallConv;
534 auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
537 iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
542 target.
addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
545 [&](gpu::LaunchFuncOp op) ->
bool {
return converter.isLegal(op); });
553 kernelBarePtrCallConv);
563 auto function = [&] {
564 if (
auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(
functionName))
569 return builder.
create<LLVM::CallOp>(loc,
function, arguments);
586 llvm_unreachable(
"unsupported type");
592 if (llvm::isa<ComplexType>(type)) {
594 auto elementType = cast<ComplexType>(type).getElementType();
595 if (elementType.isBF16())
597 if (elementType.isF16())
599 if (elementType.isF32())
601 if (elementType.isF64())
603 if (elementType.isInteger(8))
605 if (elementType.isInteger(16))
607 if (elementType.isInteger(32))
625 llvm_unreachable(
"unsupported element type");
629 return spMat.
getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
654 llvm_unreachable(
"cannot find spmat def");
659 auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
672 if (!llvm::all_of(operands, [](
Value value) {
676 op,
"Cannot convert if operands aren't of LLVM type.");
682 gpu::AsyncOpInterface op) {
683 if (op.getAsyncDependencies().size() != 1)
685 op,
"Can only convert with exactly one async dependency.");
687 if (!op.getAsyncToken())
693 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
694 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
696 auto *op = hostRegisterOp.getOperation();
702 auto memRefType = hostRegisterOp.getValue().getType();
703 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
706 auto arguments = getTypeConverter()->promoteOperands(
707 loc, op->getOperands(), adaptor.getOperands(), rewriter);
708 arguments.push_back(elementSize);
709 hostRegisterCallBuilder.create(loc, rewriter, arguments);
715 LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
716 gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
718 Operation *op = hostUnregisterOp.getOperation();
724 auto memRefType = hostUnregisterOp.getValue().getType();
725 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
728 auto arguments = getTypeConverter()->promoteOperands(
729 loc, op->
getOperands(), adaptor.getOperands(), rewriter);
730 arguments.push_back(elementSize);
731 hostUnregisterCallBuilder.create(loc, rewriter, arguments);
737 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
738 gpu::AllocOp allocOp, OpAdaptor adaptor,
741 MemRefType memRefType = allocOp.getType();
743 if (failed(
areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
744 !isConvertibleAndHasIdentityMaps(memRefType))
747 auto loc = allocOp.getLoc();
749 bool isShared = allocOp.getHostShared();
751 if (isShared && allocOp.getAsyncToken())
753 allocOp,
"Host Shared allocation cannot be done async");
762 getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
763 shape, strides, sizeBytes);
767 auto nullPtr = rewriter.
create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
768 Value stream = adaptor.getAsyncDependencies().empty()
770 : adaptor.getAsyncDependencies().front();
772 auto isHostShared = rewriter.
create<mlir::LLVM::ConstantOp>(
776 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
780 Value alignedPtr = allocatedPtr;
783 auto memRefDescriptor = this->createMemRefDescriptor(
784 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
786 if (allocOp.getAsyncToken()) {
788 rewriter.
replaceOp(allocOp, {memRefDescriptor, stream});
790 rewriter.
replaceOp(allocOp, {memRefDescriptor});
796 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
797 gpu::DeallocOp deallocOp, OpAdaptor adaptor,
799 if (failed(
areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
807 Value stream = adaptor.getAsyncDependencies().front();
808 deallocCallBuilder.create(loc, rewriter, {pointer, stream});
815 return isa<gpu::AsyncTokenType>(value.
getType());
822 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
823 async::YieldOp yieldOp, OpAdaptor adaptor,
830 llvm::SmallDenseSet<Value> streams;
831 for (
auto &operand : yieldOp->getOpOperands()) {
834 auto idx = operand.getOperandNumber();
835 auto stream = adaptor.getOperands()[idx];
836 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
837 eventRecordCallBuilder.create(loc, rewriter, {event, stream});
838 newOperands[idx] = event;
839 streams.insert(stream);
841 for (
auto stream : streams)
842 streamDestroyCallBuilder.create(loc, rewriter, {stream});
844 rewriter.
modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
850 assert(isa<LLVM::LLVMPointerType>(value.
getType()));
852 return *defOp.getCallee() == functionName;
860 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
861 gpu::WaitOp waitOp, OpAdaptor adaptor,
863 if (waitOp.getAsyncToken())
868 for (
auto operand : adaptor.getOperands()) {
871 streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
872 streamDestroyCallBuilder.create(loc, rewriter, {operand});
876 eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
877 eventDestroyCallBuilder.create(loc, rewriter, {operand});
890 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
891 gpu::WaitOp waitOp, OpAdaptor adaptor,
893 if (!waitOp.getAsyncToken())
901 llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
902 auto operand = std::get<1>(pair);
906 auto *defOp = std::get<0>(pair).getDefiningOp();
908 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
909 eventRecordCallBuilder.create(loc, rewriter, {event, operand});
910 events.push_back(event);
914 events.push_back(operand);
918 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
919 for (
auto event : events)
920 streamWaitEventCallBuilder.create(loc, rewriter, {stream,
event});
921 for (
auto event : events)
922 eventDestroyCallBuilder.create(loc, rewriter, {
event});
929 LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
930 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
932 if (failed(
areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
935 if (launchOp.getAsyncDependencies().size() > 1)
937 launchOp,
"Cannot convert with more than one async dependency.");
942 if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
944 launchOp,
"Cannot convert non-async op with async dependencies.");
949 if (!adaptor.getAsyncDependencies().empty())
950 stream = adaptor.getAsyncDependencies().front();
953 else if (launchOp.getAsyncToken())
954 stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
959 loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(), rewriter,
960 kernelBarePtrCallConv);
962 std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
963 if (launchOp.hasClusterSize()) {
966 adaptor.getClusterSizeZ()};
968 rewriter.
create<gpu::LaunchFuncOp>(
969 launchOp.getLoc(), launchOp.getKernelAttr(),
971 adaptor.getGridSizeZ()},
973 adaptor.getBlockSizeZ()},
974 adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
975 if (launchOp.getAsyncToken())
984 LLVM::LLVMPointerType destinationType,
987 auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.
getType());
988 if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
989 sourcePtr = rewriter.
create<LLVM::AddrSpaceCastOp>(
992 destinationType.getAddressSpace()),
997 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
998 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1000 auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1002 if (failed(
areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
1003 !isConvertibleAndHasIdentityMaps(memRefType) ||
1007 auto loc = memcpyOp.getLoc();
1013 Value nullPtr = rewriter.
create<LLVM::ZeroOp>(loc, elementPtrType);
1015 loc, elementPtrType,
1016 typeConverter->convertType(memRefType.getElementType()), nullPtr,
1019 rewriter.
create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
1022 srcDesc.alignedPtr(rewriter, loc),
1023 *getTypeConverter());
1025 loc, rewriter, llvmPointerType,
1027 *getTypeConverter());
1029 auto stream = adaptor.getAsyncDependencies().front();
1030 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1037 LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1038 gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1040 auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1042 if (failed(
areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
1043 !isConvertibleAndHasIdentityMaps(memRefType) ||
1047 auto loc = memsetOp.getLoc();
1049 Type valueType = adaptor.getValue().getType();
1052 if (!valueType.
isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1054 memsetOp,
"value must be a 16 or 32 bit int or float");
1058 Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1064 rewriter.
create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
1066 dstDesc.alignedPtr(rewriter, loc),
1067 *getTypeConverter());
1069 auto stream = adaptor.getAsyncDependencies().front();
1071 valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1072 builder.
create(loc, rewriter, {dst, value, numElements, stream});
1078 LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1079 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1082 auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1083 {adaptor.getDevIndex()});
1088 template <
typename T>
1091 return builder.
create<LLVM::ConstantOp>(loc, llvmInt32Type,
1092 static_cast<int32_t
>(tValue));
1095 template <
typename T>
1098 return builder.
create<LLVM::ConstantOp>(
1099 loc, llvmFloat32Type,
1103 LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1104 gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1110 auto stream = adaptor.getAsyncDependencies().front();
1113 Type dType = op.getMemref().getType().getElementType();
1117 for (
Value dim : adaptor.getDims()) {
1118 dims.push_back(dim);
1128 if (dims.size() == 2) {
1130 auto handleSz = rewriter.
create<LLVM::ConstantOp>(
1132 handle = rewriter.
create<LLVM::AllocaOp>(
1133 loc, llvmPointerType, llvmInt8Type, handleSz, 16);
1134 handle = rewriter.
create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1136 createLtDnMatCallBuilder
1138 {handle, dims[0], dims[1], pTensor, dtp, stream})
1142 createDnMatCallBuilder
1143 .
create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1147 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1148 handle = createDnVecCallBuilder
1149 .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1152 rewriter.
replaceOp(op, {handle, stream});
1156 LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1157 gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1163 auto stream = adaptor.getAsyncDependencies().front();
1164 auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
1166 for (
Value dim : definingOp.getDims()) {
1167 dims.push_back(dim);
1169 if (dims.size() == 2) {
1173 destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1174 {adaptor.getDnTensor(), stream});
1176 destroyDnMatCallBuilder.create(loc, rewriter,
1177 {adaptor.getDnTensor(), stream});
1180 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1181 destroyDnVecCallBuilder.create(loc, rewriter,
1182 {adaptor.getDnTensor(), stream});
1188 LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1189 gpu::CreateCooOp op, OpAdaptor adaptor,
1195 auto stream = adaptor.getAsyncDependencies().front();
1203 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1205 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1209 createCooCallBuilder
1210 .create(loc, rewriter,
1211 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1212 pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1214 rewriter.
replaceOp(op, {handle, stream});
1218 LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1219 gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1225 auto stream = adaptor.getAsyncDependencies().front();
1229 Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1231 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1235 createCooAoSCallBuilder
1236 .create(loc, rewriter,
1237 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1238 pIdxs, pValues, itp, dtp, stream})
1240 rewriter.
replaceOp(op, {handle, stream});
1244 LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1245 gpu::CreateCsrOp op, OpAdaptor adaptor,
1251 auto stream = adaptor.getAsyncDependencies().front();
1259 llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1261 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1263 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1268 createCsrCallBuilder
1269 .create(loc, rewriter,
1270 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1271 pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1273 rewriter.
replaceOp(op, {handle, stream});
1277 LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1278 gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1284 auto stream = adaptor.getAsyncDependencies().front();
1288 llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1292 auto handleSz = rewriter.
create<LLVM::ConstantOp>(
1295 loc, llvmPointerType, llvmInt8Type, handleSz, 16);
1296 handle = rewriter.
create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1298 create2To4SpMatCallBuilder
1300 {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1302 rewriter.
replaceOp(op, {handle, stream});
1306 LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1307 gpu::DestroySpMatOp op, OpAdaptor adaptor,
1313 auto stream = adaptor.getAsyncDependencies().front();
1316 destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1317 {adaptor.getSpmat(), stream});
1320 destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1326 LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1327 gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1336 auto stream = adaptor.getAsyncDependencies().front();
1337 auto bufferSize = spMVBufferSizeCallBuilder
1338 .create(loc, rewriter,
1339 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1340 adaptor.getDnY(), computeType, stream})
1342 rewriter.
replaceOp(op, {bufferSize, stream});
1346 LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1347 gpu::SpMVOp op, OpAdaptor adaptor,
1356 auto stream = adaptor.getAsyncDependencies().front();
1359 spMVCallBuilder.create(loc, rewriter,
1360 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1361 adaptor.getDnY(), computeType, pBuf, stream});
1366 LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1367 gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1375 auto stream = adaptor.getAsyncDependencies().front();
1382 auto three = rewriter.
create<LLVM::ConstantOp>(loc, getIndexType(),
1384 auto bufferSize = rewriter.
create<LLVM::AllocaOp>(
1385 loc, llvmPointerType, llvmPointerType, three, 16);
1386 createCuSparseLtSpMMBufferSizeBuilder
1388 {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1389 adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1393 auto bufferSizePtr1 = rewriter.
create<LLVM::GEPOp>(
1394 loc, llvmPointerType, llvmPointerType, bufferSize,
1397 auto bufferSizePtr2 = rewriter.
create<LLVM::GEPOp>(
1398 loc, llvmPointerType, llvmPointerType, bufferSize,
1402 rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize);
1404 rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1);
1406 rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
1408 rewriter.
replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1413 createSpMMBufferSizeCallBuilder
1414 .create(loc, rewriter,
1415 {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1416 adaptor.getDnmatC(), computeType, stream})
1418 rewriter.
replaceOp(op, {bufferSize, stream});
1423 LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1424 gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1434 auto stream = adaptor.getAsyncDependencies().front();
1436 createSDDMMBufferSizeCallBuilder
1437 .create(loc, rewriter,
1438 {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1439 adaptor.getSpmatC(), computeType, stream})
1441 rewriter.
replaceOp(op, {bufferSize, stream});
1445 LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1446 gpu::SpMMOp op, OpAdaptor adaptor,
1457 auto stream = adaptor.getAsyncDependencies().front();
1462 for (
Value buffer : adaptor.getBuffers()) {
1464 pBufs.push_back(pBuf);
1466 createCuSparseLtSpMMBuilder.create(
1468 {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1469 pBufs[0], pBufs[1], pBufs[2], stream});
1473 createSpMMCallBuilder.create(loc, rewriter,
1474 {modeA, modeB, adaptor.getSpmatA(),
1475 adaptor.getDnmatB(), adaptor.getDnmatC(),
1476 computeType, pBuf, stream});
1482 template <
typename T>
1489 LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1490 gpu::SDDMMOp op, OpAdaptor adaptor,
1500 auto stream = adaptor.getAsyncDependencies().front();
1503 createSDDMMCallBuilder.create(loc, rewriter,
1504 {modeA, modeB, adaptor.getDnmatA(),
1505 adaptor.getDnmatB(), adaptor.getSpmatC(),
1506 computeType, pBuf, stream});
1512 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1513 gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1519 auto stream = adaptor.getAsyncDependencies().front();
1520 Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1522 rewriter.
replaceOp(op, {descr, stream});
1527 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1528 gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1534 auto stream = adaptor.getAsyncDependencies().front();
1535 createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1536 {adaptor.getDesc(), stream});
1542 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1543 gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1553 auto stream = adaptor.getAsyncDependencies().front();
1557 Value bufferSizeNew;
1559 if (adaptor.getKind() ==
1560 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1562 createSpGEMMWorkEstimationBuilder
1563 .create(loc, rewriter,
1564 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1565 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1566 adaptor.getBufferSz(), pBuf, stream})
1570 createSpGEMMComputeBuilder
1571 .create(loc, rewriter,
1572 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1573 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1574 adaptor.getBufferSz(), pBuf, stream})
1577 rewriter.
replaceOp(op, {bufferSizeNew, stream});
1581 LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1582 gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1592 auto stream = adaptor.getAsyncDependencies().front();
1593 createSpGEMMCopyBuilder.create(loc, rewriter,
1594 {adaptor.getDesc(), modeA, modeB,
1595 adaptor.getSpmatA(), adaptor.getSpmatB(),
1596 adaptor.getSpmatC(), computeType, stream});
1601 LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1602 gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1608 auto stream = adaptor.getAsyncDependencies().front();
1610 auto three = rewriter.
create<LLVM::ConstantOp>(loc, getIndexType(),
1612 auto buffer = rewriter.
create<LLVM::AllocaOp>(
1613 loc, llvmPointerType, llvmInt64Type, three, 16);
1615 auto rowsPtr = rewriter.
create<LLVM::GEPOp>(
1616 loc, llvmPointerType, llvmPointerType, buffer,
1619 auto colsPtr = rewriter.
create<LLVM::GEPOp>(
1620 loc, llvmPointerType, llvmPointerType, buffer,
1623 auto nnzsPtr = rewriter.
create<LLVM::GEPOp>(
1624 loc, llvmPointerType, llvmPointerType, buffer,
1627 createSpMatGetSizeBuilder.
create(
1628 loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1629 auto rows = rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr);
1630 auto cols = rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr);
1631 auto nnzs = rewriter.
create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr);
1637 LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1638 gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1644 auto stream = adaptor.getAsyncDependencies().front();
1651 createSetCsrPointersBuilder.create(
1652 loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1657 LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1658 gpu::CreateCscOp op, OpAdaptor adaptor,
1664 auto stream = adaptor.getAsyncDependencies().front();
1672 llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1674 llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1676 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1681 createCscCallBuilder
1682 .create(loc, rewriter,
1683 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1684 pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1686 rewriter.
replaceOp(op, {handle, stream});
1690 LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1691 gpu::CreateBsrOp op, OpAdaptor adaptor,
1697 auto stream = adaptor.getAsyncDependencies().front();
1705 llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1707 llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1709 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1714 createBsrCallBuilder
1715 .create(loc, rewriter,
1716 {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1717 adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1718 pColIdxs, pValues, ptp, itp, dtp, stream})
1720 rewriter.
replaceOp(op, {handle, stream});
1726 bool kernelBarePtrCallConv) {
1727 addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
1728 addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
1729 addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
1730 addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter);
1732 patterns.
add<ConvertAllocOpToGpuRuntimeCallPattern,
1733 ConvertDeallocOpToGpuRuntimeCallPattern,
1734 ConvertHostRegisterOpToGpuRuntimeCallPattern,
1735 ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1736 ConvertMemcpyOpToGpuRuntimeCallPattern,
1737 ConvertMemsetOpToGpuRuntimeCallPattern,
1738 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1739 ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1740 ConvertWaitOpToGpuRuntimeCallPattern,
1741 ConvertAsyncYieldToGpuRuntimeCallPattern,
1742 ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1743 ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1744 ConvertCreateCooOpToGpuRuntimeCallPattern,
1745 ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1746 ConvertCreateCsrOpToGpuRuntimeCallPattern,
1747 ConvertCreateCscOpToGpuRuntimeCallPattern,
1748 ConvertCreateBsrOpToGpuRuntimeCallPattern,
1749 ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1750 ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1751 ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1752 ConvertSpMVOpToGpuRuntimeCallPattern,
1753 ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1754 ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1755 ConvertSpMMOpToGpuRuntimeCallPattern,
1756 ConvertSDDMMOpToGpuRuntimeCallPattern,
1757 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1758 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1759 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1760 ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1761 ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1762 ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
1763 patterns.
add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv);
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
IntegerAttr getIndexAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
FloatAttr getF32FloatAttr(float value)
IntegerAttr getI8IntegerAttr(int8_t value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
This class helps build Operations.
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Block * getBlock() const
Returns the current block of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
Location getLoc()
The source location the operation was defined or derived from.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
operand_range getOperands()
Returns an iterator on the underlying Value's.
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Include the generated interface declarations.
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void registerConvertToLLVMDependentDialectLoading(DialectRegistry ®istry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool kernelBarePtrCallConv=false)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::LLVMFunctionType functionType
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
Utility class for the GPU dialect to represent triples of Values accessible through ....