MLIR 23.0.0git
GPUToLLVMConversion.cpp
Go to the documentation of this file.
1//===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert gpu.launch_func op into a sequence of
10// GPU runtime calls. As most of GPU runtimes does not have a stable published
11// ABI, this pass uses a slim runtime layer that builds on top of the public
12// API from GPU runtime headers.
13//
14//===----------------------------------------------------------------------===//
15
17
34#include "mlir/IR/Attributes.h"
35#include "mlir/IR/Builders.h"
36#include "mlir/IR/BuiltinOps.h"
39
40#include "llvm/ADT/STLExtras.h"
41
42#define DEBUG_TYPE "gpu-to-llvm"
43
44namespace mlir {
45#define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
46#include "mlir/Conversion/Passes.h.inc"
47} // namespace mlir
48
49using namespace mlir;
50
51namespace {
52class GpuToLLVMConversionPass
53 : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
54public:
55 using Base::Base;
56 void getDependentDialects(DialectRegistry &registry) const final {
57 Base::getDependentDialects(registry);
59 }
60 // Run the dialect converter on the module.
61 void runOnOperation() override;
62};
63
64template <typename OpTy>
65class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
66public:
67 explicit ConvertOpToGpuRuntimeCallPattern(
68 const LLVMTypeConverter &typeConverter)
69 : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
70
71protected:
72 Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
73 MemRefType type, MemRefDescriptor desc) const {
74 Type indexType = ConvertToLLVMPattern::getIndexType();
75 if (type.hasStaticShape())
77 rewriter, loc, indexType, type.getNumElements());
78 // Compute the number of elements by multiplying all the dim sizes.
79 uint64_t rank = type.getRank();
80 Value numElements = desc.size(rewriter, loc, /*pos=*/0);
81 for (unsigned i = 1; i < rank; i++)
82 numElements = LLVM::MulOp::create(rewriter, loc, numElements,
83 desc.size(rewriter, loc, /*pos=*/i));
84 return numElements;
85 }
86
87 MLIRContext *context = &this->getTypeConverter()->getContext();
88
89 Type llvmVoidType = LLVM::LLVMVoidType::get(context);
90 LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
91 Type llvmInt8Type = IntegerType::get(context, 8);
92 Type llvmInt16Type = IntegerType::get(context, 16);
93 Type llvmInt32Type = IntegerType::get(context, 32);
94 Type llvmInt64Type = IntegerType::get(context, 64);
95 Type llvmFloat32Type = Float32Type::get(context);
96 Type llvmIntPtrType = IntegerType::get(
97 context, this->getTypeConverter()->getPointerBitwidth(0));
98
99 FunctionCallBuilder streamCreateCallBuilder = {
100 "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
101 FunctionCallBuilder streamDestroyCallBuilder = {
102 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
103 FunctionCallBuilder streamSynchronizeCallBuilder = {
104 "mgpuStreamSynchronize",
105 llvmVoidType,
106 {llvmPointerType /* void *stream */}};
107 FunctionCallBuilder streamWaitEventCallBuilder = {
108 "mgpuStreamWaitEvent",
109 llvmVoidType,
110 {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
111 FunctionCallBuilder eventCreateCallBuilder = {
112 "mgpuEventCreate", llvmPointerType /* void *event */, {}};
113 FunctionCallBuilder eventDestroyCallBuilder = {
114 "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
115 FunctionCallBuilder eventSynchronizeCallBuilder = {
116 "mgpuEventSynchronize",
117 llvmVoidType,
118 {llvmPointerType /* void *event */}};
119 FunctionCallBuilder eventRecordCallBuilder = {
120 "mgpuEventRecord",
121 llvmVoidType,
122 {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
123 FunctionCallBuilder hostRegisterCallBuilder = {
124 "mgpuMemHostRegisterMemRef",
125 llvmVoidType,
126 {llvmIntPtrType /* intptr_t rank */,
127 llvmPointerType /* void *memrefDesc */,
128 llvmIntPtrType /* intptr_t elementSizeBytes */}};
129 FunctionCallBuilder hostUnregisterCallBuilder = {
130 "mgpuMemHostUnregisterMemRef",
131 llvmVoidType,
132 {llvmIntPtrType /* intptr_t rank */,
133 llvmPointerType /* void *memrefDesc */,
134 llvmIntPtrType /* intptr_t elementSizeBytes */}};
135 FunctionCallBuilder allocCallBuilder = {
136 "mgpuMemAlloc",
137 llvmPointerType /* void * */,
138 {llvmIntPtrType /* intptr_t sizeBytes */,
139 llvmPointerType /* void *stream */,
140 llvmInt8Type /* bool isHostShared */}};
141 FunctionCallBuilder deallocCallBuilder = {
142 "mgpuMemFree",
143 llvmVoidType,
144 {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
145 FunctionCallBuilder memcpyCallBuilder = {
146 "mgpuMemcpy",
147 llvmVoidType,
148 {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
149 llvmIntPtrType /* intptr_t sizeBytes */,
150 llvmPointerType /* void *stream */}};
151 FunctionCallBuilder memset16CallBuilder = {
152 "mgpuMemset16",
153 llvmVoidType,
154 {llvmPointerType /* void *dst */,
155 llvmInt16Type /* unsigned short value */,
156 llvmIntPtrType /* intptr_t sizeBytes */,
157 llvmPointerType /* void *stream */}};
158 FunctionCallBuilder memset32CallBuilder = {
159 "mgpuMemset32",
160 llvmVoidType,
161 {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
162 llvmIntPtrType /* intptr_t sizeBytes */,
163 llvmPointerType /* void *stream */}};
164 FunctionCallBuilder setDefaultDeviceCallBuilder = {
165 "mgpuSetDefaultDevice",
166 llvmVoidType,
167 {llvmInt32Type /* uint32_t devIndex */}};
168 FunctionCallBuilder createDnVecCallBuilder = {
169 "mgpuCreateDnVec",
170 llvmPointerType,
171 {llvmIntPtrType, llvmPointerType, llvmInt32Type,
172 llvmPointerType /* void *stream */}};
173 FunctionCallBuilder destroyDnVecCallBuilder = {
174 "mgpuDestroyDnVec",
175 llvmVoidType,
176 {llvmPointerType, llvmPointerType /* void *stream */}};
177 FunctionCallBuilder createDnMatCallBuilder = {
178 "mgpuCreateDnMat",
179 llvmPointerType,
180 {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
181 llvmPointerType /* void *stream */}};
182 FunctionCallBuilder destroyDnMatCallBuilder = {
183 "mgpuDestroyDnMat",
184 llvmVoidType,
185 {llvmPointerType, llvmPointerType /* void *stream */}};
186 FunctionCallBuilder createCooCallBuilder = {
187 "mgpuCreateCoo",
188 llvmPointerType,
189 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
190 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
191 llvmPointerType /* void *stream */}};
192 FunctionCallBuilder createCooAoSCallBuilder = {
193 "mgpuCreateCooAoS", // deprecated in cuSPARSE 11.2
194 llvmPointerType,
195 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
196 llvmPointerType, llvmInt32Type, llvmInt32Type,
197 llvmPointerType /* void *stream */}};
198 FunctionCallBuilder createCsrCallBuilder = {
199 "mgpuCreateCsr",
200 llvmPointerType,
201 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
202 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
203 llvmInt32Type, llvmPointerType /* void *stream */}};
204 FunctionCallBuilder createCscCallBuilder = {
205 "mgpuCreateCsc",
206 llvmPointerType,
207 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
208 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
209 llvmInt32Type, llvmPointerType /* void *stream */}};
210 FunctionCallBuilder createBsrCallBuilder = {
211 "mgpuCreateBsr",
212 llvmPointerType,
213 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
214 llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
215 llvmInt32Type, llvmInt32Type, llvmInt32Type,
216 llvmPointerType /* void *stream */}};
217 FunctionCallBuilder destroySpMatCallBuilder = {
218 "mgpuDestroySpMat",
219 llvmVoidType,
220 {llvmPointerType, llvmPointerType /* void *stream */}};
221 FunctionCallBuilder spMVBufferSizeCallBuilder = {
222 "mgpuSpMVBufferSize",
223 llvmIntPtrType,
224 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
225 llvmInt32Type, llvmPointerType /* void *stream */}};
226 FunctionCallBuilder spMVCallBuilder = {
227 "mgpuSpMV",
228 llvmVoidType,
229 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
230 llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
231 FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
232 "mgpuSpMMBufferSize",
233 llvmIntPtrType,
234 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
235 llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
236 FunctionCallBuilder createSpMMCallBuilder = {
237 "mgpuSpMM",
238 llvmVoidType,
239 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
240 llvmPointerType, llvmInt32Type, llvmPointerType,
241 llvmPointerType /* void *stream */}};
242 FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
243 "mgpuSDDMMBufferSize",
244 llvmIntPtrType,
245 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
246 llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
247 FunctionCallBuilder createSDDMMCallBuilder = {
248 "mgpuSDDMM",
249 llvmVoidType,
250 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
251 llvmPointerType, llvmInt32Type, llvmPointerType,
252 llvmPointerType /* void *stream */}};
253 FunctionCallBuilder createLtDnMatCallBuilder = {
254 "mgpuCreateCuSparseLtDnMat",
255 llvmVoidType,
256 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
257 llvmInt32Type, llvmPointerType /* void *stream */}};
258 FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
259 "mgpuDestroyCuSparseLtSpMat",
260 llvmVoidType,
261 {llvmPointerType, llvmPointerType /* void *stream */}};
262 FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
263 "mgpuDestroyCuSparseLtDnMat",
264 llvmVoidType,
265 {llvmPointerType, llvmPointerType /* void *stream */}};
266 FunctionCallBuilder create2To4SpMatCallBuilder = {
267 "mgpuCusparseLtCreate2To4SpMat",
268 llvmVoidType,
269 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
270 llvmInt32Type, llvmPointerType /* void *stream */}};
271 FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
272 "mgpuCuSparseLtSpMMBufferSize",
273 llvmVoidType,
274 {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
275 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
276 llvmPointerType /*void *stream*/}};
277 FunctionCallBuilder createCuSparseLtSpMMBuilder = {
278 "mgpuCuSparseLtSpMM",
279 llvmVoidType,
280 {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
281 llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}};
282 FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
283 "mgpuSpGEMMCreateDescr",
284 llvmPointerType,
285 {llvmPointerType /*void *stream*/}};
286 FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
287 "mgpuSpGEMMDestroyDescr",
288 llvmVoidType,
289 {llvmPointerType /*s*/, llvmPointerType /*void *stream*/}};
290 FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
291 "mgpuSpGEMMWorkEstimation",
292 llvmIntPtrType,
293 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
294 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
295 llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
296 llvmPointerType /*void *stream*/}};
297 FunctionCallBuilder createSpGEMMComputeBuilder = {
298 "mgpuSpGEMMCompute",
299 llvmIntPtrType,
300 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
301 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
302 llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
303 llvmPointerType /*void *stream*/}};
304 FunctionCallBuilder createSpGEMMCopyBuilder = {
305 "mgpuSpGEMMCopy",
306 llvmVoidType,
307 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
308 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
309 llvmInt32Type /*ctp*/, llvmPointerType /*void *stream*/}};
310 FunctionCallBuilder createSpMatGetSizeBuilder = {
311 "mgpuSpMatGetSize",
312 llvmVoidType,
313 {llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/,
314 llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}};
315 FunctionCallBuilder createSetCsrPointersBuilder = {
316 "mgpuSetCsrPointers",
317 llvmVoidType,
318 {llvmPointerType /*spmat*/, llvmPointerType /*pos*/,
319 llvmPointerType /*crd*/, llvmPointerType /*val*/,
320 llvmPointerType /*void *stream*/}};
321};
322
323/// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
324/// call. Currently it supports CUDA and ROCm (HIP).
325class ConvertHostRegisterOpToGpuRuntimeCallPattern
326 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
327public:
328 ConvertHostRegisterOpToGpuRuntimeCallPattern(
329 const LLVMTypeConverter &typeConverter)
330 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
331
332private:
333 LogicalResult
334 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter) const override;
336};
337
338class ConvertHostUnregisterOpToGpuRuntimeCallPattern
339 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
340public:
341 ConvertHostUnregisterOpToGpuRuntimeCallPattern(
342 const LLVMTypeConverter &typeConverter)
343 : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
344 }
345
346private:
347 LogicalResult
348 matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter) const override;
350};
351
352/// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
353/// call. Currently it supports CUDA and ROCm (HIP).
354class ConvertAllocOpToGpuRuntimeCallPattern
355 : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
356public:
357 ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
358 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
359
360private:
361 LogicalResult
362 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter) const override;
364};
365
366/// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
367/// call. Currently it supports CUDA and ROCm (HIP).
368class ConvertDeallocOpToGpuRuntimeCallPattern
369 : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
370public:
371 ConvertDeallocOpToGpuRuntimeCallPattern(
372 const LLVMTypeConverter &typeConverter)
373 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
374
375private:
376 LogicalResult
377 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
378 ConversionPatternRewriter &rewriter) const override;
379};
380
381class ConvertAsyncYieldToGpuRuntimeCallPattern
382 : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
383public:
384 ConvertAsyncYieldToGpuRuntimeCallPattern(
385 const LLVMTypeConverter &typeConverter)
386 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
387
388private:
389 LogicalResult
390 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
391 ConversionPatternRewriter &rewriter) const override;
392};
393
394/// A rewrite pattern to convert gpu.wait operations into a GPU runtime
395/// call. Currently it supports CUDA and ROCm (HIP).
396class ConvertWaitOpToGpuRuntimeCallPattern
397 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
398public:
399 ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
400 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
401
402private:
403 LogicalResult
404 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
405 ConversionPatternRewriter &rewriter) const override;
406};
407
408/// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
409/// call. Currently it supports CUDA and ROCm (HIP).
410class ConvertWaitAsyncOpToGpuRuntimeCallPattern
411 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
412public:
413 ConvertWaitAsyncOpToGpuRuntimeCallPattern(
414 const LLVMTypeConverter &typeConverter)
415 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
416
417private:
418 LogicalResult
419 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
420 ConversionPatternRewriter &rewriter) const override;
421};
422
423/// A rewrite patter to legalize gpu.launch_func with LLVM types.
424class LegalizeLaunchFuncOpPattern
425 : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
426public:
427 LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
428 bool kernelBarePtrCallConv,
429 bool kernelIntersperseSizeCallConv)
430 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
431 kernelBarePtrCallConv(kernelBarePtrCallConv),
432 kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
433
434private:
435 LogicalResult
436 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
437 ConversionPatternRewriter &rewriter) const override;
438
439 bool kernelBarePtrCallConv;
440 bool kernelIntersperseSizeCallConv;
441};
442
443/// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
444/// call. Currently it supports CUDA and ROCm (HIP).
445class ConvertMemcpyOpToGpuRuntimeCallPattern
446 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
447public:
448 ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
449 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
450
451private:
452 LogicalResult
453 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
454 ConversionPatternRewriter &rewriter) const override;
455};
456
457/// A rewrite pattern to convert gpu.memset operations into a GPU runtime
458/// call. Currently it supports CUDA and ROCm (HIP).
459class ConvertMemsetOpToGpuRuntimeCallPattern
460 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
461public:
462 ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
463 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
464
465private:
466 LogicalResult
467 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
468 ConversionPatternRewriter &rewriter) const override;
469};
470
471/// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call.
472/// Currently supports CUDA and ROCm (HIP)
473class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
474 : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
475public:
476 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
477 const LLVMTypeConverter &typeConverter)
478 : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
479 typeConverter) {}
480
481 LogicalResult
482 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
483 ConversionPatternRewriter &rewriter) const override;
484};
485
486/// Generic rewriting rule for operation on sparse matrices.
487/// Currently supports CUDA (by means of cuSparse and cuSparseLt).
488#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
489 class Convert##op_name##ToGpuRuntimeCallPattern \
490 : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
491 public: \
492 Convert##op_name##ToGpuRuntimeCallPattern( \
493 const LLVMTypeConverter &typeConverter) \
494 : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
495 \
496 private: \
497 LogicalResult \
498 matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
499 ConversionPatternRewriter &rewriter) const override; \
500 };
501
519DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp)
523
524} // namespace
525
526void GpuToLLVMConversionPass::runOnOperation() {
527 MLIRContext *context = &getContext();
528
529 // Perform progressive lowering of vector transfer operations.
530 {
531 RewritePatternSet patterns(&getContext());
532 // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
534 /*maxTransferRank=*/1);
535 // Transform N-D vector.from_elements to 1-D vector.from_elements before
536 // conversion.
537 vector::populateVectorFromElementsUnrollPatterns(patterns);
538 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
539 return signalPassFailure();
540 }
541
542 LowerToLLVMOptions options(context);
543 options.useBarePtrCallConv = hostBarePtrCallConv;
544 RewritePatternSet patterns(context);
545 ConversionTarget target(*context);
546 target.addLegalDialect<LLVM::LLVMDialect>();
547 LLVMTypeConverter converter(context, options);
548
549 // Populate all patterns from all dialects that implement the
550 // `ConvertToLLVMPatternInterface` interface.
551 for (Dialect *dialect : context->getLoadedDialects()) {
552 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
553 if (!iface)
554 continue;
555 iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
556 }
557
558 // Preserve GPU modules and binaries. Modules are preserved as they can be
559 // converted later by `gpu-module-to-binary`.
560 target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
561 // Accept as legal LaunchFuncOps if the operands have been lowered.
562 target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
563 [&](gpu::LaunchFuncOp op) -> bool { return converter.isLegal(op); });
564
565 // These aren't covered by the ConvertToLLVMPatternInterface right now.
566 populateVectorToLLVMConversionPatterns(converter, patterns);
569 target);
570 populateGpuToLLVMConversionPatterns(converter, patterns,
571 kernelBarePtrCallConv,
572 kernelIntersperseSizeCallConv);
573
574 if (failed(
575 applyPartialConversion(getOperation(), target, std::move(patterns))))
576 signalPassFailure();
577}
578
580 ArrayRef<Value> arguments) const {
581 auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
582 auto function = [&] {
583 if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
584 return function;
585 auto builder = OpBuilder::atBlockEnd(module.getBody());
586 return LLVM::LLVMFuncOp::create(builder, loc, functionName, functionType);
587 }();
588 return LLVM::CallOp::create(builder, loc, function, arguments);
589}
590
591// Corresponding to cusparseIndexType_t defined in cusparse.h.
592static int32_t getCuSparseIndexTypeFrom(Type type) {
593 if (type.isInteger(16))
594 return 1; // CUSPARSE_INDEX_16U
595 if (type.isInteger(32))
596 return 2; // CUSPARSE_INDEX_32I
597 return 3; // CUSPARSE_INDEX_64I
598}
599
600static int32_t getCuSparseLtDataTypeFrom(Type type) {
601 if (type.isF16())
602 return 0; // CUSPARSE_COMPUTE_16F,
603 if (type.isInteger(32))
604 return 1; // CUSPARSE_COMPUTE_32I
605 llvm_unreachable("unsupported type");
606 // TODO: add support to TF32
607}
608
609// Corresponding to cudaDataType_t defined in CUDA library_types.h.
610static int32_t getCuSparseDataTypeFrom(Type type) {
611 if (llvm::isa<ComplexType>(type)) {
612 // get the element type
613 auto elementType = cast<ComplexType>(type).getElementType();
614 if (elementType.isBF16())
615 return 15; // CUDA_C_16BF
616 if (elementType.isF16())
617 return 6; // CUDA_C_16F
618 if (elementType.isF32())
619 return 4; // CUDA_C_32F
620 if (elementType.isF64())
621 return 5; // CUDA_C_64F
622 if (elementType.isInteger(8))
623 return 7; // CUDA_C_8I
624 if (elementType.isInteger(16))
625 return 21; // CUDA_C_16I
626 if (elementType.isInteger(32))
627 return 11; // CUDA_C_32I
628 }
629 if (type.isBF16())
630 return 14; // CUDA_R_16BF
631 if (type.isF16())
632 return 2; // CUDA_R_16F
633 if (type.isF32())
634 return 0; // CUDA_R_32F
635 if (type.isF64())
636 return 1; // CUDA_R_64F
637 if (type.isInteger(8))
638 return 3; // CUDA_R_8I
639 if (type.isInteger(16))
640 return 20; // CUDA_R_16I
641 if (type.isInteger(32))
642 return 10; // CUDA_R_32I
643
644 llvm_unreachable("unsupported element type");
645}
646
647static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat) {
648 return spMat.getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
649}
650
651// TODO: We may want a run-time (of the mlir compiler) disablement/warning:
652// cusparseLt currently won't work for cuda architecture <8.0 and will trigger a
653// runtime (of the CUDA program) error , but it might be great if we could at
654// least output a warning when we found the target architecture is <8.0 and the
655// user still wants to use cusparseLt. to make sure when lowering gpu sparse
656// dialect to llvm calls, the cusparselt calls are disabled for cuda
657// architecture <8.0
658static bool is2To4Sparsity(Value spMat) {
659 if (auto op = spMat.getDefiningOp<gpu::Create2To4SpMatOp>())
660 return true;
661 if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
662 return false;
663 if (auto op = spMat.getDefiningOp<gpu::CreateCooAoSOp>())
664 return false;
665 if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
666 return false;
667 if (auto op = spMat.getDefiningOp<gpu::CreateCscOp>())
668 return false;
669 if (auto op = spMat.getDefiningOp<gpu::CreateBsrOp>())
670 return false;
671 // Print the spMat defining op
672 spMat.getDefiningOp()->print(llvm::errs());
673 llvm_unreachable("cannot find spmat def");
674}
675
676static bool isSpMMCusparseLtOp(Value op) {
677 for (Operation *user : op.getUsers()) {
678 auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
679 // If the other operator is 50% sparsity then we should use cusparseLt
680 if (!spmmOp)
681 continue;
682 if (is2To4Sparsity(spmmOp.getSpmatA()))
683 return true;
684 }
685 return false;
686}
687
688// Returns whether all operands are of LLVM type.
689static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
690 ConversionPatternRewriter &rewriter) {
691 if (!llvm::all_of(operands, [](Value value) {
692 return LLVM::isCompatibleType(value.getType());
693 }))
694 return rewriter.notifyMatchFailure(
695 op, "Cannot convert if operands aren't of LLVM type.");
696 return success();
697}
698
699static LogicalResult
700isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
701 gpu::AsyncOpInterface op) {
702 if (op.getAsyncDependencies().size() != 1)
703 return rewriter.notifyMatchFailure(
704 op, "Can only convert with exactly one async dependency.");
705
706 if (!op.getAsyncToken())
707 return rewriter.notifyMatchFailure(op, "Can convert only async version.");
708
709 return success();
710}
711
712LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
713 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
714 ConversionPatternRewriter &rewriter) const {
715 auto *op = hostRegisterOp.getOperation();
716 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
717 return failure();
718
719 Location loc = op->getLoc();
720
721 auto memRefType = hostRegisterOp.getValue().getType();
722 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
723 auto elementSize = getSizeInBytes(loc, elementType, rewriter);
724
725 auto arguments = getTypeConverter()->promoteOperands(
726 loc, op->getOperands(), adaptor.getOperands(), rewriter);
727 arguments.push_back(elementSize);
728 hostRegisterCallBuilder.create(loc, rewriter, arguments);
729
730 rewriter.eraseOp(op);
731 return success();
732}
733
734LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
735 gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
736 ConversionPatternRewriter &rewriter) const {
737 Operation *op = hostUnregisterOp.getOperation();
738 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
739 return failure();
740
741 Location loc = op->getLoc();
742
743 auto memRefType = hostUnregisterOp.getValue().getType();
744 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
745 auto elementSize = getSizeInBytes(loc, elementType, rewriter);
746
747 auto arguments = getTypeConverter()->promoteOperands(
748 loc, op->getOperands(), adaptor.getOperands(), rewriter);
749 arguments.push_back(elementSize);
750 hostUnregisterCallBuilder.create(loc, rewriter, arguments);
751
752 rewriter.eraseOp(op);
753 return success();
754}
755
756LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
757 gpu::AllocOp allocOp, OpAdaptor adaptor,
758 ConversionPatternRewriter &rewriter) const {
759
760 MemRefType memRefType = allocOp.getType();
761
762 if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
763 !isConvertibleAndHasIdentityMaps(memRefType))
764 return failure();
765
766 auto loc = allocOp.getLoc();
767
768 bool isShared = allocOp.getHostShared();
769
770 if (isShared && allocOp.getAsyncToken())
771 return rewriter.notifyMatchFailure(
772 allocOp, "Host Shared allocation cannot be done async");
773 if (!isShared && failed(isAsyncWithOneDependency(rewriter, allocOp)))
774 return failure();
775
776 // Get shape of the memref as values: static sizes are constant
777 // values and dynamic sizes are passed to 'alloc' as operands.
778 SmallVector<Value, 4> shape;
779 SmallVector<Value, 4> strides;
780 Value sizeBytes;
781 getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
782 shape, strides, sizeBytes);
783
784 // Allocate the underlying buffer and store a pointer to it in the MemRef
785 // descriptor.
786 auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType);
787 Value stream = adaptor.getAsyncDependencies().empty()
788 ? nullPtr
789 : adaptor.getAsyncDependencies().front();
790
791 auto isHostShared = mlir::LLVM::ConstantOp::create(
792 rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
793
794 Value allocatedPtr =
795 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
796 .getResult();
797
798 // No alignment.
799 Value alignedPtr = allocatedPtr;
800
801 // Create the MemRef descriptor.
802 auto memRefDescriptor = this->createMemRefDescriptor(
803 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
804
805 if (allocOp.getAsyncToken()) {
806 // Async alloc: make dependent ops use the same stream.
807 rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
808 } else {
809 rewriter.replaceOp(allocOp, {memRefDescriptor});
810 }
811
812 return success();
813}
814
815LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
816 gpu::DeallocOp deallocOp, OpAdaptor adaptor,
817 ConversionPatternRewriter &rewriter) const {
818 if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
819 failed(isAsyncWithOneDependency(rewriter, deallocOp)))
820 return failure();
821
822 Location loc = deallocOp.getLoc();
823
824 Value pointer =
825 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
826 Value stream = adaptor.getAsyncDependencies().front();
827 deallocCallBuilder.create(loc, rewriter, {pointer, stream});
828
829 rewriter.replaceOp(deallocOp, {stream});
830 return success();
831}
832
833static bool isGpuAsyncTokenType(Value value) {
834 return isa<gpu::AsyncTokenType>(value.getType());
835}
836
837// Converts !gpu.async.token operands of `async.yield` to runtime calls. The
838// !gpu.async.token are lowered to stream within the async.execute region, but
839// are passed as events between them. For each !gpu.async.token operand, we
840// create an event and record it on the stream.
841LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
842 async::YieldOp yieldOp, OpAdaptor adaptor,
843 ConversionPatternRewriter &rewriter) const {
844 if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType))
845 return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
846
847 Location loc = yieldOp.getLoc();
848 SmallVector<Value, 4> newOperands(adaptor.getOperands());
849 llvm::SmallDenseSet<Value> streams;
850 for (auto &operand : yieldOp->getOpOperands()) {
851 if (!isGpuAsyncTokenType(operand.get()))
852 continue;
853 auto idx = operand.getOperandNumber();
854 auto stream = adaptor.getOperands()[idx];
855 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
856 eventRecordCallBuilder.create(loc, rewriter, {event, stream});
857 newOperands[idx] = event;
858 streams.insert(stream);
859 }
860 for (auto stream : streams)
861 streamDestroyCallBuilder.create(loc, rewriter, {stream});
862
863 rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
864 return success();
865}
866
867// Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
868static bool isDefinedByCallTo(Value value, StringRef functionName) {
869 assert(isa<LLVM::LLVMPointerType>(value.getType()));
870 if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
871 return *defOp.getCallee() == functionName;
872 return false;
873}
874
875// Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
876// with the stream/event operands. The operands are destroyed. That is, it
877// assumes that it is not used afterwards or elsewhere. Otherwise we will get a
878// runtime error. Eventually, we should guarantee this property.
879LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
880 gpu::WaitOp waitOp, OpAdaptor adaptor,
881 ConversionPatternRewriter &rewriter) const {
882 if (waitOp.getAsyncToken())
883 return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
884
885 Location loc = waitOp.getLoc();
886
887 for (auto operand : adaptor.getOperands()) {
888 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
889 // The converted operand's definition created a stream.
890 streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
891 streamDestroyCallBuilder.create(loc, rewriter, {operand});
892 } else {
893 // Otherwise the converted operand is an event. This assumes that we use
894 // events in control flow code as well.
895 eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
896 eventDestroyCallBuilder.create(loc, rewriter, {operand});
897 }
898 }
899
900 rewriter.eraseOp(waitOp);
901 return success();
902}
903
904// Converts `gpu.wait async` to runtime calls. The converted op creates a new
905// stream that is synchronized with stream/event operands. The operands are
906// destroyed. That is, it assumes that it is not used afterwards or elsewhere.
907// Otherwise we will get a runtime error. Eventually, we should guarantee this
908// property.
909LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
910 gpu::WaitOp waitOp, OpAdaptor adaptor,
911 ConversionPatternRewriter &rewriter) const {
912 if (!waitOp.getAsyncToken())
913 return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
914
915 Location loc = waitOp.getLoc();
916
917 auto insertionPoint = rewriter.saveInsertionPoint();
918 SmallVector<Value, 1> events;
919 for (auto pair :
920 llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
921 auto operand = std::get<1>(pair);
922 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
923 // The converted operand's definition created a stream. Insert an event
924 // into the stream just after the last use of the original token operand.
925 auto *defOp = std::get<0>(pair).getDefiningOp();
926 rewriter.setInsertionPointAfter(defOp);
927 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
928 eventRecordCallBuilder.create(loc, rewriter, {event, operand});
929 events.push_back(event);
930 } else {
931 // Otherwise the converted operand is an event. This assumes that we use
932 // events in control flow code as well.
933 events.push_back(operand);
934 }
935 }
936 rewriter.restoreInsertionPoint(insertionPoint);
937 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
938 for (auto event : events)
939 streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
940 for (auto event : events)
941 eventDestroyCallBuilder.create(loc, rewriter, {event});
942 rewriter.replaceOp(waitOp, {stream});
943
944 return success();
945}
946
947// Legalize the op's operands.
948LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
949 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
950 ConversionPatternRewriter &rewriter) const {
951 if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
952 return failure();
953
954 // Fail when the synchronous version of the op has async dependencies. The
955 // lowering destroys the stream, and we do not want to check that there is no
956 // use of the stream after this op.
957 if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
958 return rewriter.notifyMatchFailure(
959 launchOp, "Cannot convert non-async op with async dependencies.");
960
961 Location loc = launchOp.getLoc();
962
963 Value stream = Value();
964 if (!adaptor.getAsyncDependencies().empty()) {
965 stream = adaptor.getAsyncDependencies().front();
966 // Synchronize additional async dependencies onto the primary stream using
967 // events, following the same approach as gpu.wait async lowering.
968 if (adaptor.getAsyncDependencies().size() > 1) {
969 auto insertionPoint = rewriter.saveInsertionPoint();
970 SmallVector<Value, 4> events;
971 for (auto [origDep, convertedDep] :
972 llvm::zip(launchOp.getAsyncDependencies().drop_front(),
973 adaptor.getAsyncDependencies().drop_front())) {
974 if (!isDefinedByCallTo(convertedDep,
975 streamCreateCallBuilder.functionName)) {
976 events.push_back(convertedDep);
977 continue;
978 }
979 Operation *defOp = origDep.getDefiningOp();
980 rewriter.setInsertionPointAfter(defOp);
981 Value event =
982 eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
983 eventRecordCallBuilder.create(loc, rewriter, {event, convertedDep});
984 events.push_back(event);
985 }
986 rewriter.restoreInsertionPoint(insertionPoint);
987 for (Value event : events)
988 streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
989 for (Value event : events)
990 eventDestroyCallBuilder.create(loc, rewriter, {event});
991 }
992 }
993 // If the async keyword is present and there are no dependencies, then a
994 // stream must be created to pass to subsequent operations.
995 else if (launchOp.getAsyncToken())
996 stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
997
998 // Lower the kernel operands to match kernel parameters.
999 // Note: If `useBarePtrCallConv` is set in the type converter's options,
1000 // the value of `kernelBarePtrCallConv` will be ignored.
1001 OperandRange origArguments = launchOp.getKernelOperands();
1002 bool effectiveBarePtr = kernelBarePtrCallConv ||
1003 getTypeConverter()->getOptions().useBarePtrCallConv;
1004 if (effectiveBarePtr) {
1005 for (Value arg : origArguments) {
1006 if (isa<UnrankedMemRefType>(arg.getType()))
1007 return rewriter.notifyMatchFailure(
1008 loc, "unranked memref kernel argument is not supported with "
1009 "the bare-pointer calling convention");
1010 }
1011 }
1012 SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
1013 loc, origArguments, adaptor.getKernelOperands(), rewriter,
1014 /*useBarePtrCallConv=*/kernelBarePtrCallConv);
1015 SmallVector<Value, 8> llvmArgumentsWithSizes;
1016
1017 // Intersperse size information if requested.
1018 if (kernelIntersperseSizeCallConv) {
1019 if (origArguments.size() != llvmArguments.size()) {
1020 // This shouldn't happen if the bare-pointer calling convention is used.
1021 return rewriter.notifyMatchFailure(
1022 launchOp,
1023 "Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
1024 }
1025
1026 llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
1027 for (auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
1028 auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
1029 if (!memrefTy) {
1030 return rewriter.notifyMatchFailure(
1031 launchOp, "Operand to launch op is not a memref.");
1032 }
1033
1034 if (!memrefTy.hasStaticShape() ||
1035 !memrefTy.getElementType().isIntOrFloat()) {
1036 return rewriter.notifyMatchFailure(
1037 launchOp, "Operand to launch op is not a memref with a static "
1038 "shape and an integer or float element type.");
1039 }
1040
1041 unsigned bitwidth = memrefTy.getElementTypeBitWidth();
1042 if (bitwidth % 8 != 0) {
1043 return rewriter.notifyMatchFailure(
1044 launchOp, "Operand to launch op is not a memref with a "
1045 "byte-aligned element type.");
1046 }
1047
1048 uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) *
1049 static_cast<uint64_t>(memrefTy.getNumElements());
1050
1051 Value sizeArg = LLVM::ConstantOp::create(
1052 rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize));
1053 llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer.
1054 llvmArgumentsWithSizes.push_back(sizeArg);
1055 }
1056 }
1057
1058 std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1059 if (launchOp.hasClusterSize()) {
1060 clusterSize =
1061 gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
1062 adaptor.getClusterSizeZ()};
1063 }
1064 gpu::LaunchFuncOp::create(
1065 rewriter, launchOp.getLoc(), launchOp.getKernelAttr(),
1066 gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1067 adaptor.getGridSizeZ()},
1068 gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1069 adaptor.getBlockSizeZ()},
1070 adaptor.getDynamicSharedMemorySize(),
1071 llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
1072 stream, clusterSize);
1073 if (launchOp.getAsyncToken())
1074 rewriter.replaceOp(launchOp, {stream});
1075 else
1076 rewriter.eraseOp(launchOp);
1077 return success();
1078}
1079
1081 ConversionPatternRewriter &rewriter,
1082 LLVM::LLVMPointerType destinationType,
1083 Value sourcePtr,
1084 const LLVMTypeConverter &typeConverter) {
1085 auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
1086 if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1087 sourcePtr = LLVM::AddrSpaceCastOp::create(
1088 rewriter, loc,
1089 LLVM::LLVMPointerType::get(rewriter.getContext(),
1090 destinationType.getAddressSpace()),
1091 sourcePtr);
1092 return sourcePtr;
1093}
1094
1095LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1096 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1097 ConversionPatternRewriter &rewriter) const {
1098 auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1099
1100 if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
1101 !isConvertibleAndHasIdentityMaps(memRefType) ||
1102 failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
1103 return failure();
1104
1105 auto loc = memcpyOp.getLoc();
1106
1107 MemRefDescriptor srcDesc(adaptor.getSrc());
1108 Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
1109
1110 Type elementPtrType = getElementPtrType(memRefType);
1111 Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
1112 Value gepPtr = LLVM::GEPOp::create(
1113 rewriter, loc, elementPtrType,
1114 typeConverter->convertType(memRefType.getElementType()), nullPtr,
1115 numElements);
1116 auto sizeBytes =
1117 LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr);
1118
1119 auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1120 srcDesc.alignedPtr(rewriter, loc),
1121 *getTypeConverter());
1122 auto dst = bitAndAddrspaceCast(
1123 loc, rewriter, llvmPointerType,
1124 MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
1125 *getTypeConverter());
1126
1127 auto stream = adaptor.getAsyncDependencies().front();
1128 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1129
1130 rewriter.replaceOp(memcpyOp, {stream});
1131
1132 return success();
1133}
1134
1135LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1136 gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1137 ConversionPatternRewriter &rewriter) const {
1138 auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1139
1140 if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
1141 !isConvertibleAndHasIdentityMaps(memRefType) ||
1142 failed(isAsyncWithOneDependency(rewriter, memsetOp)))
1143 return failure();
1144
1145 auto loc = memsetOp.getLoc();
1146
1147 Type valueType = adaptor.getValue().getType();
1148 unsigned bitWidth = valueType.getIntOrFloatBitWidth();
1149 // Ints and floats of 16 or 32 bit width are allowed.
1150 if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1151 return rewriter.notifyMatchFailure(
1152 memsetOp, "value must be a 16 or 32 bit int or float");
1153 }
1154
1155 unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth();
1156 Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1157
1158 MemRefDescriptor dstDesc(adaptor.getDst());
1159 Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
1160
1161 auto value =
1162 LLVM::BitcastOp::create(rewriter, loc, bitCastType, adaptor.getValue());
1163 auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1164 dstDesc.alignedPtr(rewriter, loc),
1165 *getTypeConverter());
1166
1167 auto stream = adaptor.getAsyncDependencies().front();
1168 FunctionCallBuilder builder =
1169 valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1170 builder.create(loc, rewriter, {dst, value, numElements, stream});
1171
1172 rewriter.replaceOp(memsetOp, {stream});
1173 return success();
1174}
1175
1176LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1177 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1178 ConversionPatternRewriter &rewriter) const {
1179 Location loc = op.getLoc();
1180 auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1181 {adaptor.getDevIndex()});
1182 rewriter.replaceOp(op, call);
1183 return success();
1184}
1185
1186template <typename T>
1187static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) {
1188 Type llvmInt32Type = builder.getIntegerType(32);
1189 return LLVM::ConstantOp::create(builder, loc, llvmInt32Type,
1190 static_cast<int32_t>(tValue));
1191}
1192
1193template <typename T>
1194static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) {
1195 Type llvmFloat32Type = builder.getF32Type();
1196 return LLVM::ConstantOp::create(
1197 builder, loc, llvmFloat32Type,
1198 builder.getF32FloatAttr(static_cast<float>(tValue)));
1199}
1200
1201LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1202 gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1203 ConversionPatternRewriter &rewriter) const {
1204 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1205 failed(isAsyncWithOneDependency(rewriter, op)))
1206 return failure();
1207 Location loc = op.getLoc();
1208 auto stream = adaptor.getAsyncDependencies().front();
1209 Value pTensor =
1210 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1211 Type dType = op.getMemref().getType().getElementType();
1212 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1213
1214 SmallVector<Value, 4> dims;
1215 for (Value dim : adaptor.getDims()) {
1216 dims.push_back(dim);
1217 }
1218
1219 Value handle;
1220 // TODO: For now, we track the use of the handle and lower it to cusparse /
1221 // cusparseLt accordingly. If in a block, both cusparse and cusparseLt are
1222 // used, we require two separate Creation ops to be the correct logic. In
1223 // future, we may add support to using one handle in sparse tensor / GPU
1224 // dialect in both cusparse and cusparseLt. use the cusparseLt create call if
1225 // the dnmat is used with spmat with 2:4 sparsity
1226 if (dims.size() == 2) {
1227 if (isSpMMCusparseLtOp(op.getDnTensor())) {
1228 auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1229 rewriter.getIndexAttr(11032));
1230 handle = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1231 llvmInt8Type, handleSz, /*alignment=*/16);
1232 handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1233
1234 createLtDnMatCallBuilder
1235 .create(loc, rewriter,
1236 {handle, dims[0], dims[1], pTensor, dtp, stream})
1237 .getResult();
1238 } else {
1239 handle =
1240 createDnMatCallBuilder
1241 .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1242 .getResult();
1243 }
1244 } else {
1245 assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1246 handle = createDnVecCallBuilder
1247 .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1248 .getResult();
1249 }
1250 rewriter.replaceOp(op, {handle, stream});
1251 return success();
1252}
1253
1254LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1255 gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1256 ConversionPatternRewriter &rewriter) const {
1257 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1258 failed(isAsyncWithOneDependency(rewriter, op)))
1259 return failure();
1260 Location loc = op.getLoc();
1261 auto stream = adaptor.getAsyncDependencies().front();
1262 auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
1263 SmallVector<Value, 4> dims;
1264 for (Value dim : definingOp.getDims()) {
1265 dims.push_back(dim);
1266 }
1267 if (dims.size() == 2) {
1268 // Use the cusparseLt destroy call if the dnmat is used with spmat with
1269 // 2:4 sparsity
1270 if (isSpMMCusparseLtOp(op.getDnTensor())) {
1271 destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1272 {adaptor.getDnTensor(), stream});
1273 } else {
1274 destroyDnMatCallBuilder.create(loc, rewriter,
1275 {adaptor.getDnTensor(), stream});
1276 }
1277 } else {
1278 assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1279 destroyDnVecCallBuilder.create(loc, rewriter,
1280 {adaptor.getDnTensor(), stream});
1281 }
1282 rewriter.replaceOp(op, {stream});
1283 return success();
1284}
1285
1286LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1287 gpu::CreateCooOp op, OpAdaptor adaptor,
1288 ConversionPatternRewriter &rewriter) const {
1289 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1290 failed(isAsyncWithOneDependency(rewriter, op)))
1291 return failure();
1292 Location loc = op.getLoc();
1293 auto stream = adaptor.getAsyncDependencies().front();
1294 Value pRowIdxs =
1295 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1296 Value pColIdxs =
1297 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1298 Value pValues =
1299 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1300 Type iType =
1301 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1302 Type dType =
1303 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1304 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1305 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1306 auto handle =
1307 createCooCallBuilder
1308 .create(loc, rewriter,
1309 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1310 pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1311 .getResult();
1312 rewriter.replaceOp(op, {handle, stream});
1313 return success();
1314}
1315
1316LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1317 gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1318 ConversionPatternRewriter &rewriter) const {
1319 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1320 failed(isAsyncWithOneDependency(rewriter, op)))
1321 return failure();
1322 Location loc = op.getLoc();
1323 auto stream = adaptor.getAsyncDependencies().front();
1324 Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc);
1325 Value pValues =
1326 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1327 Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1328 Type dType =
1329 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1330 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1331 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1332 auto handle =
1333 createCooAoSCallBuilder
1334 .create(loc, rewriter,
1335 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1336 pIdxs, pValues, itp, dtp, stream})
1337 .getResult();
1338 rewriter.replaceOp(op, {handle, stream});
1339 return success();
1340}
1341
1342LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1343 gpu::CreateCsrOp op, OpAdaptor adaptor,
1344 ConversionPatternRewriter &rewriter) const {
1345 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1346 failed(isAsyncWithOneDependency(rewriter, op)))
1347 return failure();
1348 Location loc = op.getLoc();
1349 auto stream = adaptor.getAsyncDependencies().front();
1350 Value pRowPos =
1351 MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc);
1352 Value pColIdxs =
1353 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1354 Value pValues =
1355 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1356 Type pType =
1357 llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1358 Type iType =
1359 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1360 Type dType =
1361 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1362 auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1363 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1364 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1365 auto handle =
1366 createCsrCallBuilder
1367 .create(loc, rewriter,
1368 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1369 pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1370 .getResult();
1371 rewriter.replaceOp(op, {handle, stream});
1372 return success();
1373}
1374
1375LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1376 gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1377 ConversionPatternRewriter &rewriter) const {
1378 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1379 failed(isAsyncWithOneDependency(rewriter, op)))
1380 return failure();
1381 Location loc = op.getLoc();
1382 auto stream = adaptor.getAsyncDependencies().front();
1383 Value pMat =
1384 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1385 Type dType =
1386 llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1387 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1388
1389 // CUDA runner asserts the size is 44104 bytes.
1390 auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1391 rewriter.getIndexAttr(44104));
1392 Value handle = LLVM::AllocaOp::create(
1393 rewriter, loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1394 handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1395
1396 create2To4SpMatCallBuilder
1397 .create(loc, rewriter,
1398 {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1399 .getResult();
1400 rewriter.replaceOp(op, {handle, stream});
1401 return success();
1402}
1403
1404LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1405 gpu::DestroySpMatOp op, OpAdaptor adaptor,
1406 ConversionPatternRewriter &rewriter) const {
1407 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1408 failed(isAsyncWithOneDependency(rewriter, op)))
1409 return failure();
1410 Location loc = op.getLoc();
1411 auto stream = adaptor.getAsyncDependencies().front();
1412 // Use the cusparseLt destroy call if the spmat is 2:4 sparsity
1413 if (is2To4Sparsity(op.getSpmat())) {
1414 destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1415 {adaptor.getSpmat(), stream});
1416
1417 } else {
1418 destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1419 }
1420 rewriter.replaceOp(op, {stream});
1421 return success();
1422}
1423
1424LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1425 gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1426 ConversionPatternRewriter &rewriter) const {
1427 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1428 failed(isAsyncWithOneDependency(rewriter, op)))
1429 return failure();
1430 Location loc = op.getLoc();
1431 auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
1432 auto computeType = genConstInt32From(
1433 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1434 auto stream = adaptor.getAsyncDependencies().front();
1435 auto bufferSize = spMVBufferSizeCallBuilder
1436 .create(loc, rewriter,
1437 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1438 adaptor.getDnY(), computeType, stream})
1439 .getResult();
1440 rewriter.replaceOp(op, {bufferSize, stream});
1441 return success();
1442}
1443
1444LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1445 gpu::SpMVOp op, OpAdaptor adaptor,
1446 ConversionPatternRewriter &rewriter) const {
1447 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1448 failed(isAsyncWithOneDependency(rewriter, op)))
1449 return failure();
1450 Location loc = op.getLoc();
1451 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1452 auto computeType = genConstInt32From(
1453 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1454 auto stream = adaptor.getAsyncDependencies().front();
1455 Value pBuf =
1456 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1457 spMVCallBuilder.create(loc, rewriter,
1458 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1459 adaptor.getDnY(), computeType, pBuf, stream});
1460 rewriter.replaceOp(op, {stream});
1461 return success();
1462}
1463
1464LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1465 gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1466 ConversionPatternRewriter &rewriter) const {
1467 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1468 failed(isAsyncWithOneDependency(rewriter, op)))
1469 return failure();
1470 Location loc = op.getLoc();
1471 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1472 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1473 auto stream = adaptor.getAsyncDependencies().front();
1474 Value bufferSize;
1475 if (is2To4Sparsity(op.getSpmatA())) {
1476 auto pruneFlag =
1477 genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA()));
1478 auto computeType = genConstInt32From(
1479 rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
1480 auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1481 rewriter.getIndexAttr(3));
1482 auto bufferSize =
1483 LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmPointerType,
1484 three, /*alignment=*/16);
1485 createCuSparseLtSpMMBufferSizeBuilder
1486 .create(loc, rewriter,
1487 {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1488 adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1489 pruneFlag, stream})
1490 .getResult();
1491
1492 auto bufferSizePtr1 = LLVM::GEPOp::create(
1493 rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1494 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1495 rewriter.getIndexAttr(1))});
1496 auto bufferSizePtr2 = LLVM::GEPOp::create(
1497 rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1498 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1499 rewriter.getIndexAttr(2))});
1500 auto bufferSize0 =
1501 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSize);
1502 auto bufferSize1 =
1503 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr1);
1504 auto bufferSize2 =
1505 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr2);
1506
1507 rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1508 } else {
1509 auto computeType = genConstInt32From(
1510 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1511 bufferSize =
1512 createSpMMBufferSizeCallBuilder
1513 .create(loc, rewriter,
1514 {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1515 adaptor.getDnmatC(), computeType, stream})
1516 .getResult();
1517 rewriter.replaceOp(op, {bufferSize, stream});
1518 }
1519 return success();
1520}
1521
1522LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1523 gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1524 ConversionPatternRewriter &rewriter) const {
1525 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1526 failed(isAsyncWithOneDependency(rewriter, op)))
1527 return failure();
1528 Location loc = op.getLoc();
1529 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1530 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1531 auto computeType = genConstInt32From(
1532 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1533 auto stream = adaptor.getAsyncDependencies().front();
1534 auto bufferSize =
1535 createSDDMMBufferSizeCallBuilder
1536 .create(loc, rewriter,
1537 {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1538 adaptor.getSpmatC(), computeType, stream})
1539 .getResult();
1540 rewriter.replaceOp(op, {bufferSize, stream});
1541 return success();
1542}
1543
1544LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1545 gpu::SpMMOp op, OpAdaptor adaptor,
1546 ConversionPatternRewriter &rewriter) const {
1547 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1548 failed(isAsyncWithOneDependency(rewriter, op)))
1549 return failure();
1550 Location loc = op.getLoc();
1551 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1552 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1553 auto computeType = genConstInt32From(
1554 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1555
1556 auto stream = adaptor.getAsyncDependencies().front();
1557
1558 // Lower to cusparseLt if applicable
1559 if (is2To4Sparsity(op.getSpmatA())) {
1560 SmallVector<Value> pBufs;
1561 for (Value buffer : adaptor.getBuffers()) {
1562 Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
1563 pBufs.push_back(pBuf);
1564 }
1565 createCuSparseLtSpMMBuilder.create(
1566 loc, rewriter,
1567 {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1568 pBufs[0], pBufs[1], pBufs[2], stream});
1569 } else {
1570 Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
1571 .allocatedPtr(rewriter, loc);
1572 createSpMMCallBuilder.create(loc, rewriter,
1573 {modeA, modeB, adaptor.getSpmatA(),
1574 adaptor.getDnmatB(), adaptor.getDnmatC(),
1575 computeType, pBuf, stream});
1576 }
1577 rewriter.replaceOp(op, {stream});
1578 return success();
1579}
1580
1581template <typename T>
1583 converter.addConversion([&converter](T) -> Type {
1584 return LLVM::LLVMPointerType::get(&converter.getContext());
1585 });
1586}
1587
1588LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1589 gpu::SDDMMOp op, OpAdaptor adaptor,
1590 ConversionPatternRewriter &rewriter) const {
1591 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1592 failed(isAsyncWithOneDependency(rewriter, op)))
1593 return failure();
1594 Location loc = op.getLoc();
1595 auto computeType = genConstInt32From(
1596 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1597 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1598 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1599 auto stream = adaptor.getAsyncDependencies().front();
1600 Value pBuf =
1601 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1602 createSDDMMCallBuilder.create(loc, rewriter,
1603 {modeA, modeB, adaptor.getDnmatA(),
1604 adaptor.getDnmatB(), adaptor.getSpmatC(),
1605 computeType, pBuf, stream});
1606 rewriter.replaceOp(op, {stream});
1607 return success();
1608}
1609
1610LogicalResult
1611ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1612 gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1613 ConversionPatternRewriter &rewriter) const {
1614 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1615 failed(isAsyncWithOneDependency(rewriter, op)))
1616 return failure();
1617 Location loc = op.getLoc();
1618 auto stream = adaptor.getAsyncDependencies().front();
1619 Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1620 .getResult();
1621 rewriter.replaceOp(op, {descr, stream});
1622 return success();
1623}
1624
1625LogicalResult
1626ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1627 gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1628 ConversionPatternRewriter &rewriter) const {
1629 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1630 failed(isAsyncWithOneDependency(rewriter, op)))
1631 return failure();
1632 Location loc = op.getLoc();
1633 auto stream = adaptor.getAsyncDependencies().front();
1634 createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1635 {adaptor.getDesc(), stream});
1636 rewriter.replaceOp(op, {stream});
1637 return success();
1638}
1639
1640LogicalResult
1641ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1642 gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1643 ConversionPatternRewriter &rewriter) const {
1644 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1645 failed(isAsyncWithOneDependency(rewriter, op)))
1646 return failure();
1647 Location loc = op.getLoc();
1648 auto computeType = genConstInt32From(
1649 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1650 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1651 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1652 auto stream = adaptor.getAsyncDependencies().front();
1653
1654 Value pBuf =
1655 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1656 Value bufferSizeNew;
1657
1658 if (adaptor.getKind() ==
1659 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1660 bufferSizeNew =
1661 createSpGEMMWorkEstimationBuilder
1662 .create(loc, rewriter,
1663 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1664 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1665 adaptor.getBufferSz(), pBuf, stream})
1666 .getResult();
1667 } else {
1668 bufferSizeNew =
1669 createSpGEMMComputeBuilder
1670 .create(loc, rewriter,
1671 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1672 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1673 adaptor.getBufferSz(), pBuf, stream})
1674 .getResult();
1675 }
1676 rewriter.replaceOp(op, {bufferSizeNew, stream});
1677 return success();
1678}
1679
1680LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1681 gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1682 ConversionPatternRewriter &rewriter) const {
1683 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1684 failed(isAsyncWithOneDependency(rewriter, op)))
1685 return failure();
1686 Location loc = op.getLoc();
1687 auto computeType = genConstInt32From(
1688 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1689 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1690 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1691 auto stream = adaptor.getAsyncDependencies().front();
1692 createSpGEMMCopyBuilder.create(loc, rewriter,
1693 {adaptor.getDesc(), modeA, modeB,
1694 adaptor.getSpmatA(), adaptor.getSpmatB(),
1695 adaptor.getSpmatC(), computeType, stream});
1696 rewriter.replaceOp(op, {stream});
1697 return success();
1698}
1699
1700LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1701 gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1702 ConversionPatternRewriter &rewriter) const {
1703 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1704 failed(isAsyncWithOneDependency(rewriter, op)))
1705 return failure();
1706 Location loc = op.getLoc();
1707 auto stream = adaptor.getAsyncDependencies().front();
1708
1709 auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1710 rewriter.getIndexAttr(3));
1711 auto buffer = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1712 llvmInt64Type, three, /*alignment=*/16);
1713
1714 auto rowsPtr = LLVM::GEPOp::create(
1715 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1716 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1717 rewriter.getIndexAttr(0))});
1718 auto colsPtr = LLVM::GEPOp::create(
1719 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1720 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1721 rewriter.getIndexAttr(1))});
1722 auto nnzsPtr = LLVM::GEPOp::create(
1723 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1724 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1725 rewriter.getIndexAttr(2))});
1726 createSpMatGetSizeBuilder.create(
1727 loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1728 auto rows = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, rowsPtr);
1729 auto cols = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, colsPtr);
1730 auto nnzs = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, nnzsPtr);
1731
1732 rewriter.replaceOp(op, {rows, cols, nnzs, stream});
1733 return success();
1734}
1735
1736LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1737 gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1738 ConversionPatternRewriter &rewriter) const {
1739 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1740 failed(isAsyncWithOneDependency(rewriter, op)))
1741 return failure();
1742 Location loc = op.getLoc();
1743 auto stream = adaptor.getAsyncDependencies().front();
1744 Value pPos =
1745 MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
1746 Value pCrd =
1747 MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
1748 Value pVal =
1749 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1750 createSetCsrPointersBuilder.create(
1751 loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1752 rewriter.replaceOp(op, {stream});
1753 return success();
1754}
1755
1756LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1757 gpu::CreateCscOp op, OpAdaptor adaptor,
1758 ConversionPatternRewriter &rewriter) const {
1759 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1760 failed(isAsyncWithOneDependency(rewriter, op)))
1761 return failure();
1762 Location loc = op.getLoc();
1763 auto stream = adaptor.getAsyncDependencies().front();
1764 Value pColPos =
1765 MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc);
1766 Value pRowIdxs =
1767 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1768 Value pValues =
1769 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1770 Type pType =
1771 llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1772 Type iType =
1773 llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1774 Type dType =
1775 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1776 auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1777 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1778 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1779 auto handle =
1780 createCscCallBuilder
1781 .create(loc, rewriter,
1782 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1783 pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1784 .getResult();
1785 rewriter.replaceOp(op, {handle, stream});
1786 return success();
1787}
1788
1789LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1790 gpu::CreateBsrOp op, OpAdaptor adaptor,
1791 ConversionPatternRewriter &rewriter) const {
1792 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1793 failed(isAsyncWithOneDependency(rewriter, op)))
1794 return failure();
1795 Location loc = op.getLoc();
1796 auto stream = adaptor.getAsyncDependencies().front();
1797 Value pRowPos =
1798 MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc);
1799 Value pColIdxs =
1800 MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc);
1801 Value pValues =
1802 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1803 Type pType =
1804 llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1805 Type iType =
1806 llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1807 Type dType =
1808 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1809 auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1810 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1811 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1812 auto handle =
1813 createBsrCallBuilder
1814 .create(loc, rewriter,
1815 {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1816 adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1817 pColIdxs, pValues, ptp, itp, dtp, stream})
1818 .getResult();
1819 rewriter.replaceOp(op, {handle, stream});
1820 return success();
1821}
1822
1824 LLVMTypeConverter &converter, RewritePatternSet &patterns,
1825 bool kernelBarePtrCallConv, bool kernelIntersperseSizeCallConv) {
1830
1831 patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
1832 ConvertDeallocOpToGpuRuntimeCallPattern,
1833 ConvertHostRegisterOpToGpuRuntimeCallPattern,
1834 ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1835 ConvertMemcpyOpToGpuRuntimeCallPattern,
1836 ConvertMemsetOpToGpuRuntimeCallPattern,
1837 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1838 ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1839 ConvertWaitOpToGpuRuntimeCallPattern,
1840 ConvertAsyncYieldToGpuRuntimeCallPattern,
1841 ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1842 ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1843 ConvertCreateCooOpToGpuRuntimeCallPattern,
1844 ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1845 ConvertCreateCsrOpToGpuRuntimeCallPattern,
1846 ConvertCreateCscOpToGpuRuntimeCallPattern,
1847 ConvertCreateBsrOpToGpuRuntimeCallPattern,
1848 ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1849 ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1850 ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1851 ConvertSpMVOpToGpuRuntimeCallPattern,
1852 ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1853 ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1854 ConvertSpMMOpToGpuRuntimeCallPattern,
1855 ConvertSDDMMOpToGpuRuntimeCallPattern,
1856 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1857 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1858 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1859 ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1860 ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1861 ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
1862 patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv,
1863 kernelIntersperseSizeCallConv);
1864}
1865
1866//===----------------------------------------------------------------------===//
1867// GPUModuleOp convert to LLVM op interface
1868//===----------------------------------------------------------------------===//
1869
1870namespace {
1871struct GPUModuleOpConvertToLLVMInterface
1872 : public ConvertToLLVMOpInterface::ExternalModel<
1873 GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
1874 /// Get the conversion patterns from the target attribute.
1875 void getConvertToLLVMConversionAttrs(
1877};
1878} // namespace
1879
1880void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
1881 Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const {
1882 auto module = cast<gpu::GPUModuleOp>(op);
1883 ArrayAttr targetsAttr = module.getTargetsAttr();
1884 // Fail if there are no target attributes or there is more than one target.
1885 if (!targetsAttr || targetsAttr.size() != 1)
1886 return;
1887 if (auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
1888 attrs.push_back(patternAttr);
1889}
1890
1892 registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
1893 gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
1894 });
1895}
return success()
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
ArrayAttr()
b getContext())
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
FloatType getF32Type()
Definition Builders.cpp:47
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
FloatAttr getF32FloatAttr(float value)
Definition Builders.cpp:250
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition Pattern.cpp:38
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition Pattern.cpp:58
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
This class helps build Operations.
Definition Builders.h:209
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition Builders.h:248
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void registerConvertGpuToLLVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMOpInterface interface on the gpu::GPUModuleOP operation.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool kernelBarePtrCallConv=false, bool kernelIntersperseSizeCallConv=false)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
LLVM::LLVMFunctionType functionType
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const