MLIR 23.0.0git
GPUToLLVMConversion.cpp
Go to the documentation of this file.
1//===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert gpu.launch_func op into a sequence of
10// GPU runtime calls. As most of GPU runtimes does not have a stable published
11// ABI, this pass uses a slim runtime layer that builds on top of the public
12// API from GPU runtime headers.
13//
14//===----------------------------------------------------------------------===//
15
17
34#include "mlir/IR/Attributes.h"
35#include "mlir/IR/Builders.h"
36#include "mlir/IR/BuiltinOps.h"
39
40#include "llvm/ADT/STLExtras.h"
41
42#define DEBUG_TYPE "gpu-to-llvm"
43
44namespace mlir {
45#define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
46#include "mlir/Conversion/Passes.h.inc"
47} // namespace mlir
48
49using namespace mlir;
50
51namespace {
52class GpuToLLVMConversionPass
53 : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
54public:
55 using Base::Base;
56 void getDependentDialects(DialectRegistry &registry) const final {
57 Base::getDependentDialects(registry);
59 }
60 // Run the dialect converter on the module.
61 void runOnOperation() override;
62};
63
64template <typename OpTy>
65class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
66public:
67 explicit ConvertOpToGpuRuntimeCallPattern(
68 const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
69 : ConvertOpToLLVMPattern<OpTy>(typeConverter, benefit) {}
70
71protected:
72 Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
73 MemRefType type, MemRefDescriptor desc) const {
74 Type indexType = ConvertToLLVMPattern::getIndexType();
75 if (type.hasStaticShape())
77 rewriter, loc, indexType, type.getNumElements());
78 // Compute the number of elements by multiplying all the dim sizes.
79 uint64_t rank = type.getRank();
80 Value numElements = desc.size(rewriter, loc, /*pos=*/0);
81 for (unsigned i = 1; i < rank; i++)
82 numElements = LLVM::MulOp::create(rewriter, loc, numElements,
83 desc.size(rewriter, loc, /*pos=*/i));
84 return numElements;
85 }
86
87 MLIRContext *context = &this->getTypeConverter()->getContext();
88
89 Type llvmVoidType = LLVM::LLVMVoidType::get(context);
90 LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
91 Type llvmInt8Type = IntegerType::get(context, 8);
92 Type llvmInt16Type = IntegerType::get(context, 16);
93 Type llvmInt32Type = IntegerType::get(context, 32);
94 Type llvmInt64Type = IntegerType::get(context, 64);
95 Type llvmFloat32Type = Float32Type::get(context);
96 Type llvmIntPtrType = IntegerType::get(
97 context, this->getTypeConverter()->getPointerBitwidth(0));
98
99 FunctionCallBuilder streamCreateCallBuilder = {
100 "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
101 FunctionCallBuilder streamDestroyCallBuilder = {
102 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
103 FunctionCallBuilder streamSynchronizeCallBuilder = {
104 "mgpuStreamSynchronize",
105 llvmVoidType,
106 {llvmPointerType /* void *stream */}};
107 FunctionCallBuilder streamWaitEventCallBuilder = {
108 "mgpuStreamWaitEvent",
109 llvmVoidType,
110 {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
111 FunctionCallBuilder eventCreateCallBuilder = {
112 "mgpuEventCreate", llvmPointerType /* void *event */, {}};
113 FunctionCallBuilder eventDestroyCallBuilder = {
114 "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
115 FunctionCallBuilder eventSynchronizeCallBuilder = {
116 "mgpuEventSynchronize",
117 llvmVoidType,
118 {llvmPointerType /* void *event */}};
119 FunctionCallBuilder eventRecordCallBuilder = {
120 "mgpuEventRecord",
121 llvmVoidType,
122 {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
123 FunctionCallBuilder hostRegisterCallBuilder = {
124 "mgpuMemHostRegisterMemRef",
125 llvmVoidType,
126 {llvmIntPtrType /* intptr_t rank */,
127 llvmPointerType /* void *memrefDesc */,
128 llvmIntPtrType /* intptr_t elementSizeBytes */}};
129 FunctionCallBuilder hostUnregisterCallBuilder = {
130 "mgpuMemHostUnregisterMemRef",
131 llvmVoidType,
132 {llvmIntPtrType /* intptr_t rank */,
133 llvmPointerType /* void *memrefDesc */,
134 llvmIntPtrType /* intptr_t elementSizeBytes */}};
135 FunctionCallBuilder allocCallBuilder = {
136 "mgpuMemAlloc",
137 llvmPointerType /* void * */,
138 {llvmIntPtrType /* intptr_t sizeBytes */,
139 llvmPointerType /* void *stream */,
140 llvmInt8Type /* bool isHostShared */}};
141 FunctionCallBuilder deallocCallBuilder = {
142 "mgpuMemFree",
143 llvmVoidType,
144 {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
145 FunctionCallBuilder memcpyCallBuilder = {
146 "mgpuMemcpy",
147 llvmVoidType,
148 {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
149 llvmIntPtrType /* intptr_t sizeBytes */,
150 llvmPointerType /* void *stream */}};
151 FunctionCallBuilder memset16CallBuilder = {
152 "mgpuMemset16",
153 llvmVoidType,
154 {llvmPointerType /* void *dst */,
155 llvmInt16Type /* unsigned short value */,
156 llvmIntPtrType /* intptr_t sizeBytes */,
157 llvmPointerType /* void *stream */}};
158 FunctionCallBuilder memset32CallBuilder = {
159 "mgpuMemset32",
160 llvmVoidType,
161 {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
162 llvmIntPtrType /* intptr_t sizeBytes */,
163 llvmPointerType /* void *stream */}};
164 FunctionCallBuilder setDefaultDeviceCallBuilder = {
165 "mgpuSetDefaultDevice",
166 llvmVoidType,
167 {llvmInt32Type /* uint32_t devIndex */}};
168 FunctionCallBuilder createDnVecCallBuilder = {
169 "mgpuCreateDnVec",
170 llvmPointerType,
171 {llvmIntPtrType, llvmPointerType, llvmInt32Type,
172 llvmPointerType /* void *stream */}};
173 FunctionCallBuilder destroyDnVecCallBuilder = {
174 "mgpuDestroyDnVec",
175 llvmVoidType,
176 {llvmPointerType, llvmPointerType /* void *stream */}};
177 FunctionCallBuilder createDnMatCallBuilder = {
178 "mgpuCreateDnMat",
179 llvmPointerType,
180 {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
181 llvmPointerType /* void *stream */}};
182 FunctionCallBuilder destroyDnMatCallBuilder = {
183 "mgpuDestroyDnMat",
184 llvmVoidType,
185 {llvmPointerType, llvmPointerType /* void *stream */}};
186 FunctionCallBuilder createCooCallBuilder = {
187 "mgpuCreateCoo",
188 llvmPointerType,
189 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
190 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
191 llvmPointerType /* void *stream */}};
192 FunctionCallBuilder createCooAoSCallBuilder = {
193 "mgpuCreateCooAoS", // deprecated in cuSPARSE 11.2
194 llvmPointerType,
195 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
196 llvmPointerType, llvmInt32Type, llvmInt32Type,
197 llvmPointerType /* void *stream */}};
198 FunctionCallBuilder createCsrCallBuilder = {
199 "mgpuCreateCsr",
200 llvmPointerType,
201 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
202 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
203 llvmInt32Type, llvmPointerType /* void *stream */}};
204 FunctionCallBuilder createCscCallBuilder = {
205 "mgpuCreateCsc",
206 llvmPointerType,
207 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
208 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
209 llvmInt32Type, llvmPointerType /* void *stream */}};
210 FunctionCallBuilder createBsrCallBuilder = {
211 "mgpuCreateBsr",
212 llvmPointerType,
213 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
214 llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
215 llvmInt32Type, llvmInt32Type, llvmInt32Type,
216 llvmPointerType /* void *stream */}};
217 FunctionCallBuilder destroySpMatCallBuilder = {
218 "mgpuDestroySpMat",
219 llvmVoidType,
220 {llvmPointerType, llvmPointerType /* void *stream */}};
221 FunctionCallBuilder spMVBufferSizeCallBuilder = {
222 "mgpuSpMVBufferSize",
223 llvmIntPtrType,
224 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
225 llvmInt32Type, llvmPointerType /* void *stream */}};
226 FunctionCallBuilder spMVCallBuilder = {
227 "mgpuSpMV",
228 llvmVoidType,
229 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
230 llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
231 FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
232 "mgpuSpMMBufferSize",
233 llvmIntPtrType,
234 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
235 llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
236 FunctionCallBuilder createSpMMCallBuilder = {
237 "mgpuSpMM",
238 llvmVoidType,
239 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
240 llvmPointerType, llvmInt32Type, llvmPointerType,
241 llvmPointerType /* void *stream */}};
242 FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
243 "mgpuSDDMMBufferSize",
244 llvmIntPtrType,
245 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
246 llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
247 FunctionCallBuilder createSDDMMCallBuilder = {
248 "mgpuSDDMM",
249 llvmVoidType,
250 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
251 llvmPointerType, llvmInt32Type, llvmPointerType,
252 llvmPointerType /* void *stream */}};
253 FunctionCallBuilder createLtDnMatCallBuilder = {
254 "mgpuCreateCuSparseLtDnMat",
255 llvmVoidType,
256 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
257 llvmInt32Type, llvmPointerType /* void *stream */}};
258 FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
259 "mgpuDestroyCuSparseLtSpMat",
260 llvmVoidType,
261 {llvmPointerType, llvmPointerType /* void *stream */}};
262 FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
263 "mgpuDestroyCuSparseLtDnMat",
264 llvmVoidType,
265 {llvmPointerType, llvmPointerType /* void *stream */}};
266 FunctionCallBuilder create2To4SpMatCallBuilder = {
267 "mgpuCusparseLtCreate2To4SpMat",
268 llvmVoidType,
269 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
270 llvmInt32Type, llvmPointerType /* void *stream */}};
271 FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
272 "mgpuCuSparseLtSpMMBufferSize",
273 llvmVoidType,
274 {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
275 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
276 llvmPointerType /*void *stream*/}};
277 FunctionCallBuilder createCuSparseLtSpMMBuilder = {
278 "mgpuCuSparseLtSpMM",
279 llvmVoidType,
280 {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
281 llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}};
282 FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
283 "mgpuSpGEMMCreateDescr",
284 llvmPointerType,
285 {llvmPointerType /*void *stream*/}};
286 FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
287 "mgpuSpGEMMDestroyDescr",
288 llvmVoidType,
289 {llvmPointerType /*s*/, llvmPointerType /*void *stream*/}};
290 FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
291 "mgpuSpGEMMWorkEstimation",
292 llvmIntPtrType,
293 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
294 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
295 llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
296 llvmPointerType /*void *stream*/}};
297 FunctionCallBuilder createSpGEMMComputeBuilder = {
298 "mgpuSpGEMMCompute",
299 llvmIntPtrType,
300 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
301 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
302 llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
303 llvmPointerType /*void *stream*/}};
304 FunctionCallBuilder createSpGEMMCopyBuilder = {
305 "mgpuSpGEMMCopy",
306 llvmVoidType,
307 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
308 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
309 llvmInt32Type /*ctp*/, llvmPointerType /*void *stream*/}};
310 FunctionCallBuilder createSpMatGetSizeBuilder = {
311 "mgpuSpMatGetSize",
312 llvmVoidType,
313 {llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/,
314 llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}};
315 FunctionCallBuilder createSetCsrPointersBuilder = {
316 "mgpuSetCsrPointers",
317 llvmVoidType,
318 {llvmPointerType /*spmat*/, llvmPointerType /*pos*/,
319 llvmPointerType /*crd*/, llvmPointerType /*val*/,
320 llvmPointerType /*void *stream*/}};
321};
322
323/// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
324/// call. Currently it supports CUDA and ROCm (HIP).
325class ConvertHostRegisterOpToGpuRuntimeCallPattern
326 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
327public:
328 ConvertHostRegisterOpToGpuRuntimeCallPattern(
329 const LLVMTypeConverter &typeConverter)
330 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
331
332private:
333 LogicalResult
334 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter) const override;
336};
337
338class ConvertHostUnregisterOpToGpuRuntimeCallPattern
339 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
340public:
341 ConvertHostUnregisterOpToGpuRuntimeCallPattern(
342 const LLVMTypeConverter &typeConverter)
343 : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
344 }
345
346private:
347 LogicalResult
348 matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter) const override;
350};
351
352/// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
353/// call. Currently it supports CUDA and ROCm (HIP).
354class ConvertAllocOpToGpuRuntimeCallPattern
355 : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
356public:
357 ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
358 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
359
360private:
361 LogicalResult
362 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter) const override;
364};
365
366/// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
367/// call. Currently it supports CUDA and ROCm (HIP).
368class ConvertDeallocOpToGpuRuntimeCallPattern
369 : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
370public:
371 ConvertDeallocOpToGpuRuntimeCallPattern(
372 const LLVMTypeConverter &typeConverter)
373 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
374
375private:
376 LogicalResult
377 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
378 ConversionPatternRewriter &rewriter) const override;
379};
380
381class ConvertAsyncYieldToGpuRuntimeCallPattern
382 : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
383public:
384 ConvertAsyncYieldToGpuRuntimeCallPattern(
385 const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
386 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter,
387 benefit) {}
388
389private:
390 LogicalResult
391 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
392 ConversionPatternRewriter &rewriter) const override;
393};
394
395/// A rewrite pattern to convert gpu.wait operations into a GPU runtime
396/// call. Currently it supports CUDA and ROCm (HIP).
397class ConvertWaitOpToGpuRuntimeCallPattern
398 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
399public:
400 ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
401 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
402
403private:
404 LogicalResult
405 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
406 ConversionPatternRewriter &rewriter) const override;
407};
408
409/// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
410/// call. Currently it supports CUDA and ROCm (HIP).
411class ConvertWaitAsyncOpToGpuRuntimeCallPattern
412 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
413public:
414 ConvertWaitAsyncOpToGpuRuntimeCallPattern(
415 const LLVMTypeConverter &typeConverter)
416 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
417
418private:
419 LogicalResult
420 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
421 ConversionPatternRewriter &rewriter) const override;
422};
423
424/// A rewrite patter to legalize gpu.launch_func with LLVM types.
425class LegalizeLaunchFuncOpPattern
426 : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
427public:
428 LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
429 bool kernelBarePtrCallConv,
430 bool kernelIntersperseSizeCallConv)
431 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
432 kernelBarePtrCallConv(kernelBarePtrCallConv),
433 kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
434
435private:
436 LogicalResult
437 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
438 ConversionPatternRewriter &rewriter) const override;
439
440 bool kernelBarePtrCallConv;
441 bool kernelIntersperseSizeCallConv;
442};
443
444/// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
445/// call. Currently it supports CUDA and ROCm (HIP).
446class ConvertMemcpyOpToGpuRuntimeCallPattern
447 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
448public:
449 ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
450 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
451
452private:
453 LogicalResult
454 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
455 ConversionPatternRewriter &rewriter) const override;
456};
457
458/// A rewrite pattern to convert gpu.memset operations into a GPU runtime
459/// call. Currently it supports CUDA and ROCm (HIP).
460class ConvertMemsetOpToGpuRuntimeCallPattern
461 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
462public:
463 ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
464 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
465
466private:
467 LogicalResult
468 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
469 ConversionPatternRewriter &rewriter) const override;
470};
471
472/// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call.
473/// Currently supports CUDA and ROCm (HIP)
474class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
475 : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
476public:
477 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
478 const LLVMTypeConverter &typeConverter)
479 : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
480 typeConverter) {}
481
482 LogicalResult
483 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
484 ConversionPatternRewriter &rewriter) const override;
485};
486
487/// Generic rewriting rule for operation on sparse matrices.
488/// Currently supports CUDA (by means of cuSparse and cuSparseLt).
489#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
490 class Convert##op_name##ToGpuRuntimeCallPattern \
491 : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
492 public: \
493 Convert##op_name##ToGpuRuntimeCallPattern( \
494 const LLVMTypeConverter &typeConverter) \
495 : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
496 \
497 private: \
498 LogicalResult \
499 matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
500 ConversionPatternRewriter &rewriter) const override; \
501 };
502
520DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp)
524
525} // namespace
526
527void GpuToLLVMConversionPass::runOnOperation() {
528 MLIRContext *context = &getContext();
529
530 // Perform progressive lowering of vector transfer operations.
531 {
532 RewritePatternSet patterns(&getContext());
533 // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
535 /*maxTransferRank=*/1);
536 // Transform N-D vector.from_elements to 1-D vector.from_elements before
537 // conversion.
538 vector::populateVectorFromElementsUnrollPatterns(patterns);
539 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
540 return signalPassFailure();
541 }
542
543 LowerToLLVMOptions options(context);
544 options.useBarePtrCallConv = hostBarePtrCallConv;
545 RewritePatternSet patterns(context);
546 ConversionTarget target(*context);
547 target.addLegalDialect<LLVM::LLVMDialect>();
548 LLVMTypeConverter converter(context, options);
549
550 // Populate all patterns from all dialects that implement the
551 // `ConvertToLLVMPatternInterface` interface.
552 for (Dialect *dialect : context->getLoadedDialects()) {
553 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
554 if (!iface)
555 continue;
556 iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
557 }
558
559 // Preserve GPU modules and binaries. Modules are preserved as they can be
560 // converted later by `gpu-module-to-binary`.
561 target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
562 // Accept as legal LaunchFuncOps if the operands have been lowered.
563 target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
564 [&](gpu::LaunchFuncOp op) -> bool { return converter.isLegal(op); });
565
566 // These aren't covered by the ConvertToLLVMPatternInterface right now.
567 populateVectorToLLVMConversionPatterns(converter, patterns);
570 target);
571 populateGpuToLLVMConversionPatterns(converter, patterns,
572 kernelBarePtrCallConv,
573 kernelIntersperseSizeCallConv);
574
575 if (failed(
576 applyPartialConversion(getOperation(), target, std::move(patterns))))
577 signalPassFailure();
578}
579
581 ArrayRef<Value> arguments) const {
582 auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
583 auto function = [&] {
584 if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
585 return function;
586 auto builder = OpBuilder::atBlockEnd(module.getBody());
587 return LLVM::LLVMFuncOp::create(builder, loc, functionName, functionType);
588 }();
589 return LLVM::CallOp::create(builder, loc, function, arguments);
590}
591
592// Corresponding to cusparseIndexType_t defined in cusparse.h.
593static int32_t getCuSparseIndexTypeFrom(Type type) {
594 if (type.isInteger(16))
595 return 1; // CUSPARSE_INDEX_16U
596 if (type.isInteger(32))
597 return 2; // CUSPARSE_INDEX_32I
598 return 3; // CUSPARSE_INDEX_64I
599}
600
601static int32_t getCuSparseLtDataTypeFrom(Type type) {
602 if (type.isF16())
603 return 0; // CUSPARSE_COMPUTE_16F,
604 if (type.isInteger(32))
605 return 1; // CUSPARSE_COMPUTE_32I
606 llvm_unreachable("unsupported type");
607 // TODO: add support to TF32
608}
609
610// Corresponding to cudaDataType_t defined in CUDA library_types.h.
611static int32_t getCuSparseDataTypeFrom(Type type) {
612 if (llvm::isa<ComplexType>(type)) {
613 // get the element type
614 auto elementType = cast<ComplexType>(type).getElementType();
615 if (elementType.isBF16())
616 return 15; // CUDA_C_16BF
617 if (elementType.isF16())
618 return 6; // CUDA_C_16F
619 if (elementType.isF32())
620 return 4; // CUDA_C_32F
621 if (elementType.isF64())
622 return 5; // CUDA_C_64F
623 if (elementType.isInteger(8))
624 return 7; // CUDA_C_8I
625 if (elementType.isInteger(16))
626 return 21; // CUDA_C_16I
627 if (elementType.isInteger(32))
628 return 11; // CUDA_C_32I
629 }
630 if (type.isBF16())
631 return 14; // CUDA_R_16BF
632 if (type.isF16())
633 return 2; // CUDA_R_16F
634 if (type.isF32())
635 return 0; // CUDA_R_32F
636 if (type.isF64())
637 return 1; // CUDA_R_64F
638 if (type.isInteger(8))
639 return 3; // CUDA_R_8I
640 if (type.isInteger(16))
641 return 20; // CUDA_R_16I
642 if (type.isInteger(32))
643 return 10; // CUDA_R_32I
644
645 llvm_unreachable("unsupported element type");
646}
647
648static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat) {
649 return spMat.getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
650}
651
652// TODO: We may want a run-time (of the mlir compiler) disablement/warning:
653// cusparseLt currently won't work for cuda architecture <8.0 and will trigger a
654// runtime (of the CUDA program) error , but it might be great if we could at
655// least output a warning when we found the target architecture is <8.0 and the
656// user still wants to use cusparseLt. to make sure when lowering gpu sparse
657// dialect to llvm calls, the cusparselt calls are disabled for cuda
658// architecture <8.0
659static bool is2To4Sparsity(Value spMat) {
660 if (auto op = spMat.getDefiningOp<gpu::Create2To4SpMatOp>())
661 return true;
662 if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
663 return false;
664 if (auto op = spMat.getDefiningOp<gpu::CreateCooAoSOp>())
665 return false;
666 if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
667 return false;
668 if (auto op = spMat.getDefiningOp<gpu::CreateCscOp>())
669 return false;
670 if (auto op = spMat.getDefiningOp<gpu::CreateBsrOp>())
671 return false;
672 // Print the spMat defining op
673 spMat.getDefiningOp()->print(llvm::errs());
674 llvm_unreachable("cannot find spmat def");
675}
676
677static bool isSpMMCusparseLtOp(Value op) {
678 for (Operation *user : op.getUsers()) {
679 auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
680 // If the other operator is 50% sparsity then we should use cusparseLt
681 if (!spmmOp)
682 continue;
683 if (is2To4Sparsity(spmmOp.getSpmatA()))
684 return true;
685 }
686 return false;
687}
688
689// Returns whether all operands are of LLVM type.
690static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
691 ConversionPatternRewriter &rewriter) {
692 if (!llvm::all_of(operands, [](Value value) {
693 return LLVM::isCompatibleType(value.getType());
694 }))
695 return rewriter.notifyMatchFailure(
696 op, "Cannot convert if operands aren't of LLVM type.");
697 return success();
698}
699
700static LogicalResult
701isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
702 gpu::AsyncOpInterface op) {
703 if (op.getAsyncDependencies().size() != 1)
704 return rewriter.notifyMatchFailure(
705 op, "Can only convert with exactly one async dependency.");
706
707 if (!op.getAsyncToken())
708 return rewriter.notifyMatchFailure(op, "Can convert only async version.");
709
710 return success();
711}
712
713LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
714 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
715 ConversionPatternRewriter &rewriter) const {
716 auto *op = hostRegisterOp.getOperation();
717 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
718 return failure();
719
720 Location loc = op->getLoc();
721
722 auto memRefType = hostRegisterOp.getValue().getType();
723 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
724 auto elementSize = getSizeInBytes(loc, elementType, rewriter);
725
726 auto arguments = getTypeConverter()->promoteOperands(
727 loc, op->getOperands(), adaptor.getOperands(), rewriter);
728 arguments.push_back(elementSize);
729 hostRegisterCallBuilder.create(loc, rewriter, arguments);
730
731 rewriter.eraseOp(op);
732 return success();
733}
734
735LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
736 gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
737 ConversionPatternRewriter &rewriter) const {
738 Operation *op = hostUnregisterOp.getOperation();
739 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
740 return failure();
741
742 Location loc = op->getLoc();
743
744 auto memRefType = hostUnregisterOp.getValue().getType();
745 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
746 auto elementSize = getSizeInBytes(loc, elementType, rewriter);
747
748 auto arguments = getTypeConverter()->promoteOperands(
749 loc, op->getOperands(), adaptor.getOperands(), rewriter);
750 arguments.push_back(elementSize);
751 hostUnregisterCallBuilder.create(loc, rewriter, arguments);
752
753 rewriter.eraseOp(op);
754 return success();
755}
756
757LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
758 gpu::AllocOp allocOp, OpAdaptor adaptor,
759 ConversionPatternRewriter &rewriter) const {
760
761 MemRefType memRefType = allocOp.getType();
762
763 if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
764 !isConvertibleAndHasIdentityMaps(memRefType))
765 return failure();
766
767 auto loc = allocOp.getLoc();
768
769 bool isShared = allocOp.getHostShared();
770
771 if (isShared && allocOp.getAsyncToken())
772 return rewriter.notifyMatchFailure(
773 allocOp, "Host Shared allocation cannot be done async");
774 if (!isShared && failed(isAsyncWithOneDependency(rewriter, allocOp)))
775 return failure();
776
777 // Get shape of the memref as values: static sizes are constant
778 // values and dynamic sizes are passed to 'alloc' as operands.
779 SmallVector<Value, 4> shape;
780 SmallVector<Value, 4> strides;
781 Value sizeBytes;
782 getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
783 shape, strides, sizeBytes);
784
785 // Allocate the underlying buffer and store a pointer to it in the MemRef
786 // descriptor.
787 auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType);
788 Value stream = adaptor.getAsyncDependencies().empty()
789 ? nullPtr
790 : adaptor.getAsyncDependencies().front();
791
792 auto isHostShared = mlir::LLVM::ConstantOp::create(
793 rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
794
795 Value allocatedPtr =
796 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
797 .getResult();
798
799 // No alignment.
800 Value alignedPtr = allocatedPtr;
801
802 // Create the MemRef descriptor.
803 auto memRefDescriptor = this->createMemRefDescriptor(
804 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
805
806 if (allocOp.getAsyncToken()) {
807 // Async alloc: make dependent ops use the same stream.
808 rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
809 } else {
810 rewriter.replaceOp(allocOp, {memRefDescriptor});
811 }
812
813 return success();
814}
815
816LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
817 gpu::DeallocOp deallocOp, OpAdaptor adaptor,
818 ConversionPatternRewriter &rewriter) const {
819 if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
820 failed(isAsyncWithOneDependency(rewriter, deallocOp)))
821 return failure();
822
823 Location loc = deallocOp.getLoc();
824
825 Value pointer =
826 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
827 Value stream = adaptor.getAsyncDependencies().front();
828 deallocCallBuilder.create(loc, rewriter, {pointer, stream});
829
830 rewriter.replaceOp(deallocOp, {stream});
831 return success();
832}
833
834static bool isGpuAsyncTokenType(Value value) {
835 return isa<gpu::AsyncTokenType>(value.getType());
836}
837
838// Converts !gpu.async.token operands of `async.yield` to runtime calls. The
839// !gpu.async.token are lowered to stream within the async.execute region, but
840// are passed as events between them. For each !gpu.async.token operand, we
841// create an event and record it on the stream.
842//
843// This pattern is registered with a higher benefit than the structural
844// async.yield rewriter from populateAsyncStructuralTypeConversionsAndLegality
845// so it wins when both match. Without that benefit override, the structural
846// pattern can win and silently retype gpu.async.token operands without
847// recording an event, leaving the host await to call cuEventSynchronize on
848// a stream pointer (a no-op that returns an error), racing the host against
849// the GPU.
850LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
851 async::YieldOp yieldOp, OpAdaptor adaptor,
852 ConversionPatternRewriter &rewriter) const {
853 if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType))
854 return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
855
856 Location loc = yieldOp.getLoc();
857 SmallVector<Value, 4> newOperands(adaptor.getOperands());
858 llvm::SmallDenseSet<Value> streams;
859 for (auto &operand : yieldOp->getOpOperands()) {
860 if (!isGpuAsyncTokenType(operand.get()))
861 continue;
862 auto idx = operand.getOperandNumber();
863 auto stream = adaptor.getOperands()[idx];
864 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
865 eventRecordCallBuilder.create(loc, rewriter, {event, stream});
866 newOperands[idx] = event;
867 streams.insert(stream);
868 }
869 for (auto stream : streams)
870 streamDestroyCallBuilder.create(loc, rewriter, {stream});
871
872 rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
873 return success();
874}
875
876// Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
877static bool isDefinedByCallTo(Value value, StringRef functionName) {
878 assert(isa<LLVM::LLVMPointerType>(value.getType()));
879 if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
880 return *defOp.getCallee() == functionName;
881 return false;
882}
883
884// Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
885// with the stream/event operands. The operands are destroyed. That is, it
886// assumes that it is not used afterwards or elsewhere. Otherwise we will get a
887// runtime error. Eventually, we should guarantee this property.
888LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
889 gpu::WaitOp waitOp, OpAdaptor adaptor,
890 ConversionPatternRewriter &rewriter) const {
891 if (waitOp.getAsyncToken())
892 return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
893
894 Location loc = waitOp.getLoc();
895
896 for (auto operand : adaptor.getOperands()) {
897 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
898 // The converted operand's definition created a stream.
899 streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
900 streamDestroyCallBuilder.create(loc, rewriter, {operand});
901 } else {
902 // Otherwise the converted operand is an event. This assumes that we use
903 // events in control flow code as well.
904 eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
905 eventDestroyCallBuilder.create(loc, rewriter, {operand});
906 }
907 }
908
909 rewriter.eraseOp(waitOp);
910 return success();
911}
912
913// Converts `gpu.wait async` to runtime calls. The converted op creates a new
914// stream that is synchronized with stream/event operands. The operands are
915// destroyed. That is, it assumes that it is not used afterwards or elsewhere.
916// Otherwise we will get a runtime error. Eventually, we should guarantee this
917// property.
918LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
919 gpu::WaitOp waitOp, OpAdaptor adaptor,
920 ConversionPatternRewriter &rewriter) const {
921 if (!waitOp.getAsyncToken())
922 return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
923
924 Location loc = waitOp.getLoc();
925
926 auto insertionPoint = rewriter.saveInsertionPoint();
927 SmallVector<Value, 1> events;
928 for (auto pair :
929 llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
930 auto operand = std::get<1>(pair);
931 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
932 // The converted operand's definition created a stream. Insert an event
933 // into the stream just after the last use of the original token operand.
934 auto *defOp = std::get<0>(pair).getDefiningOp();
935 rewriter.setInsertionPointAfter(defOp);
936 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
937 eventRecordCallBuilder.create(loc, rewriter, {event, operand});
938 events.push_back(event);
939 } else {
940 // Otherwise the converted operand is an event. This assumes that we use
941 // events in control flow code as well.
942 events.push_back(operand);
943 }
944 }
945 rewriter.restoreInsertionPoint(insertionPoint);
946 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
947 for (auto event : events)
948 streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
949 for (auto event : events)
950 eventDestroyCallBuilder.create(loc, rewriter, {event});
951 rewriter.replaceOp(waitOp, {stream});
952
953 return success();
954}
955
956// Legalize the op's operands.
957LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
958 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
959 ConversionPatternRewriter &rewriter) const {
960 if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
961 return failure();
962
963 // Fail when the synchronous version of the op has async dependencies. The
964 // lowering destroys the stream, and we do not want to check that there is no
965 // use of the stream after this op.
966 if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
967 return rewriter.notifyMatchFailure(
968 launchOp, "Cannot convert non-async op with async dependencies.");
969
970 Location loc = launchOp.getLoc();
971
972 Value stream = Value();
973 if (!adaptor.getAsyncDependencies().empty()) {
974 stream = adaptor.getAsyncDependencies().front();
975 // Synchronize additional async dependencies onto the primary stream using
976 // events, following the same approach as gpu.wait async lowering.
977 if (adaptor.getAsyncDependencies().size() > 1) {
978 auto insertionPoint = rewriter.saveInsertionPoint();
979 SmallVector<Value, 4> events;
980 for (auto [origDep, convertedDep] :
981 llvm::zip(launchOp.getAsyncDependencies().drop_front(),
982 adaptor.getAsyncDependencies().drop_front())) {
983 if (!isDefinedByCallTo(convertedDep,
984 streamCreateCallBuilder.functionName)) {
985 events.push_back(convertedDep);
986 continue;
987 }
988 Operation *defOp = origDep.getDefiningOp();
989 rewriter.setInsertionPointAfter(defOp);
990 Value event =
991 eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
992 eventRecordCallBuilder.create(loc, rewriter, {event, convertedDep});
993 events.push_back(event);
994 }
995 rewriter.restoreInsertionPoint(insertionPoint);
996 for (Value event : events)
997 streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
998 for (Value event : events)
999 eventDestroyCallBuilder.create(loc, rewriter, {event});
1000 }
1001 }
1002 // If the async keyword is present and there are no dependencies, then a
1003 // stream must be created to pass to subsequent operations.
1004 else if (launchOp.getAsyncToken())
1005 stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
1006
1007 // Lower the kernel operands to match kernel parameters.
1008 // Note: If `useBarePtrCallConv` is set in the type converter's options,
1009 // the value of `kernelBarePtrCallConv` will be ignored.
1010 OperandRange origArguments = launchOp.getKernelOperands();
1011 bool effectiveBarePtr = kernelBarePtrCallConv ||
1012 getTypeConverter()->getOptions().useBarePtrCallConv;
1013 if (effectiveBarePtr) {
1014 for (Value arg : origArguments) {
1015 if (isa<UnrankedMemRefType>(arg.getType()))
1016 return rewriter.notifyMatchFailure(
1017 loc, "unranked memref kernel argument is not supported with "
1018 "the bare-pointer calling convention");
1019 }
1020 }
1021 SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
1022 loc, origArguments, adaptor.getKernelOperands(), rewriter,
1023 /*useBarePtrCallConv=*/kernelBarePtrCallConv);
1024 SmallVector<Value, 8> llvmArgumentsWithSizes;
1025
1026 // Intersperse size information if requested.
1027 if (kernelIntersperseSizeCallConv) {
1028 if (origArguments.size() != llvmArguments.size()) {
1029 // This shouldn't happen if the bare-pointer calling convention is used.
1030 return rewriter.notifyMatchFailure(
1031 launchOp,
1032 "Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
1033 }
1034
1035 llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
1036 for (auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
1037 auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
1038 if (!memrefTy) {
1039 return rewriter.notifyMatchFailure(
1040 launchOp, "Operand to launch op is not a memref.");
1041 }
1042
1043 if (!memrefTy.hasStaticShape() ||
1044 !memrefTy.getElementType().isIntOrFloat()) {
1045 return rewriter.notifyMatchFailure(
1046 launchOp, "Operand to launch op is not a memref with a static "
1047 "shape and an integer or float element type.");
1048 }
1049
1050 unsigned bitwidth = memrefTy.getElementTypeBitWidth();
1051 if (bitwidth % 8 != 0) {
1052 return rewriter.notifyMatchFailure(
1053 launchOp, "Operand to launch op is not a memref with a "
1054 "byte-aligned element type.");
1055 }
1056
1057 uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) *
1058 static_cast<uint64_t>(memrefTy.getNumElements());
1059
1060 Value sizeArg = LLVM::ConstantOp::create(
1061 rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize));
1062 llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer.
1063 llvmArgumentsWithSizes.push_back(sizeArg);
1064 }
1065 }
1066
1067 std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1068 if (launchOp.hasClusterSize()) {
1069 clusterSize =
1070 gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
1071 adaptor.getClusterSizeZ()};
1072 }
1073 gpu::LaunchFuncOp::create(
1074 rewriter, launchOp.getLoc(), launchOp.getKernelAttr(),
1075 gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1076 adaptor.getGridSizeZ()},
1077 gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1078 adaptor.getBlockSizeZ()},
1079 adaptor.getDynamicSharedMemorySize(),
1080 llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
1081 stream, clusterSize);
1082 if (launchOp.getAsyncToken())
1083 rewriter.replaceOp(launchOp, {stream});
1084 else
1085 rewriter.eraseOp(launchOp);
1086 return success();
1087}
1088
1090 ConversionPatternRewriter &rewriter,
1091 LLVM::LLVMPointerType destinationType,
1092 Value sourcePtr,
1093 const LLVMTypeConverter &typeConverter) {
1094 auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
1095 if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1096 sourcePtr = LLVM::AddrSpaceCastOp::create(
1097 rewriter, loc,
1098 LLVM::LLVMPointerType::get(rewriter.getContext(),
1099 destinationType.getAddressSpace()),
1100 sourcePtr);
1101 return sourcePtr;
1102}
1103
1104LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1105 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1106 ConversionPatternRewriter &rewriter) const {
1107 auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1108
1109 if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
1110 !isConvertibleAndHasIdentityMaps(memRefType) ||
1111 failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
1112 return failure();
1113
1114 auto loc = memcpyOp.getLoc();
1115
1116 MemRefDescriptor srcDesc(adaptor.getSrc());
1117 Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
1118
1119 Type elementPtrType = getElementPtrType(memRefType);
1120 Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
1121 Value gepPtr = LLVM::GEPOp::create(
1122 rewriter, loc, elementPtrType,
1123 typeConverter->convertType(memRefType.getElementType()), nullPtr,
1124 numElements);
1125 auto sizeBytes =
1126 LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr);
1127
1128 auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1129 srcDesc.alignedPtr(rewriter, loc),
1130 *getTypeConverter());
1131 auto dst = bitAndAddrspaceCast(
1132 loc, rewriter, llvmPointerType,
1133 MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
1134 *getTypeConverter());
1135
1136 auto stream = adaptor.getAsyncDependencies().front();
1137 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1138
1139 rewriter.replaceOp(memcpyOp, {stream});
1140
1141 return success();
1142}
1143
1144LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1145 gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1146 ConversionPatternRewriter &rewriter) const {
1147 auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1148
1149 if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
1150 !isConvertibleAndHasIdentityMaps(memRefType) ||
1151 failed(isAsyncWithOneDependency(rewriter, memsetOp)))
1152 return failure();
1153
1154 auto loc = memsetOp.getLoc();
1155
1156 Type valueType = adaptor.getValue().getType();
1157 unsigned bitWidth = valueType.getIntOrFloatBitWidth();
1158 // Ints and floats of 16 or 32 bit width are allowed.
1159 if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1160 return rewriter.notifyMatchFailure(
1161 memsetOp, "value must be a 16 or 32 bit int or float");
1162 }
1163
1164 unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth();
1165 Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1166
1167 MemRefDescriptor dstDesc(adaptor.getDst());
1168 Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
1169
1170 auto value =
1171 LLVM::BitcastOp::create(rewriter, loc, bitCastType, adaptor.getValue());
1172 auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1173 dstDesc.alignedPtr(rewriter, loc),
1174 *getTypeConverter());
1175
1176 auto stream = adaptor.getAsyncDependencies().front();
1177 FunctionCallBuilder builder =
1178 valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1179 builder.create(loc, rewriter, {dst, value, numElements, stream});
1180
1181 rewriter.replaceOp(memsetOp, {stream});
1182 return success();
1183}
1184
1185LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1186 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1187 ConversionPatternRewriter &rewriter) const {
1188 Location loc = op.getLoc();
1189 auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1190 {adaptor.getDevIndex()});
1191 rewriter.replaceOp(op, call);
1192 return success();
1193}
1194
1195template <typename T>
1196static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) {
1197 Type llvmInt32Type = builder.getIntegerType(32);
1198 return LLVM::ConstantOp::create(builder, loc, llvmInt32Type,
1199 static_cast<int32_t>(tValue));
1200}
1201
1202template <typename T>
1203static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) {
1204 Type llvmFloat32Type = builder.getF32Type();
1205 return LLVM::ConstantOp::create(
1206 builder, loc, llvmFloat32Type,
1207 builder.getF32FloatAttr(static_cast<float>(tValue)));
1208}
1209
1210LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1211 gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1212 ConversionPatternRewriter &rewriter) const {
1213 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1214 failed(isAsyncWithOneDependency(rewriter, op)))
1215 return failure();
1216 Location loc = op.getLoc();
1217 auto stream = adaptor.getAsyncDependencies().front();
1218 Value pTensor =
1219 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1220 Type dType = op.getMemref().getType().getElementType();
1221 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1222
1223 SmallVector<Value, 4> dims;
1224 for (Value dim : adaptor.getDims()) {
1225 dims.push_back(dim);
1226 }
1227
1228 Value handle;
1229 // TODO: For now, we track the use of the handle and lower it to cusparse /
1230 // cusparseLt accordingly. If in a block, both cusparse and cusparseLt are
1231 // used, we require two separate Creation ops to be the correct logic. In
1232 // future, we may add support to using one handle in sparse tensor / GPU
1233 // dialect in both cusparse and cusparseLt. use the cusparseLt create call if
1234 // the dnmat is used with spmat with 2:4 sparsity
1235 if (dims.size() == 2) {
1236 if (isSpMMCusparseLtOp(op.getDnTensor())) {
1237 auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1238 rewriter.getIndexAttr(11032));
1239 handle = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1240 llvmInt8Type, handleSz, /*alignment=*/16);
1241 handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1242
1243 createLtDnMatCallBuilder
1244 .create(loc, rewriter,
1245 {handle, dims[0], dims[1], pTensor, dtp, stream})
1246 .getResult();
1247 } else {
1248 handle =
1249 createDnMatCallBuilder
1250 .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1251 .getResult();
1252 }
1253 } else {
1254 assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1255 handle = createDnVecCallBuilder
1256 .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1257 .getResult();
1258 }
1259 rewriter.replaceOp(op, {handle, stream});
1260 return success();
1261}
1262
1263LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1264 gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1265 ConversionPatternRewriter &rewriter) const {
1266 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1267 failed(isAsyncWithOneDependency(rewriter, op)))
1268 return failure();
1269 Location loc = op.getLoc();
1270 auto stream = adaptor.getAsyncDependencies().front();
1271 auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
1272 SmallVector<Value, 4> dims;
1273 for (Value dim : definingOp.getDims()) {
1274 dims.push_back(dim);
1275 }
1276 if (dims.size() == 2) {
1277 // Use the cusparseLt destroy call if the dnmat is used with spmat with
1278 // 2:4 sparsity
1279 if (isSpMMCusparseLtOp(op.getDnTensor())) {
1280 destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1281 {adaptor.getDnTensor(), stream});
1282 } else {
1283 destroyDnMatCallBuilder.create(loc, rewriter,
1284 {adaptor.getDnTensor(), stream});
1285 }
1286 } else {
1287 assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1288 destroyDnVecCallBuilder.create(loc, rewriter,
1289 {adaptor.getDnTensor(), stream});
1290 }
1291 rewriter.replaceOp(op, {stream});
1292 return success();
1293}
1294
1295LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1296 gpu::CreateCooOp op, OpAdaptor adaptor,
1297 ConversionPatternRewriter &rewriter) const {
1298 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1299 failed(isAsyncWithOneDependency(rewriter, op)))
1300 return failure();
1301 Location loc = op.getLoc();
1302 auto stream = adaptor.getAsyncDependencies().front();
1303 Value pRowIdxs =
1304 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1305 Value pColIdxs =
1306 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1307 Value pValues =
1308 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1309 Type iType =
1310 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1311 Type dType =
1312 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1313 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1314 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1315 auto handle =
1316 createCooCallBuilder
1317 .create(loc, rewriter,
1318 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1319 pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1320 .getResult();
1321 rewriter.replaceOp(op, {handle, stream});
1322 return success();
1323}
1324
1325LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1326 gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1327 ConversionPatternRewriter &rewriter) const {
1328 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1329 failed(isAsyncWithOneDependency(rewriter, op)))
1330 return failure();
1331 Location loc = op.getLoc();
1332 auto stream = adaptor.getAsyncDependencies().front();
1333 Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc);
1334 Value pValues =
1335 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1336 Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1337 Type dType =
1338 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1339 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1340 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1341 auto handle =
1342 createCooAoSCallBuilder
1343 .create(loc, rewriter,
1344 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1345 pIdxs, pValues, itp, dtp, stream})
1346 .getResult();
1347 rewriter.replaceOp(op, {handle, stream});
1348 return success();
1349}
1350
1351LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1352 gpu::CreateCsrOp op, OpAdaptor adaptor,
1353 ConversionPatternRewriter &rewriter) const {
1354 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1355 failed(isAsyncWithOneDependency(rewriter, op)))
1356 return failure();
1357 Location loc = op.getLoc();
1358 auto stream = adaptor.getAsyncDependencies().front();
1359 Value pRowPos =
1360 MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc);
1361 Value pColIdxs =
1362 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1363 Value pValues =
1364 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1365 Type pType =
1366 llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1367 Type iType =
1368 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1369 Type dType =
1370 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1371 auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1372 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1373 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1374 auto handle =
1375 createCsrCallBuilder
1376 .create(loc, rewriter,
1377 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1378 pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1379 .getResult();
1380 rewriter.replaceOp(op, {handle, stream});
1381 return success();
1382}
1383
1384LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1385 gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1386 ConversionPatternRewriter &rewriter) const {
1387 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1388 failed(isAsyncWithOneDependency(rewriter, op)))
1389 return failure();
1390 Location loc = op.getLoc();
1391 auto stream = adaptor.getAsyncDependencies().front();
1392 Value pMat =
1393 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1394 Type dType =
1395 llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1396 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1397
1398 // CUDA runner asserts the size is 44104 bytes.
1399 auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1400 rewriter.getIndexAttr(44104));
1401 Value handle = LLVM::AllocaOp::create(
1402 rewriter, loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1403 handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1404
1405 create2To4SpMatCallBuilder
1406 .create(loc, rewriter,
1407 {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1408 .getResult();
1409 rewriter.replaceOp(op, {handle, stream});
1410 return success();
1411}
1412
1413LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1414 gpu::DestroySpMatOp op, OpAdaptor adaptor,
1415 ConversionPatternRewriter &rewriter) const {
1416 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1417 failed(isAsyncWithOneDependency(rewriter, op)))
1418 return failure();
1419 Location loc = op.getLoc();
1420 auto stream = adaptor.getAsyncDependencies().front();
1421 // Use the cusparseLt destroy call if the spmat is 2:4 sparsity
1422 if (is2To4Sparsity(op.getSpmat())) {
1423 destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1424 {adaptor.getSpmat(), stream});
1425
1426 } else {
1427 destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1428 }
1429 rewriter.replaceOp(op, {stream});
1430 return success();
1431}
1432
1433LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1434 gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1435 ConversionPatternRewriter &rewriter) const {
1436 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1437 failed(isAsyncWithOneDependency(rewriter, op)))
1438 return failure();
1439 Location loc = op.getLoc();
1440 auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
1441 auto computeType = genConstInt32From(
1442 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1443 auto stream = adaptor.getAsyncDependencies().front();
1444 auto bufferSize = spMVBufferSizeCallBuilder
1445 .create(loc, rewriter,
1446 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1447 adaptor.getDnY(), computeType, stream})
1448 .getResult();
1449 rewriter.replaceOp(op, {bufferSize, stream});
1450 return success();
1451}
1452
1453LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1454 gpu::SpMVOp op, OpAdaptor adaptor,
1455 ConversionPatternRewriter &rewriter) const {
1456 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1457 failed(isAsyncWithOneDependency(rewriter, op)))
1458 return failure();
1459 Location loc = op.getLoc();
1460 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1461 auto computeType = genConstInt32From(
1462 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1463 auto stream = adaptor.getAsyncDependencies().front();
1464 Value pBuf =
1465 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1466 spMVCallBuilder.create(loc, rewriter,
1467 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1468 adaptor.getDnY(), computeType, pBuf, stream});
1469 rewriter.replaceOp(op, {stream});
1470 return success();
1471}
1472
1473LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1474 gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1475 ConversionPatternRewriter &rewriter) const {
1476 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1477 failed(isAsyncWithOneDependency(rewriter, op)))
1478 return failure();
1479 Location loc = op.getLoc();
1480 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1481 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1482 auto stream = adaptor.getAsyncDependencies().front();
1483 Value bufferSize;
1484 if (is2To4Sparsity(op.getSpmatA())) {
1485 auto pruneFlag =
1486 genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA()));
1487 auto computeType = genConstInt32From(
1488 rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
1489 auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1490 rewriter.getIndexAttr(3));
1491 auto bufferSize =
1492 LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmPointerType,
1493 three, /*alignment=*/16);
1494 createCuSparseLtSpMMBufferSizeBuilder
1495 .create(loc, rewriter,
1496 {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1497 adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1498 pruneFlag, stream})
1499 .getResult();
1500
1501 auto bufferSizePtr1 = LLVM::GEPOp::create(
1502 rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1503 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1504 rewriter.getIndexAttr(1))});
1505 auto bufferSizePtr2 = LLVM::GEPOp::create(
1506 rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1507 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1508 rewriter.getIndexAttr(2))});
1509 auto bufferSize0 =
1510 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSize);
1511 auto bufferSize1 =
1512 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr1);
1513 auto bufferSize2 =
1514 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr2);
1515
1516 rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1517 } else {
1518 auto computeType = genConstInt32From(
1519 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1520 bufferSize =
1521 createSpMMBufferSizeCallBuilder
1522 .create(loc, rewriter,
1523 {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1524 adaptor.getDnmatC(), computeType, stream})
1525 .getResult();
1526 rewriter.replaceOp(op, {bufferSize, stream});
1527 }
1528 return success();
1529}
1530
1531LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1532 gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1533 ConversionPatternRewriter &rewriter) const {
1534 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1535 failed(isAsyncWithOneDependency(rewriter, op)))
1536 return failure();
1537 Location loc = op.getLoc();
1538 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1539 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1540 auto computeType = genConstInt32From(
1541 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1542 auto stream = adaptor.getAsyncDependencies().front();
1543 auto bufferSize =
1544 createSDDMMBufferSizeCallBuilder
1545 .create(loc, rewriter,
1546 {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1547 adaptor.getSpmatC(), computeType, stream})
1548 .getResult();
1549 rewriter.replaceOp(op, {bufferSize, stream});
1550 return success();
1551}
1552
1553LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1554 gpu::SpMMOp op, OpAdaptor adaptor,
1555 ConversionPatternRewriter &rewriter) const {
1556 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1557 failed(isAsyncWithOneDependency(rewriter, op)))
1558 return failure();
1559 Location loc = op.getLoc();
1560 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1561 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1562 auto computeType = genConstInt32From(
1563 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1564
1565 auto stream = adaptor.getAsyncDependencies().front();
1566
1567 // Lower to cusparseLt if applicable
1568 if (is2To4Sparsity(op.getSpmatA())) {
1569 SmallVector<Value> pBufs;
1570 for (Value buffer : adaptor.getBuffers()) {
1571 Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
1572 pBufs.push_back(pBuf);
1573 }
1574 createCuSparseLtSpMMBuilder.create(
1575 loc, rewriter,
1576 {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1577 pBufs[0], pBufs[1], pBufs[2], stream});
1578 } else {
1579 Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
1580 .allocatedPtr(rewriter, loc);
1581 createSpMMCallBuilder.create(loc, rewriter,
1582 {modeA, modeB, adaptor.getSpmatA(),
1583 adaptor.getDnmatB(), adaptor.getDnmatC(),
1584 computeType, pBuf, stream});
1585 }
1586 rewriter.replaceOp(op, {stream});
1587 return success();
1588}
1589
1590template <typename T>
1592 converter.addConversion([&converter](T) -> Type {
1593 return LLVM::LLVMPointerType::get(&converter.getContext());
1594 });
1595}
1596
1597LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1598 gpu::SDDMMOp op, OpAdaptor adaptor,
1599 ConversionPatternRewriter &rewriter) const {
1600 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1601 failed(isAsyncWithOneDependency(rewriter, op)))
1602 return failure();
1603 Location loc = op.getLoc();
1604 auto computeType = genConstInt32From(
1605 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1606 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1607 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1608 auto stream = adaptor.getAsyncDependencies().front();
1609 Value pBuf =
1610 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1611 createSDDMMCallBuilder.create(loc, rewriter,
1612 {modeA, modeB, adaptor.getDnmatA(),
1613 adaptor.getDnmatB(), adaptor.getSpmatC(),
1614 computeType, pBuf, stream});
1615 rewriter.replaceOp(op, {stream});
1616 return success();
1617}
1618
1619LogicalResult
1620ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1621 gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1622 ConversionPatternRewriter &rewriter) const {
1623 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1624 failed(isAsyncWithOneDependency(rewriter, op)))
1625 return failure();
1626 Location loc = op.getLoc();
1627 auto stream = adaptor.getAsyncDependencies().front();
1628 Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1629 .getResult();
1630 rewriter.replaceOp(op, {descr, stream});
1631 return success();
1632}
1633
1634LogicalResult
1635ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1636 gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1637 ConversionPatternRewriter &rewriter) const {
1638 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1639 failed(isAsyncWithOneDependency(rewriter, op)))
1640 return failure();
1641 Location loc = op.getLoc();
1642 auto stream = adaptor.getAsyncDependencies().front();
1643 createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1644 {adaptor.getDesc(), stream});
1645 rewriter.replaceOp(op, {stream});
1646 return success();
1647}
1648
1649LogicalResult
1650ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1651 gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1652 ConversionPatternRewriter &rewriter) const {
1653 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1654 failed(isAsyncWithOneDependency(rewriter, op)))
1655 return failure();
1656 Location loc = op.getLoc();
1657 auto computeType = genConstInt32From(
1658 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1659 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1660 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1661 auto stream = adaptor.getAsyncDependencies().front();
1662
1663 Value pBuf =
1664 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1665 Value bufferSizeNew;
1666
1667 if (adaptor.getKind() ==
1668 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1669 bufferSizeNew =
1670 createSpGEMMWorkEstimationBuilder
1671 .create(loc, rewriter,
1672 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1673 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1674 adaptor.getBufferSz(), pBuf, stream})
1675 .getResult();
1676 } else {
1677 bufferSizeNew =
1678 createSpGEMMComputeBuilder
1679 .create(loc, rewriter,
1680 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1681 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1682 adaptor.getBufferSz(), pBuf, stream})
1683 .getResult();
1684 }
1685 rewriter.replaceOp(op, {bufferSizeNew, stream});
1686 return success();
1687}
1688
1689LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1690 gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1691 ConversionPatternRewriter &rewriter) const {
1692 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1693 failed(isAsyncWithOneDependency(rewriter, op)))
1694 return failure();
1695 Location loc = op.getLoc();
1696 auto computeType = genConstInt32From(
1697 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1698 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1699 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1700 auto stream = adaptor.getAsyncDependencies().front();
1701 createSpGEMMCopyBuilder.create(loc, rewriter,
1702 {adaptor.getDesc(), modeA, modeB,
1703 adaptor.getSpmatA(), adaptor.getSpmatB(),
1704 adaptor.getSpmatC(), computeType, stream});
1705 rewriter.replaceOp(op, {stream});
1706 return success();
1707}
1708
1709LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1710 gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1711 ConversionPatternRewriter &rewriter) const {
1712 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1713 failed(isAsyncWithOneDependency(rewriter, op)))
1714 return failure();
1715 Location loc = op.getLoc();
1716 auto stream = adaptor.getAsyncDependencies().front();
1717
1718 auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1719 rewriter.getIndexAttr(3));
1720 auto buffer = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1721 llvmInt64Type, three, /*alignment=*/16);
1722
1723 auto rowsPtr = LLVM::GEPOp::create(
1724 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1725 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1726 rewriter.getIndexAttr(0))});
1727 auto colsPtr = LLVM::GEPOp::create(
1728 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1729 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1730 rewriter.getIndexAttr(1))});
1731 auto nnzsPtr = LLVM::GEPOp::create(
1732 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1733 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1734 rewriter.getIndexAttr(2))});
1735 createSpMatGetSizeBuilder.create(
1736 loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1737 auto rows = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, rowsPtr);
1738 auto cols = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, colsPtr);
1739 auto nnzs = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, nnzsPtr);
1740
1741 rewriter.replaceOp(op, {rows, cols, nnzs, stream});
1742 return success();
1743}
1744
1745LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1746 gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1747 ConversionPatternRewriter &rewriter) const {
1748 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1749 failed(isAsyncWithOneDependency(rewriter, op)))
1750 return failure();
1751 Location loc = op.getLoc();
1752 auto stream = adaptor.getAsyncDependencies().front();
1753 Value pPos =
1754 MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
1755 Value pCrd =
1756 MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
1757 Value pVal =
1758 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1759 createSetCsrPointersBuilder.create(
1760 loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1761 rewriter.replaceOp(op, {stream});
1762 return success();
1763}
1764
1765LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1766 gpu::CreateCscOp op, OpAdaptor adaptor,
1767 ConversionPatternRewriter &rewriter) const {
1768 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1769 failed(isAsyncWithOneDependency(rewriter, op)))
1770 return failure();
1771 Location loc = op.getLoc();
1772 auto stream = adaptor.getAsyncDependencies().front();
1773 Value pColPos =
1774 MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc);
1775 Value pRowIdxs =
1776 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1777 Value pValues =
1778 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1779 Type pType =
1780 llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1781 Type iType =
1782 llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1783 Type dType =
1784 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1785 auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1786 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1787 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1788 auto handle =
1789 createCscCallBuilder
1790 .create(loc, rewriter,
1791 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1792 pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1793 .getResult();
1794 rewriter.replaceOp(op, {handle, stream});
1795 return success();
1796}
1797
1798LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1799 gpu::CreateBsrOp op, OpAdaptor adaptor,
1800 ConversionPatternRewriter &rewriter) const {
1801 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1802 failed(isAsyncWithOneDependency(rewriter, op)))
1803 return failure();
1804 Location loc = op.getLoc();
1805 auto stream = adaptor.getAsyncDependencies().front();
1806 Value pRowPos =
1807 MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc);
1808 Value pColIdxs =
1809 MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc);
1810 Value pValues =
1811 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1812 Type pType =
1813 llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1814 Type iType =
1815 llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1816 Type dType =
1817 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1818 auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1819 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1820 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1821 auto handle =
1822 createBsrCallBuilder
1823 .create(loc, rewriter,
1824 {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1825 adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1826 pColIdxs, pValues, ptp, itp, dtp, stream})
1827 .getResult();
1828 rewriter.replaceOp(op, {handle, stream});
1829 return success();
1830}
1831
1833 LLVMTypeConverter &converter, RewritePatternSet &patterns,
1834 bool kernelBarePtrCallConv, bool kernelIntersperseSizeCallConv) {
1839
1840 // Higher benefit so this pattern wins over the structural async.yield
1841 // rewriter from populateAsyncStructuralTypeConversionsAndLegality on yields
1842 // with gpu.async.token operands. The structural rewriter would silently
1843 // retype operands without recording an event on the underlying stream.
1844 patterns.add<ConvertAsyncYieldToGpuRuntimeCallPattern>(converter,
1845 /*benefit=*/2);
1846
1847 patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
1848 ConvertDeallocOpToGpuRuntimeCallPattern,
1849 ConvertHostRegisterOpToGpuRuntimeCallPattern,
1850 ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1851 ConvertMemcpyOpToGpuRuntimeCallPattern,
1852 ConvertMemsetOpToGpuRuntimeCallPattern,
1853 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1854 ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1855 ConvertWaitOpToGpuRuntimeCallPattern,
1856 ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1857 ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1858 ConvertCreateCooOpToGpuRuntimeCallPattern,
1859 ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1860 ConvertCreateCsrOpToGpuRuntimeCallPattern,
1861 ConvertCreateCscOpToGpuRuntimeCallPattern,
1862 ConvertCreateBsrOpToGpuRuntimeCallPattern,
1863 ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1864 ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1865 ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1866 ConvertSpMVOpToGpuRuntimeCallPattern,
1867 ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1868 ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1869 ConvertSpMMOpToGpuRuntimeCallPattern,
1870 ConvertSDDMMOpToGpuRuntimeCallPattern,
1871 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1872 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1873 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1874 ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1875 ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1876 ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
1877 patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv,
1878 kernelIntersperseSizeCallConv);
1879}
1880
1881//===----------------------------------------------------------------------===//
1882// GPUModuleOp convert to LLVM op interface
1883//===----------------------------------------------------------------------===//
1884
1885namespace {
1886struct GPUModuleOpConvertToLLVMInterface
1887 : public ConvertToLLVMOpInterface::ExternalModel<
1888 GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
1889 /// Get the conversion patterns from the target attribute.
1890 void getConvertToLLVMConversionAttrs(
1892};
1893} // namespace
1894
1895void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
1896 Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const {
1897 auto module = cast<gpu::GPUModuleOp>(op);
1898 ArrayAttr targetsAttr = module.getTargetsAttr();
1899 // Fail if there are no target attributes or there is more than one target.
1900 if (!targetsAttr || targetsAttr.size() != 1)
1901 return;
1902 if (auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
1903 attrs.push_back(patternAttr);
1904}
1905
1907 registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
1908 gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
1909 });
1910}
return success()
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
ArrayAttr()
b getContext())
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
FloatType getF32Type()
Definition Builders.cpp:47
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
FloatAttr getF32FloatAttr(float value)
Definition Builders.cpp:250
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition Pattern.cpp:38
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition Pattern.cpp:58
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
This class helps build Operations.
Definition Builders.h:209
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition Builders.h:248
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void registerConvertGpuToLLVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMOpInterface interface on the gpu::GPUModuleOP operation.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool kernelBarePtrCallConv=false, bool kernelIntersperseSizeCallConv=false)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
LLVM::LLVMFunctionType functionType
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const