MLIR  19.0.0git
GPUToLLVMConversion.cpp
Go to the documentation of this file.
1 //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert gpu.launch_func op into a sequence of
10 // GPU runtime calls. As most of GPU runtimes does not have a stable published
11 // ABI, this pass uses a slim runtime layer that builds on top of the public
12 // API from GPU runtime headers.
13 //
14 //===----------------------------------------------------------------------===//
15 
17 
34 #include "mlir/IR/Attributes.h"
35 #include "mlir/IR/Builders.h"
36 #include "mlir/IR/BuiltinOps.h"
37 #include "mlir/IR/BuiltinTypes.h"
38 
39 #include "llvm/ADT/STLExtras.h"
40 #include "llvm/Support/Error.h"
41 #include "llvm/Support/FormatVariadic.h"
42 
43 #define DEBUG_TYPE "gpu-to-llvm"
44 
45 namespace mlir {
46 #define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
47 #include "mlir/Conversion/Passes.h.inc"
48 } // namespace mlir
49 
50 using namespace mlir;
51 
52 static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst";
53 
54 namespace {
55 class GpuToLLVMConversionPass
56  : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
57 public:
58  using Base::Base;
59  void getDependentDialects(DialectRegistry &registry) const final {
60  Base::getDependentDialects(registry);
62  }
63  // Run the dialect converter on the module.
64  void runOnOperation() override;
65 };
66 
67 template <typename OpTy>
68 class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
69 public:
70  explicit ConvertOpToGpuRuntimeCallPattern(
71  const LLVMTypeConverter &typeConverter)
72  : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
73 
74 protected:
76  MemRefType type, MemRefDescriptor desc) const {
78  return type.hasStaticShape()
80  rewriter, loc, indexType, type.getNumElements())
81  // For identity maps (verified by caller), the number of
82  // elements is stride[0] * size[0].
83  : rewriter.create<LLVM::MulOp>(loc,
84  desc.stride(rewriter, loc, 0),
85  desc.size(rewriter, loc, 0));
86  }
87 
88  MLIRContext *context = &this->getTypeConverter()->getContext();
89 
90  Type llvmVoidType = LLVM::LLVMVoidType::get(context);
91  LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
92  Type llvmInt8Type = IntegerType::get(context, 8);
93  Type llvmInt16Type = IntegerType::get(context, 16);
94  Type llvmInt32Type = IntegerType::get(context, 32);
95  Type llvmInt64Type = IntegerType::get(context, 64);
96  Type llvmFloat32Type = Float32Type::get(context);
97  Type llvmIntPtrType = IntegerType::get(
98  context, this->getTypeConverter()->getPointerBitwidth(0));
99 
100  FunctionCallBuilder moduleLoadCallBuilder = {
101  "mgpuModuleLoad",
102  llvmPointerType /* void *module */,
103  {llvmPointerType /* void *cubin */, llvmInt64Type /* size_t size */}};
104  FunctionCallBuilder moduleUnloadCallBuilder = {
105  "mgpuModuleUnload", llvmVoidType, {llvmPointerType /* void *module */}};
106  FunctionCallBuilder moduleGetFunctionCallBuilder = {
107  "mgpuModuleGetFunction",
108  llvmPointerType /* void *function */,
109  {
110  llvmPointerType, /* void *module */
111  llvmPointerType /* char *name */
112  }};
113  FunctionCallBuilder launchKernelCallBuilder = {
114  "mgpuLaunchKernel",
115  llvmVoidType,
116  {
117  llvmPointerType, /* void* f */
118  llvmIntPtrType, /* intptr_t gridXDim */
119  llvmIntPtrType, /* intptr_t gridyDim */
120  llvmIntPtrType, /* intptr_t gridZDim */
121  llvmIntPtrType, /* intptr_t blockXDim */
122  llvmIntPtrType, /* intptr_t blockYDim */
123  llvmIntPtrType, /* intptr_t blockZDim */
124  llvmInt32Type, /* unsigned int sharedMemBytes */
125  llvmPointerType, /* void *hstream */
126  llvmPointerType, /* void **kernelParams */
127  llvmPointerType, /* void **extra */
128  llvmInt64Type /* size_t paramsCount */
129  }};
130  FunctionCallBuilder streamCreateCallBuilder = {
131  "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
132  FunctionCallBuilder streamDestroyCallBuilder = {
133  "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
134  FunctionCallBuilder streamSynchronizeCallBuilder = {
135  "mgpuStreamSynchronize",
136  llvmVoidType,
137  {llvmPointerType /* void *stream */}};
138  FunctionCallBuilder streamWaitEventCallBuilder = {
139  "mgpuStreamWaitEvent",
140  llvmVoidType,
141  {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
142  FunctionCallBuilder eventCreateCallBuilder = {
143  "mgpuEventCreate", llvmPointerType /* void *event */, {}};
144  FunctionCallBuilder eventDestroyCallBuilder = {
145  "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
146  FunctionCallBuilder eventSynchronizeCallBuilder = {
147  "mgpuEventSynchronize",
148  llvmVoidType,
149  {llvmPointerType /* void *event */}};
150  FunctionCallBuilder eventRecordCallBuilder = {
151  "mgpuEventRecord",
152  llvmVoidType,
153  {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
154  FunctionCallBuilder hostRegisterCallBuilder = {
155  "mgpuMemHostRegisterMemRef",
156  llvmVoidType,
157  {llvmIntPtrType /* intptr_t rank */,
158  llvmPointerType /* void *memrefDesc */,
159  llvmIntPtrType /* intptr_t elementSizeBytes */}};
160  FunctionCallBuilder hostUnregisterCallBuilder = {
161  "mgpuMemHostUnregisterMemRef",
162  llvmVoidType,
163  {llvmIntPtrType /* intptr_t rank */,
164  llvmPointerType /* void *memrefDesc */,
165  llvmIntPtrType /* intptr_t elementSizeBytes */}};
166  FunctionCallBuilder allocCallBuilder = {
167  "mgpuMemAlloc",
168  llvmPointerType /* void * */,
169  {llvmIntPtrType /* intptr_t sizeBytes */,
170  llvmPointerType /* void *stream */,
171  llvmInt8Type /* bool isHostShared */}};
172  FunctionCallBuilder deallocCallBuilder = {
173  "mgpuMemFree",
174  llvmVoidType,
175  {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
176  FunctionCallBuilder memcpyCallBuilder = {
177  "mgpuMemcpy",
178  llvmVoidType,
179  {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
180  llvmIntPtrType /* intptr_t sizeBytes */,
181  llvmPointerType /* void *stream */}};
182  FunctionCallBuilder memset16CallBuilder = {
183  "mgpuMemset16",
184  llvmVoidType,
185  {llvmPointerType /* void *dst */,
186  llvmInt16Type /* unsigned short value */,
187  llvmIntPtrType /* intptr_t sizeBytes */,
188  llvmPointerType /* void *stream */}};
189  FunctionCallBuilder memset32CallBuilder = {
190  "mgpuMemset32",
191  llvmVoidType,
192  {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
193  llvmIntPtrType /* intptr_t sizeBytes */,
194  llvmPointerType /* void *stream */}};
195  FunctionCallBuilder setDefaultDeviceCallBuilder = {
196  "mgpuSetDefaultDevice",
197  llvmVoidType,
198  {llvmInt32Type /* uint32_t devIndex */}};
199  FunctionCallBuilder createDnVecCallBuilder = {
200  "mgpuCreateDnVec",
201  llvmPointerType,
202  {llvmIntPtrType, llvmPointerType, llvmInt32Type,
203  llvmPointerType /* void *stream */}};
204  FunctionCallBuilder destroyDnVecCallBuilder = {
205  "mgpuDestroyDnVec",
206  llvmVoidType,
207  {llvmPointerType, llvmPointerType /* void *stream */}};
208  FunctionCallBuilder createDnMatCallBuilder = {
209  "mgpuCreateDnMat",
210  llvmPointerType,
211  {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
212  llvmPointerType /* void *stream */}};
213  FunctionCallBuilder destroyDnMatCallBuilder = {
214  "mgpuDestroyDnMat",
215  llvmVoidType,
216  {llvmPointerType, llvmPointerType /* void *stream */}};
217  FunctionCallBuilder createCooCallBuilder = {
218  "mgpuCreateCoo",
219  llvmPointerType,
220  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
221  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
222  llvmPointerType /* void *stream */}};
223  FunctionCallBuilder createCooAoSCallBuilder = {
224  "mgpuCreateCooAoS", // deprecated in cuSPARSE 11.2
225  llvmPointerType,
226  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
227  llvmPointerType, llvmInt32Type, llvmInt32Type,
228  llvmPointerType /* void *stream */}};
229  FunctionCallBuilder createCsrCallBuilder = {
230  "mgpuCreateCsr",
231  llvmPointerType,
232  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
233  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
234  llvmInt32Type, llvmPointerType /* void *stream */}};
235  FunctionCallBuilder createCscCallBuilder = {
236  "mgpuCreateCsc",
237  llvmPointerType,
238  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
239  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
240  llvmInt32Type, llvmPointerType /* void *stream */}};
241  FunctionCallBuilder createBsrCallBuilder = {
242  "mgpuCreateBsr",
243  llvmPointerType,
244  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
245  llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
246  llvmInt32Type, llvmInt32Type, llvmInt32Type,
247  llvmPointerType /* void *stream */}};
248  FunctionCallBuilder destroySpMatCallBuilder = {
249  "mgpuDestroySpMat",
250  llvmVoidType,
251  {llvmPointerType, llvmPointerType /* void *stream */}};
252  FunctionCallBuilder spMVBufferSizeCallBuilder = {
253  "mgpuSpMVBufferSize",
254  llvmIntPtrType,
255  {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
256  llvmInt32Type, llvmPointerType /* void *stream */}};
257  FunctionCallBuilder spMVCallBuilder = {
258  "mgpuSpMV",
259  llvmVoidType,
260  {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
261  llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
262  FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
263  "mgpuSpMMBufferSize",
264  llvmIntPtrType,
265  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
266  llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
267  FunctionCallBuilder createSpMMCallBuilder = {
268  "mgpuSpMM",
269  llvmVoidType,
270  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
271  llvmPointerType, llvmInt32Type, llvmPointerType,
272  llvmPointerType /* void *stream */}};
273  FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
274  "mgpuSDDMMBufferSize",
275  llvmIntPtrType,
276  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
277  llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
278  FunctionCallBuilder createSDDMMCallBuilder = {
279  "mgpuSDDMM",
280  llvmVoidType,
281  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
282  llvmPointerType, llvmInt32Type, llvmPointerType,
283  llvmPointerType /* void *stream */}};
284  FunctionCallBuilder createLtDnMatCallBuilder = {
285  "mgpuCreateCuSparseLtDnMat",
286  llvmVoidType,
287  {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
288  llvmInt32Type, llvmPointerType /* void *stream */}};
289  FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
290  "mgpuDestroyCuSparseLtSpMat",
291  llvmVoidType,
292  {llvmPointerType, llvmPointerType /* void *stream */}};
293  FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
294  "mgpuDestroyCuSparseLtDnMat",
295  llvmVoidType,
296  {llvmPointerType, llvmPointerType /* void *stream */}};
297  FunctionCallBuilder create2To4SpMatCallBuilder = {
298  "mgpuCusparseLtCreate2To4SpMat",
299  llvmVoidType,
300  {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
301  llvmInt32Type, llvmPointerType /* void *stream */}};
302  FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
303  "mgpuCuSparseLtSpMMBufferSize",
304  llvmVoidType,
305  {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
306  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
307  llvmPointerType /*void *stream*/}};
308  FunctionCallBuilder createCuSparseLtSpMMBuilder = {
309  "mgpuCuSparseLtSpMM",
310  llvmVoidType,
311  {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
312  llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}};
313  FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
314  "mgpuSpGEMMCreateDescr",
315  llvmPointerType,
316  {llvmPointerType /*void *stream*/}};
317  FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
318  "mgpuSpGEMMDestroyDescr",
319  llvmVoidType,
320  {llvmPointerType /*s*/, llvmPointerType /*void *stream*/}};
321  FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
322  "mgpuSpGEMMWorkEstimation",
323  llvmIntPtrType,
324  {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
325  llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
326  llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
327  llvmPointerType /*void *stream*/}};
328  FunctionCallBuilder createSpGEMMComputeBuilder = {
329  "mgpuSpGEMMCompute",
330  llvmIntPtrType,
331  {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
332  llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
333  llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
334  llvmPointerType /*void *stream*/}};
335  FunctionCallBuilder createSpGEMMCopyBuilder = {
336  "mgpuSpGEMMCopy",
337  llvmVoidType,
338  {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
339  llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
340  llvmInt32Type /*ctp*/, llvmPointerType /*void *stream*/}};
341  FunctionCallBuilder createSpMatGetSizeBuilder = {
342  "mgpuSpMatGetSize",
343  llvmVoidType,
344  {llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/,
345  llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}};
346  FunctionCallBuilder createSetCsrPointersBuilder = {
347  "mgpuSetCsrPointers",
348  llvmVoidType,
349  {llvmPointerType /*spmat*/, llvmPointerType /*pos*/,
350  llvmPointerType /*crd*/, llvmPointerType /*val*/,
351  llvmPointerType /*void *stream*/}};
352 };
353 
354 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
355 /// call. Currently it supports CUDA and ROCm (HIP).
356 class ConvertHostRegisterOpToGpuRuntimeCallPattern
357  : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
358 public:
359  ConvertHostRegisterOpToGpuRuntimeCallPattern(
360  const LLVMTypeConverter &typeConverter)
361  : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
362 
363 private:
365  matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
366  ConversionPatternRewriter &rewriter) const override;
367 };
368 
369 class ConvertHostUnregisterOpToGpuRuntimeCallPattern
370  : public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
371 public:
372  ConvertHostUnregisterOpToGpuRuntimeCallPattern(
373  const LLVMTypeConverter &typeConverter)
374  : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
375  }
376 
377 private:
379  matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
380  ConversionPatternRewriter &rewriter) const override;
381 };
382 
383 /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
384 /// call. Currently it supports CUDA and ROCm (HIP).
385 class ConvertAllocOpToGpuRuntimeCallPattern
386  : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
387 public:
388  ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
389  : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
390 
391 private:
393  matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
394  ConversionPatternRewriter &rewriter) const override;
395 };
396 
397 /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
398 /// call. Currently it supports CUDA and ROCm (HIP).
399 class ConvertDeallocOpToGpuRuntimeCallPattern
400  : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
401 public:
402  ConvertDeallocOpToGpuRuntimeCallPattern(
403  const LLVMTypeConverter &typeConverter)
404  : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
405 
406 private:
408  matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
409  ConversionPatternRewriter &rewriter) const override;
410 };
411 
412 class ConvertAsyncYieldToGpuRuntimeCallPattern
413  : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
414 public:
415  ConvertAsyncYieldToGpuRuntimeCallPattern(
416  const LLVMTypeConverter &typeConverter)
417  : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
418 
419 private:
421  matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
422  ConversionPatternRewriter &rewriter) const override;
423 };
424 
425 /// A rewrite pattern to convert gpu.wait operations into a GPU runtime
426 /// call. Currently it supports CUDA and ROCm (HIP).
427 class ConvertWaitOpToGpuRuntimeCallPattern
428  : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
429 public:
430  ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
431  : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
432 
433 private:
435  matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
436  ConversionPatternRewriter &rewriter) const override;
437 };
438 
439 /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
440 /// call. Currently it supports CUDA and ROCm (HIP).
441 class ConvertWaitAsyncOpToGpuRuntimeCallPattern
442  : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
443 public:
444  ConvertWaitAsyncOpToGpuRuntimeCallPattern(
445  const LLVMTypeConverter &typeConverter)
446  : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
447 
448 private:
450  matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
451  ConversionPatternRewriter &rewriter) const override;
452 };
453 
454 /// A rewrite patter to convert gpu.launch_func operations into a sequence of
455 /// GPU runtime calls. Currently it supports CUDA and ROCm (HIP).
456 ///
457 /// In essence, a gpu.launch_func operations gets compiled into the following
458 /// sequence of runtime calls:
459 ///
460 /// * moduleLoad -- loads the module given the cubin / hsaco data
461 /// * moduleGetFunction -- gets a handle to the actual kernel function
462 /// * getStreamHelper -- initializes a new compute stream on GPU
463 /// * launchKernel -- launches the kernel on a stream
464 /// * streamSynchronize -- waits for operations on the stream to finish
465 ///
466 /// Intermediate data structures are allocated on the stack.
467 class ConvertLaunchFuncOpToGpuRuntimeCallPattern
468  : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
469 public:
470  ConvertLaunchFuncOpToGpuRuntimeCallPattern(
471  const LLVMTypeConverter &typeConverter, StringRef gpuBinaryAnnotation,
472  bool kernelBarePtrCallConv, SymbolTable *cachedModuleTable)
473  : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
474  gpuBinaryAnnotation(gpuBinaryAnnotation),
475  kernelBarePtrCallConv(kernelBarePtrCallConv),
476  cachedModuleTable(cachedModuleTable) {}
477 
478 private:
479  Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
480  OpBuilder &builder) const;
481  Value generateKernelNameConstant(StringRef moduleName, StringRef name,
482  Location loc, OpBuilder &builder) const;
483 
485  matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
486  ConversionPatternRewriter &rewriter) const override;
487 
488  llvm::SmallString<32> gpuBinaryAnnotation;
489  bool kernelBarePtrCallConv;
490  SymbolTable *cachedModuleTable;
491 };
492 
493 class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
495 
496  LogicalResult matchAndRewrite(gpu::GPUModuleOp op,
497  PatternRewriter &rewriter) const override {
498  // GPU kernel modules are no longer necessary since we have a global
499  // constant with the CUBIN, or HSACO data.
500  rewriter.eraseOp(op);
501  return success();
502  }
503 };
504 
505 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
506 /// call. Currently it supports CUDA and ROCm (HIP).
507 class ConvertMemcpyOpToGpuRuntimeCallPattern
508  : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
509 public:
510  ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
511  : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
512 
513 private:
515  matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
516  ConversionPatternRewriter &rewriter) const override;
517 };
518 
519 /// A rewrite pattern to convert gpu.memset operations into a GPU runtime
520 /// call. Currently it supports CUDA and ROCm (HIP).
521 class ConvertMemsetOpToGpuRuntimeCallPattern
522  : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
523 public:
524  ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
525  : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
526 
527 private:
529  matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
530  ConversionPatternRewriter &rewriter) const override;
531 };
532 
533 /// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call.
534 /// Currently supports CUDA and ROCm (HIP)
535 class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
536  : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
537 public:
538  ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
539  const LLVMTypeConverter &typeConverter)
540  : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
541  typeConverter) {}
542 
544  matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
545  ConversionPatternRewriter &rewriter) const override;
546 };
547 
548 /// Generic rewriting rule for operation on sparse matrices.
549 /// Currently supports CUDA (by means of cuSparse and cuSparseLt).
550 #define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
551  class Convert##op_name##ToGpuRuntimeCallPattern \
552  : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
553  public: \
554  Convert##op_name##ToGpuRuntimeCallPattern( \
555  const LLVMTypeConverter &typeConverter) \
556  : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
557  \
558  private: \
559  LogicalResult \
560  matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
561  ConversionPatternRewriter &rewriter) const override; \
562  };
563 
581 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp)
585 
586 } // namespace
587 
588 void GpuToLLVMConversionPass::runOnOperation() {
589  MLIRContext *context = &getContext();
590  SymbolTable symbolTable = SymbolTable(getOperation());
591  LowerToLLVMOptions options(context);
592  options.useBarePtrCallConv = hostBarePtrCallConv;
593  RewritePatternSet patterns(context);
594  ConversionTarget target(*context);
595  target.addLegalDialect<LLVM::LLVMDialect>();
596  LLVMTypeConverter converter(context, options);
597 
598  // Populate all patterns from all dialects that implement the
599  // `ConvertToLLVMPatternInterface` interface.
600  for (Dialect *dialect : context->getLoadedDialects()) {
601  auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
602  if (!iface)
603  continue;
604  iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
605  }
606 
607  // Preserve GPU modules if they have target attributes.
608  target.addDynamicallyLegalOp<gpu::GPUModuleOp>(
609  [](gpu::GPUModuleOp module) -> bool {
610  return module.getTargetsAttr() != nullptr;
611  });
612  // Accept as legal LaunchFuncOps if they refer to GPU Modules with targets and
613  // the operands have been lowered.
614  target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
615  [&](gpu::LaunchFuncOp op) -> bool {
616  auto module =
617  symbolTable.lookup<gpu::GPUModuleOp>(op.getKernelModuleName());
618  return converter.isLegal(op->getOperandTypes()) &&
619  converter.isLegal(op->getResultTypes()) &&
620  (module && module.getTargetsAttr() &&
621  !module.getTargetsAttr().empty());
622  });
623 
624  // These aren't covered by the ConvertToLLVMPatternInterface right now.
625  populateVectorToLLVMConversionPatterns(converter, patterns);
628  target);
629  populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation,
630  kernelBarePtrCallConv, &symbolTable);
631 
632  if (failed(
633  applyPartialConversion(getOperation(), target, std::move(patterns))))
634  signalPassFailure();
635 }
636 
637 LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
638  ArrayRef<Value> arguments) const {
639  auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
640  auto function = [&] {
641  if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
642  return function;
643  return OpBuilder::atBlockEnd(module.getBody())
644  .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
645  }();
646  return builder.create<LLVM::CallOp>(loc, function, arguments);
647 }
648 
649 // Corresponding to cusparseIndexType_t defined in cusparse.h.
650 static int32_t getCuSparseIndexTypeFrom(Type type) {
651  if (type.isInteger(16))
652  return 1; // CUSPARSE_INDEX_16U
653  if (type.isInteger(32))
654  return 2; // CUSPARSE_INDEX_32I
655  return 3; // CUSPARSE_INDEX_64I
656 }
657 
658 static int32_t getCuSparseLtDataTypeFrom(Type type) {
659  if (type.isF16())
660  return 0; // CUSPARSE_COMPUTE_16F,
661  if (type.isInteger(32))
662  return 1; // CUSPARSE_COMPUTE_32I
663  llvm_unreachable("unsupported type");
664  // TODO: add support to TF32
665 }
666 
667 // Corresponding to cudaDataType_t defined in CUDA library_types.h.
668 static int32_t getCuSparseDataTypeFrom(Type type) {
669  if (llvm::isa<ComplexType>(type)) {
670  // get the element type
671  auto elementType = cast<ComplexType>(type).getElementType();
672  if (elementType.isBF16())
673  return 15; // CUDA_C_16BF
674  if (elementType.isF16())
675  return 6; // CUDA_C_16F
676  if (elementType.isF32())
677  return 4; // CUDA_C_32F
678  if (elementType.isF64())
679  return 5; // CUDA_C_64F
680  if (elementType.isInteger(8))
681  return 7; // CUDA_C_8I
682  if (elementType.isInteger(16))
683  return 21; // CUDA_C_16I
684  if (elementType.isInteger(32))
685  return 11; // CUDA_C_32I
686  }
687  if (type.isBF16())
688  return 14; // CUDA_R_16BF
689  if (type.isF16())
690  return 2; // CUDA_R_16F
691  if (type.isF32())
692  return 0; // CUDA_R_32F
693  if (type.isF64())
694  return 1; // CUDA_R_64F
695  if (type.isInteger(8))
696  return 3; // CUDA_R_8I
697  if (type.isInteger(16))
698  return 20; // CUDA_R_16I
699  if (type.isInteger(32))
700  return 10; // CUDA_R_32I
701 
702  llvm_unreachable("unsupported element type");
703 }
704 
705 static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat) {
706  return spMat.getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
707 }
708 
709 // TODO: We may want a run-time (of the mlir compiler) disablement/warning:
710 // cusparseLt currently won't work for cuda architecture <8.0 and will trigger a
711 // runtime (of the CUDA program) error , but it might be great if we could at
712 // least output a warning when we found the target architecture is <8.0 and the
713 // user still wants to use cusparseLt. to make sure when lowering gpu sparse
714 // dialect to llvm calls, the cusparselt calls are disabled for cuda
715 // architecture <8.0
716 static bool is2To4Sparsity(Value spMat) {
717  if (auto op = spMat.getDefiningOp<gpu::Create2To4SpMatOp>())
718  return true;
719  if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
720  return false;
721  if (auto op = spMat.getDefiningOp<gpu::CreateCooAoSOp>())
722  return false;
723  if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
724  return false;
725  if (auto op = spMat.getDefiningOp<gpu::CreateCscOp>())
726  return false;
727  if (auto op = spMat.getDefiningOp<gpu::CreateBsrOp>())
728  return false;
729  // Print the spMat defining op
730  spMat.getDefiningOp()->print(llvm::errs());
731  llvm_unreachable("cannot find spmat def");
732 }
733 
734 static bool isSpMMCusparseLtOp(Value op) {
735  for (Operation *user : op.getUsers()) {
736  auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
737  // If the other operator is 50% sparsity then we should use cusparseLt
738  if (!spmmOp)
739  continue;
740  if (is2To4Sparsity(spmmOp.getSpmatA()))
741  return true;
742  }
743  return false;
744 }
745 
746 // Returns whether all operands are of LLVM type.
748  ConversionPatternRewriter &rewriter) {
749  if (!llvm::all_of(operands, [](Value value) {
750  return LLVM::isCompatibleType(value.getType());
751  }))
752  return rewriter.notifyMatchFailure(
753  op, "Cannot convert if operands aren't of LLVM type.");
754  return success();
755 }
756 
757 static LogicalResult
759  gpu::AsyncOpInterface op) {
760  if (op.getAsyncDependencies().size() != 1)
761  return rewriter.notifyMatchFailure(
762  op, "Can only convert with exactly one async dependency.");
763 
764  if (!op.getAsyncToken())
765  return rewriter.notifyMatchFailure(op, "Can convert only async version.");
766 
767  return success();
768 }
769 
770 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
771  gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
772  ConversionPatternRewriter &rewriter) const {
773  auto *op = hostRegisterOp.getOperation();
774  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
775  return failure();
776 
777  Location loc = op->getLoc();
778 
779  auto memRefType = hostRegisterOp.getValue().getType();
780  auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
781  auto elementSize = getSizeInBytes(loc, elementType, rewriter);
782 
783  auto arguments = getTypeConverter()->promoteOperands(
784  loc, op->getOperands(), adaptor.getOperands(), rewriter);
785  arguments.push_back(elementSize);
786  hostRegisterCallBuilder.create(loc, rewriter, arguments);
787 
788  rewriter.eraseOp(op);
789  return success();
790 }
791 
792 LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
793  gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
794  ConversionPatternRewriter &rewriter) const {
795  Operation *op = hostUnregisterOp.getOperation();
796  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
797  return failure();
798 
799  Location loc = op->getLoc();
800 
801  auto memRefType = hostUnregisterOp.getValue().getType();
802  auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
803  auto elementSize = getSizeInBytes(loc, elementType, rewriter);
804 
805  auto arguments = getTypeConverter()->promoteOperands(
806  loc, op->getOperands(), adaptor.getOperands(), rewriter);
807  arguments.push_back(elementSize);
808  hostUnregisterCallBuilder.create(loc, rewriter, arguments);
809 
810  rewriter.eraseOp(op);
811  return success();
812 }
813 
814 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
815  gpu::AllocOp allocOp, OpAdaptor adaptor,
816  ConversionPatternRewriter &rewriter) const {
817 
818  MemRefType memRefType = allocOp.getType();
819 
820  if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
821  !isConvertibleAndHasIdentityMaps(memRefType))
822  return failure();
823 
824  auto loc = allocOp.getLoc();
825 
826  bool isShared = allocOp.getHostShared();
827 
828  if (isShared && allocOp.getAsyncToken())
829  return rewriter.notifyMatchFailure(
830  allocOp, "Host Shared allocation cannot be done async");
831  if (!isShared && failed(isAsyncWithOneDependency(rewriter, allocOp)))
832  return failure();
833 
834  // Get shape of the memref as values: static sizes are constant
835  // values and dynamic sizes are passed to 'alloc' as operands.
836  SmallVector<Value, 4> shape;
837  SmallVector<Value, 4> strides;
838  Value sizeBytes;
839  getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
840  shape, strides, sizeBytes);
841 
842  // Allocate the underlying buffer and store a pointer to it in the MemRef
843  // descriptor.
844  auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
845  Value stream = adaptor.getAsyncDependencies().empty()
846  ? nullPtr
847  : adaptor.getAsyncDependencies().front();
848 
849  auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>(
850  loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
851 
852  Value allocatedPtr =
853  allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
854  .getResult();
855 
856  // No alignment.
857  Value alignedPtr = allocatedPtr;
858 
859  // Create the MemRef descriptor.
860  auto memRefDescriptor = this->createMemRefDescriptor(
861  loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
862 
863  if (allocOp.getAsyncToken()) {
864  // Async alloc: make dependent ops use the same stream.
865  rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
866  } else {
867  rewriter.replaceOp(allocOp, {memRefDescriptor});
868  }
869 
870  return success();
871 }
872 
873 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
874  gpu::DeallocOp deallocOp, OpAdaptor adaptor,
875  ConversionPatternRewriter &rewriter) const {
876  if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
877  failed(isAsyncWithOneDependency(rewriter, deallocOp)))
878  return failure();
879 
880  Location loc = deallocOp.getLoc();
881 
882  Value pointer =
883  MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
884  Value stream = adaptor.getAsyncDependencies().front();
885  deallocCallBuilder.create(loc, rewriter, {pointer, stream});
886 
887  rewriter.replaceOp(deallocOp, {stream});
888  return success();
889 }
890 
891 static bool isGpuAsyncTokenType(Value value) {
892  return isa<gpu::AsyncTokenType>(value.getType());
893 }
894 
895 // Converts !gpu.async.token operands of `async.yield` to runtime calls. The
896 // !gpu.async.token are lowered to stream within the async.execute region, but
897 // are passed as events between them. For each !gpu.async.token operand, we
898 // create an event and record it on the stream.
899 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
900  async::YieldOp yieldOp, OpAdaptor adaptor,
901  ConversionPatternRewriter &rewriter) const {
902  if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType))
903  return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
904 
905  Location loc = yieldOp.getLoc();
906  SmallVector<Value, 4> newOperands(adaptor.getOperands());
907  llvm::SmallDenseSet<Value> streams;
908  for (auto &operand : yieldOp->getOpOperands()) {
909  if (!isGpuAsyncTokenType(operand.get()))
910  continue;
911  auto idx = operand.getOperandNumber();
912  auto stream = adaptor.getOperands()[idx];
913  auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
914  eventRecordCallBuilder.create(loc, rewriter, {event, stream});
915  newOperands[idx] = event;
916  streams.insert(stream);
917  }
918  for (auto stream : streams)
919  streamDestroyCallBuilder.create(loc, rewriter, {stream});
920 
921  rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
922  return success();
923 }
924 
925 // Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
926 static bool isDefinedByCallTo(Value value, StringRef functionName) {
927  assert(isa<LLVM::LLVMPointerType>(value.getType()));
928  if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
929  return defOp.getCallee()->equals(functionName);
930  return false;
931 }
932 
933 // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
934 // with the stream/event operands. The operands are destroyed. That is, it
935 // assumes that it is not used afterwards or elsewhere. Otherwise we will get a
936 // runtime error. Eventually, we should guarantee this property.
937 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
938  gpu::WaitOp waitOp, OpAdaptor adaptor,
939  ConversionPatternRewriter &rewriter) const {
940  if (waitOp.getAsyncToken())
941  return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
942 
943  Location loc = waitOp.getLoc();
944 
945  for (auto operand : adaptor.getOperands()) {
946  if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
947  // The converted operand's definition created a stream.
948  streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
949  streamDestroyCallBuilder.create(loc, rewriter, {operand});
950  } else {
951  // Otherwise the converted operand is an event. This assumes that we use
952  // events in control flow code as well.
953  eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
954  eventDestroyCallBuilder.create(loc, rewriter, {operand});
955  }
956  }
957 
958  rewriter.eraseOp(waitOp);
959  return success();
960 }
961 
962 // Converts `gpu.wait async` to runtime calls. The converted op creates a new
963 // stream that is synchronized with stream/event operands. The operands are
964 // destroyed. That is, it assumes that it is not used afterwards or elsewhere.
965 // Otherwise we will get a runtime error. Eventually, we should guarantee this
966 // property.
967 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
968  gpu::WaitOp waitOp, OpAdaptor adaptor,
969  ConversionPatternRewriter &rewriter) const {
970  if (!waitOp.getAsyncToken())
971  return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
972 
973  Location loc = waitOp.getLoc();
974 
975  auto insertionPoint = rewriter.saveInsertionPoint();
976  SmallVector<Value, 1> events;
977  for (auto pair :
978  llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
979  auto operand = std::get<1>(pair);
980  if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
981  // The converted operand's definition created a stream. Insert an event
982  // into the stream just after the last use of the original token operand.
983  auto *defOp = std::get<0>(pair).getDefiningOp();
984  rewriter.setInsertionPointAfter(defOp);
985  auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
986  eventRecordCallBuilder.create(loc, rewriter, {event, operand});
987  events.push_back(event);
988  } else {
989  // Otherwise the converted operand is an event. This assumes that we use
990  // events in control flow code as well.
991  events.push_back(operand);
992  }
993  }
994  rewriter.restoreInsertionPoint(insertionPoint);
995  auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
996  for (auto event : events)
997  streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
998  for (auto event : events)
999  eventDestroyCallBuilder.create(loc, rewriter, {event});
1000  rewriter.replaceOp(waitOp, {stream});
1001 
1002  return success();
1003 }
1004 
1005 // Creates a struct containing all kernel parameters on the stack and returns
1006 // an array of type-erased pointers to the fields of the struct. The array can
1007 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
1008 // The generated code is essentially as follows:
1009 //
1010 // %struct = alloca(sizeof(struct { Parameters... }))
1011 // %array = alloca(NumParameters * sizeof(void *))
1012 // for (i : [0, NumParameters))
1013 // %fieldPtr = llvm.getelementptr %struct[0, i]
1014 // llvm.store parameters[i], %fieldPtr
1015 // %elementPtr = llvm.getelementptr %array[i]
1016 // llvm.store %fieldPtr, %elementPtr
1017 // return %array
1018 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
1019  gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const {
1020  auto loc = launchOp.getLoc();
1021  auto numKernelOperands = launchOp.getNumKernelOperands();
1022  // Note: If `useBarePtrCallConv` is set in the type converter's options,
1023  // the value of `kernelBarePtrCallConv` will be ignored.
1024  SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands(
1025  loc, launchOp.getOperands().take_back(numKernelOperands),
1026  adaptor.getOperands().take_back(numKernelOperands), builder,
1027  /*useBarePtrCallConv=*/kernelBarePtrCallConv);
1028  auto numArguments = arguments.size();
1029  SmallVector<Type, 4> argumentTypes;
1030  argumentTypes.reserve(numArguments);
1031  for (auto argument : arguments)
1032  argumentTypes.push_back(argument.getType());
1033  auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(),
1034  argumentTypes);
1035  auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, 1);
1036  auto structPtr =
1037  builder.create<LLVM::AllocaOp>(loc, llvmPointerType, structType, one,
1038  /*alignment=*/0);
1039  auto arraySize =
1040  builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, numArguments);
1041  auto arrayPtr = builder.create<LLVM::AllocaOp>(
1042  loc, llvmPointerType, llvmPointerType, arraySize, /*alignment=*/0);
1043  for (const auto &en : llvm::enumerate(arguments)) {
1044  const auto index = static_cast<int32_t>(en.index());
1045  Value fieldPtr =
1046  builder.create<LLVM::GEPOp>(loc, llvmPointerType, structType, structPtr,
1047  ArrayRef<LLVM::GEPArg>{0, index});
1048  builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
1049  auto elementPtr =
1050  builder.create<LLVM::GEPOp>(loc, llvmPointerType, llvmPointerType,
1051  arrayPtr, ArrayRef<LLVM::GEPArg>{index});
1052  builder.create<LLVM::StoreOp>(loc, fieldPtr, elementPtr);
1053  }
1054  return arrayPtr;
1055 }
1056 
1057 // Generates an LLVM IR dialect global that contains the name of the given
1058 // kernel function as a C string, and returns a pointer to its beginning.
1059 // The code is essentially:
1060 //
1061 // llvm.global constant @kernel_name("function_name\00")
1062 // func(...) {
1063 // %0 = llvm.addressof @kernel_name
1064 // %1 = llvm.constant (0 : index)
1065 // %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
1066 // }
1067 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
1068  StringRef moduleName, StringRef name, Location loc,
1069  OpBuilder &builder) const {
1070  // Make sure the trailing zero is included in the constant.
1071  std::vector<char> kernelName(name.begin(), name.end());
1072  kernelName.push_back('\0');
1073 
1074  std::string globalName =
1075  std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name));
1076  return LLVM::createGlobalString(
1077  loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
1078  LLVM::Linkage::Internal);
1079 }
1080 
1081 // Emits LLVM IR to launch a kernel function. Expects the module that contains
1082 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
1083 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
1084 //
1085 // %0 = call %binarygetter
1086 // %1 = call %moduleLoad(%0)
1087 // %2 = <see generateKernelNameConstant>
1088 // %3 = call %moduleGetFunction(%1, %2)
1089 // %4 = call %streamCreate()
1090 // %5 = <see generateParamsArray>
1091 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
1092 // call %streamSynchronize(%4)
1093 // call %streamDestroy(%4)
1094 // call %moduleUnload(%1)
1095 //
1096 // If the op is async, the stream corresponds to the (single) async dependency
1097 // as well as the async token the op produces.
1098 LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
1099  gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
1100  ConversionPatternRewriter &rewriter) const {
1101  if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
1102  return failure();
1103 
1104  if (launchOp.getAsyncDependencies().size() > 1)
1105  return rewriter.notifyMatchFailure(
1106  launchOp, "Cannot convert with more than one async dependency.");
1107 
1108  // Fail when the synchronous version of the op has async dependencies. The
1109  // lowering destroys the stream, and we do not want to check that there is no
1110  // use of the stream after this op.
1111  if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
1112  return rewriter.notifyMatchFailure(
1113  launchOp, "Cannot convert non-async op with async dependencies.");
1114 
1115  Location loc = launchOp.getLoc();
1116 
1117  // Create an LLVM global with CUBIN extracted from the kernel annotation and
1118  // obtain a pointer to the first byte in it.
1119  gpu::GPUModuleOp kernelModule;
1120  if (cachedModuleTable)
1121  kernelModule = cachedModuleTable->lookup<gpu::GPUModuleOp>(
1122  launchOp.getKernelModuleName());
1123  else
1124  kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
1125  launchOp, launchOp.getKernelModuleName());
1126  assert(kernelModule && "expected a kernel module");
1127 
1128  // If the module has Targets then just update the op operands.
1129  if (ArrayAttr targets = kernelModule.getTargetsAttr()) {
1130  Value stream = Value();
1131  if (!adaptor.getAsyncDependencies().empty())
1132  stream = adaptor.getAsyncDependencies().front();
1133  // If the async keyword is present and there are no dependencies, then a
1134  // stream must be created to pass to subsequent operations.
1135  else if (launchOp.getAsyncToken())
1136  stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
1137 
1138  // Lower the kernel operands to match kernel parameters.
1139  // Note: If `useBarePtrCallConv` is set in the type converter's options,
1140  // the value of `kernelBarePtrCallConv` will be ignored.
1141  SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands(
1142  loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(),
1143  rewriter, /*useBarePtrCallConv=*/kernelBarePtrCallConv);
1144 
1145  std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1146  if (launchOp.hasClusterSize()) {
1147  clusterSize =
1148  gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
1149  adaptor.getClusterSizeZ()};
1150  }
1151  rewriter.create<gpu::LaunchFuncOp>(
1152  launchOp.getLoc(), launchOp.getKernelAttr(),
1153  gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1154  adaptor.getGridSizeZ()},
1155  gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1156  adaptor.getBlockSizeZ()},
1157  adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
1158  if (launchOp.getAsyncToken())
1159  rewriter.replaceOp(launchOp, {stream});
1160  else
1161  rewriter.eraseOp(launchOp);
1162  return success();
1163  }
1164 
1165  auto binaryAttr =
1166  kernelModule->getAttrOfType<StringAttr>(gpuBinaryAnnotation);
1167  if (!binaryAttr) {
1168  kernelModule.emitOpError()
1169  << "missing " << gpuBinaryAnnotation << " attribute";
1170  return failure();
1171  }
1172 
1173  SmallString<128> nameBuffer(kernelModule.getName());
1174  nameBuffer.append(kGpuBinaryStorageSuffix);
1175  Value data =
1176  LLVM::createGlobalString(loc, rewriter, nameBuffer.str(),
1177  binaryAttr.getValue(), LLVM::Linkage::Internal);
1178 
1179  // Pass the binary size. SPIRV requires binary size.
1180  auto gpuBlob = binaryAttr.getValue();
1181  auto gpuBlobSize = rewriter.create<mlir::LLVM::ConstantOp>(
1182  loc, llvmInt64Type,
1183  mlir::IntegerAttr::get(llvmInt64Type,
1184  static_cast<int64_t>(gpuBlob.size())));
1185 
1186  auto module =
1187  moduleLoadCallBuilder.create(loc, rewriter, {data, gpuBlobSize});
1188 
1189  // Pass the count of the parameters to runtime wrappers
1190  auto paramsCount = rewriter.create<mlir::LLVM::ConstantOp>(
1191  loc, llvmInt64Type,
1193  llvmInt64Type,
1194  static_cast<int64_t>(launchOp.getNumKernelOperands())));
1195 
1196  // Get the function from the module. The name corresponds to the name of
1197  // the kernel function.
1198  auto kernelName = generateKernelNameConstant(
1199  launchOp.getKernelModuleName().getValue(),
1200  launchOp.getKernelName().getValue(), loc, rewriter);
1201  auto function = moduleGetFunctionCallBuilder.create(
1202  loc, rewriter, {module.getResult(), kernelName});
1203  Value zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type, 0);
1204  Value stream =
1205  adaptor.getAsyncDependencies().empty()
1206  ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult()
1207  : adaptor.getAsyncDependencies().front();
1208  // Create array of pointers to kernel arguments.
1209  auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter);
1210  auto nullpointer = rewriter.create<LLVM::ZeroOp>(loc, llvmPointerType);
1211  Value dynamicSharedMemorySize = launchOp.getDynamicSharedMemorySize()
1212  ? launchOp.getDynamicSharedMemorySize()
1213  : zero;
1214  launchKernelCallBuilder.create(
1215  loc, rewriter,
1216  {function.getResult(), adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1217  adaptor.getGridSizeZ(), adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1218  adaptor.getBlockSizeZ(), dynamicSharedMemorySize, stream, kernelParams,
1219  /*extra=*/nullpointer, paramsCount});
1220 
1221  if (launchOp.getAsyncToken()) {
1222  // Async launch: make dependent ops use the same stream.
1223  rewriter.replaceOp(launchOp, {stream});
1224  } else {
1225  // Synchronize with host and destroy stream. This must be the stream created
1226  // above (with no other uses) because we check that the synchronous version
1227  // does not have any async dependencies.
1228  streamSynchronizeCallBuilder.create(loc, rewriter, stream);
1229  streamDestroyCallBuilder.create(loc, rewriter, stream);
1230  rewriter.eraseOp(launchOp);
1231  }
1232  moduleUnloadCallBuilder.create(loc, rewriter, module.getResult());
1233 
1234  return success();
1235 }
1236 
1238  ConversionPatternRewriter &rewriter,
1239  LLVM::LLVMPointerType destinationType,
1240  Value sourcePtr,
1241  const LLVMTypeConverter &typeConverter) {
1242  auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
1243  if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1244  sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1245  loc,
1247  destinationType.getAddressSpace()),
1248  sourcePtr);
1249  return sourcePtr;
1250 }
1251 
1252 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1253  gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1254  ConversionPatternRewriter &rewriter) const {
1255  auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1256 
1257  if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
1258  !isConvertibleAndHasIdentityMaps(memRefType) ||
1259  failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
1260  return failure();
1261 
1262  auto loc = memcpyOp.getLoc();
1263 
1264  MemRefDescriptor srcDesc(adaptor.getSrc());
1265  Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
1266 
1267  Type elementPtrType = getElementPtrType(memRefType);
1268  Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
1269  Value gepPtr = rewriter.create<LLVM::GEPOp>(
1270  loc, elementPtrType,
1271  typeConverter->convertType(memRefType.getElementType()), nullPtr,
1272  numElements);
1273  auto sizeBytes =
1274  rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
1275 
1276  auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1277  srcDesc.alignedPtr(rewriter, loc),
1278  *getTypeConverter());
1279  auto dst = bitAndAddrspaceCast(
1280  loc, rewriter, llvmPointerType,
1281  MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
1282  *getTypeConverter());
1283 
1284  auto stream = adaptor.getAsyncDependencies().front();
1285  memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1286 
1287  rewriter.replaceOp(memcpyOp, {stream});
1288 
1289  return success();
1290 }
1291 
1292 LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1293  gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1294  ConversionPatternRewriter &rewriter) const {
1295  auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1296 
1297  if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
1298  !isConvertibleAndHasIdentityMaps(memRefType) ||
1299  failed(isAsyncWithOneDependency(rewriter, memsetOp)))
1300  return failure();
1301 
1302  auto loc = memsetOp.getLoc();
1303 
1304  Type valueType = adaptor.getValue().getType();
1305  unsigned bitWidth = valueType.getIntOrFloatBitWidth();
1306  // Ints and floats of 16 or 32 bit width are allowed.
1307  if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1308  return rewriter.notifyMatchFailure(
1309  memsetOp, "value must be a 16 or 32 bit int or float");
1310  }
1311 
1312  unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth();
1313  Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1314 
1315  MemRefDescriptor dstDesc(adaptor.getDst());
1316  Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
1317 
1318  auto value =
1319  rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
1320  auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1321  dstDesc.alignedPtr(rewriter, loc),
1322  *getTypeConverter());
1323 
1324  auto stream = adaptor.getAsyncDependencies().front();
1325  FunctionCallBuilder builder =
1326  valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1327  builder.create(loc, rewriter, {dst, value, numElements, stream});
1328 
1329  rewriter.replaceOp(memsetOp, {stream});
1330  return success();
1331 }
1332 
1333 LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1334  gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1335  ConversionPatternRewriter &rewriter) const {
1336  Location loc = op.getLoc();
1337  auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1338  {adaptor.getDevIndex()});
1339  rewriter.replaceOp(op, call);
1340  return success();
1341 }
1342 
1343 template <typename T>
1344 static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) {
1345  Type llvmInt32Type = builder.getIntegerType(32);
1346  return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
1347  static_cast<int32_t>(tValue));
1348 }
1349 
1350 template <typename T>
1351 static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) {
1352  Type llvmFloat32Type = builder.getF32Type();
1353  return builder.create<LLVM::ConstantOp>(
1354  loc, llvmFloat32Type,
1355  builder.getF32FloatAttr(static_cast<float>(tValue)));
1356 }
1357 
1358 LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1359  gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1360  ConversionPatternRewriter &rewriter) const {
1361  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1362  failed(isAsyncWithOneDependency(rewriter, op)))
1363  return failure();
1364  Location loc = op.getLoc();
1365  auto stream = adaptor.getAsyncDependencies().front();
1366  Value pTensor =
1367  MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1368  Type dType = op.getMemref().getType().getElementType();
1369  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1370 
1371  SmallVector<Value, 4> dims;
1372  for (Value dim : adaptor.getDims()) {
1373  dims.push_back(dim);
1374  }
1375 
1376  Value handle;
1377  // TODO: For now, we track the use of the handle and lower it to cusparse /
1378  // cusparseLt accordingly. If in a block, both cusparse and cusparseLt are
1379  // used, we require two separate Creation ops to be the correct logic. In
1380  // future, we may add support to using one handle in sparse tensor / GPU
1381  // dialect in both cusparse and cusparseLt. use the cusparseLt create call if
1382  // the dnmat is used with spmat with 2:4 sparsity
1383  if (dims.size() == 2) {
1384  if (isSpMMCusparseLtOp(op.getDnTensor())) {
1385  auto handleSz = rewriter.create<LLVM::ConstantOp>(
1386  loc, getIndexType(), rewriter.getIndexAttr(11032));
1387  handle = rewriter.create<LLVM::AllocaOp>(
1388  loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1389  handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1390 
1391  createLtDnMatCallBuilder
1392  .create(loc, rewriter,
1393  {handle, dims[0], dims[1], pTensor, dtp, stream})
1394  .getResult();
1395  } else {
1396  handle =
1397  createDnMatCallBuilder
1398  .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1399  .getResult();
1400  }
1401  } else {
1402  assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1403  handle = createDnVecCallBuilder
1404  .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1405  .getResult();
1406  }
1407  rewriter.replaceOp(op, {handle, stream});
1408  return success();
1409 }
1410 
1411 LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1412  gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1413  ConversionPatternRewriter &rewriter) const {
1414  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1415  failed(isAsyncWithOneDependency(rewriter, op)))
1416  return failure();
1417  Location loc = op.getLoc();
1418  auto stream = adaptor.getAsyncDependencies().front();
1419  auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
1420  SmallVector<Value, 4> dims;
1421  for (Value dim : definingOp.getDims()) {
1422  dims.push_back(dim);
1423  }
1424  if (dims.size() == 2) {
1425  // Use the cusparseLt destroy call if the dnmat is used with spmat with
1426  // 2:4 sparsity
1427  if (isSpMMCusparseLtOp(op.getDnTensor())) {
1428  destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1429  {adaptor.getDnTensor(), stream});
1430  } else {
1431  destroyDnMatCallBuilder.create(loc, rewriter,
1432  {adaptor.getDnTensor(), stream});
1433  }
1434  } else {
1435  assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1436  destroyDnVecCallBuilder.create(loc, rewriter,
1437  {adaptor.getDnTensor(), stream});
1438  }
1439  rewriter.replaceOp(op, {stream});
1440  return success();
1441 }
1442 
1443 LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1444  gpu::CreateCooOp op, OpAdaptor adaptor,
1445  ConversionPatternRewriter &rewriter) const {
1446  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1447  failed(isAsyncWithOneDependency(rewriter, op)))
1448  return failure();
1449  Location loc = op.getLoc();
1450  auto stream = adaptor.getAsyncDependencies().front();
1451  Value pRowIdxs =
1452  MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1453  Value pColIdxs =
1454  MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1455  Value pValues =
1456  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1457  Type iType =
1458  llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1459  Type dType =
1460  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1461  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1462  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1463  auto handle =
1464  createCooCallBuilder
1465  .create(loc, rewriter,
1466  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1467  pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1468  .getResult();
1469  rewriter.replaceOp(op, {handle, stream});
1470  return success();
1471 }
1472 
1473 LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1474  gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1475  ConversionPatternRewriter &rewriter) const {
1476  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1477  failed(isAsyncWithOneDependency(rewriter, op)))
1478  return failure();
1479  Location loc = op.getLoc();
1480  auto stream = adaptor.getAsyncDependencies().front();
1481  Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc);
1482  Value pValues =
1483  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1484  Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1485  Type dType =
1486  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1487  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1488  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1489  auto handle =
1490  createCooAoSCallBuilder
1491  .create(loc, rewriter,
1492  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1493  pIdxs, pValues, itp, dtp, stream})
1494  .getResult();
1495  rewriter.replaceOp(op, {handle, stream});
1496  return success();
1497 }
1498 
1499 LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1500  gpu::CreateCsrOp op, OpAdaptor adaptor,
1501  ConversionPatternRewriter &rewriter) const {
1502  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1503  failed(isAsyncWithOneDependency(rewriter, op)))
1504  return failure();
1505  Location loc = op.getLoc();
1506  auto stream = adaptor.getAsyncDependencies().front();
1507  Value pRowPos =
1508  MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc);
1509  Value pColIdxs =
1510  MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1511  Value pValues =
1512  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1513  Type pType =
1514  llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1515  Type iType =
1516  llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1517  Type dType =
1518  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1519  auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1520  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1521  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1522  auto handle =
1523  createCsrCallBuilder
1524  .create(loc, rewriter,
1525  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1526  pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1527  .getResult();
1528  rewriter.replaceOp(op, {handle, stream});
1529  return success();
1530 }
1531 
1532 LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1533  gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1534  ConversionPatternRewriter &rewriter) const {
1535  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1536  failed(isAsyncWithOneDependency(rewriter, op)))
1537  return failure();
1538  Location loc = op.getLoc();
1539  auto stream = adaptor.getAsyncDependencies().front();
1540  Value pMat =
1541  MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1542  Type dType =
1543  llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1544  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1545 
1546  // CUDA runner asserts the size is 44104 bytes.
1547  auto handleSz = rewriter.create<LLVM::ConstantOp>(
1548  loc, getIndexType(), rewriter.getIndexAttr(44104));
1549  Value handle = rewriter.create<LLVM::AllocaOp>(
1550  loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1551  handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1552 
1553  create2To4SpMatCallBuilder
1554  .create(loc, rewriter,
1555  {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1556  .getResult();
1557  rewriter.replaceOp(op, {handle, stream});
1558  return success();
1559 }
1560 
1561 LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1562  gpu::DestroySpMatOp op, OpAdaptor adaptor,
1563  ConversionPatternRewriter &rewriter) const {
1564  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1565  failed(isAsyncWithOneDependency(rewriter, op)))
1566  return failure();
1567  Location loc = op.getLoc();
1568  auto stream = adaptor.getAsyncDependencies().front();
1569  // Use the cusparseLt destroy call if the spmat is 2:4 sparsity
1570  if (is2To4Sparsity(op.getSpmat())) {
1571  destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1572  {adaptor.getSpmat(), stream});
1573 
1574  } else {
1575  destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1576  }
1577  rewriter.replaceOp(op, {stream});
1578  return success();
1579 }
1580 
1581 LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1582  gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1583  ConversionPatternRewriter &rewriter) const {
1584  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1585  failed(isAsyncWithOneDependency(rewriter, op)))
1586  return failure();
1587  Location loc = op.getLoc();
1588  auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
1589  auto computeType = genConstInt32From(
1590  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1591  auto stream = adaptor.getAsyncDependencies().front();
1592  auto bufferSize = spMVBufferSizeCallBuilder
1593  .create(loc, rewriter,
1594  {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1595  adaptor.getDnY(), computeType, stream})
1596  .getResult();
1597  rewriter.replaceOp(op, {bufferSize, stream});
1598  return success();
1599 }
1600 
1601 LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1602  gpu::SpMVOp op, OpAdaptor adaptor,
1603  ConversionPatternRewriter &rewriter) const {
1604  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1605  failed(isAsyncWithOneDependency(rewriter, op)))
1606  return failure();
1607  Location loc = op.getLoc();
1608  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1609  auto computeType = genConstInt32From(
1610  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1611  auto stream = adaptor.getAsyncDependencies().front();
1612  Value pBuf =
1613  MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1614  spMVCallBuilder.create(loc, rewriter,
1615  {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1616  adaptor.getDnY(), computeType, pBuf, stream});
1617  rewriter.replaceOp(op, {stream});
1618  return success();
1619 }
1620 
1621 LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1622  gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1623  ConversionPatternRewriter &rewriter) const {
1624  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1625  failed(isAsyncWithOneDependency(rewriter, op)))
1626  return failure();
1627  Location loc = op.getLoc();
1628  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1629  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1630  auto stream = adaptor.getAsyncDependencies().front();
1631  Value bufferSize;
1632  if (is2To4Sparsity(op.getSpmatA())) {
1633  auto pruneFlag =
1634  genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA()));
1635  auto computeType = genConstInt32From(
1636  rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
1637  auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1638  rewriter.getIndexAttr(3));
1639  auto bufferSize = rewriter.create<LLVM::AllocaOp>(
1640  loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16);
1641  createCuSparseLtSpMMBufferSizeBuilder
1642  .create(loc, rewriter,
1643  {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1644  adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1645  pruneFlag, stream})
1646  .getResult();
1647 
1648  auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>(
1649  loc, llvmPointerType, llvmPointerType, bufferSize,
1650  ValueRange{rewriter.create<LLVM::ConstantOp>(
1651  loc, getIndexType(), rewriter.getIndexAttr(1))});
1652  auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>(
1653  loc, llvmPointerType, llvmPointerType, bufferSize,
1654  ValueRange{rewriter.create<LLVM::ConstantOp>(
1655  loc, getIndexType(), rewriter.getIndexAttr(2))});
1656  auto bufferSize0 =
1657  rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize);
1658  auto bufferSize1 =
1659  rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1);
1660  auto bufferSize2 =
1661  rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
1662 
1663  rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1664  } else {
1665  auto computeType = genConstInt32From(
1666  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1667  bufferSize =
1668  createSpMMBufferSizeCallBuilder
1669  .create(loc, rewriter,
1670  {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1671  adaptor.getDnmatC(), computeType, stream})
1672  .getResult();
1673  rewriter.replaceOp(op, {bufferSize, stream});
1674  }
1675  return success();
1676 }
1677 
1678 LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1679  gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1680  ConversionPatternRewriter &rewriter) const {
1681  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1682  failed(isAsyncWithOneDependency(rewriter, op)))
1683  return failure();
1684  Location loc = op.getLoc();
1685  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1686  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1687  auto computeType = genConstInt32From(
1688  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1689  auto stream = adaptor.getAsyncDependencies().front();
1690  auto bufferSize =
1691  createSDDMMBufferSizeCallBuilder
1692  .create(loc, rewriter,
1693  {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1694  adaptor.getSpmatC(), computeType, stream})
1695  .getResult();
1696  rewriter.replaceOp(op, {bufferSize, stream});
1697  return success();
1698 }
1699 
1700 LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1701  gpu::SpMMOp op, OpAdaptor adaptor,
1702  ConversionPatternRewriter &rewriter) const {
1703  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1704  failed(isAsyncWithOneDependency(rewriter, op)))
1705  return failure();
1706  Location loc = op.getLoc();
1707  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1708  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1709  auto computeType = genConstInt32From(
1710  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1711 
1712  auto stream = adaptor.getAsyncDependencies().front();
1713 
1714  // Lower to cusparseLt if applicable
1715  if (is2To4Sparsity(op.getSpmatA())) {
1716  SmallVector<Value> pBufs;
1717  for (Value buffer : adaptor.getBuffers()) {
1718  Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
1719  pBufs.push_back(pBuf);
1720  }
1721  createCuSparseLtSpMMBuilder.create(
1722  loc, rewriter,
1723  {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1724  pBufs[0], pBufs[1], pBufs[2], stream});
1725  } else {
1726  Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
1727  .allocatedPtr(rewriter, loc);
1728  createSpMMCallBuilder.create(loc, rewriter,
1729  {modeA, modeB, adaptor.getSpmatA(),
1730  adaptor.getDnmatB(), adaptor.getDnmatC(),
1731  computeType, pBuf, stream});
1732  }
1733  rewriter.replaceOp(op, {stream});
1734  return success();
1735 }
1736 
1737 template <typename T>
1739  converter.addConversion([&converter](T) -> Type {
1740  return LLVM::LLVMPointerType::get(&converter.getContext());
1741  });
1742 }
1743 
1744 LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1745  gpu::SDDMMOp op, OpAdaptor adaptor,
1746  ConversionPatternRewriter &rewriter) const {
1747  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1748  failed(isAsyncWithOneDependency(rewriter, op)))
1749  return failure();
1750  Location loc = op.getLoc();
1751  auto computeType = genConstInt32From(
1752  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1753  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1754  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1755  auto stream = adaptor.getAsyncDependencies().front();
1756  Value pBuf =
1757  MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1758  createSDDMMCallBuilder.create(loc, rewriter,
1759  {modeA, modeB, adaptor.getDnmatA(),
1760  adaptor.getDnmatB(), adaptor.getSpmatC(),
1761  computeType, pBuf, stream});
1762  rewriter.replaceOp(op, {stream});
1763  return success();
1764 }
1765 
1767 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1768  gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1769  ConversionPatternRewriter &rewriter) const {
1770  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1771  failed(isAsyncWithOneDependency(rewriter, op)))
1772  return failure();
1773  Location loc = op.getLoc();
1774  auto stream = adaptor.getAsyncDependencies().front();
1775  Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1776  .getResult();
1777  rewriter.replaceOp(op, {descr, stream});
1778  return success();
1779 }
1780 
1782 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1783  gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1784  ConversionPatternRewriter &rewriter) const {
1785  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1786  failed(isAsyncWithOneDependency(rewriter, op)))
1787  return failure();
1788  Location loc = op.getLoc();
1789  auto stream = adaptor.getAsyncDependencies().front();
1790  createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1791  {adaptor.getDesc(), stream});
1792  rewriter.replaceOp(op, {stream});
1793  return success();
1794 }
1795 
1797 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1798  gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1799  ConversionPatternRewriter &rewriter) const {
1800  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1801  failed(isAsyncWithOneDependency(rewriter, op)))
1802  return failure();
1803  Location loc = op.getLoc();
1804  auto computeType = genConstInt32From(
1805  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1806  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1807  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1808  auto stream = adaptor.getAsyncDependencies().front();
1809 
1810  Value pBuf =
1811  MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1812  Value bufferSizeNew;
1813 
1814  if (adaptor.getKind() ==
1815  gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1816  bufferSizeNew =
1817  createSpGEMMWorkEstimationBuilder
1818  .create(loc, rewriter,
1819  {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1820  adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1821  adaptor.getBufferSz(), pBuf, stream})
1822  .getResult();
1823  } else {
1824  bufferSizeNew =
1825  createSpGEMMComputeBuilder
1826  .create(loc, rewriter,
1827  {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1828  adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1829  adaptor.getBufferSz(), pBuf, stream})
1830  .getResult();
1831  }
1832  rewriter.replaceOp(op, {bufferSizeNew, stream});
1833  return success();
1834 }
1835 
1836 LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1837  gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1838  ConversionPatternRewriter &rewriter) const {
1839  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1840  failed(isAsyncWithOneDependency(rewriter, op)))
1841  return failure();
1842  Location loc = op.getLoc();
1843  auto computeType = genConstInt32From(
1844  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1845  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1846  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1847  auto stream = adaptor.getAsyncDependencies().front();
1848  createSpGEMMCopyBuilder.create(loc, rewriter,
1849  {adaptor.getDesc(), modeA, modeB,
1850  adaptor.getSpmatA(), adaptor.getSpmatB(),
1851  adaptor.getSpmatC(), computeType, stream});
1852  rewriter.replaceOp(op, {stream});
1853  return success();
1854 }
1855 
1856 LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1857  gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1858  ConversionPatternRewriter &rewriter) const {
1859  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1860  failed(isAsyncWithOneDependency(rewriter, op)))
1861  return failure();
1862  Location loc = op.getLoc();
1863  auto stream = adaptor.getAsyncDependencies().front();
1864 
1865  auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1866  rewriter.getIndexAttr(3));
1867  auto buffer = rewriter.create<LLVM::AllocaOp>(
1868  loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16);
1869 
1870  auto rowsPtr = rewriter.create<LLVM::GEPOp>(
1871  loc, llvmPointerType, llvmPointerType, buffer,
1872  ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1873  rewriter.getIndexAttr(0))});
1874  auto colsPtr = rewriter.create<LLVM::GEPOp>(
1875  loc, llvmPointerType, llvmPointerType, buffer,
1876  ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1877  rewriter.getIndexAttr(1))});
1878  auto nnzsPtr = rewriter.create<LLVM::GEPOp>(
1879  loc, llvmPointerType, llvmPointerType, buffer,
1880  ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1881  rewriter.getIndexAttr(2))});
1882  createSpMatGetSizeBuilder.create(
1883  loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1884  auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr);
1885  auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr);
1886  auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr);
1887 
1888  rewriter.replaceOp(op, {rows, cols, nnzs, stream});
1889  return success();
1890 }
1891 
1892 LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1893  gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1894  ConversionPatternRewriter &rewriter) const {
1895  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1896  failed(isAsyncWithOneDependency(rewriter, op)))
1897  return failure();
1898  Location loc = op.getLoc();
1899  auto stream = adaptor.getAsyncDependencies().front();
1900  Value pPos =
1901  MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
1902  Value pCrd =
1903  MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
1904  Value pVal =
1905  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1906  createSetCsrPointersBuilder.create(
1907  loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1908  rewriter.replaceOp(op, {stream});
1909  return success();
1910 }
1911 
1912 LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1913  gpu::CreateCscOp op, OpAdaptor adaptor,
1914  ConversionPatternRewriter &rewriter) const {
1915  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1916  failed(isAsyncWithOneDependency(rewriter, op)))
1917  return failure();
1918  Location loc = op.getLoc();
1919  auto stream = adaptor.getAsyncDependencies().front();
1920  Value pColPos =
1921  MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc);
1922  Value pRowIdxs =
1923  MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1924  Value pValues =
1925  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1926  Type pType =
1927  llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1928  Type iType =
1929  llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1930  Type dType =
1931  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1932  auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1933  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1934  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1935  auto handle =
1936  createCscCallBuilder
1937  .create(loc, rewriter,
1938  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1939  pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1940  .getResult();
1941  rewriter.replaceOp(op, {handle, stream});
1942  return success();
1943 }
1944 
1945 LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1946  gpu::CreateBsrOp op, OpAdaptor adaptor,
1947  ConversionPatternRewriter &rewriter) const {
1948  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1949  failed(isAsyncWithOneDependency(rewriter, op)))
1950  return failure();
1951  Location loc = op.getLoc();
1952  auto stream = adaptor.getAsyncDependencies().front();
1953  Value pRowPos =
1954  MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc);
1955  Value pColIdxs =
1956  MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc);
1957  Value pValues =
1958  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1959  Type pType =
1960  llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1961  Type iType =
1962  llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1963  Type dType =
1964  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1965  auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1966  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1967  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1968  auto handle =
1969  createBsrCallBuilder
1970  .create(loc, rewriter,
1971  {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1972  adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1973  pColIdxs, pValues, ptp, itp, dtp, stream})
1974  .getResult();
1975  rewriter.replaceOp(op, {handle, stream});
1976  return success();
1977 }
1978 
1980  RewritePatternSet &patterns,
1981  StringRef gpuBinaryAnnotation,
1982  bool kernelBarePtrCallConv,
1983  SymbolTable *cachedModuleTable) {
1984  addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
1985  addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
1986  addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
1987  addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter);
1988 
1989  patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
1990  ConvertDeallocOpToGpuRuntimeCallPattern,
1991  ConvertHostRegisterOpToGpuRuntimeCallPattern,
1992  ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1993  ConvertMemcpyOpToGpuRuntimeCallPattern,
1994  ConvertMemsetOpToGpuRuntimeCallPattern,
1995  ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1996  ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1997  ConvertWaitOpToGpuRuntimeCallPattern,
1998  ConvertAsyncYieldToGpuRuntimeCallPattern,
1999  ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
2000  ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
2001  ConvertCreateCooOpToGpuRuntimeCallPattern,
2002  ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
2003  ConvertCreateCsrOpToGpuRuntimeCallPattern,
2004  ConvertCreateCscOpToGpuRuntimeCallPattern,
2005  ConvertCreateBsrOpToGpuRuntimeCallPattern,
2006  ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
2007  ConvertDestroySpMatOpToGpuRuntimeCallPattern,
2008  ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
2009  ConvertSpMVOpToGpuRuntimeCallPattern,
2010  ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
2011  ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
2012  ConvertSpMMOpToGpuRuntimeCallPattern,
2013  ConvertSDDMMOpToGpuRuntimeCallPattern,
2014  ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
2015  ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
2016  ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
2017  ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
2018  ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
2019  ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
2020  patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
2021  converter, gpuBinaryAnnotation, kernelBarePtrCallConv, cachedModuleTable);
2022  patterns.add<EraseGpuModuleOpPattern>(&converter.getContext());
2023 }
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static constexpr const char * kGpuBinaryStorageSuffix
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static MLIRContext * getContext(OpFoldResult val)
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static spirv::ScalarType getIndexType(MLIRContext *ctx, const SPIRVConversionOptions &options)
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1541
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
FloatType getF32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
MLIRContext * getContext() const
Definition: Builders.h:55
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:253
IntegerAttr getI8IntegerAttr(int8_t value)
Definition: Builders.cpp:234
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:36
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:53
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
MLIRContext & getContext() const
Returns the MLIR context.
static LLVMStructType getNewIdentified(MLIRContext *context, StringRef name, ArrayRef< Type > elements, bool isPacked=false)
Gets a new identified struct with the given body.
Definition: LLVMTypes.cpp:436
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
This class helps build Operations.
Definition: Builders.h:209
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition: Builders.h:387
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition: Builders.h:248
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition: Builders.h:392
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:450
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
Definition: Region.h:205
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:52
bool isF32() const
Definition: Types.cpp:51
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:58
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:119
bool isF16() const
Definition: Types.cpp:49
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
bool isBF16() const
Definition: Types.cpp:48
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Value createGlobalString(Location loc, OpBuilder &builder, StringRef name, StringRef value, Linkage linkage)
Create an LLVM global containing the string "value" at the module containing surrounding the insertio...
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:858
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Definition: MemoryOps.cpp:263
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateFinalizeMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef gpuBinaryAnnotation={}, bool kernelBarePtrCallConv=false, SymbolTable *cachedModuleTable=nullptr)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LLVM::LLVMFunctionType functionType
Definition: GPUCommonPass.h:64
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:38