40#include "llvm/ADT/STLExtras.h"
42#define DEBUG_TYPE "gpu-to-llvm"
45#define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
46#include "mlir/Conversion/Passes.h.inc"
52class GpuToLLVMConversionPass
53 :
public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
56 void getDependentDialects(DialectRegistry ®istry)
const final {
57 Base::getDependentDialects(registry);
61 void runOnOperation()
override;
64template <
typename OpTy>
67 explicit ConvertOpToGpuRuntimeCallPattern(
68 const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
69 : ConvertOpToLLVMPattern<OpTy>(typeConverter, benefit) {}
72 Value
getNumElements(ConversionPatternRewriter &rewriter, Location loc,
73 MemRefType type, MemRefDescriptor desc)
const {
75 if (type.hasStaticShape())
77 rewriter, loc, indexType, type.getNumElements());
79 uint64_t rank = type.getRank();
80 Value numElements = desc.
size(rewriter, loc, 0);
81 for (
unsigned i = 1; i < rank; i++)
82 numElements = LLVM::MulOp::create(rewriter, loc, numElements,
83 desc.
size(rewriter, loc, i));
87 MLIRContext *context = &this->getTypeConverter()->
getContext();
89 Type llvmVoidType = LLVM::LLVMVoidType::get(context);
90 LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
91 Type llvmInt8Type = IntegerType::get(context, 8);
92 Type llvmInt16Type = IntegerType::get(context, 16);
93 Type llvmInt32Type = IntegerType::get(context, 32);
94 Type llvmInt64Type = IntegerType::get(context, 64);
95 Type llvmFloat32Type = Float32Type::get(context);
96 Type llvmIntPtrType = IntegerType::get(
97 context, this->getTypeConverter()->getPointerBitwidth(0));
99 FunctionCallBuilder streamCreateCallBuilder = {
100 "mgpuStreamCreate", llvmPointerType , {}};
101 FunctionCallBuilder streamDestroyCallBuilder = {
102 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType }};
103 FunctionCallBuilder streamSynchronizeCallBuilder = {
104 "mgpuStreamSynchronize",
107 FunctionCallBuilder streamWaitEventCallBuilder = {
108 "mgpuStreamWaitEvent",
110 {llvmPointerType , llvmPointerType }};
111 FunctionCallBuilder eventCreateCallBuilder = {
112 "mgpuEventCreate", llvmPointerType , {}};
113 FunctionCallBuilder eventDestroyCallBuilder = {
114 "mgpuEventDestroy", llvmVoidType, {llvmPointerType }};
115 FunctionCallBuilder eventSynchronizeCallBuilder = {
116 "mgpuEventSynchronize",
119 FunctionCallBuilder eventRecordCallBuilder = {
122 {llvmPointerType , llvmPointerType }};
123 FunctionCallBuilder hostRegisterCallBuilder = {
124 "mgpuMemHostRegisterMemRef",
129 FunctionCallBuilder hostUnregisterCallBuilder = {
130 "mgpuMemHostUnregisterMemRef",
135 FunctionCallBuilder allocCallBuilder = {
141 FunctionCallBuilder deallocCallBuilder = {
144 {llvmPointerType , llvmPointerType }};
145 FunctionCallBuilder memcpyCallBuilder = {
148 {llvmPointerType , llvmPointerType ,
151 FunctionCallBuilder memset16CallBuilder = {
158 FunctionCallBuilder memset32CallBuilder = {
161 {llvmPointerType , llvmInt32Type ,
164 FunctionCallBuilder setDefaultDeviceCallBuilder = {
165 "mgpuSetDefaultDevice",
168 FunctionCallBuilder createDnVecCallBuilder = {
171 {llvmIntPtrType, llvmPointerType, llvmInt32Type,
173 FunctionCallBuilder destroyDnVecCallBuilder = {
176 {llvmPointerType, llvmPointerType }};
177 FunctionCallBuilder createDnMatCallBuilder = {
180 {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
182 FunctionCallBuilder destroyDnMatCallBuilder = {
185 {llvmPointerType, llvmPointerType }};
186 FunctionCallBuilder createCooCallBuilder = {
189 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
190 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
192 FunctionCallBuilder createCooAoSCallBuilder = {
195 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
196 llvmPointerType, llvmInt32Type, llvmInt32Type,
198 FunctionCallBuilder createCsrCallBuilder = {
201 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
202 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
203 llvmInt32Type, llvmPointerType }};
204 FunctionCallBuilder createCscCallBuilder = {
207 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
208 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
209 llvmInt32Type, llvmPointerType }};
210 FunctionCallBuilder createBsrCallBuilder = {
213 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
214 llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
215 llvmInt32Type, llvmInt32Type, llvmInt32Type,
217 FunctionCallBuilder destroySpMatCallBuilder = {
220 {llvmPointerType, llvmPointerType }};
221 FunctionCallBuilder spMVBufferSizeCallBuilder = {
222 "mgpuSpMVBufferSize",
224 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
225 llvmInt32Type, llvmPointerType }};
226 FunctionCallBuilder spMVCallBuilder = {
229 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
230 llvmInt32Type, llvmPointerType, llvmPointerType }};
231 FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
232 "mgpuSpMMBufferSize",
234 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
235 llvmPointerType, llvmInt32Type, llvmPointerType }};
236 FunctionCallBuilder createSpMMCallBuilder = {
239 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
240 llvmPointerType, llvmInt32Type, llvmPointerType,
242 FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
243 "mgpuSDDMMBufferSize",
245 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
246 llvmPointerType, llvmInt32Type, llvmPointerType }};
247 FunctionCallBuilder createSDDMMCallBuilder = {
250 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
251 llvmPointerType, llvmInt32Type, llvmPointerType,
253 FunctionCallBuilder createLtDnMatCallBuilder = {
254 "mgpuCreateCuSparseLtDnMat",
256 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
257 llvmInt32Type, llvmPointerType }};
258 FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
259 "mgpuDestroyCuSparseLtSpMat",
261 {llvmPointerType, llvmPointerType }};
262 FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
263 "mgpuDestroyCuSparseLtDnMat",
265 {llvmPointerType, llvmPointerType }};
266 FunctionCallBuilder create2To4SpMatCallBuilder = {
267 "mgpuCusparseLtCreate2To4SpMat",
269 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
270 llvmInt32Type, llvmPointerType }};
271 FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
272 "mgpuCuSparseLtSpMMBufferSize",
274 {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
275 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
277 FunctionCallBuilder createCuSparseLtSpMMBuilder = {
278 "mgpuCuSparseLtSpMM",
280 {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
281 llvmPointerType, llvmPointerType, llvmPointerType }};
282 FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
283 "mgpuSpGEMMCreateDescr",
286 FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
287 "mgpuSpGEMMDestroyDescr",
289 {llvmPointerType , llvmPointerType }};
290 FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
291 "mgpuSpGEMMWorkEstimation",
293 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
294 llvmPointerType , llvmPointerType , llvmPointerType ,
295 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
297 FunctionCallBuilder createSpGEMMComputeBuilder = {
300 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
301 llvmPointerType , llvmPointerType , llvmPointerType ,
302 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
304 FunctionCallBuilder createSpGEMMCopyBuilder = {
307 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
308 llvmPointerType , llvmPointerType , llvmPointerType ,
309 llvmInt32Type , llvmPointerType }};
310 FunctionCallBuilder createSpMatGetSizeBuilder = {
313 {llvmPointerType , llvmPointerType , llvmPointerType ,
314 llvmPointerType , llvmPointerType }};
315 FunctionCallBuilder createSetCsrPointersBuilder = {
316 "mgpuSetCsrPointers",
318 {llvmPointerType , llvmPointerType ,
319 llvmPointerType , llvmPointerType ,
325class ConvertHostRegisterOpToGpuRuntimeCallPattern
326 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
328 ConvertHostRegisterOpToGpuRuntimeCallPattern(
329 const LLVMTypeConverter &typeConverter)
330 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
334 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter)
const override;
338class ConvertHostUnregisterOpToGpuRuntimeCallPattern
339 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
341 ConvertHostUnregisterOpToGpuRuntimeCallPattern(
342 const LLVMTypeConverter &typeConverter)
343 : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
348 matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter)
const override;
354class ConvertAllocOpToGpuRuntimeCallPattern
355 :
public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
357 ConvertAllocOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
358 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
362 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter)
const override;
368class ConvertDeallocOpToGpuRuntimeCallPattern
369 :
public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
371 ConvertDeallocOpToGpuRuntimeCallPattern(
372 const LLVMTypeConverter &typeConverter)
373 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
377 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
378 ConversionPatternRewriter &rewriter)
const override;
381class ConvertAsyncYieldToGpuRuntimeCallPattern
382 :
public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
384 ConvertAsyncYieldToGpuRuntimeCallPattern(
385 const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
386 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter,
391 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
392 ConversionPatternRewriter &rewriter)
const override;
397class ConvertWaitOpToGpuRuntimeCallPattern
398 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
400 ConvertWaitOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
401 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
405 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
406 ConversionPatternRewriter &rewriter)
const override;
411class ConvertWaitAsyncOpToGpuRuntimeCallPattern
412 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
414 ConvertWaitAsyncOpToGpuRuntimeCallPattern(
415 const LLVMTypeConverter &typeConverter)
416 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
420 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
421 ConversionPatternRewriter &rewriter)
const override;
425class LegalizeLaunchFuncOpPattern
426 :
public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
428 LegalizeLaunchFuncOpPattern(
const LLVMTypeConverter &typeConverter,
429 bool kernelBarePtrCallConv,
430 bool kernelIntersperseSizeCallConv)
431 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
432 kernelBarePtrCallConv(kernelBarePtrCallConv),
433 kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
437 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
438 ConversionPatternRewriter &rewriter)
const override;
440 bool kernelBarePtrCallConv;
441 bool kernelIntersperseSizeCallConv;
446class ConvertMemcpyOpToGpuRuntimeCallPattern
447 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
449 ConvertMemcpyOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
450 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
454 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
455 ConversionPatternRewriter &rewriter)
const override;
460class ConvertMemsetOpToGpuRuntimeCallPattern
461 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
463 ConvertMemsetOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
464 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
468 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
469 ConversionPatternRewriter &rewriter)
const override;
474class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
475 :
public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
477 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
478 const LLVMTypeConverter &typeConverter)
479 : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
483 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
484 ConversionPatternRewriter &rewriter)
const override;
489#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
490 class Convert##op_name##ToGpuRuntimeCallPattern \
491 : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
493 Convert##op_name##ToGpuRuntimeCallPattern( \
494 const LLVMTypeConverter &typeConverter) \
495 : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
499 matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
500 ConversionPatternRewriter &rewriter) const override; \
527void GpuToLLVMConversionPass::runOnOperation() {
538 vector::populateVectorFromElementsUnrollPatterns(patterns);
540 return signalPassFailure();
543 LowerToLLVMOptions
options(context);
544 options.useBarePtrCallConv = hostBarePtrCallConv;
545 RewritePatternSet patterns(context);
546 ConversionTarget
target(*context);
547 target.addLegalDialect<LLVM::LLVMDialect>();
548 LLVMTypeConverter converter(context,
options);
553 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
556 iface->populateConvertToLLVMConversionPatterns(
target, converter, patterns);
561 target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
563 target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
564 [&](gpu::LaunchFuncOp op) ->
bool {
return converter.isLegal(op); });
572 kernelBarePtrCallConv,
573 kernelIntersperseSizeCallConv);
576 applyPartialConversion(getOperation(),
target, std::move(patterns))))
582 auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
583 auto function = [&] {
584 if (
auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(
functionName))
589 return LLVM::CallOp::create(builder, loc, function, arguments);
606 llvm_unreachable(
"unsupported type");
612 if (llvm::isa<ComplexType>(type)) {
614 auto elementType = cast<ComplexType>(type).getElementType();
615 if (elementType.isBF16())
617 if (elementType.isF16())
619 if (elementType.isF32())
621 if (elementType.isF64())
623 if (elementType.isInteger(8))
625 if (elementType.isInteger(16))
627 if (elementType.isInteger(32))
645 llvm_unreachable(
"unsupported element type");
649 return spMat.
getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
674 llvm_unreachable(
"cannot find spmat def");
679 auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
691 ConversionPatternRewriter &rewriter) {
692 if (!llvm::all_of(operands, [](
Value value) {
695 return rewriter.notifyMatchFailure(
696 op,
"Cannot convert if operands aren't of LLVM type.");
702 gpu::AsyncOpInterface op) {
703 if (op.getAsyncDependencies().size() != 1)
704 return rewriter.notifyMatchFailure(
705 op,
"Can only convert with exactly one async dependency.");
707 if (!op.getAsyncToken())
708 return rewriter.notifyMatchFailure(op,
"Can convert only async version.");
713LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
714 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
715 ConversionPatternRewriter &rewriter)
const {
716 auto *op = hostRegisterOp.getOperation();
720 Location loc = op->getLoc();
722 auto memRefType = hostRegisterOp.getValue().getType();
723 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
726 auto arguments = getTypeConverter()->promoteOperands(
727 loc, op->getOperands(), adaptor.getOperands(), rewriter);
728 arguments.push_back(elementSize);
729 hostRegisterCallBuilder.create(loc, rewriter, arguments);
731 rewriter.eraseOp(op);
735LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
736 gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
737 ConversionPatternRewriter &rewriter)
const {
738 Operation *op = hostUnregisterOp.getOperation();
742 Location loc = op->
getLoc();
744 auto memRefType = hostUnregisterOp.getValue().getType();
745 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
748 auto arguments = getTypeConverter()->promoteOperands(
749 loc, op->
getOperands(), adaptor.getOperands(), rewriter);
750 arguments.push_back(elementSize);
751 hostUnregisterCallBuilder.create(loc, rewriter, arguments);
753 rewriter.eraseOp(op);
757LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
758 gpu::AllocOp allocOp, OpAdaptor adaptor,
759 ConversionPatternRewriter &rewriter)
const {
761 MemRefType memRefType = allocOp.getType();
764 !isConvertibleAndHasIdentityMaps(memRefType))
767 auto loc = allocOp.getLoc();
769 bool isShared = allocOp.getHostShared();
771 if (isShared && allocOp.getAsyncToken())
772 return rewriter.notifyMatchFailure(
773 allocOp,
"Host Shared allocation cannot be done async");
774 if (adaptor.getAsyncDependencies().size() > 1)
775 return rewriter.notifyMatchFailure(
776 allocOp,
"Can convert with at most one async dependency.");
780 SmallVector<Value, 4> shape;
781 SmallVector<Value, 4> strides;
783 getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
784 shape, strides, sizeBytes);
788 auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType);
789 Value stream = adaptor.getAsyncDependencies().empty()
791 : adaptor.getAsyncDependencies().front();
793 auto isHostShared = mlir::LLVM::ConstantOp::create(
794 rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
797 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
802 unsigned dstAddrSpace = memRefType.getMemorySpaceAsInt();
803 unsigned srcAddrSpace =
804 cast<LLVM::LLVMPointerType>(allocatedPtr.
getType()).getAddressSpace();
805 if (dstAddrSpace != srcAddrSpace) {
807 LLVM::LLVMPointerType::get(rewriter.getContext(), dstAddrSpace);
809 LLVM::AddrSpaceCastOp::create(rewriter, loc, targetPtrTy, allocatedPtr);
813 Value alignedPtr = allocatedPtr;
816 auto memRefDescriptor = this->createMemRefDescriptor(
817 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
819 if (allocOp.getAsyncToken()) {
821 rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
823 rewriter.replaceOp(allocOp, {memRefDescriptor});
829LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
830 gpu::DeallocOp deallocOp, OpAdaptor adaptor,
831 ConversionPatternRewriter &rewriter)
const {
834 if (adaptor.getAsyncDependencies().size() > 1)
835 return rewriter.notifyMatchFailure(
836 deallocOp,
"Can convert with at most one async dependency.");
838 Location loc = deallocOp.getLoc();
841 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
842 auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType);
843 Value stream = adaptor.getAsyncDependencies().empty()
845 : adaptor.getAsyncDependencies().front();
846 deallocCallBuilder.create(loc, rewriter, {pointer, stream});
848 if (deallocOp.getAsyncToken()) {
850 rewriter.replaceOp(deallocOp, {stream});
853 rewriter.eraseOp(deallocOp);
859 return isa<gpu::AsyncTokenType>(value.
getType());
874LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
875 async::YieldOp yieldOp, OpAdaptor adaptor,
876 ConversionPatternRewriter &rewriter)
const {
878 return rewriter.notifyMatchFailure(yieldOp,
"no gpu async token operand");
880 Location loc = yieldOp.getLoc();
881 SmallVector<Value, 4> newOperands(adaptor.getOperands());
882 llvm::SmallDenseSet<Value> streams;
883 for (
auto &operand : yieldOp->getOpOperands()) {
886 auto idx = operand.getOperandNumber();
887 auto stream = adaptor.getOperands()[idx];
888 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
889 eventRecordCallBuilder.create(loc, rewriter, {event, stream});
890 newOperands[idx] = event;
891 streams.insert(stream);
893 for (
auto stream : streams)
894 streamDestroyCallBuilder.create(loc, rewriter, {stream});
896 rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
902 assert(isa<LLVM::LLVMPointerType>(value.
getType()));
904 return *defOp.getCallee() == functionName;
912LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
913 gpu::WaitOp waitOp, OpAdaptor adaptor,
914 ConversionPatternRewriter &rewriter)
const {
915 if (waitOp.getAsyncToken())
916 return rewriter.notifyMatchFailure(waitOp,
"Cannot convert async op.");
918 Location loc = waitOp.getLoc();
920 for (
auto operand : adaptor.getOperands()) {
923 streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
924 streamDestroyCallBuilder.create(loc, rewriter, {operand});
928 eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
929 eventDestroyCallBuilder.create(loc, rewriter, {operand});
933 rewriter.eraseOp(waitOp);
942LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
943 gpu::WaitOp waitOp, OpAdaptor adaptor,
944 ConversionPatternRewriter &rewriter)
const {
945 if (!waitOp.getAsyncToken())
946 return rewriter.notifyMatchFailure(waitOp,
"Can only convert async op.");
948 Location loc = waitOp.getLoc();
950 auto insertionPoint = rewriter.saveInsertionPoint();
951 SmallVector<Value, 1> events;
953 llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
954 auto operand = std::get<1>(pair);
958 auto *defOp = std::get<0>(pair).getDefiningOp();
959 rewriter.setInsertionPointAfter(defOp);
960 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
961 eventRecordCallBuilder.create(loc, rewriter, {event, operand});
962 events.push_back(event);
966 events.push_back(operand);
969 rewriter.restoreInsertionPoint(insertionPoint);
970 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
971 for (
auto event : events)
972 streamWaitEventCallBuilder.create(loc, rewriter, {stream,
event});
973 for (
auto event : events)
974 eventDestroyCallBuilder.create(loc, rewriter, {
event});
975 rewriter.replaceOp(waitOp, {stream});
981LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
982 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
983 ConversionPatternRewriter &rewriter)
const {
990 if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
991 return rewriter.notifyMatchFailure(
992 launchOp,
"Cannot convert non-async op with async dependencies.");
994 Location loc = launchOp.getLoc();
996 Value stream = Value();
997 if (!adaptor.getAsyncDependencies().empty()) {
998 stream = adaptor.getAsyncDependencies().front();
1001 if (adaptor.getAsyncDependencies().size() > 1) {
1002 auto insertionPoint = rewriter.saveInsertionPoint();
1003 SmallVector<Value, 4> events;
1004 for (
auto [origDep, convertedDep] :
1005 llvm::zip(launchOp.getAsyncDependencies().drop_front(),
1006 adaptor.getAsyncDependencies().drop_front())) {
1008 streamCreateCallBuilder.functionName)) {
1009 events.push_back(convertedDep);
1012 Operation *defOp = origDep.getDefiningOp();
1013 rewriter.setInsertionPointAfter(defOp);
1015 eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
1016 eventRecordCallBuilder.create(loc, rewriter, {event, convertedDep});
1017 events.push_back(event);
1019 rewriter.restoreInsertionPoint(insertionPoint);
1020 for (Value event : events)
1021 streamWaitEventCallBuilder.create(loc, rewriter, {stream,
event});
1022 for (Value event : events)
1023 eventDestroyCallBuilder.create(loc, rewriter, {
event});
1028 else if (launchOp.getAsyncToken())
1029 stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
1034 OperandRange origArguments = launchOp.getKernelOperands();
1035 bool effectiveBarePtr = kernelBarePtrCallConv ||
1036 getTypeConverter()->getOptions().useBarePtrCallConv;
1037 if (effectiveBarePtr) {
1038 for (Value arg : origArguments) {
1039 if (isa<UnrankedMemRefType>(arg.getType()))
1040 return rewriter.notifyMatchFailure(
1041 loc,
"unranked memref kernel argument is not supported with "
1042 "the bare-pointer calling convention");
1045 SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
1046 loc, origArguments, adaptor.getKernelOperands(), rewriter,
1047 kernelBarePtrCallConv);
1048 SmallVector<Value, 8> llvmArgumentsWithSizes;
1051 if (kernelIntersperseSizeCallConv) {
1052 if (origArguments.size() != llvmArguments.size()) {
1054 return rewriter.notifyMatchFailure(
1056 "Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
1059 llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
1060 for (
auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
1061 auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
1063 return rewriter.notifyMatchFailure(
1064 launchOp,
"Operand to launch op is not a memref.");
1067 if (!memrefTy.hasStaticShape() ||
1068 !memrefTy.getElementType().isIntOrFloat()) {
1069 return rewriter.notifyMatchFailure(
1070 launchOp,
"Operand to launch op is not a memref with a static "
1071 "shape and an integer or float element type.");
1074 unsigned bitwidth = memrefTy.getElementTypeBitWidth();
1075 if (bitwidth % 8 != 0) {
1076 return rewriter.notifyMatchFailure(
1077 launchOp,
"Operand to launch op is not a memref with a "
1078 "byte-aligned element type.");
1081 uint64_t staticSize =
static_cast<uint64_t
>(bitwidth / 8) *
1082 static_cast<uint64_t
>(memrefTy.getNumElements());
1084 Value sizeArg = LLVM::ConstantOp::create(
1085 rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize));
1086 llvmArgumentsWithSizes.push_back(llvmArg);
1087 llvmArgumentsWithSizes.push_back(sizeArg);
1091 std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1092 if (launchOp.hasClusterSize()) {
1094 gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
1095 adaptor.getClusterSizeZ()};
1097 auto newLaunchOp = gpu::LaunchFuncOp::create(
1098 rewriter, launchOp.getLoc(), launchOp.getKernelAttr(),
1099 gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1100 adaptor.getGridSizeZ()},
1101 gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1102 adaptor.getBlockSizeZ()},
1103 adaptor.getDynamicSharedMemorySize(),
1104 llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
1105 stream, clusterSize);
1106 if (launchOp.getCooperative())
1107 newLaunchOp.setCooperative(
true);
1108 if (launchOp.getAsyncToken())
1109 rewriter.replaceOp(launchOp, {stream});
1111 rewriter.eraseOp(launchOp);
1116 ConversionPatternRewriter &rewriter,
1117 LLVM::LLVMPointerType destinationType,
1120 auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.
getType());
1121 if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1122 sourcePtr = LLVM::AddrSpaceCastOp::create(
1124 LLVM::LLVMPointerType::get(rewriter.getContext(),
1125 destinationType.getAddressSpace()),
1130LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1131 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1132 ConversionPatternRewriter &rewriter)
const {
1133 auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1136 !isConvertibleAndHasIdentityMaps(memRefType) ||
1140 auto loc = memcpyOp.getLoc();
1142 MemRefDescriptor srcDesc(adaptor.getSrc());
1143 Value numElements =
getNumElements(rewriter, loc, memRefType, srcDesc);
1145 Type elementPtrType = getElementPtrType(memRefType);
1146 Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
1147 Value gepPtr = LLVM::GEPOp::create(
1148 rewriter, loc, elementPtrType,
1149 typeConverter->convertType(memRefType.getElementType()), nullPtr,
1152 LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr);
1155 srcDesc.alignedPtr(rewriter, loc),
1156 *getTypeConverter());
1158 loc, rewriter, llvmPointerType,
1159 MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
1160 *getTypeConverter());
1162 auto stream = adaptor.getAsyncDependencies().front();
1163 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1165 rewriter.replaceOp(memcpyOp, {stream});
1170LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1171 gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1172 ConversionPatternRewriter &rewriter)
const {
1173 auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1176 !isConvertibleAndHasIdentityMaps(memRefType) ||
1180 auto loc = memsetOp.getLoc();
1182 Type valueType = adaptor.getValue().getType();
1185 if (!valueType.
isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1186 return rewriter.notifyMatchFailure(
1187 memsetOp,
"value must be a 16 or 32 bit int or float");
1191 Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1193 MemRefDescriptor dstDesc(adaptor.getDst());
1194 Value numElements =
getNumElements(rewriter, loc, memRefType, dstDesc);
1197 LLVM::BitcastOp::create(rewriter, loc, bitCastType, adaptor.getValue());
1199 dstDesc.alignedPtr(rewriter, loc),
1200 *getTypeConverter());
1202 auto stream = adaptor.getAsyncDependencies().front();
1203 FunctionCallBuilder builder =
1204 valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1205 builder.
create(loc, rewriter, {dst, value, numElements, stream});
1207 rewriter.replaceOp(memsetOp, {stream});
1211LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1212 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1213 ConversionPatternRewriter &rewriter)
const {
1214 Location loc = op.getLoc();
1215 auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1216 {adaptor.getDevIndex()});
1217 rewriter.replaceOp(op, call);
1221template <
typename T>
1224 return LLVM::ConstantOp::create(builder, loc, llvmInt32Type,
1225 static_cast<int32_t
>(tValue));
1228template <
typename T>
1231 return LLVM::ConstantOp::create(
1232 builder, loc, llvmFloat32Type,
1236LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1237 gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1238 ConversionPatternRewriter &rewriter)
const {
1242 Location loc = op.getLoc();
1243 auto stream = adaptor.getAsyncDependencies().front();
1245 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1246 Type dType = op.getMemref().
getType().getElementType();
1249 SmallVector<Value, 4> dims;
1250 for (Value dim : adaptor.getDims()) {
1251 dims.push_back(dim);
1261 if (dims.size() == 2) {
1263 auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1264 rewriter.getIndexAttr(11032));
1265 handle = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1266 llvmInt8Type, handleSz, 16);
1267 handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1269 createLtDnMatCallBuilder
1270 .create(loc, rewriter,
1271 {handle, dims[0], dims[1], pTensor, dtp, stream})
1275 createDnMatCallBuilder
1276 .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1280 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1281 handle = createDnVecCallBuilder
1282 .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1285 rewriter.replaceOp(op, {handle, stream});
1289LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1290 gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1291 ConversionPatternRewriter &rewriter)
const {
1295 Location loc = op.getLoc();
1296 auto stream = adaptor.getAsyncDependencies().front();
1297 auto definingOp = op.getDnTensor().
getDefiningOp<gpu::CreateDnTensorOp>();
1298 SmallVector<Value, 4> dims;
1299 for (Value dim : definingOp.getDims()) {
1300 dims.push_back(dim);
1302 if (dims.size() == 2) {
1306 destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1307 {adaptor.getDnTensor(), stream});
1309 destroyDnMatCallBuilder.create(loc, rewriter,
1310 {adaptor.getDnTensor(), stream});
1313 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1314 destroyDnVecCallBuilder.create(loc, rewriter,
1315 {adaptor.getDnTensor(), stream});
1317 rewriter.replaceOp(op, {stream});
1321LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1322 gpu::CreateCooOp op, OpAdaptor adaptor,
1323 ConversionPatternRewriter &rewriter)
const {
1327 Location loc = op.getLoc();
1328 auto stream = adaptor.getAsyncDependencies().front();
1330 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1332 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1334 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1336 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1338 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1342 createCooCallBuilder
1343 .create(loc, rewriter,
1344 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1345 pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1347 rewriter.replaceOp(op, {handle, stream});
1351LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1352 gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1353 ConversionPatternRewriter &rewriter)
const {
1357 Location loc = op.getLoc();
1358 auto stream = adaptor.getAsyncDependencies().front();
1359 Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc);
1361 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1362 Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1364 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1368 createCooAoSCallBuilder
1369 .create(loc, rewriter,
1370 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1371 pIdxs, pValues, itp, dtp, stream})
1373 rewriter.replaceOp(op, {handle, stream});
1377LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1378 gpu::CreateCsrOp op, OpAdaptor adaptor,
1379 ConversionPatternRewriter &rewriter)
const {
1383 Location loc = op.getLoc();
1384 auto stream = adaptor.getAsyncDependencies().front();
1386 MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc);
1388 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1390 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1392 llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1394 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1396 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1401 createCsrCallBuilder
1402 .create(loc, rewriter,
1403 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1404 pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1406 rewriter.replaceOp(op, {handle, stream});
1410LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1411 gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1412 ConversionPatternRewriter &rewriter)
const {
1416 Location loc = op.getLoc();
1417 auto stream = adaptor.getAsyncDependencies().front();
1419 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1421 llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1425 auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1426 rewriter.getIndexAttr(44104));
1427 Value handle = LLVM::AllocaOp::create(
1428 rewriter, loc, llvmPointerType, llvmInt8Type, handleSz, 16);
1429 handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1431 create2To4SpMatCallBuilder
1432 .create(loc, rewriter,
1433 {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1435 rewriter.replaceOp(op, {handle, stream});
1439LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1440 gpu::DestroySpMatOp op, OpAdaptor adaptor,
1441 ConversionPatternRewriter &rewriter)
const {
1445 Location loc = op.getLoc();
1446 auto stream = adaptor.getAsyncDependencies().front();
1449 destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1450 {adaptor.getSpmat(), stream});
1453 destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1455 rewriter.replaceOp(op, {stream});
1459LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1460 gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1461 ConversionPatternRewriter &rewriter)
const {
1465 Location loc = op.getLoc();
1469 auto stream = adaptor.getAsyncDependencies().front();
1470 auto bufferSize = spMVBufferSizeCallBuilder
1471 .create(loc, rewriter,
1472 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1473 adaptor.getDnY(), computeType, stream})
1475 rewriter.replaceOp(op, {bufferSize, stream});
1479LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1480 gpu::SpMVOp op, OpAdaptor adaptor,
1481 ConversionPatternRewriter &rewriter)
const {
1485 Location loc = op.getLoc();
1489 auto stream = adaptor.getAsyncDependencies().front();
1491 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1492 spMVCallBuilder.create(loc, rewriter,
1493 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1494 adaptor.getDnY(), computeType, pBuf, stream});
1495 rewriter.replaceOp(op, {stream});
1499LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1500 gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1501 ConversionPatternRewriter &rewriter)
const {
1505 Location loc = op.getLoc();
1508 auto stream = adaptor.getAsyncDependencies().front();
1515 auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1516 rewriter.getIndexAttr(3));
1518 LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmPointerType,
1520 createCuSparseLtSpMMBufferSizeBuilder
1521 .create(loc, rewriter,
1522 {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1523 adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1527 auto bufferSizePtr1 = LLVM::GEPOp::create(
1528 rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1529 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1530 rewriter.getIndexAttr(1))});
1531 auto bufferSizePtr2 = LLVM::GEPOp::create(
1532 rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1533 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1534 rewriter.getIndexAttr(2))});
1536 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSize);
1538 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr1);
1540 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr2);
1542 rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1547 createSpMMBufferSizeCallBuilder
1548 .create(loc, rewriter,
1549 {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1550 adaptor.getDnmatC(), computeType, stream})
1552 rewriter.replaceOp(op, {bufferSize, stream});
1557LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1558 gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1559 ConversionPatternRewriter &rewriter)
const {
1563 Location loc = op.getLoc();
1568 auto stream = adaptor.getAsyncDependencies().front();
1570 createSDDMMBufferSizeCallBuilder
1571 .create(loc, rewriter,
1572 {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1573 adaptor.getSpmatC(), computeType, stream})
1575 rewriter.replaceOp(op, {bufferSize, stream});
1579LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1580 gpu::SpMMOp op, OpAdaptor adaptor,
1581 ConversionPatternRewriter &rewriter)
const {
1585 Location loc = op.getLoc();
1591 auto stream = adaptor.getAsyncDependencies().front();
1595 SmallVector<Value> pBufs;
1596 for (Value buffer : adaptor.getBuffers()) {
1597 Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
1598 pBufs.push_back(pBuf);
1600 createCuSparseLtSpMMBuilder.create(
1602 {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1603 pBufs[0], pBufs[1], pBufs[2], stream});
1605 Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
1606 .allocatedPtr(rewriter, loc);
1607 createSpMMCallBuilder.create(loc, rewriter,
1608 {modeA, modeB, adaptor.getSpmatA(),
1609 adaptor.getDnmatB(), adaptor.getDnmatC(),
1610 computeType, pBuf, stream});
1612 rewriter.replaceOp(op, {stream});
1616template <
typename T>
1618 converter.addConversion([&converter](T) ->
Type {
1619 return LLVM::LLVMPointerType::get(&converter.
getContext());
1623LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1624 gpu::SDDMMOp op, OpAdaptor adaptor,
1625 ConversionPatternRewriter &rewriter)
const {
1629 Location loc = op.getLoc();
1634 auto stream = adaptor.getAsyncDependencies().front();
1636 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1637 createSDDMMCallBuilder.create(loc, rewriter,
1638 {modeA, modeB, adaptor.getDnmatA(),
1639 adaptor.getDnmatB(), adaptor.getSpmatC(),
1640 computeType, pBuf, stream});
1641 rewriter.replaceOp(op, {stream});
1646ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1647 gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1648 ConversionPatternRewriter &rewriter)
const {
1652 Location loc = op.getLoc();
1653 auto stream = adaptor.getAsyncDependencies().front();
1654 Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1656 rewriter.replaceOp(op, {descr, stream});
1661ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1662 gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1663 ConversionPatternRewriter &rewriter)
const {
1667 Location loc = op.getLoc();
1668 auto stream = adaptor.getAsyncDependencies().front();
1669 createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1670 {adaptor.getDesc(), stream});
1671 rewriter.replaceOp(op, {stream});
1676ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1677 gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1678 ConversionPatternRewriter &rewriter)
const {
1682 Location loc = op.getLoc();
1687 auto stream = adaptor.getAsyncDependencies().front();
1690 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1691 Value bufferSizeNew;
1693 if (adaptor.getKind() ==
1694 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1696 createSpGEMMWorkEstimationBuilder
1697 .create(loc, rewriter,
1698 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1699 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1700 adaptor.getBufferSz(), pBuf, stream})
1704 createSpGEMMComputeBuilder
1705 .create(loc, rewriter,
1706 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1707 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1708 adaptor.getBufferSz(), pBuf, stream})
1711 rewriter.replaceOp(op, {bufferSizeNew, stream});
1715LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1716 gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1717 ConversionPatternRewriter &rewriter)
const {
1721 Location loc = op.getLoc();
1726 auto stream = adaptor.getAsyncDependencies().front();
1727 createSpGEMMCopyBuilder.create(loc, rewriter,
1728 {adaptor.getDesc(), modeA, modeB,
1729 adaptor.getSpmatA(), adaptor.getSpmatB(),
1730 adaptor.getSpmatC(), computeType, stream});
1731 rewriter.replaceOp(op, {stream});
1735LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1736 gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1737 ConversionPatternRewriter &rewriter)
const {
1741 Location loc = op.getLoc();
1742 auto stream = adaptor.getAsyncDependencies().front();
1744 auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1745 rewriter.getIndexAttr(3));
1746 auto buffer = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1747 llvmInt64Type, three, 16);
1749 auto rowsPtr = LLVM::GEPOp::create(
1750 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1751 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1752 rewriter.getIndexAttr(0))});
1753 auto colsPtr = LLVM::GEPOp::create(
1754 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1755 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1756 rewriter.getIndexAttr(1))});
1757 auto nnzsPtr = LLVM::GEPOp::create(
1758 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1759 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1760 rewriter.getIndexAttr(2))});
1761 createSpMatGetSizeBuilder.create(
1762 loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1763 auto rows = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, rowsPtr);
1764 auto cols = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, colsPtr);
1765 auto nnzs = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, nnzsPtr);
1767 rewriter.replaceOp(op, {rows, cols, nnzs, stream});
1771LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1772 gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1773 ConversionPatternRewriter &rewriter)
const {
1777 Location loc = op.getLoc();
1778 auto stream = adaptor.getAsyncDependencies().front();
1780 MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
1782 MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
1784 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1785 createSetCsrPointersBuilder.create(
1786 loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1787 rewriter.replaceOp(op, {stream});
1791LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1792 gpu::CreateCscOp op, OpAdaptor adaptor,
1793 ConversionPatternRewriter &rewriter)
const {
1797 Location loc = op.getLoc();
1798 auto stream = adaptor.getAsyncDependencies().front();
1800 MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc);
1802 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1804 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1806 llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1808 llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1810 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1815 createCscCallBuilder
1816 .create(loc, rewriter,
1817 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1818 pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1820 rewriter.replaceOp(op, {handle, stream});
1824LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1825 gpu::CreateBsrOp op, OpAdaptor adaptor,
1826 ConversionPatternRewriter &rewriter)
const {
1830 Location loc = op.getLoc();
1831 auto stream = adaptor.getAsyncDependencies().front();
1833 MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc);
1835 MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc);
1837 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1839 llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1841 llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1843 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1848 createBsrCallBuilder
1849 .create(loc, rewriter,
1850 {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1851 adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1852 pColIdxs, pValues, ptp, itp, dtp, stream})
1854 rewriter.replaceOp(op, {handle, stream});
1860 bool kernelBarePtrCallConv,
bool kernelIntersperseSizeCallConv) {
1870 patterns.
add<ConvertAsyncYieldToGpuRuntimeCallPattern>(converter,
1873 patterns.
add<ConvertAllocOpToGpuRuntimeCallPattern,
1874 ConvertDeallocOpToGpuRuntimeCallPattern,
1875 ConvertHostRegisterOpToGpuRuntimeCallPattern,
1876 ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1877 ConvertMemcpyOpToGpuRuntimeCallPattern,
1878 ConvertMemsetOpToGpuRuntimeCallPattern,
1879 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1880 ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1881 ConvertWaitOpToGpuRuntimeCallPattern,
1882 ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1883 ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1884 ConvertCreateCooOpToGpuRuntimeCallPattern,
1885 ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1886 ConvertCreateCsrOpToGpuRuntimeCallPattern,
1887 ConvertCreateCscOpToGpuRuntimeCallPattern,
1888 ConvertCreateBsrOpToGpuRuntimeCallPattern,
1889 ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1890 ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1891 ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1892 ConvertSpMVOpToGpuRuntimeCallPattern,
1893 ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1894 ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1895 ConvertSpMMOpToGpuRuntimeCallPattern,
1896 ConvertSDDMMOpToGpuRuntimeCallPattern,
1897 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1898 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1899 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1900 ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1901 ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1902 ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
1903 patterns.
add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv,
1904 kernelIntersperseSizeCallConv);
1912struct GPUModuleOpConvertToLLVMInterface
1913 :
public ConvertToLLVMOpInterface::ExternalModel<
1914 GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
1916 void getConvertToLLVMConversionAttrs(
1921void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
1922 Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs)
const {
1923 auto module = cast<gpu::GPUModuleOp>(op);
1924 ArrayAttr targetsAttr =
module.getTargetsAttr();
1926 if (!targetsAttr || targetsAttr.size() != 1)
1928 if (
auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
1929 attrs.push_back(patternAttr);
1934 gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
IntegerType getIntegerType(unsigned width)
FloatAttr getF32FloatAttr(float value)
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
This class helps build Operations.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void registerConvertGpuToLLVMInterface(DialectRegistry ®istry)
Registers the ConvertToLLVMOpInterface interface on the gpu::GPUModuleOP operation.
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool kernelBarePtrCallConv=false, bool kernelIntersperseSizeCallConv=false)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
void registerConvertToLLVMDependentDialectLoading(DialectRegistry ®istry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
LLVM::LLVMFunctionType functionType
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const