MLIR  22.0.0git
GPUToLLVMConversion.cpp
Go to the documentation of this file.
1 //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert gpu.launch_func op into a sequence of
10 // GPU runtime calls. As most of GPU runtimes does not have a stable published
11 // ABI, this pass uses a slim runtime layer that builds on top of the public
12 // API from GPU runtime headers.
13 //
14 //===----------------------------------------------------------------------===//
15 
17 
34 #include "mlir/IR/Attributes.h"
35 #include "mlir/IR/Builders.h"
36 #include "mlir/IR/BuiltinOps.h"
37 #include "mlir/IR/BuiltinTypes.h"
39 
40 #include "llvm/ADT/STLExtras.h"
41 
42 #define DEBUG_TYPE "gpu-to-llvm"
43 
44 namespace mlir {
45 #define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
46 #include "mlir/Conversion/Passes.h.inc"
47 } // namespace mlir
48 
49 using namespace mlir;
50 
51 namespace {
52 class GpuToLLVMConversionPass
53  : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
54 public:
55  using Base::Base;
56  void getDependentDialects(DialectRegistry &registry) const final {
57  Base::getDependentDialects(registry);
59  }
60  // Run the dialect converter on the module.
61  void runOnOperation() override;
62 };
63 
64 template <typename OpTy>
65 class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
66 public:
67  explicit ConvertOpToGpuRuntimeCallPattern(
68  const LLVMTypeConverter &typeConverter)
69  : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
70 
71 protected:
73  MemRefType type, MemRefDescriptor desc) const {
75  if (type.hasStaticShape())
77  rewriter, loc, indexType, type.getNumElements());
78  // Compute the number of elements by multiplying all the dim sizes.
79  uint64_t rank = type.getRank();
80  Value numElements = desc.size(rewriter, loc, /*pos=*/0);
81  for (unsigned i = 1; i < rank; i++)
82  numElements = LLVM::MulOp::create(rewriter, loc, numElements,
83  desc.size(rewriter, loc, /*pos=*/i));
84  return numElements;
85  }
86 
87  MLIRContext *context = &this->getTypeConverter()->getContext();
88 
89  Type llvmVoidType = LLVM::LLVMVoidType::get(context);
90  LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
91  Type llvmInt8Type = IntegerType::get(context, 8);
92  Type llvmInt16Type = IntegerType::get(context, 16);
93  Type llvmInt32Type = IntegerType::get(context, 32);
94  Type llvmInt64Type = IntegerType::get(context, 64);
95  Type llvmFloat32Type = Float32Type::get(context);
96  Type llvmIntPtrType = IntegerType::get(
97  context, this->getTypeConverter()->getPointerBitwidth(0));
98 
99  FunctionCallBuilder streamCreateCallBuilder = {
100  "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
101  FunctionCallBuilder streamDestroyCallBuilder = {
102  "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
103  FunctionCallBuilder streamSynchronizeCallBuilder = {
104  "mgpuStreamSynchronize",
105  llvmVoidType,
106  {llvmPointerType /* void *stream */}};
107  FunctionCallBuilder streamWaitEventCallBuilder = {
108  "mgpuStreamWaitEvent",
109  llvmVoidType,
110  {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
111  FunctionCallBuilder eventCreateCallBuilder = {
112  "mgpuEventCreate", llvmPointerType /* void *event */, {}};
113  FunctionCallBuilder eventDestroyCallBuilder = {
114  "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
115  FunctionCallBuilder eventSynchronizeCallBuilder = {
116  "mgpuEventSynchronize",
117  llvmVoidType,
118  {llvmPointerType /* void *event */}};
119  FunctionCallBuilder eventRecordCallBuilder = {
120  "mgpuEventRecord",
121  llvmVoidType,
122  {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
123  FunctionCallBuilder hostRegisterCallBuilder = {
124  "mgpuMemHostRegisterMemRef",
125  llvmVoidType,
126  {llvmIntPtrType /* intptr_t rank */,
127  llvmPointerType /* void *memrefDesc */,
128  llvmIntPtrType /* intptr_t elementSizeBytes */}};
129  FunctionCallBuilder hostUnregisterCallBuilder = {
130  "mgpuMemHostUnregisterMemRef",
131  llvmVoidType,
132  {llvmIntPtrType /* intptr_t rank */,
133  llvmPointerType /* void *memrefDesc */,
134  llvmIntPtrType /* intptr_t elementSizeBytes */}};
135  FunctionCallBuilder allocCallBuilder = {
136  "mgpuMemAlloc",
137  llvmPointerType /* void * */,
138  {llvmIntPtrType /* intptr_t sizeBytes */,
139  llvmPointerType /* void *stream */,
140  llvmInt8Type /* bool isHostShared */}};
141  FunctionCallBuilder deallocCallBuilder = {
142  "mgpuMemFree",
143  llvmVoidType,
144  {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
145  FunctionCallBuilder memcpyCallBuilder = {
146  "mgpuMemcpy",
147  llvmVoidType,
148  {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
149  llvmIntPtrType /* intptr_t sizeBytes */,
150  llvmPointerType /* void *stream */}};
151  FunctionCallBuilder memset16CallBuilder = {
152  "mgpuMemset16",
153  llvmVoidType,
154  {llvmPointerType /* void *dst */,
155  llvmInt16Type /* unsigned short value */,
156  llvmIntPtrType /* intptr_t sizeBytes */,
157  llvmPointerType /* void *stream */}};
158  FunctionCallBuilder memset32CallBuilder = {
159  "mgpuMemset32",
160  llvmVoidType,
161  {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
162  llvmIntPtrType /* intptr_t sizeBytes */,
163  llvmPointerType /* void *stream */}};
164  FunctionCallBuilder setDefaultDeviceCallBuilder = {
165  "mgpuSetDefaultDevice",
166  llvmVoidType,
167  {llvmInt32Type /* uint32_t devIndex */}};
168  FunctionCallBuilder createDnVecCallBuilder = {
169  "mgpuCreateDnVec",
170  llvmPointerType,
171  {llvmIntPtrType, llvmPointerType, llvmInt32Type,
172  llvmPointerType /* void *stream */}};
173  FunctionCallBuilder destroyDnVecCallBuilder = {
174  "mgpuDestroyDnVec",
175  llvmVoidType,
176  {llvmPointerType, llvmPointerType /* void *stream */}};
177  FunctionCallBuilder createDnMatCallBuilder = {
178  "mgpuCreateDnMat",
179  llvmPointerType,
180  {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
181  llvmPointerType /* void *stream */}};
182  FunctionCallBuilder destroyDnMatCallBuilder = {
183  "mgpuDestroyDnMat",
184  llvmVoidType,
185  {llvmPointerType, llvmPointerType /* void *stream */}};
186  FunctionCallBuilder createCooCallBuilder = {
187  "mgpuCreateCoo",
188  llvmPointerType,
189  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
190  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
191  llvmPointerType /* void *stream */}};
192  FunctionCallBuilder createCooAoSCallBuilder = {
193  "mgpuCreateCooAoS", // deprecated in cuSPARSE 11.2
194  llvmPointerType,
195  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
196  llvmPointerType, llvmInt32Type, llvmInt32Type,
197  llvmPointerType /* void *stream */}};
198  FunctionCallBuilder createCsrCallBuilder = {
199  "mgpuCreateCsr",
200  llvmPointerType,
201  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
202  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
203  llvmInt32Type, llvmPointerType /* void *stream */}};
204  FunctionCallBuilder createCscCallBuilder = {
205  "mgpuCreateCsc",
206  llvmPointerType,
207  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
208  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
209  llvmInt32Type, llvmPointerType /* void *stream */}};
210  FunctionCallBuilder createBsrCallBuilder = {
211  "mgpuCreateBsr",
212  llvmPointerType,
213  {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
214  llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
215  llvmInt32Type, llvmInt32Type, llvmInt32Type,
216  llvmPointerType /* void *stream */}};
217  FunctionCallBuilder destroySpMatCallBuilder = {
218  "mgpuDestroySpMat",
219  llvmVoidType,
220  {llvmPointerType, llvmPointerType /* void *stream */}};
221  FunctionCallBuilder spMVBufferSizeCallBuilder = {
222  "mgpuSpMVBufferSize",
223  llvmIntPtrType,
224  {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
225  llvmInt32Type, llvmPointerType /* void *stream */}};
226  FunctionCallBuilder spMVCallBuilder = {
227  "mgpuSpMV",
228  llvmVoidType,
229  {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
230  llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
231  FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
232  "mgpuSpMMBufferSize",
233  llvmIntPtrType,
234  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
235  llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
236  FunctionCallBuilder createSpMMCallBuilder = {
237  "mgpuSpMM",
238  llvmVoidType,
239  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
240  llvmPointerType, llvmInt32Type, llvmPointerType,
241  llvmPointerType /* void *stream */}};
242  FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
243  "mgpuSDDMMBufferSize",
244  llvmIntPtrType,
245  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
246  llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
247  FunctionCallBuilder createSDDMMCallBuilder = {
248  "mgpuSDDMM",
249  llvmVoidType,
250  {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
251  llvmPointerType, llvmInt32Type, llvmPointerType,
252  llvmPointerType /* void *stream */}};
253  FunctionCallBuilder createLtDnMatCallBuilder = {
254  "mgpuCreateCuSparseLtDnMat",
255  llvmVoidType,
256  {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
257  llvmInt32Type, llvmPointerType /* void *stream */}};
258  FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
259  "mgpuDestroyCuSparseLtSpMat",
260  llvmVoidType,
261  {llvmPointerType, llvmPointerType /* void *stream */}};
262  FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
263  "mgpuDestroyCuSparseLtDnMat",
264  llvmVoidType,
265  {llvmPointerType, llvmPointerType /* void *stream */}};
266  FunctionCallBuilder create2To4SpMatCallBuilder = {
267  "mgpuCusparseLtCreate2To4SpMat",
268  llvmVoidType,
269  {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
270  llvmInt32Type, llvmPointerType /* void *stream */}};
271  FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
272  "mgpuCuSparseLtSpMMBufferSize",
273  llvmVoidType,
274  {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
275  llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
276  llvmPointerType /*void *stream*/}};
277  FunctionCallBuilder createCuSparseLtSpMMBuilder = {
278  "mgpuCuSparseLtSpMM",
279  llvmVoidType,
280  {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
281  llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}};
282  FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
283  "mgpuSpGEMMCreateDescr",
284  llvmPointerType,
285  {llvmPointerType /*void *stream*/}};
286  FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
287  "mgpuSpGEMMDestroyDescr",
288  llvmVoidType,
289  {llvmPointerType /*s*/, llvmPointerType /*void *stream*/}};
290  FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
291  "mgpuSpGEMMWorkEstimation",
292  llvmIntPtrType,
293  {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
294  llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
295  llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
296  llvmPointerType /*void *stream*/}};
297  FunctionCallBuilder createSpGEMMComputeBuilder = {
298  "mgpuSpGEMMCompute",
299  llvmIntPtrType,
300  {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
301  llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
302  llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
303  llvmPointerType /*void *stream*/}};
304  FunctionCallBuilder createSpGEMMCopyBuilder = {
305  "mgpuSpGEMMCopy",
306  llvmVoidType,
307  {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
308  llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
309  llvmInt32Type /*ctp*/, llvmPointerType /*void *stream*/}};
310  FunctionCallBuilder createSpMatGetSizeBuilder = {
311  "mgpuSpMatGetSize",
312  llvmVoidType,
313  {llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/,
314  llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}};
315  FunctionCallBuilder createSetCsrPointersBuilder = {
316  "mgpuSetCsrPointers",
317  llvmVoidType,
318  {llvmPointerType /*spmat*/, llvmPointerType /*pos*/,
319  llvmPointerType /*crd*/, llvmPointerType /*val*/,
320  llvmPointerType /*void *stream*/}};
321 };
322 
323 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
324 /// call. Currently it supports CUDA and ROCm (HIP).
325 class ConvertHostRegisterOpToGpuRuntimeCallPattern
326  : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
327 public:
328  ConvertHostRegisterOpToGpuRuntimeCallPattern(
329  const LLVMTypeConverter &typeConverter)
330  : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
331 
332 private:
333  LogicalResult
334  matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
335  ConversionPatternRewriter &rewriter) const override;
336 };
337 
338 class ConvertHostUnregisterOpToGpuRuntimeCallPattern
339  : public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
340 public:
341  ConvertHostUnregisterOpToGpuRuntimeCallPattern(
342  const LLVMTypeConverter &typeConverter)
343  : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
344  }
345 
346 private:
347  LogicalResult
348  matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
349  ConversionPatternRewriter &rewriter) const override;
350 };
351 
352 /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
353 /// call. Currently it supports CUDA and ROCm (HIP).
354 class ConvertAllocOpToGpuRuntimeCallPattern
355  : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
356 public:
357  ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
358  : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
359 
360 private:
361  LogicalResult
362  matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
363  ConversionPatternRewriter &rewriter) const override;
364 };
365 
366 /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
367 /// call. Currently it supports CUDA and ROCm (HIP).
368 class ConvertDeallocOpToGpuRuntimeCallPattern
369  : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
370 public:
371  ConvertDeallocOpToGpuRuntimeCallPattern(
372  const LLVMTypeConverter &typeConverter)
373  : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
374 
375 private:
376  LogicalResult
377  matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
378  ConversionPatternRewriter &rewriter) const override;
379 };
380 
381 class ConvertAsyncYieldToGpuRuntimeCallPattern
382  : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
383 public:
384  ConvertAsyncYieldToGpuRuntimeCallPattern(
385  const LLVMTypeConverter &typeConverter)
386  : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
387 
388 private:
389  LogicalResult
390  matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
391  ConversionPatternRewriter &rewriter) const override;
392 };
393 
394 /// A rewrite pattern to convert gpu.wait operations into a GPU runtime
395 /// call. Currently it supports CUDA and ROCm (HIP).
396 class ConvertWaitOpToGpuRuntimeCallPattern
397  : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
398 public:
399  ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
400  : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
401 
402 private:
403  LogicalResult
404  matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
405  ConversionPatternRewriter &rewriter) const override;
406 };
407 
408 /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
409 /// call. Currently it supports CUDA and ROCm (HIP).
410 class ConvertWaitAsyncOpToGpuRuntimeCallPattern
411  : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
412 public:
413  ConvertWaitAsyncOpToGpuRuntimeCallPattern(
414  const LLVMTypeConverter &typeConverter)
415  : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
416 
417 private:
418  LogicalResult
419  matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
420  ConversionPatternRewriter &rewriter) const override;
421 };
422 
423 /// A rewrite patter to legalize gpu.launch_func with LLVM types.
424 class LegalizeLaunchFuncOpPattern
425  : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
426 public:
427  LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
428  bool kernelBarePtrCallConv,
429  bool kernelIntersperseSizeCallConv)
430  : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
431  kernelBarePtrCallConv(kernelBarePtrCallConv),
432  kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
433 
434 private:
435  LogicalResult
436  matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
437  ConversionPatternRewriter &rewriter) const override;
438 
439  bool kernelBarePtrCallConv;
440  bool kernelIntersperseSizeCallConv;
441 };
442 
443 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
444 /// call. Currently it supports CUDA and ROCm (HIP).
445 class ConvertMemcpyOpToGpuRuntimeCallPattern
446  : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
447 public:
448  ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
449  : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
450 
451 private:
452  LogicalResult
453  matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
454  ConversionPatternRewriter &rewriter) const override;
455 };
456 
457 /// A rewrite pattern to convert gpu.memset operations into a GPU runtime
458 /// call. Currently it supports CUDA and ROCm (HIP).
459 class ConvertMemsetOpToGpuRuntimeCallPattern
460  : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
461 public:
462  ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
463  : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
464 
465 private:
466  LogicalResult
467  matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
468  ConversionPatternRewriter &rewriter) const override;
469 };
470 
471 /// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call.
472 /// Currently supports CUDA and ROCm (HIP)
473 class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
474  : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
475 public:
476  ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
477  const LLVMTypeConverter &typeConverter)
478  : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
479  typeConverter) {}
480 
481  LogicalResult
482  matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
483  ConversionPatternRewriter &rewriter) const override;
484 };
485 
486 /// Generic rewriting rule for operation on sparse matrices.
487 /// Currently supports CUDA (by means of cuSparse and cuSparseLt).
488 #define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
489  class Convert##op_name##ToGpuRuntimeCallPattern \
490  : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
491  public: \
492  Convert##op_name##ToGpuRuntimeCallPattern( \
493  const LLVMTypeConverter &typeConverter) \
494  : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
495  \
496  private: \
497  LogicalResult \
498  matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
499  ConversionPatternRewriter &rewriter) const override; \
500  };
501 
519 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp)
523 
524 } // namespace
525 
526 void GpuToLLVMConversionPass::runOnOperation() {
527  MLIRContext *context = &getContext();
528 
529  // Perform progressive lowering of vector transfer operations.
530  {
532  // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
534  /*maxTransferRank=*/1);
535  // Transform N-D vector.from_elements to 1-D vector.from_elements before
536  // conversion.
538  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
539  return signalPassFailure();
540  }
541 
542  LowerToLLVMOptions options(context);
543  options.useBarePtrCallConv = hostBarePtrCallConv;
544  RewritePatternSet patterns(context);
545  ConversionTarget target(*context);
546  target.addLegalDialect<LLVM::LLVMDialect>();
547  LLVMTypeConverter converter(context, options);
548 
549  // Populate all patterns from all dialects that implement the
550  // `ConvertToLLVMPatternInterface` interface.
551  for (Dialect *dialect : context->getLoadedDialects()) {
552  auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
553  if (!iface)
554  continue;
555  iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
556  }
557 
558  // Preserve GPU modules and binaries. Modules are preserved as they can be
559  // converted later by `gpu-module-to-binary`.
560  target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
561  // Accept as legal LaunchFuncOps if the operands have been lowered.
562  target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
563  [&](gpu::LaunchFuncOp op) -> bool { return converter.isLegal(op); });
564 
565  // These aren't covered by the ConvertToLLVMPatternInterface right now.
569  target);
571  kernelBarePtrCallConv,
572  kernelIntersperseSizeCallConv);
573 
574  if (failed(
575  applyPartialConversion(getOperation(), target, std::move(patterns))))
576  signalPassFailure();
577 }
578 
579 LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
580  ArrayRef<Value> arguments) const {
581  auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
582  auto function = [&] {
583  if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
584  return function;
585  auto builder = OpBuilder::atBlockEnd(module.getBody());
586  return LLVM::LLVMFuncOp::create(builder, loc, functionName, functionType);
587  }();
588  return LLVM::CallOp::create(builder, loc, function, arguments);
589 }
590 
591 // Corresponding to cusparseIndexType_t defined in cusparse.h.
592 static int32_t getCuSparseIndexTypeFrom(Type type) {
593  if (type.isInteger(16))
594  return 1; // CUSPARSE_INDEX_16U
595  if (type.isInteger(32))
596  return 2; // CUSPARSE_INDEX_32I
597  return 3; // CUSPARSE_INDEX_64I
598 }
599 
600 static int32_t getCuSparseLtDataTypeFrom(Type type) {
601  if (type.isF16())
602  return 0; // CUSPARSE_COMPUTE_16F,
603  if (type.isInteger(32))
604  return 1; // CUSPARSE_COMPUTE_32I
605  llvm_unreachable("unsupported type");
606  // TODO: add support to TF32
607 }
608 
609 // Corresponding to cudaDataType_t defined in CUDA library_types.h.
610 static int32_t getCuSparseDataTypeFrom(Type type) {
611  if (llvm::isa<ComplexType>(type)) {
612  // get the element type
613  auto elementType = cast<ComplexType>(type).getElementType();
614  if (elementType.isBF16())
615  return 15; // CUDA_C_16BF
616  if (elementType.isF16())
617  return 6; // CUDA_C_16F
618  if (elementType.isF32())
619  return 4; // CUDA_C_32F
620  if (elementType.isF64())
621  return 5; // CUDA_C_64F
622  if (elementType.isInteger(8))
623  return 7; // CUDA_C_8I
624  if (elementType.isInteger(16))
625  return 21; // CUDA_C_16I
626  if (elementType.isInteger(32))
627  return 11; // CUDA_C_32I
628  }
629  if (type.isBF16())
630  return 14; // CUDA_R_16BF
631  if (type.isF16())
632  return 2; // CUDA_R_16F
633  if (type.isF32())
634  return 0; // CUDA_R_32F
635  if (type.isF64())
636  return 1; // CUDA_R_64F
637  if (type.isInteger(8))
638  return 3; // CUDA_R_8I
639  if (type.isInteger(16))
640  return 20; // CUDA_R_16I
641  if (type.isInteger(32))
642  return 10; // CUDA_R_32I
643 
644  llvm_unreachable("unsupported element type");
645 }
646 
647 static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat) {
648  return spMat.getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
649 }
650 
651 // TODO: We may want a run-time (of the mlir compiler) disablement/warning:
652 // cusparseLt currently won't work for cuda architecture <8.0 and will trigger a
653 // runtime (of the CUDA program) error , but it might be great if we could at
654 // least output a warning when we found the target architecture is <8.0 and the
655 // user still wants to use cusparseLt. to make sure when lowering gpu sparse
656 // dialect to llvm calls, the cusparselt calls are disabled for cuda
657 // architecture <8.0
658 static bool is2To4Sparsity(Value spMat) {
659  if (auto op = spMat.getDefiningOp<gpu::Create2To4SpMatOp>())
660  return true;
661  if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
662  return false;
663  if (auto op = spMat.getDefiningOp<gpu::CreateCooAoSOp>())
664  return false;
665  if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
666  return false;
667  if (auto op = spMat.getDefiningOp<gpu::CreateCscOp>())
668  return false;
669  if (auto op = spMat.getDefiningOp<gpu::CreateBsrOp>())
670  return false;
671  // Print the spMat defining op
672  spMat.getDefiningOp()->print(llvm::errs());
673  llvm_unreachable("cannot find spmat def");
674 }
675 
676 static bool isSpMMCusparseLtOp(Value op) {
677  for (Operation *user : op.getUsers()) {
678  auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
679  // If the other operator is 50% sparsity then we should use cusparseLt
680  if (!spmmOp)
681  continue;
682  if (is2To4Sparsity(spmmOp.getSpmatA()))
683  return true;
684  }
685  return false;
686 }
687 
688 // Returns whether all operands are of LLVM type.
689 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
690  ConversionPatternRewriter &rewriter) {
691  if (!llvm::all_of(operands, [](Value value) {
692  return LLVM::isCompatibleType(value.getType());
693  }))
694  return rewriter.notifyMatchFailure(
695  op, "Cannot convert if operands aren't of LLVM type.");
696  return success();
697 }
698 
699 static LogicalResult
701  gpu::AsyncOpInterface op) {
702  if (op.getAsyncDependencies().size() != 1)
703  return rewriter.notifyMatchFailure(
704  op, "Can only convert with exactly one async dependency.");
705 
706  if (!op.getAsyncToken())
707  return rewriter.notifyMatchFailure(op, "Can convert only async version.");
708 
709  return success();
710 }
711 
712 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
713  gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
714  ConversionPatternRewriter &rewriter) const {
715  auto *op = hostRegisterOp.getOperation();
716  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
717  return failure();
718 
719  Location loc = op->getLoc();
720 
721  auto memRefType = hostRegisterOp.getValue().getType();
722  auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
723  auto elementSize = getSizeInBytes(loc, elementType, rewriter);
724 
725  auto arguments = getTypeConverter()->promoteOperands(
726  loc, op->getOperands(), adaptor.getOperands(), rewriter);
727  arguments.push_back(elementSize);
728  hostRegisterCallBuilder.create(loc, rewriter, arguments);
729 
730  rewriter.eraseOp(op);
731  return success();
732 }
733 
734 LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
735  gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
736  ConversionPatternRewriter &rewriter) const {
737  Operation *op = hostUnregisterOp.getOperation();
738  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
739  return failure();
740 
741  Location loc = op->getLoc();
742 
743  auto memRefType = hostUnregisterOp.getValue().getType();
744  auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
745  auto elementSize = getSizeInBytes(loc, elementType, rewriter);
746 
747  auto arguments = getTypeConverter()->promoteOperands(
748  loc, op->getOperands(), adaptor.getOperands(), rewriter);
749  arguments.push_back(elementSize);
750  hostUnregisterCallBuilder.create(loc, rewriter, arguments);
751 
752  rewriter.eraseOp(op);
753  return success();
754 }
755 
756 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
757  gpu::AllocOp allocOp, OpAdaptor adaptor,
758  ConversionPatternRewriter &rewriter) const {
759 
760  MemRefType memRefType = allocOp.getType();
761 
762  if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
763  !isConvertibleAndHasIdentityMaps(memRefType))
764  return failure();
765 
766  auto loc = allocOp.getLoc();
767 
768  bool isShared = allocOp.getHostShared();
769 
770  if (isShared && allocOp.getAsyncToken())
771  return rewriter.notifyMatchFailure(
772  allocOp, "Host Shared allocation cannot be done async");
773  if (!isShared && failed(isAsyncWithOneDependency(rewriter, allocOp)))
774  return failure();
775 
776  // Get shape of the memref as values: static sizes are constant
777  // values and dynamic sizes are passed to 'alloc' as operands.
778  SmallVector<Value, 4> shape;
779  SmallVector<Value, 4> strides;
780  Value sizeBytes;
781  getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
782  shape, strides, sizeBytes);
783 
784  // Allocate the underlying buffer and store a pointer to it in the MemRef
785  // descriptor.
786  auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType);
787  Value stream = adaptor.getAsyncDependencies().empty()
788  ? nullPtr
789  : adaptor.getAsyncDependencies().front();
790 
791  auto isHostShared = mlir::LLVM::ConstantOp::create(
792  rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
793 
794  Value allocatedPtr =
795  allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
796  .getResult();
797 
798  // No alignment.
799  Value alignedPtr = allocatedPtr;
800 
801  // Create the MemRef descriptor.
802  auto memRefDescriptor = this->createMemRefDescriptor(
803  loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
804 
805  if (allocOp.getAsyncToken()) {
806  // Async alloc: make dependent ops use the same stream.
807  rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
808  } else {
809  rewriter.replaceOp(allocOp, {memRefDescriptor});
810  }
811 
812  return success();
813 }
814 
815 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
816  gpu::DeallocOp deallocOp, OpAdaptor adaptor,
817  ConversionPatternRewriter &rewriter) const {
818  if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
819  failed(isAsyncWithOneDependency(rewriter, deallocOp)))
820  return failure();
821 
822  Location loc = deallocOp.getLoc();
823 
824  Value pointer =
825  MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
826  Value stream = adaptor.getAsyncDependencies().front();
827  deallocCallBuilder.create(loc, rewriter, {pointer, stream});
828 
829  rewriter.replaceOp(deallocOp, {stream});
830  return success();
831 }
832 
833 static bool isGpuAsyncTokenType(Value value) {
834  return isa<gpu::AsyncTokenType>(value.getType());
835 }
836 
837 // Converts !gpu.async.token operands of `async.yield` to runtime calls. The
838 // !gpu.async.token are lowered to stream within the async.execute region, but
839 // are passed as events between them. For each !gpu.async.token operand, we
840 // create an event and record it on the stream.
841 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
842  async::YieldOp yieldOp, OpAdaptor adaptor,
843  ConversionPatternRewriter &rewriter) const {
844  if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType))
845  return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
846 
847  Location loc = yieldOp.getLoc();
848  SmallVector<Value, 4> newOperands(adaptor.getOperands());
849  llvm::SmallDenseSet<Value> streams;
850  for (auto &operand : yieldOp->getOpOperands()) {
851  if (!isGpuAsyncTokenType(operand.get()))
852  continue;
853  auto idx = operand.getOperandNumber();
854  auto stream = adaptor.getOperands()[idx];
855  auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
856  eventRecordCallBuilder.create(loc, rewriter, {event, stream});
857  newOperands[idx] = event;
858  streams.insert(stream);
859  }
860  for (auto stream : streams)
861  streamDestroyCallBuilder.create(loc, rewriter, {stream});
862 
863  rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
864  return success();
865 }
866 
867 // Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
868 static bool isDefinedByCallTo(Value value, StringRef functionName) {
869  assert(isa<LLVM::LLVMPointerType>(value.getType()));
870  if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
871  return *defOp.getCallee() == functionName;
872  return false;
873 }
874 
875 // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
876 // with the stream/event operands. The operands are destroyed. That is, it
877 // assumes that it is not used afterwards or elsewhere. Otherwise we will get a
878 // runtime error. Eventually, we should guarantee this property.
879 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
880  gpu::WaitOp waitOp, OpAdaptor adaptor,
881  ConversionPatternRewriter &rewriter) const {
882  if (waitOp.getAsyncToken())
883  return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
884 
885  Location loc = waitOp.getLoc();
886 
887  for (auto operand : adaptor.getOperands()) {
888  if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
889  // The converted operand's definition created a stream.
890  streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
891  streamDestroyCallBuilder.create(loc, rewriter, {operand});
892  } else {
893  // Otherwise the converted operand is an event. This assumes that we use
894  // events in control flow code as well.
895  eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
896  eventDestroyCallBuilder.create(loc, rewriter, {operand});
897  }
898  }
899 
900  rewriter.eraseOp(waitOp);
901  return success();
902 }
903 
904 // Converts `gpu.wait async` to runtime calls. The converted op creates a new
905 // stream that is synchronized with stream/event operands. The operands are
906 // destroyed. That is, it assumes that it is not used afterwards or elsewhere.
907 // Otherwise we will get a runtime error. Eventually, we should guarantee this
908 // property.
909 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
910  gpu::WaitOp waitOp, OpAdaptor adaptor,
911  ConversionPatternRewriter &rewriter) const {
912  if (!waitOp.getAsyncToken())
913  return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
914 
915  Location loc = waitOp.getLoc();
916 
917  auto insertionPoint = rewriter.saveInsertionPoint();
918  SmallVector<Value, 1> events;
919  for (auto pair :
920  llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
921  auto operand = std::get<1>(pair);
922  if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
923  // The converted operand's definition created a stream. Insert an event
924  // into the stream just after the last use of the original token operand.
925  auto *defOp = std::get<0>(pair).getDefiningOp();
926  rewriter.setInsertionPointAfter(defOp);
927  auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
928  eventRecordCallBuilder.create(loc, rewriter, {event, operand});
929  events.push_back(event);
930  } else {
931  // Otherwise the converted operand is an event. This assumes that we use
932  // events in control flow code as well.
933  events.push_back(operand);
934  }
935  }
936  rewriter.restoreInsertionPoint(insertionPoint);
937  auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
938  for (auto event : events)
939  streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
940  for (auto event : events)
941  eventDestroyCallBuilder.create(loc, rewriter, {event});
942  rewriter.replaceOp(waitOp, {stream});
943 
944  return success();
945 }
946 
947 // Legalize the op's operands.
948 LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
949  gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
950  ConversionPatternRewriter &rewriter) const {
951  if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
952  return failure();
953 
954  if (launchOp.getAsyncDependencies().size() > 1)
955  return rewriter.notifyMatchFailure(
956  launchOp, "Cannot convert with more than one async dependency.");
957 
958  // Fail when the synchronous version of the op has async dependencies. The
959  // lowering destroys the stream, and we do not want to check that there is no
960  // use of the stream after this op.
961  if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
962  return rewriter.notifyMatchFailure(
963  launchOp, "Cannot convert non-async op with async dependencies.");
964 
965  Location loc = launchOp.getLoc();
966 
967  Value stream = Value();
968  if (!adaptor.getAsyncDependencies().empty())
969  stream = adaptor.getAsyncDependencies().front();
970  // If the async keyword is present and there are no dependencies, then a
971  // stream must be created to pass to subsequent operations.
972  else if (launchOp.getAsyncToken())
973  stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
974 
975  // Lower the kernel operands to match kernel parameters.
976  // Note: If `useBarePtrCallConv` is set in the type converter's options,
977  // the value of `kernelBarePtrCallConv` will be ignored.
978  OperandRange origArguments = launchOp.getKernelOperands();
979  SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
980  loc, origArguments, adaptor.getKernelOperands(), rewriter,
981  /*useBarePtrCallConv=*/kernelBarePtrCallConv);
982  SmallVector<Value, 8> llvmArgumentsWithSizes;
983 
984  // Intersperse size information if requested.
985  if (kernelIntersperseSizeCallConv) {
986  if (origArguments.size() != llvmArguments.size()) {
987  // This shouldn't happen if the bare-pointer calling convention is used.
988  return rewriter.notifyMatchFailure(
989  launchOp,
990  "Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
991  }
992 
993  llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
994  for (auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
995  auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
996  if (!memrefTy) {
997  return rewriter.notifyMatchFailure(
998  launchOp, "Operand to launch op is not a memref.");
999  }
1000 
1001  if (!memrefTy.hasStaticShape() ||
1002  !memrefTy.getElementType().isIntOrFloat()) {
1003  return rewriter.notifyMatchFailure(
1004  launchOp, "Operand to launch op is not a memref with a static "
1005  "shape and an integer or float element type.");
1006  }
1007 
1008  unsigned bitwidth = memrefTy.getElementTypeBitWidth();
1009  if (bitwidth % 8 != 0) {
1010  return rewriter.notifyMatchFailure(
1011  launchOp, "Operand to launch op is not a memref with a "
1012  "byte-aligned element type.");
1013  }
1014 
1015  uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) *
1016  static_cast<uint64_t>(memrefTy.getNumElements());
1017 
1018  Value sizeArg = LLVM::ConstantOp::create(
1019  rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize));
1020  llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer.
1021  llvmArgumentsWithSizes.push_back(sizeArg);
1022  }
1023  }
1024 
1025  std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1026  if (launchOp.hasClusterSize()) {
1027  clusterSize =
1028  gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
1029  adaptor.getClusterSizeZ()};
1030  }
1031  gpu::LaunchFuncOp::create(
1032  rewriter, launchOp.getLoc(), launchOp.getKernelAttr(),
1033  gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1034  adaptor.getGridSizeZ()},
1035  gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1036  adaptor.getBlockSizeZ()},
1037  adaptor.getDynamicSharedMemorySize(),
1038  llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
1039  stream, clusterSize);
1040  if (launchOp.getAsyncToken())
1041  rewriter.replaceOp(launchOp, {stream});
1042  else
1043  rewriter.eraseOp(launchOp);
1044  return success();
1045 }
1046 
1048  ConversionPatternRewriter &rewriter,
1049  LLVM::LLVMPointerType destinationType,
1050  Value sourcePtr,
1051  const LLVMTypeConverter &typeConverter) {
1052  auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
1053  if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1054  sourcePtr = LLVM::AddrSpaceCastOp::create(
1055  rewriter, loc,
1057  destinationType.getAddressSpace()),
1058  sourcePtr);
1059  return sourcePtr;
1060 }
1061 
1062 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1063  gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1064  ConversionPatternRewriter &rewriter) const {
1065  auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1066 
1067  if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
1068  !isConvertibleAndHasIdentityMaps(memRefType) ||
1069  failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
1070  return failure();
1071 
1072  auto loc = memcpyOp.getLoc();
1073 
1074  MemRefDescriptor srcDesc(adaptor.getSrc());
1075  Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
1076 
1077  Type elementPtrType = getElementPtrType(memRefType);
1078  Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
1079  Value gepPtr = LLVM::GEPOp::create(
1080  rewriter, loc, elementPtrType,
1081  typeConverter->convertType(memRefType.getElementType()), nullPtr,
1082  numElements);
1083  auto sizeBytes =
1084  LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr);
1085 
1086  auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1087  srcDesc.alignedPtr(rewriter, loc),
1088  *getTypeConverter());
1089  auto dst = bitAndAddrspaceCast(
1090  loc, rewriter, llvmPointerType,
1091  MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
1092  *getTypeConverter());
1093 
1094  auto stream = adaptor.getAsyncDependencies().front();
1095  memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1096 
1097  rewriter.replaceOp(memcpyOp, {stream});
1098 
1099  return success();
1100 }
1101 
1102 LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1103  gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1104  ConversionPatternRewriter &rewriter) const {
1105  auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1106 
1107  if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
1108  !isConvertibleAndHasIdentityMaps(memRefType) ||
1109  failed(isAsyncWithOneDependency(rewriter, memsetOp)))
1110  return failure();
1111 
1112  auto loc = memsetOp.getLoc();
1113 
1114  Type valueType = adaptor.getValue().getType();
1115  unsigned bitWidth = valueType.getIntOrFloatBitWidth();
1116  // Ints and floats of 16 or 32 bit width are allowed.
1117  if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1118  return rewriter.notifyMatchFailure(
1119  memsetOp, "value must be a 16 or 32 bit int or float");
1120  }
1121 
1122  unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth();
1123  Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1124 
1125  MemRefDescriptor dstDesc(adaptor.getDst());
1126  Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
1127 
1128  auto value =
1129  LLVM::BitcastOp::create(rewriter, loc, bitCastType, adaptor.getValue());
1130  auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1131  dstDesc.alignedPtr(rewriter, loc),
1132  *getTypeConverter());
1133 
1134  auto stream = adaptor.getAsyncDependencies().front();
1135  FunctionCallBuilder builder =
1136  valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1137  builder.create(loc, rewriter, {dst, value, numElements, stream});
1138 
1139  rewriter.replaceOp(memsetOp, {stream});
1140  return success();
1141 }
1142 
1143 LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1144  gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1145  ConversionPatternRewriter &rewriter) const {
1146  Location loc = op.getLoc();
1147  auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1148  {adaptor.getDevIndex()});
1149  rewriter.replaceOp(op, call);
1150  return success();
1151 }
1152 
1153 template <typename T>
1154 static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) {
1155  Type llvmInt32Type = builder.getIntegerType(32);
1156  return LLVM::ConstantOp::create(builder, loc, llvmInt32Type,
1157  static_cast<int32_t>(tValue));
1158 }
1159 
1160 template <typename T>
1161 static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) {
1162  Type llvmFloat32Type = builder.getF32Type();
1163  return LLVM::ConstantOp::create(
1164  builder, loc, llvmFloat32Type,
1165  builder.getF32FloatAttr(static_cast<float>(tValue)));
1166 }
1167 
1168 LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1169  gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1170  ConversionPatternRewriter &rewriter) const {
1171  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1172  failed(isAsyncWithOneDependency(rewriter, op)))
1173  return failure();
1174  Location loc = op.getLoc();
1175  auto stream = adaptor.getAsyncDependencies().front();
1176  Value pTensor =
1177  MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1178  Type dType = op.getMemref().getType().getElementType();
1179  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1180 
1181  SmallVector<Value, 4> dims;
1182  for (Value dim : adaptor.getDims()) {
1183  dims.push_back(dim);
1184  }
1185 
1186  Value handle;
1187  // TODO: For now, we track the use of the handle and lower it to cusparse /
1188  // cusparseLt accordingly. If in a block, both cusparse and cusparseLt are
1189  // used, we require two separate Creation ops to be the correct logic. In
1190  // future, we may add support to using one handle in sparse tensor / GPU
1191  // dialect in both cusparse and cusparseLt. use the cusparseLt create call if
1192  // the dnmat is used with spmat with 2:4 sparsity
1193  if (dims.size() == 2) {
1194  if (isSpMMCusparseLtOp(op.getDnTensor())) {
1195  auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1196  rewriter.getIndexAttr(11032));
1197  handle = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1198  llvmInt8Type, handleSz, /*alignment=*/16);
1199  handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1200 
1201  createLtDnMatCallBuilder
1202  .create(loc, rewriter,
1203  {handle, dims[0], dims[1], pTensor, dtp, stream})
1204  .getResult();
1205  } else {
1206  handle =
1207  createDnMatCallBuilder
1208  .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1209  .getResult();
1210  }
1211  } else {
1212  assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1213  handle = createDnVecCallBuilder
1214  .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1215  .getResult();
1216  }
1217  rewriter.replaceOp(op, {handle, stream});
1218  return success();
1219 }
1220 
1221 LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1222  gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1223  ConversionPatternRewriter &rewriter) const {
1224  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1225  failed(isAsyncWithOneDependency(rewriter, op)))
1226  return failure();
1227  Location loc = op.getLoc();
1228  auto stream = adaptor.getAsyncDependencies().front();
1229  auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
1230  SmallVector<Value, 4> dims;
1231  for (Value dim : definingOp.getDims()) {
1232  dims.push_back(dim);
1233  }
1234  if (dims.size() == 2) {
1235  // Use the cusparseLt destroy call if the dnmat is used with spmat with
1236  // 2:4 sparsity
1237  if (isSpMMCusparseLtOp(op.getDnTensor())) {
1238  destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1239  {adaptor.getDnTensor(), stream});
1240  } else {
1241  destroyDnMatCallBuilder.create(loc, rewriter,
1242  {adaptor.getDnTensor(), stream});
1243  }
1244  } else {
1245  assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1246  destroyDnVecCallBuilder.create(loc, rewriter,
1247  {adaptor.getDnTensor(), stream});
1248  }
1249  rewriter.replaceOp(op, {stream});
1250  return success();
1251 }
1252 
1253 LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1254  gpu::CreateCooOp op, OpAdaptor adaptor,
1255  ConversionPatternRewriter &rewriter) const {
1256  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1257  failed(isAsyncWithOneDependency(rewriter, op)))
1258  return failure();
1259  Location loc = op.getLoc();
1260  auto stream = adaptor.getAsyncDependencies().front();
1261  Value pRowIdxs =
1262  MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1263  Value pColIdxs =
1264  MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1265  Value pValues =
1266  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1267  Type iType =
1268  llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1269  Type dType =
1270  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1271  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1272  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1273  auto handle =
1274  createCooCallBuilder
1275  .create(loc, rewriter,
1276  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1277  pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1278  .getResult();
1279  rewriter.replaceOp(op, {handle, stream});
1280  return success();
1281 }
1282 
1283 LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1284  gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1285  ConversionPatternRewriter &rewriter) const {
1286  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1287  failed(isAsyncWithOneDependency(rewriter, op)))
1288  return failure();
1289  Location loc = op.getLoc();
1290  auto stream = adaptor.getAsyncDependencies().front();
1291  Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc);
1292  Value pValues =
1293  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1294  Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1295  Type dType =
1296  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1297  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1298  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1299  auto handle =
1300  createCooAoSCallBuilder
1301  .create(loc, rewriter,
1302  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1303  pIdxs, pValues, itp, dtp, stream})
1304  .getResult();
1305  rewriter.replaceOp(op, {handle, stream});
1306  return success();
1307 }
1308 
1309 LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1310  gpu::CreateCsrOp op, OpAdaptor adaptor,
1311  ConversionPatternRewriter &rewriter) const {
1312  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1313  failed(isAsyncWithOneDependency(rewriter, op)))
1314  return failure();
1315  Location loc = op.getLoc();
1316  auto stream = adaptor.getAsyncDependencies().front();
1317  Value pRowPos =
1318  MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc);
1319  Value pColIdxs =
1320  MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1321  Value pValues =
1322  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1323  Type pType =
1324  llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1325  Type iType =
1326  llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1327  Type dType =
1328  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1329  auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1330  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1331  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1332  auto handle =
1333  createCsrCallBuilder
1334  .create(loc, rewriter,
1335  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1336  pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1337  .getResult();
1338  rewriter.replaceOp(op, {handle, stream});
1339  return success();
1340 }
1341 
1342 LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1343  gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1344  ConversionPatternRewriter &rewriter) const {
1345  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1346  failed(isAsyncWithOneDependency(rewriter, op)))
1347  return failure();
1348  Location loc = op.getLoc();
1349  auto stream = adaptor.getAsyncDependencies().front();
1350  Value pMat =
1351  MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1352  Type dType =
1353  llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1354  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1355 
1356  // CUDA runner asserts the size is 44104 bytes.
1357  auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1358  rewriter.getIndexAttr(44104));
1359  Value handle = LLVM::AllocaOp::create(
1360  rewriter, loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1361  handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1362 
1363  create2To4SpMatCallBuilder
1364  .create(loc, rewriter,
1365  {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1366  .getResult();
1367  rewriter.replaceOp(op, {handle, stream});
1368  return success();
1369 }
1370 
1371 LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1372  gpu::DestroySpMatOp op, OpAdaptor adaptor,
1373  ConversionPatternRewriter &rewriter) const {
1374  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1375  failed(isAsyncWithOneDependency(rewriter, op)))
1376  return failure();
1377  Location loc = op.getLoc();
1378  auto stream = adaptor.getAsyncDependencies().front();
1379  // Use the cusparseLt destroy call if the spmat is 2:4 sparsity
1380  if (is2To4Sparsity(op.getSpmat())) {
1381  destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1382  {adaptor.getSpmat(), stream});
1383 
1384  } else {
1385  destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1386  }
1387  rewriter.replaceOp(op, {stream});
1388  return success();
1389 }
1390 
1391 LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1392  gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1393  ConversionPatternRewriter &rewriter) const {
1394  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1395  failed(isAsyncWithOneDependency(rewriter, op)))
1396  return failure();
1397  Location loc = op.getLoc();
1398  auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
1399  auto computeType = genConstInt32From(
1400  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1401  auto stream = adaptor.getAsyncDependencies().front();
1402  auto bufferSize = spMVBufferSizeCallBuilder
1403  .create(loc, rewriter,
1404  {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1405  adaptor.getDnY(), computeType, stream})
1406  .getResult();
1407  rewriter.replaceOp(op, {bufferSize, stream});
1408  return success();
1409 }
1410 
1411 LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1412  gpu::SpMVOp op, OpAdaptor adaptor,
1413  ConversionPatternRewriter &rewriter) const {
1414  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1415  failed(isAsyncWithOneDependency(rewriter, op)))
1416  return failure();
1417  Location loc = op.getLoc();
1418  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1419  auto computeType = genConstInt32From(
1420  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1421  auto stream = adaptor.getAsyncDependencies().front();
1422  Value pBuf =
1423  MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1424  spMVCallBuilder.create(loc, rewriter,
1425  {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1426  adaptor.getDnY(), computeType, pBuf, stream});
1427  rewriter.replaceOp(op, {stream});
1428  return success();
1429 }
1430 
1431 LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1432  gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1433  ConversionPatternRewriter &rewriter) const {
1434  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1435  failed(isAsyncWithOneDependency(rewriter, op)))
1436  return failure();
1437  Location loc = op.getLoc();
1438  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1439  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1440  auto stream = adaptor.getAsyncDependencies().front();
1441  Value bufferSize;
1442  if (is2To4Sparsity(op.getSpmatA())) {
1443  auto pruneFlag =
1444  genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA()));
1445  auto computeType = genConstInt32From(
1446  rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
1447  auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1448  rewriter.getIndexAttr(3));
1449  auto bufferSize =
1450  LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmPointerType,
1451  three, /*alignment=*/16);
1452  createCuSparseLtSpMMBufferSizeBuilder
1453  .create(loc, rewriter,
1454  {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1455  adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1456  pruneFlag, stream})
1457  .getResult();
1458 
1459  auto bufferSizePtr1 = LLVM::GEPOp::create(
1460  rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1461  ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1462  rewriter.getIndexAttr(1))});
1463  auto bufferSizePtr2 = LLVM::GEPOp::create(
1464  rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1465  ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1466  rewriter.getIndexAttr(2))});
1467  auto bufferSize0 =
1468  LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSize);
1469  auto bufferSize1 =
1470  LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr1);
1471  auto bufferSize2 =
1472  LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr2);
1473 
1474  rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1475  } else {
1476  auto computeType = genConstInt32From(
1477  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1478  bufferSize =
1479  createSpMMBufferSizeCallBuilder
1480  .create(loc, rewriter,
1481  {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1482  adaptor.getDnmatC(), computeType, stream})
1483  .getResult();
1484  rewriter.replaceOp(op, {bufferSize, stream});
1485  }
1486  return success();
1487 }
1488 
1489 LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1490  gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1491  ConversionPatternRewriter &rewriter) const {
1492  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1493  failed(isAsyncWithOneDependency(rewriter, op)))
1494  return failure();
1495  Location loc = op.getLoc();
1496  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1497  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1498  auto computeType = genConstInt32From(
1499  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1500  auto stream = adaptor.getAsyncDependencies().front();
1501  auto bufferSize =
1502  createSDDMMBufferSizeCallBuilder
1503  .create(loc, rewriter,
1504  {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1505  adaptor.getSpmatC(), computeType, stream})
1506  .getResult();
1507  rewriter.replaceOp(op, {bufferSize, stream});
1508  return success();
1509 }
1510 
1511 LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1512  gpu::SpMMOp op, OpAdaptor adaptor,
1513  ConversionPatternRewriter &rewriter) const {
1514  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1515  failed(isAsyncWithOneDependency(rewriter, op)))
1516  return failure();
1517  Location loc = op.getLoc();
1518  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1519  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1520  auto computeType = genConstInt32From(
1521  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1522 
1523  auto stream = adaptor.getAsyncDependencies().front();
1524 
1525  // Lower to cusparseLt if applicable
1526  if (is2To4Sparsity(op.getSpmatA())) {
1527  SmallVector<Value> pBufs;
1528  for (Value buffer : adaptor.getBuffers()) {
1529  Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
1530  pBufs.push_back(pBuf);
1531  }
1532  createCuSparseLtSpMMBuilder.create(
1533  loc, rewriter,
1534  {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1535  pBufs[0], pBufs[1], pBufs[2], stream});
1536  } else {
1537  Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
1538  .allocatedPtr(rewriter, loc);
1539  createSpMMCallBuilder.create(loc, rewriter,
1540  {modeA, modeB, adaptor.getSpmatA(),
1541  adaptor.getDnmatB(), adaptor.getDnmatC(),
1542  computeType, pBuf, stream});
1543  }
1544  rewriter.replaceOp(op, {stream});
1545  return success();
1546 }
1547 
1548 template <typename T>
1550  converter.addConversion([&converter](T) -> Type {
1551  return LLVM::LLVMPointerType::get(&converter.getContext());
1552  });
1553 }
1554 
1555 LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1556  gpu::SDDMMOp op, OpAdaptor adaptor,
1557  ConversionPatternRewriter &rewriter) const {
1558  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1559  failed(isAsyncWithOneDependency(rewriter, op)))
1560  return failure();
1561  Location loc = op.getLoc();
1562  auto computeType = genConstInt32From(
1563  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1564  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1565  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1566  auto stream = adaptor.getAsyncDependencies().front();
1567  Value pBuf =
1568  MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1569  createSDDMMCallBuilder.create(loc, rewriter,
1570  {modeA, modeB, adaptor.getDnmatA(),
1571  adaptor.getDnmatB(), adaptor.getSpmatC(),
1572  computeType, pBuf, stream});
1573  rewriter.replaceOp(op, {stream});
1574  return success();
1575 }
1576 
1577 LogicalResult
1578 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1579  gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1580  ConversionPatternRewriter &rewriter) const {
1581  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1582  failed(isAsyncWithOneDependency(rewriter, op)))
1583  return failure();
1584  Location loc = op.getLoc();
1585  auto stream = adaptor.getAsyncDependencies().front();
1586  Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1587  .getResult();
1588  rewriter.replaceOp(op, {descr, stream});
1589  return success();
1590 }
1591 
1592 LogicalResult
1593 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1594  gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1595  ConversionPatternRewriter &rewriter) const {
1596  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1597  failed(isAsyncWithOneDependency(rewriter, op)))
1598  return failure();
1599  Location loc = op.getLoc();
1600  auto stream = adaptor.getAsyncDependencies().front();
1601  createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1602  {adaptor.getDesc(), stream});
1603  rewriter.replaceOp(op, {stream});
1604  return success();
1605 }
1606 
1607 LogicalResult
1608 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1609  gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1610  ConversionPatternRewriter &rewriter) const {
1611  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1612  failed(isAsyncWithOneDependency(rewriter, op)))
1613  return failure();
1614  Location loc = op.getLoc();
1615  auto computeType = genConstInt32From(
1616  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1617  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1618  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1619  auto stream = adaptor.getAsyncDependencies().front();
1620 
1621  Value pBuf =
1622  MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1623  Value bufferSizeNew;
1624 
1625  if (adaptor.getKind() ==
1626  gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1627  bufferSizeNew =
1628  createSpGEMMWorkEstimationBuilder
1629  .create(loc, rewriter,
1630  {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1631  adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1632  adaptor.getBufferSz(), pBuf, stream})
1633  .getResult();
1634  } else {
1635  bufferSizeNew =
1636  createSpGEMMComputeBuilder
1637  .create(loc, rewriter,
1638  {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1639  adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1640  adaptor.getBufferSz(), pBuf, stream})
1641  .getResult();
1642  }
1643  rewriter.replaceOp(op, {bufferSizeNew, stream});
1644  return success();
1645 }
1646 
1647 LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1648  gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1649  ConversionPatternRewriter &rewriter) const {
1650  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1651  failed(isAsyncWithOneDependency(rewriter, op)))
1652  return failure();
1653  Location loc = op.getLoc();
1654  auto computeType = genConstInt32From(
1655  rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1656  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1657  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1658  auto stream = adaptor.getAsyncDependencies().front();
1659  createSpGEMMCopyBuilder.create(loc, rewriter,
1660  {adaptor.getDesc(), modeA, modeB,
1661  adaptor.getSpmatA(), adaptor.getSpmatB(),
1662  adaptor.getSpmatC(), computeType, stream});
1663  rewriter.replaceOp(op, {stream});
1664  return success();
1665 }
1666 
1667 LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1668  gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1669  ConversionPatternRewriter &rewriter) const {
1670  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1671  failed(isAsyncWithOneDependency(rewriter, op)))
1672  return failure();
1673  Location loc = op.getLoc();
1674  auto stream = adaptor.getAsyncDependencies().front();
1675 
1676  auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1677  rewriter.getIndexAttr(3));
1678  auto buffer = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1679  llvmInt64Type, three, /*alignment=*/16);
1680 
1681  auto rowsPtr = LLVM::GEPOp::create(
1682  rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1683  ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1684  rewriter.getIndexAttr(0))});
1685  auto colsPtr = LLVM::GEPOp::create(
1686  rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1687  ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1688  rewriter.getIndexAttr(1))});
1689  auto nnzsPtr = LLVM::GEPOp::create(
1690  rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1691  ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1692  rewriter.getIndexAttr(2))});
1693  createSpMatGetSizeBuilder.create(
1694  loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1695  auto rows = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, rowsPtr);
1696  auto cols = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, colsPtr);
1697  auto nnzs = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, nnzsPtr);
1698 
1699  rewriter.replaceOp(op, {rows, cols, nnzs, stream});
1700  return success();
1701 }
1702 
1703 LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1704  gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1705  ConversionPatternRewriter &rewriter) const {
1706  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1707  failed(isAsyncWithOneDependency(rewriter, op)))
1708  return failure();
1709  Location loc = op.getLoc();
1710  auto stream = adaptor.getAsyncDependencies().front();
1711  Value pPos =
1712  MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
1713  Value pCrd =
1714  MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
1715  Value pVal =
1716  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1717  createSetCsrPointersBuilder.create(
1718  loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1719  rewriter.replaceOp(op, {stream});
1720  return success();
1721 }
1722 
1723 LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1724  gpu::CreateCscOp op, OpAdaptor adaptor,
1725  ConversionPatternRewriter &rewriter) const {
1726  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1727  failed(isAsyncWithOneDependency(rewriter, op)))
1728  return failure();
1729  Location loc = op.getLoc();
1730  auto stream = adaptor.getAsyncDependencies().front();
1731  Value pColPos =
1732  MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc);
1733  Value pRowIdxs =
1734  MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1735  Value pValues =
1736  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1737  Type pType =
1738  llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1739  Type iType =
1740  llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1741  Type dType =
1742  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1743  auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1744  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1745  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1746  auto handle =
1747  createCscCallBuilder
1748  .create(loc, rewriter,
1749  {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1750  pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1751  .getResult();
1752  rewriter.replaceOp(op, {handle, stream});
1753  return success();
1754 }
1755 
1756 LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1757  gpu::CreateBsrOp op, OpAdaptor adaptor,
1758  ConversionPatternRewriter &rewriter) const {
1759  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1760  failed(isAsyncWithOneDependency(rewriter, op)))
1761  return failure();
1762  Location loc = op.getLoc();
1763  auto stream = adaptor.getAsyncDependencies().front();
1764  Value pRowPos =
1765  MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc);
1766  Value pColIdxs =
1767  MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc);
1768  Value pValues =
1769  MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1770  Type pType =
1771  llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1772  Type iType =
1773  llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1774  Type dType =
1775  llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1776  auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1777  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1778  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1779  auto handle =
1780  createBsrCallBuilder
1781  .create(loc, rewriter,
1782  {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1783  adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1784  pColIdxs, pValues, ptp, itp, dtp, stream})
1785  .getResult();
1786  rewriter.replaceOp(op, {handle, stream});
1787  return success();
1788 }
1789 
1792  bool kernelBarePtrCallConv, bool kernelIntersperseSizeCallConv) {
1793  addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
1794  addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
1795  addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
1796  addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter);
1797 
1798  patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
1799  ConvertDeallocOpToGpuRuntimeCallPattern,
1800  ConvertHostRegisterOpToGpuRuntimeCallPattern,
1801  ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1802  ConvertMemcpyOpToGpuRuntimeCallPattern,
1803  ConvertMemsetOpToGpuRuntimeCallPattern,
1804  ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1805  ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1806  ConvertWaitOpToGpuRuntimeCallPattern,
1807  ConvertAsyncYieldToGpuRuntimeCallPattern,
1808  ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1809  ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1810  ConvertCreateCooOpToGpuRuntimeCallPattern,
1811  ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1812  ConvertCreateCsrOpToGpuRuntimeCallPattern,
1813  ConvertCreateCscOpToGpuRuntimeCallPattern,
1814  ConvertCreateBsrOpToGpuRuntimeCallPattern,
1815  ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1816  ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1817  ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1818  ConvertSpMVOpToGpuRuntimeCallPattern,
1819  ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1820  ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1821  ConvertSpMMOpToGpuRuntimeCallPattern,
1822  ConvertSDDMMOpToGpuRuntimeCallPattern,
1823  ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1824  ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1825  ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1826  ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1827  ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1828  ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
1829  patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv,
1830  kernelIntersperseSizeCallConv);
1831 }
1832 
1833 //===----------------------------------------------------------------------===//
1834 // GPUModuleOp convert to LLVM op interface
1835 //===----------------------------------------------------------------------===//
1836 
1837 namespace {
1838 struct GPUModuleOpConvertToLLVMInterface
1839  : public ConvertToLLVMOpInterface::ExternalModel<
1840  GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
1841  /// Get the conversion patterns from the target attribute.
1842  void getConvertToLLVMConversionAttrs(
1844 };
1845 } // namespace
1846 
1847 void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
1849  auto module = cast<gpu::GPUModuleOp>(op);
1850  ArrayAttr targetsAttr = module.getTargetsAttr();
1851  // Fail if there are no target attributes or there is more than one target.
1852  if (!targetsAttr || targetsAttr.size() != 1)
1853  return;
1854  if (auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
1855  attrs.push_back(patternAttr);
1856 }
1857 
1859  registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
1860  gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
1861  });
1862 }
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static MLIRContext * getContext(OpFoldResult val)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
int64_t cols
int64_t rows
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
FloatType getF32Type()
Definition: Builders.cpp:42
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
MLIRContext * getContext() const
Definition: Builders.h:55
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:241
IntegerAttr getI8IntegerAttr(int8_t value)
Definition: Builders.cpp:216
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:209
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition: Pattern.cpp:36
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition: Pattern.cpp:56
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
This class helps build Operations.
Definition: Builders.h:205
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition: Builders.h:383
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition: Builders.h:244
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition: Builders.h:388
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
Definition: Region.h:205
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:795
void registerConvertGpuToLLVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMOpInterface interface on the gpu::GPUModuleOP operation.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Definition: MemoryOps.cpp:263
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool kernelBarePtrCallConv=false, bool kernelIntersperseSizeCallConv=false)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::LLVMFunctionType functionType
Definition: GPUCommonPass.h:59
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:39