40#include "llvm/ADT/STLExtras.h"
42#define DEBUG_TYPE "gpu-to-llvm"
45#define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
46#include "mlir/Conversion/Passes.h.inc"
52class GpuToLLVMConversionPass
56 void getDependentDialects(DialectRegistry ®istry)
const final {
57 Base::getDependentDialects(registry);
61 void runOnOperation()
override;
64template <
typename OpTy>
67 explicit ConvertOpToGpuRuntimeCallPattern(
68 const LLVMTypeConverter &typeConverter)
69 : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
72 Value
getNumElements(ConversionPatternRewriter &rewriter, Location loc,
73 MemRefType type, MemRefDescriptor desc)
const {
75 if (type.hasStaticShape())
77 rewriter, loc, indexType, type.getNumElements());
79 uint64_t rank = type.getRank();
80 Value numElements = desc.
size(rewriter, loc, 0);
81 for (
unsigned i = 1; i < rank; i++)
82 numElements = LLVM::MulOp::create(rewriter, loc, numElements,
83 desc.
size(rewriter, loc, i));
87 MLIRContext *context = &this->getTypeConverter()->
getContext();
89 Type llvmVoidType = LLVM::LLVMVoidType::get(context);
90 LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
91 Type llvmInt8Type = IntegerType::get(context, 8);
92 Type llvmInt16Type = IntegerType::get(context, 16);
93 Type llvmInt32Type = IntegerType::get(context, 32);
94 Type llvmInt64Type = IntegerType::get(context, 64);
95 Type llvmFloat32Type = Float32Type::get(context);
96 Type llvmIntPtrType = IntegerType::get(
97 context, this->getTypeConverter()->getPointerBitwidth(0));
99 FunctionCallBuilder streamCreateCallBuilder = {
100 "mgpuStreamCreate", llvmPointerType , {}};
101 FunctionCallBuilder streamDestroyCallBuilder = {
102 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType }};
103 FunctionCallBuilder streamSynchronizeCallBuilder = {
104 "mgpuStreamSynchronize",
107 FunctionCallBuilder streamWaitEventCallBuilder = {
108 "mgpuStreamWaitEvent",
110 {llvmPointerType , llvmPointerType }};
111 FunctionCallBuilder eventCreateCallBuilder = {
112 "mgpuEventCreate", llvmPointerType , {}};
113 FunctionCallBuilder eventDestroyCallBuilder = {
114 "mgpuEventDestroy", llvmVoidType, {llvmPointerType }};
115 FunctionCallBuilder eventSynchronizeCallBuilder = {
116 "mgpuEventSynchronize",
119 FunctionCallBuilder eventRecordCallBuilder = {
122 {llvmPointerType , llvmPointerType }};
123 FunctionCallBuilder hostRegisterCallBuilder = {
124 "mgpuMemHostRegisterMemRef",
129 FunctionCallBuilder hostUnregisterCallBuilder = {
130 "mgpuMemHostUnregisterMemRef",
135 FunctionCallBuilder allocCallBuilder = {
141 FunctionCallBuilder deallocCallBuilder = {
144 {llvmPointerType , llvmPointerType }};
145 FunctionCallBuilder memcpyCallBuilder = {
148 {llvmPointerType , llvmPointerType ,
151 FunctionCallBuilder memset16CallBuilder = {
158 FunctionCallBuilder memset32CallBuilder = {
161 {llvmPointerType , llvmInt32Type ,
164 FunctionCallBuilder setDefaultDeviceCallBuilder = {
165 "mgpuSetDefaultDevice",
168 FunctionCallBuilder createDnVecCallBuilder = {
171 {llvmIntPtrType, llvmPointerType, llvmInt32Type,
173 FunctionCallBuilder destroyDnVecCallBuilder = {
176 {llvmPointerType, llvmPointerType }};
177 FunctionCallBuilder createDnMatCallBuilder = {
180 {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
182 FunctionCallBuilder destroyDnMatCallBuilder = {
185 {llvmPointerType, llvmPointerType }};
186 FunctionCallBuilder createCooCallBuilder = {
189 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
190 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
192 FunctionCallBuilder createCooAoSCallBuilder = {
195 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
196 llvmPointerType, llvmInt32Type, llvmInt32Type,
198 FunctionCallBuilder createCsrCallBuilder = {
201 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
202 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
203 llvmInt32Type, llvmPointerType }};
204 FunctionCallBuilder createCscCallBuilder = {
207 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
208 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
209 llvmInt32Type, llvmPointerType }};
210 FunctionCallBuilder createBsrCallBuilder = {
213 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
214 llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
215 llvmInt32Type, llvmInt32Type, llvmInt32Type,
217 FunctionCallBuilder destroySpMatCallBuilder = {
220 {llvmPointerType, llvmPointerType }};
221 FunctionCallBuilder spMVBufferSizeCallBuilder = {
222 "mgpuSpMVBufferSize",
224 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
225 llvmInt32Type, llvmPointerType }};
226 FunctionCallBuilder spMVCallBuilder = {
229 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
230 llvmInt32Type, llvmPointerType, llvmPointerType }};
231 FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
232 "mgpuSpMMBufferSize",
234 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
235 llvmPointerType, llvmInt32Type, llvmPointerType }};
236 FunctionCallBuilder createSpMMCallBuilder = {
239 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
240 llvmPointerType, llvmInt32Type, llvmPointerType,
242 FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
243 "mgpuSDDMMBufferSize",
245 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
246 llvmPointerType, llvmInt32Type, llvmPointerType }};
247 FunctionCallBuilder createSDDMMCallBuilder = {
250 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
251 llvmPointerType, llvmInt32Type, llvmPointerType,
253 FunctionCallBuilder createLtDnMatCallBuilder = {
254 "mgpuCreateCuSparseLtDnMat",
256 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
257 llvmInt32Type, llvmPointerType }};
258 FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
259 "mgpuDestroyCuSparseLtSpMat",
261 {llvmPointerType, llvmPointerType }};
262 FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
263 "mgpuDestroyCuSparseLtDnMat",
265 {llvmPointerType, llvmPointerType }};
266 FunctionCallBuilder create2To4SpMatCallBuilder = {
267 "mgpuCusparseLtCreate2To4SpMat",
269 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
270 llvmInt32Type, llvmPointerType }};
271 FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
272 "mgpuCuSparseLtSpMMBufferSize",
274 {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
275 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
277 FunctionCallBuilder createCuSparseLtSpMMBuilder = {
278 "mgpuCuSparseLtSpMM",
280 {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
281 llvmPointerType, llvmPointerType, llvmPointerType }};
282 FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
283 "mgpuSpGEMMCreateDescr",
286 FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
287 "mgpuSpGEMMDestroyDescr",
289 {llvmPointerType , llvmPointerType }};
290 FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
291 "mgpuSpGEMMWorkEstimation",
293 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
294 llvmPointerType , llvmPointerType , llvmPointerType ,
295 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
297 FunctionCallBuilder createSpGEMMComputeBuilder = {
300 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
301 llvmPointerType , llvmPointerType , llvmPointerType ,
302 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
304 FunctionCallBuilder createSpGEMMCopyBuilder = {
307 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
308 llvmPointerType , llvmPointerType , llvmPointerType ,
309 llvmInt32Type , llvmPointerType }};
310 FunctionCallBuilder createSpMatGetSizeBuilder = {
313 {llvmPointerType , llvmPointerType , llvmPointerType ,
314 llvmPointerType , llvmPointerType }};
315 FunctionCallBuilder createSetCsrPointersBuilder = {
316 "mgpuSetCsrPointers",
318 {llvmPointerType , llvmPointerType ,
319 llvmPointerType , llvmPointerType ,
325class ConvertHostRegisterOpToGpuRuntimeCallPattern
326 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
328 ConvertHostRegisterOpToGpuRuntimeCallPattern(
329 const LLVMTypeConverter &typeConverter)
330 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
334 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter)
const override;
338class ConvertHostUnregisterOpToGpuRuntimeCallPattern
339 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
341 ConvertHostUnregisterOpToGpuRuntimeCallPattern(
342 const LLVMTypeConverter &typeConverter)
343 : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
348 matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter)
const override;
354class ConvertAllocOpToGpuRuntimeCallPattern
355 :
public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
357 ConvertAllocOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
358 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
362 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter)
const override;
368class ConvertDeallocOpToGpuRuntimeCallPattern
369 :
public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
371 ConvertDeallocOpToGpuRuntimeCallPattern(
372 const LLVMTypeConverter &typeConverter)
373 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
377 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
378 ConversionPatternRewriter &rewriter)
const override;
381class ConvertAsyncYieldToGpuRuntimeCallPattern
382 :
public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
384 ConvertAsyncYieldToGpuRuntimeCallPattern(
385 const LLVMTypeConverter &typeConverter)
386 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
390 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
391 ConversionPatternRewriter &rewriter)
const override;
396class ConvertWaitOpToGpuRuntimeCallPattern
397 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
399 ConvertWaitOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
400 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
404 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
405 ConversionPatternRewriter &rewriter)
const override;
410class ConvertWaitAsyncOpToGpuRuntimeCallPattern
411 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
413 ConvertWaitAsyncOpToGpuRuntimeCallPattern(
414 const LLVMTypeConverter &typeConverter)
415 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
419 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
420 ConversionPatternRewriter &rewriter)
const override;
424class LegalizeLaunchFuncOpPattern
425 :
public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
427 LegalizeLaunchFuncOpPattern(
const LLVMTypeConverter &typeConverter,
428 bool kernelBarePtrCallConv,
429 bool kernelIntersperseSizeCallConv)
430 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
431 kernelBarePtrCallConv(kernelBarePtrCallConv),
432 kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
436 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
437 ConversionPatternRewriter &rewriter)
const override;
439 bool kernelBarePtrCallConv;
440 bool kernelIntersperseSizeCallConv;
445class ConvertMemcpyOpToGpuRuntimeCallPattern
446 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
448 ConvertMemcpyOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
449 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
453 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
454 ConversionPatternRewriter &rewriter)
const override;
459class ConvertMemsetOpToGpuRuntimeCallPattern
460 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
462 ConvertMemsetOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
463 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
467 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
468 ConversionPatternRewriter &rewriter)
const override;
473class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
474 :
public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
476 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
477 const LLVMTypeConverter &typeConverter)
478 : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
482 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
483 ConversionPatternRewriter &rewriter)
const override;
488#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
489 class Convert##op_name##ToGpuRuntimeCallPattern \
490 : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
492 Convert##op_name##ToGpuRuntimeCallPattern( \
493 const LLVMTypeConverter &typeConverter) \
494 : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
498 matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
499 ConversionPatternRewriter &rewriter) const override; \
526void GpuToLLVMConversionPass::runOnOperation() {
537 vector::populateVectorFromElementsUnrollPatterns(patterns);
539 return signalPassFailure();
542 LowerToLLVMOptions
options(context);
543 options.useBarePtrCallConv = hostBarePtrCallConv;
544 RewritePatternSet patterns(context);
545 ConversionTarget
target(*context);
546 target.addLegalDialect<LLVM::LLVMDialect>();
547 LLVMTypeConverter converter(context,
options);
552 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
555 iface->populateConvertToLLVMConversionPatterns(
target, converter, patterns);
560 target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
562 target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
563 [&](gpu::LaunchFuncOp op) ->
bool {
return converter.isLegal(op); });
571 kernelBarePtrCallConv,
572 kernelIntersperseSizeCallConv);
575 applyPartialConversion(getOperation(),
target, std::move(patterns))))
581 auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
582 auto function = [&] {
583 if (
auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(
functionName))
588 return LLVM::CallOp::create(builder, loc, function, arguments);
605 llvm_unreachable(
"unsupported type");
611 if (llvm::isa<ComplexType>(type)) {
613 auto elementType = cast<ComplexType>(type).getElementType();
614 if (elementType.isBF16())
616 if (elementType.isF16())
618 if (elementType.isF32())
620 if (elementType.isF64())
622 if (elementType.isInteger(8))
624 if (elementType.isInteger(16))
626 if (elementType.isInteger(32))
644 llvm_unreachable(
"unsupported element type");
648 return spMat.
getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
673 llvm_unreachable(
"cannot find spmat def");
678 auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
690 ConversionPatternRewriter &rewriter) {
691 if (!llvm::all_of(operands, [](
Value value) {
694 return rewriter.notifyMatchFailure(
695 op,
"Cannot convert if operands aren't of LLVM type.");
701 gpu::AsyncOpInterface op) {
702 if (op.getAsyncDependencies().size() != 1)
703 return rewriter.notifyMatchFailure(
704 op,
"Can only convert with exactly one async dependency.");
706 if (!op.getAsyncToken())
707 return rewriter.notifyMatchFailure(op,
"Can convert only async version.");
712LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
713 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
714 ConversionPatternRewriter &rewriter)
const {
715 auto *op = hostRegisterOp.getOperation();
719 Location loc = op->getLoc();
721 auto memRefType = hostRegisterOp.getValue().getType();
722 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
725 auto arguments = getTypeConverter()->promoteOperands(
726 loc, op->getOperands(), adaptor.getOperands(), rewriter);
727 arguments.push_back(elementSize);
728 hostRegisterCallBuilder.create(loc, rewriter, arguments);
730 rewriter.eraseOp(op);
734LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
735 gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
736 ConversionPatternRewriter &rewriter)
const {
737 Operation *op = hostUnregisterOp.getOperation();
741 Location loc = op->
getLoc();
743 auto memRefType = hostUnregisterOp.getValue().getType();
744 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
747 auto arguments = getTypeConverter()->promoteOperands(
748 loc, op->
getOperands(), adaptor.getOperands(), rewriter);
749 arguments.push_back(elementSize);
750 hostUnregisterCallBuilder.create(loc, rewriter, arguments);
752 rewriter.eraseOp(op);
756LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
757 gpu::AllocOp allocOp, OpAdaptor adaptor,
758 ConversionPatternRewriter &rewriter)
const {
760 MemRefType memRefType = allocOp.getType();
763 !isConvertibleAndHasIdentityMaps(memRefType))
766 auto loc = allocOp.getLoc();
768 bool isShared = allocOp.getHostShared();
770 if (isShared && allocOp.getAsyncToken())
771 return rewriter.notifyMatchFailure(
772 allocOp,
"Host Shared allocation cannot be done async");
778 SmallVector<Value, 4> shape;
779 SmallVector<Value, 4> strides;
781 getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
782 shape, strides, sizeBytes);
786 auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType);
787 Value stream = adaptor.getAsyncDependencies().empty()
789 : adaptor.getAsyncDependencies().front();
791 auto isHostShared = mlir::LLVM::ConstantOp::create(
792 rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
795 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
799 Value alignedPtr = allocatedPtr;
802 auto memRefDescriptor = this->createMemRefDescriptor(
803 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
805 if (allocOp.getAsyncToken()) {
807 rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
809 rewriter.replaceOp(allocOp, {memRefDescriptor});
815LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
816 gpu::DeallocOp deallocOp, OpAdaptor adaptor,
817 ConversionPatternRewriter &rewriter)
const {
822 Location loc = deallocOp.getLoc();
825 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
826 Value stream = adaptor.getAsyncDependencies().front();
827 deallocCallBuilder.create(loc, rewriter, {pointer, stream});
829 rewriter.replaceOp(deallocOp, {stream});
834 return isa<gpu::AsyncTokenType>(value.
getType());
841LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
842 async::YieldOp yieldOp, OpAdaptor adaptor,
843 ConversionPatternRewriter &rewriter)
const {
845 return rewriter.notifyMatchFailure(yieldOp,
"no gpu async token operand");
847 Location loc = yieldOp.getLoc();
848 SmallVector<Value, 4> newOperands(adaptor.getOperands());
849 llvm::SmallDenseSet<Value> streams;
850 for (
auto &operand : yieldOp->getOpOperands()) {
853 auto idx = operand.getOperandNumber();
854 auto stream = adaptor.getOperands()[idx];
855 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
856 eventRecordCallBuilder.create(loc, rewriter, {event, stream});
857 newOperands[idx] = event;
858 streams.insert(stream);
860 for (
auto stream : streams)
861 streamDestroyCallBuilder.create(loc, rewriter, {stream});
863 rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
869 assert(isa<LLVM::LLVMPointerType>(value.
getType()));
871 return *defOp.getCallee() == functionName;
879LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
880 gpu::WaitOp waitOp, OpAdaptor adaptor,
881 ConversionPatternRewriter &rewriter)
const {
882 if (waitOp.getAsyncToken())
883 return rewriter.notifyMatchFailure(waitOp,
"Cannot convert async op.");
885 Location loc = waitOp.getLoc();
887 for (
auto operand : adaptor.getOperands()) {
890 streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
891 streamDestroyCallBuilder.create(loc, rewriter, {operand});
895 eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
896 eventDestroyCallBuilder.create(loc, rewriter, {operand});
900 rewriter.eraseOp(waitOp);
909LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
910 gpu::WaitOp waitOp, OpAdaptor adaptor,
911 ConversionPatternRewriter &rewriter)
const {
912 if (!waitOp.getAsyncToken())
913 return rewriter.notifyMatchFailure(waitOp,
"Can only convert async op.");
915 Location loc = waitOp.getLoc();
917 auto insertionPoint = rewriter.saveInsertionPoint();
918 SmallVector<Value, 1> events;
920 llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
921 auto operand = std::get<1>(pair);
925 auto *defOp = std::get<0>(pair).getDefiningOp();
926 rewriter.setInsertionPointAfter(defOp);
927 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
928 eventRecordCallBuilder.create(loc, rewriter, {event, operand});
929 events.push_back(event);
933 events.push_back(operand);
936 rewriter.restoreInsertionPoint(insertionPoint);
937 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
938 for (
auto event : events)
939 streamWaitEventCallBuilder.create(loc, rewriter, {stream,
event});
940 for (
auto event : events)
941 eventDestroyCallBuilder.create(loc, rewriter, {
event});
942 rewriter.replaceOp(waitOp, {stream});
948LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
949 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
950 ConversionPatternRewriter &rewriter)
const {
954 if (launchOp.getAsyncDependencies().size() > 1)
955 return rewriter.notifyMatchFailure(
956 launchOp,
"Cannot convert with more than one async dependency.");
961 if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
962 return rewriter.notifyMatchFailure(
963 launchOp,
"Cannot convert non-async op with async dependencies.");
965 Location loc = launchOp.getLoc();
967 Value stream = Value();
968 if (!adaptor.getAsyncDependencies().empty())
969 stream = adaptor.getAsyncDependencies().front();
972 else if (launchOp.getAsyncToken())
973 stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
978 OperandRange origArguments = launchOp.getKernelOperands();
979 bool effectiveBarePtr = kernelBarePtrCallConv ||
980 getTypeConverter()->getOptions().useBarePtrCallConv;
981 if (effectiveBarePtr) {
982 for (Value arg : origArguments) {
983 if (isa<UnrankedMemRefType>(arg.getType()))
984 return rewriter.notifyMatchFailure(
985 loc,
"unranked memref kernel argument is not supported with "
986 "the bare-pointer calling convention");
989 SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
990 loc, origArguments, adaptor.getKernelOperands(), rewriter,
991 kernelBarePtrCallConv);
992 SmallVector<Value, 8> llvmArgumentsWithSizes;
995 if (kernelIntersperseSizeCallConv) {
996 if (origArguments.size() != llvmArguments.size()) {
998 return rewriter.notifyMatchFailure(
1000 "Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
1003 llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
1004 for (
auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
1005 auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
1007 return rewriter.notifyMatchFailure(
1008 launchOp,
"Operand to launch op is not a memref.");
1011 if (!memrefTy.hasStaticShape() ||
1012 !memrefTy.getElementType().isIntOrFloat()) {
1013 return rewriter.notifyMatchFailure(
1014 launchOp,
"Operand to launch op is not a memref with a static "
1015 "shape and an integer or float element type.");
1018 unsigned bitwidth = memrefTy.getElementTypeBitWidth();
1019 if (bitwidth % 8 != 0) {
1020 return rewriter.notifyMatchFailure(
1021 launchOp,
"Operand to launch op is not a memref with a "
1022 "byte-aligned element type.");
1025 uint64_t staticSize =
static_cast<uint64_t
>(bitwidth / 8) *
1026 static_cast<uint64_t
>(memrefTy.getNumElements());
1028 Value sizeArg = LLVM::ConstantOp::create(
1029 rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize));
1030 llvmArgumentsWithSizes.push_back(llvmArg);
1031 llvmArgumentsWithSizes.push_back(sizeArg);
1035 std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1036 if (launchOp.hasClusterSize()) {
1038 gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
1039 adaptor.getClusterSizeZ()};
1041 gpu::LaunchFuncOp::create(
1042 rewriter, launchOp.getLoc(), launchOp.getKernelAttr(),
1043 gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1044 adaptor.getGridSizeZ()},
1045 gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1046 adaptor.getBlockSizeZ()},
1047 adaptor.getDynamicSharedMemorySize(),
1048 llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
1049 stream, clusterSize);
1050 if (launchOp.getAsyncToken())
1051 rewriter.replaceOp(launchOp, {stream});
1053 rewriter.eraseOp(launchOp);
1058 ConversionPatternRewriter &rewriter,
1059 LLVM::LLVMPointerType destinationType,
1062 auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.
getType());
1063 if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1064 sourcePtr = LLVM::AddrSpaceCastOp::create(
1066 LLVM::LLVMPointerType::get(rewriter.getContext(),
1067 destinationType.getAddressSpace()),
1072LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1073 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1074 ConversionPatternRewriter &rewriter)
const {
1075 auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1078 !isConvertibleAndHasIdentityMaps(memRefType) ||
1082 auto loc = memcpyOp.getLoc();
1084 MemRefDescriptor srcDesc(adaptor.getSrc());
1085 Value numElements =
getNumElements(rewriter, loc, memRefType, srcDesc);
1087 Type elementPtrType = getElementPtrType(memRefType);
1088 Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
1089 Value gepPtr = LLVM::GEPOp::create(
1090 rewriter, loc, elementPtrType,
1091 typeConverter->convertType(memRefType.getElementType()), nullPtr,
1094 LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr);
1097 srcDesc.alignedPtr(rewriter, loc),
1098 *getTypeConverter());
1100 loc, rewriter, llvmPointerType,
1101 MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
1102 *getTypeConverter());
1104 auto stream = adaptor.getAsyncDependencies().front();
1105 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1107 rewriter.replaceOp(memcpyOp, {stream});
1112LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1113 gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1114 ConversionPatternRewriter &rewriter)
const {
1115 auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1118 !isConvertibleAndHasIdentityMaps(memRefType) ||
1122 auto loc = memsetOp.getLoc();
1124 Type valueType = adaptor.getValue().getType();
1127 if (!valueType.
isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1128 return rewriter.notifyMatchFailure(
1129 memsetOp,
"value must be a 16 or 32 bit int or float");
1133 Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1135 MemRefDescriptor dstDesc(adaptor.getDst());
1136 Value numElements =
getNumElements(rewriter, loc, memRefType, dstDesc);
1139 LLVM::BitcastOp::create(rewriter, loc, bitCastType, adaptor.getValue());
1141 dstDesc.alignedPtr(rewriter, loc),
1142 *getTypeConverter());
1144 auto stream = adaptor.getAsyncDependencies().front();
1145 FunctionCallBuilder builder =
1146 valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1147 builder.
create(loc, rewriter, {dst, value, numElements, stream});
1149 rewriter.replaceOp(memsetOp, {stream});
1153LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1154 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1155 ConversionPatternRewriter &rewriter)
const {
1156 Location loc = op.getLoc();
1157 auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1158 {adaptor.getDevIndex()});
1159 rewriter.replaceOp(op, call);
1163template <
typename T>
1166 return LLVM::ConstantOp::create(builder, loc, llvmInt32Type,
1167 static_cast<int32_t
>(tValue));
1170template <
typename T>
1173 return LLVM::ConstantOp::create(
1174 builder, loc, llvmFloat32Type,
1178LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1179 gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1180 ConversionPatternRewriter &rewriter)
const {
1184 Location loc = op.getLoc();
1185 auto stream = adaptor.getAsyncDependencies().front();
1187 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1188 Type dType = op.getMemref().
getType().getElementType();
1191 SmallVector<Value, 4> dims;
1192 for (Value dim : adaptor.getDims()) {
1193 dims.push_back(dim);
1203 if (dims.size() == 2) {
1205 auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1206 rewriter.getIndexAttr(11032));
1207 handle = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1208 llvmInt8Type, handleSz, 16);
1209 handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1211 createLtDnMatCallBuilder
1212 .create(loc, rewriter,
1213 {handle, dims[0], dims[1], pTensor, dtp, stream})
1217 createDnMatCallBuilder
1218 .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1222 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1223 handle = createDnVecCallBuilder
1224 .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1227 rewriter.replaceOp(op, {handle, stream});
1231LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1232 gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1233 ConversionPatternRewriter &rewriter)
const {
1237 Location loc = op.getLoc();
1238 auto stream = adaptor.getAsyncDependencies().front();
1239 auto definingOp = op.getDnTensor().
getDefiningOp<gpu::CreateDnTensorOp>();
1240 SmallVector<Value, 4> dims;
1241 for (Value dim : definingOp.getDims()) {
1242 dims.push_back(dim);
1244 if (dims.size() == 2) {
1248 destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1249 {adaptor.getDnTensor(), stream});
1251 destroyDnMatCallBuilder.create(loc, rewriter,
1252 {adaptor.getDnTensor(), stream});
1255 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1256 destroyDnVecCallBuilder.create(loc, rewriter,
1257 {adaptor.getDnTensor(), stream});
1259 rewriter.replaceOp(op, {stream});
1263LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1264 gpu::CreateCooOp op, OpAdaptor adaptor,
1265 ConversionPatternRewriter &rewriter)
const {
1269 Location loc = op.getLoc();
1270 auto stream = adaptor.getAsyncDependencies().front();
1272 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1274 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1276 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1278 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1280 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1284 createCooCallBuilder
1285 .create(loc, rewriter,
1286 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1287 pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1289 rewriter.replaceOp(op, {handle, stream});
1293LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1294 gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1295 ConversionPatternRewriter &rewriter)
const {
1299 Location loc = op.getLoc();
1300 auto stream = adaptor.getAsyncDependencies().front();
1301 Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc);
1303 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1304 Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1306 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1310 createCooAoSCallBuilder
1311 .create(loc, rewriter,
1312 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1313 pIdxs, pValues, itp, dtp, stream})
1315 rewriter.replaceOp(op, {handle, stream});
1319LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1320 gpu::CreateCsrOp op, OpAdaptor adaptor,
1321 ConversionPatternRewriter &rewriter)
const {
1325 Location loc = op.getLoc();
1326 auto stream = adaptor.getAsyncDependencies().front();
1328 MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc);
1330 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1332 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1334 llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1336 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1338 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1343 createCsrCallBuilder
1344 .create(loc, rewriter,
1345 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1346 pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1348 rewriter.replaceOp(op, {handle, stream});
1352LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1353 gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1354 ConversionPatternRewriter &rewriter)
const {
1358 Location loc = op.getLoc();
1359 auto stream = adaptor.getAsyncDependencies().front();
1361 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1363 llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1367 auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1368 rewriter.getIndexAttr(44104));
1369 Value handle = LLVM::AllocaOp::create(
1370 rewriter, loc, llvmPointerType, llvmInt8Type, handleSz, 16);
1371 handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1373 create2To4SpMatCallBuilder
1374 .create(loc, rewriter,
1375 {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1377 rewriter.replaceOp(op, {handle, stream});
1381LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1382 gpu::DestroySpMatOp op, OpAdaptor adaptor,
1383 ConversionPatternRewriter &rewriter)
const {
1387 Location loc = op.getLoc();
1388 auto stream = adaptor.getAsyncDependencies().front();
1391 destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1392 {adaptor.getSpmat(), stream});
1395 destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1397 rewriter.replaceOp(op, {stream});
1401LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1402 gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1403 ConversionPatternRewriter &rewriter)
const {
1407 Location loc = op.getLoc();
1411 auto stream = adaptor.getAsyncDependencies().front();
1412 auto bufferSize = spMVBufferSizeCallBuilder
1413 .create(loc, rewriter,
1414 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1415 adaptor.getDnY(), computeType, stream})
1417 rewriter.replaceOp(op, {bufferSize, stream});
1421LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1422 gpu::SpMVOp op, OpAdaptor adaptor,
1423 ConversionPatternRewriter &rewriter)
const {
1427 Location loc = op.getLoc();
1431 auto stream = adaptor.getAsyncDependencies().front();
1433 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1434 spMVCallBuilder.create(loc, rewriter,
1435 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1436 adaptor.getDnY(), computeType, pBuf, stream});
1437 rewriter.replaceOp(op, {stream});
1441LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1442 gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1443 ConversionPatternRewriter &rewriter)
const {
1447 Location loc = op.getLoc();
1450 auto stream = adaptor.getAsyncDependencies().front();
1457 auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1458 rewriter.getIndexAttr(3));
1460 LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmPointerType,
1462 createCuSparseLtSpMMBufferSizeBuilder
1463 .create(loc, rewriter,
1464 {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1465 adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1469 auto bufferSizePtr1 = LLVM::GEPOp::create(
1470 rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1471 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1472 rewriter.getIndexAttr(1))});
1473 auto bufferSizePtr2 = LLVM::GEPOp::create(
1474 rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1475 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1476 rewriter.getIndexAttr(2))});
1478 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSize);
1480 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr1);
1482 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr2);
1484 rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1489 createSpMMBufferSizeCallBuilder
1490 .create(loc, rewriter,
1491 {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1492 adaptor.getDnmatC(), computeType, stream})
1494 rewriter.replaceOp(op, {bufferSize, stream});
1499LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1500 gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1501 ConversionPatternRewriter &rewriter)
const {
1505 Location loc = op.getLoc();
1510 auto stream = adaptor.getAsyncDependencies().front();
1512 createSDDMMBufferSizeCallBuilder
1513 .create(loc, rewriter,
1514 {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1515 adaptor.getSpmatC(), computeType, stream})
1517 rewriter.replaceOp(op, {bufferSize, stream});
1521LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1522 gpu::SpMMOp op, OpAdaptor adaptor,
1523 ConversionPatternRewriter &rewriter)
const {
1527 Location loc = op.getLoc();
1533 auto stream = adaptor.getAsyncDependencies().front();
1537 SmallVector<Value> pBufs;
1538 for (Value buffer : adaptor.getBuffers()) {
1539 Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
1540 pBufs.push_back(pBuf);
1542 createCuSparseLtSpMMBuilder.create(
1544 {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1545 pBufs[0], pBufs[1], pBufs[2], stream});
1547 Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
1548 .allocatedPtr(rewriter, loc);
1549 createSpMMCallBuilder.create(loc, rewriter,
1550 {modeA, modeB, adaptor.getSpmatA(),
1551 adaptor.getDnmatB(), adaptor.getDnmatC(),
1552 computeType, pBuf, stream});
1554 rewriter.replaceOp(op, {stream});
1558template <
typename T>
1560 converter.addConversion([&converter](T) ->
Type {
1561 return LLVM::LLVMPointerType::get(&converter.
getContext());
1565LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1566 gpu::SDDMMOp op, OpAdaptor adaptor,
1567 ConversionPatternRewriter &rewriter)
const {
1571 Location loc = op.getLoc();
1576 auto stream = adaptor.getAsyncDependencies().front();
1578 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1579 createSDDMMCallBuilder.create(loc, rewriter,
1580 {modeA, modeB, adaptor.getDnmatA(),
1581 adaptor.getDnmatB(), adaptor.getSpmatC(),
1582 computeType, pBuf, stream});
1583 rewriter.replaceOp(op, {stream});
1588ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1589 gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1590 ConversionPatternRewriter &rewriter)
const {
1594 Location loc = op.getLoc();
1595 auto stream = adaptor.getAsyncDependencies().front();
1596 Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1598 rewriter.replaceOp(op, {descr, stream});
1603ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1604 gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1605 ConversionPatternRewriter &rewriter)
const {
1609 Location loc = op.getLoc();
1610 auto stream = adaptor.getAsyncDependencies().front();
1611 createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1612 {adaptor.getDesc(), stream});
1613 rewriter.replaceOp(op, {stream});
1618ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1619 gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1620 ConversionPatternRewriter &rewriter)
const {
1624 Location loc = op.getLoc();
1629 auto stream = adaptor.getAsyncDependencies().front();
1632 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1633 Value bufferSizeNew;
1635 if (adaptor.getKind() ==
1636 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1638 createSpGEMMWorkEstimationBuilder
1639 .create(loc, rewriter,
1640 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1641 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1642 adaptor.getBufferSz(), pBuf, stream})
1646 createSpGEMMComputeBuilder
1647 .create(loc, rewriter,
1648 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1649 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1650 adaptor.getBufferSz(), pBuf, stream})
1653 rewriter.replaceOp(op, {bufferSizeNew, stream});
1657LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1658 gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1659 ConversionPatternRewriter &rewriter)
const {
1663 Location loc = op.getLoc();
1668 auto stream = adaptor.getAsyncDependencies().front();
1669 createSpGEMMCopyBuilder.create(loc, rewriter,
1670 {adaptor.getDesc(), modeA, modeB,
1671 adaptor.getSpmatA(), adaptor.getSpmatB(),
1672 adaptor.getSpmatC(), computeType, stream});
1673 rewriter.replaceOp(op, {stream});
1677LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1678 gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1679 ConversionPatternRewriter &rewriter)
const {
1683 Location loc = op.getLoc();
1684 auto stream = adaptor.getAsyncDependencies().front();
1686 auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1687 rewriter.getIndexAttr(3));
1688 auto buffer = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1689 llvmInt64Type, three, 16);
1691 auto rowsPtr = LLVM::GEPOp::create(
1692 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1693 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1694 rewriter.getIndexAttr(0))});
1695 auto colsPtr = LLVM::GEPOp::create(
1696 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1697 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1698 rewriter.getIndexAttr(1))});
1699 auto nnzsPtr = LLVM::GEPOp::create(
1700 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1701 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1702 rewriter.getIndexAttr(2))});
1703 createSpMatGetSizeBuilder.create(
1704 loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1705 auto rows = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, rowsPtr);
1706 auto cols = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, colsPtr);
1707 auto nnzs = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, nnzsPtr);
1709 rewriter.replaceOp(op, {rows, cols, nnzs, stream});
1713LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1714 gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1715 ConversionPatternRewriter &rewriter)
const {
1719 Location loc = op.getLoc();
1720 auto stream = adaptor.getAsyncDependencies().front();
1722 MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
1724 MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
1726 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1727 createSetCsrPointersBuilder.create(
1728 loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1729 rewriter.replaceOp(op, {stream});
1733LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1734 gpu::CreateCscOp op, OpAdaptor adaptor,
1735 ConversionPatternRewriter &rewriter)
const {
1739 Location loc = op.getLoc();
1740 auto stream = adaptor.getAsyncDependencies().front();
1742 MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc);
1744 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1746 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1748 llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1750 llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1752 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1757 createCscCallBuilder
1758 .create(loc, rewriter,
1759 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1760 pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1762 rewriter.replaceOp(op, {handle, stream});
1766LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1767 gpu::CreateBsrOp op, OpAdaptor adaptor,
1768 ConversionPatternRewriter &rewriter)
const {
1772 Location loc = op.getLoc();
1773 auto stream = adaptor.getAsyncDependencies().front();
1775 MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc);
1777 MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc);
1779 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1781 llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1783 llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1785 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1790 createBsrCallBuilder
1791 .create(loc, rewriter,
1792 {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1793 adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1794 pColIdxs, pValues, ptp, itp, dtp, stream})
1796 rewriter.replaceOp(op, {handle, stream});
1802 bool kernelBarePtrCallConv,
bool kernelIntersperseSizeCallConv) {
1808 patterns.
add<ConvertAllocOpToGpuRuntimeCallPattern,
1809 ConvertDeallocOpToGpuRuntimeCallPattern,
1810 ConvertHostRegisterOpToGpuRuntimeCallPattern,
1811 ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1812 ConvertMemcpyOpToGpuRuntimeCallPattern,
1813 ConvertMemsetOpToGpuRuntimeCallPattern,
1814 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1815 ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1816 ConvertWaitOpToGpuRuntimeCallPattern,
1817 ConvertAsyncYieldToGpuRuntimeCallPattern,
1818 ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1819 ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1820 ConvertCreateCooOpToGpuRuntimeCallPattern,
1821 ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1822 ConvertCreateCsrOpToGpuRuntimeCallPattern,
1823 ConvertCreateCscOpToGpuRuntimeCallPattern,
1824 ConvertCreateBsrOpToGpuRuntimeCallPattern,
1825 ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1826 ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1827 ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1828 ConvertSpMVOpToGpuRuntimeCallPattern,
1829 ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1830 ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1831 ConvertSpMMOpToGpuRuntimeCallPattern,
1832 ConvertSDDMMOpToGpuRuntimeCallPattern,
1833 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1834 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1835 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1836 ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1837 ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1838 ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
1839 patterns.
add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv,
1840 kernelIntersperseSizeCallConv);
1848struct GPUModuleOpConvertToLLVMInterface
1849 :
public ConvertToLLVMOpInterface::ExternalModel<
1850 GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
1852 void getConvertToLLVMConversionAttrs(
1857void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
1858 Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs)
const {
1859 auto module = cast<gpu::GPUModuleOp>(op);
1860 ArrayAttr targetsAttr =
module.getTargetsAttr();
1862 if (!targetsAttr || targetsAttr.size() != 1)
1864 if (
auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
1865 attrs.push_back(patternAttr);
1870 gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
IntegerType getIntegerType(unsigned width)
FloatAttr getF32FloatAttr(float value)
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
This class helps build Operations.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void registerConvertGpuToLLVMInterface(DialectRegistry ®istry)
Registers the ConvertToLLVMOpInterface interface on the gpu::GPUModuleOP operation.
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool kernelBarePtrCallConv=false, bool kernelIntersperseSizeCallConv=false)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
void registerConvertToLLVMDependentDialectLoading(DialectRegistry ®istry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
LLVM::LLVMFunctionType functionType
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const