MLIR  21.0.0git
CudaRuntimeWrappers.cpp
Go to the documentation of this file.
1 //===- CudaRuntimeWrappers.cpp - MLIR CUDA API wrapper library ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implements C wrappers around the CUDA library for easy linking in ORC jit.
10 // Also adds some debugging helpers that are helpful when writing MLIR code to
11 // run on GPUs.
12 //
13 //===----------------------------------------------------------------------===//
14 
16 
17 #include <cstdio>
18 
19 #include "cuda.h"
20 #include "cuda_bf16.h"
21 #include "cuda_fp16.h"
22 
23 #ifdef MLIR_ENABLE_CUDA_CUSPARSE
24 #include "cusparse.h"
25 #ifdef MLIR_ENABLE_CUDA_CUSPARSELT
26 #include "cusparseLt.h"
27 #endif // MLIR_ENABLE_CUDA_CUSPARSELT
28 #endif // MLIR_ENABLE_CUDA_CUSPARSE
29 
30 #ifdef _WIN32
31 #include <malloc.h>
32 #define MLIR_CUDA_WRAPPERS_EXPORT __declspec(dllexport)
33 #else
34 #define MLIR_CUDA_WRAPPERS_EXPORT __attribute__((visibility("default")))
35 #endif // _WIN32
36 
37 #define CUDA_REPORT_IF_ERROR(expr) \
38  [](CUresult result) { \
39  if (!result) \
40  return; \
41  const char *name = nullptr; \
42  cuGetErrorName(result, &name); \
43  if (!name) \
44  name = "<unknown>"; \
45  fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \
46  }(expr)
47 
48 #define CUSPARSE_REPORT_IF_ERROR(expr) \
49  { \
50  cusparseStatus_t status = (expr); \
51  if (status != CUSPARSE_STATUS_SUCCESS) { \
52  fprintf(stderr, "cuSPARSE '%s' failed with '%s'\n", #expr, \
53  cusparseGetErrorString(status)); \
54  } \
55  }
56 
57 thread_local static int32_t defaultDevice = 0;
58 
59 /// Helper method that checks environment value for debugging.
61  const char *kDebugEnvironmentVariable = "MLIR_CUDA_DEBUG";
62  static bool isEnabled = getenv(kDebugEnvironmentVariable) != nullptr;
63  return isEnabled;
64 }
65 
66 #define debug_print(fmt, ...) \
67  do { \
68  if (isDebugEnabled()) \
69  fprintf(stderr, "%s:%d:%s(): " fmt, "CudaRuntimeWrappers.cpp", __LINE__, \
70  __func__, __VA_ARGS__); \
71  } while (0)
72 
73 // Returns default CUdevice
74 CUdevice getDefaultCuDevice() {
75  CUdevice device;
76  CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice));
77  return device;
78 }
79 
80 // Make the primary context of the current default device current for the
81 // duration
82 // of the instance and restore the previous context on destruction.
84 public:
86  // Static reference to CUDA primary context for device ordinal
87  // defaultDevice.
88  static CUcontext context = [] {
89  CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0));
90  CUcontext ctx;
91  // Note: this does not affect the current context.
93  cuDevicePrimaryCtxRetain(&ctx, getDefaultCuDevice()));
94  return ctx;
95  }();
96 
97  CUDA_REPORT_IF_ERROR(cuCtxPushCurrent(context));
98  }
99 
100  ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxPopCurrent(nullptr)); }
101 };
102 
103 #ifdef MLIR_ENABLE_CUDA_CUSPARSE
104 // Note that (1) Nvidia confirms the safety to share handle across multiple
105 // instances, and streams. (2) Clients are responsible to call the @mgpu
106 // environment initialization/destruction in a thread-safe manner, e.g.,
107 // at the beginning of the program before multi-threads are created.
108 static cusparseHandle_t cusparse_env = nullptr;
109 
110 #ifdef MLIR_ENABLE_CUDA_CUSPARSELT
111 // cusparseLtHandle_t is not a pointer type, so we need an additional flag to
112 // indicate whether it is initialized.
113 static cusparseLtHandle_t cusparseLt_env;
114 static bool cusparseLt_initiated = false;
115 
116 #endif // MLIR_ENABLE_CUDA_CUSPARSELT
117 #endif // MLIR_ENABLE_CUDA_CUSPARSE
118 
119 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule
120 mgpuModuleLoad(void *data, size_t /*gpuBlobSize*/) {
121  ScopedContext scopedContext;
122  CUmodule module = nullptr;
123  CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data));
124  return module;
125 }
126 
127 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoadJIT(void *data,
128  int optLevel) {
129  ScopedContext scopedContext;
130  CUmodule module = nullptr;
131  char jitErrorBuffer[4096] = {0};
132  CUjit_option jitOptions[] = {CU_JIT_ERROR_LOG_BUFFER,
133  CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
134  CU_JIT_OPTIMIZATION_LEVEL};
135  void *jitOptionsVals[] = {jitErrorBuffer,
136  reinterpret_cast<void *>(sizeof(jitErrorBuffer)),
137  reinterpret_cast<void *>(optLevel)};
138 
139  CUresult result =
140  cuModuleLoadDataEx(&module, data, 3, jitOptions, jitOptionsVals);
141  if (result) {
142  fprintf(stderr, "JIT compilation failed with: '%s'\n", jitErrorBuffer);
143  CUDA_REPORT_IF_ERROR(result);
144  }
145  return module;
146 }
147 
148 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuModuleUnload(CUmodule module) {
149  CUDA_REPORT_IF_ERROR(cuModuleUnload(module));
150 }
151 
152 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUfunction
153 mgpuModuleGetFunction(CUmodule module, const char *name) {
154  CUfunction function = nullptr;
155  CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name));
156  return function;
157 }
158 
159 // The wrapper uses intptr_t instead of CUDA's unsigned int to match
160 // the type of MLIR's index type. This avoids the need for casts in the
161 // generated MLIR code.
162 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
163 mgpuLaunchKernel(CUfunction function, intptr_t gridX, intptr_t gridY,
164  intptr_t gridZ, intptr_t blockX, intptr_t blockY,
165  intptr_t blockZ, int32_t smem, CUstream stream, void **params,
166  void **extra, size_t /*paramsCount*/) {
167  ScopedContext scopedContext;
168  if (smem > 0) {
169  // Avoid checking driver as it's more expensive than if statement
170  int32_t maxShmem = 0;
171  CUdevice device = getDefaultCuDevice();
172  CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice));
173  CUDA_REPORT_IF_ERROR(cuDeviceGetAttribute(
174  &maxShmem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
175  device));
176  if (maxShmem < smem) {
177  fprintf(stderr,
178  "Requested shared memory (%dkb) is larger than maximum allowed "
179  "shared memory (%dkb) for this device\n",
180  smem, maxShmem);
181  }
182  CUDA_REPORT_IF_ERROR(cuFuncSetAttribute(
183  function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem));
184  }
185  debug_print("Launching kernel, grid=%ld,%ld,%ld, "
186  "threads: %ld, %ld, %ld, "
187  "smem: %dkb\n",
188  gridX, gridY, gridZ, blockX, blockY, blockZ, smem);
189  CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX,
190  blockY, blockZ, smem, stream, params,
191  extra));
192 }
193 
195  ScopedContext scopedContext;
196  CUstream stream = nullptr;
197  CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
198  return stream;
199 }
200 
201 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamDestroy(CUstream stream) {
202  CUDA_REPORT_IF_ERROR(cuStreamDestroy(stream));
203 }
204 
205 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
206 mgpuStreamSynchronize(CUstream stream) {
207  CUDA_REPORT_IF_ERROR(cuStreamSynchronize(stream));
208 }
209 
210 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamWaitEvent(CUstream stream,
211  CUevent event) {
212  CUDA_REPORT_IF_ERROR(cuStreamWaitEvent(stream, event, /*flags=*/0));
213 }
214 
216  ScopedContext scopedContext;
217  CUevent event = nullptr;
218  CUDA_REPORT_IF_ERROR(cuEventCreate(&event, CU_EVENT_DISABLE_TIMING));
219  return event;
220 }
221 
222 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventDestroy(CUevent event) {
223  CUDA_REPORT_IF_ERROR(cuEventDestroy(event));
224 }
225 
226 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventSynchronize(CUevent event) {
227  CUDA_REPORT_IF_ERROR(cuEventSynchronize(event));
228 }
229 
230 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventRecord(CUevent event,
231  CUstream stream) {
232  CUDA_REPORT_IF_ERROR(cuEventRecord(event, stream));
233 }
234 
235 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
236 mgpuMemAlloc(uint64_t sizeBytes, CUstream stream, bool isHostShared) {
237  ScopedContext scopedContext;
238  CUdeviceptr ptr = 0;
239  if (sizeBytes == 0)
240  return reinterpret_cast<void *>(ptr);
241 
242  if (isHostShared) {
244  cuMemAllocManaged(&ptr, sizeBytes, CU_MEM_ATTACH_GLOBAL));
245  return reinterpret_cast<void *>(ptr);
246  }
247  CUDA_REPORT_IF_ERROR(cuMemAlloc(&ptr, sizeBytes));
248  return reinterpret_cast<void *>(ptr);
249 }
250 
251 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemFree(void *ptr,
252  CUstream /*stream*/) {
253  CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(ptr)));
254 }
255 
256 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
257 mgpuMemcpy(void *dst, void *src, size_t sizeBytes, CUstream stream) {
258  CUDA_REPORT_IF_ERROR(cuMemcpyAsync(reinterpret_cast<CUdeviceptr>(dst),
259  reinterpret_cast<CUdeviceptr>(src),
260  sizeBytes, stream));
261 }
262 
263 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
264 mgpuMemset32(void *dst, unsigned int value, size_t count, CUstream stream) {
265  CUDA_REPORT_IF_ERROR(cuMemsetD32Async(reinterpret_cast<CUdeviceptr>(dst),
266  value, count, stream));
267 }
268 
269 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
270 mgpuMemset16(void *dst, unsigned short value, size_t count, CUstream stream) {
271  CUDA_REPORT_IF_ERROR(cuMemsetD16Async(reinterpret_cast<CUdeviceptr>(dst),
272  value, count, stream));
273 }
274 
275 ///
276 /// Helper functions for writing mlir example code
277 ///
278 
279 // Allows to register byte array with the CUDA runtime. Helpful until we have
280 // transfer functions implemented.
281 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
282 mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
283  ScopedContext scopedContext;
284  CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0));
285 }
286 
287 /// Registers a memref with the CUDA runtime. `descriptor` is a pointer to a
288 /// ranked memref descriptor struct of rank `rank`. Helpful until we have
289 /// transfer functions implemented.
290 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
292  int64_t elementSizeBytes) {
293  // Only densely packed tensors are currently supported.
294 #ifdef _WIN32
295  int64_t *denseStrides = (int64_t *)_alloca(rank * sizeof(int64_t));
296 #else
297  int64_t *denseStrides = (int64_t *)alloca(rank * sizeof(int64_t));
298 #endif // _WIN32
299  int64_t *sizes = descriptor->sizes;
300  for (int64_t i = rank - 1, runningStride = 1; i >= 0; i--) {
301  denseStrides[i] = runningStride;
302  runningStride *= sizes[i];
303  }
304  uint64_t sizeBytes = sizes[0] * denseStrides[0] * elementSizeBytes;
305  int64_t *strides = &sizes[rank];
306  (void)strides;
307  for (unsigned i = 0; i < rank; ++i)
308  assert(strides[i] == denseStrides[i] &&
309  "Mismatch in computed dense strides");
310 
311  auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes;
312  mgpuMemHostRegister(ptr, sizeBytes);
313 }
314 
315 // Allows to unregister byte array with the CUDA runtime.
317  ScopedContext scopedContext;
318  CUDA_REPORT_IF_ERROR(cuMemHostUnregister(ptr));
319 }
320 
321 /// Unregisters a memref with the CUDA runtime. `descriptor` is a pointer to a
322 /// ranked memref descriptor struct of rank `rank`
323 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
325  StridedMemRefType<char, 1> *descriptor,
326  int64_t elementSizeBytes) {
327  auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes;
329 }
330 
331 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) {
332  defaultDevice = device;
333 }
334 
335 ///
336 /// Runtime methods using CUDA 12.0+ driver
337 ///
338 
339 #if (CUDA_VERSION >= 12000)
340 
341 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuLaunchClusterKernel(
342  CUfunction function, intptr_t clusterX, intptr_t clusterY,
343  intptr_t clusterZ, intptr_t gridX, intptr_t gridY, intptr_t gridZ,
344  intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem,
345  CUstream stream, void **params, void **extra, size_t /*paramsCount*/) {
346  ScopedContext scopedContext;
347  if (smem > 0) {
348  // Avoid checking driver as it's more expensive than if statement
349  int32_t maxShmem = 0;
350  CUdevice device = getDefaultCuDevice();
351  CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice));
352  CUDA_REPORT_IF_ERROR(cuDeviceGetAttribute(
353  &maxShmem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
354  device));
355  if (maxShmem < smem) {
356  fprintf(stderr,
357  "Requested shared memory (%dkb) is larger than maximum allowed "
358  "shared memory (%dkb) for this device\n",
359  smem, maxShmem);
360  }
361  CUDA_REPORT_IF_ERROR(cuFuncSetAttribute(
362  function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem));
363  }
364  CUlaunchConfig config;
365  config.gridDimX = gridX;
366  config.gridDimY = gridY;
367  config.gridDimZ = gridZ;
368  config.blockDimX = blockX;
369  config.blockDimY = blockY;
370  config.blockDimZ = blockZ;
371  config.sharedMemBytes = smem;
372  config.hStream = stream;
373  CUlaunchAttribute launchAttr[2];
374  launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
375  launchAttr[0].value.clusterDim.x = clusterX;
376  launchAttr[0].value.clusterDim.y = clusterY;
377  launchAttr[0].value.clusterDim.z = clusterZ;
378  launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
379  launchAttr[1].value.clusterSchedulingPolicyPreference =
380  CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
381  config.numAttrs = 2;
382  config.attrs = launchAttr;
383 
384  debug_print("Launching kernel,"
385  "cluster: %ld, %ld, %ld, "
386  "grid=%ld,%ld,%ld, "
387  "threads: %ld, %ld, %ld, "
388  "smem: %dkb\n",
389  clusterX, clusterY, clusterZ, gridX, gridY, gridZ, blockX, blockY,
390  blockZ, smem);
391 
392  CUDA_REPORT_IF_ERROR(cuLaunchKernelEx(&config, function, params, extra));
393 }
394 
395 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuTensorMapEncodeTiled(
396  CUtensorMap *tensorMap, // Tensor map object
397  CUtensorMapDataType tensorDataType, // Tensor data type
398  cuuint32_t tensorRank, // Dimensionality of tensor
399  void *globalAddress, // Starting address
400  const cuuint64_t *globalDim, // Tensor size (number of elements)
401  const cuuint64_t *globalStrides, // Stride size (in bytes)
402  const cuuint32_t *boxDim, // Traversal box (number of elments)
403  const cuuint32_t *elementStrides, // Traversal stride
404  CUtensorMapInterleave interleave, // Type of interleaved layout
405  CUtensorMapSwizzle swizzle, // Bank swizzling pattern
406  CUtensorMapL2promotion l2Promotion, // L2 promotion size
407  CUtensorMapFloatOOBfill oobFill // Padding zfill or NaN fill
408 ) {
409  ScopedContext scopedContext;
410  CUDA_REPORT_IF_ERROR(cuTensorMapEncodeTiled(
411  tensorMap, tensorDataType, tensorRank, globalAddress, globalDim,
412  globalStrides, boxDim, elementStrides, interleave, swizzle, l2Promotion,
413  oobFill));
414  debug_print("Created TMA descriptor\n Addr: %p\n"
415  "data type : %d\n"
416  "rank : %d\n"
417  "globalDim[5]: %zu, %zu, %zu, %zu, %zu\n"
418  "globalStrides[5]: %zu, %zu, %zu, %zu, %zu\n"
419  "boxDim[5]: %u, %u, %u, %u, %u\n"
420  "elementStrides[5]: %u, %u, %u, %u, %u\n"
421  "interleave: %u \n"
422  "swizzle: %u \n"
423  "l2Promotion: %u \n"
424  "oobFill: %u \n",
425  (void *)&tensorMap, tensorDataType, tensorRank, globalDim[0],
426  globalDim[1], globalDim[2], globalDim[3], globalDim[4],
427  globalStrides[0], globalStrides[1], globalStrides[2],
428  globalStrides[3], globalStrides[4], boxDim[0], boxDim[1],
429  boxDim[2], boxDim[3], boxDim[4], elementStrides[0],
430  elementStrides[1], elementStrides[2], elementStrides[3],
431  elementStrides[4], interleave, swizzle, l2Promotion, oobFill);
432 }
433 
434 template <int Rank>
435 void mgpuGetMemRefDataAndShape(void *rawDescriptor, char **addr,
436  uint64_t *globalDim, uint64_t *globalStrides,
437  const CUtensorMapDataType tensorDataType) {
438  auto descriptor =
439  reinterpret_cast<StridedMemRefType<char, Rank> *>(rawDescriptor);
440  *addr = descriptor->data;
441  for (int i = 0; i < Rank; ++i) {
442  globalDim[i] = static_cast<uint64_t>(descriptor->sizes[Rank - i - 1]);
443  }
444  static constexpr int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2,
445  4, 8, 2, 4, 4, 4};
446  for (int i = 0; i < Rank - 1; ++i) {
447  globalStrides[i] = static_cast<uint64_t>(
448  descriptor->strides[Rank - i - 2] * elementSizeInBytes[tensorDataType]);
449  }
450 }
451 
452 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref(
453  int64_t tensorRank, // Dimensionality of tensor
454  void *rankedDescriptor, // Ranked MemRef descriptor
455  const CUtensorMapDataType tensorDataType, // Stride size (in bytes)
456  CUtensorMapInterleave interleave, // Type of interleaved layout
457  CUtensorMapSwizzle swizzle, // Bank swizzling pattern
458  CUtensorMapL2promotion l2Promotion, // L2 promotion size
459  CUtensorMapFloatOOBfill oobFill, // Padding zfill or NaN fill
460  int64_t *inputBoxDims // Tensor size (number of elements)
461 ) {
462  CUtensorMap tensorMap;
463 
464  uint32_t boxDim[5] = {1, 1, 1, 1, 1}, elementStrides[5] = {1, 1, 1, 1, 1};
465  uint64_t globalDim[5] = {1, 1, 1, 1, 1}, globalStrides[5] = {0};
466  uint32_t tensorRank32 = uint32_t(tensorRank);
467 
468  char *globalAddress = nullptr;
469  switch (tensorRank) {
470  case 1:
471  mgpuGetMemRefDataAndShape<1>(rankedDescriptor, &globalAddress, globalDim,
472  globalStrides, tensorDataType);
473  break;
474  case 2:
475  mgpuGetMemRefDataAndShape<2>(rankedDescriptor, &globalAddress, globalDim,
476  globalStrides, tensorDataType);
477  break;
478  case 3:
479  mgpuGetMemRefDataAndShape<3>(rankedDescriptor, &globalAddress, globalDim,
480  globalStrides, tensorDataType);
481  break;
482  case 4:
483  mgpuGetMemRefDataAndShape<4>(rankedDescriptor, &globalAddress, globalDim,
484  globalStrides, tensorDataType);
485  break;
486  case 5:
487  mgpuGetMemRefDataAndShape<5>(rankedDescriptor, &globalAddress, globalDim,
488  globalStrides, tensorDataType);
489  break;
490  default:
491  fprintf(
492  stderr,
493  "'mgpuTensorMapEncodeTiledMemref' failed with 'rank is too high'\n");
494  return nullptr;
495  }
496 
497  for (int64_t r = 0; r < tensorRank; ++r) {
498  boxDim[r] = static_cast<uint32_t>(inputBoxDims[tensorRank - r - 1]);
499  }
500 
501  ScopedContext scopedContext;
502  mgpuTensorMapEncodeTiled(&tensorMap, tensorDataType, tensorRank32,
503  globalAddress, globalDim, globalStrides, boxDim,
504  elementStrides, interleave, swizzle, l2Promotion,
505  oobFill);
506  // Copy created tensor map to device
507  CUdeviceptr dTensorMap;
508  CUDA_REPORT_IF_ERROR(cuMemAlloc(&dTensorMap, sizeof(CUtensorMap)));
509  CUDA_REPORT_IF_ERROR(cuMemcpy(dTensorMap,
510  reinterpret_cast<CUdeviceptr>(&tensorMap),
511  sizeof(CUtensorMap)));
512  return reinterpret_cast<void *>(dTensorMap);
513 }
514 #endif
515 
516 #ifdef MLIR_ENABLE_CUDA_CUSPARSE
517 
518 ///
519 /// Wrapper methods for the cuSparse library.
520 ///
521 
522 // Some macro magic to get float/double alpha and beta on host.
523 // TODO: add support to passing alpha and beta as arguments
524 #define ALPHABETA(dtp, alpha, beta) \
525  __nv_bfloat16(alpha##16bf) = 1.0f; \
526  __nv_bfloat16(beta##16bf) = 1.0f; \
527  __half(alpha##16f) = 1.0f; \
528  __half(beta##16f) = 1.0f; \
529  float(alpha##f) = 1.0f; \
530  float(beta##f) = 1.0f; \
531  double(alpha##d) = 1.0; \
532  double(beta##d) = 1.0; \
533  const void *(alpha##p) = nullptr; \
534  const void *(beta##p) = nullptr; \
535  if (dtp == CUDA_R_16BF || dtp == CUDA_C_16BF) { \
536  (alpha##p) = reinterpret_cast<void *>(&(alpha##16bf)); \
537  (beta##p) = reinterpret_cast<void *>(&(beta##16bf)); \
538  } else if (dtp == CUDA_R_16F || dtp == CUDA_C_16F) { \
539  (alpha##p) = reinterpret_cast<void *>(&(alpha##16f)); \
540  (beta##p) = reinterpret_cast<void *>(&(beta##16f)); \
541  } else if (dtp == CUDA_R_32F || dtp == CUDA_C_32F) { \
542  (alpha##p) = reinterpret_cast<void *>(&(alpha##f)); \
543  (beta##p) = reinterpret_cast<void *>(&(beta##f)); \
544  } else { \
545  (alpha##p) = reinterpret_cast<void *>(&(alpha##d)); \
546  (beta##p) = reinterpret_cast<void *>(&(beta##d)); \
547  }
548 
549 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuCreateSparseEnv() {
550  // ScopedContext is for cuda initialization.
551  ScopedContext scopedContext;
552  assert(!cusparse_env && "client called mgpuCreateSparseEnv() twice");
553  CUSPARSE_REPORT_IF_ERROR(cusparseCreate(&cusparse_env));
554 }
555 
556 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuDestroySparseEnv() {
557  assert(cusparse_env && "client did not call mgpuCreateSparseEnv()");
558  CUSPARSE_REPORT_IF_ERROR(cusparseDestroy(cusparse_env));
559  cusparse_env = nullptr;
560 }
561 
562 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
563 mgpuCreateDnVec(intptr_t size, void *values, int32_t dtp, CUstream /*stream*/) {
564  cusparseDnVecDescr_t vec = nullptr;
565  auto dTp = static_cast<cudaDataType_t>(dtp);
566  CUSPARSE_REPORT_IF_ERROR(cusparseCreateDnVec(&vec, size, values, dTp))
567  return reinterpret_cast<void *>(vec);
568 }
569 
570 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
571 mgpuDestroyDnVec(void *v, CUstream /*stream*/) {
572  cusparseDnVecDescr_t vec = reinterpret_cast<cusparseDnVecDescr_t>(v);
573  CUSPARSE_REPORT_IF_ERROR(cusparseDestroyDnVec(vec))
574 }
575 
576 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
577 mgpuCreateDnMat(intptr_t rows, intptr_t cols, void *values, int32_t dtp,
578  CUstream /*stream*/) {
579  cusparseDnMatDescr_t mat = nullptr;
580  auto dTp = static_cast<cudaDataType_t>(dtp);
581  CUSPARSE_REPORT_IF_ERROR(cusparseCreateDnMat(&mat, rows, cols, /*ld=*/cols,
582  values, dTp, CUSPARSE_ORDER_ROW))
583  return reinterpret_cast<void *>(mat);
584 }
585 
586 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
587 mgpuDestroyDnMat(void *m, CUstream /*stream*/) {
588  cusparseDnMatDescr_t mat = reinterpret_cast<cusparseDnMatDescr_t>(m);
589  CUSPARSE_REPORT_IF_ERROR(cusparseDestroyDnMat(mat))
590 }
591 
592 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
593 mgpuCreateCoo(intptr_t rows, intptr_t cols, intptr_t nnz, void *rowIdxs,
594  void *colIdxs, void *values, int32_t itp, int32_t dtp,
595  CUstream /*stream*/) {
596  cusparseSpMatDescr_t mat = nullptr;
597  auto iTp = static_cast<cusparseIndexType_t>(itp);
598  auto dTp = static_cast<cudaDataType_t>(dtp);
599  CUSPARSE_REPORT_IF_ERROR(cusparseCreateCoo(&mat, rows, cols, nnz, rowIdxs,
600  colIdxs, values, iTp,
601  CUSPARSE_INDEX_BASE_ZERO, dTp))
602  return reinterpret_cast<void *>(mat);
603 }
604 
605 #ifdef CUSPARSE_COO_AOS // deprecated in cuSPARSE 11.2
606 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
607 mgpuCreateCooAoS(intptr_t rows, intptr_t cols, intptr_t nnz, void *idxs,
608  void *values, int32_t itp, int32_t dtp, CUstream /*stream*/) {
609  cusparseSpMatDescr_t mat = nullptr;
610  auto iTp = static_cast<cusparseIndexType_t>(itp);
611  auto dTp = static_cast<cudaDataType_t>(dtp);
612  CUSPARSE_REPORT_IF_ERROR(cusparseCreateCooAoS(
613  &mat, rows, cols, nnz, idxs, values, iTp, CUSPARSE_INDEX_BASE_ZERO, dTp))
614  return reinterpret_cast<void *>(mat);
615 }
616 #endif // CUSPARSE_COO_AOS
617 
618 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
619 mgpuCreateCsr(intptr_t rows, intptr_t cols, intptr_t nnz, void *rowPos,
620  void *colIdxs, void *values, int32_t ptp, int32_t itp,
621  int32_t dtp, CUstream /*stream*/) {
622  cusparseSpMatDescr_t mat = nullptr;
623  auto pTp = static_cast<cusparseIndexType_t>(ptp);
624  auto iTp = static_cast<cusparseIndexType_t>(itp);
625  auto dTp = static_cast<cudaDataType_t>(dtp);
626  CUSPARSE_REPORT_IF_ERROR(cusparseCreateCsr(&mat, rows, cols, nnz, rowPos,
627  colIdxs, values, pTp, iTp,
628  CUSPARSE_INDEX_BASE_ZERO, dTp))
629  return reinterpret_cast<void *>(mat);
630 }
631 
632 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
633 mgpuCreateCsc(intptr_t rows, intptr_t cols, intptr_t nnz, void *colPos,
634  void *rowIdxs, void *values, int32_t ptp, int32_t itp,
635  int32_t dtp, CUstream /*stream*/) {
636  cusparseSpMatDescr_t mat = nullptr;
637  auto pTp = static_cast<cusparseIndexType_t>(ptp);
638  auto iTp = static_cast<cusparseIndexType_t>(itp);
639  auto dTp = static_cast<cudaDataType_t>(dtp);
640  CUSPARSE_REPORT_IF_ERROR(cusparseCreateCsc(&mat, rows, cols, nnz, colPos,
641  rowIdxs, values, pTp, iTp,
642  CUSPARSE_INDEX_BASE_ZERO, dTp))
643  return reinterpret_cast<void *>(mat);
644 }
645 
646 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
647 mgpuCreateBsr(intptr_t brows, intptr_t bcols, intptr_t bnnz, intptr_t rBsz,
648  intptr_t cBsz, void *rowPos, void *colIdxs, void *values,
649  int32_t ptp, int32_t itp, int32_t dtp, CUstream /*stream*/) {
650  cusparseSpMatDescr_t mat = nullptr;
651 #if CUSPARSE_VERSION >= 12100
652  auto pTp = static_cast<cusparseIndexType_t>(ptp);
653  auto iTp = static_cast<cusparseIndexType_t>(itp);
654  auto dTp = static_cast<cudaDataType_t>(dtp);
655  CUSPARSE_REPORT_IF_ERROR(cusparseCreateBsr(
656  &mat, brows, bcols, bnnz, rBsz, cBsz, rowPos, colIdxs, values, pTp, iTp,
657  CUSPARSE_INDEX_BASE_ZERO, dTp, CUSPARSE_ORDER_ROW))
658 #endif
659  return reinterpret_cast<void *>(mat);
660 }
661 
662 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
663 mgpuDestroySpMat(void *m, CUstream /*stream*/) {
664  cusparseSpMatDescr_t mat = reinterpret_cast<cusparseSpMatDescr_t>(m);
665  CUSPARSE_REPORT_IF_ERROR(cusparseDestroySpMat(mat))
666 }
667 
668 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMVBufferSize(
669  int32_t ma, void *a, void *x, void *y, int32_t ctp, CUstream /*stream*/) {
670  assert(cusparse_env && "client did not call mgpuCreateSparseEnv()");
671  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
672  cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
673  cusparseDnVecDescr_t vecX = reinterpret_cast<cusparseDnVecDescr_t>(x);
674  cusparseDnVecDescr_t vecY = reinterpret_cast<cusparseDnVecDescr_t>(y);
675  cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp);
676  ALPHABETA(cTp, alpha, beta)
677  size_t bufferSize = 0;
678  CUSPARSE_REPORT_IF_ERROR(cusparseSpMV_bufferSize(
679  cusparse_env, modeA, alphap, matA, vecX, betap, vecY, cTp,
680  CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize))
681  return bufferSize;
682 }
683 
684 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMV(int32_t ma, void *a, void *x,
685  void *y, int32_t ctp,
686  void *buf,
687  CUstream /*stream*/) {
688  assert(cusparse_env && "client did not call mgpuCreateSparseEnv()");
689  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
690  cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
691  cusparseDnVecDescr_t vecX = reinterpret_cast<cusparseDnVecDescr_t>(x);
692  cusparseDnVecDescr_t vecY = reinterpret_cast<cusparseDnVecDescr_t>(y);
693  cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp);
694  ALPHABETA(cTp, alpha, beta)
695  CUSPARSE_REPORT_IF_ERROR(cusparseSpMV(cusparse_env, modeA, alphap, matA, vecX,
696  betap, vecY, cTp,
697  CUSPARSE_SPMV_ALG_DEFAULT, buf))
698 }
699 
700 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
701 mgpuSpMMBufferSize(int32_t ma, int32_t mb, void *a, void *b, void *c,
702  int32_t ctp, CUstream /*stream*/) {
703  assert(cusparse_env && "client did not call mgpuCreateSparseEnv()");
704  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
705  cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
706  cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
707  cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
708  cusparseDnMatDescr_t matC = reinterpret_cast<cusparseDnMatDescr_t>(c);
709  cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp);
710  ALPHABETA(cTp, alpha, beta)
711  size_t bufferSize = 0;
712  CUSPARSE_REPORT_IF_ERROR(cusparseSpMM_bufferSize(
713  cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
714  CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize))
715  return bufferSize;
716 }
717 
718 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMM(int32_t ma, int32_t mb,
719  void *a, void *b, void *c,
720  int32_t ctp, void *buf,
721  CUstream /*stream*/) {
722  assert(cusparse_env && "client did not call mgpuCreateSparseEnv()");
723  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
724  cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
725  cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
726  cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
727  cusparseDnMatDescr_t matC = reinterpret_cast<cusparseDnMatDescr_t>(c);
728  cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp);
729  ALPHABETA(cTp, alpha, beta)
730  CUSPARSE_REPORT_IF_ERROR(cusparseSpMM(cusparse_env, modeA, modeB, alphap,
731  matA, matB, betap, matC, cTp,
732  CUSPARSE_SPMM_ALG_DEFAULT, buf))
733 }
734 
735 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
736 mgpuSDDMMBufferSize(int32_t ma, int32_t mb, void *a, void *b, void *c,
737  int32_t ctp, CUstream /*stream*/) {
738  assert(cusparse_env && "client did not call mgpuCreateSparseEnv()");
739  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
740  cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
741  cusparseDnMatDescr_t matA = reinterpret_cast<cusparseDnMatDescr_t>(a);
742  cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
743  cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
744  auto cTp = static_cast<cudaDataType_t>(ctp);
745  ALPHABETA(cTp, alpha, beta)
746  size_t bufferSize = 0;
747  CUSPARSE_REPORT_IF_ERROR(cusparseSDDMM_bufferSize(
748  cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
749  CUSPARSE_SDDMM_ALG_DEFAULT, &bufferSize))
750  return bufferSize;
751 }
752 
753 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSDDMM(int32_t ma, int32_t mb,
754  void *a, void *b, void *c,
755  int32_t ctp, void *buf,
756  CUstream /*stream*/) {
757  assert(cusparse_env && "client did not call mgpuCreateSparseEnv()");
758  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
759  cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
760  cusparseDnMatDescr_t matA = reinterpret_cast<cusparseDnMatDescr_t>(a);
761  cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
762  cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
763  auto cTp = static_cast<cudaDataType_t>(ctp);
764  ALPHABETA(cTp, alpha, beta)
765  CUSPARSE_REPORT_IF_ERROR(cusparseSDDMM(cusparse_env, modeA, modeB, alphap,
766  matA, matB, betap, matC, cTp,
767  CUSPARSE_SDDMM_ALG_DEFAULT, buf))
768 }
769 
770 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
771 mgpuSpGEMMCreateDescr(CUstream /*stream*/) {
772  cusparseSpGEMMDescr_t spgemmDesc = nullptr;
773  CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_createDescr(&spgemmDesc))
774  return reinterpret_cast<void *>(spgemmDesc);
775 }
776 
777 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
778 mgpuSpGEMMDestroyDescr(void *s, CUstream /*stream*/) {
779  cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
780  CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_destroyDescr(spgemmDesc))
781 }
782 
783 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpGEMMWorkEstimation(
784  void *s, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t ctp,
785  intptr_t bs, void *buf, CUstream /*stream*/) {
786  cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
787  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
788  cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
789  cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
790  cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b);
791  cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
792  auto cTp = static_cast<cudaDataType_t>(ctp);
793  ALPHABETA(cTp, alpha, beta)
794  size_t newBufferSize = bs;
795  CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_workEstimation(
796  cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
797  CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &newBufferSize, buf))
798  return newBufferSize;
799 }
800 
801 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
802 mgpuSpGEMMCompute(void *s, int32_t ma, int32_t mb, void *a, void *b, void *c,
803  int32_t ctp, intptr_t bsz2, void *buf2, CUstream /*stream*/) {
804  cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
805  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
806  cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
807  cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
808  cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b);
809  cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
810  auto cTp = static_cast<cudaDataType_t>(ctp);
811  ALPHABETA(cTp, alpha, beta)
812  size_t newBufferSize2 = bsz2;
813  CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_compute(
814  cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
815  CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &newBufferSize2, buf2))
816  return newBufferSize2;
817 }
818 
819 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
820 mgpuSpGEMMCopy(void *s, int32_t ma, int32_t mb, void *a, void *b, void *c,
821  int32_t ctp, CUstream /*stream*/) {
822  cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
823  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
824  cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
825  cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
826  cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b);
827  cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
828  auto cTp = static_cast<cudaDataType_t>(ctp);
829  ALPHABETA(cTp, alpha, beta)
831  cusparseSpGEMM_copy(cusparse_env, modeA, modeB, alphap, matA, matB, betap,
832  matC, cTp, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc))
833 }
834 
835 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
836 mgpuSpMatGetSize(void *m, void *r, void *c, void *n, CUstream /*stream*/) {
837  cusparseConstSpMatDescr_t matDescr =
838  reinterpret_cast<cusparseConstSpMatDescr_t>(m);
839  int64_t *rows = reinterpret_cast<int64_t *>(r);
840  int64_t *cols = reinterpret_cast<int64_t *>(c);
841  int64_t *nnz = reinterpret_cast<int64_t *>(n);
842  CUSPARSE_REPORT_IF_ERROR(cusparseSpMatGetSize(matDescr, rows, cols, nnz));
843 }
844 
845 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
846 mgpuSetCsrPointers(void *m, void *p, void *c, void *v, CUstream /*stream*/) {
847  cusparseSpMatDescr_t matDescr = reinterpret_cast<cusparseSpMatDescr_t>(m);
848  CUSPARSE_REPORT_IF_ERROR(cusparseCsrSetPointers(matDescr, p, c, v));
849 }
850 
851 #ifdef MLIR_ENABLE_CUDA_CUSPARSELT
852 
853 ///
854 /// Wrapper methods for the cuSparseLt library.
855 ///
856 
857 struct cusparseLtSpMatHandleAndData {
858  cusparseLtMatDescriptor_t mat;
859  // TODO: the following three are associated with the SpMM operator rather than
860  // the sparse matrix. Create workspace buffers and pass them to the SpMM
861  // execution.
862  cusparseLtMatmulAlgSelection_t alg_sel;
863  cusparseLtMatmulPlan_t plan;
864  cusparseLtMatmulDescriptor_t matmul;
865  void *values{nullptr};
866 };
867 
868 struct cusparseLtDnMatHandleAndData {
869  cusparseLtMatDescriptor_t mat;
870  void *values{nullptr};
871 };
872 
873 static_assert(sizeof(cusparseLtHandle_t) == 11024,
874  "Unexpected cusparseLt handle size");
875 static_assert(sizeof(cusparseLtSpMatHandleAndData) == 44104,
876  "Unexpected cusparseLt sparse matrix handle size");
877 static_assert(sizeof(cusparseLtDnMatHandleAndData) == 11032,
878  "Unexpected cusparseLt dense matrix handle size");
879 
880 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuCreateSparseLtEnv() {
881  // ScopedContext is for cuda initialization.
882  ScopedContext scopedContext;
883  assert(!cusparseLt_initiated &&
884  "client called mgpuCreateSparseLtEnv() twice");
885  // Note that cuSparseLt still uses cusparseStatus_t.
886  CUSPARSE_REPORT_IF_ERROR(cusparseLtInit(&cusparseLt_env));
887  cusparseLt_initiated = true;
888 }
889 
890 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuDestroySparseLtEnv() {
891  assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()");
892  CUSPARSE_REPORT_IF_ERROR(cusparseLtDestroy(&cusparseLt_env));
893  cusparseLt_initiated = false;
894 }
895 
896 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
897 mgpuCreateCuSparseLtDnMat(void *dh, intptr_t rows, intptr_t cols, void *values,
898  int32_t dtp, CUstream /*stream*/) {
899  assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()");
900  auto dnmat_handle = reinterpret_cast<cusparseLtDnMatHandleAndData *>(dh);
901  dnmat_handle->values = values;
902  auto dTp = static_cast<cudaDataType_t>(dtp);
903  // Assume row-major when deciding lda.
904  const uint32_t alignment = 16;
905  CUSPARSE_REPORT_IF_ERROR(cusparseLtDenseDescriptorInit(
906  &cusparseLt_env, &(dnmat_handle->mat), rows, cols, /*lda=*/cols,
907  alignment, dTp, CUSPARSE_ORDER_ROW))
908 }
909 
910 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
911 mgpuDestroyCuSparseLtDnMat(void *dh, CUstream /*stream*/) {
912  auto dnmat_handle = reinterpret_cast<cusparseLtDnMatHandleAndData *>(dh);
913  CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(dnmat_handle->mat)))
914 }
915 
916 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
917 mgpuCusparseLtCreate2To4SpMat(void *sh, intptr_t rows, intptr_t cols,
918  void *values, int32_t dtp, CUstream /*stream*/) {
919  assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()");
920  auto spmat_handle = reinterpret_cast<cusparseLtSpMatHandleAndData *>(sh);
921  spmat_handle->values = values;
922  auto dTp = static_cast<cudaDataType_t>(dtp);
923  // Assume row-major when deciding lda.
924  const uint32_t alignment = 16;
925  CUSPARSE_REPORT_IF_ERROR(cusparseLtStructuredDescriptorInit(
926  &cusparseLt_env, &(spmat_handle->mat), rows, cols, /*ld=*/cols, alignment,
927  dTp, CUSPARSE_ORDER_ROW, CUSPARSELT_SPARSITY_50_PERCENT))
928 }
929 
930 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
931 mgpuDestroyCuSparseLtSpMat(void *sh, CUstream /*stream*/) {
932  auto spmat_handle = reinterpret_cast<cusparseLtSpMatHandleAndData *>(sh);
933  CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(spmat_handle->mat)))
934 }
935 
936 // Several things are being done in this stage, algorithm selection, planning,
937 // and returning workspace and compressed matrices data buffer sizes.
938 // The parameter prune_flag is used to indicate whether pruning and pruning
939 // check will happen 0 means not prune or prune check, 1 means prune, 2 means
940 // prune & prune check
941 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
942 mgpuCuSparseLtSpMMBufferSize(void *bs, int32_t ma, int32_t mb, void *a, void *b,
943  void *c, int32_t ctp, int32_t prune_flag,
944  CUstream stream) {
945  assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()");
946  // TODO: support more advanced settings, e.g., the input right operand is a
947  // sparse matrix assuming matA is the sparse matrix
948  auto matA = reinterpret_cast<cusparseLtSpMatHandleAndData *>(a);
949  auto matB = reinterpret_cast<cusparseLtDnMatHandleAndData *>(b);
950  auto matC = reinterpret_cast<cusparseLtDnMatHandleAndData *>(c);
951  auto workspace_size = reinterpret_cast<size_t *>(bs);
952  auto compressed_size = &(reinterpret_cast<size_t *>(bs)[1]);
953  auto compressed_buffer_size = &(reinterpret_cast<size_t *>(bs)[2]);
954  auto cTp = static_cast<cusparseComputeType>(ctp);
955 
956  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
957  cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
958  CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulDescriptorInit(
959  &cusparseLt_env, &(matA->matmul), modeA, modeB, &(matA->mat),
960  &(matB->mat), &(matC->mat), &(matC->mat), cTp))
961  CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulAlgSelectionInit(
962  &cusparseLt_env, &(matA->alg_sel), &(matA->matmul),
963  CUSPARSELT_MATMUL_ALG_DEFAULT))
964  int alg = 0;
965  CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulAlgSetAttribute(
966  &cusparseLt_env, &(matA->alg_sel), CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg,
967  sizeof(alg)))
968 
969  CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulPlanInit(
970  &cusparseLt_env, &(matA->plan), &(matA->matmul), &(matA->alg_sel)))
971 
972  // Pruning step (in-place).
973  if (prune_flag > 0)
974  CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMAPrune(
975  &cusparseLt_env, &(matA->matmul), matA->values, matA->values,
976  CUSPARSELT_PRUNE_SPMMA_STRIP, stream))
977 
978  // Check structure of A.
979  // Note that this adds a synchronization on the stream.
980  // TODO: Do we want that?
981  if (prune_flag == 2) {
982  int *dvalid = (int *)mgpuMemAlloc(sizeof(int), stream, false);
983  CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMAPruneCheck(
984  &cusparseLt_env, &(matA->matmul), matA->values, dvalid, stream))
985  int valid = 0;
986  mgpuMemcpy(&valid, dvalid, sizeof(int), stream);
987  mgpuStreamSynchronize(stream);
988  mgpuMemFree(dvalid, stream);
989  if (valid != 0)
990  fprintf(stderr, "CUPARSE-LT: sparse matrix is not 2:4; computed results "
991  "will be invalid\n");
992  }
993 
994  CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulGetWorkspace(
995  &cusparseLt_env, &(matA->plan), workspace_size))
996  CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMACompressedSize(
997  &cusparseLt_env, &(matA->plan), compressed_size, compressed_buffer_size))
998 }
999 
1000 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
1001 mgpuCuSparseLtSpMM(void *a, void *b, void *c, void *d_workspace,
1002  void *dA_compressed, void *dA_compressedBuffer,
1003  CUstream stream) {
1004  assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()");
1005  auto matA = reinterpret_cast<cusparseLtSpMatHandleAndData *>(a);
1006  auto matB = reinterpret_cast<cusparseLtDnMatHandleAndData *>(b);
1007  auto matC = reinterpret_cast<cusparseLtDnMatHandleAndData *>(c);
1008 
1009  ALPHABETA(CUDA_R_32F, alpha, beta)
1011  cusparseLtSpMMACompress(&cusparseLt_env, &(matA->plan), (matA->values),
1012  dA_compressed, dA_compressedBuffer, stream))
1013 
1014  // TODO: add support to multi-stream execution
1015  // Perform the matrix multiplication. D = A*B+C using C==D for now
1017  cusparseLtMatmul(&cusparseLt_env, &(matA->plan), alphap, dA_compressed,
1018  matB->values, betap, matC->values,
1019  /*dD*/ matC->values, d_workspace, nullptr, 0))
1020 
1021  CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(matA->mat)))
1022  // destroy the plan associated with the sparse matrix
1023  CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulPlanDestroy(&(matA->plan)))
1024 }
1025 
1026 #endif // MLIR_ENABLE_CUDA_CUSPARSELT
1027 #endif // MLIR_ENABLE_CUDA_CUSPARSE
#define CUSPARSE_REPORT_IF_ERROR(expr)
bool isDebugEnabled()
Helper method that checks environment value for debugging.
MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamWaitEvent(CUstream stream, CUevent event)
#define MLIR_CUDA_WRAPPERS_EXPORT
MLIR_CUDA_WRAPPERS_EXPORT void mgpuModuleUnload(CUmodule module)
#define CUDA_REPORT_IF_ERROR(expr)
static thread_local int32_t defaultDevice
MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType< char, 1 > *descriptor, int64_t elementSizeBytes)
Registers a memref with the CUDA runtime.
MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemHostUnregister(void *ptr)
MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventRecord(CUevent event, CUstream stream)
MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemcpy(void *dst, void *src, size_t sizeBytes, CUstream stream)
MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemset16(void *dst, unsigned short value, size_t count, CUstream stream)
MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemset32(void *dst, unsigned int value, size_t count, CUstream stream)
MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoadJIT(void *data, int optLevel)
MLIR_CUDA_WRAPPERS_EXPORT CUfunction mgpuModuleGetFunction(CUmodule module, const char *name)
#define debug_print(fmt,...)
MLIR_CUDA_WRAPPERS_EXPORT CUevent mgpuEventCreate()
CUdevice getDefaultCuDevice()
MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes)
Helper functions for writing mlir example code.
MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamSynchronize(CUstream stream)
MLIR_CUDA_WRAPPERS_EXPORT CUstream mgpuStreamCreate()
MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemFree(void *ptr, CUstream)
MLIR_CUDA_WRAPPERS_EXPORT void * mgpuMemAlloc(uint64_t sizeBytes, CUstream stream, bool isHostShared)
MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoad(void *data, size_t)
MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device)
MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamDestroy(CUstream stream)
MLIR_CUDA_WRAPPERS_EXPORT void mgpuLaunchKernel(CUfunction function, intptr_t gridX, intptr_t gridY, intptr_t gridZ, intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem, CUstream stream, void **params, void **extra, size_t)
MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemHostUnregisterMemRef(int64_t rank, StridedMemRefType< char, 1 > *descriptor, int64_t elementSizeBytes)
Unregisters a memref with the CUDA runtime.
MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventSynchronize(CUevent event)
MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventDestroy(CUevent event)
int64_t cols
int64_t rows
const FrozenRewritePatternSet GreedyRewriteConfig config
StridedMemRef descriptor type with static rank.
Definition: CRunnerUtils.h:131
int64_t sizes[N]
Definition: CRunnerUtils.h:135