MLIR  20.0.0git
GPUToLLVMConversion.cpp
Go to the documentation of this file.
1 //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert gpu.launch_func op into a sequence of
10 // GPU runtime calls. As most of GPU runtimes does not have a stable published
11 // ABI, this pass uses a slim runtime layer that builds on top of the public
12 // API from GPU runtime headers.
13 //
14 //===----------------------------------------------------------------------===//
15 
17 
34 #include "mlir/IR/Attributes.h"
35 #include "mlir/IR/Builders.h"
36 #include "mlir/IR/BuiltinOps.h"
37 #include "mlir/IR/BuiltinTypes.h"
38 
39 #include "llvm/ADT/STLExtras.h"
40 #include "llvm/Support/Error.h"
41 #include "llvm/Support/FormatVariadic.h"
42 
43 #define DEBUG_TYPE "gpu-to-llvm"
44 
45 namespace mlir {
46 #define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
47 #include "mlir/Conversion/Passes.h.inc"
48 } // namespace mlir
49 
50 using namespace mlir;
51 
52 namespace {
53 class GpuToLLVMConversionPass
54  : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
55 public:
56  using Base::Base;
57  void getDependentDialects(DialectRegistry &registry) const final {
58  Base::getDependentDialects(registry);
60  }
61  // Run the dialect converter on the module.
62  void runOnOperation() override;
63 };
64 
65 template <typename OpTy>
66 class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
67 public:
68  explicit ConvertOpToGpuRuntimeCallPattern(
69  const LLVMTypeConverter &typeConverter)
70  : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
71 
72 protected:
74  MemRefType type, MemRefDescriptor desc) const {
76  return type.hasStaticShape()
78  rewriter, loc, indexType, type.getNumElements())
79  // For identity maps (verified by caller), the number of
80  // elements is stride[0] * size[0].
81  : rewriter.create<LLVM::MulOp>(loc,
82  desc.stride(rewriter, loc, 0),
83  desc.size(rewriter, loc, 0));
84  }
85 
86  MLIRContext *context = &this->getTypeConverter()->getContext();
87 
88  Type llvmVoidType = LLVM::LLVMVoidType::get(context);
89  LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
90  Type llvmInt8Type = IntegerType::get(context, 8);
91  Type llvmInt16Type = IntegerType::get(context, 16);
92  Type llvmInt32Type = IntegerType::get(context, 32);
93  Type llvmInt64Type = IntegerType::get(context, 64);
94  Type llvmFloat32Type = Float32Type::get(context);
95  Type llvmIntPtrType = IntegerType::get(
96  context, this->getTypeConverter()->getPointerBitwidth(0));
97 
98  FunctionCallBuilder streamCreateCallBuilder = {
99  "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
100  FunctionCallBuilder streamDestroyCallBuilder = {
101  "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
102  FunctionCallBuilder streamSynchronizeCallBuilder = {
103  "mgpuStreamSynchronize",
104  llvmVoidType,
105  {llvmPointerType /* void *stream */}};
106  FunctionCallBuilder streamWaitEventCallBuilder = {
107  "mgpuStreamWaitEvent",
108  llvmVoidType,
109  {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
110  FunctionCallBuilder eventCreateCallBuilder = {
111  "mgpuEventCreate", llvmPointerType /* void *event */, {}};
112  FunctionCallBuilder eventDestroyCallBuilder = {
113  "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
114  FunctionCallBuilder eventSynchronizeCallBuilder = {
115  "mgpuEventSynchronize",
116  llvmVoidType,
117  {llvmPointerType /* void *event */}};
118  FunctionCallBuilder eventRecordCallBuilder = {
119  "mgpuEventRecord",
120  llvmVoidType,
121  {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
122  FunctionCallBuilder hostRegisterCallBuilder = {
123  "mgpuMemHostRegisterMemRef",
124  llvmVoidType,
125  {llvmIntPtrType /* intptr_t rank */,
126  llvmPointerType /* void *memrefDesc */,
127  llvmIntPtrType /* intptr_t elementSizeBytes */}};
128  FunctionCallBuilder hostUnregisterCallBuilder = {
129  "mgpuMemHostUnregisterMemRef",
130  llvmVoidType,
131  {llvmIntPtrType /* intptr_t rank */,
132  llvmPointerType /* void *memrefDesc */,
133  llvmIntPtrType /* intptr_t elementSizeBytes */}};
134  FunctionCallBuilder allocCallBuilder = {
135  "mgpuMemAlloc",
136  llvmPointerType /* void * */,
137  {llvmIntPtrType /* intptr_t sizeBytes */,
138  llvmPointerType /* void *stream */,
139  llvmInt8Type /* bool isHostShared */}};
140  FunctionCallBuilder deallocCallBuilder = {
141  "mgpuMemFree",
142  llvmVoidType,
143  {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
144  FunctionCallBuilder memcpyCallBuilder = {
145  "mgpuMemcpy",
146  llvmVoidType,
147  {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
148  llvmIntPtrType /* intptr_t sizeBytes */,
149  llvmPointerType /* void *stream */}};
150  FunctionCallBuilder memset16CallBuilder = {
151  "mgpuMemset16",
152  llvmVoidType,
153  {llvmPointerType /* void *dst */,
154  llvmInt16Type /* unsigned short value */,
155  llvmIntPtrType /* intptr_t sizeBytes */,
156  llvmPointerType /* void *stream */}};
157  FunctionCallBuilder memset32CallBuilder = {
158  "mgpuMemset32",
159  llvmVoidType,
160  {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
161  llvmIntPtrType /* intptr_t sizeBytes */,
162  llvmPointerType /* void *stream */}};
163  FunctionCallBuilder setDefaultDeviceCallBuilder = {
164  "mgpuSetDefaultDevice",
165  llvmVoidType,
166  {llvmInt32Type /* uint32_t devIndex */}};
167  FunctionCallBuilder createDnVecCallBuilder = {
168  "mgpuCreateDnVec",
169  llvmPointerType,
170  {llvmIntPtrType, llvmPointerType, llvmInt32Type,
171  llvmPointerType /* void *stream */}};
172  FunctionCallBuilder destroyDnVecCallBuilder = {
173  "mgpuDestroyDnVec",
174  llvmVoidType,
175  {llvmPointerType, llvmPointerType /* void *stream */}};
176  FunctionCallBuilder createDnMatCallBuilder = {
177  "mgpuCreateDnMat",
178  llvmPointerType,
179  {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
180  llvmPointerType /* void *stream */}};
181  FunctionCallBuilder destroyDnMatCallBuilder = {
182  "mgpuDestroyDnMat",
183  llvmVoidType,
184  {llvmPointerType, llvmPointerType /* void *stream */}};
185  FunctionCallBuilder createCooCallBuilder = {
186  "mgpuCreateCoo",
187  llvmPointerType,
188  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
189  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
190  llvmPointerType /* void *stream */}};
191  FunctionCallBuilder createCooAoSCallBuilder = {
192  "mgpuCreateCooAoS", // deprecated in cuSPARSE 11.2
193  llvmPointerType,
194  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
195  llvmPointerType, llvmInt32Type, llvmInt32Type,
196  llvmPointerType /* void *stream */}};
197  FunctionCallBuilder createCsrCallBuilder = {
198  "mgpuCreateCsr",
199  llvmPointerType,
200  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
201  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
202  llvmInt32Type, llvmPointerType /* void *stream */}};
203  FunctionCallBuilder createCscCallBuilder = {
204  "mgpuCreateCsc",
205  llvmPointerType,
206  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
207  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
208  llvmInt32Type, llvmPointerType /* void *stream */}};
209  FunctionCallBuilder createBsrCallBuilder = {
210  "mgpuCreateBsr",
211  llvmPointerType,
212  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
213  llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
214  llvmInt32Type, llvmInt32Type, llvmInt32Type,
215  llvmPointerType /* void *stream */}};
216  FunctionCallBuilder destroySpMatCallBuilder = {
217  "mgpuDestroySpMat",
218  llvmVoidType,
219  {llvmPointerType, llvmPointerType /* void *stream */}};
220  FunctionCallBuilder spMVBufferSizeCallBuilder = {
221  "mgpuSpMVBufferSize",
222  llvmIntPtrType,
223  {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
224  llvmInt32Type, llvmPointerType /* void *stream */}};
225  FunctionCallBuilder spMVCallBuilder = {
226  "mgpuSpMV",
227  llvmVoidType,
228  {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
229  llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
230  FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
231  "mgpuSpMMBufferSize",
232  llvmIntPtrType,
233  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
234  llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
235  FunctionCallBuilder createSpMMCallBuilder = {
236  "mgpuSpMM",
237  llvmVoidType,
238  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
239  llvmPointerType, llvmInt32Type, llvmPointerType,
240  llvmPointerType /* void *stream */}};
241  FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
242  "mgpuSDDMMBufferSize",
243  llvmIntPtrType,
244  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
245  llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
246  FunctionCallBuilder createSDDMMCallBuilder = {
247  "mgpuSDDMM",
248  llvmVoidType,
249  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
250  llvmPointerType, llvmInt32Type, llvmPointerType,
251  llvmPointerType /* void *stream */}};
252  FunctionCallBuilder createLtDnMatCallBuilder = {
253  "mgpuCreateCuSparseLtDnMat",
254  llvmVoidType,
255  {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
256  llvmInt32Type, llvmPointerType /* void *stream */}};
257  FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
258  "mgpuDestroyCuSparseLtSpMat",
259  llvmVoidType,
260  {llvmPointerType, llvmPointerType /* void *stream */}};
261  FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
262  "mgpuDestroyCuSparseLtDnMat",
263  llvmVoidType,
264  {llvmPointerType, llvmPointerType /* void *stream */}};
265  FunctionCallBuilder create2To4SpMatCallBuilder = {
266  "mgpuCusparseLtCreate2To4SpMat",
267  llvmVoidType,
268  {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
269  llvmInt32Type, llvmPointerType /* void *stream */}};
270  FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
271  "mgpuCuSparseLtSpMMBufferSize",
272  llvmVoidType,
273  {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
274  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
275  llvmPointerType /*void *stream*/}};
276  FunctionCallBuilder createCuSparseLtSpMMBuilder = {
277  "mgpuCuSparseLtSpMM",
278  llvmVoidType,
279  {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
280  llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}};
281  FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
282  "mgpuSpGEMMCreateDescr",
283  llvmPointerType,
284  {llvmPointerType /*void *stream*/}};
285  FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
286  "mgpuSpGEMMDestroyDescr",
287  llvmVoidType,
288  {llvmPointerType /*s*/, llvmPointerType /*void *stream*/}};
289  FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
290  "mgpuSpGEMMWorkEstimation",
291  llvmIntPtrType,
292  {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
293  llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
294  llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
295  llvmPointerType /*void *stream*/}};
296  FunctionCallBuilder createSpGEMMComputeBuilder = {
297  "mgpuSpGEMMCompute",
298  llvmIntPtrType,
299  {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
300  llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
301  llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
302  llvmPointerType /*void *stream*/}};
303  FunctionCallBuilder createSpGEMMCopyBuilder = {
304  "mgpuSpGEMMCopy",
305  llvmVoidType,
306  {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
307  llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
308  llvmInt32Type /*ctp*/, llvmPointerType /*void *stream*/}};
309  FunctionCallBuilder createSpMatGetSizeBuilder = {
310  "mgpuSpMatGetSize",
311  llvmVoidType,
312  {llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/,
313  llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}};
314  FunctionCallBuilder createSetCsrPointersBuilder = {
315  "mgpuSetCsrPointers",
316  llvmVoidType,
317  {llvmPointerType /*spmat*/, llvmPointerType /*pos*/,
318  llvmPointerType /*crd*/, llvmPointerType /*val*/,
319  llvmPointerType /*void *stream*/}};
320 };
321 
322 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
323 /// call. Currently it supports CUDA and ROCm (HIP).
324 class ConvertHostRegisterOpToGpuRuntimeCallPattern
325  : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
326 public:
327  ConvertHostRegisterOpToGpuRuntimeCallPattern(
328  const LLVMTypeConverter &typeConverter)
329  : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
330 
331 private:
332  LogicalResult
333  matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
334  ConversionPatternRewriter &rewriter) const override;
335 };
336 
337 class ConvertHostUnregisterOpToGpuRuntimeCallPattern
338  : public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
339 public:
340  ConvertHostUnregisterOpToGpuRuntimeCallPattern(
341  const LLVMTypeConverter &typeConverter)
342  : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
343  }
344 
345 private:
346  LogicalResult
347  matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
348  ConversionPatternRewriter &rewriter) const override;
349 };
350 
351 /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
352 /// call. Currently it supports CUDA and ROCm (HIP).
353 class ConvertAllocOpToGpuRuntimeCallPattern
354  : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
355 public:
356  ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
357  : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
358 
359 private:
360  LogicalResult
361  matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
362  ConversionPatternRewriter &rewriter) const override;
363 };
364 
365 /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
366 /// call. Currently it supports CUDA and ROCm (HIP).
367 class ConvertDeallocOpToGpuRuntimeCallPattern
368  : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
369 public:
370  ConvertDeallocOpToGpuRuntimeCallPattern(
371  const LLVMTypeConverter &typeConverter)
372  : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
373 
374 private:
375  LogicalResult
376  matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
377  ConversionPatternRewriter &rewriter) const override;
378 };
379 
380 class ConvertAsyncYieldToGpuRuntimeCallPattern
381  : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
382 public:
383  ConvertAsyncYieldToGpuRuntimeCallPattern(
384  const LLVMTypeConverter &typeConverter)
385  : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
386 
387 private:
388  LogicalResult
389  matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
390  ConversionPatternRewriter &rewriter) const override;
391 };
392 
393 /// A rewrite pattern to convert gpu.wait operations into a GPU runtime
394 /// call. Currently it supports CUDA and ROCm (HIP).
395 class ConvertWaitOpToGpuRuntimeCallPattern
396  : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
397 public:
398  ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
399  : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
400 
401 private:
402  LogicalResult
403  matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
404  ConversionPatternRewriter &rewriter) const override;
405 };
406 
407 /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
408 /// call. Currently it supports CUDA and ROCm (HIP).
409 class ConvertWaitAsyncOpToGpuRuntimeCallPattern
410  : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
411 public:
412  ConvertWaitAsyncOpToGpuRuntimeCallPattern(
413  const LLVMTypeConverter &typeConverter)
414  : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
415 
416 private:
417  LogicalResult
418  matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
419  ConversionPatternRewriter &rewriter) const override;
420 };
421 
422 /// A rewrite patter to legalize gpu.launch_func with LLVM types.
423 class LegalizeLaunchFuncOpPattern
424  : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
425 public:
426  LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
427  bool kernelBarePtrCallConv)
428  : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
429  kernelBarePtrCallConv(kernelBarePtrCallConv) {}
430 
431 private:
432  LogicalResult
433  matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
434  ConversionPatternRewriter &rewriter) const override;
435 
436  bool kernelBarePtrCallConv;
437 };
438 
439 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
440 /// call. Currently it supports CUDA and ROCm (HIP).
441 class ConvertMemcpyOpToGpuRuntimeCallPattern
442  : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
443 public:
444  ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
445  : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
446 
447 private:
448  LogicalResult
449  matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
450  ConversionPatternRewriter &rewriter) const override;
451 };
452 
453 /// A rewrite pattern to convert gpu.memset operations into a GPU runtime
454 /// call. Currently it supports CUDA and ROCm (HIP).
455 class ConvertMemsetOpToGpuRuntimeCallPattern
456  : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
457 public:
458  ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
459  : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
460 
461 private:
462  LogicalResult
463  matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
464  ConversionPatternRewriter &rewriter) const override;
465 };
466 
467 /// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call.
468 /// Currently supports CUDA and ROCm (HIP)
469 class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
470  : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
471 public:
472  ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
473  const LLVMTypeConverter &typeConverter)
474  : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
475  typeConverter) {}
476 
477  LogicalResult
478  matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
479  ConversionPatternRewriter &rewriter) const override;
480 };
481 
482 /// Generic rewriting rule for operation on sparse matrices.
483 /// Currently supports CUDA (by means of cuSparse and cuSparseLt).
484 #define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
485  class Convert##op_name##ToGpuRuntimeCallPattern \
486  : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
487  public: \
488  Convert##op_name##ToGpuRuntimeCallPattern( \
489  const LLVMTypeConverter &typeConverter) \
490  : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
491  \
492  private: \
493  LogicalResult \
494  matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
495  ConversionPatternRewriter &rewriter) const override; \
496  };
497 
515 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp)
519 
520 } // namespace
521 
522 void GpuToLLVMConversionPass::runOnOperation() {
523  MLIRContext *context = &getContext();
524  LowerToLLVMOptions options(context);
525  options.useBarePtrCallConv = hostBarePtrCallConv;
526  RewritePatternSet patterns(context);
527  ConversionTarget target(*context);
528  target.addLegalDialect<LLVM::LLVMDialect>();
529  LLVMTypeConverter converter(context, options);
530 
531  // Populate all patterns from all dialects that implement the
532  // `ConvertToLLVMPatternInterface` interface.
533  for (Dialect *dialect : context->getLoadedDialects()) {
534  auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
535  if (!iface)
536  continue;
537  iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
538  }
539 
540  // Preserve GPU modules and binaries. Modules are preserved as they can be
541  // converted later by `gpu-module-to-binary`.
542  target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
543  // Accept as legal LaunchFuncOps if the operands have been lowered.
544  target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
545  [&](gpu::LaunchFuncOp op) -> bool { return converter.isLegal(op); });
546 
547  // These aren't covered by the ConvertToLLVMPatternInterface right now.
548  populateVectorToLLVMConversionPatterns(converter, patterns);
551  target);
552  populateGpuToLLVMConversionPatterns(converter, patterns,
553  kernelBarePtrCallConv);
554 
555  if (failed(
556  applyPartialConversion(getOperation(), target, std::move(patterns))))
557  signalPassFailure();
558 }
559 
560 LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
561  ArrayRef<Value> arguments) const {
562  auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
563  auto function = [&] {
564  if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
565  return function;
566  return OpBuilder::atBlockEnd(module.getBody())
567  .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
568  }();
569  return builder.create<LLVM::CallOp>(loc, function, arguments);
570 }
571 
572 // Corresponding to cusparseIndexType_t defined in cusparse.h.
573 static int32_t getCuSparseIndexTypeFrom(Type type) {
574  if (type.isInteger(16))
575  return 1; // CUSPARSE_INDEX_16U
576  if (type.isInteger(32))
577  return 2; // CUSPARSE_INDEX_32I
578  return 3; // CUSPARSE_INDEX_64I
579 }
580 
581 static int32_t getCuSparseLtDataTypeFrom(Type type) {
582  if (type.isF16())
583  return 0; // CUSPARSE_COMPUTE_16F,
584  if (type.isInteger(32))
585  return 1; // CUSPARSE_COMPUTE_32I
586  llvm_unreachable("unsupported type");
587  // TODO: add support to TF32
588 }
589 
590 // Corresponding to cudaDataType_t defined in CUDA library_types.h.
591 static int32_t getCuSparseDataTypeFrom(Type type) {
592  if (llvm::isa<ComplexType>(type)) {
593  // get the element type
594  auto elementType = cast<ComplexType>(type).getElementType();
595  if (elementType.isBF16())
596  return 15; // CUDA_C_16BF
597  if (elementType.isF16())
598  return 6; // CUDA_C_16F
599  if (elementType.isF32())
600  return 4; // CUDA_C_32F
601  if (elementType.isF64())
602  return 5; // CUDA_C_64F
603  if (elementType.isInteger(8))
604  return 7; // CUDA_C_8I
605  if (elementType.isInteger(16))
606  return 21; // CUDA_C_16I
607  if (elementType.isInteger(32))
608  return 11; // CUDA_C_32I
609  }
610  if (type.isBF16())
611  return 14; // CUDA_R_16BF
612  if (type.isF16())
613  return 2; // CUDA_R_16F
614  if (type.isF32())
615  return 0; // CUDA_R_32F
616  if (type.isF64())
617  return 1; // CUDA_R_64F
618  if (type.isInteger(8))
619  return 3; // CUDA_R_8I
620  if (type.isInteger(16))
621  return 20; // CUDA_R_16I
622  if (type.isInteger(32))
623  return 10; // CUDA_R_32I
624 
625  llvm_unreachable("unsupported element type");
626 }
627 
628 static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat) {
629  return spMat.getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
630 }
631 
632 // TODO: We may want a run-time (of the mlir compiler) disablement/warning:
633 // cusparseLt currently won't work for cuda architecture <8.0 and will trigger a
634 // runtime (of the CUDA program) error , but it might be great if we could at
635 // least output a warning when we found the target architecture is <8.0 and the
636 // user still wants to use cusparseLt. to make sure when lowering gpu sparse
637 // dialect to llvm calls, the cusparselt calls are disabled for cuda
638 // architecture <8.0
639 static bool is2To4Sparsity(Value spMat) {
640  if (auto op = spMat.getDefiningOp<gpu::Create2To4SpMatOp>())
641  return true;
642  if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
643  return false;
644  if (auto op = spMat.getDefiningOp<gpu::CreateCooAoSOp>())
645  return false;
646  if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
647  return false;
648  if (auto op = spMat.getDefiningOp<gpu::CreateCscOp>())
649  return false;
650  if (auto op = spMat.getDefiningOp<gpu::CreateBsrOp>())
651  return false;
652  // Print the spMat defining op
653  spMat.getDefiningOp()->print(llvm::errs());
654  llvm_unreachable("cannot find spmat def");
655 }
656 
657 static bool isSpMMCusparseLtOp(Value op) {
658  for (Operation *user : op.getUsers()) {
659  auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
660  // If the other operator is 50% sparsity then we should use cusparseLt
661  if (!spmmOp)
662  continue;
663  if (is2To4Sparsity(spmmOp.getSpmatA()))
664  return true;
665  }
666  return false;
667 }
668 
669 // Returns whether all operands are of LLVM type.
670 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
671  ConversionPatternRewriter &rewriter) {
672  if (!llvm::all_of(operands, [](Value value) {
673  return LLVM::isCompatibleType(value.getType());
674  }))
675  return rewriter.notifyMatchFailure(
676  op, "Cannot convert if operands aren't of LLVM type.");
677  return success();
678 }
679 
680 static LogicalResult
682  gpu::AsyncOpInterface op) {
683  if (op.getAsyncDependencies().size() != 1)
684  return rewriter.notifyMatchFailure(
685  op, "Can only convert with exactly one async dependency.");
686 
687  if (!op.getAsyncToken())
688  return rewriter.notifyMatchFailure(op, "Can convert only async version.");
689 
690  return success();
691 }
692 
693 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
694  gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
695  ConversionPatternRewriter &rewriter) const {
696  auto *op = hostRegisterOp.getOperation();
697  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
698  return failure();
699 
700  Location loc = op->getLoc();
701 
702  auto memRefType = hostRegisterOp.getValue().getType();
703  auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
704  auto elementSize = getSizeInBytes(loc, elementType, rewriter);
705 
706  auto arguments = getTypeConverter()->promoteOperands(
707  loc, op->getOperands(), adaptor.getOperands(), rewriter);
708  arguments.push_back(elementSize);
709  hostRegisterCallBuilder.create(loc, rewriter, arguments);
710 
711  rewriter.eraseOp(op);
712  return success();
713 }
714 
715 LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
716  gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
717  ConversionPatternRewriter &rewriter) const {
718  Operation *op = hostUnregisterOp.getOperation();
719  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
720  return failure();
721 
722  Location loc = op->getLoc();
723 
724  auto memRefType = hostUnregisterOp.getValue().getType();
725  auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
726  auto elementSize = getSizeInBytes(loc, elementType, rewriter);
727 
728  auto arguments = getTypeConverter()->promoteOperands(
729  loc, op->getOperands(), adaptor.getOperands(), rewriter);
730  arguments.push_back(elementSize);
731  hostUnregisterCallBuilder.create(loc, rewriter, arguments);
732 
733  rewriter.eraseOp(op);
734  return success();
735 }
736 
737 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
738  gpu::AllocOp allocOp, OpAdaptor adaptor,
739  ConversionPatternRewriter &rewriter) const {
740 
741  MemRefType memRefType = allocOp.getType();
742 
743  if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
744  !isConvertibleAndHasIdentityMaps(memRefType))
745  return failure();
746 
747  auto loc = allocOp.getLoc();
748 
749  bool isShared = allocOp.getHostShared();
750 
751  if (isShared && allocOp.getAsyncToken())
752  return rewriter.notifyMatchFailure(
753  allocOp, "Host Shared allocation cannot be done async");
754  if (!isShared && failed(isAsyncWithOneDependency(rewriter, allocOp)))
755  return failure();
756 
757  // Get shape of the memref as values: static sizes are constant
758  // values and dynamic sizes are passed to 'alloc' as operands.
759  SmallVector<Value, 4> shape;
760  SmallVector<Value, 4> strides;
761  Value sizeBytes;
762  getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
763  shape, strides, sizeBytes);
764 
765  // Allocate the underlying buffer and store a pointer to it in the MemRef
766  // descriptor.
767  auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
768  Value stream = adaptor.getAsyncDependencies().empty()
769  ? nullPtr
770  : adaptor.getAsyncDependencies().front();
771 
772  auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>(
773  loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
774 
775  Value allocatedPtr =
776  allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
777  .getResult();
778 
779  // No alignment.
780  Value alignedPtr = allocatedPtr;
781 
782  // Create the MemRef descriptor.
783  auto memRefDescriptor = this->createMemRefDescriptor(
784  loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
785 
786  if (allocOp.getAsyncToken()) {
787  // Async alloc: make dependent ops use the same stream.
788  rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
789  } else {
790  rewriter.replaceOp(allocOp, {memRefDescriptor});
791  }
792 
793  return success();
794 }
795 
796 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
797  gpu::DeallocOp deallocOp, OpAdaptor adaptor,
798  ConversionPatternRewriter &rewriter) const {
799  if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
800  failed(isAsyncWithOneDependency(rewriter, deallocOp)))
801  return failure();
802 
803  Location loc = deallocOp.getLoc();
804 
805  Value pointer =
806  MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
807  Value stream = adaptor.getAsyncDependencies().front();
808  deallocCallBuilder.create(loc, rewriter, {pointer, stream});
809 
810  rewriter.replaceOp(deallocOp, {stream});
811  return success();
812 }
813 
814 static bool isGpuAsyncTokenType(Value value) {
815  return isa<gpu::AsyncTokenType>(value.getType());
816 }
817 
818 // Converts !gpu.async.token operands of `async.yield` to runtime calls. The
819 // !gpu.async.token are lowered to stream within the async.execute region, but
820 // are passed as events between them. For each !gpu.async.token operand, we
821 // create an event and record it on the stream.
822 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
823  async::YieldOp yieldOp, OpAdaptor adaptor,
824  ConversionPatternRewriter &rewriter) const {
825  if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType))
826  return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
827 
828  Location loc = yieldOp.getLoc();
829  SmallVector<Value, 4> newOperands(adaptor.getOperands());
830  llvm::SmallDenseSet<Value> streams;
831  for (auto &operand : yieldOp->getOpOperands()) {
832  if (!isGpuAsyncTokenType(operand.get()))
833  continue;
834  auto idx = operand.getOperandNumber();
835  auto stream = adaptor.getOperands()[idx];
836  auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
837  eventRecordCallBuilder.create(loc, rewriter, {event, stream});
838  newOperands[idx] = event;
839  streams.insert(stream);
840  }
841  for (auto stream : streams)
842  streamDestroyCallBuilder.create(loc, rewriter, {stream});
843 
844  rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
845  return success();
846 }
847 
848 // Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
849 static bool isDefinedByCallTo(Value value, StringRef functionName) {
850  assert(isa<LLVM::LLVMPointerType>(value.getType()));
851  if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
852  return *defOp.getCallee() == functionName;
853  return false;
854 }
855 
856 // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
857 // with the stream/event operands. The operands are destroyed. That is, it
858 // assumes that it is not used afterwards or elsewhere. Otherwise we will get a
859 // runtime error. Eventually, we should guarantee this property.
860 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
861  gpu::WaitOp waitOp, OpAdaptor adaptor,
862  ConversionPatternRewriter &rewriter) const {
863  if (waitOp.getAsyncToken())
864  return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
865 
866  Location loc = waitOp.getLoc();
867 
868  for (auto operand : adaptor.getOperands()) {
869  if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
870  // The converted operand's definition created a stream.
871  streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
872  streamDestroyCallBuilder.create(loc, rewriter, {operand});
873  } else {
874  // Otherwise the converted operand is an event. This assumes that we use
875  // events in control flow code as well.
876  eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
877  eventDestroyCallBuilder.create(loc, rewriter, {operand});
878  }
879  }
880 
881  rewriter.eraseOp(waitOp);
882  return success();
883 }
884 
885 // Converts `gpu.wait async` to runtime calls. The converted op creates a new
886 // stream that is synchronized with stream/event operands. The operands are
887 // destroyed. That is, it assumes that it is not used afterwards or elsewhere.
888 // Otherwise we will get a runtime error. Eventually, we should guarantee this
889 // property.
890 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
891  gpu::WaitOp waitOp, OpAdaptor adaptor,
892  ConversionPatternRewriter &rewriter) const {
893  if (!waitOp.getAsyncToken())
894  return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
895 
896  Location loc = waitOp.getLoc();
897 
898  auto insertionPoint = rewriter.saveInsertionPoint();
899  SmallVector<Value, 1> events;
900  for (auto pair :
901  llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
902  auto operand = std::get<1>(pair);
903  if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
904  // The converted operand's definition created a stream. Insert an event
905  // into the stream just after the last use of the original token operand.
906  auto *defOp = std::get<0>(pair).getDefiningOp();
907  rewriter.setInsertionPointAfter(defOp);
908  auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
909  eventRecordCallBuilder.create(loc, rewriter, {event, operand});
910  events.push_back(event);
911  } else {
912  // Otherwise the converted operand is an event. This assumes that we use
913  // events in control flow code as well.
914  events.push_back(operand);
915  }
916  }
917  rewriter.restoreInsertionPoint(insertionPoint);
918  auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
919  for (auto event : events)
920  streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
921  for (auto event : events)
922  eventDestroyCallBuilder.create(loc, rewriter, {event});
923  rewriter.replaceOp(waitOp, {stream});
924 
925  return success();
926 }
927 
928 // Legalize the op's operands.
929 LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
930  gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
931  ConversionPatternRewriter &rewriter) const {
932  if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
933  return failure();
934 
935  if (launchOp.getAsyncDependencies().size() > 1)
936  return rewriter.notifyMatchFailure(
937  launchOp, "Cannot convert with more than one async dependency.");
938 
939  // Fail when the synchronous version of the op has async dependencies. The
940  // lowering destroys the stream, and we do not want to check that there is no
941  // use of the stream after this op.
942  if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
943  return rewriter.notifyMatchFailure(
944  launchOp, "Cannot convert non-async op with async dependencies.");
945 
946  Location loc = launchOp.getLoc();
947 
948  Value stream = Value();
949  if (!adaptor.getAsyncDependencies().empty())
950  stream = adaptor.getAsyncDependencies().front();
951  // If the async keyword is present and there are no dependencies, then a
952  // stream must be created to pass to subsequent operations.
953  else if (launchOp.getAsyncToken())
954  stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
955  // Lower the kernel operands to match kernel parameters.
956  // Note: If `useBarePtrCallConv` is set in the type converter's options,
957  // the value of `kernelBarePtrCallConv` will be ignored.
958  SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands(
959  loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(), rewriter,
960  /*useBarePtrCallConv=*/kernelBarePtrCallConv);
961 
962  std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
963  if (launchOp.hasClusterSize()) {
964  clusterSize =
965  gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
966  adaptor.getClusterSizeZ()};
967  }
968  rewriter.create<gpu::LaunchFuncOp>(
969  launchOp.getLoc(), launchOp.getKernelAttr(),
970  gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
971  adaptor.getGridSizeZ()},
972  gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
973  adaptor.getBlockSizeZ()},
974  adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
975  if (launchOp.getAsyncToken())
976  rewriter.replaceOp(launchOp, {stream});
977  else
978  rewriter.eraseOp(launchOp);
979  return success();
980 }
981 
983  ConversionPatternRewriter &rewriter,
984  LLVM::LLVMPointerType destinationType,
985  Value sourcePtr,
986  const LLVMTypeConverter &typeConverter) {
987  auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
988  if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
989  sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>(
990  loc,
992  destinationType.getAddressSpace()),
993  sourcePtr);
994  return sourcePtr;
995 }
996 
997 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
998  gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
999  ConversionPatternRewriter &rewriter) const {
1000  auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1001 
1002  if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
1003  !isConvertibleAndHasIdentityMaps(memRefType) ||
1004  failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
1005  return failure();
1006 
1007  auto loc = memcpyOp.getLoc();
1008 
1009  MemRefDescriptor srcDesc(adaptor.getSrc());
1010  Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
1011 
1012  Type elementPtrType = getElementPtrType(memRefType);
1013  Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
1014  Value gepPtr = rewriter.create<LLVM::GEPOp>(
1015  loc, elementPtrType,
1016  typeConverter->convertType(memRefType.getElementType()), nullPtr,
1017  numElements);
1018  auto sizeBytes =
1019  rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
1020 
1021  auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1022  srcDesc.alignedPtr(rewriter, loc),
1023  *getTypeConverter());
1024  auto dst = bitAndAddrspaceCast(
1025  loc, rewriter, llvmPointerType,
1026  MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
1027  *getTypeConverter());
1028 
1029  auto stream = adaptor.getAsyncDependencies().front();
1030  memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1031 
1032  rewriter.replaceOp(memcpyOp, {stream});
1033 
1034  return success();
1035 }
1036 
1037 LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1038  gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1039  ConversionPatternRewriter &rewriter) const {
1040  auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1041 
1042  if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
1043  !isConvertibleAndHasIdentityMaps(memRefType) ||
1044  failed(isAsyncWithOneDependency(rewriter, memsetOp)))
1045  return failure();
1046 
1047  auto loc = memsetOp.getLoc();
1048 
1049  Type valueType = adaptor.getValue().getType();
1050  unsigned bitWidth = valueType.getIntOrFloatBitWidth();
1051  // Ints and floats of 16 or 32 bit width are allowed.
1052  if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1053  return rewriter.notifyMatchFailure(
1054  memsetOp, "value must be a 16 or 32 bit int or float");
1055  }
1056 
1057  unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth();
1058  Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1059 
1060  MemRefDescriptor dstDesc(adaptor.getDst());
1061  Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
1062 
1063  auto value =
1064  rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
1065  auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1066  dstDesc.alignedPtr(rewriter, loc),
1067  *getTypeConverter());
1068 
1069  auto stream = adaptor.getAsyncDependencies().front();
1070  FunctionCallBuilder builder =
1071  valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1072  builder.create(loc, rewriter, {dst, value, numElements, stream});
1073 
1074  rewriter.replaceOp(memsetOp, {stream});
1075  return success();
1076 }
1077 
1078 LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1079  gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1080  ConversionPatternRewriter &rewriter) const {
1081  Location loc = op.getLoc();
1082  auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1083  {adaptor.getDevIndex()});
1084  rewriter.replaceOp(op, call);
1085  return success();
1086 }
1087 
1088 template <typename T>
1089 static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) {
1090  Type llvmInt32Type = builder.getIntegerType(32);
1091  return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
1092  static_cast<int32_t>(tValue));
1093 }
1094 
1095 template <typename T>
1096 static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) {
1097  Type llvmFloat32Type = builder.getF32Type();
1098  return builder.create<LLVM::ConstantOp>(
1099  loc, llvmFloat32Type,
1100  builder.getF32FloatAttr(static_cast<float>(tValue)));
1101 }
1102 
1103 LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1104  gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1105  ConversionPatternRewriter &rewriter) const {
1106  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1107  failed(isAsyncWithOneDependency(rewriter, op)))
1108  return failure();
1109  Location loc = op.getLoc();
1110  auto stream = adaptor.getAsyncDependencies().front();
1111  Value pTensor =
1112  MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1113  Type dType = op.getMemref().getType().getElementType();
1114  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1115 
1116  SmallVector<Value, 4> dims;
1117  for (Value dim : adaptor.getDims()) {
1118  dims.push_back(dim);
1119  }
1120 
1121  Value handle;
1122  // TODO: For now, we track the use of the handle and lower it to cusparse /
1123  // cusparseLt accordingly. If in a block, both cusparse and cusparseLt are
1124  // used, we require two separate Creation ops to be the correct logic. In
1125  // future, we may add support to using one handle in sparse tensor / GPU
1126  // dialect in both cusparse and cusparseLt. use the cusparseLt create call if
1127  // the dnmat is used with spmat with 2:4 sparsity
1128  if (dims.size() == 2) {
1129  if (isSpMMCusparseLtOp(op.getDnTensor())) {
1130  auto handleSz = rewriter.create<LLVM::ConstantOp>(
1131  loc, getIndexType(), rewriter.getIndexAttr(11032));
1132  handle = rewriter.create<LLVM::AllocaOp>(
1133  loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1134  handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1135 
1136  createLtDnMatCallBuilder
1137  .create(loc, rewriter,
1138  {handle, dims[0], dims[1], pTensor, dtp, stream})
1139  .getResult();
1140  } else {
1141  handle =
1142  createDnMatCallBuilder
1143  .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1144  .getResult();
1145  }
1146  } else {
1147  assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1148  handle = createDnVecCallBuilder
1149  .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1150  .getResult();
1151  }
1152  rewriter.replaceOp(op, {handle, stream});
1153  return success();
1154 }
1155 
1156 LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1157  gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1158  ConversionPatternRewriter &rewriter) const {
1159  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1160  failed(isAsyncWithOneDependency(rewriter, op)))
1161  return failure();
1162  Location loc = op.getLoc();
1163  auto stream = adaptor.getAsyncDependencies().front();
1164  auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
1165  SmallVector<Value, 4> dims;
1166  for (Value dim : definingOp.getDims()) {
1167  dims.push_back(dim);
1168  }
1169  if (dims.size() == 2) {
1170  // Use the cusparseLt destroy call if the dnmat is used with spmat with
1171  // 2:4 sparsity
1172  if (isSpMMCusparseLtOp(op.getDnTensor())) {
1173  destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1174  {adaptor.getDnTensor(), stream});
1175  } else {
1176  destroyDnMatCallBuilder.create(loc, rewriter,
1177  {adaptor.getDnTensor(), stream});
1178  }
1179  } else {
1180  assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1181  destroyDnVecCallBuilder.create(loc, rewriter,
1182  {adaptor.getDnTensor(), stream});
1183  }
1184  rewriter.replaceOp(op, {stream});
1185  return success();
1186 }
1187 
1188 LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1189  gpu::CreateCooOp op, OpAdaptor adaptor,
1190  ConversionPatternRewriter &rewriter) const {
1191  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1192  failed(isAsyncWithOneDependency(rewriter, op)))
1193  return failure();
1194  Location loc = op.getLoc();
1195  auto stream = adaptor.getAsyncDependencies().front();
1196  Value pRowIdxs =
1197  MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1198  Value pColIdxs =
1199  MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1200  Value pValues =
1201  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1202  Type iType =
1203  llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1204  Type dType =
1205  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1206  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1207  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1208  auto handle =
1209  createCooCallBuilder
1210  .create(loc, rewriter,
1211  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1212  pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1213  .getResult();
1214  rewriter.replaceOp(op, {handle, stream});
1215  return success();
1216 }
1217 
1218 LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1219  gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1220  ConversionPatternRewriter &rewriter) const {
1221  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1222  failed(isAsyncWithOneDependency(rewriter, op)))
1223  return failure();
1224  Location loc = op.getLoc();
1225  auto stream = adaptor.getAsyncDependencies().front();
1226  Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc);
1227  Value pValues =
1228  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1229  Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1230  Type dType =
1231  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1232  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1233  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1234  auto handle =
1235  createCooAoSCallBuilder
1236  .create(loc, rewriter,
1237  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1238  pIdxs, pValues, itp, dtp, stream})
1239  .getResult();
1240  rewriter.replaceOp(op, {handle, stream});
1241  return success();
1242 }
1243 
1244 LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1245  gpu::CreateCsrOp op, OpAdaptor adaptor,
1246  ConversionPatternRewriter &rewriter) const {
1247  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1248  failed(isAsyncWithOneDependency(rewriter, op)))
1249  return failure();
1250  Location loc = op.getLoc();
1251  auto stream = adaptor.getAsyncDependencies().front();
1252  Value pRowPos =
1253  MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc);
1254  Value pColIdxs =
1255  MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1256  Value pValues =
1257  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1258  Type pType =
1259  llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1260  Type iType =
1261  llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1262  Type dType =
1263  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1264  auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1265  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1266  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1267  auto handle =
1268  createCsrCallBuilder
1269  .create(loc, rewriter,
1270  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1271  pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1272  .getResult();
1273  rewriter.replaceOp(op, {handle, stream});
1274  return success();
1275 }
1276 
1277 LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1278  gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1279  ConversionPatternRewriter &rewriter) const {
1280  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1281  failed(isAsyncWithOneDependency(rewriter, op)))
1282  return failure();
1283  Location loc = op.getLoc();
1284  auto stream = adaptor.getAsyncDependencies().front();
1285  Value pMat =
1286  MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1287  Type dType =
1288  llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1289  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1290 
1291  // CUDA runner asserts the size is 44104 bytes.
1292  auto handleSz = rewriter.create<LLVM::ConstantOp>(
1293  loc, getIndexType(), rewriter.getIndexAttr(44104));
1294  Value handle = rewriter.create<LLVM::AllocaOp>(
1295  loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1296  handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1297 
1298  create2To4SpMatCallBuilder
1299  .create(loc, rewriter,
1300  {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1301  .getResult();
1302  rewriter.replaceOp(op, {handle, stream});
1303  return success();
1304 }
1305 
1306 LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1307  gpu::DestroySpMatOp op, OpAdaptor adaptor,
1308  ConversionPatternRewriter &rewriter) const {
1309  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1310  failed(isAsyncWithOneDependency(rewriter, op)))
1311  return failure();
1312  Location loc = op.getLoc();
1313  auto stream = adaptor.getAsyncDependencies().front();
1314  // Use the cusparseLt destroy call if the spmat is 2:4 sparsity
1315  if (is2To4Sparsity(op.getSpmat())) {
1316  destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1317  {adaptor.getSpmat(), stream});
1318 
1319  } else {
1320  destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1321  }
1322  rewriter.replaceOp(op, {stream});
1323  return success();
1324 }
1325 
1326 LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1327  gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1328  ConversionPatternRewriter &rewriter) const {
1329  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1330  failed(isAsyncWithOneDependency(rewriter, op)))
1331  return failure();
1332  Location loc = op.getLoc();
1333  auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
1334  auto computeType = genConstInt32From(
1335  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1336  auto stream = adaptor.getAsyncDependencies().front();
1337  auto bufferSize = spMVBufferSizeCallBuilder
1338  .create(loc, rewriter,
1339  {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1340  adaptor.getDnY(), computeType, stream})
1341  .getResult();
1342  rewriter.replaceOp(op, {bufferSize, stream});
1343  return success();
1344 }
1345 
1346 LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1347  gpu::SpMVOp op, OpAdaptor adaptor,
1348  ConversionPatternRewriter &rewriter) const {
1349  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1350  failed(isAsyncWithOneDependency(rewriter, op)))
1351  return failure();
1352  Location loc = op.getLoc();
1353  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1354  auto computeType = genConstInt32From(
1355  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1356  auto stream = adaptor.getAsyncDependencies().front();
1357  Value pBuf =
1358  MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1359  spMVCallBuilder.create(loc, rewriter,
1360  {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1361  adaptor.getDnY(), computeType, pBuf, stream});
1362  rewriter.replaceOp(op, {stream});
1363  return success();
1364 }
1365 
1366 LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1367  gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1368  ConversionPatternRewriter &rewriter) const {
1369  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1370  failed(isAsyncWithOneDependency(rewriter, op)))
1371  return failure();
1372  Location loc = op.getLoc();
1373  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1374  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1375  auto stream = adaptor.getAsyncDependencies().front();
1376  Value bufferSize;
1377  if (is2To4Sparsity(op.getSpmatA())) {
1378  auto pruneFlag =
1379  genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA()));
1380  auto computeType = genConstInt32From(
1381  rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
1382  auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1383  rewriter.getIndexAttr(3));
1384  auto bufferSize = rewriter.create<LLVM::AllocaOp>(
1385  loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16);
1386  createCuSparseLtSpMMBufferSizeBuilder
1387  .create(loc, rewriter,
1388  {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1389  adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1390  pruneFlag, stream})
1391  .getResult();
1392 
1393  auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>(
1394  loc, llvmPointerType, llvmPointerType, bufferSize,
1395  ValueRange{rewriter.create<LLVM::ConstantOp>(
1396  loc, getIndexType(), rewriter.getIndexAttr(1))});
1397  auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>(
1398  loc, llvmPointerType, llvmPointerType, bufferSize,
1399  ValueRange{rewriter.create<LLVM::ConstantOp>(
1400  loc, getIndexType(), rewriter.getIndexAttr(2))});
1401  auto bufferSize0 =
1402  rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize);
1403  auto bufferSize1 =
1404  rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1);
1405  auto bufferSize2 =
1406  rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
1407 
1408  rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1409  } else {
1410  auto computeType = genConstInt32From(
1411  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1412  bufferSize =
1413  createSpMMBufferSizeCallBuilder
1414  .create(loc, rewriter,
1415  {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1416  adaptor.getDnmatC(), computeType, stream})
1417  .getResult();
1418  rewriter.replaceOp(op, {bufferSize, stream});
1419  }
1420  return success();
1421 }
1422 
1423 LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1424  gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1425  ConversionPatternRewriter &rewriter) const {
1426  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1427  failed(isAsyncWithOneDependency(rewriter, op)))
1428  return failure();
1429  Location loc = op.getLoc();
1430  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1431  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1432  auto computeType = genConstInt32From(
1433  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1434  auto stream = adaptor.getAsyncDependencies().front();
1435  auto bufferSize =
1436  createSDDMMBufferSizeCallBuilder
1437  .create(loc, rewriter,
1438  {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1439  adaptor.getSpmatC(), computeType, stream})
1440  .getResult();
1441  rewriter.replaceOp(op, {bufferSize, stream});
1442  return success();
1443 }
1444 
1445 LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1446  gpu::SpMMOp op, OpAdaptor adaptor,
1447  ConversionPatternRewriter &rewriter) const {
1448  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1449  failed(isAsyncWithOneDependency(rewriter, op)))
1450  return failure();
1451  Location loc = op.getLoc();
1452  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1453  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1454  auto computeType = genConstInt32From(
1455  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1456 
1457  auto stream = adaptor.getAsyncDependencies().front();
1458 
1459  // Lower to cusparseLt if applicable
1460  if (is2To4Sparsity(op.getSpmatA())) {
1461  SmallVector<Value> pBufs;
1462  for (Value buffer : adaptor.getBuffers()) {
1463  Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
1464  pBufs.push_back(pBuf);
1465  }
1466  createCuSparseLtSpMMBuilder.create(
1467  loc, rewriter,
1468  {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1469  pBufs[0], pBufs[1], pBufs[2], stream});
1470  } else {
1471  Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
1472  .allocatedPtr(rewriter, loc);
1473  createSpMMCallBuilder.create(loc, rewriter,
1474  {modeA, modeB, adaptor.getSpmatA(),
1475  adaptor.getDnmatB(), adaptor.getDnmatC(),
1476  computeType, pBuf, stream});
1477  }
1478  rewriter.replaceOp(op, {stream});
1479  return success();
1480 }
1481 
1482 template <typename T>
1484  converter.addConversion([&converter](T) -> Type {
1485  return LLVM::LLVMPointerType::get(&converter.getContext());
1486  });
1487 }
1488 
1489 LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1490  gpu::SDDMMOp op, OpAdaptor adaptor,
1491  ConversionPatternRewriter &rewriter) const {
1492  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1493  failed(isAsyncWithOneDependency(rewriter, op)))
1494  return failure();
1495  Location loc = op.getLoc();
1496  auto computeType = genConstInt32From(
1497  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1498  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1499  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1500  auto stream = adaptor.getAsyncDependencies().front();
1501  Value pBuf =
1502  MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1503  createSDDMMCallBuilder.create(loc, rewriter,
1504  {modeA, modeB, adaptor.getDnmatA(),
1505  adaptor.getDnmatB(), adaptor.getSpmatC(),
1506  computeType, pBuf, stream});
1507  rewriter.replaceOp(op, {stream});
1508  return success();
1509 }
1510 
1511 LogicalResult
1512 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1513  gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1514  ConversionPatternRewriter &rewriter) const {
1515  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1516  failed(isAsyncWithOneDependency(rewriter, op)))
1517  return failure();
1518  Location loc = op.getLoc();
1519  auto stream = adaptor.getAsyncDependencies().front();
1520  Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1521  .getResult();
1522  rewriter.replaceOp(op, {descr, stream});
1523  return success();
1524 }
1525 
1526 LogicalResult
1527 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1528  gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1529  ConversionPatternRewriter &rewriter) const {
1530  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1531  failed(isAsyncWithOneDependency(rewriter, op)))
1532  return failure();
1533  Location loc = op.getLoc();
1534  auto stream = adaptor.getAsyncDependencies().front();
1535  createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1536  {adaptor.getDesc(), stream});
1537  rewriter.replaceOp(op, {stream});
1538  return success();
1539 }
1540 
1541 LogicalResult
1542 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1543  gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1544  ConversionPatternRewriter &rewriter) const {
1545  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1546  failed(isAsyncWithOneDependency(rewriter, op)))
1547  return failure();
1548  Location loc = op.getLoc();
1549  auto computeType = genConstInt32From(
1550  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1551  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1552  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1553  auto stream = adaptor.getAsyncDependencies().front();
1554 
1555  Value pBuf =
1556  MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1557  Value bufferSizeNew;
1558 
1559  if (adaptor.getKind() ==
1560  gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1561  bufferSizeNew =
1562  createSpGEMMWorkEstimationBuilder
1563  .create(loc, rewriter,
1564  {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1565  adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1566  adaptor.getBufferSz(), pBuf, stream})
1567  .getResult();
1568  } else {
1569  bufferSizeNew =
1570  createSpGEMMComputeBuilder
1571  .create(loc, rewriter,
1572  {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1573  adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1574  adaptor.getBufferSz(), pBuf, stream})
1575  .getResult();
1576  }
1577  rewriter.replaceOp(op, {bufferSizeNew, stream});
1578  return success();
1579 }
1580 
1581 LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1582  gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1583  ConversionPatternRewriter &rewriter) const {
1584  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1585  failed(isAsyncWithOneDependency(rewriter, op)))
1586  return failure();
1587  Location loc = op.getLoc();
1588  auto computeType = genConstInt32From(
1589  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1590  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1591  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1592  auto stream = adaptor.getAsyncDependencies().front();
1593  createSpGEMMCopyBuilder.create(loc, rewriter,
1594  {adaptor.getDesc(), modeA, modeB,
1595  adaptor.getSpmatA(), adaptor.getSpmatB(),
1596  adaptor.getSpmatC(), computeType, stream});
1597  rewriter.replaceOp(op, {stream});
1598  return success();
1599 }
1600 
1601 LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1602  gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1603  ConversionPatternRewriter &rewriter) const {
1604  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1605  failed(isAsyncWithOneDependency(rewriter, op)))
1606  return failure();
1607  Location loc = op.getLoc();
1608  auto stream = adaptor.getAsyncDependencies().front();
1609 
1610  auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1611  rewriter.getIndexAttr(3));
1612  auto buffer = rewriter.create<LLVM::AllocaOp>(
1613  loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16);
1614 
1615  auto rowsPtr = rewriter.create<LLVM::GEPOp>(
1616  loc, llvmPointerType, llvmPointerType, buffer,
1617  ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1618  rewriter.getIndexAttr(0))});
1619  auto colsPtr = rewriter.create<LLVM::GEPOp>(
1620  loc, llvmPointerType, llvmPointerType, buffer,
1621  ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1622  rewriter.getIndexAttr(1))});
1623  auto nnzsPtr = rewriter.create<LLVM::GEPOp>(
1624  loc, llvmPointerType, llvmPointerType, buffer,
1625  ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1626  rewriter.getIndexAttr(2))});
1627  createSpMatGetSizeBuilder.create(
1628  loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1629  auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr);
1630  auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr);
1631  auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr);
1632 
1633  rewriter.replaceOp(op, {rows, cols, nnzs, stream});
1634  return success();
1635 }
1636 
1637 LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1638  gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1639  ConversionPatternRewriter &rewriter) const {
1640  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1641  failed(isAsyncWithOneDependency(rewriter, op)))
1642  return failure();
1643  Location loc = op.getLoc();
1644  auto stream = adaptor.getAsyncDependencies().front();
1645  Value pPos =
1646  MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
1647  Value pCrd =
1648  MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
1649  Value pVal =
1650  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1651  createSetCsrPointersBuilder.create(
1652  loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1653  rewriter.replaceOp(op, {stream});
1654  return success();
1655 }
1656 
1657 LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1658  gpu::CreateCscOp op, OpAdaptor adaptor,
1659  ConversionPatternRewriter &rewriter) const {
1660  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1661  failed(isAsyncWithOneDependency(rewriter, op)))
1662  return failure();
1663  Location loc = op.getLoc();
1664  auto stream = adaptor.getAsyncDependencies().front();
1665  Value pColPos =
1666  MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc);
1667  Value pRowIdxs =
1668  MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1669  Value pValues =
1670  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1671  Type pType =
1672  llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1673  Type iType =
1674  llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1675  Type dType =
1676  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1677  auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1678  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1679  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1680  auto handle =
1681  createCscCallBuilder
1682  .create(loc, rewriter,
1683  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1684  pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1685  .getResult();
1686  rewriter.replaceOp(op, {handle, stream});
1687  return success();
1688 }
1689 
1690 LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1691  gpu::CreateBsrOp op, OpAdaptor adaptor,
1692  ConversionPatternRewriter &rewriter) const {
1693  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1694  failed(isAsyncWithOneDependency(rewriter, op)))
1695  return failure();
1696  Location loc = op.getLoc();
1697  auto stream = adaptor.getAsyncDependencies().front();
1698  Value pRowPos =
1699  MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc);
1700  Value pColIdxs =
1701  MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc);
1702  Value pValues =
1703  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1704  Type pType =
1705  llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1706  Type iType =
1707  llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1708  Type dType =
1709  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1710  auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1711  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1712  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1713  auto handle =
1714  createBsrCallBuilder
1715  .create(loc, rewriter,
1716  {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1717  adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1718  pColIdxs, pValues, ptp, itp, dtp, stream})
1719  .getResult();
1720  rewriter.replaceOp(op, {handle, stream});
1721  return success();
1722 }
1723 
1725  RewritePatternSet &patterns,
1726  bool kernelBarePtrCallConv) {
1727  addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
1728  addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
1729  addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
1730  addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter);
1731 
1732  patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
1733  ConvertDeallocOpToGpuRuntimeCallPattern,
1734  ConvertHostRegisterOpToGpuRuntimeCallPattern,
1735  ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1736  ConvertMemcpyOpToGpuRuntimeCallPattern,
1737  ConvertMemsetOpToGpuRuntimeCallPattern,
1738  ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1739  ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1740  ConvertWaitOpToGpuRuntimeCallPattern,
1741  ConvertAsyncYieldToGpuRuntimeCallPattern,
1742  ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1743  ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1744  ConvertCreateCooOpToGpuRuntimeCallPattern,
1745  ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1746  ConvertCreateCsrOpToGpuRuntimeCallPattern,
1747  ConvertCreateCscOpToGpuRuntimeCallPattern,
1748  ConvertCreateBsrOpToGpuRuntimeCallPattern,
1749  ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1750  ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1751  ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1752  ConvertSpMVOpToGpuRuntimeCallPattern,
1753  ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1754  ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1755  ConvertSpMMOpToGpuRuntimeCallPattern,
1756  ConvertSDDMMOpToGpuRuntimeCallPattern,
1757  ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1758  ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1759  ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1760  ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1761  ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1762  ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
1763  patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv);
1764 }
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static MLIRContext * getContext(OpFoldResult val)
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1545
int64_t cols
int64_t rows
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:128
FloatType getF32Type()
Definition: Builders.cpp:67
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:91
MLIRContext * getContext() const
Definition: Builders.h:55
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:257
IntegerAttr getI8IntegerAttr(int8_t value)
Definition: Builders.cpp:238
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:36
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:53
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
This class helps build Operations.
Definition: Builders.h:210
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition: Builders.h:388
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition: Builders.h:249
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition: Builders.h:393
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:451
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:415
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
Definition: Region.h:205
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:53
bool isF32() const
Definition: Types.cpp:52
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:59
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:120
bool isF16() const
Definition: Types.cpp:50
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:126
bool isBF16() const
Definition: Types.cpp:49
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:856
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Definition: MemoryOps.cpp:263
Include the generated interface declarations.
void populateFinalizeMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool kernelBarePtrCallConv=false)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::LLVMFunctionType functionType
Definition: GPUCommonPass.h:60
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:38