MLIR  20.0.0git
GPUToLLVMConversion.cpp
Go to the documentation of this file.
1 //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert gpu.launch_func op into a sequence of
10 // GPU runtime calls. As most of GPU runtimes does not have a stable published
11 // ABI, this pass uses a slim runtime layer that builds on top of the public
12 // API from GPU runtime headers.
13 //
14 //===----------------------------------------------------------------------===//
15 
17 
35 #include "mlir/IR/Attributes.h"
36 #include "mlir/IR/Builders.h"
37 #include "mlir/IR/BuiltinOps.h"
38 #include "mlir/IR/BuiltinTypes.h"
39 
40 #include "llvm/ADT/STLExtras.h"
41 #include "llvm/Support/Error.h"
42 #include "llvm/Support/FormatVariadic.h"
43 
44 #define DEBUG_TYPE "gpu-to-llvm"
45 
46 namespace mlir {
47 #define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
48 #include "mlir/Conversion/Passes.h.inc"
49 } // namespace mlir
50 
51 using namespace mlir;
52 
53 namespace {
54 class GpuToLLVMConversionPass
55  : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
56 public:
57  using Base::Base;
58  void getDependentDialects(DialectRegistry &registry) const final {
59  Base::getDependentDialects(registry);
61  }
62  // Run the dialect converter on the module.
63  void runOnOperation() override;
64 };
65 
66 template <typename OpTy>
67 class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
68 public:
69  explicit ConvertOpToGpuRuntimeCallPattern(
70  const LLVMTypeConverter &typeConverter)
71  : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
72 
73 protected:
75  MemRefType type, MemRefDescriptor desc) const {
77  return type.hasStaticShape()
79  rewriter, loc, indexType, type.getNumElements())
80  // For identity maps (verified by caller), the number of
81  // elements is stride[0] * size[0].
82  : rewriter.create<LLVM::MulOp>(loc,
83  desc.stride(rewriter, loc, 0),
84  desc.size(rewriter, loc, 0));
85  }
86 
87  MLIRContext *context = &this->getTypeConverter()->getContext();
88 
89  Type llvmVoidType = LLVM::LLVMVoidType::get(context);
90  LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
91  Type llvmInt8Type = IntegerType::get(context, 8);
92  Type llvmInt16Type = IntegerType::get(context, 16);
93  Type llvmInt32Type = IntegerType::get(context, 32);
94  Type llvmInt64Type = IntegerType::get(context, 64);
95  Type llvmFloat32Type = Float32Type::get(context);
96  Type llvmIntPtrType = IntegerType::get(
97  context, this->getTypeConverter()->getPointerBitwidth(0));
98 
99  FunctionCallBuilder streamCreateCallBuilder = {
100  "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
101  FunctionCallBuilder streamDestroyCallBuilder = {
102  "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
103  FunctionCallBuilder streamSynchronizeCallBuilder = {
104  "mgpuStreamSynchronize",
105  llvmVoidType,
106  {llvmPointerType /* void *stream */}};
107  FunctionCallBuilder streamWaitEventCallBuilder = {
108  "mgpuStreamWaitEvent",
109  llvmVoidType,
110  {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
111  FunctionCallBuilder eventCreateCallBuilder = {
112  "mgpuEventCreate", llvmPointerType /* void *event */, {}};
113  FunctionCallBuilder eventDestroyCallBuilder = {
114  "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
115  FunctionCallBuilder eventSynchronizeCallBuilder = {
116  "mgpuEventSynchronize",
117  llvmVoidType,
118  {llvmPointerType /* void *event */}};
119  FunctionCallBuilder eventRecordCallBuilder = {
120  "mgpuEventRecord",
121  llvmVoidType,
122  {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
123  FunctionCallBuilder hostRegisterCallBuilder = {
124  "mgpuMemHostRegisterMemRef",
125  llvmVoidType,
126  {llvmIntPtrType /* intptr_t rank */,
127  llvmPointerType /* void *memrefDesc */,
128  llvmIntPtrType /* intptr_t elementSizeBytes */}};
129  FunctionCallBuilder hostUnregisterCallBuilder = {
130  "mgpuMemHostUnregisterMemRef",
131  llvmVoidType,
132  {llvmIntPtrType /* intptr_t rank */,
133  llvmPointerType /* void *memrefDesc */,
134  llvmIntPtrType /* intptr_t elementSizeBytes */}};
135  FunctionCallBuilder allocCallBuilder = {
136  "mgpuMemAlloc",
137  llvmPointerType /* void * */,
138  {llvmIntPtrType /* intptr_t sizeBytes */,
139  llvmPointerType /* void *stream */,
140  llvmInt8Type /* bool isHostShared */}};
141  FunctionCallBuilder deallocCallBuilder = {
142  "mgpuMemFree",
143  llvmVoidType,
144  {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
145  FunctionCallBuilder memcpyCallBuilder = {
146  "mgpuMemcpy",
147  llvmVoidType,
148  {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
149  llvmIntPtrType /* intptr_t sizeBytes */,
150  llvmPointerType /* void *stream */}};
151  FunctionCallBuilder memset16CallBuilder = {
152  "mgpuMemset16",
153  llvmVoidType,
154  {llvmPointerType /* void *dst */,
155  llvmInt16Type /* unsigned short value */,
156  llvmIntPtrType /* intptr_t sizeBytes */,
157  llvmPointerType /* void *stream */}};
158  FunctionCallBuilder memset32CallBuilder = {
159  "mgpuMemset32",
160  llvmVoidType,
161  {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
162  llvmIntPtrType /* intptr_t sizeBytes */,
163  llvmPointerType /* void *stream */}};
164  FunctionCallBuilder setDefaultDeviceCallBuilder = {
165  "mgpuSetDefaultDevice",
166  llvmVoidType,
167  {llvmInt32Type /* uint32_t devIndex */}};
168  FunctionCallBuilder createDnVecCallBuilder = {
169  "mgpuCreateDnVec",
170  llvmPointerType,
171  {llvmIntPtrType, llvmPointerType, llvmInt32Type,
172  llvmPointerType /* void *stream */}};
173  FunctionCallBuilder destroyDnVecCallBuilder = {
174  "mgpuDestroyDnVec",
175  llvmVoidType,
176  {llvmPointerType, llvmPointerType /* void *stream */}};
177  FunctionCallBuilder createDnMatCallBuilder = {
178  "mgpuCreateDnMat",
179  llvmPointerType,
180  {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
181  llvmPointerType /* void *stream */}};
182  FunctionCallBuilder destroyDnMatCallBuilder = {
183  "mgpuDestroyDnMat",
184  llvmVoidType,
185  {llvmPointerType, llvmPointerType /* void *stream */}};
186  FunctionCallBuilder createCooCallBuilder = {
187  "mgpuCreateCoo",
188  llvmPointerType,
189  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
190  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
191  llvmPointerType /* void *stream */}};
192  FunctionCallBuilder createCooAoSCallBuilder = {
193  "mgpuCreateCooAoS", // deprecated in cuSPARSE 11.2
194  llvmPointerType,
195  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
196  llvmPointerType, llvmInt32Type, llvmInt32Type,
197  llvmPointerType /* void *stream */}};
198  FunctionCallBuilder createCsrCallBuilder = {
199  "mgpuCreateCsr",
200  llvmPointerType,
201  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
202  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
203  llvmInt32Type, llvmPointerType /* void *stream */}};
204  FunctionCallBuilder createCscCallBuilder = {
205  "mgpuCreateCsc",
206  llvmPointerType,
207  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
208  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
209  llvmInt32Type, llvmPointerType /* void *stream */}};
210  FunctionCallBuilder createBsrCallBuilder = {
211  "mgpuCreateBsr",
212  llvmPointerType,
213  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
214  llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
215  llvmInt32Type, llvmInt32Type, llvmInt32Type,
216  llvmPointerType /* void *stream */}};
217  FunctionCallBuilder destroySpMatCallBuilder = {
218  "mgpuDestroySpMat",
219  llvmVoidType,
220  {llvmPointerType, llvmPointerType /* void *stream */}};
221  FunctionCallBuilder spMVBufferSizeCallBuilder = {
222  "mgpuSpMVBufferSize",
223  llvmIntPtrType,
224  {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
225  llvmInt32Type, llvmPointerType /* void *stream */}};
226  FunctionCallBuilder spMVCallBuilder = {
227  "mgpuSpMV",
228  llvmVoidType,
229  {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
230  llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
231  FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
232  "mgpuSpMMBufferSize",
233  llvmIntPtrType,
234  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
235  llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
236  FunctionCallBuilder createSpMMCallBuilder = {
237  "mgpuSpMM",
238  llvmVoidType,
239  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
240  llvmPointerType, llvmInt32Type, llvmPointerType,
241  llvmPointerType /* void *stream */}};
242  FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
243  "mgpuSDDMMBufferSize",
244  llvmIntPtrType,
245  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
246  llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
247  FunctionCallBuilder createSDDMMCallBuilder = {
248  "mgpuSDDMM",
249  llvmVoidType,
250  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
251  llvmPointerType, llvmInt32Type, llvmPointerType,
252  llvmPointerType /* void *stream */}};
253  FunctionCallBuilder createLtDnMatCallBuilder = {
254  "mgpuCreateCuSparseLtDnMat",
255  llvmVoidType,
256  {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
257  llvmInt32Type, llvmPointerType /* void *stream */}};
258  FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
259  "mgpuDestroyCuSparseLtSpMat",
260  llvmVoidType,
261  {llvmPointerType, llvmPointerType /* void *stream */}};
262  FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
263  "mgpuDestroyCuSparseLtDnMat",
264  llvmVoidType,
265  {llvmPointerType, llvmPointerType /* void *stream */}};
266  FunctionCallBuilder create2To4SpMatCallBuilder = {
267  "mgpuCusparseLtCreate2To4SpMat",
268  llvmVoidType,
269  {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
270  llvmInt32Type, llvmPointerType /* void *stream */}};
271  FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
272  "mgpuCuSparseLtSpMMBufferSize",
273  llvmVoidType,
274  {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
275  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
276  llvmPointerType /*void *stream*/}};
277  FunctionCallBuilder createCuSparseLtSpMMBuilder = {
278  "mgpuCuSparseLtSpMM",
279  llvmVoidType,
280  {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
281  llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}};
282  FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
283  "mgpuSpGEMMCreateDescr",
284  llvmPointerType,
285  {llvmPointerType /*void *stream*/}};
286  FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
287  "mgpuSpGEMMDestroyDescr",
288  llvmVoidType,
289  {llvmPointerType /*s*/, llvmPointerType /*void *stream*/}};
290  FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
291  "mgpuSpGEMMWorkEstimation",
292  llvmIntPtrType,
293  {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
294  llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
295  llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
296  llvmPointerType /*void *stream*/}};
297  FunctionCallBuilder createSpGEMMComputeBuilder = {
298  "mgpuSpGEMMCompute",
299  llvmIntPtrType,
300  {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
301  llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
302  llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
303  llvmPointerType /*void *stream*/}};
304  FunctionCallBuilder createSpGEMMCopyBuilder = {
305  "mgpuSpGEMMCopy",
306  llvmVoidType,
307  {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
308  llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
309  llvmInt32Type /*ctp*/, llvmPointerType /*void *stream*/}};
310  FunctionCallBuilder createSpMatGetSizeBuilder = {
311  "mgpuSpMatGetSize",
312  llvmVoidType,
313  {llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/,
314  llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}};
315  FunctionCallBuilder createSetCsrPointersBuilder = {
316  "mgpuSetCsrPointers",
317  llvmVoidType,
318  {llvmPointerType /*spmat*/, llvmPointerType /*pos*/,
319  llvmPointerType /*crd*/, llvmPointerType /*val*/,
320  llvmPointerType /*void *stream*/}};
321 };
322 
323 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
324 /// call. Currently it supports CUDA and ROCm (HIP).
325 class ConvertHostRegisterOpToGpuRuntimeCallPattern
326  : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
327 public:
328  ConvertHostRegisterOpToGpuRuntimeCallPattern(
329  const LLVMTypeConverter &typeConverter)
330  : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
331 
332 private:
333  LogicalResult
334  matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
335  ConversionPatternRewriter &rewriter) const override;
336 };
337 
338 class ConvertHostUnregisterOpToGpuRuntimeCallPattern
339  : public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
340 public:
341  ConvertHostUnregisterOpToGpuRuntimeCallPattern(
342  const LLVMTypeConverter &typeConverter)
343  : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
344  }
345 
346 private:
347  LogicalResult
348  matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
349  ConversionPatternRewriter &rewriter) const override;
350 };
351 
352 /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
353 /// call. Currently it supports CUDA and ROCm (HIP).
354 class ConvertAllocOpToGpuRuntimeCallPattern
355  : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
356 public:
357  ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
358  : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
359 
360 private:
361  LogicalResult
362  matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
363  ConversionPatternRewriter &rewriter) const override;
364 };
365 
366 /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
367 /// call. Currently it supports CUDA and ROCm (HIP).
368 class ConvertDeallocOpToGpuRuntimeCallPattern
369  : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
370 public:
371  ConvertDeallocOpToGpuRuntimeCallPattern(
372  const LLVMTypeConverter &typeConverter)
373  : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
374 
375 private:
376  LogicalResult
377  matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
378  ConversionPatternRewriter &rewriter) const override;
379 };
380 
381 class ConvertAsyncYieldToGpuRuntimeCallPattern
382  : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
383 public:
384  ConvertAsyncYieldToGpuRuntimeCallPattern(
385  const LLVMTypeConverter &typeConverter)
386  : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
387 
388 private:
389  LogicalResult
390  matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
391  ConversionPatternRewriter &rewriter) const override;
392 };
393 
394 /// A rewrite pattern to convert gpu.wait operations into a GPU runtime
395 /// call. Currently it supports CUDA and ROCm (HIP).
396 class ConvertWaitOpToGpuRuntimeCallPattern
397  : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
398 public:
399  ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
400  : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
401 
402 private:
403  LogicalResult
404  matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
405  ConversionPatternRewriter &rewriter) const override;
406 };
407 
408 /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
409 /// call. Currently it supports CUDA and ROCm (HIP).
410 class ConvertWaitAsyncOpToGpuRuntimeCallPattern
411  : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
412 public:
413  ConvertWaitAsyncOpToGpuRuntimeCallPattern(
414  const LLVMTypeConverter &typeConverter)
415  : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
416 
417 private:
418  LogicalResult
419  matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
420  ConversionPatternRewriter &rewriter) const override;
421 };
422 
423 /// A rewrite patter to legalize gpu.launch_func with LLVM types.
424 class LegalizeLaunchFuncOpPattern
425  : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
426 public:
427  LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
428  bool kernelBarePtrCallConv)
429  : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
430  kernelBarePtrCallConv(kernelBarePtrCallConv) {}
431 
432 private:
433  LogicalResult
434  matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
435  ConversionPatternRewriter &rewriter) const override;
436 
437  bool kernelBarePtrCallConv;
438 };
439 
440 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
441 /// call. Currently it supports CUDA and ROCm (HIP).
442 class ConvertMemcpyOpToGpuRuntimeCallPattern
443  : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
444 public:
445  ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
446  : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
447 
448 private:
449  LogicalResult
450  matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
451  ConversionPatternRewriter &rewriter) const override;
452 };
453 
454 /// A rewrite pattern to convert gpu.memset operations into a GPU runtime
455 /// call. Currently it supports CUDA and ROCm (HIP).
456 class ConvertMemsetOpToGpuRuntimeCallPattern
457  : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
458 public:
459  ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
460  : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
461 
462 private:
463  LogicalResult
464  matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
465  ConversionPatternRewriter &rewriter) const override;
466 };
467 
468 /// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call.
469 /// Currently supports CUDA and ROCm (HIP)
470 class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
471  : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
472 public:
473  ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
474  const LLVMTypeConverter &typeConverter)
475  : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
476  typeConverter) {}
477 
478  LogicalResult
479  matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
480  ConversionPatternRewriter &rewriter) const override;
481 };
482 
483 /// Generic rewriting rule for operation on sparse matrices.
484 /// Currently supports CUDA (by means of cuSparse and cuSparseLt).
485 #define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
486  class Convert##op_name##ToGpuRuntimeCallPattern \
487  : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
488  public: \
489  Convert##op_name##ToGpuRuntimeCallPattern( \
490  const LLVMTypeConverter &typeConverter) \
491  : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
492  \
493  private: \
494  LogicalResult \
495  matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
496  ConversionPatternRewriter &rewriter) const override; \
497  };
498 
516 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp)
520 
521 } // namespace
522 
523 void GpuToLLVMConversionPass::runOnOperation() {
524  MLIRContext *context = &getContext();
525  LowerToLLVMOptions options(context);
526  options.useBarePtrCallConv = hostBarePtrCallConv;
527  RewritePatternSet patterns(context);
528  ConversionTarget target(*context);
529  target.addLegalDialect<LLVM::LLVMDialect>();
530  LLVMTypeConverter converter(context, options);
531 
532  // Populate all patterns from all dialects that implement the
533  // `ConvertToLLVMPatternInterface` interface.
534  for (Dialect *dialect : context->getLoadedDialects()) {
535  auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
536  if (!iface)
537  continue;
538  iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
539  }
540 
541  // Preserve GPU modules and binaries. Modules are preserved as they can be
542  // converted later by `gpu-module-to-binary`.
543  target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
544  // Accept as legal LaunchFuncOps if the operands have been lowered.
545  target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
546  [&](gpu::LaunchFuncOp op) -> bool { return converter.isLegal(op); });
547 
548  // These aren't covered by the ConvertToLLVMPatternInterface right now.
549  populateVectorToLLVMConversionPatterns(converter, patterns);
552  target);
553  populateGpuToLLVMConversionPatterns(converter, patterns,
554  kernelBarePtrCallConv);
555 
556  if (failed(
557  applyPartialConversion(getOperation(), target, std::move(patterns))))
558  signalPassFailure();
559 }
560 
561 LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
562  ArrayRef<Value> arguments) const {
563  auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
564  auto function = [&] {
565  if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
566  return function;
567  return OpBuilder::atBlockEnd(module.getBody())
568  .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
569  }();
570  return builder.create<LLVM::CallOp>(loc, function, arguments);
571 }
572 
573 // Corresponding to cusparseIndexType_t defined in cusparse.h.
574 static int32_t getCuSparseIndexTypeFrom(Type type) {
575  if (type.isInteger(16))
576  return 1; // CUSPARSE_INDEX_16U
577  if (type.isInteger(32))
578  return 2; // CUSPARSE_INDEX_32I
579  return 3; // CUSPARSE_INDEX_64I
580 }
581 
582 static int32_t getCuSparseLtDataTypeFrom(Type type) {
583  if (type.isF16())
584  return 0; // CUSPARSE_COMPUTE_16F,
585  if (type.isInteger(32))
586  return 1; // CUSPARSE_COMPUTE_32I
587  llvm_unreachable("unsupported type");
588  // TODO: add support to TF32
589 }
590 
591 // Corresponding to cudaDataType_t defined in CUDA library_types.h.
592 static int32_t getCuSparseDataTypeFrom(Type type) {
593  if (llvm::isa<ComplexType>(type)) {
594  // get the element type
595  auto elementType = cast<ComplexType>(type).getElementType();
596  if (elementType.isBF16())
597  return 15; // CUDA_C_16BF
598  if (elementType.isF16())
599  return 6; // CUDA_C_16F
600  if (elementType.isF32())
601  return 4; // CUDA_C_32F
602  if (elementType.isF64())
603  return 5; // CUDA_C_64F
604  if (elementType.isInteger(8))
605  return 7; // CUDA_C_8I
606  if (elementType.isInteger(16))
607  return 21; // CUDA_C_16I
608  if (elementType.isInteger(32))
609  return 11; // CUDA_C_32I
610  }
611  if (type.isBF16())
612  return 14; // CUDA_R_16BF
613  if (type.isF16())
614  return 2; // CUDA_R_16F
615  if (type.isF32())
616  return 0; // CUDA_R_32F
617  if (type.isF64())
618  return 1; // CUDA_R_64F
619  if (type.isInteger(8))
620  return 3; // CUDA_R_8I
621  if (type.isInteger(16))
622  return 20; // CUDA_R_16I
623  if (type.isInteger(32))
624  return 10; // CUDA_R_32I
625 
626  llvm_unreachable("unsupported element type");
627 }
628 
629 static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat) {
630  return spMat.getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
631 }
632 
633 // TODO: We may want a run-time (of the mlir compiler) disablement/warning:
634 // cusparseLt currently won't work for cuda architecture <8.0 and will trigger a
635 // runtime (of the CUDA program) error , but it might be great if we could at
636 // least output a warning when we found the target architecture is <8.0 and the
637 // user still wants to use cusparseLt. to make sure when lowering gpu sparse
638 // dialect to llvm calls, the cusparselt calls are disabled for cuda
639 // architecture <8.0
640 static bool is2To4Sparsity(Value spMat) {
641  if (auto op = spMat.getDefiningOp<gpu::Create2To4SpMatOp>())
642  return true;
643  if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
644  return false;
645  if (auto op = spMat.getDefiningOp<gpu::CreateCooAoSOp>())
646  return false;
647  if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
648  return false;
649  if (auto op = spMat.getDefiningOp<gpu::CreateCscOp>())
650  return false;
651  if (auto op = spMat.getDefiningOp<gpu::CreateBsrOp>())
652  return false;
653  // Print the spMat defining op
654  spMat.getDefiningOp()->print(llvm::errs());
655  llvm_unreachable("cannot find spmat def");
656 }
657 
658 static bool isSpMMCusparseLtOp(Value op) {
659  for (Operation *user : op.getUsers()) {
660  auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
661  // If the other operator is 50% sparsity then we should use cusparseLt
662  if (!spmmOp)
663  continue;
664  if (is2To4Sparsity(spmmOp.getSpmatA()))
665  return true;
666  }
667  return false;
668 }
669 
670 // Returns whether all operands are of LLVM type.
671 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
672  ConversionPatternRewriter &rewriter) {
673  if (!llvm::all_of(operands, [](Value value) {
674  return LLVM::isCompatibleType(value.getType());
675  }))
676  return rewriter.notifyMatchFailure(
677  op, "Cannot convert if operands aren't of LLVM type.");
678  return success();
679 }
680 
681 static LogicalResult
683  gpu::AsyncOpInterface op) {
684  if (op.getAsyncDependencies().size() != 1)
685  return rewriter.notifyMatchFailure(
686  op, "Can only convert with exactly one async dependency.");
687 
688  if (!op.getAsyncToken())
689  return rewriter.notifyMatchFailure(op, "Can convert only async version.");
690 
691  return success();
692 }
693 
694 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
695  gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
696  ConversionPatternRewriter &rewriter) const {
697  auto *op = hostRegisterOp.getOperation();
698  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
699  return failure();
700 
701  Location loc = op->getLoc();
702 
703  auto memRefType = hostRegisterOp.getValue().getType();
704  auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
705  auto elementSize = getSizeInBytes(loc, elementType, rewriter);
706 
707  auto arguments = getTypeConverter()->promoteOperands(
708  loc, op->getOperands(), adaptor.getOperands(), rewriter);
709  arguments.push_back(elementSize);
710  hostRegisterCallBuilder.create(loc, rewriter, arguments);
711 
712  rewriter.eraseOp(op);
713  return success();
714 }
715 
716 LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
717  gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
718  ConversionPatternRewriter &rewriter) const {
719  Operation *op = hostUnregisterOp.getOperation();
720  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
721  return failure();
722 
723  Location loc = op->getLoc();
724 
725  auto memRefType = hostUnregisterOp.getValue().getType();
726  auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
727  auto elementSize = getSizeInBytes(loc, elementType, rewriter);
728 
729  auto arguments = getTypeConverter()->promoteOperands(
730  loc, op->getOperands(), adaptor.getOperands(), rewriter);
731  arguments.push_back(elementSize);
732  hostUnregisterCallBuilder.create(loc, rewriter, arguments);
733 
734  rewriter.eraseOp(op);
735  return success();
736 }
737 
738 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
739  gpu::AllocOp allocOp, OpAdaptor adaptor,
740  ConversionPatternRewriter &rewriter) const {
741 
742  MemRefType memRefType = allocOp.getType();
743 
744  if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
745  !isConvertibleAndHasIdentityMaps(memRefType))
746  return failure();
747 
748  auto loc = allocOp.getLoc();
749 
750  bool isShared = allocOp.getHostShared();
751 
752  if (isShared && allocOp.getAsyncToken())
753  return rewriter.notifyMatchFailure(
754  allocOp, "Host Shared allocation cannot be done async");
755  if (!isShared && failed(isAsyncWithOneDependency(rewriter, allocOp)))
756  return failure();
757 
758  // Get shape of the memref as values: static sizes are constant
759  // values and dynamic sizes are passed to 'alloc' as operands.
760  SmallVector<Value, 4> shape;
761  SmallVector<Value, 4> strides;
762  Value sizeBytes;
763  getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
764  shape, strides, sizeBytes);
765 
766  // Allocate the underlying buffer and store a pointer to it in the MemRef
767  // descriptor.
768  auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
769  Value stream = adaptor.getAsyncDependencies().empty()
770  ? nullPtr
771  : adaptor.getAsyncDependencies().front();
772 
773  auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>(
774  loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
775 
776  Value allocatedPtr =
777  allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
778  .getResult();
779 
780  // No alignment.
781  Value alignedPtr = allocatedPtr;
782 
783  // Create the MemRef descriptor.
784  auto memRefDescriptor = this->createMemRefDescriptor(
785  loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
786 
787  if (allocOp.getAsyncToken()) {
788  // Async alloc: make dependent ops use the same stream.
789  rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
790  } else {
791  rewriter.replaceOp(allocOp, {memRefDescriptor});
792  }
793 
794  return success();
795 }
796 
797 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
798  gpu::DeallocOp deallocOp, OpAdaptor adaptor,
799  ConversionPatternRewriter &rewriter) const {
800  if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
801  failed(isAsyncWithOneDependency(rewriter, deallocOp)))
802  return failure();
803 
804  Location loc = deallocOp.getLoc();
805 
806  Value pointer =
807  MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
808  Value stream = adaptor.getAsyncDependencies().front();
809  deallocCallBuilder.create(loc, rewriter, {pointer, stream});
810 
811  rewriter.replaceOp(deallocOp, {stream});
812  return success();
813 }
814 
815 static bool isGpuAsyncTokenType(Value value) {
816  return isa<gpu::AsyncTokenType>(value.getType());
817 }
818 
819 // Converts !gpu.async.token operands of `async.yield` to runtime calls. The
820 // !gpu.async.token are lowered to stream within the async.execute region, but
821 // are passed as events between them. For each !gpu.async.token operand, we
822 // create an event and record it on the stream.
823 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
824  async::YieldOp yieldOp, OpAdaptor adaptor,
825  ConversionPatternRewriter &rewriter) const {
826  if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType))
827  return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
828 
829  Location loc = yieldOp.getLoc();
830  SmallVector<Value, 4> newOperands(adaptor.getOperands());
831  llvm::SmallDenseSet<Value> streams;
832  for (auto &operand : yieldOp->getOpOperands()) {
833  if (!isGpuAsyncTokenType(operand.get()))
834  continue;
835  auto idx = operand.getOperandNumber();
836  auto stream = adaptor.getOperands()[idx];
837  auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
838  eventRecordCallBuilder.create(loc, rewriter, {event, stream});
839  newOperands[idx] = event;
840  streams.insert(stream);
841  }
842  for (auto stream : streams)
843  streamDestroyCallBuilder.create(loc, rewriter, {stream});
844 
845  rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
846  return success();
847 }
848 
849 // Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
850 static bool isDefinedByCallTo(Value value, StringRef functionName) {
851  assert(isa<LLVM::LLVMPointerType>(value.getType()));
852  if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
853  return *defOp.getCallee() == functionName;
854  return false;
855 }
856 
857 // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
858 // with the stream/event operands. The operands are destroyed. That is, it
859 // assumes that it is not used afterwards or elsewhere. Otherwise we will get a
860 // runtime error. Eventually, we should guarantee this property.
861 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
862  gpu::WaitOp waitOp, OpAdaptor adaptor,
863  ConversionPatternRewriter &rewriter) const {
864  if (waitOp.getAsyncToken())
865  return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
866 
867  Location loc = waitOp.getLoc();
868 
869  for (auto operand : adaptor.getOperands()) {
870  if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
871  // The converted operand's definition created a stream.
872  streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
873  streamDestroyCallBuilder.create(loc, rewriter, {operand});
874  } else {
875  // Otherwise the converted operand is an event. This assumes that we use
876  // events in control flow code as well.
877  eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
878  eventDestroyCallBuilder.create(loc, rewriter, {operand});
879  }
880  }
881 
882  rewriter.eraseOp(waitOp);
883  return success();
884 }
885 
886 // Converts `gpu.wait async` to runtime calls. The converted op creates a new
887 // stream that is synchronized with stream/event operands. The operands are
888 // destroyed. That is, it assumes that it is not used afterwards or elsewhere.
889 // Otherwise we will get a runtime error. Eventually, we should guarantee this
890 // property.
891 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
892  gpu::WaitOp waitOp, OpAdaptor adaptor,
893  ConversionPatternRewriter &rewriter) const {
894  if (!waitOp.getAsyncToken())
895  return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
896 
897  Location loc = waitOp.getLoc();
898 
899  auto insertionPoint = rewriter.saveInsertionPoint();
900  SmallVector<Value, 1> events;
901  for (auto pair :
902  llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
903  auto operand = std::get<1>(pair);
904  if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
905  // The converted operand's definition created a stream. Insert an event
906  // into the stream just after the last use of the original token operand.
907  auto *defOp = std::get<0>(pair).getDefiningOp();
908  rewriter.setInsertionPointAfter(defOp);
909  auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
910  eventRecordCallBuilder.create(loc, rewriter, {event, operand});
911  events.push_back(event);
912  } else {
913  // Otherwise the converted operand is an event. This assumes that we use
914  // events in control flow code as well.
915  events.push_back(operand);
916  }
917  }
918  rewriter.restoreInsertionPoint(insertionPoint);
919  auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
920  for (auto event : events)
921  streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
922  for (auto event : events)
923  eventDestroyCallBuilder.create(loc, rewriter, {event});
924  rewriter.replaceOp(waitOp, {stream});
925 
926  return success();
927 }
928 
929 // Legalize the op's operands.
930 LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
931  gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
932  ConversionPatternRewriter &rewriter) const {
933  if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
934  return failure();
935 
936  if (launchOp.getAsyncDependencies().size() > 1)
937  return rewriter.notifyMatchFailure(
938  launchOp, "Cannot convert with more than one async dependency.");
939 
940  // Fail when the synchronous version of the op has async dependencies. The
941  // lowering destroys the stream, and we do not want to check that there is no
942  // use of the stream after this op.
943  if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
944  return rewriter.notifyMatchFailure(
945  launchOp, "Cannot convert non-async op with async dependencies.");
946 
947  Location loc = launchOp.getLoc();
948 
949  Value stream = Value();
950  if (!adaptor.getAsyncDependencies().empty())
951  stream = adaptor.getAsyncDependencies().front();
952  // If the async keyword is present and there are no dependencies, then a
953  // stream must be created to pass to subsequent operations.
954  else if (launchOp.getAsyncToken())
955  stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
956  // Lower the kernel operands to match kernel parameters.
957  // Note: If `useBarePtrCallConv` is set in the type converter's options,
958  // the value of `kernelBarePtrCallConv` will be ignored.
959  SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands(
960  loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(), rewriter,
961  /*useBarePtrCallConv=*/kernelBarePtrCallConv);
962 
963  std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
964  if (launchOp.hasClusterSize()) {
965  clusterSize =
966  gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
967  adaptor.getClusterSizeZ()};
968  }
969  rewriter.create<gpu::LaunchFuncOp>(
970  launchOp.getLoc(), launchOp.getKernelAttr(),
971  gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
972  adaptor.getGridSizeZ()},
973  gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
974  adaptor.getBlockSizeZ()},
975  adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
976  if (launchOp.getAsyncToken())
977  rewriter.replaceOp(launchOp, {stream});
978  else
979  rewriter.eraseOp(launchOp);
980  return success();
981 }
982 
984  ConversionPatternRewriter &rewriter,
985  LLVM::LLVMPointerType destinationType,
986  Value sourcePtr,
987  const LLVMTypeConverter &typeConverter) {
988  auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
989  if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
990  sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>(
991  loc,
993  destinationType.getAddressSpace()),
994  sourcePtr);
995  return sourcePtr;
996 }
997 
998 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
999  gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1000  ConversionPatternRewriter &rewriter) const {
1001  auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1002 
1003  if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
1004  !isConvertibleAndHasIdentityMaps(memRefType) ||
1005  failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
1006  return failure();
1007 
1008  auto loc = memcpyOp.getLoc();
1009 
1010  MemRefDescriptor srcDesc(adaptor.getSrc());
1011  Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
1012 
1013  Type elementPtrType = getElementPtrType(memRefType);
1014  Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
1015  Value gepPtr = rewriter.create<LLVM::GEPOp>(
1016  loc, elementPtrType,
1017  typeConverter->convertType(memRefType.getElementType()), nullPtr,
1018  numElements);
1019  auto sizeBytes =
1020  rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
1021 
1022  auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1023  srcDesc.alignedPtr(rewriter, loc),
1024  *getTypeConverter());
1025  auto dst = bitAndAddrspaceCast(
1026  loc, rewriter, llvmPointerType,
1027  MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
1028  *getTypeConverter());
1029 
1030  auto stream = adaptor.getAsyncDependencies().front();
1031  memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1032 
1033  rewriter.replaceOp(memcpyOp, {stream});
1034 
1035  return success();
1036 }
1037 
1038 LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1039  gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1040  ConversionPatternRewriter &rewriter) const {
1041  auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1042 
1043  if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
1044  !isConvertibleAndHasIdentityMaps(memRefType) ||
1045  failed(isAsyncWithOneDependency(rewriter, memsetOp)))
1046  return failure();
1047 
1048  auto loc = memsetOp.getLoc();
1049 
1050  Type valueType = adaptor.getValue().getType();
1051  unsigned bitWidth = valueType.getIntOrFloatBitWidth();
1052  // Ints and floats of 16 or 32 bit width are allowed.
1053  if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1054  return rewriter.notifyMatchFailure(
1055  memsetOp, "value must be a 16 or 32 bit int or float");
1056  }
1057 
1058  unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth();
1059  Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1060 
1061  MemRefDescriptor dstDesc(adaptor.getDst());
1062  Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
1063 
1064  auto value =
1065  rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
1066  auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1067  dstDesc.alignedPtr(rewriter, loc),
1068  *getTypeConverter());
1069 
1070  auto stream = adaptor.getAsyncDependencies().front();
1071  FunctionCallBuilder builder =
1072  valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1073  builder.create(loc, rewriter, {dst, value, numElements, stream});
1074 
1075  rewriter.replaceOp(memsetOp, {stream});
1076  return success();
1077 }
1078 
1079 LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1080  gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1081  ConversionPatternRewriter &rewriter) const {
1082  Location loc = op.getLoc();
1083  auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1084  {adaptor.getDevIndex()});
1085  rewriter.replaceOp(op, call);
1086  return success();
1087 }
1088 
1089 template <typename T>
1090 static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) {
1091  Type llvmInt32Type = builder.getIntegerType(32);
1092  return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
1093  static_cast<int32_t>(tValue));
1094 }
1095 
1096 template <typename T>
1097 static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) {
1098  Type llvmFloat32Type = builder.getF32Type();
1099  return builder.create<LLVM::ConstantOp>(
1100  loc, llvmFloat32Type,
1101  builder.getF32FloatAttr(static_cast<float>(tValue)));
1102 }
1103 
1104 LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1105  gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1106  ConversionPatternRewriter &rewriter) const {
1107  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1108  failed(isAsyncWithOneDependency(rewriter, op)))
1109  return failure();
1110  Location loc = op.getLoc();
1111  auto stream = adaptor.getAsyncDependencies().front();
1112  Value pTensor =
1113  MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1114  Type dType = op.getMemref().getType().getElementType();
1115  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1116 
1117  SmallVector<Value, 4> dims;
1118  for (Value dim : adaptor.getDims()) {
1119  dims.push_back(dim);
1120  }
1121 
1122  Value handle;
1123  // TODO: For now, we track the use of the handle and lower it to cusparse /
1124  // cusparseLt accordingly. If in a block, both cusparse and cusparseLt are
1125  // used, we require two separate Creation ops to be the correct logic. In
1126  // future, we may add support to using one handle in sparse tensor / GPU
1127  // dialect in both cusparse and cusparseLt. use the cusparseLt create call if
1128  // the dnmat is used with spmat with 2:4 sparsity
1129  if (dims.size() == 2) {
1130  if (isSpMMCusparseLtOp(op.getDnTensor())) {
1131  auto handleSz = rewriter.create<LLVM::ConstantOp>(
1132  loc, getIndexType(), rewriter.getIndexAttr(11032));
1133  handle = rewriter.create<LLVM::AllocaOp>(
1134  loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1135  handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1136 
1137  createLtDnMatCallBuilder
1138  .create(loc, rewriter,
1139  {handle, dims[0], dims[1], pTensor, dtp, stream})
1140  .getResult();
1141  } else {
1142  handle =
1143  createDnMatCallBuilder
1144  .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1145  .getResult();
1146  }
1147  } else {
1148  assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1149  handle = createDnVecCallBuilder
1150  .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1151  .getResult();
1152  }
1153  rewriter.replaceOp(op, {handle, stream});
1154  return success();
1155 }
1156 
1157 LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1158  gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1159  ConversionPatternRewriter &rewriter) const {
1160  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1161  failed(isAsyncWithOneDependency(rewriter, op)))
1162  return failure();
1163  Location loc = op.getLoc();
1164  auto stream = adaptor.getAsyncDependencies().front();
1165  auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
1166  SmallVector<Value, 4> dims;
1167  for (Value dim : definingOp.getDims()) {
1168  dims.push_back(dim);
1169  }
1170  if (dims.size() == 2) {
1171  // Use the cusparseLt destroy call if the dnmat is used with spmat with
1172  // 2:4 sparsity
1173  if (isSpMMCusparseLtOp(op.getDnTensor())) {
1174  destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1175  {adaptor.getDnTensor(), stream});
1176  } else {
1177  destroyDnMatCallBuilder.create(loc, rewriter,
1178  {adaptor.getDnTensor(), stream});
1179  }
1180  } else {
1181  assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1182  destroyDnVecCallBuilder.create(loc, rewriter,
1183  {adaptor.getDnTensor(), stream});
1184  }
1185  rewriter.replaceOp(op, {stream});
1186  return success();
1187 }
1188 
1189 LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1190  gpu::CreateCooOp op, OpAdaptor adaptor,
1191  ConversionPatternRewriter &rewriter) const {
1192  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1193  failed(isAsyncWithOneDependency(rewriter, op)))
1194  return failure();
1195  Location loc = op.getLoc();
1196  auto stream = adaptor.getAsyncDependencies().front();
1197  Value pRowIdxs =
1198  MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1199  Value pColIdxs =
1200  MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1201  Value pValues =
1202  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1203  Type iType =
1204  llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1205  Type dType =
1206  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1207  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1208  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1209  auto handle =
1210  createCooCallBuilder
1211  .create(loc, rewriter,
1212  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1213  pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1214  .getResult();
1215  rewriter.replaceOp(op, {handle, stream});
1216  return success();
1217 }
1218 
1219 LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1220  gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1221  ConversionPatternRewriter &rewriter) const {
1222  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1223  failed(isAsyncWithOneDependency(rewriter, op)))
1224  return failure();
1225  Location loc = op.getLoc();
1226  auto stream = adaptor.getAsyncDependencies().front();
1227  Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc);
1228  Value pValues =
1229  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1230  Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1231  Type dType =
1232  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1233  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1234  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1235  auto handle =
1236  createCooAoSCallBuilder
1237  .create(loc, rewriter,
1238  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1239  pIdxs, pValues, itp, dtp, stream})
1240  .getResult();
1241  rewriter.replaceOp(op, {handle, stream});
1242  return success();
1243 }
1244 
1245 LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1246  gpu::CreateCsrOp op, OpAdaptor adaptor,
1247  ConversionPatternRewriter &rewriter) const {
1248  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1249  failed(isAsyncWithOneDependency(rewriter, op)))
1250  return failure();
1251  Location loc = op.getLoc();
1252  auto stream = adaptor.getAsyncDependencies().front();
1253  Value pRowPos =
1254  MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc);
1255  Value pColIdxs =
1256  MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1257  Value pValues =
1258  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1259  Type pType =
1260  llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1261  Type iType =
1262  llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1263  Type dType =
1264  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1265  auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1266  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1267  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1268  auto handle =
1269  createCsrCallBuilder
1270  .create(loc, rewriter,
1271  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1272  pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1273  .getResult();
1274  rewriter.replaceOp(op, {handle, stream});
1275  return success();
1276 }
1277 
1278 LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1279  gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1280  ConversionPatternRewriter &rewriter) const {
1281  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1282  failed(isAsyncWithOneDependency(rewriter, op)))
1283  return failure();
1284  Location loc = op.getLoc();
1285  auto stream = adaptor.getAsyncDependencies().front();
1286  Value pMat =
1287  MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1288  Type dType =
1289  llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1290  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1291 
1292  // CUDA runner asserts the size is 44104 bytes.
1293  auto handleSz = rewriter.create<LLVM::ConstantOp>(
1294  loc, getIndexType(), rewriter.getIndexAttr(44104));
1295  Value handle = rewriter.create<LLVM::AllocaOp>(
1296  loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1297  handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1298 
1299  create2To4SpMatCallBuilder
1300  .create(loc, rewriter,
1301  {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1302  .getResult();
1303  rewriter.replaceOp(op, {handle, stream});
1304  return success();
1305 }
1306 
1307 LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1308  gpu::DestroySpMatOp op, OpAdaptor adaptor,
1309  ConversionPatternRewriter &rewriter) const {
1310  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1311  failed(isAsyncWithOneDependency(rewriter, op)))
1312  return failure();
1313  Location loc = op.getLoc();
1314  auto stream = adaptor.getAsyncDependencies().front();
1315  // Use the cusparseLt destroy call if the spmat is 2:4 sparsity
1316  if (is2To4Sparsity(op.getSpmat())) {
1317  destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1318  {adaptor.getSpmat(), stream});
1319 
1320  } else {
1321  destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1322  }
1323  rewriter.replaceOp(op, {stream});
1324  return success();
1325 }
1326 
1327 LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1328  gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1329  ConversionPatternRewriter &rewriter) const {
1330  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1331  failed(isAsyncWithOneDependency(rewriter, op)))
1332  return failure();
1333  Location loc = op.getLoc();
1334  auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
1335  auto computeType = genConstInt32From(
1336  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1337  auto stream = adaptor.getAsyncDependencies().front();
1338  auto bufferSize = spMVBufferSizeCallBuilder
1339  .create(loc, rewriter,
1340  {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1341  adaptor.getDnY(), computeType, stream})
1342  .getResult();
1343  rewriter.replaceOp(op, {bufferSize, stream});
1344  return success();
1345 }
1346 
1347 LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1348  gpu::SpMVOp op, OpAdaptor adaptor,
1349  ConversionPatternRewriter &rewriter) const {
1350  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1351  failed(isAsyncWithOneDependency(rewriter, op)))
1352  return failure();
1353  Location loc = op.getLoc();
1354  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1355  auto computeType = genConstInt32From(
1356  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1357  auto stream = adaptor.getAsyncDependencies().front();
1358  Value pBuf =
1359  MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1360  spMVCallBuilder.create(loc, rewriter,
1361  {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1362  adaptor.getDnY(), computeType, pBuf, stream});
1363  rewriter.replaceOp(op, {stream});
1364  return success();
1365 }
1366 
1367 LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1368  gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1369  ConversionPatternRewriter &rewriter) const {
1370  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1371  failed(isAsyncWithOneDependency(rewriter, op)))
1372  return failure();
1373  Location loc = op.getLoc();
1374  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1375  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1376  auto stream = adaptor.getAsyncDependencies().front();
1377  Value bufferSize;
1378  if (is2To4Sparsity(op.getSpmatA())) {
1379  auto pruneFlag =
1380  genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA()));
1381  auto computeType = genConstInt32From(
1382  rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
1383  auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1384  rewriter.getIndexAttr(3));
1385  auto bufferSize = rewriter.create<LLVM::AllocaOp>(
1386  loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16);
1387  createCuSparseLtSpMMBufferSizeBuilder
1388  .create(loc, rewriter,
1389  {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1390  adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1391  pruneFlag, stream})
1392  .getResult();
1393 
1394  auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>(
1395  loc, llvmPointerType, llvmPointerType, bufferSize,
1396  ValueRange{rewriter.create<LLVM::ConstantOp>(
1397  loc, getIndexType(), rewriter.getIndexAttr(1))});
1398  auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>(
1399  loc, llvmPointerType, llvmPointerType, bufferSize,
1400  ValueRange{rewriter.create<LLVM::ConstantOp>(
1401  loc, getIndexType(), rewriter.getIndexAttr(2))});
1402  auto bufferSize0 =
1403  rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize);
1404  auto bufferSize1 =
1405  rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1);
1406  auto bufferSize2 =
1407  rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
1408 
1409  rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1410  } else {
1411  auto computeType = genConstInt32From(
1412  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1413  bufferSize =
1414  createSpMMBufferSizeCallBuilder
1415  .create(loc, rewriter,
1416  {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1417  adaptor.getDnmatC(), computeType, stream})
1418  .getResult();
1419  rewriter.replaceOp(op, {bufferSize, stream});
1420  }
1421  return success();
1422 }
1423 
1424 LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1425  gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1426  ConversionPatternRewriter &rewriter) const {
1427  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1428  failed(isAsyncWithOneDependency(rewriter, op)))
1429  return failure();
1430  Location loc = op.getLoc();
1431  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1432  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1433  auto computeType = genConstInt32From(
1434  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1435  auto stream = adaptor.getAsyncDependencies().front();
1436  auto bufferSize =
1437  createSDDMMBufferSizeCallBuilder
1438  .create(loc, rewriter,
1439  {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1440  adaptor.getSpmatC(), computeType, stream})
1441  .getResult();
1442  rewriter.replaceOp(op, {bufferSize, stream});
1443  return success();
1444 }
1445 
1446 LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1447  gpu::SpMMOp op, OpAdaptor adaptor,
1448  ConversionPatternRewriter &rewriter) const {
1449  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1450  failed(isAsyncWithOneDependency(rewriter, op)))
1451  return failure();
1452  Location loc = op.getLoc();
1453  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1454  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1455  auto computeType = genConstInt32From(
1456  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1457 
1458  auto stream = adaptor.getAsyncDependencies().front();
1459 
1460  // Lower to cusparseLt if applicable
1461  if (is2To4Sparsity(op.getSpmatA())) {
1462  SmallVector<Value> pBufs;
1463  for (Value buffer : adaptor.getBuffers()) {
1464  Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
1465  pBufs.push_back(pBuf);
1466  }
1467  createCuSparseLtSpMMBuilder.create(
1468  loc, rewriter,
1469  {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1470  pBufs[0], pBufs[1], pBufs[2], stream});
1471  } else {
1472  Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
1473  .allocatedPtr(rewriter, loc);
1474  createSpMMCallBuilder.create(loc, rewriter,
1475  {modeA, modeB, adaptor.getSpmatA(),
1476  adaptor.getDnmatB(), adaptor.getDnmatC(),
1477  computeType, pBuf, stream});
1478  }
1479  rewriter.replaceOp(op, {stream});
1480  return success();
1481 }
1482 
1483 template <typename T>
1485  converter.addConversion([&converter](T) -> Type {
1486  return LLVM::LLVMPointerType::get(&converter.getContext());
1487  });
1488 }
1489 
1490 LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1491  gpu::SDDMMOp op, OpAdaptor adaptor,
1492  ConversionPatternRewriter &rewriter) const {
1493  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1494  failed(isAsyncWithOneDependency(rewriter, op)))
1495  return failure();
1496  Location loc = op.getLoc();
1497  auto computeType = genConstInt32From(
1498  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1499  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1500  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1501  auto stream = adaptor.getAsyncDependencies().front();
1502  Value pBuf =
1503  MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1504  createSDDMMCallBuilder.create(loc, rewriter,
1505  {modeA, modeB, adaptor.getDnmatA(),
1506  adaptor.getDnmatB(), adaptor.getSpmatC(),
1507  computeType, pBuf, stream});
1508  rewriter.replaceOp(op, {stream});
1509  return success();
1510 }
1511 
1512 LogicalResult
1513 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1514  gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1515  ConversionPatternRewriter &rewriter) const {
1516  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1517  failed(isAsyncWithOneDependency(rewriter, op)))
1518  return failure();
1519  Location loc = op.getLoc();
1520  auto stream = adaptor.getAsyncDependencies().front();
1521  Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1522  .getResult();
1523  rewriter.replaceOp(op, {descr, stream});
1524  return success();
1525 }
1526 
1527 LogicalResult
1528 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1529  gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1530  ConversionPatternRewriter &rewriter) const {
1531  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1532  failed(isAsyncWithOneDependency(rewriter, op)))
1533  return failure();
1534  Location loc = op.getLoc();
1535  auto stream = adaptor.getAsyncDependencies().front();
1536  createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1537  {adaptor.getDesc(), stream});
1538  rewriter.replaceOp(op, {stream});
1539  return success();
1540 }
1541 
1542 LogicalResult
1543 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1544  gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1545  ConversionPatternRewriter &rewriter) const {
1546  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1547  failed(isAsyncWithOneDependency(rewriter, op)))
1548  return failure();
1549  Location loc = op.getLoc();
1550  auto computeType = genConstInt32From(
1551  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1552  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1553  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1554  auto stream = adaptor.getAsyncDependencies().front();
1555 
1556  Value pBuf =
1557  MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1558  Value bufferSizeNew;
1559 
1560  if (adaptor.getKind() ==
1561  gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1562  bufferSizeNew =
1563  createSpGEMMWorkEstimationBuilder
1564  .create(loc, rewriter,
1565  {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1566  adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1567  adaptor.getBufferSz(), pBuf, stream})
1568  .getResult();
1569  } else {
1570  bufferSizeNew =
1571  createSpGEMMComputeBuilder
1572  .create(loc, rewriter,
1573  {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1574  adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1575  adaptor.getBufferSz(), pBuf, stream})
1576  .getResult();
1577  }
1578  rewriter.replaceOp(op, {bufferSizeNew, stream});
1579  return success();
1580 }
1581 
1582 LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1583  gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1584  ConversionPatternRewriter &rewriter) const {
1585  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1586  failed(isAsyncWithOneDependency(rewriter, op)))
1587  return failure();
1588  Location loc = op.getLoc();
1589  auto computeType = genConstInt32From(
1590  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1591  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1592  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1593  auto stream = adaptor.getAsyncDependencies().front();
1594  createSpGEMMCopyBuilder.create(loc, rewriter,
1595  {adaptor.getDesc(), modeA, modeB,
1596  adaptor.getSpmatA(), adaptor.getSpmatB(),
1597  adaptor.getSpmatC(), computeType, stream});
1598  rewriter.replaceOp(op, {stream});
1599  return success();
1600 }
1601 
1602 LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1603  gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1604  ConversionPatternRewriter &rewriter) const {
1605  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1606  failed(isAsyncWithOneDependency(rewriter, op)))
1607  return failure();
1608  Location loc = op.getLoc();
1609  auto stream = adaptor.getAsyncDependencies().front();
1610 
1611  auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1612  rewriter.getIndexAttr(3));
1613  auto buffer = rewriter.create<LLVM::AllocaOp>(
1614  loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16);
1615 
1616  auto rowsPtr = rewriter.create<LLVM::GEPOp>(
1617  loc, llvmPointerType, llvmPointerType, buffer,
1618  ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1619  rewriter.getIndexAttr(0))});
1620  auto colsPtr = rewriter.create<LLVM::GEPOp>(
1621  loc, llvmPointerType, llvmPointerType, buffer,
1622  ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1623  rewriter.getIndexAttr(1))});
1624  auto nnzsPtr = rewriter.create<LLVM::GEPOp>(
1625  loc, llvmPointerType, llvmPointerType, buffer,
1626  ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1627  rewriter.getIndexAttr(2))});
1628  createSpMatGetSizeBuilder.create(
1629  loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1630  auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr);
1631  auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr);
1632  auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr);
1633 
1634  rewriter.replaceOp(op, {rows, cols, nnzs, stream});
1635  return success();
1636 }
1637 
1638 LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1639  gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1640  ConversionPatternRewriter &rewriter) const {
1641  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1642  failed(isAsyncWithOneDependency(rewriter, op)))
1643  return failure();
1644  Location loc = op.getLoc();
1645  auto stream = adaptor.getAsyncDependencies().front();
1646  Value pPos =
1647  MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
1648  Value pCrd =
1649  MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
1650  Value pVal =
1651  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1652  createSetCsrPointersBuilder.create(
1653  loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1654  rewriter.replaceOp(op, {stream});
1655  return success();
1656 }
1657 
1658 LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1659  gpu::CreateCscOp op, OpAdaptor adaptor,
1660  ConversionPatternRewriter &rewriter) const {
1661  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1662  failed(isAsyncWithOneDependency(rewriter, op)))
1663  return failure();
1664  Location loc = op.getLoc();
1665  auto stream = adaptor.getAsyncDependencies().front();
1666  Value pColPos =
1667  MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc);
1668  Value pRowIdxs =
1669  MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1670  Value pValues =
1671  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1672  Type pType =
1673  llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1674  Type iType =
1675  llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1676  Type dType =
1677  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1678  auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1679  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1680  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1681  auto handle =
1682  createCscCallBuilder
1683  .create(loc, rewriter,
1684  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1685  pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1686  .getResult();
1687  rewriter.replaceOp(op, {handle, stream});
1688  return success();
1689 }
1690 
1691 LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1692  gpu::CreateBsrOp op, OpAdaptor adaptor,
1693  ConversionPatternRewriter &rewriter) const {
1694  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1695  failed(isAsyncWithOneDependency(rewriter, op)))
1696  return failure();
1697  Location loc = op.getLoc();
1698  auto stream = adaptor.getAsyncDependencies().front();
1699  Value pRowPos =
1700  MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc);
1701  Value pColIdxs =
1702  MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc);
1703  Value pValues =
1704  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1705  Type pType =
1706  llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1707  Type iType =
1708  llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1709  Type dType =
1710  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1711  auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1712  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1713  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1714  auto handle =
1715  createBsrCallBuilder
1716  .create(loc, rewriter,
1717  {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1718  adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1719  pColIdxs, pValues, ptp, itp, dtp, stream})
1720  .getResult();
1721  rewriter.replaceOp(op, {handle, stream});
1722  return success();
1723 }
1724 
1726  RewritePatternSet &patterns,
1727  bool kernelBarePtrCallConv) {
1728  addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
1729  addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
1730  addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
1731  addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter);
1732 
1733  patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
1734  ConvertDeallocOpToGpuRuntimeCallPattern,
1735  ConvertHostRegisterOpToGpuRuntimeCallPattern,
1736  ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1737  ConvertMemcpyOpToGpuRuntimeCallPattern,
1738  ConvertMemsetOpToGpuRuntimeCallPattern,
1739  ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1740  ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1741  ConvertWaitOpToGpuRuntimeCallPattern,
1742  ConvertAsyncYieldToGpuRuntimeCallPattern,
1743  ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1744  ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1745  ConvertCreateCooOpToGpuRuntimeCallPattern,
1746  ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1747  ConvertCreateCsrOpToGpuRuntimeCallPattern,
1748  ConvertCreateCscOpToGpuRuntimeCallPattern,
1749  ConvertCreateBsrOpToGpuRuntimeCallPattern,
1750  ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1751  ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1752  ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1753  ConvertSpMVOpToGpuRuntimeCallPattern,
1754  ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1755  ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1756  ConvertSpMMOpToGpuRuntimeCallPattern,
1757  ConvertSDDMMOpToGpuRuntimeCallPattern,
1758  ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1759  ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1760  ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1761  ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1762  ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1763  ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
1764  patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv);
1765 }
1766 
1767 //===----------------------------------------------------------------------===//
1768 // GPUModuleOp convert to LLVM op interface
1769 //===----------------------------------------------------------------------===//
1770 
1771 namespace {
1772 struct GPUModuleOpConvertToLLVMInterface
1773  : public ConvertToLLVMOpInterface::ExternalModel<
1774  GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
1775  /// Get the conversion patterns from the target attribute.
1776  void getConvertToLLVMConversionAttrs(
1778 };
1779 } // namespace
1780 
1781 void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
1783  auto module = cast<gpu::GPUModuleOp>(op);
1784  ArrayAttr targetsAttr = module.getTargetsAttr();
1785  // Fail if there are no target attributes or there is more than one target.
1786  if (!targetsAttr || targetsAttr.size() != 1)
1787  return;
1788  if (auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
1789  attrs.push_back(patternAttr);
1790 }
1791 
1793  registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
1794  gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
1795  });
1796 }
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
int64_t cols
int64_t rows
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
FloatType getF32Type()
Definition: Builders.cpp:87
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
MLIRContext * getContext() const
Definition: Builders.h:56
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:286
IntegerAttr getI8IntegerAttr(int8_t value)
Definition: Builders.cpp:261
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:36
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:53
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
This class helps build Operations.
Definition: Builders.h:216
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition: Builders.h:394
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition: Builders.h:255
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition: Builders.h:399
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:457
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:421
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
Definition: Region.h:205
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:60
bool isF32() const
Definition: Types.cpp:59
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:66
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:127
bool isF16() const
Definition: Types.cpp:57
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
bool isBF16() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:858
void registerConvertGpuToLLVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMOpInterface interface on the gpu::GPUModuleOP operation.
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Definition: MemoryOps.cpp:263
Include the generated interface declarations.
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool kernelBarePtrCallConv=false)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::LLVMFunctionType functionType
Definition: GPUCommonPass.h:60
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:39