40#include "llvm/ADT/STLExtras.h"
42#define DEBUG_TYPE "gpu-to-llvm"
45#define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
46#include "mlir/Conversion/Passes.h.inc"
52class GpuToLLVMConversionPass
56 void getDependentDialects(DialectRegistry ®istry)
const final {
57 Base::getDependentDialects(registry);
61 void runOnOperation()
override;
64template <
typename OpTy>
67 explicit ConvertOpToGpuRuntimeCallPattern(
68 const LLVMTypeConverter &typeConverter)
69 : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
72 Value
getNumElements(ConversionPatternRewriter &rewriter, Location loc,
73 MemRefType type, MemRefDescriptor desc)
const {
75 if (type.hasStaticShape())
77 rewriter, loc, indexType, type.getNumElements());
79 uint64_t rank = type.getRank();
80 Value numElements = desc.
size(rewriter, loc, 0);
81 for (
unsigned i = 1; i < rank; i++)
82 numElements = LLVM::MulOp::create(rewriter, loc, numElements,
83 desc.
size(rewriter, loc, i));
87 MLIRContext *context = &this->getTypeConverter()->
getContext();
89 Type llvmVoidType = LLVM::LLVMVoidType::get(context);
90 LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
91 Type llvmInt8Type = IntegerType::get(context, 8);
92 Type llvmInt16Type = IntegerType::get(context, 16);
93 Type llvmInt32Type = IntegerType::get(context, 32);
94 Type llvmInt64Type = IntegerType::get(context, 64);
95 Type llvmFloat32Type = Float32Type::get(context);
96 Type llvmIntPtrType = IntegerType::get(
97 context, this->getTypeConverter()->getPointerBitwidth(0));
99 FunctionCallBuilder streamCreateCallBuilder = {
100 "mgpuStreamCreate", llvmPointerType , {}};
101 FunctionCallBuilder streamDestroyCallBuilder = {
102 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType }};
103 FunctionCallBuilder streamSynchronizeCallBuilder = {
104 "mgpuStreamSynchronize",
107 FunctionCallBuilder streamWaitEventCallBuilder = {
108 "mgpuStreamWaitEvent",
110 {llvmPointerType , llvmPointerType }};
111 FunctionCallBuilder eventCreateCallBuilder = {
112 "mgpuEventCreate", llvmPointerType , {}};
113 FunctionCallBuilder eventDestroyCallBuilder = {
114 "mgpuEventDestroy", llvmVoidType, {llvmPointerType }};
115 FunctionCallBuilder eventSynchronizeCallBuilder = {
116 "mgpuEventSynchronize",
119 FunctionCallBuilder eventRecordCallBuilder = {
122 {llvmPointerType , llvmPointerType }};
123 FunctionCallBuilder hostRegisterCallBuilder = {
124 "mgpuMemHostRegisterMemRef",
129 FunctionCallBuilder hostUnregisterCallBuilder = {
130 "mgpuMemHostUnregisterMemRef",
135 FunctionCallBuilder allocCallBuilder = {
141 FunctionCallBuilder deallocCallBuilder = {
144 {llvmPointerType , llvmPointerType }};
145 FunctionCallBuilder memcpyCallBuilder = {
148 {llvmPointerType , llvmPointerType ,
151 FunctionCallBuilder memset16CallBuilder = {
158 FunctionCallBuilder memset32CallBuilder = {
161 {llvmPointerType , llvmInt32Type ,
164 FunctionCallBuilder setDefaultDeviceCallBuilder = {
165 "mgpuSetDefaultDevice",
168 FunctionCallBuilder createDnVecCallBuilder = {
171 {llvmIntPtrType, llvmPointerType, llvmInt32Type,
173 FunctionCallBuilder destroyDnVecCallBuilder = {
176 {llvmPointerType, llvmPointerType }};
177 FunctionCallBuilder createDnMatCallBuilder = {
180 {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
182 FunctionCallBuilder destroyDnMatCallBuilder = {
185 {llvmPointerType, llvmPointerType }};
186 FunctionCallBuilder createCooCallBuilder = {
189 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
190 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
192 FunctionCallBuilder createCooAoSCallBuilder = {
195 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
196 llvmPointerType, llvmInt32Type, llvmInt32Type,
198 FunctionCallBuilder createCsrCallBuilder = {
201 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
202 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
203 llvmInt32Type, llvmPointerType }};
204 FunctionCallBuilder createCscCallBuilder = {
207 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
208 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
209 llvmInt32Type, llvmPointerType }};
210 FunctionCallBuilder createBsrCallBuilder = {
213 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
214 llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
215 llvmInt32Type, llvmInt32Type, llvmInt32Type,
217 FunctionCallBuilder destroySpMatCallBuilder = {
220 {llvmPointerType, llvmPointerType }};
221 FunctionCallBuilder spMVBufferSizeCallBuilder = {
222 "mgpuSpMVBufferSize",
224 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
225 llvmInt32Type, llvmPointerType }};
226 FunctionCallBuilder spMVCallBuilder = {
229 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
230 llvmInt32Type, llvmPointerType, llvmPointerType }};
231 FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
232 "mgpuSpMMBufferSize",
234 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
235 llvmPointerType, llvmInt32Type, llvmPointerType }};
236 FunctionCallBuilder createSpMMCallBuilder = {
239 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
240 llvmPointerType, llvmInt32Type, llvmPointerType,
242 FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
243 "mgpuSDDMMBufferSize",
245 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
246 llvmPointerType, llvmInt32Type, llvmPointerType }};
247 FunctionCallBuilder createSDDMMCallBuilder = {
250 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
251 llvmPointerType, llvmInt32Type, llvmPointerType,
253 FunctionCallBuilder createLtDnMatCallBuilder = {
254 "mgpuCreateCuSparseLtDnMat",
256 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
257 llvmInt32Type, llvmPointerType }};
258 FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
259 "mgpuDestroyCuSparseLtSpMat",
261 {llvmPointerType, llvmPointerType }};
262 FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
263 "mgpuDestroyCuSparseLtDnMat",
265 {llvmPointerType, llvmPointerType }};
266 FunctionCallBuilder create2To4SpMatCallBuilder = {
267 "mgpuCusparseLtCreate2To4SpMat",
269 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
270 llvmInt32Type, llvmPointerType }};
271 FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
272 "mgpuCuSparseLtSpMMBufferSize",
274 {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
275 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
277 FunctionCallBuilder createCuSparseLtSpMMBuilder = {
278 "mgpuCuSparseLtSpMM",
280 {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
281 llvmPointerType, llvmPointerType, llvmPointerType }};
282 FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
283 "mgpuSpGEMMCreateDescr",
286 FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
287 "mgpuSpGEMMDestroyDescr",
289 {llvmPointerType , llvmPointerType }};
290 FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
291 "mgpuSpGEMMWorkEstimation",
293 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
294 llvmPointerType , llvmPointerType , llvmPointerType ,
295 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
297 FunctionCallBuilder createSpGEMMComputeBuilder = {
300 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
301 llvmPointerType , llvmPointerType , llvmPointerType ,
302 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
304 FunctionCallBuilder createSpGEMMCopyBuilder = {
307 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
308 llvmPointerType , llvmPointerType , llvmPointerType ,
309 llvmInt32Type , llvmPointerType }};
310 FunctionCallBuilder createSpMatGetSizeBuilder = {
313 {llvmPointerType , llvmPointerType , llvmPointerType ,
314 llvmPointerType , llvmPointerType }};
315 FunctionCallBuilder createSetCsrPointersBuilder = {
316 "mgpuSetCsrPointers",
318 {llvmPointerType , llvmPointerType ,
319 llvmPointerType , llvmPointerType ,
325class ConvertHostRegisterOpToGpuRuntimeCallPattern
326 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
328 ConvertHostRegisterOpToGpuRuntimeCallPattern(
329 const LLVMTypeConverter &typeConverter)
330 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
334 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter)
const override;
338class ConvertHostUnregisterOpToGpuRuntimeCallPattern
339 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
341 ConvertHostUnregisterOpToGpuRuntimeCallPattern(
342 const LLVMTypeConverter &typeConverter)
343 : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
348 matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter)
const override;
354class ConvertAllocOpToGpuRuntimeCallPattern
355 :
public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
357 ConvertAllocOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
358 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
362 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter)
const override;
368class ConvertDeallocOpToGpuRuntimeCallPattern
369 :
public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
371 ConvertDeallocOpToGpuRuntimeCallPattern(
372 const LLVMTypeConverter &typeConverter)
373 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
377 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
378 ConversionPatternRewriter &rewriter)
const override;
381class ConvertAsyncYieldToGpuRuntimeCallPattern
382 :
public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
384 ConvertAsyncYieldToGpuRuntimeCallPattern(
385 const LLVMTypeConverter &typeConverter)
386 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
390 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
391 ConversionPatternRewriter &rewriter)
const override;
396class ConvertWaitOpToGpuRuntimeCallPattern
397 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
399 ConvertWaitOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
400 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
404 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
405 ConversionPatternRewriter &rewriter)
const override;
410class ConvertWaitAsyncOpToGpuRuntimeCallPattern
411 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
413 ConvertWaitAsyncOpToGpuRuntimeCallPattern(
414 const LLVMTypeConverter &typeConverter)
415 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
419 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
420 ConversionPatternRewriter &rewriter)
const override;
424class LegalizeLaunchFuncOpPattern
425 :
public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
427 LegalizeLaunchFuncOpPattern(
const LLVMTypeConverter &typeConverter,
428 bool kernelBarePtrCallConv,
429 bool kernelIntersperseSizeCallConv)
430 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
431 kernelBarePtrCallConv(kernelBarePtrCallConv),
432 kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
436 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
437 ConversionPatternRewriter &rewriter)
const override;
439 bool kernelBarePtrCallConv;
440 bool kernelIntersperseSizeCallConv;
445class ConvertMemcpyOpToGpuRuntimeCallPattern
446 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
448 ConvertMemcpyOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
449 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
453 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
454 ConversionPatternRewriter &rewriter)
const override;
459class ConvertMemsetOpToGpuRuntimeCallPattern
460 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
462 ConvertMemsetOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
463 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
467 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
468 ConversionPatternRewriter &rewriter)
const override;
473class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
474 :
public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
476 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
477 const LLVMTypeConverter &typeConverter)
478 : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
482 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
483 ConversionPatternRewriter &rewriter)
const override;
488#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
489 class Convert##op_name##ToGpuRuntimeCallPattern \
490 : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
492 Convert##op_name##ToGpuRuntimeCallPattern( \
493 const LLVMTypeConverter &typeConverter) \
494 : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
498 matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
499 ConversionPatternRewriter &rewriter) const override; \
526void GpuToLLVMConversionPass::runOnOperation() {
537 vector::populateVectorFromElementsUnrollPatterns(
patterns);
539 return signalPassFailure();
542 LowerToLLVMOptions
options(context);
543 options.useBarePtrCallConv = hostBarePtrCallConv;
544 RewritePatternSet
patterns(context);
545 ConversionTarget
target(*context);
546 target.addLegalDialect<LLVM::LLVMDialect>();
547 LLVMTypeConverter converter(context,
options);
552 auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
555 iface->populateConvertToLLVMConversionPatterns(
target, converter,
patterns);
560 target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
562 target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
563 [&](gpu::LaunchFuncOp op) ->
bool {
return converter.isLegal(op); });
571 kernelBarePtrCallConv,
572 kernelIntersperseSizeCallConv);
575 applyPartialConversion(getOperation(),
target, std::move(
patterns))))
581 auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
582 auto function = [&] {
583 if (
auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(
functionName))
588 return LLVM::CallOp::create(builder, loc, function, arguments);
605 llvm_unreachable(
"unsupported type");
611 if (llvm::isa<ComplexType>(type)) {
613 auto elementType = cast<ComplexType>(type).getElementType();
614 if (elementType.isBF16())
616 if (elementType.isF16())
618 if (elementType.isF32())
620 if (elementType.isF64())
622 if (elementType.isInteger(8))
624 if (elementType.isInteger(16))
626 if (elementType.isInteger(32))
644 llvm_unreachable(
"unsupported element type");
648 return spMat.
getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
673 llvm_unreachable(
"cannot find spmat def");
678 auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
690 ConversionPatternRewriter &rewriter) {
691 if (!llvm::all_of(operands, [](
Value value) {
694 return rewriter.notifyMatchFailure(
695 op,
"Cannot convert if operands aren't of LLVM type.");
701 gpu::AsyncOpInterface op) {
702 if (op.getAsyncDependencies().size() != 1)
703 return rewriter.notifyMatchFailure(
704 op,
"Can only convert with exactly one async dependency.");
706 if (!op.getAsyncToken())
707 return rewriter.notifyMatchFailure(op,
"Can convert only async version.");
712LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
713 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
714 ConversionPatternRewriter &rewriter)
const {
715 auto *op = hostRegisterOp.getOperation();
719 Location loc = op->getLoc();
721 auto memRefType = hostRegisterOp.getValue().getType();
722 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
725 auto arguments = getTypeConverter()->promoteOperands(
726 loc, op->getOperands(), adaptor.getOperands(), rewriter);
727 arguments.push_back(elementSize);
728 hostRegisterCallBuilder.create(loc, rewriter, arguments);
730 rewriter.eraseOp(op);
734LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
735 gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
736 ConversionPatternRewriter &rewriter)
const {
737 Operation *op = hostUnregisterOp.getOperation();
741 Location loc = op->
getLoc();
743 auto memRefType = hostUnregisterOp.getValue().getType();
744 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
747 auto arguments = getTypeConverter()->promoteOperands(
748 loc, op->
getOperands(), adaptor.getOperands(), rewriter);
749 arguments.push_back(elementSize);
750 hostUnregisterCallBuilder.create(loc, rewriter, arguments);
752 rewriter.eraseOp(op);
756LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
757 gpu::AllocOp allocOp, OpAdaptor adaptor,
758 ConversionPatternRewriter &rewriter)
const {
760 MemRefType memRefType = allocOp.getType();
763 !isConvertibleAndHasIdentityMaps(memRefType))
766 auto loc = allocOp.getLoc();
768 bool isShared = allocOp.getHostShared();
770 if (isShared && allocOp.getAsyncToken())
771 return rewriter.notifyMatchFailure(
772 allocOp,
"Host Shared allocation cannot be done async");
778 SmallVector<Value, 4> shape;
779 SmallVector<Value, 4> strides;
781 getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
782 shape, strides, sizeBytes);
786 auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType);
787 Value stream = adaptor.getAsyncDependencies().empty()
789 : adaptor.getAsyncDependencies().front();
791 auto isHostShared = mlir::LLVM::ConstantOp::create(
792 rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
795 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
799 Value alignedPtr = allocatedPtr;
802 auto memRefDescriptor = this->createMemRefDescriptor(
803 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
805 if (allocOp.getAsyncToken()) {
807 rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
809 rewriter.replaceOp(allocOp, {memRefDescriptor});
815LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
816 gpu::DeallocOp deallocOp, OpAdaptor adaptor,
817 ConversionPatternRewriter &rewriter)
const {
822 Location loc = deallocOp.getLoc();
825 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
826 Value stream = adaptor.getAsyncDependencies().front();
827 deallocCallBuilder.create(loc, rewriter, {pointer, stream});
829 rewriter.replaceOp(deallocOp, {stream});
834 return isa<gpu::AsyncTokenType>(value.
getType());
841LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
842 async::YieldOp yieldOp, OpAdaptor adaptor,
843 ConversionPatternRewriter &rewriter)
const {
845 return rewriter.notifyMatchFailure(yieldOp,
"no gpu async token operand");
847 Location loc = yieldOp.getLoc();
848 SmallVector<Value, 4> newOperands(adaptor.getOperands());
849 llvm::SmallDenseSet<Value> streams;
850 for (
auto &operand : yieldOp->getOpOperands()) {
853 auto idx = operand.getOperandNumber();
854 auto stream = adaptor.getOperands()[idx];
855 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
856 eventRecordCallBuilder.create(loc, rewriter, {event, stream});
857 newOperands[idx] = event;
858 streams.insert(stream);
860 for (
auto stream : streams)
861 streamDestroyCallBuilder.create(loc, rewriter, {stream});
863 rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
869 assert(isa<LLVM::LLVMPointerType>(value.
getType()));
871 return *defOp.getCallee() == functionName;
879LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
880 gpu::WaitOp waitOp, OpAdaptor adaptor,
881 ConversionPatternRewriter &rewriter)
const {
882 if (waitOp.getAsyncToken())
883 return rewriter.notifyMatchFailure(waitOp,
"Cannot convert async op.");
885 Location loc = waitOp.getLoc();
887 for (
auto operand : adaptor.getOperands()) {
890 streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
891 streamDestroyCallBuilder.create(loc, rewriter, {operand});
895 eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
896 eventDestroyCallBuilder.create(loc, rewriter, {operand});
900 rewriter.eraseOp(waitOp);
909LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
910 gpu::WaitOp waitOp, OpAdaptor adaptor,
911 ConversionPatternRewriter &rewriter)
const {
912 if (!waitOp.getAsyncToken())
913 return rewriter.notifyMatchFailure(waitOp,
"Can only convert async op.");
915 Location loc = waitOp.getLoc();
917 auto insertionPoint = rewriter.saveInsertionPoint();
918 SmallVector<Value, 1> events;
920 llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
921 auto operand = std::get<1>(pair);
925 auto *defOp = std::get<0>(pair).getDefiningOp();
926 rewriter.setInsertionPointAfter(defOp);
927 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
928 eventRecordCallBuilder.create(loc, rewriter, {event, operand});
929 events.push_back(event);
933 events.push_back(operand);
936 rewriter.restoreInsertionPoint(insertionPoint);
937 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
938 for (
auto event : events)
939 streamWaitEventCallBuilder.create(loc, rewriter, {stream,
event});
940 for (
auto event : events)
941 eventDestroyCallBuilder.create(loc, rewriter, {
event});
942 rewriter.replaceOp(waitOp, {stream});
948LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
949 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
950 ConversionPatternRewriter &rewriter)
const {
954 if (launchOp.getAsyncDependencies().size() > 1)
955 return rewriter.notifyMatchFailure(
956 launchOp,
"Cannot convert with more than one async dependency.");
961 if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
962 return rewriter.notifyMatchFailure(
963 launchOp,
"Cannot convert non-async op with async dependencies.");
965 Location loc = launchOp.getLoc();
967 Value stream = Value();
968 if (!adaptor.getAsyncDependencies().empty())
969 stream = adaptor.getAsyncDependencies().front();
972 else if (launchOp.getAsyncToken())
973 stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
978 OperandRange origArguments = launchOp.getKernelOperands();
979 SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
980 loc, origArguments, adaptor.getKernelOperands(), rewriter,
981 kernelBarePtrCallConv);
982 SmallVector<Value, 8> llvmArgumentsWithSizes;
985 if (kernelIntersperseSizeCallConv) {
986 if (origArguments.size() != llvmArguments.size()) {
988 return rewriter.notifyMatchFailure(
990 "Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
993 llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
994 for (
auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
995 auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
997 return rewriter.notifyMatchFailure(
998 launchOp,
"Operand to launch op is not a memref.");
1001 if (!memrefTy.hasStaticShape() ||
1002 !memrefTy.getElementType().isIntOrFloat()) {
1003 return rewriter.notifyMatchFailure(
1004 launchOp,
"Operand to launch op is not a memref with a static "
1005 "shape and an integer or float element type.");
1008 unsigned bitwidth = memrefTy.getElementTypeBitWidth();
1009 if (bitwidth % 8 != 0) {
1010 return rewriter.notifyMatchFailure(
1011 launchOp,
"Operand to launch op is not a memref with a "
1012 "byte-aligned element type.");
1015 uint64_t staticSize =
static_cast<uint64_t
>(bitwidth / 8) *
1016 static_cast<uint64_t
>(memrefTy.getNumElements());
1018 Value sizeArg = LLVM::ConstantOp::create(
1019 rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize));
1020 llvmArgumentsWithSizes.push_back(llvmArg);
1021 llvmArgumentsWithSizes.push_back(sizeArg);
1025 std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1026 if (launchOp.hasClusterSize()) {
1028 gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
1029 adaptor.getClusterSizeZ()};
1031 gpu::LaunchFuncOp::create(
1032 rewriter, launchOp.getLoc(), launchOp.getKernelAttr(),
1033 gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1034 adaptor.getGridSizeZ()},
1035 gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1036 adaptor.getBlockSizeZ()},
1037 adaptor.getDynamicSharedMemorySize(),
1038 llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
1039 stream, clusterSize);
1040 if (launchOp.getAsyncToken())
1041 rewriter.replaceOp(launchOp, {stream});
1043 rewriter.eraseOp(launchOp);
1048 ConversionPatternRewriter &rewriter,
1049 LLVM::LLVMPointerType destinationType,
1052 auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.
getType());
1053 if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1054 sourcePtr = LLVM::AddrSpaceCastOp::create(
1056 LLVM::LLVMPointerType::get(rewriter.getContext(),
1057 destinationType.getAddressSpace()),
1062LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1063 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1064 ConversionPatternRewriter &rewriter)
const {
1065 auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1068 !isConvertibleAndHasIdentityMaps(memRefType) ||
1072 auto loc = memcpyOp.getLoc();
1074 MemRefDescriptor srcDesc(adaptor.getSrc());
1075 Value numElements =
getNumElements(rewriter, loc, memRefType, srcDesc);
1077 Type elementPtrType = getElementPtrType(memRefType);
1078 Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
1079 Value gepPtr = LLVM::GEPOp::create(
1080 rewriter, loc, elementPtrType,
1081 typeConverter->convertType(memRefType.getElementType()), nullPtr,
1084 LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr);
1087 srcDesc.alignedPtr(rewriter, loc),
1088 *getTypeConverter());
1090 loc, rewriter, llvmPointerType,
1091 MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
1092 *getTypeConverter());
1094 auto stream = adaptor.getAsyncDependencies().front();
1095 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1097 rewriter.replaceOp(memcpyOp, {stream});
1102LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1103 gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1104 ConversionPatternRewriter &rewriter)
const {
1105 auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1108 !isConvertibleAndHasIdentityMaps(memRefType) ||
1112 auto loc = memsetOp.getLoc();
1114 Type valueType = adaptor.getValue().getType();
1117 if (!valueType.
isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1118 return rewriter.notifyMatchFailure(
1119 memsetOp,
"value must be a 16 or 32 bit int or float");
1123 Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1125 MemRefDescriptor dstDesc(adaptor.getDst());
1126 Value numElements =
getNumElements(rewriter, loc, memRefType, dstDesc);
1129 LLVM::BitcastOp::create(rewriter, loc, bitCastType, adaptor.getValue());
1131 dstDesc.alignedPtr(rewriter, loc),
1132 *getTypeConverter());
1134 auto stream = adaptor.getAsyncDependencies().front();
1135 FunctionCallBuilder builder =
1136 valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1137 builder.
create(loc, rewriter, {dst, value, numElements, stream});
1139 rewriter.replaceOp(memsetOp, {stream});
1143LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1144 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1145 ConversionPatternRewriter &rewriter)
const {
1146 Location loc = op.getLoc();
1147 auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1148 {adaptor.getDevIndex()});
1149 rewriter.replaceOp(op, call);
1153template <
typename T>
1156 return LLVM::ConstantOp::create(builder, loc, llvmInt32Type,
1157 static_cast<int32_t
>(tValue));
1160template <
typename T>
1163 return LLVM::ConstantOp::create(
1164 builder, loc, llvmFloat32Type,
1168LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1169 gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1170 ConversionPatternRewriter &rewriter)
const {
1174 Location loc = op.getLoc();
1175 auto stream = adaptor.getAsyncDependencies().front();
1177 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1178 Type dType = op.getMemref().
getType().getElementType();
1181 SmallVector<Value, 4> dims;
1182 for (Value dim : adaptor.getDims()) {
1183 dims.push_back(dim);
1193 if (dims.size() == 2) {
1195 auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1196 rewriter.getIndexAttr(11032));
1197 handle = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1198 llvmInt8Type, handleSz, 16);
1199 handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1201 createLtDnMatCallBuilder
1202 .create(loc, rewriter,
1203 {handle, dims[0], dims[1], pTensor, dtp, stream})
1207 createDnMatCallBuilder
1208 .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1212 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1213 handle = createDnVecCallBuilder
1214 .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1217 rewriter.replaceOp(op, {handle, stream});
1221LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1222 gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1223 ConversionPatternRewriter &rewriter)
const {
1227 Location loc = op.getLoc();
1228 auto stream = adaptor.getAsyncDependencies().front();
1229 auto definingOp = op.getDnTensor().
getDefiningOp<gpu::CreateDnTensorOp>();
1230 SmallVector<Value, 4> dims;
1231 for (Value dim : definingOp.getDims()) {
1232 dims.push_back(dim);
1234 if (dims.size() == 2) {
1238 destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1239 {adaptor.getDnTensor(), stream});
1241 destroyDnMatCallBuilder.create(loc, rewriter,
1242 {adaptor.getDnTensor(), stream});
1245 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1246 destroyDnVecCallBuilder.create(loc, rewriter,
1247 {adaptor.getDnTensor(), stream});
1249 rewriter.replaceOp(op, {stream});
1253LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1254 gpu::CreateCooOp op, OpAdaptor adaptor,
1255 ConversionPatternRewriter &rewriter)
const {
1259 Location loc = op.getLoc();
1260 auto stream = adaptor.getAsyncDependencies().front();
1262 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1264 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1266 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1268 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1270 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1274 createCooCallBuilder
1275 .create(loc, rewriter,
1276 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1277 pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1279 rewriter.replaceOp(op, {handle, stream});
1283LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1284 gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1285 ConversionPatternRewriter &rewriter)
const {
1289 Location loc = op.getLoc();
1290 auto stream = adaptor.getAsyncDependencies().front();
1291 Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc);
1293 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1294 Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1296 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1300 createCooAoSCallBuilder
1301 .create(loc, rewriter,
1302 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1303 pIdxs, pValues, itp, dtp, stream})
1305 rewriter.replaceOp(op, {handle, stream});
1309LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1310 gpu::CreateCsrOp op, OpAdaptor adaptor,
1311 ConversionPatternRewriter &rewriter)
const {
1315 Location loc = op.getLoc();
1316 auto stream = adaptor.getAsyncDependencies().front();
1318 MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc);
1320 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1322 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1324 llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1326 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1328 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1333 createCsrCallBuilder
1334 .create(loc, rewriter,
1335 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1336 pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1338 rewriter.replaceOp(op, {handle, stream});
1342LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1343 gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1344 ConversionPatternRewriter &rewriter)
const {
1348 Location loc = op.getLoc();
1349 auto stream = adaptor.getAsyncDependencies().front();
1351 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1353 llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1357 auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1358 rewriter.getIndexAttr(44104));
1359 Value handle = LLVM::AllocaOp::create(
1360 rewriter, loc, llvmPointerType, llvmInt8Type, handleSz, 16);
1361 handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1363 create2To4SpMatCallBuilder
1364 .create(loc, rewriter,
1365 {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1367 rewriter.replaceOp(op, {handle, stream});
1371LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1372 gpu::DestroySpMatOp op, OpAdaptor adaptor,
1373 ConversionPatternRewriter &rewriter)
const {
1377 Location loc = op.getLoc();
1378 auto stream = adaptor.getAsyncDependencies().front();
1381 destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1382 {adaptor.getSpmat(), stream});
1385 destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1387 rewriter.replaceOp(op, {stream});
1391LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1392 gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1393 ConversionPatternRewriter &rewriter)
const {
1397 Location loc = op.getLoc();
1401 auto stream = adaptor.getAsyncDependencies().front();
1402 auto bufferSize = spMVBufferSizeCallBuilder
1403 .create(loc, rewriter,
1404 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1405 adaptor.getDnY(), computeType, stream})
1407 rewriter.replaceOp(op, {bufferSize, stream});
1411LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1412 gpu::SpMVOp op, OpAdaptor adaptor,
1413 ConversionPatternRewriter &rewriter)
const {
1417 Location loc = op.getLoc();
1421 auto stream = adaptor.getAsyncDependencies().front();
1423 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1424 spMVCallBuilder.create(loc, rewriter,
1425 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1426 adaptor.getDnY(), computeType, pBuf, stream});
1427 rewriter.replaceOp(op, {stream});
1431LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1432 gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1433 ConversionPatternRewriter &rewriter)
const {
1437 Location loc = op.getLoc();
1440 auto stream = adaptor.getAsyncDependencies().front();
1447 auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1448 rewriter.getIndexAttr(3));
1450 LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmPointerType,
1452 createCuSparseLtSpMMBufferSizeBuilder
1453 .create(loc, rewriter,
1454 {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1455 adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1459 auto bufferSizePtr1 = LLVM::GEPOp::create(
1460 rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1461 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1462 rewriter.getIndexAttr(1))});
1463 auto bufferSizePtr2 = LLVM::GEPOp::create(
1464 rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1465 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1466 rewriter.getIndexAttr(2))});
1468 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSize);
1470 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr1);
1472 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr2);
1474 rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1479 createSpMMBufferSizeCallBuilder
1480 .create(loc, rewriter,
1481 {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1482 adaptor.getDnmatC(), computeType, stream})
1484 rewriter.replaceOp(op, {bufferSize, stream});
1489LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1490 gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1491 ConversionPatternRewriter &rewriter)
const {
1495 Location loc = op.getLoc();
1500 auto stream = adaptor.getAsyncDependencies().front();
1502 createSDDMMBufferSizeCallBuilder
1503 .create(loc, rewriter,
1504 {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1505 adaptor.getSpmatC(), computeType, stream})
1507 rewriter.replaceOp(op, {bufferSize, stream});
1511LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1512 gpu::SpMMOp op, OpAdaptor adaptor,
1513 ConversionPatternRewriter &rewriter)
const {
1517 Location loc = op.getLoc();
1523 auto stream = adaptor.getAsyncDependencies().front();
1527 SmallVector<Value> pBufs;
1528 for (Value buffer : adaptor.getBuffers()) {
1529 Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
1530 pBufs.push_back(pBuf);
1532 createCuSparseLtSpMMBuilder.create(
1534 {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1535 pBufs[0], pBufs[1], pBufs[2], stream});
1537 Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
1538 .allocatedPtr(rewriter, loc);
1539 createSpMMCallBuilder.create(loc, rewriter,
1540 {modeA, modeB, adaptor.getSpmatA(),
1541 adaptor.getDnmatB(), adaptor.getDnmatC(),
1542 computeType, pBuf, stream});
1544 rewriter.replaceOp(op, {stream});
1548template <
typename T>
1550 converter.addConversion([&converter](T) ->
Type {
1551 return LLVM::LLVMPointerType::get(&converter.
getContext());
1555LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1556 gpu::SDDMMOp op, OpAdaptor adaptor,
1557 ConversionPatternRewriter &rewriter)
const {
1561 Location loc = op.getLoc();
1566 auto stream = adaptor.getAsyncDependencies().front();
1568 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1569 createSDDMMCallBuilder.create(loc, rewriter,
1570 {modeA, modeB, adaptor.getDnmatA(),
1571 adaptor.getDnmatB(), adaptor.getSpmatC(),
1572 computeType, pBuf, stream});
1573 rewriter.replaceOp(op, {stream});
1578ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1579 gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1580 ConversionPatternRewriter &rewriter)
const {
1584 Location loc = op.getLoc();
1585 auto stream = adaptor.getAsyncDependencies().front();
1586 Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1588 rewriter.replaceOp(op, {descr, stream});
1593ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1594 gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1595 ConversionPatternRewriter &rewriter)
const {
1599 Location loc = op.getLoc();
1600 auto stream = adaptor.getAsyncDependencies().front();
1601 createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1602 {adaptor.getDesc(), stream});
1603 rewriter.replaceOp(op, {stream});
1608ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1609 gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1610 ConversionPatternRewriter &rewriter)
const {
1614 Location loc = op.getLoc();
1619 auto stream = adaptor.getAsyncDependencies().front();
1622 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1623 Value bufferSizeNew;
1625 if (adaptor.getKind() ==
1626 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1628 createSpGEMMWorkEstimationBuilder
1629 .create(loc, rewriter,
1630 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1631 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1632 adaptor.getBufferSz(), pBuf, stream})
1636 createSpGEMMComputeBuilder
1637 .create(loc, rewriter,
1638 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1639 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1640 adaptor.getBufferSz(), pBuf, stream})
1643 rewriter.replaceOp(op, {bufferSizeNew, stream});
1647LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1648 gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1649 ConversionPatternRewriter &rewriter)
const {
1653 Location loc = op.getLoc();
1658 auto stream = adaptor.getAsyncDependencies().front();
1659 createSpGEMMCopyBuilder.create(loc, rewriter,
1660 {adaptor.getDesc(), modeA, modeB,
1661 adaptor.getSpmatA(), adaptor.getSpmatB(),
1662 adaptor.getSpmatC(), computeType, stream});
1663 rewriter.replaceOp(op, {stream});
1667LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1668 gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1669 ConversionPatternRewriter &rewriter)
const {
1673 Location loc = op.getLoc();
1674 auto stream = adaptor.getAsyncDependencies().front();
1676 auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1677 rewriter.getIndexAttr(3));
1678 auto buffer = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1679 llvmInt64Type, three, 16);
1681 auto rowsPtr = LLVM::GEPOp::create(
1682 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1683 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1684 rewriter.getIndexAttr(0))});
1685 auto colsPtr = LLVM::GEPOp::create(
1686 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1687 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1688 rewriter.getIndexAttr(1))});
1689 auto nnzsPtr = LLVM::GEPOp::create(
1690 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1691 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1692 rewriter.getIndexAttr(2))});
1693 createSpMatGetSizeBuilder.create(
1694 loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1695 auto rows = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, rowsPtr);
1696 auto cols = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, colsPtr);
1697 auto nnzs = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, nnzsPtr);
1699 rewriter.replaceOp(op, {rows, cols, nnzs, stream});
1703LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1704 gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1705 ConversionPatternRewriter &rewriter)
const {
1709 Location loc = op.getLoc();
1710 auto stream = adaptor.getAsyncDependencies().front();
1712 MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
1714 MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
1716 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1717 createSetCsrPointersBuilder.create(
1718 loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1719 rewriter.replaceOp(op, {stream});
1723LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1724 gpu::CreateCscOp op, OpAdaptor adaptor,
1725 ConversionPatternRewriter &rewriter)
const {
1729 Location loc = op.getLoc();
1730 auto stream = adaptor.getAsyncDependencies().front();
1732 MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc);
1734 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1736 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1738 llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1740 llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1742 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1747 createCscCallBuilder
1748 .create(loc, rewriter,
1749 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1750 pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1752 rewriter.replaceOp(op, {handle, stream});
1756LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1757 gpu::CreateBsrOp op, OpAdaptor adaptor,
1758 ConversionPatternRewriter &rewriter)
const {
1762 Location loc = op.getLoc();
1763 auto stream = adaptor.getAsyncDependencies().front();
1765 MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc);
1767 MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc);
1769 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1771 llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1773 llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1775 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1780 createBsrCallBuilder
1781 .create(loc, rewriter,
1782 {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1783 adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1784 pColIdxs, pValues, ptp, itp, dtp, stream})
1786 rewriter.replaceOp(op, {handle, stream});
1792 bool kernelBarePtrCallConv,
bool kernelIntersperseSizeCallConv) {
1798 patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
1799 ConvertDeallocOpToGpuRuntimeCallPattern,
1800 ConvertHostRegisterOpToGpuRuntimeCallPattern,
1801 ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1802 ConvertMemcpyOpToGpuRuntimeCallPattern,
1803 ConvertMemsetOpToGpuRuntimeCallPattern,
1804 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1805 ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1806 ConvertWaitOpToGpuRuntimeCallPattern,
1807 ConvertAsyncYieldToGpuRuntimeCallPattern,
1808 ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1809 ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1810 ConvertCreateCooOpToGpuRuntimeCallPattern,
1811 ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1812 ConvertCreateCsrOpToGpuRuntimeCallPattern,
1813 ConvertCreateCscOpToGpuRuntimeCallPattern,
1814 ConvertCreateBsrOpToGpuRuntimeCallPattern,
1815 ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1816 ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1817 ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1818 ConvertSpMVOpToGpuRuntimeCallPattern,
1819 ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1820 ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1821 ConvertSpMMOpToGpuRuntimeCallPattern,
1822 ConvertSDDMMOpToGpuRuntimeCallPattern,
1823 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1824 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1825 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1826 ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1827 ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1828 ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
1829 patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv,
1830 kernelIntersperseSizeCallConv);
1838struct GPUModuleOpConvertToLLVMInterface
1839 :
public ConvertToLLVMOpInterface::ExternalModel<
1840 GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
1842 void getConvertToLLVMConversionAttrs(
1847void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
1848 Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs)
const {
1849 auto module = cast<gpu::GPUModuleOp>(op);
1850 ArrayAttr targetsAttr =
module.getTargetsAttr();
1852 if (!targetsAttr || targetsAttr.size() != 1)
1854 if (
auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
1855 attrs.push_back(patternAttr);
1860 gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
IntegerType getIntegerType(unsigned width)
FloatAttr getF32FloatAttr(float value)
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
This class helps build Operations.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void registerConvertGpuToLLVMInterface(DialectRegistry ®istry)
Registers the ConvertToLLVMOpInterface interface on the gpu::GPUModuleOP operation.
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool kernelBarePtrCallConv=false, bool kernelIntersperseSizeCallConv=false)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry ®istry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
LLVM::LLVMFunctionType functionType
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const