MLIR  18.0.0git
GPUToLLVMConversion.cpp
Go to the documentation of this file.
1 //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert gpu.launch_func op into a sequence of
10 // GPU runtime calls. As most of GPU runtimes does not have a stable published
11 // ABI, this pass uses a slim runtime layer that builds on top of the public
12 // API from GPU runtime headers.
13 //
14 //===----------------------------------------------------------------------===//
15 
17 
34 #include "mlir/IR/Attributes.h"
35 #include "mlir/IR/Builders.h"
36 #include "mlir/IR/BuiltinOps.h"
37 #include "mlir/IR/BuiltinTypes.h"
38 
39 #include "llvm/ADT/STLExtras.h"
40 #include "llvm/Support/Error.h"
41 #include "llvm/Support/FormatVariadic.h"
42 
43 #define DEBUG_TYPE "gpu-to-llvm"
44 
45 namespace mlir {
46 #define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
47 #include "mlir/Conversion/Passes.h.inc"
48 } // namespace mlir
49 
50 using namespace mlir;
51 
52 static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst";
53 
54 namespace {
55 class GpuToLLVMConversionPass
56  : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
57 public:
58  using Base::Base;
59  void getDependentDialects(DialectRegistry &registry) const final {
60  Base::getDependentDialects(registry);
62  }
63  // Run the dialect converter on the module.
64  void runOnOperation() override;
65 };
66 
67 template <typename OpTy>
68 class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
69 public:
70  explicit ConvertOpToGpuRuntimeCallPattern(
71  const LLVMTypeConverter &typeConverter)
72  : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
73 
74 protected:
76  MemRefType type, MemRefDescriptor desc) const {
78  return type.hasStaticShape()
80  rewriter, loc, indexType, type.getNumElements())
81  // For identity maps (verified by caller), the number of
82  // elements is stride[0] * size[0].
83  : rewriter.create<LLVM::MulOp>(loc,
84  desc.stride(rewriter, loc, 0),
85  desc.size(rewriter, loc, 0));
86  }
87 
88  MLIRContext *context = &this->getTypeConverter()->getContext();
89 
90  Type llvmVoidType = LLVM::LLVMVoidType::get(context);
91  LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
92  Type llvmInt8Type = IntegerType::get(context, 8);
93  Type llvmInt16Type = IntegerType::get(context, 16);
94  Type llvmInt32Type = IntegerType::get(context, 32);
95  Type llvmInt64Type = IntegerType::get(context, 64);
96  Type llvmFloat32Type = Float32Type::get(context);
97  Type llvmIntPtrType = IntegerType::get(
98  context, this->getTypeConverter()->getPointerBitwidth(0));
99 
100  FunctionCallBuilder moduleLoadCallBuilder = {
101  "mgpuModuleLoad",
102  llvmPointerType /* void *module */,
103  {llvmPointerType /* void *cubin */, llvmInt64Type /* size_t size */}};
104  FunctionCallBuilder moduleUnloadCallBuilder = {
105  "mgpuModuleUnload", llvmVoidType, {llvmPointerType /* void *module */}};
106  FunctionCallBuilder moduleGetFunctionCallBuilder = {
107  "mgpuModuleGetFunction",
108  llvmPointerType /* void *function */,
109  {
110  llvmPointerType, /* void *module */
111  llvmPointerType /* char *name */
112  }};
113  FunctionCallBuilder launchKernelCallBuilder = {
114  "mgpuLaunchKernel",
115  llvmVoidType,
116  {
117  llvmPointerType, /* void* f */
118  llvmIntPtrType, /* intptr_t gridXDim */
119  llvmIntPtrType, /* intptr_t gridyDim */
120  llvmIntPtrType, /* intptr_t gridZDim */
121  llvmIntPtrType, /* intptr_t blockXDim */
122  llvmIntPtrType, /* intptr_t blockYDim */
123  llvmIntPtrType, /* intptr_t blockZDim */
124  llvmInt32Type, /* unsigned int sharedMemBytes */
125  llvmPointerType, /* void *hstream */
126  llvmPointerType, /* void **kernelParams */
127  llvmPointerType, /* void **extra */
128  llvmInt64Type /* size_t paramsCount */
129  }};
130  FunctionCallBuilder streamCreateCallBuilder = {
131  "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
132  FunctionCallBuilder streamDestroyCallBuilder = {
133  "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
134  FunctionCallBuilder streamSynchronizeCallBuilder = {
135  "mgpuStreamSynchronize",
136  llvmVoidType,
137  {llvmPointerType /* void *stream */}};
138  FunctionCallBuilder streamWaitEventCallBuilder = {
139  "mgpuStreamWaitEvent",
140  llvmVoidType,
141  {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
142  FunctionCallBuilder eventCreateCallBuilder = {
143  "mgpuEventCreate", llvmPointerType /* void *event */, {}};
144  FunctionCallBuilder eventDestroyCallBuilder = {
145  "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
146  FunctionCallBuilder eventSynchronizeCallBuilder = {
147  "mgpuEventSynchronize",
148  llvmVoidType,
149  {llvmPointerType /* void *event */}};
150  FunctionCallBuilder eventRecordCallBuilder = {
151  "mgpuEventRecord",
152  llvmVoidType,
153  {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
154  FunctionCallBuilder hostRegisterCallBuilder = {
155  "mgpuMemHostRegisterMemRef",
156  llvmVoidType,
157  {llvmIntPtrType /* intptr_t rank */,
158  llvmPointerType /* void *memrefDesc */,
159  llvmIntPtrType /* intptr_t elementSizeBytes */}};
160  FunctionCallBuilder hostUnregisterCallBuilder = {
161  "mgpuMemHostUnregisterMemRef",
162  llvmVoidType,
163  {llvmIntPtrType /* intptr_t rank */,
164  llvmPointerType /* void *memrefDesc */,
165  llvmIntPtrType /* intptr_t elementSizeBytes */}};
166  FunctionCallBuilder allocCallBuilder = {
167  "mgpuMemAlloc",
168  llvmPointerType /* void * */,
169  {llvmIntPtrType /* intptr_t sizeBytes */,
170  llvmPointerType /* void *stream */,
171  llvmInt8Type /* bool isHostShared */}};
172  FunctionCallBuilder deallocCallBuilder = {
173  "mgpuMemFree",
174  llvmVoidType,
175  {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
176  FunctionCallBuilder memcpyCallBuilder = {
177  "mgpuMemcpy",
178  llvmVoidType,
179  {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
180  llvmIntPtrType /* intptr_t sizeBytes */,
181  llvmPointerType /* void *stream */}};
182  FunctionCallBuilder memset16CallBuilder = {
183  "mgpuMemset16",
184  llvmVoidType,
185  {llvmPointerType /* void *dst */,
186  llvmInt16Type /* unsigned short value */,
187  llvmIntPtrType /* intptr_t sizeBytes */,
188  llvmPointerType /* void *stream */}};
189  FunctionCallBuilder memset32CallBuilder = {
190  "mgpuMemset32",
191  llvmVoidType,
192  {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
193  llvmIntPtrType /* intptr_t sizeBytes */,
194  llvmPointerType /* void *stream */}};
195  FunctionCallBuilder setDefaultDeviceCallBuilder = {
196  "mgpuSetDefaultDevice",
197  llvmVoidType,
198  {llvmInt32Type /* uint32_t devIndex */}};
199  FunctionCallBuilder createDnVecCallBuilder = {
200  "mgpuCreateDnVec",
201  llvmPointerType,
202  {llvmIntPtrType, llvmPointerType, llvmInt32Type,
203  llvmPointerType /* void *stream */}};
204  FunctionCallBuilder destroyDnVecCallBuilder = {
205  "mgpuDestroyDnVec",
206  llvmVoidType,
207  {llvmPointerType, llvmPointerType /* void *stream */}};
208  FunctionCallBuilder createDnMatCallBuilder = {
209  "mgpuCreateDnMat",
210  llvmPointerType,
211  {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
212  llvmPointerType /* void *stream */}};
213  FunctionCallBuilder destroyDnMatCallBuilder = {
214  "mgpuDestroyDnMat",
215  llvmVoidType,
216  {llvmPointerType, llvmPointerType /* void *stream */}};
217  FunctionCallBuilder createCooCallBuilder = {
218  "mgpuCreateCoo",
219  llvmPointerType,
220  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
221  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
222  llvmPointerType /* void *stream */}};
223  FunctionCallBuilder createCooAoSCallBuilder = {
224  "mgpuCreateCooAoS", // deprecated in cuSPARSE 11.2
225  llvmPointerType,
226  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
227  llvmPointerType, llvmInt32Type, llvmInt32Type,
228  llvmPointerType /* void *stream */}};
229  FunctionCallBuilder createCsrCallBuilder = {
230  "mgpuCreateCsr",
231  llvmPointerType,
232  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
233  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
234  llvmInt32Type, llvmPointerType /* void *stream */}};
235  FunctionCallBuilder createCscCallBuilder = {
236  "mgpuCreateCsc",
237  llvmPointerType,
238  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
239  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
240  llvmInt32Type, llvmPointerType /* void *stream */}};
241  FunctionCallBuilder createBsrCallBuilder = {
242  "mgpuCreateBsr",
243  llvmPointerType,
244  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
245  llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
246  llvmInt32Type, llvmInt32Type, llvmInt32Type,
247  llvmPointerType /* void *stream */}};
248  FunctionCallBuilder destroySpMatCallBuilder = {
249  "mgpuDestroySpMat",
250  llvmVoidType,
251  {llvmPointerType, llvmPointerType /* void *stream */}};
252  FunctionCallBuilder spMVBufferSizeCallBuilder = {
253  "mgpuSpMVBufferSize",
254  llvmIntPtrType,
255  {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
256  llvmInt32Type, llvmPointerType /* void *stream */}};
257  FunctionCallBuilder spMVCallBuilder = {
258  "mgpuSpMV",
259  llvmVoidType,
260  {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
261  llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
262  FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
263  "mgpuSpMMBufferSize",
264  llvmIntPtrType,
265  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
266  llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
267  FunctionCallBuilder createSpMMCallBuilder = {
268  "mgpuSpMM",
269  llvmVoidType,
270  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
271  llvmPointerType, llvmInt32Type, llvmPointerType,
272  llvmPointerType /* void *stream */}};
273  FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
274  "mgpuSDDMMBufferSize",
275  llvmIntPtrType,
276  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
277  llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
278  FunctionCallBuilder createSDDMMCallBuilder = {
279  "mgpuSDDMM",
280  llvmVoidType,
281  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
282  llvmPointerType, llvmInt32Type, llvmPointerType,
283  llvmPointerType /* void *stream */}};
284  FunctionCallBuilder createLtDnMatCallBuilder = {
285  "mgpuCreateCuSparseLtDnMat",
286  llvmVoidType,
287  {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
288  llvmInt32Type, llvmPointerType /* void *stream */}};
289  FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
290  "mgpuDestroyCuSparseLtSpMat",
291  llvmVoidType,
292  {llvmPointerType, llvmPointerType /* void *stream */}};
293  FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
294  "mgpuDestroyCuSparseLtDnMat",
295  llvmVoidType,
296  {llvmPointerType, llvmPointerType /* void *stream */}};
297  FunctionCallBuilder create2To4SpMatCallBuilder = {
298  "mgpuCusparseLtCreate2To4SpMat",
299  llvmVoidType,
300  {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
301  llvmInt32Type, llvmPointerType /* void *stream */}};
302  FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
303  "mgpuCuSparseLtSpMMBufferSize",
304  llvmVoidType,
305  {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
306  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
307  llvmPointerType /*void *stream*/}};
308  FunctionCallBuilder createCuSparseLtSpMMBuilder = {
309  "mgpuCuSparseLtSpMM",
310  llvmVoidType,
311  {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
312  llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}};
313  FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
314  "mgpuSpGEMMCreateDescr",
315  llvmPointerType,
316  {llvmPointerType /*void *stream*/}};
317  FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
318  "mgpuSpGEMMDestroyDescr",
319  llvmVoidType,
320  {llvmPointerType /*s*/, llvmPointerType /*void *stream*/}};
321  FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
322  "mgpuSpGEMMWorkEstimation",
323  llvmIntPtrType,
324  {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
325  llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
326  llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
327  llvmPointerType /*void *stream*/}};
328  FunctionCallBuilder createSpGEMMComputeBuilder = {
329  "mgpuSpGEMMCompute",
330  llvmIntPtrType,
331  {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
332  llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
333  llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
334  llvmPointerType /*void *stream*/}};
335  FunctionCallBuilder createSpGEMMCopyBuilder = {
336  "mgpuSpGEMMCopy",
337  llvmVoidType,
338  {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
339  llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
340  llvmInt32Type /*ctp*/, llvmPointerType /*void *stream*/}};
341  FunctionCallBuilder createSpMatGetSizeBuilder = {
342  "mgpuSpMatGetSize",
343  llvmVoidType,
344  {llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/,
345  llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}};
346  FunctionCallBuilder createSetCsrPointersBuilder = {
347  "mgpuSetCsrPointers",
348  llvmVoidType,
349  {llvmPointerType /*spmat*/, llvmPointerType /*pos*/,
350  llvmPointerType /*crd*/, llvmPointerType /*val*/,
351  llvmPointerType /*void *stream*/}};
352 };
353 
354 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
355 /// call. Currently it supports CUDA and ROCm (HIP).
356 class ConvertHostRegisterOpToGpuRuntimeCallPattern
357  : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
358 public:
359  ConvertHostRegisterOpToGpuRuntimeCallPattern(
360  const LLVMTypeConverter &typeConverter)
361  : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
362 
363 private:
365  matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
366  ConversionPatternRewriter &rewriter) const override;
367 };
368 
369 class ConvertHostUnregisterOpToGpuRuntimeCallPattern
370  : public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
371 public:
372  ConvertHostUnregisterOpToGpuRuntimeCallPattern(
373  const LLVMTypeConverter &typeConverter)
374  : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
375  }
376 
377 private:
379  matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
380  ConversionPatternRewriter &rewriter) const override;
381 };
382 
383 /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
384 /// call. Currently it supports CUDA and ROCm (HIP).
385 class ConvertAllocOpToGpuRuntimeCallPattern
386  : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
387 public:
388  ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
389  : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
390 
391 private:
393  matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
394  ConversionPatternRewriter &rewriter) const override;
395 };
396 
397 /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
398 /// call. Currently it supports CUDA and ROCm (HIP).
399 class ConvertDeallocOpToGpuRuntimeCallPattern
400  : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
401 public:
402  ConvertDeallocOpToGpuRuntimeCallPattern(
403  const LLVMTypeConverter &typeConverter)
404  : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
405 
406 private:
408  matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
409  ConversionPatternRewriter &rewriter) const override;
410 };
411 
412 class ConvertAsyncYieldToGpuRuntimeCallPattern
413  : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
414 public:
415  ConvertAsyncYieldToGpuRuntimeCallPattern(
416  const LLVMTypeConverter &typeConverter)
417  : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
418 
419 private:
421  matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
422  ConversionPatternRewriter &rewriter) const override;
423 };
424 
425 /// A rewrite pattern to convert gpu.wait operations into a GPU runtime
426 /// call. Currently it supports CUDA and ROCm (HIP).
427 class ConvertWaitOpToGpuRuntimeCallPattern
428  : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
429 public:
430  ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
431  : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
432 
433 private:
435  matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
436  ConversionPatternRewriter &rewriter) const override;
437 };
438 
439 /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
440 /// call. Currently it supports CUDA and ROCm (HIP).
441 class ConvertWaitAsyncOpToGpuRuntimeCallPattern
442  : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
443 public:
444  ConvertWaitAsyncOpToGpuRuntimeCallPattern(
445  const LLVMTypeConverter &typeConverter)
446  : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
447 
448 private:
450  matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
451  ConversionPatternRewriter &rewriter) const override;
452 };
453 
454 /// A rewrite patter to convert gpu.launch_func operations into a sequence of
455 /// GPU runtime calls. Currently it supports CUDA and ROCm (HIP).
456 ///
457 /// In essence, a gpu.launch_func operations gets compiled into the following
458 /// sequence of runtime calls:
459 ///
460 /// * moduleLoad -- loads the module given the cubin / hsaco data
461 /// * moduleGetFunction -- gets a handle to the actual kernel function
462 /// * getStreamHelper -- initializes a new compute stream on GPU
463 /// * launchKernel -- launches the kernel on a stream
464 /// * streamSynchronize -- waits for operations on the stream to finish
465 ///
466 /// Intermediate data structures are allocated on the stack.
467 class ConvertLaunchFuncOpToGpuRuntimeCallPattern
468  : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
469 public:
470  ConvertLaunchFuncOpToGpuRuntimeCallPattern(
471  const LLVMTypeConverter &typeConverter, StringRef gpuBinaryAnnotation,
472  bool kernelBarePtrCallConv, SymbolTable *cachedModuleTable)
473  : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
474  gpuBinaryAnnotation(gpuBinaryAnnotation),
475  kernelBarePtrCallConv(kernelBarePtrCallConv),
476  cachedModuleTable(cachedModuleTable) {}
477 
478 private:
479  Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
480  OpBuilder &builder) const;
481  Value generateKernelNameConstant(StringRef moduleName, StringRef name,
482  Location loc, OpBuilder &builder) const;
483 
485  matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
486  ConversionPatternRewriter &rewriter) const override;
487 
488  llvm::SmallString<32> gpuBinaryAnnotation;
489  bool kernelBarePtrCallConv;
490  SymbolTable *cachedModuleTable;
491 };
492 
493 class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
495 
496  LogicalResult matchAndRewrite(gpu::GPUModuleOp op,
497  PatternRewriter &rewriter) const override {
498  // GPU kernel modules are no longer necessary since we have a global
499  // constant with the CUBIN, or HSACO data.
500  rewriter.eraseOp(op);
501  return success();
502  }
503 };
504 
505 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
506 /// call. Currently it supports CUDA and ROCm (HIP).
507 class ConvertMemcpyOpToGpuRuntimeCallPattern
508  : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
509 public:
510  ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
511  : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
512 
513 private:
515  matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
516  ConversionPatternRewriter &rewriter) const override;
517 };
518 
519 /// A rewrite pattern to convert gpu.memset operations into a GPU runtime
520 /// call. Currently it supports CUDA and ROCm (HIP).
521 class ConvertMemsetOpToGpuRuntimeCallPattern
522  : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
523 public:
524  ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
525  : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
526 
527 private:
529  matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
530  ConversionPatternRewriter &rewriter) const override;
531 };
532 
533 /// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call.
534 /// Currently supports CUDA and ROCm (HIP)
535 class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
536  : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
537 public:
538  ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
539  const LLVMTypeConverter &typeConverter)
540  : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
541  typeConverter) {}
542 
544  matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
545  ConversionPatternRewriter &rewriter) const override;
546 };
547 
548 /// Generic rewriting rule for operation on sparse matrices.
549 /// Currently supports CUDA (by means of cuSparse and cuSparseLt).
550 #define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
551  class Convert##op_name##ToGpuRuntimeCallPattern \
552  : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
553  public: \
554  Convert##op_name##ToGpuRuntimeCallPattern( \
555  const LLVMTypeConverter &typeConverter) \
556  : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
557  \
558  private: \
559  LogicalResult \
560  matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
561  ConversionPatternRewriter &rewriter) const override; \
562  };
563 
581 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp)
585 
586 } // namespace
587 
588 void GpuToLLVMConversionPass::runOnOperation() {
589  MLIRContext *context = &getContext();
590  SymbolTable symbolTable = SymbolTable(getOperation());
591  LowerToLLVMOptions options(context);
592  options.useBarePtrCallConv = hostBarePtrCallConv;
593  RewritePatternSet patterns(context);
594  ConversionTarget target(*context);
595  target.addLegalDialect<LLVM::LLVMDialect>();
596  LLVMTypeConverter converter(context, options);
597 
598  // Populate all patterns from all dialects that implement the
599  // `ConvertToLLVMPatternInterface` interface.
600  for (Dialect *dialect : context->getLoadedDialects()) {
601  auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
602  if (!iface)
603  continue;
604  iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
605  }
606 
607  // Preserve GPU modules if they have target attributes.
608  target.addDynamicallyLegalOp<gpu::GPUModuleOp>(
609  [](gpu::GPUModuleOp module) -> bool {
610  return module.getTargetsAttr() != nullptr;
611  });
612  // Accept as legal LaunchFuncOps if they refer to GPU Modules with targets and
613  // the operands have been lowered.
614  target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
615  [&](gpu::LaunchFuncOp op) -> bool {
616  auto module =
617  symbolTable.lookup<gpu::GPUModuleOp>(op.getKernelModuleName());
618  return converter.isLegal(op->getOperandTypes()) &&
619  converter.isLegal(op->getResultTypes()) &&
620  (module && module.getTargetsAttr() &&
621  !module.getTargetsAttr().empty());
622  });
623 
624  // These aren't covered by the ConvertToLLVMPatternInterface right now.
625  populateVectorToLLVMConversionPatterns(converter, patterns);
628  target);
629  populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation,
630  kernelBarePtrCallConv, &symbolTable);
631 
632  if (failed(
633  applyPartialConversion(getOperation(), target, std::move(patterns))))
634  signalPassFailure();
635 }
636 
637 LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
638  ArrayRef<Value> arguments) const {
639  auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
640  auto function = [&] {
641  if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
642  return function;
643  return OpBuilder::atBlockEnd(module.getBody())
644  .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
645  }();
646  return builder.create<LLVM::CallOp>(loc, function, arguments);
647 }
648 
649 // Corresponding to cusparseIndexType_t defined in cusparse.h.
650 static int32_t getCuSparseIndexTypeFrom(Type type) {
651  if (type.isInteger(16))
652  return 1; // CUSPARSE_INDEX_16U
653  if (type.isInteger(32))
654  return 2; // CUSPARSE_INDEX_32I
655  return 3; // CUSPARSE_INDEX_64I
656 }
657 
658 static int32_t getCuSparseLtDataTypeFrom(Type type) {
659  if (type.isF16())
660  return 0; // CUSPARSE_COMPUTE_16F,
661  if (type.isInteger(32))
662  return 1; // CUSPARSE_COMPUTE_32I
663  llvm_unreachable("unsupported type");
664  // TODO: add support to TF32
665 }
666 
667 // Corresponding to cudaDataType_t defined in CUDA library_types.h.
668 static int32_t getCuSparseDataTypeFrom(Type type) {
669  if (llvm::isa<ComplexType>(type)) {
670  // get the element type
671  auto elementType = type.cast<ComplexType>().getElementType();
672  if (elementType.isBF16())
673  return 15; // CUDA_C_16BF
674  if (elementType.isF16())
675  return 6; // CUDA_C_16F
676  if (elementType.isF32())
677  return 4; // CUDA_C_32F
678  if (elementType.isF64())
679  return 5; // CUDA_C_64F
680  if (elementType.isInteger(8))
681  return 7; // CUDA_C_8I
682  if (elementType.isInteger(16))
683  return 21; // CUDA_C_16I
684  if (elementType.isInteger(32))
685  return 11; // CUDA_C_32I
686  }
687  if (type.isBF16())
688  return 14; // CUDA_R_16BF
689  if (type.isF16())
690  return 2; // CUDA_R_16F
691  if (type.isF32())
692  return 0; // CUDA_R_32F
693  if (type.isF64())
694  return 1; // CUDA_R_64F
695  if (type.isInteger(8))
696  return 3; // CUDA_R_8I
697  if (type.isInteger(16))
698  return 20; // CUDA_R_16I
699  if (type.isInteger(32))
700  return 10; // CUDA_R_32I
701 
702  llvm_unreachable("unsupported element type");
703 }
704 
705 static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat) {
706  return spMat.getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
707 }
708 
709 // TODO: We may want a run-time (of the mlir compiler) disablement/warning:
710 // cusparseLt currently won't work for cuda architecture <8.0 and will trigger a
711 // runtime (of the CUDA program) error , but it might be great if we could at
712 // least output a warning when we found the target architecture is <8.0 and the
713 // user still wants to use cusparseLt. to make sure when lowering gpu sparse
714 // dialect to llvm calls, the cusparselt calls are disabled for cuda
715 // architecture <8.0
716 static bool is2To4Sparsity(Value spMat) {
717  if (auto op = spMat.getDefiningOp<gpu::Create2To4SpMatOp>())
718  return true;
719  if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
720  return false;
721  if (auto op = spMat.getDefiningOp<gpu::CreateCooAoSOp>())
722  return false;
723  if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
724  return false;
725  if (auto op = spMat.getDefiningOp<gpu::CreateCscOp>())
726  return false;
727  if (auto op = spMat.getDefiningOp<gpu::CreateBsrOp>())
728  return false;
729  // Print the spMat defining op
730  spMat.getDefiningOp()->print(llvm::errs());
731  llvm_unreachable("cannot find spmat def");
732 }
733 
734 static bool isSpMMCusparseLtOp(Value op) {
735  for (Operation *user : op.getUsers()) {
736  auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
737  // If the other operator is 50% sparsity then we should use cusparseLt
738  if (!spmmOp)
739  continue;
740  if (is2To4Sparsity(spmmOp.getSpmatA()))
741  return true;
742  }
743  return false;
744 }
745 
746 // Returns whether all operands are of LLVM type.
748  ConversionPatternRewriter &rewriter) {
749  if (!llvm::all_of(operands, [](Value value) {
750  return LLVM::isCompatibleType(value.getType());
751  }))
752  return rewriter.notifyMatchFailure(
753  op, "Cannot convert if operands aren't of LLVM type.");
754  return success();
755 }
756 
757 static LogicalResult
759  gpu::AsyncOpInterface op) {
760  if (op.getAsyncDependencies().size() != 1)
761  return rewriter.notifyMatchFailure(
762  op, "Can only convert with exactly one async dependency.");
763 
764  if (!op.getAsyncToken())
765  return rewriter.notifyMatchFailure(op, "Can convert only async version.");
766 
767  return success();
768 }
769 
770 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
771  gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
772  ConversionPatternRewriter &rewriter) const {
773  auto *op = hostRegisterOp.getOperation();
774  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
775  return failure();
776 
777  Location loc = op->getLoc();
778 
779  auto memRefType = hostRegisterOp.getValue().getType();
780  auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
781  auto elementSize = getSizeInBytes(loc, elementType, rewriter);
782 
783  auto arguments = getTypeConverter()->promoteOperands(
784  loc, op->getOperands(), adaptor.getOperands(), rewriter);
785  arguments.push_back(elementSize);
786  hostRegisterCallBuilder.create(loc, rewriter, arguments);
787 
788  rewriter.eraseOp(op);
789  return success();
790 }
791 
792 LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
793  gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
794  ConversionPatternRewriter &rewriter) const {
795  Operation *op = hostUnregisterOp.getOperation();
796  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
797  return failure();
798 
799  Location loc = op->getLoc();
800 
801  auto memRefType = hostUnregisterOp.getValue().getType();
802  auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
803  auto elementSize = getSizeInBytes(loc, elementType, rewriter);
804 
805  auto arguments = getTypeConverter()->promoteOperands(
806  loc, op->getOperands(), adaptor.getOperands(), rewriter);
807  arguments.push_back(elementSize);
808  hostUnregisterCallBuilder.create(loc, rewriter, arguments);
809 
810  rewriter.eraseOp(op);
811  return success();
812 }
813 
814 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
815  gpu::AllocOp allocOp, OpAdaptor adaptor,
816  ConversionPatternRewriter &rewriter) const {
817 
818  MemRefType memRefType = allocOp.getType();
819 
820  if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
821  !isConvertibleAndHasIdentityMaps(memRefType))
822  return failure();
823 
824  auto loc = allocOp.getLoc();
825 
826  bool isShared = allocOp.getHostShared();
827 
828  if (isShared && allocOp.getAsyncToken())
829  return rewriter.notifyMatchFailure(
830  allocOp, "Host Shared allocation cannot be done async");
831  if (!isShared && failed(isAsyncWithOneDependency(rewriter, allocOp)))
832  return failure();
833 
834  // Get shape of the memref as values: static sizes are constant
835  // values and dynamic sizes are passed to 'alloc' as operands.
836  SmallVector<Value, 4> shape;
837  SmallVector<Value, 4> strides;
838  Value sizeBytes;
839  getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
840  shape, strides, sizeBytes);
841 
842  // Allocate the underlying buffer and store a pointer to it in the MemRef
843  // descriptor.
844  auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
845  Value stream = adaptor.getAsyncDependencies().empty()
846  ? nullPtr
847  : adaptor.getAsyncDependencies().front();
848 
849  auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>(
850  loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
851 
852  Value allocatedPtr =
853  allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
854  .getResult();
855 
856  // No alignment.
857  Value alignedPtr = allocatedPtr;
858 
859  // Create the MemRef descriptor.
860  auto memRefDescriptor = this->createMemRefDescriptor(
861  loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
862 
863  if (allocOp.getAsyncToken()) {
864  // Async alloc: make dependent ops use the same stream.
865  rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
866  } else {
867  rewriter.replaceOp(allocOp, {memRefDescriptor});
868  }
869 
870  return success();
871 }
872 
873 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
874  gpu::DeallocOp deallocOp, OpAdaptor adaptor,
875  ConversionPatternRewriter &rewriter) const {
876  if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
877  failed(isAsyncWithOneDependency(rewriter, deallocOp)))
878  return failure();
879 
880  Location loc = deallocOp.getLoc();
881 
882  Value pointer =
883  MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
884  Value stream = adaptor.getAsyncDependencies().front();
885  deallocCallBuilder.create(loc, rewriter, {pointer, stream});
886 
887  rewriter.replaceOp(deallocOp, {stream});
888  return success();
889 }
890 
891 static bool isGpuAsyncTokenType(Value value) {
892  return isa<gpu::AsyncTokenType>(value.getType());
893 }
894 
895 // Converts !gpu.async.token operands of `async.yield` to runtime calls. The
896 // !gpu.async.token are lowered to stream within the async.execute region, but
897 // are passed as events between them. For each !gpu.async.token operand, we
898 // create an event and record it on the stream.
899 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
900  async::YieldOp yieldOp, OpAdaptor adaptor,
901  ConversionPatternRewriter &rewriter) const {
902  if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType))
903  return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
904 
905  Location loc = yieldOp.getLoc();
906  SmallVector<Value, 4> newOperands(adaptor.getOperands());
907  llvm::SmallDenseSet<Value> streams;
908  for (auto &operand : yieldOp->getOpOperands()) {
909  if (!isGpuAsyncTokenType(operand.get()))
910  continue;
911  auto idx = operand.getOperandNumber();
912  auto stream = adaptor.getOperands()[idx];
913  auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
914  eventRecordCallBuilder.create(loc, rewriter, {event, stream});
915  newOperands[idx] = event;
916  streams.insert(stream);
917  }
918  for (auto stream : streams)
919  streamDestroyCallBuilder.create(loc, rewriter, {stream});
920 
921  rewriter.updateRootInPlace(yieldOp,
922  [&] { yieldOp->setOperands(newOperands); });
923  return success();
924 }
925 
926 // Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
927 static bool isDefinedByCallTo(Value value, StringRef functionName) {
928  assert(isa<LLVM::LLVMPointerType>(value.getType()));
929  if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
930  return defOp.getCallee()->equals(functionName);
931  return false;
932 }
933 
934 // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
935 // with the stream/event operands. The operands are destroyed. That is, it
936 // assumes that it is not used afterwards or elsewhere. Otherwise we will get a
937 // runtime error. Eventually, we should guarantee this property.
938 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
939  gpu::WaitOp waitOp, OpAdaptor adaptor,
940  ConversionPatternRewriter &rewriter) const {
941  if (waitOp.getAsyncToken())
942  return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
943 
944  Location loc = waitOp.getLoc();
945 
946  for (auto operand : adaptor.getOperands()) {
947  if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
948  // The converted operand's definition created a stream.
949  streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
950  streamDestroyCallBuilder.create(loc, rewriter, {operand});
951  } else {
952  // Otherwise the converted operand is an event. This assumes that we use
953  // events in control flow code as well.
954  eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
955  eventDestroyCallBuilder.create(loc, rewriter, {operand});
956  }
957  }
958 
959  rewriter.eraseOp(waitOp);
960  return success();
961 }
962 
963 // Converts `gpu.wait async` to runtime calls. The converted op creates a new
964 // stream that is synchronized with stream/event operands. The operands are
965 // destroyed. That is, it assumes that it is not used afterwards or elsewhere.
966 // Otherwise we will get a runtime error. Eventually, we should guarantee this
967 // property.
968 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
969  gpu::WaitOp waitOp, OpAdaptor adaptor,
970  ConversionPatternRewriter &rewriter) const {
971  if (!waitOp.getAsyncToken())
972  return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
973 
974  Location loc = waitOp.getLoc();
975 
976  auto insertionPoint = rewriter.saveInsertionPoint();
977  SmallVector<Value, 1> events;
978  for (auto pair :
979  llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
980  auto operand = std::get<1>(pair);
981  if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
982  // The converted operand's definition created a stream. Insert an event
983  // into the stream just after the last use of the original token operand.
984  auto *defOp = std::get<0>(pair).getDefiningOp();
985  rewriter.setInsertionPointAfter(defOp);
986  auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
987  eventRecordCallBuilder.create(loc, rewriter, {event, operand});
988  events.push_back(event);
989  } else {
990  // Otherwise the converted operand is an event. This assumes that we use
991  // events in control flow code as well.
992  events.push_back(operand);
993  }
994  }
995  rewriter.restoreInsertionPoint(insertionPoint);
996  auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
997  for (auto event : events)
998  streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
999  for (auto event : events)
1000  eventDestroyCallBuilder.create(loc, rewriter, {event});
1001  rewriter.replaceOp(waitOp, {stream});
1002 
1003  return success();
1004 }
1005 
1006 // Creates a struct containing all kernel parameters on the stack and returns
1007 // an array of type-erased pointers to the fields of the struct. The array can
1008 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
1009 // The generated code is essentially as follows:
1010 //
1011 // %struct = alloca(sizeof(struct { Parameters... }))
1012 // %array = alloca(NumParameters * sizeof(void *))
1013 // for (i : [0, NumParameters))
1014 // %fieldPtr = llvm.getelementptr %struct[0, i]
1015 // llvm.store parameters[i], %fieldPtr
1016 // %elementPtr = llvm.getelementptr %array[i]
1017 // llvm.store %fieldPtr, %elementPtr
1018 // return %array
1019 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
1020  gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const {
1021  auto loc = launchOp.getLoc();
1022  auto numKernelOperands = launchOp.getNumKernelOperands();
1023  // Note: If `useBarePtrCallConv` is set in the type converter's options,
1024  // the value of `kernelBarePtrCallConv` will be ignored.
1025  SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands(
1026  loc, launchOp.getOperands().take_back(numKernelOperands),
1027  adaptor.getOperands().take_back(numKernelOperands), builder,
1028  /*useBarePtrCallConv=*/kernelBarePtrCallConv);
1029  auto numArguments = arguments.size();
1030  SmallVector<Type, 4> argumentTypes;
1031  argumentTypes.reserve(numArguments);
1032  for (auto argument : arguments)
1033  argumentTypes.push_back(argument.getType());
1034  auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(),
1035  argumentTypes);
1036  auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, 1);
1037  auto structPtr =
1038  builder.create<LLVM::AllocaOp>(loc, llvmPointerType, structType, one,
1039  /*alignment=*/0);
1040  auto arraySize =
1041  builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, numArguments);
1042  auto arrayPtr = builder.create<LLVM::AllocaOp>(
1043  loc, llvmPointerType, llvmPointerType, arraySize, /*alignment=*/0);
1044  for (const auto &en : llvm::enumerate(arguments)) {
1045  Value fieldPtr =
1046  builder.create<LLVM::GEPOp>(loc, llvmPointerType, structType, structPtr,
1047  ArrayRef<LLVM::GEPArg>{0, en.index()});
1048  builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
1049  auto elementPtr = builder.create<LLVM::GEPOp>(
1050  loc, llvmPointerType, llvmPointerType, arrayPtr,
1051  ArrayRef<LLVM::GEPArg>{en.index()});
1052  builder.create<LLVM::StoreOp>(loc, fieldPtr, elementPtr);
1053  }
1054  return arrayPtr;
1055 }
1056 
1057 // Generates an LLVM IR dialect global that contains the name of the given
1058 // kernel function as a C string, and returns a pointer to its beginning.
1059 // The code is essentially:
1060 //
1061 // llvm.global constant @kernel_name("function_name\00")
1062 // func(...) {
1063 // %0 = llvm.addressof @kernel_name
1064 // %1 = llvm.constant (0 : index)
1065 // %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
1066 // }
1067 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
1068  StringRef moduleName, StringRef name, Location loc,
1069  OpBuilder &builder) const {
1070  // Make sure the trailing zero is included in the constant.
1071  std::vector<char> kernelName(name.begin(), name.end());
1072  kernelName.push_back('\0');
1073 
1074  std::string globalName =
1075  std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name));
1076  return LLVM::createGlobalString(
1077  loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
1078  LLVM::Linkage::Internal);
1079 }
1080 
1081 // Emits LLVM IR to launch a kernel function. Expects the module that contains
1082 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
1083 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
1084 //
1085 // %0 = call %binarygetter
1086 // %1 = call %moduleLoad(%0)
1087 // %2 = <see generateKernelNameConstant>
1088 // %3 = call %moduleGetFunction(%1, %2)
1089 // %4 = call %streamCreate()
1090 // %5 = <see generateParamsArray>
1091 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
1092 // call %streamSynchronize(%4)
1093 // call %streamDestroy(%4)
1094 // call %moduleUnload(%1)
1095 //
1096 // If the op is async, the stream corresponds to the (single) async dependency
1097 // as well as the async token the op produces.
1098 LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
1099  gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
1100  ConversionPatternRewriter &rewriter) const {
1101  if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
1102  return failure();
1103 
1104  if (launchOp.getAsyncDependencies().size() > 1)
1105  return rewriter.notifyMatchFailure(
1106  launchOp, "Cannot convert with more than one async dependency.");
1107 
1108  // Fail when the synchronous version of the op has async dependencies. The
1109  // lowering destroys the stream, and we do not want to check that there is no
1110  // use of the stream after this op.
1111  if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
1112  return rewriter.notifyMatchFailure(
1113  launchOp, "Cannot convert non-async op with async dependencies.");
1114 
1115  Location loc = launchOp.getLoc();
1116 
1117  // Create an LLVM global with CUBIN extracted from the kernel annotation and
1118  // obtain a pointer to the first byte in it.
1119  gpu::GPUModuleOp kernelModule;
1120  if (cachedModuleTable)
1121  kernelModule = cachedModuleTable->lookup<gpu::GPUModuleOp>(
1122  launchOp.getKernelModuleName());
1123  else
1124  kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
1125  launchOp, launchOp.getKernelModuleName());
1126  assert(kernelModule && "expected a kernel module");
1127 
1128  // If the module has Targets then just update the op operands.
1129  if (ArrayAttr targets = kernelModule.getTargetsAttr()) {
1130  Value stream = Value();
1131  if (!adaptor.getAsyncDependencies().empty())
1132  stream = adaptor.getAsyncDependencies().front();
1133  // If the async keyword is present and there are no dependencies, then a
1134  // stream must be created to pass to subsequent operations.
1135  else if (launchOp.getAsyncToken())
1136  stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
1137 
1138  // Lower the kernel operands to match kernel parameters.
1139  // Note: If `useBarePtrCallConv` is set in the type converter's options,
1140  // the value of `kernelBarePtrCallConv` will be ignored.
1141  SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands(
1142  loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(),
1143  rewriter, /*useBarePtrCallConv=*/kernelBarePtrCallConv);
1144 
1145  std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1146  if (launchOp.hasClusterSize()) {
1147  clusterSize =
1148  gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
1149  adaptor.getClusterSizeZ()};
1150  }
1151  rewriter.create<gpu::LaunchFuncOp>(
1152  launchOp.getLoc(), launchOp.getKernelAttr(),
1153  gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1154  adaptor.getGridSizeZ()},
1155  gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1156  adaptor.getBlockSizeZ()},
1157  adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
1158  if (launchOp.getAsyncToken())
1159  rewriter.replaceOp(launchOp, {stream});
1160  else
1161  rewriter.eraseOp(launchOp);
1162  return success();
1163  }
1164 
1165  auto binaryAttr =
1166  kernelModule->getAttrOfType<StringAttr>(gpuBinaryAnnotation);
1167  if (!binaryAttr) {
1168  kernelModule.emitOpError()
1169  << "missing " << gpuBinaryAnnotation << " attribute";
1170  return failure();
1171  }
1172 
1173  SmallString<128> nameBuffer(kernelModule.getName());
1174  nameBuffer.append(kGpuBinaryStorageSuffix);
1175  Value data =
1176  LLVM::createGlobalString(loc, rewriter, nameBuffer.str(),
1177  binaryAttr.getValue(), LLVM::Linkage::Internal);
1178 
1179  // Pass the binary size. SPIRV requires binary size.
1180  auto gpuBlob = binaryAttr.getValue();
1181  auto gpuBlobSize = rewriter.create<mlir::LLVM::ConstantOp>(
1182  loc, llvmInt64Type,
1183  mlir::IntegerAttr::get(llvmInt64Type,
1184  static_cast<int64_t>(gpuBlob.size())));
1185 
1186  auto module =
1187  moduleLoadCallBuilder.create(loc, rewriter, {data, gpuBlobSize});
1188 
1189  // Pass the count of the parameters to runtime wrappers
1190  auto paramsCount = rewriter.create<mlir::LLVM::ConstantOp>(
1191  loc, llvmInt64Type,
1193  llvmInt64Type,
1194  static_cast<int64_t>(launchOp.getNumKernelOperands())));
1195 
1196  // Get the function from the module. The name corresponds to the name of
1197  // the kernel function.
1198  auto kernelName = generateKernelNameConstant(
1199  launchOp.getKernelModuleName().getValue(),
1200  launchOp.getKernelName().getValue(), loc, rewriter);
1201  auto function = moduleGetFunctionCallBuilder.create(
1202  loc, rewriter, {module.getResult(), kernelName});
1203  Value zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type, 0);
1204  Value stream =
1205  adaptor.getAsyncDependencies().empty()
1206  ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult()
1207  : adaptor.getAsyncDependencies().front();
1208  // Create array of pointers to kernel arguments.
1209  auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter);
1210  auto nullpointer = rewriter.create<LLVM::ZeroOp>(loc, llvmPointerType);
1211  Value dynamicSharedMemorySize = launchOp.getDynamicSharedMemorySize()
1212  ? launchOp.getDynamicSharedMemorySize()
1213  : zero;
1214  launchKernelCallBuilder.create(
1215  loc, rewriter,
1216  {function.getResult(), adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1217  adaptor.getGridSizeZ(), adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1218  adaptor.getBlockSizeZ(), dynamicSharedMemorySize, stream, kernelParams,
1219  /*extra=*/nullpointer, paramsCount});
1220 
1221  if (launchOp.getAsyncToken()) {
1222  // Async launch: make dependent ops use the same stream.
1223  rewriter.replaceOp(launchOp, {stream});
1224  } else {
1225  // Synchronize with host and destroy stream. This must be the stream created
1226  // above (with no other uses) because we check that the synchronous version
1227  // does not have any async dependencies.
1228  streamSynchronizeCallBuilder.create(loc, rewriter, stream);
1229  streamDestroyCallBuilder.create(loc, rewriter, stream);
1230  rewriter.eraseOp(launchOp);
1231  }
1232  moduleUnloadCallBuilder.create(loc, rewriter, module.getResult());
1233 
1234  return success();
1235 }
1236 
1238  ConversionPatternRewriter &rewriter,
1239  LLVM::LLVMPointerType destinationType,
1240  Value sourcePtr,
1241  const LLVMTypeConverter &typeConverter) {
1242  auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
1243  if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1244  sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1245  loc,
1247  destinationType.getAddressSpace()),
1248  sourcePtr);
1249  return sourcePtr;
1250 }
1251 
1252 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1253  gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1254  ConversionPatternRewriter &rewriter) const {
1255  auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1256 
1257  if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
1258  !isConvertibleAndHasIdentityMaps(memRefType) ||
1259  failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
1260  return failure();
1261 
1262  auto loc = memcpyOp.getLoc();
1263 
1264  MemRefDescriptor srcDesc(adaptor.getSrc());
1265  Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
1266 
1267  Type elementPtrType = getElementPtrType(memRefType);
1268  Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
1269  Value gepPtr = rewriter.create<LLVM::GEPOp>(
1270  loc, elementPtrType,
1271  typeConverter->convertType(memRefType.getElementType()), nullPtr,
1272  numElements);
1273  auto sizeBytes =
1274  rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
1275 
1276  auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1277  srcDesc.alignedPtr(rewriter, loc),
1278  *getTypeConverter());
1279  auto dst = bitAndAddrspaceCast(
1280  loc, rewriter, llvmPointerType,
1281  MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
1282  *getTypeConverter());
1283 
1284  auto stream = adaptor.getAsyncDependencies().front();
1285  memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1286 
1287  rewriter.replaceOp(memcpyOp, {stream});
1288 
1289  return success();
1290 }
1291 
1292 LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1293  gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1294  ConversionPatternRewriter &rewriter) const {
1295  auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1296 
1297  if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
1298  !isConvertibleAndHasIdentityMaps(memRefType) ||
1299  failed(isAsyncWithOneDependency(rewriter, memsetOp)))
1300  return failure();
1301 
1302  auto loc = memsetOp.getLoc();
1303 
1304  Type valueType = adaptor.getValue().getType();
1305  unsigned bitWidth = valueType.getIntOrFloatBitWidth();
1306  // Ints and floats of 16 or 32 bit width are allowed.
1307  if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1308  return rewriter.notifyMatchFailure(
1309  memsetOp, "value must be a 16 or 32 bit int or float");
1310  }
1311 
1312  unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth();
1313  Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1314 
1315  MemRefDescriptor dstDesc(adaptor.getDst());
1316  Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
1317 
1318  auto value =
1319  rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
1320  auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1321  dstDesc.alignedPtr(rewriter, loc),
1322  *getTypeConverter());
1323 
1324  auto stream = adaptor.getAsyncDependencies().front();
1325  FunctionCallBuilder builder =
1326  valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1327  builder.create(loc, rewriter, {dst, value, numElements, stream});
1328 
1329  rewriter.replaceOp(memsetOp, {stream});
1330  return success();
1331 }
1332 
1333 LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1334  gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1335  ConversionPatternRewriter &rewriter) const {
1336  Location loc = op.getLoc();
1337  setDefaultDeviceCallBuilder.create(loc, rewriter, {adaptor.getDevIndex()});
1338  rewriter.replaceOp(op, {});
1339  return success();
1340 }
1341 
1342 template <typename T>
1343 static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) {
1344  Type llvmInt32Type = builder.getIntegerType(32);
1345  return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
1346  static_cast<int32_t>(tValue));
1347 }
1348 
1349 template <typename T>
1350 static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) {
1351  Type llvmFloat32Type = builder.getF32Type();
1352  return builder.create<LLVM::ConstantOp>(
1353  loc, llvmFloat32Type,
1354  builder.getF32FloatAttr(static_cast<float>(tValue)));
1355 }
1356 
1357 LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1358  gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1359  ConversionPatternRewriter &rewriter) const {
1360  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1361  failed(isAsyncWithOneDependency(rewriter, op)))
1362  return failure();
1363  Location loc = op.getLoc();
1364  auto stream = adaptor.getAsyncDependencies().front();
1365  Value pTensor =
1366  MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1367  Type dType = op.getMemref().getType().getElementType();
1368  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1369 
1370  SmallVector<Value, 4> dims;
1371  for (Value dim : adaptor.getDims()) {
1372  dims.push_back(dim);
1373  }
1374 
1375  Value handle;
1376  // TODO: For now, we track the use of the handle and lower it to cusparse /
1377  // cusparseLt accordingly. If in a block, both cusparse and cusparseLt are
1378  // used, we require two separate Creation ops to be the correct logic. In
1379  // future, we may add support to using one handle in sparse tensor / GPU
1380  // dialect in both cusparse and cusparseLt. use the cusparseLt create call if
1381  // the dnmat is used with spmat with 2:4 sparsity
1382  if (dims.size() == 2) {
1383  if (isSpMMCusparseLtOp(op.getDnTensor())) {
1384  auto handleSz = rewriter.create<LLVM::ConstantOp>(
1385  loc, getIndexType(), rewriter.getIndexAttr(11032));
1386  handle = rewriter.create<LLVM::AllocaOp>(
1387  loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1388  handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1389 
1390  createLtDnMatCallBuilder
1391  .create(loc, rewriter,
1392  {handle, dims[0], dims[1], pTensor, dtp, stream})
1393  .getResult();
1394  } else {
1395  handle =
1396  createDnMatCallBuilder
1397  .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1398  .getResult();
1399  }
1400  } else {
1401  assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1402  handle = createDnVecCallBuilder
1403  .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1404  .getResult();
1405  }
1406  rewriter.replaceOp(op, {handle, stream});
1407  return success();
1408 }
1409 
1410 LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1411  gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1412  ConversionPatternRewriter &rewriter) const {
1413  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1414  failed(isAsyncWithOneDependency(rewriter, op)))
1415  return failure();
1416  Location loc = op.getLoc();
1417  auto stream = adaptor.getAsyncDependencies().front();
1418  auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
1419  SmallVector<Value, 4> dims;
1420  for (Value dim : definingOp.getDims()) {
1421  dims.push_back(dim);
1422  }
1423  if (dims.size() == 2) {
1424  // Use the cusparseLt destroy call if the dnmat is used with spmat with
1425  // 2:4 sparsity
1426  if (isSpMMCusparseLtOp(op.getDnTensor())) {
1427  destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1428  {adaptor.getDnTensor(), stream});
1429  } else {
1430  destroyDnMatCallBuilder.create(loc, rewriter,
1431  {adaptor.getDnTensor(), stream});
1432  }
1433  } else {
1434  assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1435  destroyDnVecCallBuilder.create(loc, rewriter,
1436  {adaptor.getDnTensor(), stream});
1437  }
1438  rewriter.replaceOp(op, {stream});
1439  return success();
1440 }
1441 
1442 LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1443  gpu::CreateCooOp op, OpAdaptor adaptor,
1444  ConversionPatternRewriter &rewriter) const {
1445  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1446  failed(isAsyncWithOneDependency(rewriter, op)))
1447  return failure();
1448  Location loc = op.getLoc();
1449  auto stream = adaptor.getAsyncDependencies().front();
1450  Value pRowIdxs =
1451  MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1452  Value pColIdxs =
1453  MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1454  Value pValues =
1455  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1456  Type iType =
1457  llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1458  Type dType =
1459  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1460  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1461  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1462  auto handle =
1463  createCooCallBuilder
1464  .create(loc, rewriter,
1465  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1466  pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1467  .getResult();
1468  rewriter.replaceOp(op, {handle, stream});
1469  return success();
1470 }
1471 
1472 LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1473  gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1474  ConversionPatternRewriter &rewriter) const {
1475  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1476  failed(isAsyncWithOneDependency(rewriter, op)))
1477  return failure();
1478  Location loc = op.getLoc();
1479  auto stream = adaptor.getAsyncDependencies().front();
1480  Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc);
1481  Value pValues =
1482  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1483  Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1484  Type dType =
1485  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1486  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1487  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1488  auto handle =
1489  createCooAoSCallBuilder
1490  .create(loc, rewriter,
1491  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1492  pIdxs, pValues, itp, dtp, stream})
1493  .getResult();
1494  rewriter.replaceOp(op, {handle, stream});
1495  return success();
1496 }
1497 
1498 LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1499  gpu::CreateCsrOp op, OpAdaptor adaptor,
1500  ConversionPatternRewriter &rewriter) const {
1501  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1502  failed(isAsyncWithOneDependency(rewriter, op)))
1503  return failure();
1504  Location loc = op.getLoc();
1505  auto stream = adaptor.getAsyncDependencies().front();
1506  Value pRowPos =
1507  MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc);
1508  Value pColIdxs =
1509  MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1510  Value pValues =
1511  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1512  Type pType =
1513  llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1514  Type iType =
1515  llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1516  Type dType =
1517  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1518  auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1519  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1520  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1521  auto handle =
1522  createCsrCallBuilder
1523  .create(loc, rewriter,
1524  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1525  pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1526  .getResult();
1527  rewriter.replaceOp(op, {handle, stream});
1528  return success();
1529 }
1530 
1531 LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1532  gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1533  ConversionPatternRewriter &rewriter) const {
1534  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1535  failed(isAsyncWithOneDependency(rewriter, op)))
1536  return failure();
1537  Location loc = op.getLoc();
1538  auto stream = adaptor.getAsyncDependencies().front();
1539  Value pMat =
1540  MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1541  Type dType =
1542  llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1543  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1544 
1545  // CUDA runner asserts the size is 44104 bytes.
1546  auto handleSz = rewriter.create<LLVM::ConstantOp>(
1547  loc, getIndexType(), rewriter.getIndexAttr(44104));
1548  Value handle = rewriter.create<LLVM::AllocaOp>(
1549  loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1550  handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1551 
1552  create2To4SpMatCallBuilder
1553  .create(loc, rewriter,
1554  {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1555  .getResult();
1556  rewriter.replaceOp(op, {handle, stream});
1557  return success();
1558 }
1559 
1560 LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1561  gpu::DestroySpMatOp op, OpAdaptor adaptor,
1562  ConversionPatternRewriter &rewriter) const {
1563  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1564  failed(isAsyncWithOneDependency(rewriter, op)))
1565  return failure();
1566  Location loc = op.getLoc();
1567  auto stream = adaptor.getAsyncDependencies().front();
1568  // Use the cusparseLt destroy call if the spmat is 2:4 sparsity
1569  if (is2To4Sparsity(op.getSpmat())) {
1570  destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1571  {adaptor.getSpmat(), stream});
1572 
1573  } else {
1574  destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1575  }
1576  rewriter.replaceOp(op, {stream});
1577  return success();
1578 }
1579 
1580 LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1581  gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1582  ConversionPatternRewriter &rewriter) const {
1583  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1584  failed(isAsyncWithOneDependency(rewriter, op)))
1585  return failure();
1586  Location loc = op.getLoc();
1587  auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
1588  auto computeType = genConstInt32From(
1589  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1590  auto stream = adaptor.getAsyncDependencies().front();
1591  auto bufferSize = spMVBufferSizeCallBuilder
1592  .create(loc, rewriter,
1593  {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1594  adaptor.getDnY(), computeType, stream})
1595  .getResult();
1596  rewriter.replaceOp(op, {bufferSize, stream});
1597  return success();
1598 }
1599 
1600 LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1601  gpu::SpMVOp op, OpAdaptor adaptor,
1602  ConversionPatternRewriter &rewriter) const {
1603  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1604  failed(isAsyncWithOneDependency(rewriter, op)))
1605  return failure();
1606  Location loc = op.getLoc();
1607  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1608  auto computeType = genConstInt32From(
1609  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1610  auto stream = adaptor.getAsyncDependencies().front();
1611  Value pBuf =
1612  MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1613  spMVCallBuilder.create(loc, rewriter,
1614  {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1615  adaptor.getDnY(), computeType, pBuf, stream});
1616  rewriter.replaceOp(op, {stream});
1617  return success();
1618 }
1619 
1620 LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1621  gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1622  ConversionPatternRewriter &rewriter) const {
1623  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1624  failed(isAsyncWithOneDependency(rewriter, op)))
1625  return failure();
1626  Location loc = op.getLoc();
1627  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1628  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1629  auto stream = adaptor.getAsyncDependencies().front();
1630  Value bufferSize;
1631  if (is2To4Sparsity(op.getSpmatA())) {
1632  auto pruneFlag =
1633  genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA()));
1634  auto computeType = genConstInt32From(
1635  rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
1636  auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1637  rewriter.getIndexAttr(3));
1638  auto bufferSize = rewriter.create<LLVM::AllocaOp>(
1639  loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16);
1640  createCuSparseLtSpMMBufferSizeBuilder
1641  .create(loc, rewriter,
1642  {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1643  adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1644  pruneFlag, stream})
1645  .getResult();
1646 
1647  auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>(
1648  loc, llvmPointerType, llvmPointerType, bufferSize,
1649  ValueRange{rewriter.create<LLVM::ConstantOp>(
1650  loc, getIndexType(), rewriter.getIndexAttr(1))});
1651  auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>(
1652  loc, llvmPointerType, llvmPointerType, bufferSize,
1653  ValueRange{rewriter.create<LLVM::ConstantOp>(
1654  loc, getIndexType(), rewriter.getIndexAttr(2))});
1655  auto bufferSize0 =
1656  rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize);
1657  auto bufferSize1 =
1658  rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1);
1659  auto bufferSize2 =
1660  rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
1661 
1662  rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1663  } else {
1664  auto computeType = genConstInt32From(
1665  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1666  bufferSize =
1667  createSpMMBufferSizeCallBuilder
1668  .create(loc, rewriter,
1669  {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1670  adaptor.getDnmatC(), computeType, stream})
1671  .getResult();
1672  rewriter.replaceOp(op, {bufferSize, stream});
1673  }
1674  return success();
1675 }
1676 
1677 LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1678  gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1679  ConversionPatternRewriter &rewriter) const {
1680  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1681  failed(isAsyncWithOneDependency(rewriter, op)))
1682  return failure();
1683  Location loc = op.getLoc();
1684  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1685  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1686  auto computeType = genConstInt32From(
1687  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1688  auto stream = adaptor.getAsyncDependencies().front();
1689  auto bufferSize =
1690  createSDDMMBufferSizeCallBuilder
1691  .create(loc, rewriter,
1692  {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1693  adaptor.getSpmatC(), computeType, stream})
1694  .getResult();
1695  rewriter.replaceOp(op, {bufferSize, stream});
1696  return success();
1697 }
1698 
1699 LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1700  gpu::SpMMOp op, OpAdaptor adaptor,
1701  ConversionPatternRewriter &rewriter) const {
1702  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1703  failed(isAsyncWithOneDependency(rewriter, op)))
1704  return failure();
1705  Location loc = op.getLoc();
1706  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1707  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1708  auto computeType = genConstInt32From(
1709  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1710 
1711  auto stream = adaptor.getAsyncDependencies().front();
1712 
1713  // Lower to cusparseLt if applicable
1714  if (is2To4Sparsity(op.getSpmatA())) {
1715  SmallVector<Value> pBufs;
1716  for (Value buffer : adaptor.getBuffers()) {
1717  Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
1718  pBufs.push_back(pBuf);
1719  }
1720  createCuSparseLtSpMMBuilder.create(
1721  loc, rewriter,
1722  {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1723  pBufs[0], pBufs[1], pBufs[2], stream});
1724  } else {
1725  Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
1726  .allocatedPtr(rewriter, loc);
1727  createSpMMCallBuilder.create(loc, rewriter,
1728  {modeA, modeB, adaptor.getSpmatA(),
1729  adaptor.getDnmatB(), adaptor.getDnmatC(),
1730  computeType, pBuf, stream});
1731  }
1732  rewriter.replaceOp(op, {stream});
1733  return success();
1734 }
1735 
1736 template <typename T>
1738  converter.addConversion([&converter](T) -> Type {
1739  return LLVM::LLVMPointerType::get(&converter.getContext());
1740  });
1741 }
1742 
1743 LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1744  gpu::SDDMMOp op, OpAdaptor adaptor,
1745  ConversionPatternRewriter &rewriter) const {
1746  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1747  failed(isAsyncWithOneDependency(rewriter, op)))
1748  return failure();
1749  Location loc = op.getLoc();
1750  auto computeType = genConstInt32From(
1751  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1752  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1753  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1754  auto stream = adaptor.getAsyncDependencies().front();
1755  Value pBuf =
1756  MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1757  createSDDMMCallBuilder.create(loc, rewriter,
1758  {modeA, modeB, adaptor.getDnmatA(),
1759  adaptor.getDnmatB(), adaptor.getSpmatC(),
1760  computeType, pBuf, stream});
1761  rewriter.replaceOp(op, {stream});
1762  return success();
1763 }
1764 
1766 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1767  gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1768  ConversionPatternRewriter &rewriter) const {
1769  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1770  failed(isAsyncWithOneDependency(rewriter, op)))
1771  return failure();
1772  Location loc = op.getLoc();
1773  auto stream = adaptor.getAsyncDependencies().front();
1774  Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1775  .getResult();
1776  rewriter.replaceOp(op, {descr, stream});
1777  return success();
1778 }
1779 
1781 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1782  gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1783  ConversionPatternRewriter &rewriter) const {
1784  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1785  failed(isAsyncWithOneDependency(rewriter, op)))
1786  return failure();
1787  Location loc = op.getLoc();
1788  auto stream = adaptor.getAsyncDependencies().front();
1789  createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1790  {adaptor.getDesc(), stream});
1791  rewriter.replaceOp(op, {stream});
1792  return success();
1793 }
1794 
1796 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1797  gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1798  ConversionPatternRewriter &rewriter) const {
1799  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1800  failed(isAsyncWithOneDependency(rewriter, op)))
1801  return failure();
1802  Location loc = op.getLoc();
1803  auto computeType = genConstInt32From(
1804  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1805  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1806  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1807  auto stream = adaptor.getAsyncDependencies().front();
1808 
1809  Value pBuf =
1810  MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1811  Value bufferSizeNew;
1812 
1813  if (adaptor.getKind() ==
1814  gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1815  bufferSizeNew =
1816  createSpGEMMWorkEstimationBuilder
1817  .create(loc, rewriter,
1818  {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1819  adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1820  adaptor.getBufferSz(), pBuf, stream})
1821  .getResult();
1822  } else {
1823  bufferSizeNew =
1824  createSpGEMMComputeBuilder
1825  .create(loc, rewriter,
1826  {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1827  adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1828  adaptor.getBufferSz(), pBuf, stream})
1829  .getResult();
1830  }
1831  rewriter.replaceOp(op, {bufferSizeNew, stream});
1832  return success();
1833 }
1834 
1835 LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1836  gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1837  ConversionPatternRewriter &rewriter) const {
1838  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1839  failed(isAsyncWithOneDependency(rewriter, op)))
1840  return failure();
1841  Location loc = op.getLoc();
1842  auto computeType = genConstInt32From(
1843  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1844  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1845  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1846  auto stream = adaptor.getAsyncDependencies().front();
1847  createSpGEMMCopyBuilder.create(loc, rewriter,
1848  {adaptor.getDesc(), modeA, modeB,
1849  adaptor.getSpmatA(), adaptor.getSpmatB(),
1850  adaptor.getSpmatC(), computeType, stream});
1851  rewriter.replaceOp(op, {stream});
1852  return success();
1853 }
1854 
1855 LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1856  gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1857  ConversionPatternRewriter &rewriter) const {
1858  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1859  failed(isAsyncWithOneDependency(rewriter, op)))
1860  return failure();
1861  Location loc = op.getLoc();
1862  auto stream = adaptor.getAsyncDependencies().front();
1863 
1864  auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1865  rewriter.getIndexAttr(3));
1866  auto buffer = rewriter.create<LLVM::AllocaOp>(
1867  loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16);
1868 
1869  auto rowsPtr = rewriter.create<LLVM::GEPOp>(
1870  loc, llvmPointerType, llvmPointerType, buffer,
1871  ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1872  rewriter.getIndexAttr(0))});
1873  auto colsPtr = rewriter.create<LLVM::GEPOp>(
1874  loc, llvmPointerType, llvmPointerType, buffer,
1875  ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1876  rewriter.getIndexAttr(1))});
1877  auto nnzsPtr = rewriter.create<LLVM::GEPOp>(
1878  loc, llvmPointerType, llvmPointerType, buffer,
1879  ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1880  rewriter.getIndexAttr(2))});
1881  createSpMatGetSizeBuilder.create(
1882  loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1883  auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr);
1884  auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr);
1885  auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr);
1886 
1887  rewriter.replaceOp(op, {rows, cols, nnzs, stream});
1888  return success();
1889 }
1890 
1891 LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1892  gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1893  ConversionPatternRewriter &rewriter) const {
1894  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1895  failed(isAsyncWithOneDependency(rewriter, op)))
1896  return failure();
1897  Location loc = op.getLoc();
1898  auto stream = adaptor.getAsyncDependencies().front();
1899  Value pPos =
1900  MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
1901  Value pCrd =
1902  MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
1903  Value pVal =
1904  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1905  createSetCsrPointersBuilder.create(
1906  loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1907  rewriter.replaceOp(op, {stream});
1908  return success();
1909 }
1910 
1911 LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1912  gpu::CreateCscOp op, OpAdaptor adaptor,
1913  ConversionPatternRewriter &rewriter) const {
1914  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1915  failed(isAsyncWithOneDependency(rewriter, op)))
1916  return failure();
1917  Location loc = op.getLoc();
1918  auto stream = adaptor.getAsyncDependencies().front();
1919  Value pColPos =
1920  MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc);
1921  Value pRowIdxs =
1922  MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1923  Value pValues =
1924  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1925  Type pType =
1926  llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1927  Type iType =
1928  llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1929  Type dType =
1930  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1931  auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1932  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1933  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1934  auto handle =
1935  createCscCallBuilder
1936  .create(loc, rewriter,
1937  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1938  pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1939  .getResult();
1940  rewriter.replaceOp(op, {handle, stream});
1941  return success();
1942 }
1943 
1944 LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1945  gpu::CreateBsrOp op, OpAdaptor adaptor,
1946  ConversionPatternRewriter &rewriter) const {
1947  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1948  failed(isAsyncWithOneDependency(rewriter, op)))
1949  return failure();
1950  Location loc = op.getLoc();
1951  auto stream = adaptor.getAsyncDependencies().front();
1952  Value pRowPos =
1953  MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc);
1954  Value pColIdxs =
1955  MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc);
1956  Value pValues =
1957  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1958  Type pType =
1959  llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1960  Type iType =
1961  llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1962  Type dType =
1963  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1964  auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1965  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1966  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1967  auto handle =
1968  createBsrCallBuilder
1969  .create(loc, rewriter,
1970  {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1971  adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1972  pColIdxs, pValues, ptp, itp, dtp, stream})
1973  .getResult();
1974  rewriter.replaceOp(op, {handle, stream});
1975  return success();
1976 }
1977 
1979  RewritePatternSet &patterns,
1980  StringRef gpuBinaryAnnotation,
1981  bool kernelBarePtrCallConv,
1982  SymbolTable *cachedModuleTable) {
1983  addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
1984  addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
1985  addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
1986  addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter);
1987 
1988  patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
1989  ConvertDeallocOpToGpuRuntimeCallPattern,
1990  ConvertHostRegisterOpToGpuRuntimeCallPattern,
1991  ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1992  ConvertMemcpyOpToGpuRuntimeCallPattern,
1993  ConvertMemsetOpToGpuRuntimeCallPattern,
1994  ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1995  ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1996  ConvertWaitOpToGpuRuntimeCallPattern,
1997  ConvertAsyncYieldToGpuRuntimeCallPattern,
1998  ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1999  ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
2000  ConvertCreateCooOpToGpuRuntimeCallPattern,
2001  ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
2002  ConvertCreateCsrOpToGpuRuntimeCallPattern,
2003  ConvertCreateCscOpToGpuRuntimeCallPattern,
2004  ConvertCreateBsrOpToGpuRuntimeCallPattern,
2005  ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
2006  ConvertDestroySpMatOpToGpuRuntimeCallPattern,
2007  ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
2008  ConvertSpMVOpToGpuRuntimeCallPattern,
2009  ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
2010  ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
2011  ConvertSpMMOpToGpuRuntimeCallPattern,
2012  ConvertSDDMMOpToGpuRuntimeCallPattern,
2013  ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
2014  ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
2015  ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
2016  ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
2017  ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
2018  ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
2019  patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
2020  converter, gpuBinaryAnnotation, kernelBarePtrCallConv, cachedModuleTable);
2021  patterns.add<EraseGpuModuleOpPattern>(&converter.getContext());
2022 }
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static constexpr const char * kGpuBinaryStorageSuffix
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static MLIRContext * getContext(OpFoldResult val)
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static spirv::ScalarType getIndexType(MLIRContext *ctx, const SPIRVConversionOptions &options)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1509
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
FloatType getF32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
MLIRContext * getContext() const
Definition: Builders.h:55
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:253
IntegerAttr getI8IntegerAttr(int8_t value)
Definition: Builders.cpp:234
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:139
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:36
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:53
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:33
MLIRContext & getContext() const
Returns the MLIR context.
static LLVMStructType getNewIdentified(MLIRContext *context, StringRef name, ArrayRef< Type > elements, bool isPacked=false)
Gets a new identified struct with the given body.
Definition: LLVMTypes.cpp:419
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
This class helps build Operations.
Definition: Builders.h:206
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition: Builders.h:370
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition: Builders.h:244
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition: Builders.h:375
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:433
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:66
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:852
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
Definition: Region.h:205
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:52
U cast() const
Definition: Types.h:339
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:59
bool isF32() const
Definition: Types.cpp:51
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:117
bool isF16() const
Definition: Types.cpp:49
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
bool isBF16() const
Definition: Types.cpp:48
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Value createGlobalString(Location loc, OpBuilder &builder, StringRef name, StringRef value, Linkage linkage)
Create an LLVM global containing the string "value" at the module containing surrounding the insertio...
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:845
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Definition: MemoryOps.cpp:221
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
void populateFinalizeMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef gpuBinaryAnnotation={}, bool kernelBarePtrCallConv=false, SymbolTable *cachedModuleTable=nullptr)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LLVM::LLVMFunctionType functionType
Definition: GPUCommonPass.h:64
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:37