40#include "llvm/ADT/STLExtras.h"
42#define DEBUG_TYPE "gpu-to-llvm"
45#define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
46#include "mlir/Conversion/Passes.h.inc"
52class GpuToLLVMConversionPass
53 :
public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
56 void getDependentDialects(DialectRegistry ®istry)
const final {
57 Base::getDependentDialects(registry);
61 void runOnOperation()
override;
64template <
typename OpTy>
67 explicit ConvertOpToGpuRuntimeCallPattern(
68 const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
69 : ConvertOpToLLVMPattern<OpTy>(typeConverter, benefit) {}
72 Value
getNumElements(ConversionPatternRewriter &rewriter, Location loc,
73 MemRefType type, MemRefDescriptor desc)
const {
75 if (type.hasStaticShape())
77 rewriter, loc, indexType, type.getNumElements());
79 uint64_t rank = type.getRank();
80 Value numElements = desc.
size(rewriter, loc, 0);
81 for (
unsigned i = 1; i < rank; i++)
82 numElements = LLVM::MulOp::create(rewriter, loc, numElements,
83 desc.
size(rewriter, loc, i));
87 MLIRContext *context = &this->getTypeConverter()->
getContext();
89 Type llvmVoidType = LLVM::LLVMVoidType::get(context);
90 LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
91 Type llvmInt8Type = IntegerType::get(context, 8);
92 Type llvmInt16Type = IntegerType::get(context, 16);
93 Type llvmInt32Type = IntegerType::get(context, 32);
94 Type llvmInt64Type = IntegerType::get(context, 64);
95 Type llvmFloat32Type = Float32Type::get(context);
96 Type llvmIntPtrType = IntegerType::get(
97 context, this->getTypeConverter()->getPointerBitwidth(0));
99 FunctionCallBuilder streamCreateCallBuilder = {
100 "mgpuStreamCreate", llvmPointerType , {}};
101 FunctionCallBuilder streamDestroyCallBuilder = {
102 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType }};
103 FunctionCallBuilder streamSynchronizeCallBuilder = {
104 "mgpuStreamSynchronize",
107 FunctionCallBuilder streamWaitEventCallBuilder = {
108 "mgpuStreamWaitEvent",
110 {llvmPointerType , llvmPointerType }};
111 FunctionCallBuilder eventCreateCallBuilder = {
112 "mgpuEventCreate", llvmPointerType , {}};
113 FunctionCallBuilder eventDestroyCallBuilder = {
114 "mgpuEventDestroy", llvmVoidType, {llvmPointerType }};
115 FunctionCallBuilder eventSynchronizeCallBuilder = {
116 "mgpuEventSynchronize",
119 FunctionCallBuilder eventRecordCallBuilder = {
122 {llvmPointerType , llvmPointerType }};
123 FunctionCallBuilder hostRegisterCallBuilder = {
124 "mgpuMemHostRegisterMemRef",
129 FunctionCallBuilder hostUnregisterCallBuilder = {
130 "mgpuMemHostUnregisterMemRef",
135 FunctionCallBuilder allocCallBuilder = {
141 FunctionCallBuilder deallocCallBuilder = {
144 {llvmPointerType , llvmPointerType }};
145 FunctionCallBuilder memcpyCallBuilder = {
148 {llvmPointerType , llvmPointerType ,
151 FunctionCallBuilder memset16CallBuilder = {
158 FunctionCallBuilder memset32CallBuilder = {
161 {llvmPointerType , llvmInt32Type ,
164 FunctionCallBuilder setDefaultDeviceCallBuilder = {
165 "mgpuSetDefaultDevice",
168 FunctionCallBuilder createDnVecCallBuilder = {
171 {llvmIntPtrType, llvmPointerType, llvmInt32Type,
173 FunctionCallBuilder destroyDnVecCallBuilder = {
176 {llvmPointerType, llvmPointerType }};
177 FunctionCallBuilder createDnMatCallBuilder = {
180 {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
182 FunctionCallBuilder destroyDnMatCallBuilder = {
185 {llvmPointerType, llvmPointerType }};
186 FunctionCallBuilder createCooCallBuilder = {
189 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
190 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
192 FunctionCallBuilder createCooAoSCallBuilder = {
195 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
196 llvmPointerType, llvmInt32Type, llvmInt32Type,
198 FunctionCallBuilder createCsrCallBuilder = {
201 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
202 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
203 llvmInt32Type, llvmPointerType }};
204 FunctionCallBuilder createCscCallBuilder = {
207 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
208 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
209 llvmInt32Type, llvmPointerType }};
210 FunctionCallBuilder createBsrCallBuilder = {
213 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
214 llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
215 llvmInt32Type, llvmInt32Type, llvmInt32Type,
217 FunctionCallBuilder destroySpMatCallBuilder = {
220 {llvmPointerType, llvmPointerType }};
221 FunctionCallBuilder spMVBufferSizeCallBuilder = {
222 "mgpuSpMVBufferSize",
224 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
225 llvmInt32Type, llvmPointerType }};
226 FunctionCallBuilder spMVCallBuilder = {
229 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
230 llvmInt32Type, llvmPointerType, llvmPointerType }};
231 FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
232 "mgpuSpMMBufferSize",
234 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
235 llvmPointerType, llvmInt32Type, llvmPointerType }};
236 FunctionCallBuilder createSpMMCallBuilder = {
239 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
240 llvmPointerType, llvmInt32Type, llvmPointerType,
242 FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
243 "mgpuSDDMMBufferSize",
245 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
246 llvmPointerType, llvmInt32Type, llvmPointerType }};
247 FunctionCallBuilder createSDDMMCallBuilder = {
250 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
251 llvmPointerType, llvmInt32Type, llvmPointerType,
253 FunctionCallBuilder createLtDnMatCallBuilder = {
254 "mgpuCreateCuSparseLtDnMat",
256 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
257 llvmInt32Type, llvmPointerType }};
258 FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
259 "mgpuDestroyCuSparseLtSpMat",
261 {llvmPointerType, llvmPointerType }};
262 FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
263 "mgpuDestroyCuSparseLtDnMat",
265 {llvmPointerType, llvmPointerType }};
266 FunctionCallBuilder create2To4SpMatCallBuilder = {
267 "mgpuCusparseLtCreate2To4SpMat",
269 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
270 llvmInt32Type, llvmPointerType }};
271 FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
272 "mgpuCuSparseLtSpMMBufferSize",
274 {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
275 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
277 FunctionCallBuilder createCuSparseLtSpMMBuilder = {
278 "mgpuCuSparseLtSpMM",
280 {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
281 llvmPointerType, llvmPointerType, llvmPointerType }};
282 FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
283 "mgpuSpGEMMCreateDescr",
286 FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
287 "mgpuSpGEMMDestroyDescr",
289 {llvmPointerType , llvmPointerType }};
290 FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
291 "mgpuSpGEMMWorkEstimation",
293 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
294 llvmPointerType , llvmPointerType , llvmPointerType ,
295 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
297 FunctionCallBuilder createSpGEMMComputeBuilder = {
300 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
301 llvmPointerType , llvmPointerType , llvmPointerType ,
302 llvmInt32Type , llvmIntPtrType , llvmPointerType ,
304 FunctionCallBuilder createSpGEMMCopyBuilder = {
307 {llvmPointerType , llvmInt32Type , llvmInt32Type ,
308 llvmPointerType , llvmPointerType , llvmPointerType ,
309 llvmInt32Type , llvmPointerType }};
310 FunctionCallBuilder createSpMatGetSizeBuilder = {
313 {llvmPointerType , llvmPointerType , llvmPointerType ,
314 llvmPointerType , llvmPointerType }};
315 FunctionCallBuilder createSetCsrPointersBuilder = {
316 "mgpuSetCsrPointers",
318 {llvmPointerType , llvmPointerType ,
319 llvmPointerType , llvmPointerType ,
325class ConvertHostRegisterOpToGpuRuntimeCallPattern
326 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
328 ConvertHostRegisterOpToGpuRuntimeCallPattern(
329 const LLVMTypeConverter &typeConverter)
330 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
334 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter)
const override;
338class ConvertHostUnregisterOpToGpuRuntimeCallPattern
339 :
public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
341 ConvertHostUnregisterOpToGpuRuntimeCallPattern(
342 const LLVMTypeConverter &typeConverter)
343 : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
348 matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter)
const override;
354class ConvertAllocOpToGpuRuntimeCallPattern
355 :
public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
357 ConvertAllocOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
358 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
362 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter)
const override;
368class ConvertDeallocOpToGpuRuntimeCallPattern
369 :
public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
371 ConvertDeallocOpToGpuRuntimeCallPattern(
372 const LLVMTypeConverter &typeConverter)
373 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
377 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
378 ConversionPatternRewriter &rewriter)
const override;
381class ConvertAsyncYieldToGpuRuntimeCallPattern
382 :
public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
384 ConvertAsyncYieldToGpuRuntimeCallPattern(
385 const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
386 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter,
391 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
392 ConversionPatternRewriter &rewriter)
const override;
397class ConvertWaitOpToGpuRuntimeCallPattern
398 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
400 ConvertWaitOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
401 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
405 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
406 ConversionPatternRewriter &rewriter)
const override;
411class ConvertWaitAsyncOpToGpuRuntimeCallPattern
412 :
public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
414 ConvertWaitAsyncOpToGpuRuntimeCallPattern(
415 const LLVMTypeConverter &typeConverter)
416 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
420 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
421 ConversionPatternRewriter &rewriter)
const override;
425class LegalizeLaunchFuncOpPattern
426 :
public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
428 LegalizeLaunchFuncOpPattern(
const LLVMTypeConverter &typeConverter,
429 bool kernelBarePtrCallConv,
430 bool kernelIntersperseSizeCallConv)
431 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
432 kernelBarePtrCallConv(kernelBarePtrCallConv),
433 kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
437 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
438 ConversionPatternRewriter &rewriter)
const override;
440 bool kernelBarePtrCallConv;
441 bool kernelIntersperseSizeCallConv;
446class ConvertMemcpyOpToGpuRuntimeCallPattern
447 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
449 ConvertMemcpyOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
450 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
454 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
455 ConversionPatternRewriter &rewriter)
const override;
460class ConvertMemsetOpToGpuRuntimeCallPattern
461 :
public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
463 ConvertMemsetOpToGpuRuntimeCallPattern(
const LLVMTypeConverter &typeConverter)
464 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
468 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
469 ConversionPatternRewriter &rewriter)
const override;
474class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
475 :
public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
477 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
478 const LLVMTypeConverter &typeConverter)
479 : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
483 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
484 ConversionPatternRewriter &rewriter)
const override;
489#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
490 class Convert##op_name##ToGpuRuntimeCallPattern \
491 : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
493 Convert##op_name##ToGpuRuntimeCallPattern( \
494 const LLVMTypeConverter &typeConverter) \
495 : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
499 matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
500 ConversionPatternRewriter &rewriter) const override; \
527void GpuToLLVMConversionPass::runOnOperation() {
538 vector::populateVectorFromElementsUnrollPatterns(patterns);
540 return signalPassFailure();
543 LowerToLLVMOptions
options(context);
544 options.useBarePtrCallConv = hostBarePtrCallConv;
545 RewritePatternSet patterns(context);
546 ConversionTarget
target(*context);
547 target.addLegalDialect<LLVM::LLVMDialect>();
548 LLVMTypeConverter converter(context,
options);
553 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
556 iface->populateConvertToLLVMConversionPatterns(
target, converter, patterns);
561 target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
563 target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
564 [&](gpu::LaunchFuncOp op) ->
bool {
return converter.isLegal(op); });
572 kernelBarePtrCallConv,
573 kernelIntersperseSizeCallConv);
576 applyPartialConversion(getOperation(),
target, std::move(patterns))))
582 auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
583 auto function = [&] {
584 if (
auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(
functionName))
589 return LLVM::CallOp::create(builder, loc, function, arguments);
606 llvm_unreachable(
"unsupported type");
612 if (llvm::isa<ComplexType>(type)) {
614 auto elementType = cast<ComplexType>(type).getElementType();
615 if (elementType.isBF16())
617 if (elementType.isF16())
619 if (elementType.isF32())
621 if (elementType.isF64())
623 if (elementType.isInteger(8))
625 if (elementType.isInteger(16))
627 if (elementType.isInteger(32))
645 llvm_unreachable(
"unsupported element type");
649 return spMat.
getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
674 llvm_unreachable(
"cannot find spmat def");
679 auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
691 ConversionPatternRewriter &rewriter) {
692 if (!llvm::all_of(operands, [](
Value value) {
695 return rewriter.notifyMatchFailure(
696 op,
"Cannot convert if operands aren't of LLVM type.");
702 gpu::AsyncOpInterface op) {
703 if (op.getAsyncDependencies().size() != 1)
704 return rewriter.notifyMatchFailure(
705 op,
"Can only convert with exactly one async dependency.");
707 if (!op.getAsyncToken())
708 return rewriter.notifyMatchFailure(op,
"Can convert only async version.");
713LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
714 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
715 ConversionPatternRewriter &rewriter)
const {
716 auto *op = hostRegisterOp.getOperation();
720 Location loc = op->getLoc();
722 auto memRefType = hostRegisterOp.getValue().getType();
723 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
726 auto arguments = getTypeConverter()->promoteOperands(
727 loc, op->getOperands(), adaptor.getOperands(), rewriter);
728 arguments.push_back(elementSize);
729 hostRegisterCallBuilder.create(loc, rewriter, arguments);
731 rewriter.eraseOp(op);
735LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
736 gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
737 ConversionPatternRewriter &rewriter)
const {
738 Operation *op = hostUnregisterOp.getOperation();
742 Location loc = op->
getLoc();
744 auto memRefType = hostUnregisterOp.getValue().getType();
745 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
748 auto arguments = getTypeConverter()->promoteOperands(
749 loc, op->
getOperands(), adaptor.getOperands(), rewriter);
750 arguments.push_back(elementSize);
751 hostUnregisterCallBuilder.create(loc, rewriter, arguments);
753 rewriter.eraseOp(op);
757LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
758 gpu::AllocOp allocOp, OpAdaptor adaptor,
759 ConversionPatternRewriter &rewriter)
const {
761 MemRefType memRefType = allocOp.getType();
764 !isConvertibleAndHasIdentityMaps(memRefType))
767 auto loc = allocOp.getLoc();
769 bool isShared = allocOp.getHostShared();
771 if (isShared && allocOp.getAsyncToken())
772 return rewriter.notifyMatchFailure(
773 allocOp,
"Host Shared allocation cannot be done async");
779 SmallVector<Value, 4> shape;
780 SmallVector<Value, 4> strides;
782 getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
783 shape, strides, sizeBytes);
787 auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType);
788 Value stream = adaptor.getAsyncDependencies().empty()
790 : adaptor.getAsyncDependencies().front();
792 auto isHostShared = mlir::LLVM::ConstantOp::create(
793 rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
796 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
800 Value alignedPtr = allocatedPtr;
803 auto memRefDescriptor = this->createMemRefDescriptor(
804 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
806 if (allocOp.getAsyncToken()) {
808 rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
810 rewriter.replaceOp(allocOp, {memRefDescriptor});
816LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
817 gpu::DeallocOp deallocOp, OpAdaptor adaptor,
818 ConversionPatternRewriter &rewriter)
const {
823 Location loc = deallocOp.getLoc();
826 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
827 Value stream = adaptor.getAsyncDependencies().front();
828 deallocCallBuilder.create(loc, rewriter, {pointer, stream});
830 rewriter.replaceOp(deallocOp, {stream});
835 return isa<gpu::AsyncTokenType>(value.
getType());
850LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
851 async::YieldOp yieldOp, OpAdaptor adaptor,
852 ConversionPatternRewriter &rewriter)
const {
854 return rewriter.notifyMatchFailure(yieldOp,
"no gpu async token operand");
856 Location loc = yieldOp.getLoc();
857 SmallVector<Value, 4> newOperands(adaptor.getOperands());
858 llvm::SmallDenseSet<Value> streams;
859 for (
auto &operand : yieldOp->getOpOperands()) {
862 auto idx = operand.getOperandNumber();
863 auto stream = adaptor.getOperands()[idx];
864 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
865 eventRecordCallBuilder.create(loc, rewriter, {event, stream});
866 newOperands[idx] = event;
867 streams.insert(stream);
869 for (
auto stream : streams)
870 streamDestroyCallBuilder.create(loc, rewriter, {stream});
872 rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
878 assert(isa<LLVM::LLVMPointerType>(value.
getType()));
880 return *defOp.getCallee() == functionName;
888LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
889 gpu::WaitOp waitOp, OpAdaptor adaptor,
890 ConversionPatternRewriter &rewriter)
const {
891 if (waitOp.getAsyncToken())
892 return rewriter.notifyMatchFailure(waitOp,
"Cannot convert async op.");
894 Location loc = waitOp.getLoc();
896 for (
auto operand : adaptor.getOperands()) {
899 streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
900 streamDestroyCallBuilder.create(loc, rewriter, {operand});
904 eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
905 eventDestroyCallBuilder.create(loc, rewriter, {operand});
909 rewriter.eraseOp(waitOp);
918LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
919 gpu::WaitOp waitOp, OpAdaptor adaptor,
920 ConversionPatternRewriter &rewriter)
const {
921 if (!waitOp.getAsyncToken())
922 return rewriter.notifyMatchFailure(waitOp,
"Can only convert async op.");
924 Location loc = waitOp.getLoc();
926 auto insertionPoint = rewriter.saveInsertionPoint();
927 SmallVector<Value, 1> events;
929 llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
930 auto operand = std::get<1>(pair);
934 auto *defOp = std::get<0>(pair).getDefiningOp();
935 rewriter.setInsertionPointAfter(defOp);
936 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
937 eventRecordCallBuilder.create(loc, rewriter, {event, operand});
938 events.push_back(event);
942 events.push_back(operand);
945 rewriter.restoreInsertionPoint(insertionPoint);
946 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
947 for (
auto event : events)
948 streamWaitEventCallBuilder.create(loc, rewriter, {stream,
event});
949 for (
auto event : events)
950 eventDestroyCallBuilder.create(loc, rewriter, {
event});
951 rewriter.replaceOp(waitOp, {stream});
957LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
958 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
959 ConversionPatternRewriter &rewriter)
const {
966 if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
967 return rewriter.notifyMatchFailure(
968 launchOp,
"Cannot convert non-async op with async dependencies.");
970 Location loc = launchOp.getLoc();
972 Value stream = Value();
973 if (!adaptor.getAsyncDependencies().empty()) {
974 stream = adaptor.getAsyncDependencies().front();
977 if (adaptor.getAsyncDependencies().size() > 1) {
978 auto insertionPoint = rewriter.saveInsertionPoint();
979 SmallVector<Value, 4> events;
980 for (
auto [origDep, convertedDep] :
981 llvm::zip(launchOp.getAsyncDependencies().drop_front(),
982 adaptor.getAsyncDependencies().drop_front())) {
984 streamCreateCallBuilder.functionName)) {
985 events.push_back(convertedDep);
988 Operation *defOp = origDep.getDefiningOp();
989 rewriter.setInsertionPointAfter(defOp);
991 eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
992 eventRecordCallBuilder.create(loc, rewriter, {event, convertedDep});
993 events.push_back(event);
995 rewriter.restoreInsertionPoint(insertionPoint);
996 for (Value event : events)
997 streamWaitEventCallBuilder.create(loc, rewriter, {stream,
event});
998 for (Value event : events)
999 eventDestroyCallBuilder.create(loc, rewriter, {
event});
1004 else if (launchOp.getAsyncToken())
1005 stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
1010 OperandRange origArguments = launchOp.getKernelOperands();
1011 bool effectiveBarePtr = kernelBarePtrCallConv ||
1012 getTypeConverter()->getOptions().useBarePtrCallConv;
1013 if (effectiveBarePtr) {
1014 for (Value arg : origArguments) {
1015 if (isa<UnrankedMemRefType>(arg.getType()))
1016 return rewriter.notifyMatchFailure(
1017 loc,
"unranked memref kernel argument is not supported with "
1018 "the bare-pointer calling convention");
1021 SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
1022 loc, origArguments, adaptor.getKernelOperands(), rewriter,
1023 kernelBarePtrCallConv);
1024 SmallVector<Value, 8> llvmArgumentsWithSizes;
1027 if (kernelIntersperseSizeCallConv) {
1028 if (origArguments.size() != llvmArguments.size()) {
1030 return rewriter.notifyMatchFailure(
1032 "Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
1035 llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
1036 for (
auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
1037 auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
1039 return rewriter.notifyMatchFailure(
1040 launchOp,
"Operand to launch op is not a memref.");
1043 if (!memrefTy.hasStaticShape() ||
1044 !memrefTy.getElementType().isIntOrFloat()) {
1045 return rewriter.notifyMatchFailure(
1046 launchOp,
"Operand to launch op is not a memref with a static "
1047 "shape and an integer or float element type.");
1050 unsigned bitwidth = memrefTy.getElementTypeBitWidth();
1051 if (bitwidth % 8 != 0) {
1052 return rewriter.notifyMatchFailure(
1053 launchOp,
"Operand to launch op is not a memref with a "
1054 "byte-aligned element type.");
1057 uint64_t staticSize =
static_cast<uint64_t
>(bitwidth / 8) *
1058 static_cast<uint64_t
>(memrefTy.getNumElements());
1060 Value sizeArg = LLVM::ConstantOp::create(
1061 rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize));
1062 llvmArgumentsWithSizes.push_back(llvmArg);
1063 llvmArgumentsWithSizes.push_back(sizeArg);
1067 std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1068 if (launchOp.hasClusterSize()) {
1070 gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
1071 adaptor.getClusterSizeZ()};
1073 auto newLaunchOp = gpu::LaunchFuncOp::create(
1074 rewriter, launchOp.getLoc(), launchOp.getKernelAttr(),
1075 gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1076 adaptor.getGridSizeZ()},
1077 gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1078 adaptor.getBlockSizeZ()},
1079 adaptor.getDynamicSharedMemorySize(),
1080 llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
1081 stream, clusterSize);
1082 if (launchOp.getCooperative())
1083 newLaunchOp.setCooperative(
true);
1084 if (launchOp.getAsyncToken())
1085 rewriter.replaceOp(launchOp, {stream});
1087 rewriter.eraseOp(launchOp);
1092 ConversionPatternRewriter &rewriter,
1093 LLVM::LLVMPointerType destinationType,
1096 auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.
getType());
1097 if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1098 sourcePtr = LLVM::AddrSpaceCastOp::create(
1100 LLVM::LLVMPointerType::get(rewriter.getContext(),
1101 destinationType.getAddressSpace()),
1106LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1107 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1108 ConversionPatternRewriter &rewriter)
const {
1109 auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1112 !isConvertibleAndHasIdentityMaps(memRefType) ||
1116 auto loc = memcpyOp.getLoc();
1118 MemRefDescriptor srcDesc(adaptor.getSrc());
1119 Value numElements =
getNumElements(rewriter, loc, memRefType, srcDesc);
1121 Type elementPtrType = getElementPtrType(memRefType);
1122 Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
1123 Value gepPtr = LLVM::GEPOp::create(
1124 rewriter, loc, elementPtrType,
1125 typeConverter->convertType(memRefType.getElementType()), nullPtr,
1128 LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr);
1131 srcDesc.alignedPtr(rewriter, loc),
1132 *getTypeConverter());
1134 loc, rewriter, llvmPointerType,
1135 MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
1136 *getTypeConverter());
1138 auto stream = adaptor.getAsyncDependencies().front();
1139 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1141 rewriter.replaceOp(memcpyOp, {stream});
1146LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1147 gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1148 ConversionPatternRewriter &rewriter)
const {
1149 auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1152 !isConvertibleAndHasIdentityMaps(memRefType) ||
1156 auto loc = memsetOp.getLoc();
1158 Type valueType = adaptor.getValue().getType();
1161 if (!valueType.
isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1162 return rewriter.notifyMatchFailure(
1163 memsetOp,
"value must be a 16 or 32 bit int or float");
1167 Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1169 MemRefDescriptor dstDesc(adaptor.getDst());
1170 Value numElements =
getNumElements(rewriter, loc, memRefType, dstDesc);
1173 LLVM::BitcastOp::create(rewriter, loc, bitCastType, adaptor.getValue());
1175 dstDesc.alignedPtr(rewriter, loc),
1176 *getTypeConverter());
1178 auto stream = adaptor.getAsyncDependencies().front();
1179 FunctionCallBuilder builder =
1180 valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1181 builder.
create(loc, rewriter, {dst, value, numElements, stream});
1183 rewriter.replaceOp(memsetOp, {stream});
1187LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1188 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1189 ConversionPatternRewriter &rewriter)
const {
1190 Location loc = op.getLoc();
1191 auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1192 {adaptor.getDevIndex()});
1193 rewriter.replaceOp(op, call);
1197template <
typename T>
1200 return LLVM::ConstantOp::create(builder, loc, llvmInt32Type,
1201 static_cast<int32_t
>(tValue));
1204template <
typename T>
1207 return LLVM::ConstantOp::create(
1208 builder, loc, llvmFloat32Type,
1212LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1213 gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1214 ConversionPatternRewriter &rewriter)
const {
1218 Location loc = op.getLoc();
1219 auto stream = adaptor.getAsyncDependencies().front();
1221 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1222 Type dType = op.getMemref().
getType().getElementType();
1225 SmallVector<Value, 4> dims;
1226 for (Value dim : adaptor.getDims()) {
1227 dims.push_back(dim);
1237 if (dims.size() == 2) {
1239 auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1240 rewriter.getIndexAttr(11032));
1241 handle = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1242 llvmInt8Type, handleSz, 16);
1243 handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1245 createLtDnMatCallBuilder
1246 .create(loc, rewriter,
1247 {handle, dims[0], dims[1], pTensor, dtp, stream})
1251 createDnMatCallBuilder
1252 .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1256 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1257 handle = createDnVecCallBuilder
1258 .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1261 rewriter.replaceOp(op, {handle, stream});
1265LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1266 gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1267 ConversionPatternRewriter &rewriter)
const {
1271 Location loc = op.getLoc();
1272 auto stream = adaptor.getAsyncDependencies().front();
1273 auto definingOp = op.getDnTensor().
getDefiningOp<gpu::CreateDnTensorOp>();
1274 SmallVector<Value, 4> dims;
1275 for (Value dim : definingOp.getDims()) {
1276 dims.push_back(dim);
1278 if (dims.size() == 2) {
1282 destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1283 {adaptor.getDnTensor(), stream});
1285 destroyDnMatCallBuilder.create(loc, rewriter,
1286 {adaptor.getDnTensor(), stream});
1289 assert(dims.size() == 1 &&
"Only 1D and 2D tensors are supported");
1290 destroyDnVecCallBuilder.create(loc, rewriter,
1291 {adaptor.getDnTensor(), stream});
1293 rewriter.replaceOp(op, {stream});
1297LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1298 gpu::CreateCooOp op, OpAdaptor adaptor,
1299 ConversionPatternRewriter &rewriter)
const {
1303 Location loc = op.getLoc();
1304 auto stream = adaptor.getAsyncDependencies().front();
1306 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1308 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1310 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1312 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1314 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1318 createCooCallBuilder
1319 .create(loc, rewriter,
1320 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1321 pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1323 rewriter.replaceOp(op, {handle, stream});
1327LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1328 gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1329 ConversionPatternRewriter &rewriter)
const {
1333 Location loc = op.getLoc();
1334 auto stream = adaptor.getAsyncDependencies().front();
1335 Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc);
1337 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1338 Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1340 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1344 createCooAoSCallBuilder
1345 .create(loc, rewriter,
1346 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1347 pIdxs, pValues, itp, dtp, stream})
1349 rewriter.replaceOp(op, {handle, stream});
1353LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1354 gpu::CreateCsrOp op, OpAdaptor adaptor,
1355 ConversionPatternRewriter &rewriter)
const {
1359 Location loc = op.getLoc();
1360 auto stream = adaptor.getAsyncDependencies().front();
1362 MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc);
1364 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1366 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1368 llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1370 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1372 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1377 createCsrCallBuilder
1378 .create(loc, rewriter,
1379 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1380 pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1382 rewriter.replaceOp(op, {handle, stream});
1386LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1387 gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1388 ConversionPatternRewriter &rewriter)
const {
1392 Location loc = op.getLoc();
1393 auto stream = adaptor.getAsyncDependencies().front();
1395 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1397 llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1401 auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1402 rewriter.getIndexAttr(44104));
1403 Value handle = LLVM::AllocaOp::create(
1404 rewriter, loc, llvmPointerType, llvmInt8Type, handleSz, 16);
1405 handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1407 create2To4SpMatCallBuilder
1408 .create(loc, rewriter,
1409 {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1411 rewriter.replaceOp(op, {handle, stream});
1415LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1416 gpu::DestroySpMatOp op, OpAdaptor adaptor,
1417 ConversionPatternRewriter &rewriter)
const {
1421 Location loc = op.getLoc();
1422 auto stream = adaptor.getAsyncDependencies().front();
1425 destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1426 {adaptor.getSpmat(), stream});
1429 destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1431 rewriter.replaceOp(op, {stream});
1435LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1436 gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1437 ConversionPatternRewriter &rewriter)
const {
1441 Location loc = op.getLoc();
1445 auto stream = adaptor.getAsyncDependencies().front();
1446 auto bufferSize = spMVBufferSizeCallBuilder
1447 .create(loc, rewriter,
1448 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1449 adaptor.getDnY(), computeType, stream})
1451 rewriter.replaceOp(op, {bufferSize, stream});
1455LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1456 gpu::SpMVOp op, OpAdaptor adaptor,
1457 ConversionPatternRewriter &rewriter)
const {
1461 Location loc = op.getLoc();
1465 auto stream = adaptor.getAsyncDependencies().front();
1467 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1468 spMVCallBuilder.create(loc, rewriter,
1469 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1470 adaptor.getDnY(), computeType, pBuf, stream});
1471 rewriter.replaceOp(op, {stream});
1475LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1476 gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1477 ConversionPatternRewriter &rewriter)
const {
1481 Location loc = op.getLoc();
1484 auto stream = adaptor.getAsyncDependencies().front();
1491 auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1492 rewriter.getIndexAttr(3));
1494 LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmPointerType,
1496 createCuSparseLtSpMMBufferSizeBuilder
1497 .create(loc, rewriter,
1498 {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1499 adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1503 auto bufferSizePtr1 = LLVM::GEPOp::create(
1504 rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1505 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1506 rewriter.getIndexAttr(1))});
1507 auto bufferSizePtr2 = LLVM::GEPOp::create(
1508 rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1509 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1510 rewriter.getIndexAttr(2))});
1512 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSize);
1514 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr1);
1516 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr2);
1518 rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1523 createSpMMBufferSizeCallBuilder
1524 .create(loc, rewriter,
1525 {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1526 adaptor.getDnmatC(), computeType, stream})
1528 rewriter.replaceOp(op, {bufferSize, stream});
1533LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1534 gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1535 ConversionPatternRewriter &rewriter)
const {
1539 Location loc = op.getLoc();
1544 auto stream = adaptor.getAsyncDependencies().front();
1546 createSDDMMBufferSizeCallBuilder
1547 .create(loc, rewriter,
1548 {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1549 adaptor.getSpmatC(), computeType, stream})
1551 rewriter.replaceOp(op, {bufferSize, stream});
1555LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1556 gpu::SpMMOp op, OpAdaptor adaptor,
1557 ConversionPatternRewriter &rewriter)
const {
1561 Location loc = op.getLoc();
1567 auto stream = adaptor.getAsyncDependencies().front();
1571 SmallVector<Value> pBufs;
1572 for (Value buffer : adaptor.getBuffers()) {
1573 Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
1574 pBufs.push_back(pBuf);
1576 createCuSparseLtSpMMBuilder.create(
1578 {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1579 pBufs[0], pBufs[1], pBufs[2], stream});
1581 Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
1582 .allocatedPtr(rewriter, loc);
1583 createSpMMCallBuilder.create(loc, rewriter,
1584 {modeA, modeB, adaptor.getSpmatA(),
1585 adaptor.getDnmatB(), adaptor.getDnmatC(),
1586 computeType, pBuf, stream});
1588 rewriter.replaceOp(op, {stream});
1592template <
typename T>
1594 converter.addConversion([&converter](T) ->
Type {
1595 return LLVM::LLVMPointerType::get(&converter.
getContext());
1599LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1600 gpu::SDDMMOp op, OpAdaptor adaptor,
1601 ConversionPatternRewriter &rewriter)
const {
1605 Location loc = op.getLoc();
1610 auto stream = adaptor.getAsyncDependencies().front();
1612 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1613 createSDDMMCallBuilder.create(loc, rewriter,
1614 {modeA, modeB, adaptor.getDnmatA(),
1615 adaptor.getDnmatB(), adaptor.getSpmatC(),
1616 computeType, pBuf, stream});
1617 rewriter.replaceOp(op, {stream});
1622ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1623 gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1624 ConversionPatternRewriter &rewriter)
const {
1628 Location loc = op.getLoc();
1629 auto stream = adaptor.getAsyncDependencies().front();
1630 Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1632 rewriter.replaceOp(op, {descr, stream});
1637ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1638 gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1639 ConversionPatternRewriter &rewriter)
const {
1643 Location loc = op.getLoc();
1644 auto stream = adaptor.getAsyncDependencies().front();
1645 createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1646 {adaptor.getDesc(), stream});
1647 rewriter.replaceOp(op, {stream});
1652ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1653 gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1654 ConversionPatternRewriter &rewriter)
const {
1658 Location loc = op.getLoc();
1663 auto stream = adaptor.getAsyncDependencies().front();
1666 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1667 Value bufferSizeNew;
1669 if (adaptor.getKind() ==
1670 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1672 createSpGEMMWorkEstimationBuilder
1673 .create(loc, rewriter,
1674 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1675 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1676 adaptor.getBufferSz(), pBuf, stream})
1680 createSpGEMMComputeBuilder
1681 .create(loc, rewriter,
1682 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1683 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1684 adaptor.getBufferSz(), pBuf, stream})
1687 rewriter.replaceOp(op, {bufferSizeNew, stream});
1691LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1692 gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1693 ConversionPatternRewriter &rewriter)
const {
1697 Location loc = op.getLoc();
1702 auto stream = adaptor.getAsyncDependencies().front();
1703 createSpGEMMCopyBuilder.create(loc, rewriter,
1704 {adaptor.getDesc(), modeA, modeB,
1705 adaptor.getSpmatA(), adaptor.getSpmatB(),
1706 adaptor.getSpmatC(), computeType, stream});
1707 rewriter.replaceOp(op, {stream});
1711LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1712 gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1713 ConversionPatternRewriter &rewriter)
const {
1717 Location loc = op.getLoc();
1718 auto stream = adaptor.getAsyncDependencies().front();
1720 auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1721 rewriter.getIndexAttr(3));
1722 auto buffer = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1723 llvmInt64Type, three, 16);
1725 auto rowsPtr = LLVM::GEPOp::create(
1726 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1727 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1728 rewriter.getIndexAttr(0))});
1729 auto colsPtr = LLVM::GEPOp::create(
1730 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1731 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1732 rewriter.getIndexAttr(1))});
1733 auto nnzsPtr = LLVM::GEPOp::create(
1734 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1735 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1736 rewriter.getIndexAttr(2))});
1737 createSpMatGetSizeBuilder.create(
1738 loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1739 auto rows = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, rowsPtr);
1740 auto cols = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, colsPtr);
1741 auto nnzs = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, nnzsPtr);
1743 rewriter.replaceOp(op, {rows, cols, nnzs, stream});
1747LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1748 gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1749 ConversionPatternRewriter &rewriter)
const {
1753 Location loc = op.getLoc();
1754 auto stream = adaptor.getAsyncDependencies().front();
1756 MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
1758 MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
1760 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1761 createSetCsrPointersBuilder.create(
1762 loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1763 rewriter.replaceOp(op, {stream});
1767LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1768 gpu::CreateCscOp op, OpAdaptor adaptor,
1769 ConversionPatternRewriter &rewriter)
const {
1773 Location loc = op.getLoc();
1774 auto stream = adaptor.getAsyncDependencies().front();
1776 MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc);
1778 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1780 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1782 llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1784 llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1786 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1791 createCscCallBuilder
1792 .create(loc, rewriter,
1793 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1794 pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1796 rewriter.replaceOp(op, {handle, stream});
1800LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1801 gpu::CreateBsrOp op, OpAdaptor adaptor,
1802 ConversionPatternRewriter &rewriter)
const {
1806 Location loc = op.getLoc();
1807 auto stream = adaptor.getAsyncDependencies().front();
1809 MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc);
1811 MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc);
1813 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1815 llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1817 llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1819 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1824 createBsrCallBuilder
1825 .create(loc, rewriter,
1826 {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1827 adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1828 pColIdxs, pValues, ptp, itp, dtp, stream})
1830 rewriter.replaceOp(op, {handle, stream});
1836 bool kernelBarePtrCallConv,
bool kernelIntersperseSizeCallConv) {
1846 patterns.
add<ConvertAsyncYieldToGpuRuntimeCallPattern>(converter,
1849 patterns.
add<ConvertAllocOpToGpuRuntimeCallPattern,
1850 ConvertDeallocOpToGpuRuntimeCallPattern,
1851 ConvertHostRegisterOpToGpuRuntimeCallPattern,
1852 ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1853 ConvertMemcpyOpToGpuRuntimeCallPattern,
1854 ConvertMemsetOpToGpuRuntimeCallPattern,
1855 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1856 ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1857 ConvertWaitOpToGpuRuntimeCallPattern,
1858 ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1859 ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1860 ConvertCreateCooOpToGpuRuntimeCallPattern,
1861 ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1862 ConvertCreateCsrOpToGpuRuntimeCallPattern,
1863 ConvertCreateCscOpToGpuRuntimeCallPattern,
1864 ConvertCreateBsrOpToGpuRuntimeCallPattern,
1865 ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1866 ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1867 ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1868 ConvertSpMVOpToGpuRuntimeCallPattern,
1869 ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1870 ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1871 ConvertSpMMOpToGpuRuntimeCallPattern,
1872 ConvertSDDMMOpToGpuRuntimeCallPattern,
1873 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1874 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1875 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1876 ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1877 ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1878 ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
1879 patterns.
add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv,
1880 kernelIntersperseSizeCallConv);
1888struct GPUModuleOpConvertToLLVMInterface
1889 :
public ConvertToLLVMOpInterface::ExternalModel<
1890 GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
1892 void getConvertToLLVMConversionAttrs(
1897void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
1898 Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs)
const {
1899 auto module = cast<gpu::GPUModuleOp>(op);
1900 ArrayAttr targetsAttr =
module.getTargetsAttr();
1902 if (!targetsAttr || targetsAttr.size() != 1)
1904 if (
auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
1905 attrs.push_back(patternAttr);
1910 gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
IntegerType getIntegerType(unsigned width)
FloatAttr getF32FloatAttr(float value)
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
This class helps build Operations.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void registerConvertGpuToLLVMInterface(DialectRegistry ®istry)
Registers the ConvertToLLVMOpInterface interface on the gpu::GPUModuleOP operation.
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool kernelBarePtrCallConv=false, bool kernelIntersperseSizeCallConv=false)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
void registerConvertToLLVMDependentDialectLoading(DialectRegistry ®istry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
LLVM::LLVMFunctionType functionType
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const