MLIR 23.0.0git
GPUToLLVMConversion.cpp
Go to the documentation of this file.
1//===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert gpu.launch_func op into a sequence of
10// GPU runtime calls. As most of GPU runtimes does not have a stable published
11// ABI, this pass uses a slim runtime layer that builds on top of the public
12// API from GPU runtime headers.
13//
14//===----------------------------------------------------------------------===//
15
17
34#include "mlir/IR/Attributes.h"
35#include "mlir/IR/Builders.h"
36#include "mlir/IR/BuiltinOps.h"
39
40#include "llvm/ADT/STLExtras.h"
41
42#define DEBUG_TYPE "gpu-to-llvm"
43
44namespace mlir {
45#define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
46#include "mlir/Conversion/Passes.h.inc"
47} // namespace mlir
48
49using namespace mlir;
50
51namespace {
52class GpuToLLVMConversionPass
53 : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
54public:
55 using Base::Base;
56 void getDependentDialects(DialectRegistry &registry) const final {
57 Base::getDependentDialects(registry);
59 }
60 // Run the dialect converter on the module.
61 void runOnOperation() override;
62};
63
64template <typename OpTy>
65class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
66public:
67 explicit ConvertOpToGpuRuntimeCallPattern(
68 const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
69 : ConvertOpToLLVMPattern<OpTy>(typeConverter, benefit) {}
70
71protected:
72 Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
73 MemRefType type, MemRefDescriptor desc) const {
74 Type indexType = ConvertToLLVMPattern::getIndexType();
75 if (type.hasStaticShape())
77 rewriter, loc, indexType, type.getNumElements());
78 // Compute the number of elements by multiplying all the dim sizes.
79 uint64_t rank = type.getRank();
80 Value numElements = desc.size(rewriter, loc, /*pos=*/0);
81 for (unsigned i = 1; i < rank; i++)
82 numElements = LLVM::MulOp::create(rewriter, loc, numElements,
83 desc.size(rewriter, loc, /*pos=*/i));
84 return numElements;
85 }
86
87 MLIRContext *context = &this->getTypeConverter()->getContext();
88
89 Type llvmVoidType = LLVM::LLVMVoidType::get(context);
90 LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
91 Type llvmInt8Type = IntegerType::get(context, 8);
92 Type llvmInt16Type = IntegerType::get(context, 16);
93 Type llvmInt32Type = IntegerType::get(context, 32);
94 Type llvmInt64Type = IntegerType::get(context, 64);
95 Type llvmFloat32Type = Float32Type::get(context);
96 Type llvmIntPtrType = IntegerType::get(
97 context, this->getTypeConverter()->getPointerBitwidth(0));
98
99 FunctionCallBuilder streamCreateCallBuilder = {
100 "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
101 FunctionCallBuilder streamDestroyCallBuilder = {
102 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
103 FunctionCallBuilder streamSynchronizeCallBuilder = {
104 "mgpuStreamSynchronize",
105 llvmVoidType,
106 {llvmPointerType /* void *stream */}};
107 FunctionCallBuilder streamWaitEventCallBuilder = {
108 "mgpuStreamWaitEvent",
109 llvmVoidType,
110 {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
111 FunctionCallBuilder eventCreateCallBuilder = {
112 "mgpuEventCreate", llvmPointerType /* void *event */, {}};
113 FunctionCallBuilder eventDestroyCallBuilder = {
114 "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
115 FunctionCallBuilder eventSynchronizeCallBuilder = {
116 "mgpuEventSynchronize",
117 llvmVoidType,
118 {llvmPointerType /* void *event */}};
119 FunctionCallBuilder eventRecordCallBuilder = {
120 "mgpuEventRecord",
121 llvmVoidType,
122 {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
123 FunctionCallBuilder hostRegisterCallBuilder = {
124 "mgpuMemHostRegisterMemRef",
125 llvmVoidType,
126 {llvmIntPtrType /* intptr_t rank */,
127 llvmPointerType /* void *memrefDesc */,
128 llvmIntPtrType /* intptr_t elementSizeBytes */}};
129 FunctionCallBuilder hostUnregisterCallBuilder = {
130 "mgpuMemHostUnregisterMemRef",
131 llvmVoidType,
132 {llvmIntPtrType /* intptr_t rank */,
133 llvmPointerType /* void *memrefDesc */,
134 llvmIntPtrType /* intptr_t elementSizeBytes */}};
135 FunctionCallBuilder allocCallBuilder = {
136 "mgpuMemAlloc",
137 llvmPointerType /* void * */,
138 {llvmIntPtrType /* intptr_t sizeBytes */,
139 llvmPointerType /* void *stream */,
140 llvmInt8Type /* bool isHostShared */}};
141 FunctionCallBuilder deallocCallBuilder = {
142 "mgpuMemFree",
143 llvmVoidType,
144 {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
145 FunctionCallBuilder memcpyCallBuilder = {
146 "mgpuMemcpy",
147 llvmVoidType,
148 {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
149 llvmIntPtrType /* intptr_t sizeBytes */,
150 llvmPointerType /* void *stream */}};
151 FunctionCallBuilder memset16CallBuilder = {
152 "mgpuMemset16",
153 llvmVoidType,
154 {llvmPointerType /* void *dst */,
155 llvmInt16Type /* unsigned short value */,
156 llvmIntPtrType /* intptr_t sizeBytes */,
157 llvmPointerType /* void *stream */}};
158 FunctionCallBuilder memset32CallBuilder = {
159 "mgpuMemset32",
160 llvmVoidType,
161 {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
162 llvmIntPtrType /* intptr_t sizeBytes */,
163 llvmPointerType /* void *stream */}};
164 FunctionCallBuilder setDefaultDeviceCallBuilder = {
165 "mgpuSetDefaultDevice",
166 llvmVoidType,
167 {llvmInt32Type /* uint32_t devIndex */}};
168 FunctionCallBuilder createDnVecCallBuilder = {
169 "mgpuCreateDnVec",
170 llvmPointerType,
171 {llvmIntPtrType, llvmPointerType, llvmInt32Type,
172 llvmPointerType /* void *stream */}};
173 FunctionCallBuilder destroyDnVecCallBuilder = {
174 "mgpuDestroyDnVec",
175 llvmVoidType,
176 {llvmPointerType, llvmPointerType /* void *stream */}};
177 FunctionCallBuilder createDnMatCallBuilder = {
178 "mgpuCreateDnMat",
179 llvmPointerType,
180 {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
181 llvmPointerType /* void *stream */}};
182 FunctionCallBuilder destroyDnMatCallBuilder = {
183 "mgpuDestroyDnMat",
184 llvmVoidType,
185 {llvmPointerType, llvmPointerType /* void *stream */}};
186 FunctionCallBuilder createCooCallBuilder = {
187 "mgpuCreateCoo",
188 llvmPointerType,
189 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
190 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
191 llvmPointerType /* void *stream */}};
192 FunctionCallBuilder createCooAoSCallBuilder = {
193 "mgpuCreateCooAoS", // deprecated in cuSPARSE 11.2
194 llvmPointerType,
195 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
196 llvmPointerType, llvmInt32Type, llvmInt32Type,
197 llvmPointerType /* void *stream */}};
198 FunctionCallBuilder createCsrCallBuilder = {
199 "mgpuCreateCsr",
200 llvmPointerType,
201 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
202 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
203 llvmInt32Type, llvmPointerType /* void *stream */}};
204 FunctionCallBuilder createCscCallBuilder = {
205 "mgpuCreateCsc",
206 llvmPointerType,
207 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
208 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
209 llvmInt32Type, llvmPointerType /* void *stream */}};
210 FunctionCallBuilder createBsrCallBuilder = {
211 "mgpuCreateBsr",
212 llvmPointerType,
213 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
214 llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
215 llvmInt32Type, llvmInt32Type, llvmInt32Type,
216 llvmPointerType /* void *stream */}};
217 FunctionCallBuilder destroySpMatCallBuilder = {
218 "mgpuDestroySpMat",
219 llvmVoidType,
220 {llvmPointerType, llvmPointerType /* void *stream */}};
221 FunctionCallBuilder spMVBufferSizeCallBuilder = {
222 "mgpuSpMVBufferSize",
223 llvmIntPtrType,
224 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
225 llvmInt32Type, llvmPointerType /* void *stream */}};
226 FunctionCallBuilder spMVCallBuilder = {
227 "mgpuSpMV",
228 llvmVoidType,
229 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
230 llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
231 FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
232 "mgpuSpMMBufferSize",
233 llvmIntPtrType,
234 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
235 llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
236 FunctionCallBuilder createSpMMCallBuilder = {
237 "mgpuSpMM",
238 llvmVoidType,
239 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
240 llvmPointerType, llvmInt32Type, llvmPointerType,
241 llvmPointerType /* void *stream */}};
242 FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
243 "mgpuSDDMMBufferSize",
244 llvmIntPtrType,
245 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
246 llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
247 FunctionCallBuilder createSDDMMCallBuilder = {
248 "mgpuSDDMM",
249 llvmVoidType,
250 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
251 llvmPointerType, llvmInt32Type, llvmPointerType,
252 llvmPointerType /* void *stream */}};
253 FunctionCallBuilder createLtDnMatCallBuilder = {
254 "mgpuCreateCuSparseLtDnMat",
255 llvmVoidType,
256 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
257 llvmInt32Type, llvmPointerType /* void *stream */}};
258 FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
259 "mgpuDestroyCuSparseLtSpMat",
260 llvmVoidType,
261 {llvmPointerType, llvmPointerType /* void *stream */}};
262 FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
263 "mgpuDestroyCuSparseLtDnMat",
264 llvmVoidType,
265 {llvmPointerType, llvmPointerType /* void *stream */}};
266 FunctionCallBuilder create2To4SpMatCallBuilder = {
267 "mgpuCusparseLtCreate2To4SpMat",
268 llvmVoidType,
269 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
270 llvmInt32Type, llvmPointerType /* void *stream */}};
271 FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
272 "mgpuCuSparseLtSpMMBufferSize",
273 llvmVoidType,
274 {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
275 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
276 llvmPointerType /*void *stream*/}};
277 FunctionCallBuilder createCuSparseLtSpMMBuilder = {
278 "mgpuCuSparseLtSpMM",
279 llvmVoidType,
280 {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
281 llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}};
282 FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
283 "mgpuSpGEMMCreateDescr",
284 llvmPointerType,
285 {llvmPointerType /*void *stream*/}};
286 FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
287 "mgpuSpGEMMDestroyDescr",
288 llvmVoidType,
289 {llvmPointerType /*s*/, llvmPointerType /*void *stream*/}};
290 FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
291 "mgpuSpGEMMWorkEstimation",
292 llvmIntPtrType,
293 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
294 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
295 llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
296 llvmPointerType /*void *stream*/}};
297 FunctionCallBuilder createSpGEMMComputeBuilder = {
298 "mgpuSpGEMMCompute",
299 llvmIntPtrType,
300 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
301 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
302 llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
303 llvmPointerType /*void *stream*/}};
304 FunctionCallBuilder createSpGEMMCopyBuilder = {
305 "mgpuSpGEMMCopy",
306 llvmVoidType,
307 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
308 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
309 llvmInt32Type /*ctp*/, llvmPointerType /*void *stream*/}};
310 FunctionCallBuilder createSpMatGetSizeBuilder = {
311 "mgpuSpMatGetSize",
312 llvmVoidType,
313 {llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/,
314 llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}};
315 FunctionCallBuilder createSetCsrPointersBuilder = {
316 "mgpuSetCsrPointers",
317 llvmVoidType,
318 {llvmPointerType /*spmat*/, llvmPointerType /*pos*/,
319 llvmPointerType /*crd*/, llvmPointerType /*val*/,
320 llvmPointerType /*void *stream*/}};
321};
322
323/// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
324/// call. Currently it supports CUDA and ROCm (HIP).
325class ConvertHostRegisterOpToGpuRuntimeCallPattern
326 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
327public:
328 ConvertHostRegisterOpToGpuRuntimeCallPattern(
329 const LLVMTypeConverter &typeConverter)
330 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
331
332private:
333 LogicalResult
334 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter) const override;
336};
337
338class ConvertHostUnregisterOpToGpuRuntimeCallPattern
339 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
340public:
341 ConvertHostUnregisterOpToGpuRuntimeCallPattern(
342 const LLVMTypeConverter &typeConverter)
343 : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
344 }
345
346private:
347 LogicalResult
348 matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter) const override;
350};
351
352/// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
353/// call. Currently it supports CUDA and ROCm (HIP).
354class ConvertAllocOpToGpuRuntimeCallPattern
355 : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
356public:
357 ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
358 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
359
360private:
361 LogicalResult
362 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter) const override;
364};
365
366/// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
367/// call. Currently it supports CUDA and ROCm (HIP).
368class ConvertDeallocOpToGpuRuntimeCallPattern
369 : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
370public:
371 ConvertDeallocOpToGpuRuntimeCallPattern(
372 const LLVMTypeConverter &typeConverter)
373 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
374
375private:
376 LogicalResult
377 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
378 ConversionPatternRewriter &rewriter) const override;
379};
380
381class ConvertAsyncYieldToGpuRuntimeCallPattern
382 : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
383public:
384 ConvertAsyncYieldToGpuRuntimeCallPattern(
385 const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
386 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter,
387 benefit) {}
388
389private:
390 LogicalResult
391 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
392 ConversionPatternRewriter &rewriter) const override;
393};
394
395/// A rewrite pattern to convert gpu.wait operations into a GPU runtime
396/// call. Currently it supports CUDA and ROCm (HIP).
397class ConvertWaitOpToGpuRuntimeCallPattern
398 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
399public:
400 ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
401 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
402
403private:
404 LogicalResult
405 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
406 ConversionPatternRewriter &rewriter) const override;
407};
408
409/// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
410/// call. Currently it supports CUDA and ROCm (HIP).
411class ConvertWaitAsyncOpToGpuRuntimeCallPattern
412 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
413public:
414 ConvertWaitAsyncOpToGpuRuntimeCallPattern(
415 const LLVMTypeConverter &typeConverter)
416 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
417
418private:
419 LogicalResult
420 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
421 ConversionPatternRewriter &rewriter) const override;
422};
423
424/// A rewrite patter to legalize gpu.launch_func with LLVM types.
425class LegalizeLaunchFuncOpPattern
426 : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
427public:
428 LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
429 bool kernelBarePtrCallConv,
430 bool kernelIntersperseSizeCallConv)
431 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
432 kernelBarePtrCallConv(kernelBarePtrCallConv),
433 kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
434
435private:
436 LogicalResult
437 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
438 ConversionPatternRewriter &rewriter) const override;
439
440 bool kernelBarePtrCallConv;
441 bool kernelIntersperseSizeCallConv;
442};
443
444/// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
445/// call. Currently it supports CUDA and ROCm (HIP).
446class ConvertMemcpyOpToGpuRuntimeCallPattern
447 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
448public:
449 ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
450 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
451
452private:
453 LogicalResult
454 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
455 ConversionPatternRewriter &rewriter) const override;
456};
457
458/// A rewrite pattern to convert gpu.memset operations into a GPU runtime
459/// call. Currently it supports CUDA and ROCm (HIP).
460class ConvertMemsetOpToGpuRuntimeCallPattern
461 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
462public:
463 ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
464 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
465
466private:
467 LogicalResult
468 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
469 ConversionPatternRewriter &rewriter) const override;
470};
471
472/// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call.
473/// Currently supports CUDA and ROCm (HIP)
474class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
475 : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
476public:
477 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
478 const LLVMTypeConverter &typeConverter)
479 : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
480 typeConverter) {}
481
482 LogicalResult
483 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
484 ConversionPatternRewriter &rewriter) const override;
485};
486
487/// Generic rewriting rule for operation on sparse matrices.
488/// Currently supports CUDA (by means of cuSparse and cuSparseLt).
489#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
490 class Convert##op_name##ToGpuRuntimeCallPattern \
491 : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
492 public: \
493 Convert##op_name##ToGpuRuntimeCallPattern( \
494 const LLVMTypeConverter &typeConverter) \
495 : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
496 \
497 private: \
498 LogicalResult \
499 matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
500 ConversionPatternRewriter &rewriter) const override; \
501 };
502
520DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp)
524
525} // namespace
526
527void GpuToLLVMConversionPass::runOnOperation() {
528 MLIRContext *context = &getContext();
529
530 // Perform progressive lowering of vector transfer operations.
531 {
532 RewritePatternSet patterns(&getContext());
533 // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
535 /*maxTransferRank=*/1);
536 // Transform N-D vector.from_elements to 1-D vector.from_elements before
537 // conversion.
538 vector::populateVectorFromElementsUnrollPatterns(patterns);
539 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
540 return signalPassFailure();
541 }
542
543 LowerToLLVMOptions options(context);
544 options.useBarePtrCallConv = hostBarePtrCallConv;
545 RewritePatternSet patterns(context);
546 ConversionTarget target(*context);
547 target.addLegalDialect<LLVM::LLVMDialect>();
548 LLVMTypeConverter converter(context, options);
549
550 // Populate all patterns from all dialects that implement the
551 // `ConvertToLLVMPatternInterface` interface.
552 for (Dialect *dialect : context->getLoadedDialects()) {
553 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
554 if (!iface)
555 continue;
556 iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
557 }
558
559 // Preserve GPU modules and binaries. Modules are preserved as they can be
560 // converted later by `gpu-module-to-binary`.
561 target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
562 // Accept as legal LaunchFuncOps if the operands have been lowered.
563 target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
564 [&](gpu::LaunchFuncOp op) -> bool { return converter.isLegal(op); });
565
566 // These aren't covered by the ConvertToLLVMPatternInterface right now.
567 populateVectorToLLVMConversionPatterns(converter, patterns);
570 target);
571 populateGpuToLLVMConversionPatterns(converter, patterns,
572 kernelBarePtrCallConv,
573 kernelIntersperseSizeCallConv);
574
575 if (failed(
576 applyPartialConversion(getOperation(), target, std::move(patterns))))
577 signalPassFailure();
578}
579
581 ArrayRef<Value> arguments) const {
582 auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
583 auto function = [&] {
584 if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
585 return function;
586 auto builder = OpBuilder::atBlockEnd(module.getBody());
587 return LLVM::LLVMFuncOp::create(builder, loc, functionName, functionType);
588 }();
589 return LLVM::CallOp::create(builder, loc, function, arguments);
590}
591
592// Corresponding to cusparseIndexType_t defined in cusparse.h.
593static int32_t getCuSparseIndexTypeFrom(Type type) {
594 if (type.isInteger(16))
595 return 1; // CUSPARSE_INDEX_16U
596 if (type.isInteger(32))
597 return 2; // CUSPARSE_INDEX_32I
598 return 3; // CUSPARSE_INDEX_64I
599}
600
601static int32_t getCuSparseLtDataTypeFrom(Type type) {
602 if (type.isF16())
603 return 0; // CUSPARSE_COMPUTE_16F,
604 if (type.isInteger(32))
605 return 1; // CUSPARSE_COMPUTE_32I
606 llvm_unreachable("unsupported type");
607 // TODO: add support to TF32
608}
609
610// Corresponding to cudaDataType_t defined in CUDA library_types.h.
611static int32_t getCuSparseDataTypeFrom(Type type) {
612 if (llvm::isa<ComplexType>(type)) {
613 // get the element type
614 auto elementType = cast<ComplexType>(type).getElementType();
615 if (elementType.isBF16())
616 return 15; // CUDA_C_16BF
617 if (elementType.isF16())
618 return 6; // CUDA_C_16F
619 if (elementType.isF32())
620 return 4; // CUDA_C_32F
621 if (elementType.isF64())
622 return 5; // CUDA_C_64F
623 if (elementType.isInteger(8))
624 return 7; // CUDA_C_8I
625 if (elementType.isInteger(16))
626 return 21; // CUDA_C_16I
627 if (elementType.isInteger(32))
628 return 11; // CUDA_C_32I
629 }
630 if (type.isBF16())
631 return 14; // CUDA_R_16BF
632 if (type.isF16())
633 return 2; // CUDA_R_16F
634 if (type.isF32())
635 return 0; // CUDA_R_32F
636 if (type.isF64())
637 return 1; // CUDA_R_64F
638 if (type.isInteger(8))
639 return 3; // CUDA_R_8I
640 if (type.isInteger(16))
641 return 20; // CUDA_R_16I
642 if (type.isInteger(32))
643 return 10; // CUDA_R_32I
644
645 llvm_unreachable("unsupported element type");
646}
647
648static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat) {
649 return spMat.getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
650}
651
652// TODO: We may want a run-time (of the mlir compiler) disablement/warning:
653// cusparseLt currently won't work for cuda architecture <8.0 and will trigger a
654// runtime (of the CUDA program) error , but it might be great if we could at
655// least output a warning when we found the target architecture is <8.0 and the
656// user still wants to use cusparseLt. to make sure when lowering gpu sparse
657// dialect to llvm calls, the cusparselt calls are disabled for cuda
658// architecture <8.0
659static bool is2To4Sparsity(Value spMat) {
660 if (auto op = spMat.getDefiningOp<gpu::Create2To4SpMatOp>())
661 return true;
662 if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
663 return false;
664 if (auto op = spMat.getDefiningOp<gpu::CreateCooAoSOp>())
665 return false;
666 if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
667 return false;
668 if (auto op = spMat.getDefiningOp<gpu::CreateCscOp>())
669 return false;
670 if (auto op = spMat.getDefiningOp<gpu::CreateBsrOp>())
671 return false;
672 // Print the spMat defining op
673 spMat.getDefiningOp()->print(llvm::errs());
674 llvm_unreachable("cannot find spmat def");
675}
676
677static bool isSpMMCusparseLtOp(Value op) {
678 for (Operation *user : op.getUsers()) {
679 auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
680 // If the other operator is 50% sparsity then we should use cusparseLt
681 if (!spmmOp)
682 continue;
683 if (is2To4Sparsity(spmmOp.getSpmatA()))
684 return true;
685 }
686 return false;
687}
688
689// Returns whether all operands are of LLVM type.
690static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
691 ConversionPatternRewriter &rewriter) {
692 if (!llvm::all_of(operands, [](Value value) {
693 return LLVM::isCompatibleType(value.getType());
694 }))
695 return rewriter.notifyMatchFailure(
696 op, "Cannot convert if operands aren't of LLVM type.");
697 return success();
698}
699
700static LogicalResult
701isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
702 gpu::AsyncOpInterface op) {
703 if (op.getAsyncDependencies().size() != 1)
704 return rewriter.notifyMatchFailure(
705 op, "Can only convert with exactly one async dependency.");
706
707 if (!op.getAsyncToken())
708 return rewriter.notifyMatchFailure(op, "Can convert only async version.");
709
710 return success();
711}
712
713LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
714 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
715 ConversionPatternRewriter &rewriter) const {
716 auto *op = hostRegisterOp.getOperation();
717 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
718 return failure();
719
720 Location loc = op->getLoc();
721
722 auto memRefType = hostRegisterOp.getValue().getType();
723 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
724 auto elementSize = getSizeInBytes(loc, elementType, rewriter);
725
726 auto arguments = getTypeConverter()->promoteOperands(
727 loc, op->getOperands(), adaptor.getOperands(), rewriter);
728 arguments.push_back(elementSize);
729 hostRegisterCallBuilder.create(loc, rewriter, arguments);
730
731 rewriter.eraseOp(op);
732 return success();
733}
734
735LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
736 gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
737 ConversionPatternRewriter &rewriter) const {
738 Operation *op = hostUnregisterOp.getOperation();
739 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
740 return failure();
741
742 Location loc = op->getLoc();
743
744 auto memRefType = hostUnregisterOp.getValue().getType();
745 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
746 auto elementSize = getSizeInBytes(loc, elementType, rewriter);
747
748 auto arguments = getTypeConverter()->promoteOperands(
749 loc, op->getOperands(), adaptor.getOperands(), rewriter);
750 arguments.push_back(elementSize);
751 hostUnregisterCallBuilder.create(loc, rewriter, arguments);
752
753 rewriter.eraseOp(op);
754 return success();
755}
756
757LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
758 gpu::AllocOp allocOp, OpAdaptor adaptor,
759 ConversionPatternRewriter &rewriter) const {
760
761 MemRefType memRefType = allocOp.getType();
762
763 if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
764 !isConvertibleAndHasIdentityMaps(memRefType))
765 return failure();
766
767 auto loc = allocOp.getLoc();
768
769 bool isShared = allocOp.getHostShared();
770
771 if (isShared && allocOp.getAsyncToken())
772 return rewriter.notifyMatchFailure(
773 allocOp, "Host Shared allocation cannot be done async");
774 if (!isShared && failed(isAsyncWithOneDependency(rewriter, allocOp)))
775 return failure();
776
777 // Get shape of the memref as values: static sizes are constant
778 // values and dynamic sizes are passed to 'alloc' as operands.
779 SmallVector<Value, 4> shape;
780 SmallVector<Value, 4> strides;
781 Value sizeBytes;
782 getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
783 shape, strides, sizeBytes);
784
785 // Allocate the underlying buffer and store a pointer to it in the MemRef
786 // descriptor.
787 auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType);
788 Value stream = adaptor.getAsyncDependencies().empty()
789 ? nullPtr
790 : adaptor.getAsyncDependencies().front();
791
792 auto isHostShared = mlir::LLVM::ConstantOp::create(
793 rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
794
795 Value allocatedPtr =
796 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
797 .getResult();
798
799 // No alignment.
800 Value alignedPtr = allocatedPtr;
801
802 // Create the MemRef descriptor.
803 auto memRefDescriptor = this->createMemRefDescriptor(
804 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
805
806 if (allocOp.getAsyncToken()) {
807 // Async alloc: make dependent ops use the same stream.
808 rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
809 } else {
810 rewriter.replaceOp(allocOp, {memRefDescriptor});
811 }
812
813 return success();
814}
815
816LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
817 gpu::DeallocOp deallocOp, OpAdaptor adaptor,
818 ConversionPatternRewriter &rewriter) const {
819 if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
820 failed(isAsyncWithOneDependency(rewriter, deallocOp)))
821 return failure();
822
823 Location loc = deallocOp.getLoc();
824
825 Value pointer =
826 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
827 Value stream = adaptor.getAsyncDependencies().front();
828 deallocCallBuilder.create(loc, rewriter, {pointer, stream});
829
830 rewriter.replaceOp(deallocOp, {stream});
831 return success();
832}
833
834static bool isGpuAsyncTokenType(Value value) {
835 return isa<gpu::AsyncTokenType>(value.getType());
836}
837
838// Converts !gpu.async.token operands of `async.yield` to runtime calls. The
839// !gpu.async.token are lowered to stream within the async.execute region, but
840// are passed as events between them. For each !gpu.async.token operand, we
841// create an event and record it on the stream.
842//
843// This pattern is registered with a higher benefit than the structural
844// async.yield rewriter from populateAsyncStructuralTypeConversionsAndLegality
845// so it wins when both match. Without that benefit override, the structural
846// pattern can win and silently retype gpu.async.token operands without
847// recording an event, leaving the host await to call cuEventSynchronize on
848// a stream pointer (a no-op that returns an error), racing the host against
849// the GPU.
850LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
851 async::YieldOp yieldOp, OpAdaptor adaptor,
852 ConversionPatternRewriter &rewriter) const {
853 if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType))
854 return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
855
856 Location loc = yieldOp.getLoc();
857 SmallVector<Value, 4> newOperands(adaptor.getOperands());
858 llvm::SmallDenseSet<Value> streams;
859 for (auto &operand : yieldOp->getOpOperands()) {
860 if (!isGpuAsyncTokenType(operand.get()))
861 continue;
862 auto idx = operand.getOperandNumber();
863 auto stream = adaptor.getOperands()[idx];
864 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
865 eventRecordCallBuilder.create(loc, rewriter, {event, stream});
866 newOperands[idx] = event;
867 streams.insert(stream);
868 }
869 for (auto stream : streams)
870 streamDestroyCallBuilder.create(loc, rewriter, {stream});
871
872 rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
873 return success();
874}
875
876// Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
877static bool isDefinedByCallTo(Value value, StringRef functionName) {
878 assert(isa<LLVM::LLVMPointerType>(value.getType()));
879 if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
880 return *defOp.getCallee() == functionName;
881 return false;
882}
883
884// Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
885// with the stream/event operands. The operands are destroyed. That is, it
886// assumes that it is not used afterwards or elsewhere. Otherwise we will get a
887// runtime error. Eventually, we should guarantee this property.
888LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
889 gpu::WaitOp waitOp, OpAdaptor adaptor,
890 ConversionPatternRewriter &rewriter) const {
891 if (waitOp.getAsyncToken())
892 return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
893
894 Location loc = waitOp.getLoc();
895
896 for (auto operand : adaptor.getOperands()) {
897 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
898 // The converted operand's definition created a stream.
899 streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
900 streamDestroyCallBuilder.create(loc, rewriter, {operand});
901 } else {
902 // Otherwise the converted operand is an event. This assumes that we use
903 // events in control flow code as well.
904 eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
905 eventDestroyCallBuilder.create(loc, rewriter, {operand});
906 }
907 }
908
909 rewriter.eraseOp(waitOp);
910 return success();
911}
912
913// Converts `gpu.wait async` to runtime calls. The converted op creates a new
914// stream that is synchronized with stream/event operands. The operands are
915// destroyed. That is, it assumes that it is not used afterwards or elsewhere.
916// Otherwise we will get a runtime error. Eventually, we should guarantee this
917// property.
918LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
919 gpu::WaitOp waitOp, OpAdaptor adaptor,
920 ConversionPatternRewriter &rewriter) const {
921 if (!waitOp.getAsyncToken())
922 return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
923
924 Location loc = waitOp.getLoc();
925
926 auto insertionPoint = rewriter.saveInsertionPoint();
927 SmallVector<Value, 1> events;
928 for (auto pair :
929 llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
930 auto operand = std::get<1>(pair);
931 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
932 // The converted operand's definition created a stream. Insert an event
933 // into the stream just after the last use of the original token operand.
934 auto *defOp = std::get<0>(pair).getDefiningOp();
935 rewriter.setInsertionPointAfter(defOp);
936 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
937 eventRecordCallBuilder.create(loc, rewriter, {event, operand});
938 events.push_back(event);
939 } else {
940 // Otherwise the converted operand is an event. This assumes that we use
941 // events in control flow code as well.
942 events.push_back(operand);
943 }
944 }
945 rewriter.restoreInsertionPoint(insertionPoint);
946 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
947 for (auto event : events)
948 streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
949 for (auto event : events)
950 eventDestroyCallBuilder.create(loc, rewriter, {event});
951 rewriter.replaceOp(waitOp, {stream});
952
953 return success();
954}
955
956// Legalize the op's operands.
957LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
958 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
959 ConversionPatternRewriter &rewriter) const {
960 if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
961 return failure();
962
963 // Fail when the synchronous version of the op has async dependencies. The
964 // lowering destroys the stream, and we do not want to check that there is no
965 // use of the stream after this op.
966 if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
967 return rewriter.notifyMatchFailure(
968 launchOp, "Cannot convert non-async op with async dependencies.");
969
970 Location loc = launchOp.getLoc();
971
972 Value stream = Value();
973 if (!adaptor.getAsyncDependencies().empty()) {
974 stream = adaptor.getAsyncDependencies().front();
975 // Synchronize additional async dependencies onto the primary stream using
976 // events, following the same approach as gpu.wait async lowering.
977 if (adaptor.getAsyncDependencies().size() > 1) {
978 auto insertionPoint = rewriter.saveInsertionPoint();
979 SmallVector<Value, 4> events;
980 for (auto [origDep, convertedDep] :
981 llvm::zip(launchOp.getAsyncDependencies().drop_front(),
982 adaptor.getAsyncDependencies().drop_front())) {
983 if (!isDefinedByCallTo(convertedDep,
984 streamCreateCallBuilder.functionName)) {
985 events.push_back(convertedDep);
986 continue;
987 }
988 Operation *defOp = origDep.getDefiningOp();
989 rewriter.setInsertionPointAfter(defOp);
990 Value event =
991 eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
992 eventRecordCallBuilder.create(loc, rewriter, {event, convertedDep});
993 events.push_back(event);
994 }
995 rewriter.restoreInsertionPoint(insertionPoint);
996 for (Value event : events)
997 streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
998 for (Value event : events)
999 eventDestroyCallBuilder.create(loc, rewriter, {event});
1000 }
1001 }
1002 // If the async keyword is present and there are no dependencies, then a
1003 // stream must be created to pass to subsequent operations.
1004 else if (launchOp.getAsyncToken())
1005 stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
1006
1007 // Lower the kernel operands to match kernel parameters.
1008 // Note: If `useBarePtrCallConv` is set in the type converter's options,
1009 // the value of `kernelBarePtrCallConv` will be ignored.
1010 OperandRange origArguments = launchOp.getKernelOperands();
1011 bool effectiveBarePtr = kernelBarePtrCallConv ||
1012 getTypeConverter()->getOptions().useBarePtrCallConv;
1013 if (effectiveBarePtr) {
1014 for (Value arg : origArguments) {
1015 if (isa<UnrankedMemRefType>(arg.getType()))
1016 return rewriter.notifyMatchFailure(
1017 loc, "unranked memref kernel argument is not supported with "
1018 "the bare-pointer calling convention");
1019 }
1020 }
1021 SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
1022 loc, origArguments, adaptor.getKernelOperands(), rewriter,
1023 /*useBarePtrCallConv=*/kernelBarePtrCallConv);
1024 SmallVector<Value, 8> llvmArgumentsWithSizes;
1025
1026 // Intersperse size information if requested.
1027 if (kernelIntersperseSizeCallConv) {
1028 if (origArguments.size() != llvmArguments.size()) {
1029 // This shouldn't happen if the bare-pointer calling convention is used.
1030 return rewriter.notifyMatchFailure(
1031 launchOp,
1032 "Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
1033 }
1034
1035 llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
1036 for (auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
1037 auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
1038 if (!memrefTy) {
1039 return rewriter.notifyMatchFailure(
1040 launchOp, "Operand to launch op is not a memref.");
1041 }
1042
1043 if (!memrefTy.hasStaticShape() ||
1044 !memrefTy.getElementType().isIntOrFloat()) {
1045 return rewriter.notifyMatchFailure(
1046 launchOp, "Operand to launch op is not a memref with a static "
1047 "shape and an integer or float element type.");
1048 }
1049
1050 unsigned bitwidth = memrefTy.getElementTypeBitWidth();
1051 if (bitwidth % 8 != 0) {
1052 return rewriter.notifyMatchFailure(
1053 launchOp, "Operand to launch op is not a memref with a "
1054 "byte-aligned element type.");
1055 }
1056
1057 uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) *
1058 static_cast<uint64_t>(memrefTy.getNumElements());
1059
1060 Value sizeArg = LLVM::ConstantOp::create(
1061 rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize));
1062 llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer.
1063 llvmArgumentsWithSizes.push_back(sizeArg);
1064 }
1065 }
1066
1067 std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1068 if (launchOp.hasClusterSize()) {
1069 clusterSize =
1070 gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
1071 adaptor.getClusterSizeZ()};
1072 }
1073 auto newLaunchOp = gpu::LaunchFuncOp::create(
1074 rewriter, launchOp.getLoc(), launchOp.getKernelAttr(),
1075 gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1076 adaptor.getGridSizeZ()},
1077 gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1078 adaptor.getBlockSizeZ()},
1079 adaptor.getDynamicSharedMemorySize(),
1080 llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
1081 stream, clusterSize);
1082 if (launchOp.getCooperative())
1083 newLaunchOp.setCooperative(true);
1084 if (launchOp.getAsyncToken())
1085 rewriter.replaceOp(launchOp, {stream});
1086 else
1087 rewriter.eraseOp(launchOp);
1088 return success();
1089}
1090
1092 ConversionPatternRewriter &rewriter,
1093 LLVM::LLVMPointerType destinationType,
1094 Value sourcePtr,
1095 const LLVMTypeConverter &typeConverter) {
1096 auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
1097 if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1098 sourcePtr = LLVM::AddrSpaceCastOp::create(
1099 rewriter, loc,
1100 LLVM::LLVMPointerType::get(rewriter.getContext(),
1101 destinationType.getAddressSpace()),
1102 sourcePtr);
1103 return sourcePtr;
1104}
1105
1106LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1107 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1108 ConversionPatternRewriter &rewriter) const {
1109 auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1110
1111 if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
1112 !isConvertibleAndHasIdentityMaps(memRefType) ||
1113 failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
1114 return failure();
1115
1116 auto loc = memcpyOp.getLoc();
1117
1118 MemRefDescriptor srcDesc(adaptor.getSrc());
1119 Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
1120
1121 Type elementPtrType = getElementPtrType(memRefType);
1122 Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
1123 Value gepPtr = LLVM::GEPOp::create(
1124 rewriter, loc, elementPtrType,
1125 typeConverter->convertType(memRefType.getElementType()), nullPtr,
1126 numElements);
1127 auto sizeBytes =
1128 LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr);
1129
1130 auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1131 srcDesc.alignedPtr(rewriter, loc),
1132 *getTypeConverter());
1133 auto dst = bitAndAddrspaceCast(
1134 loc, rewriter, llvmPointerType,
1135 MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
1136 *getTypeConverter());
1137
1138 auto stream = adaptor.getAsyncDependencies().front();
1139 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1140
1141 rewriter.replaceOp(memcpyOp, {stream});
1142
1143 return success();
1144}
1145
1146LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1147 gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1148 ConversionPatternRewriter &rewriter) const {
1149 auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1150
1151 if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
1152 !isConvertibleAndHasIdentityMaps(memRefType) ||
1153 failed(isAsyncWithOneDependency(rewriter, memsetOp)))
1154 return failure();
1155
1156 auto loc = memsetOp.getLoc();
1157
1158 Type valueType = adaptor.getValue().getType();
1159 unsigned bitWidth = valueType.getIntOrFloatBitWidth();
1160 // Ints and floats of 16 or 32 bit width are allowed.
1161 if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1162 return rewriter.notifyMatchFailure(
1163 memsetOp, "value must be a 16 or 32 bit int or float");
1164 }
1165
1166 unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth();
1167 Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1168
1169 MemRefDescriptor dstDesc(adaptor.getDst());
1170 Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
1171
1172 auto value =
1173 LLVM::BitcastOp::create(rewriter, loc, bitCastType, adaptor.getValue());
1174 auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1175 dstDesc.alignedPtr(rewriter, loc),
1176 *getTypeConverter());
1177
1178 auto stream = adaptor.getAsyncDependencies().front();
1179 FunctionCallBuilder builder =
1180 valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1181 builder.create(loc, rewriter, {dst, value, numElements, stream});
1182
1183 rewriter.replaceOp(memsetOp, {stream});
1184 return success();
1185}
1186
1187LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1188 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1189 ConversionPatternRewriter &rewriter) const {
1190 Location loc = op.getLoc();
1191 auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1192 {adaptor.getDevIndex()});
1193 rewriter.replaceOp(op, call);
1194 return success();
1195}
1196
1197template <typename T>
1198static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) {
1199 Type llvmInt32Type = builder.getIntegerType(32);
1200 return LLVM::ConstantOp::create(builder, loc, llvmInt32Type,
1201 static_cast<int32_t>(tValue));
1202}
1203
1204template <typename T>
1205static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) {
1206 Type llvmFloat32Type = builder.getF32Type();
1207 return LLVM::ConstantOp::create(
1208 builder, loc, llvmFloat32Type,
1209 builder.getF32FloatAttr(static_cast<float>(tValue)));
1210}
1211
1212LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1213 gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1214 ConversionPatternRewriter &rewriter) const {
1215 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1216 failed(isAsyncWithOneDependency(rewriter, op)))
1217 return failure();
1218 Location loc = op.getLoc();
1219 auto stream = adaptor.getAsyncDependencies().front();
1220 Value pTensor =
1221 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1222 Type dType = op.getMemref().getType().getElementType();
1223 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1224
1225 SmallVector<Value, 4> dims;
1226 for (Value dim : adaptor.getDims()) {
1227 dims.push_back(dim);
1228 }
1229
1230 Value handle;
1231 // TODO: For now, we track the use of the handle and lower it to cusparse /
1232 // cusparseLt accordingly. If in a block, both cusparse and cusparseLt are
1233 // used, we require two separate Creation ops to be the correct logic. In
1234 // future, we may add support to using one handle in sparse tensor / GPU
1235 // dialect in both cusparse and cusparseLt. use the cusparseLt create call if
1236 // the dnmat is used with spmat with 2:4 sparsity
1237 if (dims.size() == 2) {
1238 if (isSpMMCusparseLtOp(op.getDnTensor())) {
1239 auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1240 rewriter.getIndexAttr(11032));
1241 handle = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1242 llvmInt8Type, handleSz, /*alignment=*/16);
1243 handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1244
1245 createLtDnMatCallBuilder
1246 .create(loc, rewriter,
1247 {handle, dims[0], dims[1], pTensor, dtp, stream})
1248 .getResult();
1249 } else {
1250 handle =
1251 createDnMatCallBuilder
1252 .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1253 .getResult();
1254 }
1255 } else {
1256 assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1257 handle = createDnVecCallBuilder
1258 .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1259 .getResult();
1260 }
1261 rewriter.replaceOp(op, {handle, stream});
1262 return success();
1263}
1264
1265LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1266 gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1267 ConversionPatternRewriter &rewriter) const {
1268 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1269 failed(isAsyncWithOneDependency(rewriter, op)))
1270 return failure();
1271 Location loc = op.getLoc();
1272 auto stream = adaptor.getAsyncDependencies().front();
1273 auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
1274 SmallVector<Value, 4> dims;
1275 for (Value dim : definingOp.getDims()) {
1276 dims.push_back(dim);
1277 }
1278 if (dims.size() == 2) {
1279 // Use the cusparseLt destroy call if the dnmat is used with spmat with
1280 // 2:4 sparsity
1281 if (isSpMMCusparseLtOp(op.getDnTensor())) {
1282 destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1283 {adaptor.getDnTensor(), stream});
1284 } else {
1285 destroyDnMatCallBuilder.create(loc, rewriter,
1286 {adaptor.getDnTensor(), stream});
1287 }
1288 } else {
1289 assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1290 destroyDnVecCallBuilder.create(loc, rewriter,
1291 {adaptor.getDnTensor(), stream});
1292 }
1293 rewriter.replaceOp(op, {stream});
1294 return success();
1295}
1296
1297LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1298 gpu::CreateCooOp op, OpAdaptor adaptor,
1299 ConversionPatternRewriter &rewriter) const {
1300 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1301 failed(isAsyncWithOneDependency(rewriter, op)))
1302 return failure();
1303 Location loc = op.getLoc();
1304 auto stream = adaptor.getAsyncDependencies().front();
1305 Value pRowIdxs =
1306 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1307 Value pColIdxs =
1308 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1309 Value pValues =
1310 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1311 Type iType =
1312 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1313 Type dType =
1314 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1315 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1316 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1317 auto handle =
1318 createCooCallBuilder
1319 .create(loc, rewriter,
1320 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1321 pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1322 .getResult();
1323 rewriter.replaceOp(op, {handle, stream});
1324 return success();
1325}
1326
1327LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1328 gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1329 ConversionPatternRewriter &rewriter) const {
1330 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1331 failed(isAsyncWithOneDependency(rewriter, op)))
1332 return failure();
1333 Location loc = op.getLoc();
1334 auto stream = adaptor.getAsyncDependencies().front();
1335 Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc);
1336 Value pValues =
1337 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1338 Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1339 Type dType =
1340 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1341 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1342 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1343 auto handle =
1344 createCooAoSCallBuilder
1345 .create(loc, rewriter,
1346 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1347 pIdxs, pValues, itp, dtp, stream})
1348 .getResult();
1349 rewriter.replaceOp(op, {handle, stream});
1350 return success();
1351}
1352
1353LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1354 gpu::CreateCsrOp op, OpAdaptor adaptor,
1355 ConversionPatternRewriter &rewriter) const {
1356 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1357 failed(isAsyncWithOneDependency(rewriter, op)))
1358 return failure();
1359 Location loc = op.getLoc();
1360 auto stream = adaptor.getAsyncDependencies().front();
1361 Value pRowPos =
1362 MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc);
1363 Value pColIdxs =
1364 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1365 Value pValues =
1366 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1367 Type pType =
1368 llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1369 Type iType =
1370 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1371 Type dType =
1372 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1373 auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1374 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1375 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1376 auto handle =
1377 createCsrCallBuilder
1378 .create(loc, rewriter,
1379 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1380 pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1381 .getResult();
1382 rewriter.replaceOp(op, {handle, stream});
1383 return success();
1384}
1385
1386LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1387 gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1388 ConversionPatternRewriter &rewriter) const {
1389 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1390 failed(isAsyncWithOneDependency(rewriter, op)))
1391 return failure();
1392 Location loc = op.getLoc();
1393 auto stream = adaptor.getAsyncDependencies().front();
1394 Value pMat =
1395 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1396 Type dType =
1397 llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1398 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1399
1400 // CUDA runner asserts the size is 44104 bytes.
1401 auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1402 rewriter.getIndexAttr(44104));
1403 Value handle = LLVM::AllocaOp::create(
1404 rewriter, loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1405 handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1406
1407 create2To4SpMatCallBuilder
1408 .create(loc, rewriter,
1409 {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1410 .getResult();
1411 rewriter.replaceOp(op, {handle, stream});
1412 return success();
1413}
1414
1415LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1416 gpu::DestroySpMatOp op, OpAdaptor adaptor,
1417 ConversionPatternRewriter &rewriter) const {
1418 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1419 failed(isAsyncWithOneDependency(rewriter, op)))
1420 return failure();
1421 Location loc = op.getLoc();
1422 auto stream = adaptor.getAsyncDependencies().front();
1423 // Use the cusparseLt destroy call if the spmat is 2:4 sparsity
1424 if (is2To4Sparsity(op.getSpmat())) {
1425 destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1426 {adaptor.getSpmat(), stream});
1427
1428 } else {
1429 destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1430 }
1431 rewriter.replaceOp(op, {stream});
1432 return success();
1433}
1434
1435LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1436 gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1437 ConversionPatternRewriter &rewriter) const {
1438 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1439 failed(isAsyncWithOneDependency(rewriter, op)))
1440 return failure();
1441 Location loc = op.getLoc();
1442 auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
1443 auto computeType = genConstInt32From(
1444 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1445 auto stream = adaptor.getAsyncDependencies().front();
1446 auto bufferSize = spMVBufferSizeCallBuilder
1447 .create(loc, rewriter,
1448 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1449 adaptor.getDnY(), computeType, stream})
1450 .getResult();
1451 rewriter.replaceOp(op, {bufferSize, stream});
1452 return success();
1453}
1454
1455LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1456 gpu::SpMVOp op, OpAdaptor adaptor,
1457 ConversionPatternRewriter &rewriter) const {
1458 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1459 failed(isAsyncWithOneDependency(rewriter, op)))
1460 return failure();
1461 Location loc = op.getLoc();
1462 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1463 auto computeType = genConstInt32From(
1464 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1465 auto stream = adaptor.getAsyncDependencies().front();
1466 Value pBuf =
1467 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1468 spMVCallBuilder.create(loc, rewriter,
1469 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1470 adaptor.getDnY(), computeType, pBuf, stream});
1471 rewriter.replaceOp(op, {stream});
1472 return success();
1473}
1474
1475LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1476 gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1477 ConversionPatternRewriter &rewriter) const {
1478 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1479 failed(isAsyncWithOneDependency(rewriter, op)))
1480 return failure();
1481 Location loc = op.getLoc();
1482 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1483 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1484 auto stream = adaptor.getAsyncDependencies().front();
1485 Value bufferSize;
1486 if (is2To4Sparsity(op.getSpmatA())) {
1487 auto pruneFlag =
1488 genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA()));
1489 auto computeType = genConstInt32From(
1490 rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
1491 auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1492 rewriter.getIndexAttr(3));
1493 auto bufferSize =
1494 LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmPointerType,
1495 three, /*alignment=*/16);
1496 createCuSparseLtSpMMBufferSizeBuilder
1497 .create(loc, rewriter,
1498 {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1499 adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1500 pruneFlag, stream})
1501 .getResult();
1502
1503 auto bufferSizePtr1 = LLVM::GEPOp::create(
1504 rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1505 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1506 rewriter.getIndexAttr(1))});
1507 auto bufferSizePtr2 = LLVM::GEPOp::create(
1508 rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1509 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1510 rewriter.getIndexAttr(2))});
1511 auto bufferSize0 =
1512 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSize);
1513 auto bufferSize1 =
1514 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr1);
1515 auto bufferSize2 =
1516 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr2);
1517
1518 rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1519 } else {
1520 auto computeType = genConstInt32From(
1521 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1522 bufferSize =
1523 createSpMMBufferSizeCallBuilder
1524 .create(loc, rewriter,
1525 {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1526 adaptor.getDnmatC(), computeType, stream})
1527 .getResult();
1528 rewriter.replaceOp(op, {bufferSize, stream});
1529 }
1530 return success();
1531}
1532
1533LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1534 gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1535 ConversionPatternRewriter &rewriter) const {
1536 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1537 failed(isAsyncWithOneDependency(rewriter, op)))
1538 return failure();
1539 Location loc = op.getLoc();
1540 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1541 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1542 auto computeType = genConstInt32From(
1543 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1544 auto stream = adaptor.getAsyncDependencies().front();
1545 auto bufferSize =
1546 createSDDMMBufferSizeCallBuilder
1547 .create(loc, rewriter,
1548 {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1549 adaptor.getSpmatC(), computeType, stream})
1550 .getResult();
1551 rewriter.replaceOp(op, {bufferSize, stream});
1552 return success();
1553}
1554
1555LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1556 gpu::SpMMOp op, OpAdaptor adaptor,
1557 ConversionPatternRewriter &rewriter) const {
1558 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1559 failed(isAsyncWithOneDependency(rewriter, op)))
1560 return failure();
1561 Location loc = op.getLoc();
1562 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1563 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1564 auto computeType = genConstInt32From(
1565 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1566
1567 auto stream = adaptor.getAsyncDependencies().front();
1568
1569 // Lower to cusparseLt if applicable
1570 if (is2To4Sparsity(op.getSpmatA())) {
1571 SmallVector<Value> pBufs;
1572 for (Value buffer : adaptor.getBuffers()) {
1573 Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
1574 pBufs.push_back(pBuf);
1575 }
1576 createCuSparseLtSpMMBuilder.create(
1577 loc, rewriter,
1578 {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1579 pBufs[0], pBufs[1], pBufs[2], stream});
1580 } else {
1581 Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
1582 .allocatedPtr(rewriter, loc);
1583 createSpMMCallBuilder.create(loc, rewriter,
1584 {modeA, modeB, adaptor.getSpmatA(),
1585 adaptor.getDnmatB(), adaptor.getDnmatC(),
1586 computeType, pBuf, stream});
1587 }
1588 rewriter.replaceOp(op, {stream});
1589 return success();
1590}
1591
1592template <typename T>
1594 converter.addConversion([&converter](T) -> Type {
1595 return LLVM::LLVMPointerType::get(&converter.getContext());
1596 });
1597}
1598
1599LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1600 gpu::SDDMMOp op, OpAdaptor adaptor,
1601 ConversionPatternRewriter &rewriter) const {
1602 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1603 failed(isAsyncWithOneDependency(rewriter, op)))
1604 return failure();
1605 Location loc = op.getLoc();
1606 auto computeType = genConstInt32From(
1607 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1608 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1609 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1610 auto stream = adaptor.getAsyncDependencies().front();
1611 Value pBuf =
1612 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1613 createSDDMMCallBuilder.create(loc, rewriter,
1614 {modeA, modeB, adaptor.getDnmatA(),
1615 adaptor.getDnmatB(), adaptor.getSpmatC(),
1616 computeType, pBuf, stream});
1617 rewriter.replaceOp(op, {stream});
1618 return success();
1619}
1620
1621LogicalResult
1622ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1623 gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1624 ConversionPatternRewriter &rewriter) const {
1625 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1626 failed(isAsyncWithOneDependency(rewriter, op)))
1627 return failure();
1628 Location loc = op.getLoc();
1629 auto stream = adaptor.getAsyncDependencies().front();
1630 Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1631 .getResult();
1632 rewriter.replaceOp(op, {descr, stream});
1633 return success();
1634}
1635
1636LogicalResult
1637ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1638 gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1639 ConversionPatternRewriter &rewriter) const {
1640 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1641 failed(isAsyncWithOneDependency(rewriter, op)))
1642 return failure();
1643 Location loc = op.getLoc();
1644 auto stream = adaptor.getAsyncDependencies().front();
1645 createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1646 {adaptor.getDesc(), stream});
1647 rewriter.replaceOp(op, {stream});
1648 return success();
1649}
1650
1651LogicalResult
1652ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1653 gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1654 ConversionPatternRewriter &rewriter) const {
1655 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1656 failed(isAsyncWithOneDependency(rewriter, op)))
1657 return failure();
1658 Location loc = op.getLoc();
1659 auto computeType = genConstInt32From(
1660 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1661 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1662 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1663 auto stream = adaptor.getAsyncDependencies().front();
1664
1665 Value pBuf =
1666 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1667 Value bufferSizeNew;
1668
1669 if (adaptor.getKind() ==
1670 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1671 bufferSizeNew =
1672 createSpGEMMWorkEstimationBuilder
1673 .create(loc, rewriter,
1674 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1675 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1676 adaptor.getBufferSz(), pBuf, stream})
1677 .getResult();
1678 } else {
1679 bufferSizeNew =
1680 createSpGEMMComputeBuilder
1681 .create(loc, rewriter,
1682 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1683 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1684 adaptor.getBufferSz(), pBuf, stream})
1685 .getResult();
1686 }
1687 rewriter.replaceOp(op, {bufferSizeNew, stream});
1688 return success();
1689}
1690
1691LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1692 gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1693 ConversionPatternRewriter &rewriter) const {
1694 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1695 failed(isAsyncWithOneDependency(rewriter, op)))
1696 return failure();
1697 Location loc = op.getLoc();
1698 auto computeType = genConstInt32From(
1699 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1700 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1701 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1702 auto stream = adaptor.getAsyncDependencies().front();
1703 createSpGEMMCopyBuilder.create(loc, rewriter,
1704 {adaptor.getDesc(), modeA, modeB,
1705 adaptor.getSpmatA(), adaptor.getSpmatB(),
1706 adaptor.getSpmatC(), computeType, stream});
1707 rewriter.replaceOp(op, {stream});
1708 return success();
1709}
1710
1711LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1712 gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1713 ConversionPatternRewriter &rewriter) const {
1714 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1715 failed(isAsyncWithOneDependency(rewriter, op)))
1716 return failure();
1717 Location loc = op.getLoc();
1718 auto stream = adaptor.getAsyncDependencies().front();
1719
1720 auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1721 rewriter.getIndexAttr(3));
1722 auto buffer = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1723 llvmInt64Type, three, /*alignment=*/16);
1724
1725 auto rowsPtr = LLVM::GEPOp::create(
1726 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1727 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1728 rewriter.getIndexAttr(0))});
1729 auto colsPtr = LLVM::GEPOp::create(
1730 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1731 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1732 rewriter.getIndexAttr(1))});
1733 auto nnzsPtr = LLVM::GEPOp::create(
1734 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1735 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1736 rewriter.getIndexAttr(2))});
1737 createSpMatGetSizeBuilder.create(
1738 loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1739 auto rows = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, rowsPtr);
1740 auto cols = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, colsPtr);
1741 auto nnzs = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, nnzsPtr);
1742
1743 rewriter.replaceOp(op, {rows, cols, nnzs, stream});
1744 return success();
1745}
1746
1747LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1748 gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1749 ConversionPatternRewriter &rewriter) const {
1750 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1751 failed(isAsyncWithOneDependency(rewriter, op)))
1752 return failure();
1753 Location loc = op.getLoc();
1754 auto stream = adaptor.getAsyncDependencies().front();
1755 Value pPos =
1756 MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
1757 Value pCrd =
1758 MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
1759 Value pVal =
1760 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1761 createSetCsrPointersBuilder.create(
1762 loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1763 rewriter.replaceOp(op, {stream});
1764 return success();
1765}
1766
1767LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1768 gpu::CreateCscOp op, OpAdaptor adaptor,
1769 ConversionPatternRewriter &rewriter) const {
1770 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1771 failed(isAsyncWithOneDependency(rewriter, op)))
1772 return failure();
1773 Location loc = op.getLoc();
1774 auto stream = adaptor.getAsyncDependencies().front();
1775 Value pColPos =
1776 MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc);
1777 Value pRowIdxs =
1778 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1779 Value pValues =
1780 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1781 Type pType =
1782 llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1783 Type iType =
1784 llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1785 Type dType =
1786 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1787 auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1788 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1789 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1790 auto handle =
1791 createCscCallBuilder
1792 .create(loc, rewriter,
1793 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1794 pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1795 .getResult();
1796 rewriter.replaceOp(op, {handle, stream});
1797 return success();
1798}
1799
1800LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1801 gpu::CreateBsrOp op, OpAdaptor adaptor,
1802 ConversionPatternRewriter &rewriter) const {
1803 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1804 failed(isAsyncWithOneDependency(rewriter, op)))
1805 return failure();
1806 Location loc = op.getLoc();
1807 auto stream = adaptor.getAsyncDependencies().front();
1808 Value pRowPos =
1809 MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc);
1810 Value pColIdxs =
1811 MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc);
1812 Value pValues =
1813 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1814 Type pType =
1815 llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1816 Type iType =
1817 llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1818 Type dType =
1819 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1820 auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1821 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1822 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1823 auto handle =
1824 createBsrCallBuilder
1825 .create(loc, rewriter,
1826 {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1827 adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1828 pColIdxs, pValues, ptp, itp, dtp, stream})
1829 .getResult();
1830 rewriter.replaceOp(op, {handle, stream});
1831 return success();
1832}
1833
1835 LLVMTypeConverter &converter, RewritePatternSet &patterns,
1836 bool kernelBarePtrCallConv, bool kernelIntersperseSizeCallConv) {
1841
1842 // Higher benefit so this pattern wins over the structural async.yield
1843 // rewriter from populateAsyncStructuralTypeConversionsAndLegality on yields
1844 // with gpu.async.token operands. The structural rewriter would silently
1845 // retype operands without recording an event on the underlying stream.
1846 patterns.add<ConvertAsyncYieldToGpuRuntimeCallPattern>(converter,
1847 /*benefit=*/2);
1848
1849 patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
1850 ConvertDeallocOpToGpuRuntimeCallPattern,
1851 ConvertHostRegisterOpToGpuRuntimeCallPattern,
1852 ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1853 ConvertMemcpyOpToGpuRuntimeCallPattern,
1854 ConvertMemsetOpToGpuRuntimeCallPattern,
1855 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1856 ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1857 ConvertWaitOpToGpuRuntimeCallPattern,
1858 ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1859 ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1860 ConvertCreateCooOpToGpuRuntimeCallPattern,
1861 ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1862 ConvertCreateCsrOpToGpuRuntimeCallPattern,
1863 ConvertCreateCscOpToGpuRuntimeCallPattern,
1864 ConvertCreateBsrOpToGpuRuntimeCallPattern,
1865 ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1866 ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1867 ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1868 ConvertSpMVOpToGpuRuntimeCallPattern,
1869 ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1870 ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1871 ConvertSpMMOpToGpuRuntimeCallPattern,
1872 ConvertSDDMMOpToGpuRuntimeCallPattern,
1873 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1874 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1875 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1876 ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1877 ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1878 ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
1879 patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv,
1880 kernelIntersperseSizeCallConv);
1881}
1882
1883//===----------------------------------------------------------------------===//
1884// GPUModuleOp convert to LLVM op interface
1885//===----------------------------------------------------------------------===//
1886
1887namespace {
1888struct GPUModuleOpConvertToLLVMInterface
1889 : public ConvertToLLVMOpInterface::ExternalModel<
1890 GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
1891 /// Get the conversion patterns from the target attribute.
1892 void getConvertToLLVMConversionAttrs(
1894};
1895} // namespace
1896
1897void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
1898 Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const {
1899 auto module = cast<gpu::GPUModuleOp>(op);
1900 ArrayAttr targetsAttr = module.getTargetsAttr();
1901 // Fail if there are no target attributes or there is more than one target.
1902 if (!targetsAttr || targetsAttr.size() != 1)
1903 return;
1904 if (auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
1905 attrs.push_back(patternAttr);
1906}
1907
1909 registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
1910 gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
1911 });
1912}
return success()
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
ArrayAttr()
b getContext())
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
FloatType getF32Type()
Definition Builders.cpp:47
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
FloatAttr getF32FloatAttr(float value)
Definition Builders.cpp:250
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition Pattern.cpp:38
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition Pattern.cpp:58
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
This class helps build Operations.
Definition Builders.h:209
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition Builders.h:248
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void registerConvertGpuToLLVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMOpInterface interface on the gpu::GPUModuleOP operation.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool kernelBarePtrCallConv=false, bool kernelIntersperseSizeCallConv=false)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
LLVM::LLVMFunctionType functionType
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const