MLIR  21.0.0git
GPUToLLVMConversion.cpp
Go to the documentation of this file.
1 //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert gpu.launch_func op into a sequence of
10 // GPU runtime calls. As most of GPU runtimes does not have a stable published
11 // ABI, this pass uses a slim runtime layer that builds on top of the public
12 // API from GPU runtime headers.
13 //
14 //===----------------------------------------------------------------------===//
15 
17 
36 #include "mlir/IR/Attributes.h"
37 #include "mlir/IR/Builders.h"
38 #include "mlir/IR/BuiltinOps.h"
39 #include "mlir/IR/BuiltinTypes.h"
41 
42 #include "llvm/ADT/STLExtras.h"
43 #include "llvm/Support/Error.h"
44 #include "llvm/Support/FormatVariadic.h"
45 
46 #define DEBUG_TYPE "gpu-to-llvm"
47 
48 namespace mlir {
49 #define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
50 #include "mlir/Conversion/Passes.h.inc"
51 } // namespace mlir
52 
53 using namespace mlir;
54 
55 namespace {
56 class GpuToLLVMConversionPass
57  : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
58 public:
59  using Base::Base;
60  void getDependentDialects(DialectRegistry &registry) const final {
61  Base::getDependentDialects(registry);
63  }
64  // Run the dialect converter on the module.
65  void runOnOperation() override;
66 };
67 
68 template <typename OpTy>
69 class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
70 public:
71  explicit ConvertOpToGpuRuntimeCallPattern(
72  const LLVMTypeConverter &typeConverter)
73  : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
74 
75 protected:
77  MemRefType type, MemRefDescriptor desc) const {
79  if (type.hasStaticShape())
81  rewriter, loc, indexType, type.getNumElements());
82  // Compute the number of elements by multiplying all the dim sizes.
83  uint64_t rank = type.getRank();
84  Value numElements = desc.size(rewriter, loc, /*pos=*/0);
85  for (unsigned i = 1; i < rank; i++)
86  numElements = rewriter.create<LLVM::MulOp>(
87  loc, numElements, desc.size(rewriter, loc, /*pos=*/i));
88  return numElements;
89  }
90 
91  MLIRContext *context = &this->getTypeConverter()->getContext();
92 
93  Type llvmVoidType = LLVM::LLVMVoidType::get(context);
94  LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
95  Type llvmInt8Type = IntegerType::get(context, 8);
96  Type llvmInt16Type = IntegerType::get(context, 16);
97  Type llvmInt32Type = IntegerType::get(context, 32);
98  Type llvmInt64Type = IntegerType::get(context, 64);
99  Type llvmFloat32Type = Float32Type::get(context);
100  Type llvmIntPtrType = IntegerType::get(
101  context, this->getTypeConverter()->getPointerBitwidth(0));
102 
103  FunctionCallBuilder streamCreateCallBuilder = {
104  "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
105  FunctionCallBuilder streamDestroyCallBuilder = {
106  "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
107  FunctionCallBuilder streamSynchronizeCallBuilder = {
108  "mgpuStreamSynchronize",
109  llvmVoidType,
110  {llvmPointerType /* void *stream */}};
111  FunctionCallBuilder streamWaitEventCallBuilder = {
112  "mgpuStreamWaitEvent",
113  llvmVoidType,
114  {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
115  FunctionCallBuilder eventCreateCallBuilder = {
116  "mgpuEventCreate", llvmPointerType /* void *event */, {}};
117  FunctionCallBuilder eventDestroyCallBuilder = {
118  "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
119  FunctionCallBuilder eventSynchronizeCallBuilder = {
120  "mgpuEventSynchronize",
121  llvmVoidType,
122  {llvmPointerType /* void *event */}};
123  FunctionCallBuilder eventRecordCallBuilder = {
124  "mgpuEventRecord",
125  llvmVoidType,
126  {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
127  FunctionCallBuilder hostRegisterCallBuilder = {
128  "mgpuMemHostRegisterMemRef",
129  llvmVoidType,
130  {llvmIntPtrType /* intptr_t rank */,
131  llvmPointerType /* void *memrefDesc */,
132  llvmIntPtrType /* intptr_t elementSizeBytes */}};
133  FunctionCallBuilder hostUnregisterCallBuilder = {
134  "mgpuMemHostUnregisterMemRef",
135  llvmVoidType,
136  {llvmIntPtrType /* intptr_t rank */,
137  llvmPointerType /* void *memrefDesc */,
138  llvmIntPtrType /* intptr_t elementSizeBytes */}};
139  FunctionCallBuilder allocCallBuilder = {
140  "mgpuMemAlloc",
141  llvmPointerType /* void * */,
142  {llvmIntPtrType /* intptr_t sizeBytes */,
143  llvmPointerType /* void *stream */,
144  llvmInt8Type /* bool isHostShared */}};
145  FunctionCallBuilder deallocCallBuilder = {
146  "mgpuMemFree",
147  llvmVoidType,
148  {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
149  FunctionCallBuilder memcpyCallBuilder = {
150  "mgpuMemcpy",
151  llvmVoidType,
152  {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
153  llvmIntPtrType /* intptr_t sizeBytes */,
154  llvmPointerType /* void *stream */}};
155  FunctionCallBuilder memset16CallBuilder = {
156  "mgpuMemset16",
157  llvmVoidType,
158  {llvmPointerType /* void *dst */,
159  llvmInt16Type /* unsigned short value */,
160  llvmIntPtrType /* intptr_t sizeBytes */,
161  llvmPointerType /* void *stream */}};
162  FunctionCallBuilder memset32CallBuilder = {
163  "mgpuMemset32",
164  llvmVoidType,
165  {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
166  llvmIntPtrType /* intptr_t sizeBytes */,
167  llvmPointerType /* void *stream */}};
168  FunctionCallBuilder setDefaultDeviceCallBuilder = {
169  "mgpuSetDefaultDevice",
170  llvmVoidType,
171  {llvmInt32Type /* uint32_t devIndex */}};
172  FunctionCallBuilder createDnVecCallBuilder = {
173  "mgpuCreateDnVec",
174  llvmPointerType,
175  {llvmIntPtrType, llvmPointerType, llvmInt32Type,
176  llvmPointerType /* void *stream */}};
177  FunctionCallBuilder destroyDnVecCallBuilder = {
178  "mgpuDestroyDnVec",
179  llvmVoidType,
180  {llvmPointerType, llvmPointerType /* void *stream */}};
181  FunctionCallBuilder createDnMatCallBuilder = {
182  "mgpuCreateDnMat",
183  llvmPointerType,
184  {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
185  llvmPointerType /* void *stream */}};
186  FunctionCallBuilder destroyDnMatCallBuilder = {
187  "mgpuDestroyDnMat",
188  llvmVoidType,
189  {llvmPointerType, llvmPointerType /* void *stream */}};
190  FunctionCallBuilder createCooCallBuilder = {
191  "mgpuCreateCoo",
192  llvmPointerType,
193  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
194  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
195  llvmPointerType /* void *stream */}};
196  FunctionCallBuilder createCooAoSCallBuilder = {
197  "mgpuCreateCooAoS", // deprecated in cuSPARSE 11.2
198  llvmPointerType,
199  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
200  llvmPointerType, llvmInt32Type, llvmInt32Type,
201  llvmPointerType /* void *stream */}};
202  FunctionCallBuilder createCsrCallBuilder = {
203  "mgpuCreateCsr",
204  llvmPointerType,
205  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
206  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
207  llvmInt32Type, llvmPointerType /* void *stream */}};
208  FunctionCallBuilder createCscCallBuilder = {
209  "mgpuCreateCsc",
210  llvmPointerType,
211  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
212  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
213  llvmInt32Type, llvmPointerType /* void *stream */}};
214  FunctionCallBuilder createBsrCallBuilder = {
215  "mgpuCreateBsr",
216  llvmPointerType,
217  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
218  llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
219  llvmInt32Type, llvmInt32Type, llvmInt32Type,
220  llvmPointerType /* void *stream */}};
221  FunctionCallBuilder destroySpMatCallBuilder = {
222  "mgpuDestroySpMat",
223  llvmVoidType,
224  {llvmPointerType, llvmPointerType /* void *stream */}};
225  FunctionCallBuilder spMVBufferSizeCallBuilder = {
226  "mgpuSpMVBufferSize",
227  llvmIntPtrType,
228  {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
229  llvmInt32Type, llvmPointerType /* void *stream */}};
230  FunctionCallBuilder spMVCallBuilder = {
231  "mgpuSpMV",
232  llvmVoidType,
233  {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
234  llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
235  FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
236  "mgpuSpMMBufferSize",
237  llvmIntPtrType,
238  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
239  llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
240  FunctionCallBuilder createSpMMCallBuilder = {
241  "mgpuSpMM",
242  llvmVoidType,
243  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
244  llvmPointerType, llvmInt32Type, llvmPointerType,
245  llvmPointerType /* void *stream */}};
246  FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
247  "mgpuSDDMMBufferSize",
248  llvmIntPtrType,
249  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
250  llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
251  FunctionCallBuilder createSDDMMCallBuilder = {
252  "mgpuSDDMM",
253  llvmVoidType,
254  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
255  llvmPointerType, llvmInt32Type, llvmPointerType,
256  llvmPointerType /* void *stream */}};
257  FunctionCallBuilder createLtDnMatCallBuilder = {
258  "mgpuCreateCuSparseLtDnMat",
259  llvmVoidType,
260  {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
261  llvmInt32Type, llvmPointerType /* void *stream */}};
262  FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
263  "mgpuDestroyCuSparseLtSpMat",
264  llvmVoidType,
265  {llvmPointerType, llvmPointerType /* void *stream */}};
266  FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
267  "mgpuDestroyCuSparseLtDnMat",
268  llvmVoidType,
269  {llvmPointerType, llvmPointerType /* void *stream */}};
270  FunctionCallBuilder create2To4SpMatCallBuilder = {
271  "mgpuCusparseLtCreate2To4SpMat",
272  llvmVoidType,
273  {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
274  llvmInt32Type, llvmPointerType /* void *stream */}};
275  FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
276  "mgpuCuSparseLtSpMMBufferSize",
277  llvmVoidType,
278  {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
279  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
280  llvmPointerType /*void *stream*/}};
281  FunctionCallBuilder createCuSparseLtSpMMBuilder = {
282  "mgpuCuSparseLtSpMM",
283  llvmVoidType,
284  {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
285  llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}};
286  FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
287  "mgpuSpGEMMCreateDescr",
288  llvmPointerType,
289  {llvmPointerType /*void *stream*/}};
290  FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
291  "mgpuSpGEMMDestroyDescr",
292  llvmVoidType,
293  {llvmPointerType /*s*/, llvmPointerType /*void *stream*/}};
294  FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
295  "mgpuSpGEMMWorkEstimation",
296  llvmIntPtrType,
297  {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
298  llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
299  llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
300  llvmPointerType /*void *stream*/}};
301  FunctionCallBuilder createSpGEMMComputeBuilder = {
302  "mgpuSpGEMMCompute",
303  llvmIntPtrType,
304  {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
305  llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
306  llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
307  llvmPointerType /*void *stream*/}};
308  FunctionCallBuilder createSpGEMMCopyBuilder = {
309  "mgpuSpGEMMCopy",
310  llvmVoidType,
311  {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
312  llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
313  llvmInt32Type /*ctp*/, llvmPointerType /*void *stream*/}};
314  FunctionCallBuilder createSpMatGetSizeBuilder = {
315  "mgpuSpMatGetSize",
316  llvmVoidType,
317  {llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/,
318  llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}};
319  FunctionCallBuilder createSetCsrPointersBuilder = {
320  "mgpuSetCsrPointers",
321  llvmVoidType,
322  {llvmPointerType /*spmat*/, llvmPointerType /*pos*/,
323  llvmPointerType /*crd*/, llvmPointerType /*val*/,
324  llvmPointerType /*void *stream*/}};
325 };
326 
327 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
328 /// call. Currently it supports CUDA and ROCm (HIP).
329 class ConvertHostRegisterOpToGpuRuntimeCallPattern
330  : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
331 public:
332  ConvertHostRegisterOpToGpuRuntimeCallPattern(
333  const LLVMTypeConverter &typeConverter)
334  : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
335 
336 private:
337  LogicalResult
338  matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
339  ConversionPatternRewriter &rewriter) const override;
340 };
341 
342 class ConvertHostUnregisterOpToGpuRuntimeCallPattern
343  : public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
344 public:
345  ConvertHostUnregisterOpToGpuRuntimeCallPattern(
346  const LLVMTypeConverter &typeConverter)
347  : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
348  }
349 
350 private:
351  LogicalResult
352  matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
353  ConversionPatternRewriter &rewriter) const override;
354 };
355 
356 /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
357 /// call. Currently it supports CUDA and ROCm (HIP).
358 class ConvertAllocOpToGpuRuntimeCallPattern
359  : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
360 public:
361  ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
362  : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
363 
364 private:
365  LogicalResult
366  matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
367  ConversionPatternRewriter &rewriter) const override;
368 };
369 
370 /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
371 /// call. Currently it supports CUDA and ROCm (HIP).
372 class ConvertDeallocOpToGpuRuntimeCallPattern
373  : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
374 public:
375  ConvertDeallocOpToGpuRuntimeCallPattern(
376  const LLVMTypeConverter &typeConverter)
377  : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
378 
379 private:
380  LogicalResult
381  matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
382  ConversionPatternRewriter &rewriter) const override;
383 };
384 
385 class ConvertAsyncYieldToGpuRuntimeCallPattern
386  : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
387 public:
388  ConvertAsyncYieldToGpuRuntimeCallPattern(
389  const LLVMTypeConverter &typeConverter)
390  : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
391 
392 private:
393  LogicalResult
394  matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
395  ConversionPatternRewriter &rewriter) const override;
396 };
397 
398 /// A rewrite pattern to convert gpu.wait operations into a GPU runtime
399 /// call. Currently it supports CUDA and ROCm (HIP).
400 class ConvertWaitOpToGpuRuntimeCallPattern
401  : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
402 public:
403  ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
404  : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
405 
406 private:
407  LogicalResult
408  matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
409  ConversionPatternRewriter &rewriter) const override;
410 };
411 
412 /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
413 /// call. Currently it supports CUDA and ROCm (HIP).
414 class ConvertWaitAsyncOpToGpuRuntimeCallPattern
415  : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
416 public:
417  ConvertWaitAsyncOpToGpuRuntimeCallPattern(
418  const LLVMTypeConverter &typeConverter)
419  : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
420 
421 private:
422  LogicalResult
423  matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
424  ConversionPatternRewriter &rewriter) const override;
425 };
426 
427 /// A rewrite patter to legalize gpu.launch_func with LLVM types.
428 class LegalizeLaunchFuncOpPattern
429  : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
430 public:
431  LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
432  bool kernelBarePtrCallConv,
433  bool kernelIntersperseSizeCallConv)
434  : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
435  kernelBarePtrCallConv(kernelBarePtrCallConv),
436  kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
437 
438 private:
439  LogicalResult
440  matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
441  ConversionPatternRewriter &rewriter) const override;
442 
443  bool kernelBarePtrCallConv;
444  bool kernelIntersperseSizeCallConv;
445 };
446 
447 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
448 /// call. Currently it supports CUDA and ROCm (HIP).
449 class ConvertMemcpyOpToGpuRuntimeCallPattern
450  : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
451 public:
452  ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
453  : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
454 
455 private:
456  LogicalResult
457  matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
458  ConversionPatternRewriter &rewriter) const override;
459 };
460 
461 /// A rewrite pattern to convert gpu.memset operations into a GPU runtime
462 /// call. Currently it supports CUDA and ROCm (HIP).
463 class ConvertMemsetOpToGpuRuntimeCallPattern
464  : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
465 public:
466  ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
467  : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
468 
469 private:
470  LogicalResult
471  matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
472  ConversionPatternRewriter &rewriter) const override;
473 };
474 
475 /// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call.
476 /// Currently supports CUDA and ROCm (HIP)
477 class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
478  : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
479 public:
480  ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
481  const LLVMTypeConverter &typeConverter)
482  : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
483  typeConverter) {}
484 
485  LogicalResult
486  matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
487  ConversionPatternRewriter &rewriter) const override;
488 };
489 
490 /// Generic rewriting rule for operation on sparse matrices.
491 /// Currently supports CUDA (by means of cuSparse and cuSparseLt).
492 #define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
493  class Convert##op_name##ToGpuRuntimeCallPattern \
494  : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
495  public: \
496  Convert##op_name##ToGpuRuntimeCallPattern( \
497  const LLVMTypeConverter &typeConverter) \
498  : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
499  \
500  private: \
501  LogicalResult \
502  matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
503  ConversionPatternRewriter &rewriter) const override; \
504  };
505 
523 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp)
527 
528 } // namespace
529 
530 void GpuToLLVMConversionPass::runOnOperation() {
531  MLIRContext *context = &getContext();
532 
533  // Perform progressive lowering of vector transfer operations.
534  {
536  // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
538  /*maxTransferRank=*/1);
539  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
540  return signalPassFailure();
541  }
542 
543  LowerToLLVMOptions options(context);
544  options.useBarePtrCallConv = hostBarePtrCallConv;
545  RewritePatternSet patterns(context);
546  ConversionTarget target(*context);
547  target.addLegalDialect<LLVM::LLVMDialect>();
548  LLVMTypeConverter converter(context, options);
549 
550  // Populate all patterns from all dialects that implement the
551  // `ConvertToLLVMPatternInterface` interface.
552  for (Dialect *dialect : context->getLoadedDialects()) {
553  auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
554  if (!iface)
555  continue;
556  iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
557  }
558 
559  // Preserve GPU modules and binaries. Modules are preserved as they can be
560  // converted later by `gpu-module-to-binary`.
561  target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
562  // Accept as legal LaunchFuncOps if the operands have been lowered.
563  target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
564  [&](gpu::LaunchFuncOp op) -> bool { return converter.isLegal(op); });
565 
566  // These aren't covered by the ConvertToLLVMPatternInterface right now.
570  target);
572  kernelBarePtrCallConv,
573  kernelIntersperseSizeCallConv);
574 
575  if (failed(
576  applyPartialConversion(getOperation(), target, std::move(patterns))))
577  signalPassFailure();
578 }
579 
580 LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
581  ArrayRef<Value> arguments) const {
582  auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
583  auto function = [&] {
584  if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
585  return function;
586  return OpBuilder::atBlockEnd(module.getBody())
587  .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
588  }();
589  return builder.create<LLVM::CallOp>(loc, function, arguments);
590 }
591 
592 // Corresponding to cusparseIndexType_t defined in cusparse.h.
593 static int32_t getCuSparseIndexTypeFrom(Type type) {
594  if (type.isInteger(16))
595  return 1; // CUSPARSE_INDEX_16U
596  if (type.isInteger(32))
597  return 2; // CUSPARSE_INDEX_32I
598  return 3; // CUSPARSE_INDEX_64I
599 }
600 
601 static int32_t getCuSparseLtDataTypeFrom(Type type) {
602  if (type.isF16())
603  return 0; // CUSPARSE_COMPUTE_16F,
604  if (type.isInteger(32))
605  return 1; // CUSPARSE_COMPUTE_32I
606  llvm_unreachable("unsupported type");
607  // TODO: add support to TF32
608 }
609 
610 // Corresponding to cudaDataType_t defined in CUDA library_types.h.
611 static int32_t getCuSparseDataTypeFrom(Type type) {
612  if (llvm::isa<ComplexType>(type)) {
613  // get the element type
614  auto elementType = cast<ComplexType>(type).getElementType();
615  if (elementType.isBF16())
616  return 15; // CUDA_C_16BF
617  if (elementType.isF16())
618  return 6; // CUDA_C_16F
619  if (elementType.isF32())
620  return 4; // CUDA_C_32F
621  if (elementType.isF64())
622  return 5; // CUDA_C_64F
623  if (elementType.isInteger(8))
624  return 7; // CUDA_C_8I
625  if (elementType.isInteger(16))
626  return 21; // CUDA_C_16I
627  if (elementType.isInteger(32))
628  return 11; // CUDA_C_32I
629  }
630  if (type.isBF16())
631  return 14; // CUDA_R_16BF
632  if (type.isF16())
633  return 2; // CUDA_R_16F
634  if (type.isF32())
635  return 0; // CUDA_R_32F
636  if (type.isF64())
637  return 1; // CUDA_R_64F
638  if (type.isInteger(8))
639  return 3; // CUDA_R_8I
640  if (type.isInteger(16))
641  return 20; // CUDA_R_16I
642  if (type.isInteger(32))
643  return 10; // CUDA_R_32I
644 
645  llvm_unreachable("unsupported element type");
646 }
647 
648 static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat) {
649  return spMat.getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
650 }
651 
652 // TODO: We may want a run-time (of the mlir compiler) disablement/warning:
653 // cusparseLt currently won't work for cuda architecture <8.0 and will trigger a
654 // runtime (of the CUDA program) error , but it might be great if we could at
655 // least output a warning when we found the target architecture is <8.0 and the
656 // user still wants to use cusparseLt. to make sure when lowering gpu sparse
657 // dialect to llvm calls, the cusparselt calls are disabled for cuda
658 // architecture <8.0
659 static bool is2To4Sparsity(Value spMat) {
660  if (auto op = spMat.getDefiningOp<gpu::Create2To4SpMatOp>())
661  return true;
662  if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
663  return false;
664  if (auto op = spMat.getDefiningOp<gpu::CreateCooAoSOp>())
665  return false;
666  if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
667  return false;
668  if (auto op = spMat.getDefiningOp<gpu::CreateCscOp>())
669  return false;
670  if (auto op = spMat.getDefiningOp<gpu::CreateBsrOp>())
671  return false;
672  // Print the spMat defining op
673  spMat.getDefiningOp()->print(llvm::errs());
674  llvm_unreachable("cannot find spmat def");
675 }
676 
677 static bool isSpMMCusparseLtOp(Value op) {
678  for (Operation *user : op.getUsers()) {
679  auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
680  // If the other operator is 50% sparsity then we should use cusparseLt
681  if (!spmmOp)
682  continue;
683  if (is2To4Sparsity(spmmOp.getSpmatA()))
684  return true;
685  }
686  return false;
687 }
688 
689 // Returns whether all operands are of LLVM type.
690 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
691  ConversionPatternRewriter &rewriter) {
692  if (!llvm::all_of(operands, [](Value value) {
693  return LLVM::isCompatibleType(value.getType());
694  }))
695  return rewriter.notifyMatchFailure(
696  op, "Cannot convert if operands aren't of LLVM type.");
697  return success();
698 }
699 
700 static LogicalResult
702  gpu::AsyncOpInterface op) {
703  if (op.getAsyncDependencies().size() != 1)
704  return rewriter.notifyMatchFailure(
705  op, "Can only convert with exactly one async dependency.");
706 
707  if (!op.getAsyncToken())
708  return rewriter.notifyMatchFailure(op, "Can convert only async version.");
709 
710  return success();
711 }
712 
713 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
714  gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
715  ConversionPatternRewriter &rewriter) const {
716  auto *op = hostRegisterOp.getOperation();
717  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
718  return failure();
719 
720  Location loc = op->getLoc();
721 
722  auto memRefType = hostRegisterOp.getValue().getType();
723  auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
724  auto elementSize = getSizeInBytes(loc, elementType, rewriter);
725 
726  auto arguments = getTypeConverter()->promoteOperands(
727  loc, op->getOperands(), adaptor.getOperands(), rewriter);
728  arguments.push_back(elementSize);
729  hostRegisterCallBuilder.create(loc, rewriter, arguments);
730 
731  rewriter.eraseOp(op);
732  return success();
733 }
734 
735 LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
736  gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
737  ConversionPatternRewriter &rewriter) const {
738  Operation *op = hostUnregisterOp.getOperation();
739  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
740  return failure();
741 
742  Location loc = op->getLoc();
743 
744  auto memRefType = hostUnregisterOp.getValue().getType();
745  auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
746  auto elementSize = getSizeInBytes(loc, elementType, rewriter);
747 
748  auto arguments = getTypeConverter()->promoteOperands(
749  loc, op->getOperands(), adaptor.getOperands(), rewriter);
750  arguments.push_back(elementSize);
751  hostUnregisterCallBuilder.create(loc, rewriter, arguments);
752 
753  rewriter.eraseOp(op);
754  return success();
755 }
756 
757 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
758  gpu::AllocOp allocOp, OpAdaptor adaptor,
759  ConversionPatternRewriter &rewriter) const {
760 
761  MemRefType memRefType = allocOp.getType();
762 
763  if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
764  !isConvertibleAndHasIdentityMaps(memRefType))
765  return failure();
766 
767  auto loc = allocOp.getLoc();
768 
769  bool isShared = allocOp.getHostShared();
770 
771  if (isShared && allocOp.getAsyncToken())
772  return rewriter.notifyMatchFailure(
773  allocOp, "Host Shared allocation cannot be done async");
774  if (!isShared && failed(isAsyncWithOneDependency(rewriter, allocOp)))
775  return failure();
776 
777  // Get shape of the memref as values: static sizes are constant
778  // values and dynamic sizes are passed to 'alloc' as operands.
779  SmallVector<Value, 4> shape;
780  SmallVector<Value, 4> strides;
781  Value sizeBytes;
782  getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
783  shape, strides, sizeBytes);
784 
785  // Allocate the underlying buffer and store a pointer to it in the MemRef
786  // descriptor.
787  auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
788  Value stream = adaptor.getAsyncDependencies().empty()
789  ? nullPtr
790  : adaptor.getAsyncDependencies().front();
791 
792  auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>(
793  loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
794 
795  Value allocatedPtr =
796  allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
797  .getResult();
798 
799  // No alignment.
800  Value alignedPtr = allocatedPtr;
801 
802  // Create the MemRef descriptor.
803  auto memRefDescriptor = this->createMemRefDescriptor(
804  loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
805 
806  if (allocOp.getAsyncToken()) {
807  // Async alloc: make dependent ops use the same stream.
808  rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
809  } else {
810  rewriter.replaceOp(allocOp, {memRefDescriptor});
811  }
812 
813  return success();
814 }
815 
816 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
817  gpu::DeallocOp deallocOp, OpAdaptor adaptor,
818  ConversionPatternRewriter &rewriter) const {
819  if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
820  failed(isAsyncWithOneDependency(rewriter, deallocOp)))
821  return failure();
822 
823  Location loc = deallocOp.getLoc();
824 
825  Value pointer =
826  MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
827  Value stream = adaptor.getAsyncDependencies().front();
828  deallocCallBuilder.create(loc, rewriter, {pointer, stream});
829 
830  rewriter.replaceOp(deallocOp, {stream});
831  return success();
832 }
833 
834 static bool isGpuAsyncTokenType(Value value) {
835  return isa<gpu::AsyncTokenType>(value.getType());
836 }
837 
838 // Converts !gpu.async.token operands of `async.yield` to runtime calls. The
839 // !gpu.async.token are lowered to stream within the async.execute region, but
840 // are passed as events between them. For each !gpu.async.token operand, we
841 // create an event and record it on the stream.
842 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
843  async::YieldOp yieldOp, OpAdaptor adaptor,
844  ConversionPatternRewriter &rewriter) const {
845  if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType))
846  return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
847 
848  Location loc = yieldOp.getLoc();
849  SmallVector<Value, 4> newOperands(adaptor.getOperands());
850  llvm::SmallDenseSet<Value> streams;
851  for (auto &operand : yieldOp->getOpOperands()) {
852  if (!isGpuAsyncTokenType(operand.get()))
853  continue;
854  auto idx = operand.getOperandNumber();
855  auto stream = adaptor.getOperands()[idx];
856  auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
857  eventRecordCallBuilder.create(loc, rewriter, {event, stream});
858  newOperands[idx] = event;
859  streams.insert(stream);
860  }
861  for (auto stream : streams)
862  streamDestroyCallBuilder.create(loc, rewriter, {stream});
863 
864  rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
865  return success();
866 }
867 
868 // Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
869 static bool isDefinedByCallTo(Value value, StringRef functionName) {
870  assert(isa<LLVM::LLVMPointerType>(value.getType()));
871  if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
872  return *defOp.getCallee() == functionName;
873  return false;
874 }
875 
876 // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
877 // with the stream/event operands. The operands are destroyed. That is, it
878 // assumes that it is not used afterwards or elsewhere. Otherwise we will get a
879 // runtime error. Eventually, we should guarantee this property.
880 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
881  gpu::WaitOp waitOp, OpAdaptor adaptor,
882  ConversionPatternRewriter &rewriter) const {
883  if (waitOp.getAsyncToken())
884  return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
885 
886  Location loc = waitOp.getLoc();
887 
888  for (auto operand : adaptor.getOperands()) {
889  if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
890  // The converted operand's definition created a stream.
891  streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
892  streamDestroyCallBuilder.create(loc, rewriter, {operand});
893  } else {
894  // Otherwise the converted operand is an event. This assumes that we use
895  // events in control flow code as well.
896  eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
897  eventDestroyCallBuilder.create(loc, rewriter, {operand});
898  }
899  }
900 
901  rewriter.eraseOp(waitOp);
902  return success();
903 }
904 
905 // Converts `gpu.wait async` to runtime calls. The converted op creates a new
906 // stream that is synchronized with stream/event operands. The operands are
907 // destroyed. That is, it assumes that it is not used afterwards or elsewhere.
908 // Otherwise we will get a runtime error. Eventually, we should guarantee this
909 // property.
910 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
911  gpu::WaitOp waitOp, OpAdaptor adaptor,
912  ConversionPatternRewriter &rewriter) const {
913  if (!waitOp.getAsyncToken())
914  return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
915 
916  Location loc = waitOp.getLoc();
917 
918  auto insertionPoint = rewriter.saveInsertionPoint();
919  SmallVector<Value, 1> events;
920  for (auto pair :
921  llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
922  auto operand = std::get<1>(pair);
923  if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
924  // The converted operand's definition created a stream. Insert an event
925  // into the stream just after the last use of the original token operand.
926  auto *defOp = std::get<0>(pair).getDefiningOp();
927  rewriter.setInsertionPointAfter(defOp);
928  auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
929  eventRecordCallBuilder.create(loc, rewriter, {event, operand});
930  events.push_back(event);
931  } else {
932  // Otherwise the converted operand is an event. This assumes that we use
933  // events in control flow code as well.
934  events.push_back(operand);
935  }
936  }
937  rewriter.restoreInsertionPoint(insertionPoint);
938  auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
939  for (auto event : events)
940  streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
941  for (auto event : events)
942  eventDestroyCallBuilder.create(loc, rewriter, {event});
943  rewriter.replaceOp(waitOp, {stream});
944 
945  return success();
946 }
947 
948 // Legalize the op's operands.
949 LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
950  gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
951  ConversionPatternRewriter &rewriter) const {
952  if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
953  return failure();
954 
955  if (launchOp.getAsyncDependencies().size() > 1)
956  return rewriter.notifyMatchFailure(
957  launchOp, "Cannot convert with more than one async dependency.");
958 
959  // Fail when the synchronous version of the op has async dependencies. The
960  // lowering destroys the stream, and we do not want to check that there is no
961  // use of the stream after this op.
962  if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
963  return rewriter.notifyMatchFailure(
964  launchOp, "Cannot convert non-async op with async dependencies.");
965 
966  Location loc = launchOp.getLoc();
967 
968  Value stream = Value();
969  if (!adaptor.getAsyncDependencies().empty())
970  stream = adaptor.getAsyncDependencies().front();
971  // If the async keyword is present and there are no dependencies, then a
972  // stream must be created to pass to subsequent operations.
973  else if (launchOp.getAsyncToken())
974  stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
975 
976  // Lower the kernel operands to match kernel parameters.
977  // Note: If `useBarePtrCallConv` is set in the type converter's options,
978  // the value of `kernelBarePtrCallConv` will be ignored.
979  OperandRange origArguments = launchOp.getKernelOperands();
980  SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
981  loc, origArguments, adaptor.getKernelOperands(), rewriter,
982  /*useBarePtrCallConv=*/kernelBarePtrCallConv);
983  SmallVector<Value, 8> llvmArgumentsWithSizes;
984 
985  // Intersperse size information if requested.
986  if (kernelIntersperseSizeCallConv) {
987  if (origArguments.size() != llvmArguments.size()) {
988  // This shouldn't happen if the bare-pointer calling convention is used.
989  return rewriter.notifyMatchFailure(
990  launchOp,
991  "Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
992  }
993 
994  llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
995  for (auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
996  auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
997  if (!memrefTy) {
998  return rewriter.notifyMatchFailure(
999  launchOp, "Operand to launch op is not a memref.");
1000  }
1001 
1002  if (!memrefTy.hasStaticShape() ||
1003  !memrefTy.getElementType().isIntOrFloat()) {
1004  return rewriter.notifyMatchFailure(
1005  launchOp, "Operand to launch op is not a memref with a static "
1006  "shape and an integer or float element type.");
1007  }
1008 
1009  unsigned bitwidth = memrefTy.getElementTypeBitWidth();
1010  if (bitwidth % 8 != 0) {
1011  return rewriter.notifyMatchFailure(
1012  launchOp, "Operand to launch op is not a memref with a "
1013  "byte-aligned element type.");
1014  }
1015 
1016  uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) *
1017  static_cast<uint64_t>(memrefTy.getNumElements());
1018 
1019  Value sizeArg = rewriter.create<LLVM::ConstantOp>(
1020  loc, getIndexType(), rewriter.getIndexAttr(staticSize));
1021  llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer.
1022  llvmArgumentsWithSizes.push_back(sizeArg);
1023  }
1024  }
1025 
1026  std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1027  if (launchOp.hasClusterSize()) {
1028  clusterSize =
1029  gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
1030  adaptor.getClusterSizeZ()};
1031  }
1032  rewriter.create<gpu::LaunchFuncOp>(
1033  launchOp.getLoc(), launchOp.getKernelAttr(),
1034  gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1035  adaptor.getGridSizeZ()},
1036  gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1037  adaptor.getBlockSizeZ()},
1038  adaptor.getDynamicSharedMemorySize(),
1039  llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
1040  stream, clusterSize);
1041  if (launchOp.getAsyncToken())
1042  rewriter.replaceOp(launchOp, {stream});
1043  else
1044  rewriter.eraseOp(launchOp);
1045  return success();
1046 }
1047 
1049  ConversionPatternRewriter &rewriter,
1050  LLVM::LLVMPointerType destinationType,
1051  Value sourcePtr,
1052  const LLVMTypeConverter &typeConverter) {
1053  auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
1054  if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1055  sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1056  loc,
1058  destinationType.getAddressSpace()),
1059  sourcePtr);
1060  return sourcePtr;
1061 }
1062 
1063 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1064  gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1065  ConversionPatternRewriter &rewriter) const {
1066  auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1067 
1068  if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
1069  !isConvertibleAndHasIdentityMaps(memRefType) ||
1070  failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
1071  return failure();
1072 
1073  auto loc = memcpyOp.getLoc();
1074 
1075  MemRefDescriptor srcDesc(adaptor.getSrc());
1076  Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
1077 
1078  Type elementPtrType = getElementPtrType(memRefType);
1079  Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
1080  Value gepPtr = rewriter.create<LLVM::GEPOp>(
1081  loc, elementPtrType,
1082  typeConverter->convertType(memRefType.getElementType()), nullPtr,
1083  numElements);
1084  auto sizeBytes =
1085  rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
1086 
1087  auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1088  srcDesc.alignedPtr(rewriter, loc),
1089  *getTypeConverter());
1090  auto dst = bitAndAddrspaceCast(
1091  loc, rewriter, llvmPointerType,
1092  MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
1093  *getTypeConverter());
1094 
1095  auto stream = adaptor.getAsyncDependencies().front();
1096  memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1097 
1098  rewriter.replaceOp(memcpyOp, {stream});
1099 
1100  return success();
1101 }
1102 
1103 LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1104  gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1105  ConversionPatternRewriter &rewriter) const {
1106  auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1107 
1108  if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
1109  !isConvertibleAndHasIdentityMaps(memRefType) ||
1110  failed(isAsyncWithOneDependency(rewriter, memsetOp)))
1111  return failure();
1112 
1113  auto loc = memsetOp.getLoc();
1114 
1115  Type valueType = adaptor.getValue().getType();
1116  unsigned bitWidth = valueType.getIntOrFloatBitWidth();
1117  // Ints and floats of 16 or 32 bit width are allowed.
1118  if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1119  return rewriter.notifyMatchFailure(
1120  memsetOp, "value must be a 16 or 32 bit int or float");
1121  }
1122 
1123  unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth();
1124  Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1125 
1126  MemRefDescriptor dstDesc(adaptor.getDst());
1127  Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
1128 
1129  auto value =
1130  rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
1131  auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1132  dstDesc.alignedPtr(rewriter, loc),
1133  *getTypeConverter());
1134 
1135  auto stream = adaptor.getAsyncDependencies().front();
1136  FunctionCallBuilder builder =
1137  valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1138  builder.create(loc, rewriter, {dst, value, numElements, stream});
1139 
1140  rewriter.replaceOp(memsetOp, {stream});
1141  return success();
1142 }
1143 
1144 LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1145  gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1146  ConversionPatternRewriter &rewriter) const {
1147  Location loc = op.getLoc();
1148  auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1149  {adaptor.getDevIndex()});
1150  rewriter.replaceOp(op, call);
1151  return success();
1152 }
1153 
1154 template <typename T>
1155 static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) {
1156  Type llvmInt32Type = builder.getIntegerType(32);
1157  return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
1158  static_cast<int32_t>(tValue));
1159 }
1160 
1161 template <typename T>
1162 static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) {
1163  Type llvmFloat32Type = builder.getF32Type();
1164  return builder.create<LLVM::ConstantOp>(
1165  loc, llvmFloat32Type,
1166  builder.getF32FloatAttr(static_cast<float>(tValue)));
1167 }
1168 
1169 LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1170  gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1171  ConversionPatternRewriter &rewriter) const {
1172  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1173  failed(isAsyncWithOneDependency(rewriter, op)))
1174  return failure();
1175  Location loc = op.getLoc();
1176  auto stream = adaptor.getAsyncDependencies().front();
1177  Value pTensor =
1178  MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1179  Type dType = op.getMemref().getType().getElementType();
1180  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1181 
1182  SmallVector<Value, 4> dims;
1183  for (Value dim : adaptor.getDims()) {
1184  dims.push_back(dim);
1185  }
1186 
1187  Value handle;
1188  // TODO: For now, we track the use of the handle and lower it to cusparse /
1189  // cusparseLt accordingly. If in a block, both cusparse and cusparseLt are
1190  // used, we require two separate Creation ops to be the correct logic. In
1191  // future, we may add support to using one handle in sparse tensor / GPU
1192  // dialect in both cusparse and cusparseLt. use the cusparseLt create call if
1193  // the dnmat is used with spmat with 2:4 sparsity
1194  if (dims.size() == 2) {
1195  if (isSpMMCusparseLtOp(op.getDnTensor())) {
1196  auto handleSz = rewriter.create<LLVM::ConstantOp>(
1197  loc, getIndexType(), rewriter.getIndexAttr(11032));
1198  handle = rewriter.create<LLVM::AllocaOp>(
1199  loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1200  handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1201 
1202  createLtDnMatCallBuilder
1203  .create(loc, rewriter,
1204  {handle, dims[0], dims[1], pTensor, dtp, stream})
1205  .getResult();
1206  } else {
1207  handle =
1208  createDnMatCallBuilder
1209  .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1210  .getResult();
1211  }
1212  } else {
1213  assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1214  handle = createDnVecCallBuilder
1215  .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1216  .getResult();
1217  }
1218  rewriter.replaceOp(op, {handle, stream});
1219  return success();
1220 }
1221 
1222 LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1223  gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1224  ConversionPatternRewriter &rewriter) const {
1225  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1226  failed(isAsyncWithOneDependency(rewriter, op)))
1227  return failure();
1228  Location loc = op.getLoc();
1229  auto stream = adaptor.getAsyncDependencies().front();
1230  auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
1231  SmallVector<Value, 4> dims;
1232  for (Value dim : definingOp.getDims()) {
1233  dims.push_back(dim);
1234  }
1235  if (dims.size() == 2) {
1236  // Use the cusparseLt destroy call if the dnmat is used with spmat with
1237  // 2:4 sparsity
1238  if (isSpMMCusparseLtOp(op.getDnTensor())) {
1239  destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1240  {adaptor.getDnTensor(), stream});
1241  } else {
1242  destroyDnMatCallBuilder.create(loc, rewriter,
1243  {adaptor.getDnTensor(), stream});
1244  }
1245  } else {
1246  assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1247  destroyDnVecCallBuilder.create(loc, rewriter,
1248  {adaptor.getDnTensor(), stream});
1249  }
1250  rewriter.replaceOp(op, {stream});
1251  return success();
1252 }
1253 
1254 LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1255  gpu::CreateCooOp op, OpAdaptor adaptor,
1256  ConversionPatternRewriter &rewriter) const {
1257  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1258  failed(isAsyncWithOneDependency(rewriter, op)))
1259  return failure();
1260  Location loc = op.getLoc();
1261  auto stream = adaptor.getAsyncDependencies().front();
1262  Value pRowIdxs =
1263  MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1264  Value pColIdxs =
1265  MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1266  Value pValues =
1267  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1268  Type iType =
1269  llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1270  Type dType =
1271  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1272  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1273  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1274  auto handle =
1275  createCooCallBuilder
1276  .create(loc, rewriter,
1277  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1278  pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1279  .getResult();
1280  rewriter.replaceOp(op, {handle, stream});
1281  return success();
1282 }
1283 
1284 LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1285  gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1286  ConversionPatternRewriter &rewriter) const {
1287  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1288  failed(isAsyncWithOneDependency(rewriter, op)))
1289  return failure();
1290  Location loc = op.getLoc();
1291  auto stream = adaptor.getAsyncDependencies().front();
1292  Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc);
1293  Value pValues =
1294  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1295  Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1296  Type dType =
1297  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1298  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1299  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1300  auto handle =
1301  createCooAoSCallBuilder
1302  .create(loc, rewriter,
1303  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1304  pIdxs, pValues, itp, dtp, stream})
1305  .getResult();
1306  rewriter.replaceOp(op, {handle, stream});
1307  return success();
1308 }
1309 
1310 LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1311  gpu::CreateCsrOp op, OpAdaptor adaptor,
1312  ConversionPatternRewriter &rewriter) const {
1313  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1314  failed(isAsyncWithOneDependency(rewriter, op)))
1315  return failure();
1316  Location loc = op.getLoc();
1317  auto stream = adaptor.getAsyncDependencies().front();
1318  Value pRowPos =
1319  MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc);
1320  Value pColIdxs =
1321  MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1322  Value pValues =
1323  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1324  Type pType =
1325  llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1326  Type iType =
1327  llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1328  Type dType =
1329  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1330  auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1331  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1332  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1333  auto handle =
1334  createCsrCallBuilder
1335  .create(loc, rewriter,
1336  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1337  pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1338  .getResult();
1339  rewriter.replaceOp(op, {handle, stream});
1340  return success();
1341 }
1342 
1343 LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1344  gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1345  ConversionPatternRewriter &rewriter) const {
1346  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1347  failed(isAsyncWithOneDependency(rewriter, op)))
1348  return failure();
1349  Location loc = op.getLoc();
1350  auto stream = adaptor.getAsyncDependencies().front();
1351  Value pMat =
1352  MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1353  Type dType =
1354  llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1355  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1356 
1357  // CUDA runner asserts the size is 44104 bytes.
1358  auto handleSz = rewriter.create<LLVM::ConstantOp>(
1359  loc, getIndexType(), rewriter.getIndexAttr(44104));
1360  Value handle = rewriter.create<LLVM::AllocaOp>(
1361  loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1362  handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1363 
1364  create2To4SpMatCallBuilder
1365  .create(loc, rewriter,
1366  {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1367  .getResult();
1368  rewriter.replaceOp(op, {handle, stream});
1369  return success();
1370 }
1371 
1372 LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1373  gpu::DestroySpMatOp op, OpAdaptor adaptor,
1374  ConversionPatternRewriter &rewriter) const {
1375  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1376  failed(isAsyncWithOneDependency(rewriter, op)))
1377  return failure();
1378  Location loc = op.getLoc();
1379  auto stream = adaptor.getAsyncDependencies().front();
1380  // Use the cusparseLt destroy call if the spmat is 2:4 sparsity
1381  if (is2To4Sparsity(op.getSpmat())) {
1382  destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1383  {adaptor.getSpmat(), stream});
1384 
1385  } else {
1386  destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1387  }
1388  rewriter.replaceOp(op, {stream});
1389  return success();
1390 }
1391 
1392 LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1393  gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1394  ConversionPatternRewriter &rewriter) const {
1395  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1396  failed(isAsyncWithOneDependency(rewriter, op)))
1397  return failure();
1398  Location loc = op.getLoc();
1399  auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
1400  auto computeType = genConstInt32From(
1401  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1402  auto stream = adaptor.getAsyncDependencies().front();
1403  auto bufferSize = spMVBufferSizeCallBuilder
1404  .create(loc, rewriter,
1405  {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1406  adaptor.getDnY(), computeType, stream})
1407  .getResult();
1408  rewriter.replaceOp(op, {bufferSize, stream});
1409  return success();
1410 }
1411 
1412 LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1413  gpu::SpMVOp op, OpAdaptor adaptor,
1414  ConversionPatternRewriter &rewriter) const {
1415  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1416  failed(isAsyncWithOneDependency(rewriter, op)))
1417  return failure();
1418  Location loc = op.getLoc();
1419  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1420  auto computeType = genConstInt32From(
1421  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1422  auto stream = adaptor.getAsyncDependencies().front();
1423  Value pBuf =
1424  MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1425  spMVCallBuilder.create(loc, rewriter,
1426  {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1427  adaptor.getDnY(), computeType, pBuf, stream});
1428  rewriter.replaceOp(op, {stream});
1429  return success();
1430 }
1431 
1432 LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1433  gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1434  ConversionPatternRewriter &rewriter) const {
1435  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1436  failed(isAsyncWithOneDependency(rewriter, op)))
1437  return failure();
1438  Location loc = op.getLoc();
1439  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1440  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1441  auto stream = adaptor.getAsyncDependencies().front();
1442  Value bufferSize;
1443  if (is2To4Sparsity(op.getSpmatA())) {
1444  auto pruneFlag =
1445  genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA()));
1446  auto computeType = genConstInt32From(
1447  rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
1448  auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1449  rewriter.getIndexAttr(3));
1450  auto bufferSize = rewriter.create<LLVM::AllocaOp>(
1451  loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16);
1452  createCuSparseLtSpMMBufferSizeBuilder
1453  .create(loc, rewriter,
1454  {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1455  adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1456  pruneFlag, stream})
1457  .getResult();
1458 
1459  auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>(
1460  loc, llvmPointerType, llvmPointerType, bufferSize,
1461  ValueRange{rewriter.create<LLVM::ConstantOp>(
1462  loc, getIndexType(), rewriter.getIndexAttr(1))});
1463  auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>(
1464  loc, llvmPointerType, llvmPointerType, bufferSize,
1465  ValueRange{rewriter.create<LLVM::ConstantOp>(
1466  loc, getIndexType(), rewriter.getIndexAttr(2))});
1467  auto bufferSize0 =
1468  rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize);
1469  auto bufferSize1 =
1470  rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1);
1471  auto bufferSize2 =
1472  rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
1473 
1474  rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1475  } else {
1476  auto computeType = genConstInt32From(
1477  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1478  bufferSize =
1479  createSpMMBufferSizeCallBuilder
1480  .create(loc, rewriter,
1481  {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1482  adaptor.getDnmatC(), computeType, stream})
1483  .getResult();
1484  rewriter.replaceOp(op, {bufferSize, stream});
1485  }
1486  return success();
1487 }
1488 
1489 LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1490  gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1491  ConversionPatternRewriter &rewriter) const {
1492  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1493  failed(isAsyncWithOneDependency(rewriter, op)))
1494  return failure();
1495  Location loc = op.getLoc();
1496  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1497  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1498  auto computeType = genConstInt32From(
1499  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1500  auto stream = adaptor.getAsyncDependencies().front();
1501  auto bufferSize =
1502  createSDDMMBufferSizeCallBuilder
1503  .create(loc, rewriter,
1504  {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1505  adaptor.getSpmatC(), computeType, stream})
1506  .getResult();
1507  rewriter.replaceOp(op, {bufferSize, stream});
1508  return success();
1509 }
1510 
1511 LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1512  gpu::SpMMOp op, OpAdaptor adaptor,
1513  ConversionPatternRewriter &rewriter) const {
1514  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1515  failed(isAsyncWithOneDependency(rewriter, op)))
1516  return failure();
1517  Location loc = op.getLoc();
1518  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1519  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1520  auto computeType = genConstInt32From(
1521  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1522 
1523  auto stream = adaptor.getAsyncDependencies().front();
1524 
1525  // Lower to cusparseLt if applicable
1526  if (is2To4Sparsity(op.getSpmatA())) {
1527  SmallVector<Value> pBufs;
1528  for (Value buffer : adaptor.getBuffers()) {
1529  Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
1530  pBufs.push_back(pBuf);
1531  }
1532  createCuSparseLtSpMMBuilder.create(
1533  loc, rewriter,
1534  {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1535  pBufs[0], pBufs[1], pBufs[2], stream});
1536  } else {
1537  Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
1538  .allocatedPtr(rewriter, loc);
1539  createSpMMCallBuilder.create(loc, rewriter,
1540  {modeA, modeB, adaptor.getSpmatA(),
1541  adaptor.getDnmatB(), adaptor.getDnmatC(),
1542  computeType, pBuf, stream});
1543  }
1544  rewriter.replaceOp(op, {stream});
1545  return success();
1546 }
1547 
1548 template <typename T>
1550  converter.addConversion([&converter](T) -> Type {
1551  return LLVM::LLVMPointerType::get(&converter.getContext());
1552  });
1553 }
1554 
1555 LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1556  gpu::SDDMMOp op, OpAdaptor adaptor,
1557  ConversionPatternRewriter &rewriter) const {
1558  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1559  failed(isAsyncWithOneDependency(rewriter, op)))
1560  return failure();
1561  Location loc = op.getLoc();
1562  auto computeType = genConstInt32From(
1563  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1564  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1565  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1566  auto stream = adaptor.getAsyncDependencies().front();
1567  Value pBuf =
1568  MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1569  createSDDMMCallBuilder.create(loc, rewriter,
1570  {modeA, modeB, adaptor.getDnmatA(),
1571  adaptor.getDnmatB(), adaptor.getSpmatC(),
1572  computeType, pBuf, stream});
1573  rewriter.replaceOp(op, {stream});
1574  return success();
1575 }
1576 
1577 LogicalResult
1578 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1579  gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1580  ConversionPatternRewriter &rewriter) const {
1581  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1582  failed(isAsyncWithOneDependency(rewriter, op)))
1583  return failure();
1584  Location loc = op.getLoc();
1585  auto stream = adaptor.getAsyncDependencies().front();
1586  Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1587  .getResult();
1588  rewriter.replaceOp(op, {descr, stream});
1589  return success();
1590 }
1591 
1592 LogicalResult
1593 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1594  gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1595  ConversionPatternRewriter &rewriter) const {
1596  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1597  failed(isAsyncWithOneDependency(rewriter, op)))
1598  return failure();
1599  Location loc = op.getLoc();
1600  auto stream = adaptor.getAsyncDependencies().front();
1601  createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1602  {adaptor.getDesc(), stream});
1603  rewriter.replaceOp(op, {stream});
1604  return success();
1605 }
1606 
1607 LogicalResult
1608 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1609  gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1610  ConversionPatternRewriter &rewriter) const {
1611  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1612  failed(isAsyncWithOneDependency(rewriter, op)))
1613  return failure();
1614  Location loc = op.getLoc();
1615  auto computeType = genConstInt32From(
1616  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1617  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1618  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1619  auto stream = adaptor.getAsyncDependencies().front();
1620 
1621  Value pBuf =
1622  MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1623  Value bufferSizeNew;
1624 
1625  if (adaptor.getKind() ==
1626  gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1627  bufferSizeNew =
1628  createSpGEMMWorkEstimationBuilder
1629  .create(loc, rewriter,
1630  {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1631  adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1632  adaptor.getBufferSz(), pBuf, stream})
1633  .getResult();
1634  } else {
1635  bufferSizeNew =
1636  createSpGEMMComputeBuilder
1637  .create(loc, rewriter,
1638  {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1639  adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1640  adaptor.getBufferSz(), pBuf, stream})
1641  .getResult();
1642  }
1643  rewriter.replaceOp(op, {bufferSizeNew, stream});
1644  return success();
1645 }
1646 
1647 LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1648  gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1649  ConversionPatternRewriter &rewriter) const {
1650  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1651  failed(isAsyncWithOneDependency(rewriter, op)))
1652  return failure();
1653  Location loc = op.getLoc();
1654  auto computeType = genConstInt32From(
1655  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1656  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1657  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1658  auto stream = adaptor.getAsyncDependencies().front();
1659  createSpGEMMCopyBuilder.create(loc, rewriter,
1660  {adaptor.getDesc(), modeA, modeB,
1661  adaptor.getSpmatA(), adaptor.getSpmatB(),
1662  adaptor.getSpmatC(), computeType, stream});
1663  rewriter.replaceOp(op, {stream});
1664  return success();
1665 }
1666 
1667 LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1668  gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1669  ConversionPatternRewriter &rewriter) const {
1670  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1671  failed(isAsyncWithOneDependency(rewriter, op)))
1672  return failure();
1673  Location loc = op.getLoc();
1674  auto stream = adaptor.getAsyncDependencies().front();
1675 
1676  auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1677  rewriter.getIndexAttr(3));
1678  auto buffer = rewriter.create<LLVM::AllocaOp>(
1679  loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16);
1680 
1681  auto rowsPtr = rewriter.create<LLVM::GEPOp>(
1682  loc, llvmPointerType, llvmPointerType, buffer,
1683  ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1684  rewriter.getIndexAttr(0))});
1685  auto colsPtr = rewriter.create<LLVM::GEPOp>(
1686  loc, llvmPointerType, llvmPointerType, buffer,
1687  ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1688  rewriter.getIndexAttr(1))});
1689  auto nnzsPtr = rewriter.create<LLVM::GEPOp>(
1690  loc, llvmPointerType, llvmPointerType, buffer,
1691  ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1692  rewriter.getIndexAttr(2))});
1693  createSpMatGetSizeBuilder.create(
1694  loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1695  auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr);
1696  auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr);
1697  auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr);
1698 
1699  rewriter.replaceOp(op, {rows, cols, nnzs, stream});
1700  return success();
1701 }
1702 
1703 LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1704  gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1705  ConversionPatternRewriter &rewriter) const {
1706  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1707  failed(isAsyncWithOneDependency(rewriter, op)))
1708  return failure();
1709  Location loc = op.getLoc();
1710  auto stream = adaptor.getAsyncDependencies().front();
1711  Value pPos =
1712  MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
1713  Value pCrd =
1714  MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
1715  Value pVal =
1716  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1717  createSetCsrPointersBuilder.create(
1718  loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1719  rewriter.replaceOp(op, {stream});
1720  return success();
1721 }
1722 
1723 LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1724  gpu::CreateCscOp op, OpAdaptor adaptor,
1725  ConversionPatternRewriter &rewriter) const {
1726  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1727  failed(isAsyncWithOneDependency(rewriter, op)))
1728  return failure();
1729  Location loc = op.getLoc();
1730  auto stream = adaptor.getAsyncDependencies().front();
1731  Value pColPos =
1732  MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc);
1733  Value pRowIdxs =
1734  MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1735  Value pValues =
1736  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1737  Type pType =
1738  llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1739  Type iType =
1740  llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1741  Type dType =
1742  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1743  auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1744  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1745  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1746  auto handle =
1747  createCscCallBuilder
1748  .create(loc, rewriter,
1749  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1750  pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1751  .getResult();
1752  rewriter.replaceOp(op, {handle, stream});
1753  return success();
1754 }
1755 
1756 LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1757  gpu::CreateBsrOp op, OpAdaptor adaptor,
1758  ConversionPatternRewriter &rewriter) const {
1759  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1760  failed(isAsyncWithOneDependency(rewriter, op)))
1761  return failure();
1762  Location loc = op.getLoc();
1763  auto stream = adaptor.getAsyncDependencies().front();
1764  Value pRowPos =
1765  MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc);
1766  Value pColIdxs =
1767  MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc);
1768  Value pValues =
1769  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1770  Type pType =
1771  llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1772  Type iType =
1773  llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1774  Type dType =
1775  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1776  auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1777  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1778  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1779  auto handle =
1780  createBsrCallBuilder
1781  .create(loc, rewriter,
1782  {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1783  adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1784  pColIdxs, pValues, ptp, itp, dtp, stream})
1785  .getResult();
1786  rewriter.replaceOp(op, {handle, stream});
1787  return success();
1788 }
1789 
1792  bool kernelBarePtrCallConv, bool kernelIntersperseSizeCallConv) {
1793  addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
1794  addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
1795  addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
1796  addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter);
1797 
1798  patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
1799  ConvertDeallocOpToGpuRuntimeCallPattern,
1800  ConvertHostRegisterOpToGpuRuntimeCallPattern,
1801  ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1802  ConvertMemcpyOpToGpuRuntimeCallPattern,
1803  ConvertMemsetOpToGpuRuntimeCallPattern,
1804  ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1805  ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1806  ConvertWaitOpToGpuRuntimeCallPattern,
1807  ConvertAsyncYieldToGpuRuntimeCallPattern,
1808  ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1809  ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1810  ConvertCreateCooOpToGpuRuntimeCallPattern,
1811  ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1812  ConvertCreateCsrOpToGpuRuntimeCallPattern,
1813  ConvertCreateCscOpToGpuRuntimeCallPattern,
1814  ConvertCreateBsrOpToGpuRuntimeCallPattern,
1815  ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1816  ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1817  ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1818  ConvertSpMVOpToGpuRuntimeCallPattern,
1819  ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1820  ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1821  ConvertSpMMOpToGpuRuntimeCallPattern,
1822  ConvertSDDMMOpToGpuRuntimeCallPattern,
1823  ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1824  ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1825  ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1826  ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1827  ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1828  ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
1829  patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv,
1830  kernelIntersperseSizeCallConv);
1831 }
1832 
1833 //===----------------------------------------------------------------------===//
1834 // GPUModuleOp convert to LLVM op interface
1835 //===----------------------------------------------------------------------===//
1836 
1837 namespace {
1838 struct GPUModuleOpConvertToLLVMInterface
1839  : public ConvertToLLVMOpInterface::ExternalModel<
1840  GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
1841  /// Get the conversion patterns from the target attribute.
1842  void getConvertToLLVMConversionAttrs(
1844 };
1845 } // namespace
1846 
1847 void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
1849  auto module = cast<gpu::GPUModuleOp>(op);
1850  ArrayAttr targetsAttr = module.getTargetsAttr();
1851  // Fail if there are no target attributes or there is more than one target.
1852  if (!targetsAttr || targetsAttr.size() != 1)
1853  return;
1854  if (auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
1855  attrs.push_back(patternAttr);
1856 }
1857 
1859  registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
1860  gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
1861  });
1862 }
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
int64_t cols
int64_t rows
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
FloatType getF32Type()
Definition: Builders.cpp:43
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
MLIRContext * getContext() const
Definition: Builders.h:56
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:242
IntegerAttr getI8IntegerAttr(int8_t value)
Definition: Builders.cpp:217
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:148
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:36
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:53
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
This class helps build Operations.
Definition: Builders.h:205
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition: Builders.h:383
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition: Builders.h:244
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition: Builders.h:388
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:446
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
Definition: Region.h:205
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:648
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:864
void registerConvertGpuToLLVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMOpInterface interface on the gpu::GPUModuleOP operation.
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Definition: MemoryOps.cpp:263
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool kernelBarePtrCallConv=false, bool kernelIntersperseSizeCallConv=false)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::LLVMFunctionType functionType
Definition: GPUCommonPass.h:60
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:39