40#include "llvm/ADT/STLExtras.h"
42#define DEBUG_TYPE "gpu-to-llvm"
45#define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
46#include "mlir/Conversion/Passes.h.inc"
52class GpuToLLVMConversionPass
53 :
public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
56 void getDependentDialects(DialectRegistry ®istry)
const final {
57 Base::getDependentDialects(registry);
61 void runOnOperation()
override;
64template <
typename OpTy>
67 explicit ConvertOpToGpuRuntimeCallPattern(
68 const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
69 : ConvertOpToLLVMPattern<OpTy>(typeConverter, benefit) {}
72 Value
getNumElements(ConversionPatternRewriter &rewriter, Location loc,
73 MemRefType type, MemRefDescriptor desc)
const {
75 if (type.hasStaticShape())
77 rewriter, loc, indexType, type.getNumElements());
79 uint64_t rank = type.getRank();
80 Value numElements = desc.
size(rewriter, loc, 0);
81 for (
unsigned i = 1; i < rank; i++)
82 numElements = LLVM::MulOp::create(rewriter, loc, numElements,
83 desc.
size(rewriter, loc, i));
87 MLIRContext *context = &this->getTypeConverter()->
getContext();
89 Type llvmVoidType = LLVM::LLVMVoidType::get(context);
90 LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
91 Type llvmInt8Type = IntegerType::get(context, 8);
92 Type llvmInt16Type = IntegerType::get(context, 16);
93 Type llvmInt32Type = IntegerType::get(context, 32);
94 Type llvmInt64Type = IntegerType::get(context, 64);
95 Type llvmFloat32Type = Float32Type::get(context);
96 Type llvmIntPtrType = IntegerType::get(
97 context, this->getTypeConverter()->getPointerBitwidth(0));
99 FunctionCallBuilder streamCreateCallBuilder = {
100 "mgpuStreamCreate", llvmPointerType , {}};
101 FunctionCallBuilder streamDestroyCallBuilder = {
102 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType }};
103 FunctionCallBuilder streamSynchronizeCallBuilder = {
104 "mgpuStreamSynchronize",
107 FunctionCallBuilder streamWaitEventCallBuilder = {
108 "mgpuStreamWaitEvent",
110 {llvmPointerType , llvmPointerType }};
111 FunctionCallBuilder eventCreateCallBuilder = {
112 "mgpuEventCreate", llvmPointerType , {}};
113 FunctionCallBuilder eventDestroyCallBuilder = {
114 "mgpuEventDestroy", llvmVoidType, {llvmPointerType }};
115 FunctionCallBuilder eventSynchronizeCallBuilder = {
116 "mgpuEventSynchronize",
119 FunctionCallBuilder eventRecordCallBuilder = {
122 {llvmPointerType , llvmPointerType }};
123 FunctionCallBuilder hostRegisterCallBuilder = {
124 "mgpuMemHostRegisterMemRef",
129 FunctionCallBuilder hostUnregisterCallBuilder = {
130 "mgpuMemHostUnregisterMemRef",
135 FunctionCallBuilder allocCallBuilder = {
141 FunctionCallBuilder deallocCallBuilder = {
144 {llvmPointerType , llvmPointerType }};
145 FunctionCallBuilder memcpyCallBuilder = {
148 {llvmPointerType , llvmPointerType ,
151 FunctionCallBuilder memset16CallBuilder = {
158 FunctionCallBuilder memset32CallBuilder = {
161 {llvmPointerType , llvmInt32Type ,
164 FunctionCallBuilder setDefaultDeviceCallBuilder = {
165 "mgpuSetDefaultDevice",
168 FunctionCallBuilder createDnVecCallBuilder = {
171 {llvmIntPtrType, llvmPointerType, llvmInt32Type,
173 FunctionCallBuilder destroyDnVecCallBuilder = {
176 {llvmPointerType, llvmPointerType }};
177 FunctionCallBuilder createDnMatCallBuilder = {
180 {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
182 FunctionCallBuilder destroyDnMatCallBuilder = {
185 {llvmPointerType, llvmPointerType }};
186 FunctionCallBuilder createCooCallBuilder = {
189 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
190 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
192 FunctionCallBuilder createCooAoSCallBuilder = {
195 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
196 llvmPointerType, llvmInt32Type, llvmInt32Type,
198 FunctionCallBuilder createCsrCallBuilder = {
201 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
202 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
203 llvmInt32Type, llvmPointerType }};
204 FunctionCallBuilder createCscCallBuilder = {
207 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
208 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
209 llvmInt32Type, llvmPointerType }};
210 FunctionCallBuilder createBsrCallBuilder = {
213 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
214 llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
215 llvmInt32Type, llvmInt32Type, llvmInt32Type,
217 FunctionCallBuilder destroySpMatCallBuilder = {
220 {llvmPointerType, llvmPointerType }};
221 FunctionCallBuilder spMVBufferSizeCallBuilder = {
222 "mgpuSpMVBufferSize",
224 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
225 llvmInt32Type, llvmPointerType }};
226 FunctionCallBuilder spMVCallBuilder = {
229 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
230 llvmInt32Type, llvmPointerType, llvmPointerType }};
231 FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
232 "mgpuSpMMBufferSize",
234 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
235 llvmPointerType, llvmInt32Type, llvmPointerType }};
236 FunctionCallBuilder createSpMMCallBuilder = {
239 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
240 llvmPointerType, llvmInt32Type, llvmPointerType,
242 FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
243 "mgpuSDDMMBufferSize",
245 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
246 llvmPointerType, llvmInt32Type, llvmPointerType }};
247 FunctionCallBuilder createSDDMMCallBuilder = {
250 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
251 llvmPointerType, llvmInt32Type, llvmPointerType,
253 FunctionCallBuilder createLtDnMatCallBuilder = {
254 "mgpuCreateCuSparseLtDnMat",
256 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
257 llvmInt32Type, llvmPointerType }};
258 FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
259 "mgpuDestroyCuSparseLtSpMat",
261 {llvmPointerType, llvmPointerType }};
262 FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
263 "mgpuDestroyCuSparseLtDnMat",
265 {llvmPointerType, llvmPointerType }};
266 FunctionCallBuilder create2To4SpMatCallBuilder = {
267 "mgpuCusparseLtCreate2To4SpMat",
269 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
270 llvmInt32Type, llvmPointerType }};
271 FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
272 "mgpuCuSparseLtSpMMBufferSize",
274 {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
275 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
277 FunctionCallBuilder createCuSparseLtSpMMBuilder = {
278 "mgpuCuSparseLtSpMM",
280 {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
281 llvmPointerType, llvmPointerType, llvmPointerType }};
282 FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
283 "mgpuSpGEMMCreateDescr",
286 FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
287 "mgpuSpGEMMDestroyDescr",
289 {llvmPointerType , llvmPointerType }};
290 FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
291 "mgpuSpGEMMWorkEstimation",
293 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
294 llvmPointerType , llvmPointerType , llvmPointerType ,
295 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
297 FunctionCallBuilder createSpGEMMComputeBuilder = {
300 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
301 llvmPointerType , llvmPointerType , llvmPointerType ,
302 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
304 FunctionCallBuilder createSpGEMMCopyBuilder = {
307 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
308 llvmPointerType , llvmPointerType , llvmPointerType ,
309 llvmInt32Type , llvmPointerType }};
310 FunctionCallBuilder createSpMatGetSizeBuilder = {
313 {llvmPointerType , llvmPointerType , llvmPointerType ,
314 llvmPointerType , llvmPointerType }};
315 FunctionCallBuilder createSetCsrPointersBuilder = {
316 "mgpuSetCsrPointers",
318 {llvmPointerType , llvmPointerType ,
319 llvmPointerType , llvmPointerType ,
325class ConvertHostRegisterOpToGpuRuntimeCallPattern
326 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
328 ConvertHostRegisterOpToGpuRuntimeCallPattern(
329 const LLVMTypeConverter &typeConverter)
330 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
334 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter)
const override;
338class ConvertHostUnregisterOpToGpuRuntimeCallPattern
339 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
341 ConvertHostUnregisterOpToGpuRuntimeCallPattern(
342 const LLVMTypeConverter &typeConverter)
343 : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
348 matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter)
const override;
354class ConvertAllocOpToGpuRuntimeCallPattern
355 :
public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
357 ConvertAllocOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
358 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
362 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter)
const override;
368class ConvertDeallocOpToGpuRuntimeCallPattern
369 :
public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
371 ConvertDeallocOpToGpuRuntimeCallPattern(
372 const LLVMTypeConverter &typeConverter)
373 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
377 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
378 ConversionPatternRewriter &rewriter)
const override;
381class ConvertAsyncYieldToGpuRuntimeCallPattern
382 :
public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
384 ConvertAsyncYieldToGpuRuntimeCallPattern(
385 const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
386 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter,
391 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
392 ConversionPatternRewriter &rewriter)
const override;
397class ConvertWaitOpToGpuRuntimeCallPattern
398 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
400 ConvertWaitOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
401 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
405 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
406 ConversionPatternRewriter &rewriter)
const override;
411class ConvertWaitAsyncOpToGpuRuntimeCallPattern
412 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
414 ConvertWaitAsyncOpToGpuRuntimeCallPattern(
415 const LLVMTypeConverter &typeConverter)
416 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
420 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
421 ConversionPatternRewriter &rewriter)
const override;
425class LegalizeLaunchFuncOpPattern
426 :
public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
428 LegalizeLaunchFuncOpPattern(
const LLVMTypeConverter &typeConverter,
429 bool kernelBarePtrCallConv,
430 bool kernelIntersperseSizeCallConv)
431 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
432 kernelBarePtrCallConv(kernelBarePtrCallConv),
433 kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
437 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
438 ConversionPatternRewriter &rewriter)
const override;
440 bool kernelBarePtrCallConv;
441 bool kernelIntersperseSizeCallConv;
446class ConvertMemcpyOpToGpuRuntimeCallPattern
447 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
449 ConvertMemcpyOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
450 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
454 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
455 ConversionPatternRewriter &rewriter)
const override;
460class ConvertMemsetOpToGpuRuntimeCallPattern
461 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
463 ConvertMemsetOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
464 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
468 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
469 ConversionPatternRewriter &rewriter)
const override;
474class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
475 :
public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
477 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
478 const LLVMTypeConverter &typeConverter)
479 : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
483 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
484 ConversionPatternRewriter &rewriter)
const override;
489#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
490 class Convert##op_name##ToGpuRuntimeCallPattern \
491 : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
493 Convert##op_name##ToGpuRuntimeCallPattern( \
494 const LLVMTypeConverter &typeConverter) \
495 : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
499 matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
500 ConversionPatternRewriter &rewriter) const override; \
527void GpuToLLVMConversionPass::runOnOperation() {
538 vector::populateVectorFromElementsUnrollPatterns(patterns);
540 return signalPassFailure();
543 LowerToLLVMOptions
options(context);
544 options.useBarePtrCallConv = hostBarePtrCallConv;
545 RewritePatternSet patterns(context);
546 ConversionTarget
target(*context);
547 target.addLegalDialect<LLVM::LLVMDialect>();
548 LLVMTypeConverter converter(context,
options);
553 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
556 iface->populateConvertToLLVMConversionPatterns(
target, converter, patterns);
561 target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
563 target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
564 [&](gpu::LaunchFuncOp op) ->
bool {
return converter.isLegal(op); });
572 kernelBarePtrCallConv,
573 kernelIntersperseSizeCallConv);
576 applyPartialConversion(getOperation(),
target, std::move(patterns))))
582 auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
583 auto function = [&] {
584 if (
auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(
functionName))
589 return LLVM::CallOp::create(builder, loc, function, arguments);
606 llvm_unreachable(
"unsupported type");
612 if (llvm::isa<ComplexType>(type)) {
614 auto elementType = cast<ComplexType>(type).getElementType();
615 if (elementType.isBF16())
617 if (elementType.isF16())
619 if (elementType.isF32())
621 if (elementType.isF64())
623 if (elementType.isInteger(8))
625 if (elementType.isInteger(16))
627 if (elementType.isInteger(32))
645 llvm_unreachable(
"unsupported element type");
649 return spMat.
getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
674 llvm_unreachable(
"cannot find spmat def");
679 auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
691 ConversionPatternRewriter &rewriter) {
692 if (!llvm::all_of(operands, [](
Value value) {
695 return rewriter.notifyMatchFailure(
696 op,
"Cannot convert if operands aren't of LLVM type.");
702 gpu::AsyncOpInterface op) {
703 if (op.getAsyncDependencies().size() != 1)
704 return rewriter.notifyMatchFailure(
705 op,
"Can only convert with exactly one async dependency.");
707 if (!op.getAsyncToken())
708 return rewriter.notifyMatchFailure(op,
"Can convert only async version.");
713LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
714 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
715 ConversionPatternRewriter &rewriter)
const {
716 auto *op = hostRegisterOp.getOperation();
720 Location loc = op->getLoc();
722 auto memRefType = hostRegisterOp.getValue().getType();
723 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
726 auto arguments = getTypeConverter()->promoteOperands(
727 loc, op->getOperands(), adaptor.getOperands(), rewriter);
728 arguments.push_back(elementSize);
729 hostRegisterCallBuilder.create(loc, rewriter, arguments);
731 rewriter.eraseOp(op);
735LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
736 gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
737 ConversionPatternRewriter &rewriter)
const {
738 Operation *op = hostUnregisterOp.getOperation();
742 Location loc = op->
getLoc();
744 auto memRefType = hostUnregisterOp.getValue().getType();
745 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
748 auto arguments = getTypeConverter()->promoteOperands(
749 loc, op->
getOperands(), adaptor.getOperands(), rewriter);
750 arguments.push_back(elementSize);
751 hostUnregisterCallBuilder.create(loc, rewriter, arguments);
753 rewriter.eraseOp(op);
757LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
758 gpu::AllocOp allocOp, OpAdaptor adaptor,
759 ConversionPatternRewriter &rewriter)
const {
761 MemRefType memRefType = allocOp.getType();
764 !isConvertibleAndHasIdentityMaps(memRefType))
767 auto loc = allocOp.getLoc();
769 bool isShared = allocOp.getHostShared();
771 if (isShared && allocOp.getAsyncToken())
772 return rewriter.notifyMatchFailure(
773 allocOp,
"Host Shared allocation cannot be done async");
779 SmallVector<Value, 4> shape;
780 SmallVector<Value, 4> strides;
782 getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
783 shape, strides, sizeBytes);
787 auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType);
788 Value stream = adaptor.getAsyncDependencies().empty()
790 : adaptor.getAsyncDependencies().front();
792 auto isHostShared = mlir::LLVM::ConstantOp::create(
793 rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
796 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
800 Value alignedPtr = allocatedPtr;
803 auto memRefDescriptor = this->createMemRefDescriptor(
804 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
806 if (allocOp.getAsyncToken()) {
808 rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
810 rewriter.replaceOp(allocOp, {memRefDescriptor});
816LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
817 gpu::DeallocOp deallocOp, OpAdaptor adaptor,
818 ConversionPatternRewriter &rewriter)
const {
823 Location loc = deallocOp.getLoc();
826 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
827 Value stream = adaptor.getAsyncDependencies().front();
828 deallocCallBuilder.create(loc, rewriter, {pointer, stream});
830 rewriter.replaceOp(deallocOp, {stream});
835 return isa<gpu::AsyncTokenType>(value.
getType());
850LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
851 async::YieldOp yieldOp, OpAdaptor adaptor,
852 ConversionPatternRewriter &rewriter)
const {
854 return rewriter.notifyMatchFailure(yieldOp,
"no gpu async token operand");
856 Location loc = yieldOp.getLoc();
857 SmallVector<Value, 4> newOperands(adaptor.getOperands());
858 llvm::SmallDenseSet<Value> streams;
859 for (
auto &operand : yieldOp->getOpOperands()) {
862 auto idx = operand.getOperandNumber();
863 auto stream = adaptor.getOperands()[idx];
864 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
865 eventRecordCallBuilder.create(loc, rewriter, {event, stream});
866 newOperands[idx] = event;
867 streams.insert(stream);
869 for (
auto stream : streams)
870 streamDestroyCallBuilder.create(loc, rewriter, {stream});
872 rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
878 assert(isa<LLVM::LLVMPointerType>(value.
getType()));
880 return *defOp.getCallee() == functionName;
888LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
889 gpu::WaitOp waitOp, OpAdaptor adaptor,
890 ConversionPatternRewriter &rewriter)
const {
891 if (waitOp.getAsyncToken())
892 return rewriter.notifyMatchFailure(waitOp,
"Cannot convert async op.");
894 Location loc = waitOp.getLoc();
896 for (
auto operand : adaptor.getOperands()) {
899 streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
900 streamDestroyCallBuilder.create(loc, rewriter, {operand});
904 eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
905 eventDestroyCallBuilder.create(loc, rewriter, {operand});
909 rewriter.eraseOp(waitOp);
918LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
919 gpu::WaitOp waitOp, OpAdaptor adaptor,
920 ConversionPatternRewriter &rewriter)
const {
921 if (!waitOp.getAsyncToken())
922 return rewriter.notifyMatchFailure(waitOp,
"Can only convert async op.");
924 Location loc = waitOp.getLoc();
926 auto insertionPoint = rewriter.saveInsertionPoint();
927 SmallVector<Value, 1> events;
929 llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
930 auto operand = std::get<1>(pair);
934 auto *defOp = std::get<0>(pair).getDefiningOp();
935 rewriter.setInsertionPointAfter(defOp);
936 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
937 eventRecordCallBuilder.create(loc, rewriter, {event, operand});
938 events.push_back(event);
942 events.push_back(operand);
945 rewriter.restoreInsertionPoint(insertionPoint);
946 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
947 for (
auto event : events)
948 streamWaitEventCallBuilder.create(loc, rewriter, {stream,
event});
949 for (
auto event : events)
950 eventDestroyCallBuilder.create(loc, rewriter, {
event});
951 rewriter.replaceOp(waitOp, {stream});
957LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
958 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
959 ConversionPatternRewriter &rewriter)
const {
966 if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
967 return rewriter.notifyMatchFailure(
968 launchOp,
"Cannot convert non-async op with async dependencies.");
970 Location loc = launchOp.getLoc();
972 Value stream = Value();
973 if (!adaptor.getAsyncDependencies().empty()) {
974 stream = adaptor.getAsyncDependencies().front();
977 if (adaptor.getAsyncDependencies().size() > 1) {
978 auto insertionPoint = rewriter.saveInsertionPoint();
979 SmallVector<Value, 4> events;
980 for (
auto [origDep, convertedDep] :
981 llvm::zip(launchOp.getAsyncDependencies().drop_front(),
982 adaptor.getAsyncDependencies().drop_front())) {
984 streamCreateCallBuilder.functionName)) {
985 events.push_back(convertedDep);
988 Operation *defOp = origDep.getDefiningOp();
989 rewriter.setInsertionPointAfter(defOp);
991 eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
992 eventRecordCallBuilder.create(loc, rewriter, {event, convertedDep});
993 events.push_back(event);
995 rewriter.restoreInsertionPoint(insertionPoint);
996 for (Value event : events)
997 streamWaitEventCallBuilder.create(loc, rewriter, {stream,
event});
998 for (Value event : events)
999 eventDestroyCallBuilder.create(loc, rewriter, {
event});
1004 else if (launchOp.getAsyncToken())
1005 stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
1010 OperandRange origArguments = launchOp.getKernelOperands();
1011 bool effectiveBarePtr = kernelBarePtrCallConv ||
1012 getTypeConverter()->getOptions().useBarePtrCallConv;
1013 if (effectiveBarePtr) {
1014 for (Value arg : origArguments) {
1015 if (isa<UnrankedMemRefType>(arg.getType()))
1016 return rewriter.notifyMatchFailure(
1017 loc,
"unranked memref kernel argument is not supported with "
1018 "the bare-pointer calling convention");
1021 SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
1022 loc, origArguments, adaptor.getKernelOperands(), rewriter,
1023 kernelBarePtrCallConv);
1024 SmallVector<Value, 8> llvmArgumentsWithSizes;
1027 if (kernelIntersperseSizeCallConv) {
1028 if (origArguments.size() != llvmArguments.size()) {
1030 return rewriter.notifyMatchFailure(
1032 "Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
1035 llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
1036 for (
auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
1037 auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
1039 return rewriter.notifyMatchFailure(
1040 launchOp,
"Operand to launch op is not a memref.");
1043 if (!memrefTy.hasStaticShape() ||
1044 !memrefTy.getElementType().isIntOrFloat()) {
1045 return rewriter.notifyMatchFailure(
1046 launchOp,
"Operand to launch op is not a memref with a static "
1047 "shape and an integer or float element type.");
1050 unsigned bitwidth = memrefTy.getElementTypeBitWidth();
1051 if (bitwidth % 8 != 0) {
1052 return rewriter.notifyMatchFailure(
1053 launchOp,
"Operand to launch op is not a memref with a "
1054 "byte-aligned element type.");
1057 uint64_t staticSize =
static_cast<uint64_t
>(bitwidth / 8) *
1058 static_cast<uint64_t
>(memrefTy.getNumElements());
1060 Value sizeArg = LLVM::ConstantOp::create(
1061 rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize));
1062 llvmArgumentsWithSizes.push_back(llvmArg);
1063 llvmArgumentsWithSizes.push_back(sizeArg);
1067 std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1068 if (launchOp.hasClusterSize()) {
1070 gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
1071 adaptor.getClusterSizeZ()};
1073 gpu::LaunchFuncOp::create(
1074 rewriter, launchOp.getLoc(), launchOp.getKernelAttr(),
1075 gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1076 adaptor.getGridSizeZ()},
1077 gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1078 adaptor.getBlockSizeZ()},
1079 adaptor.getDynamicSharedMemorySize(),
1080 llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
1081 stream, clusterSize);
1082 if (launchOp.getAsyncToken())
1083 rewriter.replaceOp(launchOp, {stream});
1085 rewriter.eraseOp(launchOp);
1090 ConversionPatternRewriter &rewriter,
1091 LLVM::LLVMPointerType destinationType,
1094 auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.
getType());
1095 if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1096 sourcePtr = LLVM::AddrSpaceCastOp::create(
1098 LLVM::LLVMPointerType::get(rewriter.getContext(),
1099 destinationType.getAddressSpace()),
1104LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1105 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1106 ConversionPatternRewriter &rewriter)
const {
1107 auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1110 !isConvertibleAndHasIdentityMaps(memRefType) ||
1114 auto loc = memcpyOp.getLoc();
1116 MemRefDescriptor srcDesc(adaptor.getSrc());
1117 Value numElements =
getNumElements(rewriter, loc, memRefType, srcDesc);
1119 Type elementPtrType = getElementPtrType(memRefType);
1120 Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
1121 Value gepPtr = LLVM::GEPOp::create(
1122 rewriter, loc, elementPtrType,
1123 typeConverter->convertType(memRefType.getElementType()), nullPtr,
1126 LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr);
1129 srcDesc.alignedPtr(rewriter, loc),
1130 *getTypeConverter());
1132 loc, rewriter, llvmPointerType,
1133 MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
1134 *getTypeConverter());
1136 auto stream = adaptor.getAsyncDependencies().front();
1137 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1139 rewriter.replaceOp(memcpyOp, {stream});
1144LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1145 gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1146 ConversionPatternRewriter &rewriter)
const {
1147 auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1150 !isConvertibleAndHasIdentityMaps(memRefType) ||
1154 auto loc = memsetOp.getLoc();
1156 Type valueType = adaptor.getValue().getType();
1159 if (!valueType.
isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1160 return rewriter.notifyMatchFailure(
1161 memsetOp,
"value must be a 16 or 32 bit int or float");
1165 Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1167 MemRefDescriptor dstDesc(adaptor.getDst());
1168 Value numElements =
getNumElements(rewriter, loc, memRefType, dstDesc);
1171 LLVM::BitcastOp::create(rewriter, loc, bitCastType, adaptor.getValue());
1173 dstDesc.alignedPtr(rewriter, loc),
1174 *getTypeConverter());
1176 auto stream = adaptor.getAsyncDependencies().front();
1177 FunctionCallBuilder builder =
1178 valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1179 builder.
create(loc, rewriter, {dst, value, numElements, stream});
1181 rewriter.replaceOp(memsetOp, {stream});
1185LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1186 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1187 ConversionPatternRewriter &rewriter)
const {
1188 Location loc = op.getLoc();
1189 auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1190 {adaptor.getDevIndex()});
1191 rewriter.replaceOp(op, call);
1195template <
typename T>
1198 return LLVM::ConstantOp::create(builder, loc, llvmInt32Type,
1199 static_cast<int32_t
>(tValue));
1202template <
typename T>
1205 return LLVM::ConstantOp::create(
1206 builder, loc, llvmFloat32Type,
1210LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1211 gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1212 ConversionPatternRewriter &rewriter)
const {
1216 Location loc = op.getLoc();
1217 auto stream = adaptor.getAsyncDependencies().front();
1219 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1220 Type dType = op.getMemref().
getType().getElementType();
1223 SmallVector<Value, 4> dims;
1224 for (Value dim : adaptor.getDims()) {
1225 dims.push_back(dim);
1235 if (dims.size() == 2) {
1237 auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1238 rewriter.getIndexAttr(11032));
1239 handle = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1240 llvmInt8Type, handleSz, 16);
1241 handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1243 createLtDnMatCallBuilder
1244 .create(loc, rewriter,
1245 {handle, dims[0], dims[1], pTensor, dtp, stream})
1249 createDnMatCallBuilder
1250 .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1254 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1255 handle = createDnVecCallBuilder
1256 .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1259 rewriter.replaceOp(op, {handle, stream});
1263LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1264 gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1265 ConversionPatternRewriter &rewriter)
const {
1269 Location loc = op.getLoc();
1270 auto stream = adaptor.getAsyncDependencies().front();
1271 auto definingOp = op.getDnTensor().
getDefiningOp<gpu::CreateDnTensorOp>();
1272 SmallVector<Value, 4> dims;
1273 for (Value dim : definingOp.getDims()) {
1274 dims.push_back(dim);
1276 if (dims.size() == 2) {
1280 destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1281 {adaptor.getDnTensor(), stream});
1283 destroyDnMatCallBuilder.create(loc, rewriter,
1284 {adaptor.getDnTensor(), stream});
1287 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1288 destroyDnVecCallBuilder.create(loc, rewriter,
1289 {adaptor.getDnTensor(), stream});
1291 rewriter.replaceOp(op, {stream});
1295LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1296 gpu::CreateCooOp op, OpAdaptor adaptor,
1297 ConversionPatternRewriter &rewriter)
const {
1301 Location loc = op.getLoc();
1302 auto stream = adaptor.getAsyncDependencies().front();
1304 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1306 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1308 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1310 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1312 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1316 createCooCallBuilder
1317 .create(loc, rewriter,
1318 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1319 pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1321 rewriter.replaceOp(op, {handle, stream});
1325LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1326 gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1327 ConversionPatternRewriter &rewriter)
const {
1331 Location loc = op.getLoc();
1332 auto stream = adaptor.getAsyncDependencies().front();
1333 Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc);
1335 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1336 Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1338 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1342 createCooAoSCallBuilder
1343 .create(loc, rewriter,
1344 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1345 pIdxs, pValues, itp, dtp, stream})
1347 rewriter.replaceOp(op, {handle, stream});
1351LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1352 gpu::CreateCsrOp op, OpAdaptor adaptor,
1353 ConversionPatternRewriter &rewriter)
const {
1357 Location loc = op.getLoc();
1358 auto stream = adaptor.getAsyncDependencies().front();
1360 MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc);
1362 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1364 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1366 llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1368 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1370 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1375 createCsrCallBuilder
1376 .create(loc, rewriter,
1377 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1378 pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1380 rewriter.replaceOp(op, {handle, stream});
1384LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1385 gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1386 ConversionPatternRewriter &rewriter)
const {
1390 Location loc = op.getLoc();
1391 auto stream = adaptor.getAsyncDependencies().front();
1393 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1395 llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1399 auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1400 rewriter.getIndexAttr(44104));
1401 Value handle = LLVM::AllocaOp::create(
1402 rewriter, loc, llvmPointerType, llvmInt8Type, handleSz, 16);
1403 handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1405 create2To4SpMatCallBuilder
1406 .create(loc, rewriter,
1407 {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1409 rewriter.replaceOp(op, {handle, stream});
1413LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1414 gpu::DestroySpMatOp op, OpAdaptor adaptor,
1415 ConversionPatternRewriter &rewriter)
const {
1419 Location loc = op.getLoc();
1420 auto stream = adaptor.getAsyncDependencies().front();
1423 destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1424 {adaptor.getSpmat(), stream});
1427 destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1429 rewriter.replaceOp(op, {stream});
1433LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1434 gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1435 ConversionPatternRewriter &rewriter)
const {
1439 Location loc = op.getLoc();
1443 auto stream = adaptor.getAsyncDependencies().front();
1444 auto bufferSize = spMVBufferSizeCallBuilder
1445 .create(loc, rewriter,
1446 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1447 adaptor.getDnY(), computeType, stream})
1449 rewriter.replaceOp(op, {bufferSize, stream});
1453LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1454 gpu::SpMVOp op, OpAdaptor adaptor,
1455 ConversionPatternRewriter &rewriter)
const {
1459 Location loc = op.getLoc();
1463 auto stream = adaptor.getAsyncDependencies().front();
1465 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1466 spMVCallBuilder.create(loc, rewriter,
1467 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1468 adaptor.getDnY(), computeType, pBuf, stream});
1469 rewriter.replaceOp(op, {stream});
1473LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1474 gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1475 ConversionPatternRewriter &rewriter)
const {
1479 Location loc = op.getLoc();
1482 auto stream = adaptor.getAsyncDependencies().front();
1489 auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1490 rewriter.getIndexAttr(3));
1492 LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmPointerType,
1494 createCuSparseLtSpMMBufferSizeBuilder
1495 .create(loc, rewriter,
1496 {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1497 adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1501 auto bufferSizePtr1 = LLVM::GEPOp::create(
1502 rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1503 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1504 rewriter.getIndexAttr(1))});
1505 auto bufferSizePtr2 = LLVM::GEPOp::create(
1506 rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1507 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1508 rewriter.getIndexAttr(2))});
1510 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSize);
1512 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr1);
1514 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr2);
1516 rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1521 createSpMMBufferSizeCallBuilder
1522 .create(loc, rewriter,
1523 {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1524 adaptor.getDnmatC(), computeType, stream})
1526 rewriter.replaceOp(op, {bufferSize, stream});
1531LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1532 gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1533 ConversionPatternRewriter &rewriter)
const {
1537 Location loc = op.getLoc();
1542 auto stream = adaptor.getAsyncDependencies().front();
1544 createSDDMMBufferSizeCallBuilder
1545 .create(loc, rewriter,
1546 {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1547 adaptor.getSpmatC(), computeType, stream})
1549 rewriter.replaceOp(op, {bufferSize, stream});
1553LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1554 gpu::SpMMOp op, OpAdaptor adaptor,
1555 ConversionPatternRewriter &rewriter)
const {
1559 Location loc = op.getLoc();
1565 auto stream = adaptor.getAsyncDependencies().front();
1569 SmallVector<Value> pBufs;
1570 for (Value buffer : adaptor.getBuffers()) {
1571 Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
1572 pBufs.push_back(pBuf);
1574 createCuSparseLtSpMMBuilder.create(
1576 {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1577 pBufs[0], pBufs[1], pBufs[2], stream});
1579 Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
1580 .allocatedPtr(rewriter, loc);
1581 createSpMMCallBuilder.create(loc, rewriter,
1582 {modeA, modeB, adaptor.getSpmatA(),
1583 adaptor.getDnmatB(), adaptor.getDnmatC(),
1584 computeType, pBuf, stream});
1586 rewriter.replaceOp(op, {stream});
1590template <
typename T>
1592 converter.addConversion([&converter](T) ->
Type {
1593 return LLVM::LLVMPointerType::get(&converter.
getContext());
1597LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1598 gpu::SDDMMOp op, OpAdaptor adaptor,
1599 ConversionPatternRewriter &rewriter)
const {
1603 Location loc = op.getLoc();
1608 auto stream = adaptor.getAsyncDependencies().front();
1610 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1611 createSDDMMCallBuilder.create(loc, rewriter,
1612 {modeA, modeB, adaptor.getDnmatA(),
1613 adaptor.getDnmatB(), adaptor.getSpmatC(),
1614 computeType, pBuf, stream});
1615 rewriter.replaceOp(op, {stream});
1620ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1621 gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1622 ConversionPatternRewriter &rewriter)
const {
1626 Location loc = op.getLoc();
1627 auto stream = adaptor.getAsyncDependencies().front();
1628 Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1630 rewriter.replaceOp(op, {descr, stream});
1635ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1636 gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1637 ConversionPatternRewriter &rewriter)
const {
1641 Location loc = op.getLoc();
1642 auto stream = adaptor.getAsyncDependencies().front();
1643 createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1644 {adaptor.getDesc(), stream});
1645 rewriter.replaceOp(op, {stream});
1650ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1651 gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1652 ConversionPatternRewriter &rewriter)
const {
1656 Location loc = op.getLoc();
1661 auto stream = adaptor.getAsyncDependencies().front();
1664 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1665 Value bufferSizeNew;
1667 if (adaptor.getKind() ==
1668 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1670 createSpGEMMWorkEstimationBuilder
1671 .create(loc, rewriter,
1672 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1673 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1674 adaptor.getBufferSz(), pBuf, stream})
1678 createSpGEMMComputeBuilder
1679 .create(loc, rewriter,
1680 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1681 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1682 adaptor.getBufferSz(), pBuf, stream})
1685 rewriter.replaceOp(op, {bufferSizeNew, stream});
1689LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1690 gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1691 ConversionPatternRewriter &rewriter)
const {
1695 Location loc = op.getLoc();
1700 auto stream = adaptor.getAsyncDependencies().front();
1701 createSpGEMMCopyBuilder.create(loc, rewriter,
1702 {adaptor.getDesc(), modeA, modeB,
1703 adaptor.getSpmatA(), adaptor.getSpmatB(),
1704 adaptor.getSpmatC(), computeType, stream});
1705 rewriter.replaceOp(op, {stream});
1709LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1710 gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1711 ConversionPatternRewriter &rewriter)
const {
1715 Location loc = op.getLoc();
1716 auto stream = adaptor.getAsyncDependencies().front();
1718 auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1719 rewriter.getIndexAttr(3));
1720 auto buffer = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1721 llvmInt64Type, three, 16);
1723 auto rowsPtr = LLVM::GEPOp::create(
1724 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1725 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1726 rewriter.getIndexAttr(0))});
1727 auto colsPtr = LLVM::GEPOp::create(
1728 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1729 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1730 rewriter.getIndexAttr(1))});
1731 auto nnzsPtr = LLVM::GEPOp::create(
1732 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1733 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1734 rewriter.getIndexAttr(2))});
1735 createSpMatGetSizeBuilder.create(
1736 loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1737 auto rows = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, rowsPtr);
1738 auto cols = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, colsPtr);
1739 auto nnzs = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, nnzsPtr);
1741 rewriter.replaceOp(op, {rows, cols, nnzs, stream});
1745LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1746 gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1747 ConversionPatternRewriter &rewriter)
const {
1751 Location loc = op.getLoc();
1752 auto stream = adaptor.getAsyncDependencies().front();
1754 MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
1756 MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
1758 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1759 createSetCsrPointersBuilder.create(
1760 loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1761 rewriter.replaceOp(op, {stream});
1765LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1766 gpu::CreateCscOp op, OpAdaptor adaptor,
1767 ConversionPatternRewriter &rewriter)
const {
1771 Location loc = op.getLoc();
1772 auto stream = adaptor.getAsyncDependencies().front();
1774 MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc);
1776 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1778 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1780 llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1782 llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1784 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1789 createCscCallBuilder
1790 .create(loc, rewriter,
1791 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1792 pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1794 rewriter.replaceOp(op, {handle, stream});
1798LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1799 gpu::CreateBsrOp op, OpAdaptor adaptor,
1800 ConversionPatternRewriter &rewriter)
const {
1804 Location loc = op.getLoc();
1805 auto stream = adaptor.getAsyncDependencies().front();
1807 MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc);
1809 MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc);
1811 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1813 llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1815 llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1817 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1822 createBsrCallBuilder
1823 .create(loc, rewriter,
1824 {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1825 adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1826 pColIdxs, pValues, ptp, itp, dtp, stream})
1828 rewriter.replaceOp(op, {handle, stream});
1834 bool kernelBarePtrCallConv,
bool kernelIntersperseSizeCallConv) {
1844 patterns.
add<ConvertAsyncYieldToGpuRuntimeCallPattern>(converter,
1847 patterns.
add<ConvertAllocOpToGpuRuntimeCallPattern,
1848 ConvertDeallocOpToGpuRuntimeCallPattern,
1849 ConvertHostRegisterOpToGpuRuntimeCallPattern,
1850 ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1851 ConvertMemcpyOpToGpuRuntimeCallPattern,
1852 ConvertMemsetOpToGpuRuntimeCallPattern,
1853 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1854 ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1855 ConvertWaitOpToGpuRuntimeCallPattern,
1856 ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1857 ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1858 ConvertCreateCooOpToGpuRuntimeCallPattern,
1859 ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1860 ConvertCreateCsrOpToGpuRuntimeCallPattern,
1861 ConvertCreateCscOpToGpuRuntimeCallPattern,
1862 ConvertCreateBsrOpToGpuRuntimeCallPattern,
1863 ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1864 ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1865 ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1866 ConvertSpMVOpToGpuRuntimeCallPattern,
1867 ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1868 ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1869 ConvertSpMMOpToGpuRuntimeCallPattern,
1870 ConvertSDDMMOpToGpuRuntimeCallPattern,
1871 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1872 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1873 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1874 ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1875 ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1876 ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
1877 patterns.
add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv,
1878 kernelIntersperseSizeCallConv);
1886struct GPUModuleOpConvertToLLVMInterface
1887 :
public ConvertToLLVMOpInterface::ExternalModel<
1888 GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
1890 void getConvertToLLVMConversionAttrs(
1895void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
1896 Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs)
const {
1897 auto module = cast<gpu::GPUModuleOp>(op);
1898 ArrayAttr targetsAttr =
module.getTargetsAttr();
1900 if (!targetsAttr || targetsAttr.size() != 1)
1902 if (
auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
1903 attrs.push_back(patternAttr);
1908 gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
IntegerType getIntegerType(unsigned width)
FloatAttr getF32FloatAttr(float value)
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
This class helps build Operations.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void registerConvertGpuToLLVMInterface(DialectRegistry ®istry)
Registers the ConvertToLLVMOpInterface interface on the gpu::GPUModuleOP operation.
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool kernelBarePtrCallConv=false, bool kernelIntersperseSizeCallConv=false)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
void registerConvertToLLVMDependentDialectLoading(DialectRegistry ®istry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
LLVM::LLVMFunctionType functionType
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const