40#include "llvm/ADT/STLExtras.h"
42#define DEBUG_TYPE "gpu-to-llvm"
45#define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
46#include "mlir/Conversion/Passes.h.inc"
52class GpuToLLVMConversionPass
56 void getDependentDialects(DialectRegistry ®istry)
const final {
57 Base::getDependentDialects(registry);
61 void runOnOperation()
override;
64template <
typename OpTy>
67 explicit ConvertOpToGpuRuntimeCallPattern(
68 const LLVMTypeConverter &typeConverter)
69 : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
72 Value
getNumElements(ConversionPatternRewriter &rewriter, Location loc,
73 MemRefType type, MemRefDescriptor desc)
const {
75 if (type.hasStaticShape())
77 rewriter, loc, indexType, type.getNumElements());
79 uint64_t rank = type.getRank();
80 Value numElements = desc.
size(rewriter, loc, 0);
81 for (
unsigned i = 1; i < rank; i++)
82 numElements = LLVM::MulOp::create(rewriter, loc, numElements,
83 desc.
size(rewriter, loc, i));
87 MLIRContext *context = &this->getTypeConverter()->
getContext();
89 Type llvmVoidType = LLVM::LLVMVoidType::get(context);
90 LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
91 Type llvmInt8Type = IntegerType::get(context, 8);
92 Type llvmInt16Type = IntegerType::get(context, 16);
93 Type llvmInt32Type = IntegerType::get(context, 32);
94 Type llvmInt64Type = IntegerType::get(context, 64);
95 Type llvmFloat32Type = Float32Type::get(context);
96 Type llvmIntPtrType = IntegerType::get(
97 context, this->getTypeConverter()->getPointerBitwidth(0));
99 FunctionCallBuilder streamCreateCallBuilder = {
100 "mgpuStreamCreate", llvmPointerType , {}};
101 FunctionCallBuilder streamDestroyCallBuilder = {
102 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType }};
103 FunctionCallBuilder streamSynchronizeCallBuilder = {
104 "mgpuStreamSynchronize",
107 FunctionCallBuilder streamWaitEventCallBuilder = {
108 "mgpuStreamWaitEvent",
110 {llvmPointerType , llvmPointerType }};
111 FunctionCallBuilder eventCreateCallBuilder = {
112 "mgpuEventCreate", llvmPointerType , {}};
113 FunctionCallBuilder eventDestroyCallBuilder = {
114 "mgpuEventDestroy", llvmVoidType, {llvmPointerType }};
115 FunctionCallBuilder eventSynchronizeCallBuilder = {
116 "mgpuEventSynchronize",
119 FunctionCallBuilder eventRecordCallBuilder = {
122 {llvmPointerType , llvmPointerType }};
123 FunctionCallBuilder hostRegisterCallBuilder = {
124 "mgpuMemHostRegisterMemRef",
129 FunctionCallBuilder hostUnregisterCallBuilder = {
130 "mgpuMemHostUnregisterMemRef",
135 FunctionCallBuilder allocCallBuilder = {
141 FunctionCallBuilder deallocCallBuilder = {
144 {llvmPointerType , llvmPointerType }};
145 FunctionCallBuilder memcpyCallBuilder = {
148 {llvmPointerType , llvmPointerType ,
151 FunctionCallBuilder memset16CallBuilder = {
158 FunctionCallBuilder memset32CallBuilder = {
161 {llvmPointerType , llvmInt32Type ,
164 FunctionCallBuilder setDefaultDeviceCallBuilder = {
165 "mgpuSetDefaultDevice",
168 FunctionCallBuilder createDnVecCallBuilder = {
171 {llvmIntPtrType, llvmPointerType, llvmInt32Type,
173 FunctionCallBuilder destroyDnVecCallBuilder = {
176 {llvmPointerType, llvmPointerType }};
177 FunctionCallBuilder createDnMatCallBuilder = {
180 {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
182 FunctionCallBuilder destroyDnMatCallBuilder = {
185 {llvmPointerType, llvmPointerType }};
186 FunctionCallBuilder createCooCallBuilder = {
189 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
190 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
192 FunctionCallBuilder createCooAoSCallBuilder = {
195 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
196 llvmPointerType, llvmInt32Type, llvmInt32Type,
198 FunctionCallBuilder createCsrCallBuilder = {
201 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
202 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
203 llvmInt32Type, llvmPointerType }};
204 FunctionCallBuilder createCscCallBuilder = {
207 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
208 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
209 llvmInt32Type, llvmPointerType }};
210 FunctionCallBuilder createBsrCallBuilder = {
213 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
214 llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
215 llvmInt32Type, llvmInt32Type, llvmInt32Type,
217 FunctionCallBuilder destroySpMatCallBuilder = {
220 {llvmPointerType, llvmPointerType }};
221 FunctionCallBuilder spMVBufferSizeCallBuilder = {
222 "mgpuSpMVBufferSize",
224 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
225 llvmInt32Type, llvmPointerType }};
226 FunctionCallBuilder spMVCallBuilder = {
229 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
230 llvmInt32Type, llvmPointerType, llvmPointerType }};
231 FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
232 "mgpuSpMMBufferSize",
234 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
235 llvmPointerType, llvmInt32Type, llvmPointerType }};
236 FunctionCallBuilder createSpMMCallBuilder = {
239 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
240 llvmPointerType, llvmInt32Type, llvmPointerType,
242 FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
243 "mgpuSDDMMBufferSize",
245 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
246 llvmPointerType, llvmInt32Type, llvmPointerType }};
247 FunctionCallBuilder createSDDMMCallBuilder = {
250 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
251 llvmPointerType, llvmInt32Type, llvmPointerType,
253 FunctionCallBuilder createLtDnMatCallBuilder = {
254 "mgpuCreateCuSparseLtDnMat",
256 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
257 llvmInt32Type, llvmPointerType }};
258 FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
259 "mgpuDestroyCuSparseLtSpMat",
261 {llvmPointerType, llvmPointerType }};
262 FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
263 "mgpuDestroyCuSparseLtDnMat",
265 {llvmPointerType, llvmPointerType }};
266 FunctionCallBuilder create2To4SpMatCallBuilder = {
267 "mgpuCusparseLtCreate2To4SpMat",
269 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
270 llvmInt32Type, llvmPointerType }};
271 FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
272 "mgpuCuSparseLtSpMMBufferSize",
274 {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
275 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
277 FunctionCallBuilder createCuSparseLtSpMMBuilder = {
278 "mgpuCuSparseLtSpMM",
280 {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
281 llvmPointerType, llvmPointerType, llvmPointerType }};
282 FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
283 "mgpuSpGEMMCreateDescr",
286 FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
287 "mgpuSpGEMMDestroyDescr",
289 {llvmPointerType , llvmPointerType }};
290 FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
291 "mgpuSpGEMMWorkEstimation",
293 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
294 llvmPointerType , llvmPointerType , llvmPointerType ,
295 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
297 FunctionCallBuilder createSpGEMMComputeBuilder = {
300 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
301 llvmPointerType , llvmPointerType , llvmPointerType ,
302 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
304 FunctionCallBuilder createSpGEMMCopyBuilder = {
307 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
308 llvmPointerType , llvmPointerType , llvmPointerType ,
309 llvmInt32Type , llvmPointerType }};
310 FunctionCallBuilder createSpMatGetSizeBuilder = {
313 {llvmPointerType , llvmPointerType , llvmPointerType ,
314 llvmPointerType , llvmPointerType }};
315 FunctionCallBuilder createSetCsrPointersBuilder = {
316 "mgpuSetCsrPointers",
318 {llvmPointerType , llvmPointerType ,
319 llvmPointerType , llvmPointerType ,
325class ConvertHostRegisterOpToGpuRuntimeCallPattern
326 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
328 ConvertHostRegisterOpToGpuRuntimeCallPattern(
329 const LLVMTypeConverter &typeConverter)
330 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
334 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter)
const override;
338class ConvertHostUnregisterOpToGpuRuntimeCallPattern
339 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
341 ConvertHostUnregisterOpToGpuRuntimeCallPattern(
342 const LLVMTypeConverter &typeConverter)
343 : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
348 matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter)
const override;
354class ConvertAllocOpToGpuRuntimeCallPattern
355 :
public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
357 ConvertAllocOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
358 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
362 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter)
const override;
368class ConvertDeallocOpToGpuRuntimeCallPattern
369 :
public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
371 ConvertDeallocOpToGpuRuntimeCallPattern(
372 const LLVMTypeConverter &typeConverter)
373 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
377 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
378 ConversionPatternRewriter &rewriter)
const override;
381class ConvertAsyncYieldToGpuRuntimeCallPattern
382 :
public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
384 ConvertAsyncYieldToGpuRuntimeCallPattern(
385 const LLVMTypeConverter &typeConverter)
386 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
390 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
391 ConversionPatternRewriter &rewriter)
const override;
396class ConvertWaitOpToGpuRuntimeCallPattern
397 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
399 ConvertWaitOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
400 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
404 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
405 ConversionPatternRewriter &rewriter)
const override;
410class ConvertWaitAsyncOpToGpuRuntimeCallPattern
411 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
413 ConvertWaitAsyncOpToGpuRuntimeCallPattern(
414 const LLVMTypeConverter &typeConverter)
415 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
419 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
420 ConversionPatternRewriter &rewriter)
const override;
424class LegalizeLaunchFuncOpPattern
425 :
public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
427 LegalizeLaunchFuncOpPattern(
const LLVMTypeConverter &typeConverter,
428 bool kernelBarePtrCallConv,
429 bool kernelIntersperseSizeCallConv)
430 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
431 kernelBarePtrCallConv(kernelBarePtrCallConv),
432 kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
436 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
437 ConversionPatternRewriter &rewriter)
const override;
439 bool kernelBarePtrCallConv;
440 bool kernelIntersperseSizeCallConv;
445class ConvertMemcpyOpToGpuRuntimeCallPattern
446 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
448 ConvertMemcpyOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
449 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
453 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
454 ConversionPatternRewriter &rewriter)
const override;
459class ConvertMemsetOpToGpuRuntimeCallPattern
460 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
462 ConvertMemsetOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
463 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
467 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
468 ConversionPatternRewriter &rewriter)
const override;
473class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
474 :
public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
476 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
477 const LLVMTypeConverter &typeConverter)
478 : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
482 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
483 ConversionPatternRewriter &rewriter)
const override;
488#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
489 class Convert##op_name##ToGpuRuntimeCallPattern \
490 : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
492 Convert##op_name##ToGpuRuntimeCallPattern( \
493 const LLVMTypeConverter &typeConverter) \
494 : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
498 matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
499 ConversionPatternRewriter &rewriter) const override; \
526void GpuToLLVMConversionPass::runOnOperation() {
537 vector::populateVectorFromElementsUnrollPatterns(patterns);
539 return signalPassFailure();
542 LowerToLLVMOptions
options(context);
543 options.useBarePtrCallConv = hostBarePtrCallConv;
544 RewritePatternSet patterns(context);
545 ConversionTarget
target(*context);
546 target.addLegalDialect<LLVM::LLVMDialect>();
547 LLVMTypeConverter converter(context,
options);
552 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
555 iface->populateConvertToLLVMConversionPatterns(
target, converter, patterns);
560 target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
562 target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
563 [&](gpu::LaunchFuncOp op) ->
bool {
return converter.isLegal(op); });
571 kernelBarePtrCallConv,
572 kernelIntersperseSizeCallConv);
575 applyPartialConversion(getOperation(),
target, std::move(patterns))))
581 auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
582 auto function = [&] {
583 if (
auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(
functionName))
588 return LLVM::CallOp::create(builder, loc, function, arguments);
605 llvm_unreachable(
"unsupported type");
611 if (llvm::isa<ComplexType>(type)) {
613 auto elementType = cast<ComplexType>(type).getElementType();
614 if (elementType.isBF16())
616 if (elementType.isF16())
618 if (elementType.isF32())
620 if (elementType.isF64())
622 if (elementType.isInteger(8))
624 if (elementType.isInteger(16))
626 if (elementType.isInteger(32))
644 llvm_unreachable(
"unsupported element type");
648 return spMat.
getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
673 llvm_unreachable(
"cannot find spmat def");
678 auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
690 ConversionPatternRewriter &rewriter) {
691 if (!llvm::all_of(operands, [](
Value value) {
694 return rewriter.notifyMatchFailure(
695 op,
"Cannot convert if operands aren't of LLVM type.");
701 gpu::AsyncOpInterface op) {
702 if (op.getAsyncDependencies().size() != 1)
703 return rewriter.notifyMatchFailure(
704 op,
"Can only convert with exactly one async dependency.");
706 if (!op.getAsyncToken())
707 return rewriter.notifyMatchFailure(op,
"Can convert only async version.");
712LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
713 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
714 ConversionPatternRewriter &rewriter)
const {
715 auto *op = hostRegisterOp.getOperation();
719 Location loc = op->getLoc();
721 auto memRefType = hostRegisterOp.getValue().getType();
722 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
725 auto arguments = getTypeConverter()->promoteOperands(
726 loc, op->getOperands(), adaptor.getOperands(), rewriter);
727 arguments.push_back(elementSize);
728 hostRegisterCallBuilder.create(loc, rewriter, arguments);
730 rewriter.eraseOp(op);
734LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
735 gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
736 ConversionPatternRewriter &rewriter)
const {
737 Operation *op = hostUnregisterOp.getOperation();
741 Location loc = op->
getLoc();
743 auto memRefType = hostUnregisterOp.getValue().getType();
744 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
747 auto arguments = getTypeConverter()->promoteOperands(
748 loc, op->
getOperands(), adaptor.getOperands(), rewriter);
749 arguments.push_back(elementSize);
750 hostUnregisterCallBuilder.create(loc, rewriter, arguments);
752 rewriter.eraseOp(op);
756LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
757 gpu::AllocOp allocOp, OpAdaptor adaptor,
758 ConversionPatternRewriter &rewriter)
const {
760 MemRefType memRefType = allocOp.getType();
763 !isConvertibleAndHasIdentityMaps(memRefType))
766 auto loc = allocOp.getLoc();
768 bool isShared = allocOp.getHostShared();
770 if (isShared && allocOp.getAsyncToken())
771 return rewriter.notifyMatchFailure(
772 allocOp,
"Host Shared allocation cannot be done async");
778 SmallVector<Value, 4> shape;
779 SmallVector<Value, 4> strides;
781 getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
782 shape, strides, sizeBytes);
786 auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType);
787 Value stream = adaptor.getAsyncDependencies().empty()
789 : adaptor.getAsyncDependencies().front();
791 auto isHostShared = mlir::LLVM::ConstantOp::create(
792 rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
795 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
799 Value alignedPtr = allocatedPtr;
802 auto memRefDescriptor = this->createMemRefDescriptor(
803 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
805 if (allocOp.getAsyncToken()) {
807 rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
809 rewriter.replaceOp(allocOp, {memRefDescriptor});
815LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
816 gpu::DeallocOp deallocOp, OpAdaptor adaptor,
817 ConversionPatternRewriter &rewriter)
const {
822 Location loc = deallocOp.getLoc();
825 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
826 Value stream = adaptor.getAsyncDependencies().front();
827 deallocCallBuilder.create(loc, rewriter, {pointer, stream});
829 rewriter.replaceOp(deallocOp, {stream});
834 return isa<gpu::AsyncTokenType>(value.
getType());
841LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
842 async::YieldOp yieldOp, OpAdaptor adaptor,
843 ConversionPatternRewriter &rewriter)
const {
845 return rewriter.notifyMatchFailure(yieldOp,
"no gpu async token operand");
847 Location loc = yieldOp.getLoc();
848 SmallVector<Value, 4> newOperands(adaptor.getOperands());
849 llvm::SmallDenseSet<Value> streams;
850 for (
auto &operand : yieldOp->getOpOperands()) {
853 auto idx = operand.getOperandNumber();
854 auto stream = adaptor.getOperands()[idx];
855 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
856 eventRecordCallBuilder.create(loc, rewriter, {event, stream});
857 newOperands[idx] = event;
858 streams.insert(stream);
860 for (
auto stream : streams)
861 streamDestroyCallBuilder.create(loc, rewriter, {stream});
863 rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
869 assert(isa<LLVM::LLVMPointerType>(value.
getType()));
871 return *defOp.getCallee() == functionName;
879LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
880 gpu::WaitOp waitOp, OpAdaptor adaptor,
881 ConversionPatternRewriter &rewriter)
const {
882 if (waitOp.getAsyncToken())
883 return rewriter.notifyMatchFailure(waitOp,
"Cannot convert async op.");
885 Location loc = waitOp.getLoc();
887 for (
auto operand : adaptor.getOperands()) {
890 streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
891 streamDestroyCallBuilder.create(loc, rewriter, {operand});
895 eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
896 eventDestroyCallBuilder.create(loc, rewriter, {operand});
900 rewriter.eraseOp(waitOp);
909LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
910 gpu::WaitOp waitOp, OpAdaptor adaptor,
911 ConversionPatternRewriter &rewriter)
const {
912 if (!waitOp.getAsyncToken())
913 return rewriter.notifyMatchFailure(waitOp,
"Can only convert async op.");
915 Location loc = waitOp.getLoc();
917 auto insertionPoint = rewriter.saveInsertionPoint();
918 SmallVector<Value, 1> events;
920 llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
921 auto operand = std::get<1>(pair);
925 auto *defOp = std::get<0>(pair).getDefiningOp();
926 rewriter.setInsertionPointAfter(defOp);
927 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
928 eventRecordCallBuilder.create(loc, rewriter, {event, operand});
929 events.push_back(event);
933 events.push_back(operand);
936 rewriter.restoreInsertionPoint(insertionPoint);
937 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
938 for (
auto event : events)
939 streamWaitEventCallBuilder.create(loc, rewriter, {stream,
event});
940 for (
auto event : events)
941 eventDestroyCallBuilder.create(loc, rewriter, {
event});
942 rewriter.replaceOp(waitOp, {stream});
948LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
949 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
950 ConversionPatternRewriter &rewriter)
const {
957 if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
958 return rewriter.notifyMatchFailure(
959 launchOp,
"Cannot convert non-async op with async dependencies.");
961 Location loc = launchOp.getLoc();
963 Value stream = Value();
964 if (!adaptor.getAsyncDependencies().empty()) {
965 stream = adaptor.getAsyncDependencies().front();
968 if (adaptor.getAsyncDependencies().size() > 1) {
969 auto insertionPoint = rewriter.saveInsertionPoint();
970 SmallVector<Value, 4> events;
971 for (
auto [origDep, convertedDep] :
972 llvm::zip(launchOp.getAsyncDependencies().drop_front(),
973 adaptor.getAsyncDependencies().drop_front())) {
975 streamCreateCallBuilder.functionName)) {
976 events.push_back(convertedDep);
979 Operation *defOp = origDep.getDefiningOp();
980 rewriter.setInsertionPointAfter(defOp);
982 eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
983 eventRecordCallBuilder.create(loc, rewriter, {event, convertedDep});
984 events.push_back(event);
986 rewriter.restoreInsertionPoint(insertionPoint);
987 for (Value event : events)
988 streamWaitEventCallBuilder.create(loc, rewriter, {stream,
event});
989 for (Value event : events)
990 eventDestroyCallBuilder.create(loc, rewriter, {
event});
995 else if (launchOp.getAsyncToken())
996 stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
1001 OperandRange origArguments = launchOp.getKernelOperands();
1002 bool effectiveBarePtr = kernelBarePtrCallConv ||
1003 getTypeConverter()->getOptions().useBarePtrCallConv;
1004 if (effectiveBarePtr) {
1005 for (Value arg : origArguments) {
1006 if (isa<UnrankedMemRefType>(arg.getType()))
1007 return rewriter.notifyMatchFailure(
1008 loc,
"unranked memref kernel argument is not supported with "
1009 "the bare-pointer calling convention");
1012 SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
1013 loc, origArguments, adaptor.getKernelOperands(), rewriter,
1014 kernelBarePtrCallConv);
1015 SmallVector<Value, 8> llvmArgumentsWithSizes;
1018 if (kernelIntersperseSizeCallConv) {
1019 if (origArguments.size() != llvmArguments.size()) {
1021 return rewriter.notifyMatchFailure(
1023 "Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
1026 llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
1027 for (
auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
1028 auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
1030 return rewriter.notifyMatchFailure(
1031 launchOp,
"Operand to launch op is not a memref.");
1034 if (!memrefTy.hasStaticShape() ||
1035 !memrefTy.getElementType().isIntOrFloat()) {
1036 return rewriter.notifyMatchFailure(
1037 launchOp,
"Operand to launch op is not a memref with a static "
1038 "shape and an integer or float element type.");
1041 unsigned bitwidth = memrefTy.getElementTypeBitWidth();
1042 if (bitwidth % 8 != 0) {
1043 return rewriter.notifyMatchFailure(
1044 launchOp,
"Operand to launch op is not a memref with a "
1045 "byte-aligned element type.");
1048 uint64_t staticSize =
static_cast<uint64_t
>(bitwidth / 8) *
1049 static_cast<uint64_t
>(memrefTy.getNumElements());
1051 Value sizeArg = LLVM::ConstantOp::create(
1052 rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize));
1053 llvmArgumentsWithSizes.push_back(llvmArg);
1054 llvmArgumentsWithSizes.push_back(sizeArg);
1058 std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1059 if (launchOp.hasClusterSize()) {
1061 gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
1062 adaptor.getClusterSizeZ()};
1064 gpu::LaunchFuncOp::create(
1065 rewriter, launchOp.getLoc(), launchOp.getKernelAttr(),
1066 gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1067 adaptor.getGridSizeZ()},
1068 gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1069 adaptor.getBlockSizeZ()},
1070 adaptor.getDynamicSharedMemorySize(),
1071 llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
1072 stream, clusterSize);
1073 if (launchOp.getAsyncToken())
1074 rewriter.replaceOp(launchOp, {stream});
1076 rewriter.eraseOp(launchOp);
1081 ConversionPatternRewriter &rewriter,
1082 LLVM::LLVMPointerType destinationType,
1085 auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.
getType());
1086 if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1087 sourcePtr = LLVM::AddrSpaceCastOp::create(
1089 LLVM::LLVMPointerType::get(rewriter.getContext(),
1090 destinationType.getAddressSpace()),
1095LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1096 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1097 ConversionPatternRewriter &rewriter)
const {
1098 auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1101 !isConvertibleAndHasIdentityMaps(memRefType) ||
1105 auto loc = memcpyOp.getLoc();
1107 MemRefDescriptor srcDesc(adaptor.getSrc());
1108 Value numElements =
getNumElements(rewriter, loc, memRefType, srcDesc);
1110 Type elementPtrType = getElementPtrType(memRefType);
1111 Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
1112 Value gepPtr = LLVM::GEPOp::create(
1113 rewriter, loc, elementPtrType,
1114 typeConverter->convertType(memRefType.getElementType()), nullPtr,
1117 LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr);
1120 srcDesc.alignedPtr(rewriter, loc),
1121 *getTypeConverter());
1123 loc, rewriter, llvmPointerType,
1124 MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
1125 *getTypeConverter());
1127 auto stream = adaptor.getAsyncDependencies().front();
1128 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1130 rewriter.replaceOp(memcpyOp, {stream});
1135LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1136 gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1137 ConversionPatternRewriter &rewriter)
const {
1138 auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1141 !isConvertibleAndHasIdentityMaps(memRefType) ||
1145 auto loc = memsetOp.getLoc();
1147 Type valueType = adaptor.getValue().getType();
1150 if (!valueType.
isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1151 return rewriter.notifyMatchFailure(
1152 memsetOp,
"value must be a 16 or 32 bit int or float");
1156 Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1158 MemRefDescriptor dstDesc(adaptor.getDst());
1159 Value numElements =
getNumElements(rewriter, loc, memRefType, dstDesc);
1162 LLVM::BitcastOp::create(rewriter, loc, bitCastType, adaptor.getValue());
1164 dstDesc.alignedPtr(rewriter, loc),
1165 *getTypeConverter());
1167 auto stream = adaptor.getAsyncDependencies().front();
1168 FunctionCallBuilder builder =
1169 valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1170 builder.
create(loc, rewriter, {dst, value, numElements, stream});
1172 rewriter.replaceOp(memsetOp, {stream});
1176LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1177 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1178 ConversionPatternRewriter &rewriter)
const {
1179 Location loc = op.getLoc();
1180 auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1181 {adaptor.getDevIndex()});
1182 rewriter.replaceOp(op, call);
1186template <
typename T>
1189 return LLVM::ConstantOp::create(builder, loc, llvmInt32Type,
1190 static_cast<int32_t
>(tValue));
1193template <
typename T>
1196 return LLVM::ConstantOp::create(
1197 builder, loc, llvmFloat32Type,
1201LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1202 gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1203 ConversionPatternRewriter &rewriter)
const {
1207 Location loc = op.getLoc();
1208 auto stream = adaptor.getAsyncDependencies().front();
1210 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1211 Type dType = op.getMemref().
getType().getElementType();
1214 SmallVector<Value, 4> dims;
1215 for (Value dim : adaptor.getDims()) {
1216 dims.push_back(dim);
1226 if (dims.size() == 2) {
1228 auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1229 rewriter.getIndexAttr(11032));
1230 handle = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1231 llvmInt8Type, handleSz, 16);
1232 handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1234 createLtDnMatCallBuilder
1235 .create(loc, rewriter,
1236 {handle, dims[0], dims[1], pTensor, dtp, stream})
1240 createDnMatCallBuilder
1241 .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1245 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1246 handle = createDnVecCallBuilder
1247 .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1250 rewriter.replaceOp(op, {handle, stream});
1254LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1255 gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1256 ConversionPatternRewriter &rewriter)
const {
1260 Location loc = op.getLoc();
1261 auto stream = adaptor.getAsyncDependencies().front();
1262 auto definingOp = op.getDnTensor().
getDefiningOp<gpu::CreateDnTensorOp>();
1263 SmallVector<Value, 4> dims;
1264 for (Value dim : definingOp.getDims()) {
1265 dims.push_back(dim);
1267 if (dims.size() == 2) {
1271 destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1272 {adaptor.getDnTensor(), stream});
1274 destroyDnMatCallBuilder.create(loc, rewriter,
1275 {adaptor.getDnTensor(), stream});
1278 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1279 destroyDnVecCallBuilder.create(loc, rewriter,
1280 {adaptor.getDnTensor(), stream});
1282 rewriter.replaceOp(op, {stream});
1286LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1287 gpu::CreateCooOp op, OpAdaptor adaptor,
1288 ConversionPatternRewriter &rewriter)
const {
1292 Location loc = op.getLoc();
1293 auto stream = adaptor.getAsyncDependencies().front();
1295 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1297 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1299 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1301 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1303 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1307 createCooCallBuilder
1308 .create(loc, rewriter,
1309 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1310 pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1312 rewriter.replaceOp(op, {handle, stream});
1316LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1317 gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1318 ConversionPatternRewriter &rewriter)
const {
1322 Location loc = op.getLoc();
1323 auto stream = adaptor.getAsyncDependencies().front();
1324 Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc);
1326 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1327 Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1329 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1333 createCooAoSCallBuilder
1334 .create(loc, rewriter,
1335 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1336 pIdxs, pValues, itp, dtp, stream})
1338 rewriter.replaceOp(op, {handle, stream});
1342LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1343 gpu::CreateCsrOp op, OpAdaptor adaptor,
1344 ConversionPatternRewriter &rewriter)
const {
1348 Location loc = op.getLoc();
1349 auto stream = adaptor.getAsyncDependencies().front();
1351 MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc);
1353 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1355 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1357 llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1359 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1361 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1366 createCsrCallBuilder
1367 .create(loc, rewriter,
1368 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1369 pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1371 rewriter.replaceOp(op, {handle, stream});
1375LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1376 gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1377 ConversionPatternRewriter &rewriter)
const {
1381 Location loc = op.getLoc();
1382 auto stream = adaptor.getAsyncDependencies().front();
1384 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1386 llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1390 auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1391 rewriter.getIndexAttr(44104));
1392 Value handle = LLVM::AllocaOp::create(
1393 rewriter, loc, llvmPointerType, llvmInt8Type, handleSz, 16);
1394 handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1396 create2To4SpMatCallBuilder
1397 .create(loc, rewriter,
1398 {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1400 rewriter.replaceOp(op, {handle, stream});
1404LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1405 gpu::DestroySpMatOp op, OpAdaptor adaptor,
1406 ConversionPatternRewriter &rewriter)
const {
1410 Location loc = op.getLoc();
1411 auto stream = adaptor.getAsyncDependencies().front();
1414 destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1415 {adaptor.getSpmat(), stream});
1418 destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1420 rewriter.replaceOp(op, {stream});
1424LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1425 gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1426 ConversionPatternRewriter &rewriter)
const {
1430 Location loc = op.getLoc();
1434 auto stream = adaptor.getAsyncDependencies().front();
1435 auto bufferSize = spMVBufferSizeCallBuilder
1436 .create(loc, rewriter,
1437 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1438 adaptor.getDnY(), computeType, stream})
1440 rewriter.replaceOp(op, {bufferSize, stream});
1444LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1445 gpu::SpMVOp op, OpAdaptor adaptor,
1446 ConversionPatternRewriter &rewriter)
const {
1450 Location loc = op.getLoc();
1454 auto stream = adaptor.getAsyncDependencies().front();
1456 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1457 spMVCallBuilder.create(loc, rewriter,
1458 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1459 adaptor.getDnY(), computeType, pBuf, stream});
1460 rewriter.replaceOp(op, {stream});
1464LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1465 gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1466 ConversionPatternRewriter &rewriter)
const {
1470 Location loc = op.getLoc();
1473 auto stream = adaptor.getAsyncDependencies().front();
1480 auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1481 rewriter.getIndexAttr(3));
1483 LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmPointerType,
1485 createCuSparseLtSpMMBufferSizeBuilder
1486 .create(loc, rewriter,
1487 {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1488 adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1492 auto bufferSizePtr1 = LLVM::GEPOp::create(
1493 rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1494 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1495 rewriter.getIndexAttr(1))});
1496 auto bufferSizePtr2 = LLVM::GEPOp::create(
1497 rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1498 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1499 rewriter.getIndexAttr(2))});
1501 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSize);
1503 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr1);
1505 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr2);
1507 rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1512 createSpMMBufferSizeCallBuilder
1513 .create(loc, rewriter,
1514 {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1515 adaptor.getDnmatC(), computeType, stream})
1517 rewriter.replaceOp(op, {bufferSize, stream});
1522LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1523 gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1524 ConversionPatternRewriter &rewriter)
const {
1528 Location loc = op.getLoc();
1533 auto stream = adaptor.getAsyncDependencies().front();
1535 createSDDMMBufferSizeCallBuilder
1536 .create(loc, rewriter,
1537 {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1538 adaptor.getSpmatC(), computeType, stream})
1540 rewriter.replaceOp(op, {bufferSize, stream});
1544LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1545 gpu::SpMMOp op, OpAdaptor adaptor,
1546 ConversionPatternRewriter &rewriter)
const {
1550 Location loc = op.getLoc();
1556 auto stream = adaptor.getAsyncDependencies().front();
1560 SmallVector<Value> pBufs;
1561 for (Value buffer : adaptor.getBuffers()) {
1562 Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
1563 pBufs.push_back(pBuf);
1565 createCuSparseLtSpMMBuilder.create(
1567 {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1568 pBufs[0], pBufs[1], pBufs[2], stream});
1570 Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
1571 .allocatedPtr(rewriter, loc);
1572 createSpMMCallBuilder.create(loc, rewriter,
1573 {modeA, modeB, adaptor.getSpmatA(),
1574 adaptor.getDnmatB(), adaptor.getDnmatC(),
1575 computeType, pBuf, stream});
1577 rewriter.replaceOp(op, {stream});
1581template <
typename T>
1583 converter.addConversion([&converter](T) ->
Type {
1584 return LLVM::LLVMPointerType::get(&converter.
getContext());
1588LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1589 gpu::SDDMMOp op, OpAdaptor adaptor,
1590 ConversionPatternRewriter &rewriter)
const {
1594 Location loc = op.getLoc();
1599 auto stream = adaptor.getAsyncDependencies().front();
1601 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1602 createSDDMMCallBuilder.create(loc, rewriter,
1603 {modeA, modeB, adaptor.getDnmatA(),
1604 adaptor.getDnmatB(), adaptor.getSpmatC(),
1605 computeType, pBuf, stream});
1606 rewriter.replaceOp(op, {stream});
1611ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1612 gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1613 ConversionPatternRewriter &rewriter)
const {
1617 Location loc = op.getLoc();
1618 auto stream = adaptor.getAsyncDependencies().front();
1619 Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1621 rewriter.replaceOp(op, {descr, stream});
1626ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1627 gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1628 ConversionPatternRewriter &rewriter)
const {
1632 Location loc = op.getLoc();
1633 auto stream = adaptor.getAsyncDependencies().front();
1634 createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1635 {adaptor.getDesc(), stream});
1636 rewriter.replaceOp(op, {stream});
1641ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1642 gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1643 ConversionPatternRewriter &rewriter)
const {
1647 Location loc = op.getLoc();
1652 auto stream = adaptor.getAsyncDependencies().front();
1655 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1656 Value bufferSizeNew;
1658 if (adaptor.getKind() ==
1659 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1661 createSpGEMMWorkEstimationBuilder
1662 .create(loc, rewriter,
1663 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1664 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1665 adaptor.getBufferSz(), pBuf, stream})
1669 createSpGEMMComputeBuilder
1670 .create(loc, rewriter,
1671 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1672 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1673 adaptor.getBufferSz(), pBuf, stream})
1676 rewriter.replaceOp(op, {bufferSizeNew, stream});
1680LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1681 gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1682 ConversionPatternRewriter &rewriter)
const {
1686 Location loc = op.getLoc();
1691 auto stream = adaptor.getAsyncDependencies().front();
1692 createSpGEMMCopyBuilder.create(loc, rewriter,
1693 {adaptor.getDesc(), modeA, modeB,
1694 adaptor.getSpmatA(), adaptor.getSpmatB(),
1695 adaptor.getSpmatC(), computeType, stream});
1696 rewriter.replaceOp(op, {stream});
1700LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1701 gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1702 ConversionPatternRewriter &rewriter)
const {
1706 Location loc = op.getLoc();
1707 auto stream = adaptor.getAsyncDependencies().front();
1709 auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1710 rewriter.getIndexAttr(3));
1711 auto buffer = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1712 llvmInt64Type, three, 16);
1714 auto rowsPtr = LLVM::GEPOp::create(
1715 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1716 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1717 rewriter.getIndexAttr(0))});
1718 auto colsPtr = LLVM::GEPOp::create(
1719 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1720 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1721 rewriter.getIndexAttr(1))});
1722 auto nnzsPtr = LLVM::GEPOp::create(
1723 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1724 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1725 rewriter.getIndexAttr(2))});
1726 createSpMatGetSizeBuilder.create(
1727 loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1728 auto rows = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, rowsPtr);
1729 auto cols = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, colsPtr);
1730 auto nnzs = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, nnzsPtr);
1732 rewriter.replaceOp(op, {rows, cols, nnzs, stream});
1736LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1737 gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1738 ConversionPatternRewriter &rewriter)
const {
1742 Location loc = op.getLoc();
1743 auto stream = adaptor.getAsyncDependencies().front();
1745 MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
1747 MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
1749 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1750 createSetCsrPointersBuilder.create(
1751 loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1752 rewriter.replaceOp(op, {stream});
1756LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1757 gpu::CreateCscOp op, OpAdaptor adaptor,
1758 ConversionPatternRewriter &rewriter)
const {
1762 Location loc = op.getLoc();
1763 auto stream = adaptor.getAsyncDependencies().front();
1765 MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc);
1767 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1769 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1771 llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1773 llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1775 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1780 createCscCallBuilder
1781 .create(loc, rewriter,
1782 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1783 pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1785 rewriter.replaceOp(op, {handle, stream});
1789LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1790 gpu::CreateBsrOp op, OpAdaptor adaptor,
1791 ConversionPatternRewriter &rewriter)
const {
1795 Location loc = op.getLoc();
1796 auto stream = adaptor.getAsyncDependencies().front();
1798 MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc);
1800 MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc);
1802 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1804 llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1806 llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1808 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1813 createBsrCallBuilder
1814 .create(loc, rewriter,
1815 {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1816 adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1817 pColIdxs, pValues, ptp, itp, dtp, stream})
1819 rewriter.replaceOp(op, {handle, stream});
1825 bool kernelBarePtrCallConv,
bool kernelIntersperseSizeCallConv) {
1831 patterns.
add<ConvertAllocOpToGpuRuntimeCallPattern,
1832 ConvertDeallocOpToGpuRuntimeCallPattern,
1833 ConvertHostRegisterOpToGpuRuntimeCallPattern,
1834 ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1835 ConvertMemcpyOpToGpuRuntimeCallPattern,
1836 ConvertMemsetOpToGpuRuntimeCallPattern,
1837 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1838 ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1839 ConvertWaitOpToGpuRuntimeCallPattern,
1840 ConvertAsyncYieldToGpuRuntimeCallPattern,
1841 ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1842 ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1843 ConvertCreateCooOpToGpuRuntimeCallPattern,
1844 ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1845 ConvertCreateCsrOpToGpuRuntimeCallPattern,
1846 ConvertCreateCscOpToGpuRuntimeCallPattern,
1847 ConvertCreateBsrOpToGpuRuntimeCallPattern,
1848 ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1849 ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1850 ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1851 ConvertSpMVOpToGpuRuntimeCallPattern,
1852 ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1853 ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1854 ConvertSpMMOpToGpuRuntimeCallPattern,
1855 ConvertSDDMMOpToGpuRuntimeCallPattern,
1856 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1857 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1858 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1859 ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1860 ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1861 ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
1862 patterns.
add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv,
1863 kernelIntersperseSizeCallConv);
1871struct GPUModuleOpConvertToLLVMInterface
1872 :
public ConvertToLLVMOpInterface::ExternalModel<
1873 GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
1875 void getConvertToLLVMConversionAttrs(
1880void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
1881 Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs)
const {
1882 auto module = cast<gpu::GPUModuleOp>(op);
1883 ArrayAttr targetsAttr =
module.getTargetsAttr();
1885 if (!targetsAttr || targetsAttr.size() != 1)
1887 if (
auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
1888 attrs.push_back(patternAttr);
1893 gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
IntegerType getIntegerType(unsigned width)
FloatAttr getF32FloatAttr(float value)
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
This class helps build Operations.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void registerConvertGpuToLLVMInterface(DialectRegistry ®istry)
Registers the ConvertToLLVMOpInterface interface on the gpu::GPUModuleOP operation.
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool kernelBarePtrCallConv=false, bool kernelIntersperseSizeCallConv=false)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
void registerConvertToLLVMDependentDialectLoading(DialectRegistry ®istry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
LLVM::LLVMFunctionType functionType
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const