MLIR 23.0.0git
GPUToLLVMConversion.cpp
Go to the documentation of this file.
1//===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert gpu.launch_func op into a sequence of
10// GPU runtime calls. As most of GPU runtimes does not have a stable published
11// ABI, this pass uses a slim runtime layer that builds on top of the public
12// API from GPU runtime headers.
13//
14//===----------------------------------------------------------------------===//
15
17
34#include "mlir/IR/Attributes.h"
35#include "mlir/IR/Builders.h"
36#include "mlir/IR/BuiltinOps.h"
39
40#include "llvm/ADT/STLExtras.h"
41
42#define DEBUG_TYPE "gpu-to-llvm"
43
44namespace mlir {
45#define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
46#include "mlir/Conversion/Passes.h.inc"
47} // namespace mlir
48
49using namespace mlir;
50
51namespace {
52class GpuToLLVMConversionPass
53 : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
54public:
55 using Base::Base;
56 void getDependentDialects(DialectRegistry &registry) const final {
57 Base::getDependentDialects(registry);
59 }
60 // Run the dialect converter on the module.
61 void runOnOperation() override;
62};
63
64template <typename OpTy>
65class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
66public:
67 explicit ConvertOpToGpuRuntimeCallPattern(
68 const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
69 : ConvertOpToLLVMPattern<OpTy>(typeConverter, benefit) {}
70
71protected:
72 Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
73 MemRefType type, MemRefDescriptor desc) const {
74 Type indexType = ConvertToLLVMPattern::getIndexType();
75 if (type.hasStaticShape())
77 rewriter, loc, indexType, type.getNumElements());
78 // Compute the number of elements by multiplying all the dim sizes.
79 uint64_t rank = type.getRank();
80 Value numElements = desc.size(rewriter, loc, /*pos=*/0);
81 for (unsigned i = 1; i < rank; i++)
82 numElements = LLVM::MulOp::create(rewriter, loc, numElements,
83 desc.size(rewriter, loc, /*pos=*/i));
84 return numElements;
85 }
86
87 MLIRContext *context = &this->getTypeConverter()->getContext();
88
89 Type llvmVoidType = LLVM::LLVMVoidType::get(context);
90 LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
91 Type llvmInt8Type = IntegerType::get(context, 8);
92 Type llvmInt16Type = IntegerType::get(context, 16);
93 Type llvmInt32Type = IntegerType::get(context, 32);
94 Type llvmInt64Type = IntegerType::get(context, 64);
95 Type llvmFloat32Type = Float32Type::get(context);
96 Type llvmIntPtrType = IntegerType::get(
97 context, this->getTypeConverter()->getPointerBitwidth(0));
98
99 FunctionCallBuilder streamCreateCallBuilder = {
100 "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
101 FunctionCallBuilder streamDestroyCallBuilder = {
102 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
103 FunctionCallBuilder streamSynchronizeCallBuilder = {
104 "mgpuStreamSynchronize",
105 llvmVoidType,
106 {llvmPointerType /* void *stream */}};
107 FunctionCallBuilder streamWaitEventCallBuilder = {
108 "mgpuStreamWaitEvent",
109 llvmVoidType,
110 {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
111 FunctionCallBuilder eventCreateCallBuilder = {
112 "mgpuEventCreate", llvmPointerType /* void *event */, {}};
113 FunctionCallBuilder eventDestroyCallBuilder = {
114 "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
115 FunctionCallBuilder eventSynchronizeCallBuilder = {
116 "mgpuEventSynchronize",
117 llvmVoidType,
118 {llvmPointerType /* void *event */}};
119 FunctionCallBuilder eventRecordCallBuilder = {
120 "mgpuEventRecord",
121 llvmVoidType,
122 {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
123 FunctionCallBuilder hostRegisterCallBuilder = {
124 "mgpuMemHostRegisterMemRef",
125 llvmVoidType,
126 {llvmIntPtrType /* intptr_t rank */,
127 llvmPointerType /* void *memrefDesc */,
128 llvmIntPtrType /* intptr_t elementSizeBytes */}};
129 FunctionCallBuilder hostUnregisterCallBuilder = {
130 "mgpuMemHostUnregisterMemRef",
131 llvmVoidType,
132 {llvmIntPtrType /* intptr_t rank */,
133 llvmPointerType /* void *memrefDesc */,
134 llvmIntPtrType /* intptr_t elementSizeBytes */}};
135 FunctionCallBuilder allocCallBuilder = {
136 "mgpuMemAlloc",
137 llvmPointerType /* void * */,
138 {llvmIntPtrType /* intptr_t sizeBytes */,
139 llvmPointerType /* void *stream */,
140 llvmInt8Type /* bool isHostShared */}};
141 FunctionCallBuilder deallocCallBuilder = {
142 "mgpuMemFree",
143 llvmVoidType,
144 {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
145 FunctionCallBuilder memcpyCallBuilder = {
146 "mgpuMemcpy",
147 llvmVoidType,
148 {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
149 llvmIntPtrType /* intptr_t sizeBytes */,
150 llvmPointerType /* void *stream */}};
151 FunctionCallBuilder memset16CallBuilder = {
152 "mgpuMemset16",
153 llvmVoidType,
154 {llvmPointerType /* void *dst */,
155 llvmInt16Type /* unsigned short value */,
156 llvmIntPtrType /* intptr_t sizeBytes */,
157 llvmPointerType /* void *stream */}};
158 FunctionCallBuilder memset32CallBuilder = {
159 "mgpuMemset32",
160 llvmVoidType,
161 {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
162 llvmIntPtrType /* intptr_t sizeBytes */,
163 llvmPointerType /* void *stream */}};
164 FunctionCallBuilder setDefaultDeviceCallBuilder = {
165 "mgpuSetDefaultDevice",
166 llvmVoidType,
167 {llvmInt32Type /* uint32_t devIndex */}};
168 FunctionCallBuilder createDnVecCallBuilder = {
169 "mgpuCreateDnVec",
170 llvmPointerType,
171 {llvmIntPtrType, llvmPointerType, llvmInt32Type,
172 llvmPointerType /* void *stream */}};
173 FunctionCallBuilder destroyDnVecCallBuilder = {
174 "mgpuDestroyDnVec",
175 llvmVoidType,
176 {llvmPointerType, llvmPointerType /* void *stream */}};
177 FunctionCallBuilder createDnMatCallBuilder = {
178 "mgpuCreateDnMat",
179 llvmPointerType,
180 {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
181 llvmPointerType /* void *stream */}};
182 FunctionCallBuilder destroyDnMatCallBuilder = {
183 "mgpuDestroyDnMat",
184 llvmVoidType,
185 {llvmPointerType, llvmPointerType /* void *stream */}};
186 FunctionCallBuilder createCooCallBuilder = {
187 "mgpuCreateCoo",
188 llvmPointerType,
189 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
190 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
191 llvmPointerType /* void *stream */}};
192 FunctionCallBuilder createCooAoSCallBuilder = {
193 "mgpuCreateCooAoS", // deprecated in cuSPARSE 11.2
194 llvmPointerType,
195 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
196 llvmPointerType, llvmInt32Type, llvmInt32Type,
197 llvmPointerType /* void *stream */}};
198 FunctionCallBuilder createCsrCallBuilder = {
199 "mgpuCreateCsr",
200 llvmPointerType,
201 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
202 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
203 llvmInt32Type, llvmPointerType /* void *stream */}};
204 FunctionCallBuilder createCscCallBuilder = {
205 "mgpuCreateCsc",
206 llvmPointerType,
207 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
208 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
209 llvmInt32Type, llvmPointerType /* void *stream */}};
210 FunctionCallBuilder createBsrCallBuilder = {
211 "mgpuCreateBsr",
212 llvmPointerType,
213 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
214 llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
215 llvmInt32Type, llvmInt32Type, llvmInt32Type,
216 llvmPointerType /* void *stream */}};
217 FunctionCallBuilder destroySpMatCallBuilder = {
218 "mgpuDestroySpMat",
219 llvmVoidType,
220 {llvmPointerType, llvmPointerType /* void *stream */}};
221 FunctionCallBuilder spMVBufferSizeCallBuilder = {
222 "mgpuSpMVBufferSize",
223 llvmIntPtrType,
224 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
225 llvmInt32Type, llvmPointerType /* void *stream */}};
226 FunctionCallBuilder spMVCallBuilder = {
227 "mgpuSpMV",
228 llvmVoidType,
229 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
230 llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
231 FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
232 "mgpuSpMMBufferSize",
233 llvmIntPtrType,
234 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
235 llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
236 FunctionCallBuilder createSpMMCallBuilder = {
237 "mgpuSpMM",
238 llvmVoidType,
239 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
240 llvmPointerType, llvmInt32Type, llvmPointerType,
241 llvmPointerType /* void *stream */}};
242 FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
243 "mgpuSDDMMBufferSize",
244 llvmIntPtrType,
245 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
246 llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
247 FunctionCallBuilder createSDDMMCallBuilder = {
248 "mgpuSDDMM",
249 llvmVoidType,
250 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
251 llvmPointerType, llvmInt32Type, llvmPointerType,
252 llvmPointerType /* void *stream */}};
253 FunctionCallBuilder createLtDnMatCallBuilder = {
254 "mgpuCreateCuSparseLtDnMat",
255 llvmVoidType,
256 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
257 llvmInt32Type, llvmPointerType /* void *stream */}};
258 FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
259 "mgpuDestroyCuSparseLtSpMat",
260 llvmVoidType,
261 {llvmPointerType, llvmPointerType /* void *stream */}};
262 FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
263 "mgpuDestroyCuSparseLtDnMat",
264 llvmVoidType,
265 {llvmPointerType, llvmPointerType /* void *stream */}};
266 FunctionCallBuilder create2To4SpMatCallBuilder = {
267 "mgpuCusparseLtCreate2To4SpMat",
268 llvmVoidType,
269 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
270 llvmInt32Type, llvmPointerType /* void *stream */}};
271 FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
272 "mgpuCuSparseLtSpMMBufferSize",
273 llvmVoidType,
274 {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
275 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
276 llvmPointerType /*void *stream*/}};
277 FunctionCallBuilder createCuSparseLtSpMMBuilder = {
278 "mgpuCuSparseLtSpMM",
279 llvmVoidType,
280 {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
281 llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}};
282 FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
283 "mgpuSpGEMMCreateDescr",
284 llvmPointerType,
285 {llvmPointerType /*void *stream*/}};
286 FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
287 "mgpuSpGEMMDestroyDescr",
288 llvmVoidType,
289 {llvmPointerType /*s*/, llvmPointerType /*void *stream*/}};
290 FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
291 "mgpuSpGEMMWorkEstimation",
292 llvmIntPtrType,
293 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
294 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
295 llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
296 llvmPointerType /*void *stream*/}};
297 FunctionCallBuilder createSpGEMMComputeBuilder = {
298 "mgpuSpGEMMCompute",
299 llvmIntPtrType,
300 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
301 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
302 llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
303 llvmPointerType /*void *stream*/}};
304 FunctionCallBuilder createSpGEMMCopyBuilder = {
305 "mgpuSpGEMMCopy",
306 llvmVoidType,
307 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
308 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
309 llvmInt32Type /*ctp*/, llvmPointerType /*void *stream*/}};
310 FunctionCallBuilder createSpMatGetSizeBuilder = {
311 "mgpuSpMatGetSize",
312 llvmVoidType,
313 {llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/,
314 llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}};
315 FunctionCallBuilder createSetCsrPointersBuilder = {
316 "mgpuSetCsrPointers",
317 llvmVoidType,
318 {llvmPointerType /*spmat*/, llvmPointerType /*pos*/,
319 llvmPointerType /*crd*/, llvmPointerType /*val*/,
320 llvmPointerType /*void *stream*/}};
321};
322
323/// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
324/// call. Currently it supports CUDA and ROCm (HIP).
325class ConvertHostRegisterOpToGpuRuntimeCallPattern
326 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
327public:
328 ConvertHostRegisterOpToGpuRuntimeCallPattern(
329 const LLVMTypeConverter &typeConverter)
330 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
331
332private:
333 LogicalResult
334 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter) const override;
336};
337
338class ConvertHostUnregisterOpToGpuRuntimeCallPattern
339 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
340public:
341 ConvertHostUnregisterOpToGpuRuntimeCallPattern(
342 const LLVMTypeConverter &typeConverter)
343 : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
344 }
345
346private:
347 LogicalResult
348 matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter) const override;
350};
351
352/// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
353/// call. Currently it supports CUDA and ROCm (HIP).
354class ConvertAllocOpToGpuRuntimeCallPattern
355 : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
356public:
357 ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
358 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
359
360private:
361 LogicalResult
362 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter) const override;
364};
365
366/// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
367/// call. Currently it supports CUDA and ROCm (HIP).
368class ConvertDeallocOpToGpuRuntimeCallPattern
369 : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
370public:
371 ConvertDeallocOpToGpuRuntimeCallPattern(
372 const LLVMTypeConverter &typeConverter)
373 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
374
375private:
376 LogicalResult
377 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
378 ConversionPatternRewriter &rewriter) const override;
379};
380
381class ConvertAsyncYieldToGpuRuntimeCallPattern
382 : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
383public:
384 ConvertAsyncYieldToGpuRuntimeCallPattern(
385 const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
386 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter,
387 benefit) {}
388
389private:
390 LogicalResult
391 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
392 ConversionPatternRewriter &rewriter) const override;
393};
394
395/// A rewrite pattern to convert gpu.wait operations into a GPU runtime
396/// call. Currently it supports CUDA and ROCm (HIP).
397class ConvertWaitOpToGpuRuntimeCallPattern
398 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
399public:
400 ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
401 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
402
403private:
404 LogicalResult
405 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
406 ConversionPatternRewriter &rewriter) const override;
407};
408
409/// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
410/// call. Currently it supports CUDA and ROCm (HIP).
411class ConvertWaitAsyncOpToGpuRuntimeCallPattern
412 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
413public:
414 ConvertWaitAsyncOpToGpuRuntimeCallPattern(
415 const LLVMTypeConverter &typeConverter)
416 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
417
418private:
419 LogicalResult
420 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
421 ConversionPatternRewriter &rewriter) const override;
422};
423
424/// A rewrite patter to legalize gpu.launch_func with LLVM types.
425class LegalizeLaunchFuncOpPattern
426 : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
427public:
428 LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
429 bool kernelBarePtrCallConv,
430 bool kernelIntersperseSizeCallConv)
431 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
432 kernelBarePtrCallConv(kernelBarePtrCallConv),
433 kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
434
435private:
436 LogicalResult
437 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
438 ConversionPatternRewriter &rewriter) const override;
439
440 bool kernelBarePtrCallConv;
441 bool kernelIntersperseSizeCallConv;
442};
443
444/// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
445/// call. Currently it supports CUDA and ROCm (HIP).
446class ConvertMemcpyOpToGpuRuntimeCallPattern
447 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
448public:
449 ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
450 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
451
452private:
453 LogicalResult
454 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
455 ConversionPatternRewriter &rewriter) const override;
456};
457
458/// A rewrite pattern to convert gpu.memset operations into a GPU runtime
459/// call. Currently it supports CUDA and ROCm (HIP).
460class ConvertMemsetOpToGpuRuntimeCallPattern
461 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
462public:
463 ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
464 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
465
466private:
467 LogicalResult
468 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
469 ConversionPatternRewriter &rewriter) const override;
470};
471
472/// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call.
473/// Currently supports CUDA and ROCm (HIP)
474class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
475 : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
476public:
477 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
478 const LLVMTypeConverter &typeConverter)
479 : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
480 typeConverter) {}
481
482 LogicalResult
483 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
484 ConversionPatternRewriter &rewriter) const override;
485};
486
487/// Generic rewriting rule for operation on sparse matrices.
488/// Currently supports CUDA (by means of cuSparse and cuSparseLt).
489#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
490 class Convert##op_name##ToGpuRuntimeCallPattern \
491 : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
492 public: \
493 Convert##op_name##ToGpuRuntimeCallPattern( \
494 const LLVMTypeConverter &typeConverter) \
495 : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
496 \
497 private: \
498 LogicalResult \
499 matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
500 ConversionPatternRewriter &rewriter) const override; \
501 };
502
520DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp)
524
525} // namespace
526
527void GpuToLLVMConversionPass::runOnOperation() {
528 MLIRContext *context = &getContext();
529
530 // Perform progressive lowering of vector transfer operations.
531 {
532 RewritePatternSet patterns(&getContext());
533 // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
535 /*maxTransferRank=*/1);
536 // Transform N-D vector.from_elements to 1-D vector.from_elements before
537 // conversion.
538 vector::populateVectorFromElementsUnrollPatterns(patterns);
539 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
540 return signalPassFailure();
541 }
542
543 LowerToLLVMOptions options(context);
544 options.useBarePtrCallConv = hostBarePtrCallConv;
545 RewritePatternSet patterns(context);
546 ConversionTarget target(*context);
547 target.addLegalDialect<LLVM::LLVMDialect>();
548 LLVMTypeConverter converter(context, options);
549
550 // Populate all patterns from all dialects that implement the
551 // `ConvertToLLVMPatternInterface` interface.
552 for (Dialect *dialect : context->getLoadedDialects()) {
553 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
554 if (!iface)
555 continue;
556 iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
557 }
558
559 // Preserve GPU modules and binaries. Modules are preserved as they can be
560 // converted later by `gpu-module-to-binary`.
561 target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
562 // Accept as legal LaunchFuncOps if the operands have been lowered.
563 target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
564 [&](gpu::LaunchFuncOp op) -> bool { return converter.isLegal(op); });
565
566 // These aren't covered by the ConvertToLLVMPatternInterface right now.
567 populateVectorToLLVMConversionPatterns(converter, patterns);
570 target);
571 populateGpuToLLVMConversionPatterns(converter, patterns,
572 kernelBarePtrCallConv,
573 kernelIntersperseSizeCallConv);
574
575 if (failed(
576 applyPartialConversion(getOperation(), target, std::move(patterns))))
577 signalPassFailure();
578}
579
581 ArrayRef<Value> arguments) const {
582 auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
583 auto function = [&] {
584 if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
585 return function;
586 auto builder = OpBuilder::atBlockEnd(module.getBody());
587 return LLVM::LLVMFuncOp::create(builder, loc, functionName, functionType);
588 }();
589 return LLVM::CallOp::create(builder, loc, function, arguments);
590}
591
592// Corresponding to cusparseIndexType_t defined in cusparse.h.
593static int32_t getCuSparseIndexTypeFrom(Type type) {
594 if (type.isInteger(16))
595 return 1; // CUSPARSE_INDEX_16U
596 if (type.isInteger(32))
597 return 2; // CUSPARSE_INDEX_32I
598 return 3; // CUSPARSE_INDEX_64I
599}
600
601static int32_t getCuSparseLtDataTypeFrom(Type type) {
602 if (type.isF16())
603 return 0; // CUSPARSE_COMPUTE_16F,
604 if (type.isInteger(32))
605 return 1; // CUSPARSE_COMPUTE_32I
606 llvm_unreachable("unsupported type");
607 // TODO: add support to TF32
608}
609
610// Corresponding to cudaDataType_t defined in CUDA library_types.h.
611static int32_t getCuSparseDataTypeFrom(Type type) {
612 if (llvm::isa<ComplexType>(type)) {
613 // get the element type
614 auto elementType = cast<ComplexType>(type).getElementType();
615 if (elementType.isBF16())
616 return 15; // CUDA_C_16BF
617 if (elementType.isF16())
618 return 6; // CUDA_C_16F
619 if (elementType.isF32())
620 return 4; // CUDA_C_32F
621 if (elementType.isF64())
622 return 5; // CUDA_C_64F
623 if (elementType.isInteger(8))
624 return 7; // CUDA_C_8I
625 if (elementType.isInteger(16))
626 return 21; // CUDA_C_16I
627 if (elementType.isInteger(32))
628 return 11; // CUDA_C_32I
629 }
630 if (type.isBF16())
631 return 14; // CUDA_R_16BF
632 if (type.isF16())
633 return 2; // CUDA_R_16F
634 if (type.isF32())
635 return 0; // CUDA_R_32F
636 if (type.isF64())
637 return 1; // CUDA_R_64F
638 if (type.isInteger(8))
639 return 3; // CUDA_R_8I
640 if (type.isInteger(16))
641 return 20; // CUDA_R_16I
642 if (type.isInteger(32))
643 return 10; // CUDA_R_32I
644
645 llvm_unreachable("unsupported element type");
646}
647
648static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat) {
649 return spMat.getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
650}
651
652// TODO: We may want a run-time (of the mlir compiler) disablement/warning:
653// cusparseLt currently won't work for cuda architecture <8.0 and will trigger a
654// runtime (of the CUDA program) error , but it might be great if we could at
655// least output a warning when we found the target architecture is <8.0 and the
656// user still wants to use cusparseLt. to make sure when lowering gpu sparse
657// dialect to llvm calls, the cusparselt calls are disabled for cuda
658// architecture <8.0
659static bool is2To4Sparsity(Value spMat) {
660 if (auto op = spMat.getDefiningOp<gpu::Create2To4SpMatOp>())
661 return true;
662 if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
663 return false;
664 if (auto op = spMat.getDefiningOp<gpu::CreateCooAoSOp>())
665 return false;
666 if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
667 return false;
668 if (auto op = spMat.getDefiningOp<gpu::CreateCscOp>())
669 return false;
670 if (auto op = spMat.getDefiningOp<gpu::CreateBsrOp>())
671 return false;
672 // Print the spMat defining op
673 spMat.getDefiningOp()->print(llvm::errs());
674 llvm_unreachable("cannot find spmat def");
675}
676
677static bool isSpMMCusparseLtOp(Value op) {
678 for (Operation *user : op.getUsers()) {
679 auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
680 // If the other operator is 50% sparsity then we should use cusparseLt
681 if (!spmmOp)
682 continue;
683 if (is2To4Sparsity(spmmOp.getSpmatA()))
684 return true;
685 }
686 return false;
687}
688
689// Returns whether all operands are of LLVM type.
690static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
691 ConversionPatternRewriter &rewriter) {
692 if (!llvm::all_of(operands, [](Value value) {
693 return LLVM::isCompatibleType(value.getType());
694 }))
695 return rewriter.notifyMatchFailure(
696 op, "Cannot convert if operands aren't of LLVM type.");
697 return success();
698}
699
700static LogicalResult
701isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
702 gpu::AsyncOpInterface op) {
703 if (op.getAsyncDependencies().size() != 1)
704 return rewriter.notifyMatchFailure(
705 op, "Can only convert with exactly one async dependency.");
706
707 if (!op.getAsyncToken())
708 return rewriter.notifyMatchFailure(op, "Can convert only async version.");
709
710 return success();
711}
712
713LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
714 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
715 ConversionPatternRewriter &rewriter) const {
716 auto *op = hostRegisterOp.getOperation();
717 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
718 return failure();
719
720 Location loc = op->getLoc();
721
722 auto memRefType = hostRegisterOp.getValue().getType();
723 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
724 auto elementSize = getSizeInBytes(loc, elementType, rewriter);
725
726 auto arguments = getTypeConverter()->promoteOperands(
727 loc, op->getOperands(), adaptor.getOperands(), rewriter);
728 arguments.push_back(elementSize);
729 hostRegisterCallBuilder.create(loc, rewriter, arguments);
730
731 rewriter.eraseOp(op);
732 return success();
733}
734
735LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
736 gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
737 ConversionPatternRewriter &rewriter) const {
738 Operation *op = hostUnregisterOp.getOperation();
739 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
740 return failure();
741
742 Location loc = op->getLoc();
743
744 auto memRefType = hostUnregisterOp.getValue().getType();
745 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
746 auto elementSize = getSizeInBytes(loc, elementType, rewriter);
747
748 auto arguments = getTypeConverter()->promoteOperands(
749 loc, op->getOperands(), adaptor.getOperands(), rewriter);
750 arguments.push_back(elementSize);
751 hostUnregisterCallBuilder.create(loc, rewriter, arguments);
752
753 rewriter.eraseOp(op);
754 return success();
755}
756
757LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
758 gpu::AllocOp allocOp, OpAdaptor adaptor,
759 ConversionPatternRewriter &rewriter) const {
760
761 MemRefType memRefType = allocOp.getType();
762
763 if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
764 !isConvertibleAndHasIdentityMaps(memRefType))
765 return failure();
766
767 auto loc = allocOp.getLoc();
768
769 bool isShared = allocOp.getHostShared();
770
771 if (isShared && allocOp.getAsyncToken())
772 return rewriter.notifyMatchFailure(
773 allocOp, "Host Shared allocation cannot be done async");
774 if (adaptor.getAsyncDependencies().size() > 1)
775 return rewriter.notifyMatchFailure(
776 allocOp, "Can convert with at most one async dependency.");
777
778 // Get shape of the memref as values: static sizes are constant
779 // values and dynamic sizes are passed to 'alloc' as operands.
780 SmallVector<Value, 4> shape;
781 SmallVector<Value, 4> strides;
782 Value sizeBytes;
783 getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
784 shape, strides, sizeBytes);
785
786 // Allocate the underlying buffer and store a pointer to it in the MemRef
787 // descriptor.
788 auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType);
789 Value stream = adaptor.getAsyncDependencies().empty()
790 ? nullPtr
791 : adaptor.getAsyncDependencies().front();
792
793 auto isHostShared = mlir::LLVM::ConstantOp::create(
794 rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
795
796 Value allocatedPtr =
797 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
798 .getResult();
799
800 // Cast the runtime-returned pointer to the memref's address space if it
801 // isn't there already, so it fits the descriptor's pointer slots.
802 unsigned dstAddrSpace = memRefType.getMemorySpaceAsInt();
803 unsigned srcAddrSpace =
804 cast<LLVM::LLVMPointerType>(allocatedPtr.getType()).getAddressSpace();
805 if (dstAddrSpace != srcAddrSpace) {
806 auto targetPtrTy =
807 LLVM::LLVMPointerType::get(rewriter.getContext(), dstAddrSpace);
808 allocatedPtr =
809 LLVM::AddrSpaceCastOp::create(rewriter, loc, targetPtrTy, allocatedPtr);
810 }
811
812 // No alignment.
813 Value alignedPtr = allocatedPtr;
814
815 // Create the MemRef descriptor.
816 auto memRefDescriptor = this->createMemRefDescriptor(
817 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
818
819 if (allocOp.getAsyncToken()) {
820 // Async alloc: make dependent ops use the same stream.
821 rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
822 } else {
823 rewriter.replaceOp(allocOp, {memRefDescriptor});
824 }
825
826 return success();
827}
828
829LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
830 gpu::DeallocOp deallocOp, OpAdaptor adaptor,
831 ConversionPatternRewriter &rewriter) const {
832 if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)))
833 return failure();
834 if (adaptor.getAsyncDependencies().size() > 1)
835 return rewriter.notifyMatchFailure(
836 deallocOp, "Can convert with at most one async dependency.");
837
838 Location loc = deallocOp.getLoc();
839
840 Value pointer =
841 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
842 auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType);
843 Value stream = adaptor.getAsyncDependencies().empty()
844 ? nullPtr
845 : adaptor.getAsyncDependencies().front();
846 deallocCallBuilder.create(loc, rewriter, {pointer, stream});
847
848 if (deallocOp.getAsyncToken()) {
849 // Async dealloc: propagate the stream as the async token replacement.
850 rewriter.replaceOp(deallocOp, {stream});
851 } else {
852 // Sync dealloc: no results to replace, just remove the op.
853 rewriter.eraseOp(deallocOp);
854 }
855 return success();
856}
857
858static bool isGpuAsyncTokenType(Value value) {
859 return isa<gpu::AsyncTokenType>(value.getType());
860}
861
862// Converts !gpu.async.token operands of `async.yield` to runtime calls. The
863// !gpu.async.token are lowered to stream within the async.execute region, but
864// are passed as events between them. For each !gpu.async.token operand, we
865// create an event and record it on the stream.
866//
867// This pattern is registered with a higher benefit than the structural
868// async.yield rewriter from populateAsyncStructuralTypeConversionsAndLegality
869// so it wins when both match. Without that benefit override, the structural
870// pattern can win and silently retype gpu.async.token operands without
871// recording an event, leaving the host await to call cuEventSynchronize on
872// a stream pointer (a no-op that returns an error), racing the host against
873// the GPU.
874LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
875 async::YieldOp yieldOp, OpAdaptor adaptor,
876 ConversionPatternRewriter &rewriter) const {
877 if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType))
878 return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
879
880 Location loc = yieldOp.getLoc();
881 SmallVector<Value, 4> newOperands(adaptor.getOperands());
882 llvm::SmallDenseSet<Value> streams;
883 for (auto &operand : yieldOp->getOpOperands()) {
884 if (!isGpuAsyncTokenType(operand.get()))
885 continue;
886 auto idx = operand.getOperandNumber();
887 auto stream = adaptor.getOperands()[idx];
888 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
889 eventRecordCallBuilder.create(loc, rewriter, {event, stream});
890 newOperands[idx] = event;
891 streams.insert(stream);
892 }
893 for (auto stream : streams)
894 streamDestroyCallBuilder.create(loc, rewriter, {stream});
895
896 rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
897 return success();
898}
899
900// Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
901static bool isDefinedByCallTo(Value value, StringRef functionName) {
902 assert(isa<LLVM::LLVMPointerType>(value.getType()));
903 if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
904 return *defOp.getCallee() == functionName;
905 return false;
906}
907
908// Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
909// with the stream/event operands. The operands are destroyed. That is, it
910// assumes that it is not used afterwards or elsewhere. Otherwise we will get a
911// runtime error. Eventually, we should guarantee this property.
912LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
913 gpu::WaitOp waitOp, OpAdaptor adaptor,
914 ConversionPatternRewriter &rewriter) const {
915 if (waitOp.getAsyncToken())
916 return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
917
918 Location loc = waitOp.getLoc();
919
920 for (auto operand : adaptor.getOperands()) {
921 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
922 // The converted operand's definition created a stream.
923 streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
924 streamDestroyCallBuilder.create(loc, rewriter, {operand});
925 } else {
926 // Otherwise the converted operand is an event. This assumes that we use
927 // events in control flow code as well.
928 eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
929 eventDestroyCallBuilder.create(loc, rewriter, {operand});
930 }
931 }
932
933 rewriter.eraseOp(waitOp);
934 return success();
935}
936
937// Converts `gpu.wait async` to runtime calls. The converted op creates a new
938// stream that is synchronized with stream/event operands. The operands are
939// destroyed. That is, it assumes that it is not used afterwards or elsewhere.
940// Otherwise we will get a runtime error. Eventually, we should guarantee this
941// property.
942LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
943 gpu::WaitOp waitOp, OpAdaptor adaptor,
944 ConversionPatternRewriter &rewriter) const {
945 if (!waitOp.getAsyncToken())
946 return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
947
948 Location loc = waitOp.getLoc();
949
950 auto insertionPoint = rewriter.saveInsertionPoint();
951 SmallVector<Value, 1> events;
952 for (auto pair :
953 llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
954 auto operand = std::get<1>(pair);
955 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
956 // The converted operand's definition created a stream. Insert an event
957 // into the stream just after the last use of the original token operand.
958 auto *defOp = std::get<0>(pair).getDefiningOp();
959 rewriter.setInsertionPointAfter(defOp);
960 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
961 eventRecordCallBuilder.create(loc, rewriter, {event, operand});
962 events.push_back(event);
963 } else {
964 // Otherwise the converted operand is an event. This assumes that we use
965 // events in control flow code as well.
966 events.push_back(operand);
967 }
968 }
969 rewriter.restoreInsertionPoint(insertionPoint);
970 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
971 for (auto event : events)
972 streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
973 for (auto event : events)
974 eventDestroyCallBuilder.create(loc, rewriter, {event});
975 rewriter.replaceOp(waitOp, {stream});
976
977 return success();
978}
979
980// Legalize the op's operands.
981LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
982 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
983 ConversionPatternRewriter &rewriter) const {
984 if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
985 return failure();
986
987 // Fail when the synchronous version of the op has async dependencies. The
988 // lowering destroys the stream, and we do not want to check that there is no
989 // use of the stream after this op.
990 if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
991 return rewriter.notifyMatchFailure(
992 launchOp, "Cannot convert non-async op with async dependencies.");
993
994 Location loc = launchOp.getLoc();
995
996 Value stream = Value();
997 if (!adaptor.getAsyncDependencies().empty()) {
998 stream = adaptor.getAsyncDependencies().front();
999 // Synchronize additional async dependencies onto the primary stream using
1000 // events, following the same approach as gpu.wait async lowering.
1001 if (adaptor.getAsyncDependencies().size() > 1) {
1002 auto insertionPoint = rewriter.saveInsertionPoint();
1003 SmallVector<Value, 4> events;
1004 for (auto [origDep, convertedDep] :
1005 llvm::zip(launchOp.getAsyncDependencies().drop_front(),
1006 adaptor.getAsyncDependencies().drop_front())) {
1007 if (!isDefinedByCallTo(convertedDep,
1008 streamCreateCallBuilder.functionName)) {
1009 events.push_back(convertedDep);
1010 continue;
1011 }
1012 Operation *defOp = origDep.getDefiningOp();
1013 rewriter.setInsertionPointAfter(defOp);
1014 Value event =
1015 eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
1016 eventRecordCallBuilder.create(loc, rewriter, {event, convertedDep});
1017 events.push_back(event);
1018 }
1019 rewriter.restoreInsertionPoint(insertionPoint);
1020 for (Value event : events)
1021 streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
1022 for (Value event : events)
1023 eventDestroyCallBuilder.create(loc, rewriter, {event});
1024 }
1025 }
1026 // If the async keyword is present and there are no dependencies, then a
1027 // stream must be created to pass to subsequent operations.
1028 else if (launchOp.getAsyncToken())
1029 stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
1030
1031 // Lower the kernel operands to match kernel parameters.
1032 // Note: If `useBarePtrCallConv` is set in the type converter's options,
1033 // the value of `kernelBarePtrCallConv` will be ignored.
1034 OperandRange origArguments = launchOp.getKernelOperands();
1035 bool effectiveBarePtr = kernelBarePtrCallConv ||
1036 getTypeConverter()->getOptions().useBarePtrCallConv;
1037 if (effectiveBarePtr) {
1038 for (Value arg : origArguments) {
1039 if (isa<UnrankedMemRefType>(arg.getType()))
1040 return rewriter.notifyMatchFailure(
1041 loc, "unranked memref kernel argument is not supported with "
1042 "the bare-pointer calling convention");
1043 }
1044 }
1045 SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
1046 loc, origArguments, adaptor.getKernelOperands(), rewriter,
1047 /*useBarePtrCallConv=*/kernelBarePtrCallConv);
1048 SmallVector<Value, 8> llvmArgumentsWithSizes;
1049
1050 // Intersperse size information if requested.
1051 if (kernelIntersperseSizeCallConv) {
1052 if (origArguments.size() != llvmArguments.size()) {
1053 // This shouldn't happen if the bare-pointer calling convention is used.
1054 return rewriter.notifyMatchFailure(
1055 launchOp,
1056 "Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
1057 }
1058
1059 llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
1060 for (auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
1061 auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
1062 if (!memrefTy) {
1063 return rewriter.notifyMatchFailure(
1064 launchOp, "Operand to launch op is not a memref.");
1065 }
1066
1067 if (!memrefTy.hasStaticShape() ||
1068 !memrefTy.getElementType().isIntOrFloat()) {
1069 return rewriter.notifyMatchFailure(
1070 launchOp, "Operand to launch op is not a memref with a static "
1071 "shape and an integer or float element type.");
1072 }
1073
1074 unsigned bitwidth = memrefTy.getElementTypeBitWidth();
1075 if (bitwidth % 8 != 0) {
1076 return rewriter.notifyMatchFailure(
1077 launchOp, "Operand to launch op is not a memref with a "
1078 "byte-aligned element type.");
1079 }
1080
1081 uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) *
1082 static_cast<uint64_t>(memrefTy.getNumElements());
1083
1084 Value sizeArg = LLVM::ConstantOp::create(
1085 rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize));
1086 llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer.
1087 llvmArgumentsWithSizes.push_back(sizeArg);
1088 }
1089 }
1090
1091 std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1092 if (launchOp.hasClusterSize()) {
1093 clusterSize =
1094 gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
1095 adaptor.getClusterSizeZ()};
1096 }
1097 auto newLaunchOp = gpu::LaunchFuncOp::create(
1098 rewriter, launchOp.getLoc(), launchOp.getKernelAttr(),
1099 gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1100 adaptor.getGridSizeZ()},
1101 gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1102 adaptor.getBlockSizeZ()},
1103 adaptor.getDynamicSharedMemorySize(),
1104 llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
1105 stream, clusterSize);
1106 if (launchOp.getCooperative())
1107 newLaunchOp.setCooperative(true);
1108 if (launchOp.getAsyncToken())
1109 rewriter.replaceOp(launchOp, {stream});
1110 else
1111 rewriter.eraseOp(launchOp);
1112 return success();
1113}
1114
1116 ConversionPatternRewriter &rewriter,
1117 LLVM::LLVMPointerType destinationType,
1118 Value sourcePtr,
1119 const LLVMTypeConverter &typeConverter) {
1120 auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
1121 if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1122 sourcePtr = LLVM::AddrSpaceCastOp::create(
1123 rewriter, loc,
1124 LLVM::LLVMPointerType::get(rewriter.getContext(),
1125 destinationType.getAddressSpace()),
1126 sourcePtr);
1127 return sourcePtr;
1128}
1129
1130LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1131 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1132 ConversionPatternRewriter &rewriter) const {
1133 auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1134
1135 if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
1136 !isConvertibleAndHasIdentityMaps(memRefType) ||
1137 failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
1138 return failure();
1139
1140 auto loc = memcpyOp.getLoc();
1141
1142 MemRefDescriptor srcDesc(adaptor.getSrc());
1143 Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
1144
1145 Type elementPtrType = getElementPtrType(memRefType);
1146 Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
1147 Value gepPtr = LLVM::GEPOp::create(
1148 rewriter, loc, elementPtrType,
1149 typeConverter->convertType(memRefType.getElementType()), nullPtr,
1150 numElements);
1151 auto sizeBytes =
1152 LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr);
1153
1154 auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1155 srcDesc.alignedPtr(rewriter, loc),
1156 *getTypeConverter());
1157 auto dst = bitAndAddrspaceCast(
1158 loc, rewriter, llvmPointerType,
1159 MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
1160 *getTypeConverter());
1161
1162 auto stream = adaptor.getAsyncDependencies().front();
1163 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1164
1165 rewriter.replaceOp(memcpyOp, {stream});
1166
1167 return success();
1168}
1169
1170LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1171 gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1172 ConversionPatternRewriter &rewriter) const {
1173 auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1174
1175 if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
1176 !isConvertibleAndHasIdentityMaps(memRefType) ||
1177 failed(isAsyncWithOneDependency(rewriter, memsetOp)))
1178 return failure();
1179
1180 auto loc = memsetOp.getLoc();
1181
1182 Type valueType = adaptor.getValue().getType();
1183 unsigned bitWidth = valueType.getIntOrFloatBitWidth();
1184 // Ints and floats of 16 or 32 bit width are allowed.
1185 if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1186 return rewriter.notifyMatchFailure(
1187 memsetOp, "value must be a 16 or 32 bit int or float");
1188 }
1189
1190 unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth();
1191 Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1192
1193 MemRefDescriptor dstDesc(adaptor.getDst());
1194 Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
1195
1196 auto value =
1197 LLVM::BitcastOp::create(rewriter, loc, bitCastType, adaptor.getValue());
1198 auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1199 dstDesc.alignedPtr(rewriter, loc),
1200 *getTypeConverter());
1201
1202 auto stream = adaptor.getAsyncDependencies().front();
1203 FunctionCallBuilder builder =
1204 valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1205 builder.create(loc, rewriter, {dst, value, numElements, stream});
1206
1207 rewriter.replaceOp(memsetOp, {stream});
1208 return success();
1209}
1210
1211LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1212 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1213 ConversionPatternRewriter &rewriter) const {
1214 Location loc = op.getLoc();
1215 auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1216 {adaptor.getDevIndex()});
1217 rewriter.replaceOp(op, call);
1218 return success();
1219}
1220
1221template <typename T>
1222static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) {
1223 Type llvmInt32Type = builder.getIntegerType(32);
1224 return LLVM::ConstantOp::create(builder, loc, llvmInt32Type,
1225 static_cast<int32_t>(tValue));
1226}
1227
1228template <typename T>
1229static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) {
1230 Type llvmFloat32Type = builder.getF32Type();
1231 return LLVM::ConstantOp::create(
1232 builder, loc, llvmFloat32Type,
1233 builder.getF32FloatAttr(static_cast<float>(tValue)));
1234}
1235
1236LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1237 gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1238 ConversionPatternRewriter &rewriter) const {
1239 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1240 failed(isAsyncWithOneDependency(rewriter, op)))
1241 return failure();
1242 Location loc = op.getLoc();
1243 auto stream = adaptor.getAsyncDependencies().front();
1244 Value pTensor =
1245 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1246 Type dType = op.getMemref().getType().getElementType();
1247 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1248
1249 SmallVector<Value, 4> dims;
1250 for (Value dim : adaptor.getDims()) {
1251 dims.push_back(dim);
1252 }
1253
1254 Value handle;
1255 // TODO: For now, we track the use of the handle and lower it to cusparse /
1256 // cusparseLt accordingly. If in a block, both cusparse and cusparseLt are
1257 // used, we require two separate Creation ops to be the correct logic. In
1258 // future, we may add support to using one handle in sparse tensor / GPU
1259 // dialect in both cusparse and cusparseLt. use the cusparseLt create call if
1260 // the dnmat is used with spmat with 2:4 sparsity
1261 if (dims.size() == 2) {
1262 if (isSpMMCusparseLtOp(op.getDnTensor())) {
1263 auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1264 rewriter.getIndexAttr(11032));
1265 handle = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1266 llvmInt8Type, handleSz, /*alignment=*/16);
1267 handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1268
1269 createLtDnMatCallBuilder
1270 .create(loc, rewriter,
1271 {handle, dims[0], dims[1], pTensor, dtp, stream})
1272 .getResult();
1273 } else {
1274 handle =
1275 createDnMatCallBuilder
1276 .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1277 .getResult();
1278 }
1279 } else {
1280 assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1281 handle = createDnVecCallBuilder
1282 .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1283 .getResult();
1284 }
1285 rewriter.replaceOp(op, {handle, stream});
1286 return success();
1287}
1288
1289LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1290 gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1291 ConversionPatternRewriter &rewriter) const {
1292 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1293 failed(isAsyncWithOneDependency(rewriter, op)))
1294 return failure();
1295 Location loc = op.getLoc();
1296 auto stream = adaptor.getAsyncDependencies().front();
1297 auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
1298 SmallVector<Value, 4> dims;
1299 for (Value dim : definingOp.getDims()) {
1300 dims.push_back(dim);
1301 }
1302 if (dims.size() == 2) {
1303 // Use the cusparseLt destroy call if the dnmat is used with spmat with
1304 // 2:4 sparsity
1305 if (isSpMMCusparseLtOp(op.getDnTensor())) {
1306 destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1307 {adaptor.getDnTensor(), stream});
1308 } else {
1309 destroyDnMatCallBuilder.create(loc, rewriter,
1310 {adaptor.getDnTensor(), stream});
1311 }
1312 } else {
1313 assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1314 destroyDnVecCallBuilder.create(loc, rewriter,
1315 {adaptor.getDnTensor(), stream});
1316 }
1317 rewriter.replaceOp(op, {stream});
1318 return success();
1319}
1320
1321LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1322 gpu::CreateCooOp op, OpAdaptor adaptor,
1323 ConversionPatternRewriter &rewriter) const {
1324 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1325 failed(isAsyncWithOneDependency(rewriter, op)))
1326 return failure();
1327 Location loc = op.getLoc();
1328 auto stream = adaptor.getAsyncDependencies().front();
1329 Value pRowIdxs =
1330 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1331 Value pColIdxs =
1332 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1333 Value pValues =
1334 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1335 Type iType =
1336 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1337 Type dType =
1338 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1339 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1340 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1341 auto handle =
1342 createCooCallBuilder
1343 .create(loc, rewriter,
1344 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1345 pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1346 .getResult();
1347 rewriter.replaceOp(op, {handle, stream});
1348 return success();
1349}
1350
1351LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1352 gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1353 ConversionPatternRewriter &rewriter) const {
1354 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1355 failed(isAsyncWithOneDependency(rewriter, op)))
1356 return failure();
1357 Location loc = op.getLoc();
1358 auto stream = adaptor.getAsyncDependencies().front();
1359 Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc);
1360 Value pValues =
1361 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1362 Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1363 Type dType =
1364 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1365 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1366 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1367 auto handle =
1368 createCooAoSCallBuilder
1369 .create(loc, rewriter,
1370 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1371 pIdxs, pValues, itp, dtp, stream})
1372 .getResult();
1373 rewriter.replaceOp(op, {handle, stream});
1374 return success();
1375}
1376
1377LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1378 gpu::CreateCsrOp op, OpAdaptor adaptor,
1379 ConversionPatternRewriter &rewriter) const {
1380 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1381 failed(isAsyncWithOneDependency(rewriter, op)))
1382 return failure();
1383 Location loc = op.getLoc();
1384 auto stream = adaptor.getAsyncDependencies().front();
1385 Value pRowPos =
1386 MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc);
1387 Value pColIdxs =
1388 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1389 Value pValues =
1390 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1391 Type pType =
1392 llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1393 Type iType =
1394 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1395 Type dType =
1396 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1397 auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1398 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1399 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1400 auto handle =
1401 createCsrCallBuilder
1402 .create(loc, rewriter,
1403 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1404 pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1405 .getResult();
1406 rewriter.replaceOp(op, {handle, stream});
1407 return success();
1408}
1409
1410LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1411 gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1412 ConversionPatternRewriter &rewriter) const {
1413 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1414 failed(isAsyncWithOneDependency(rewriter, op)))
1415 return failure();
1416 Location loc = op.getLoc();
1417 auto stream = adaptor.getAsyncDependencies().front();
1418 Value pMat =
1419 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1420 Type dType =
1421 llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1422 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1423
1424 // CUDA runner asserts the size is 44104 bytes.
1425 auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1426 rewriter.getIndexAttr(44104));
1427 Value handle = LLVM::AllocaOp::create(
1428 rewriter, loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1429 handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
1430
1431 create2To4SpMatCallBuilder
1432 .create(loc, rewriter,
1433 {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1434 .getResult();
1435 rewriter.replaceOp(op, {handle, stream});
1436 return success();
1437}
1438
1439LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1440 gpu::DestroySpMatOp op, OpAdaptor adaptor,
1441 ConversionPatternRewriter &rewriter) const {
1442 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1443 failed(isAsyncWithOneDependency(rewriter, op)))
1444 return failure();
1445 Location loc = op.getLoc();
1446 auto stream = adaptor.getAsyncDependencies().front();
1447 // Use the cusparseLt destroy call if the spmat is 2:4 sparsity
1448 if (is2To4Sparsity(op.getSpmat())) {
1449 destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1450 {adaptor.getSpmat(), stream});
1451
1452 } else {
1453 destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1454 }
1455 rewriter.replaceOp(op, {stream});
1456 return success();
1457}
1458
1459LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1460 gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1461 ConversionPatternRewriter &rewriter) const {
1462 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1463 failed(isAsyncWithOneDependency(rewriter, op)))
1464 return failure();
1465 Location loc = op.getLoc();
1466 auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
1467 auto computeType = genConstInt32From(
1468 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1469 auto stream = adaptor.getAsyncDependencies().front();
1470 auto bufferSize = spMVBufferSizeCallBuilder
1471 .create(loc, rewriter,
1472 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1473 adaptor.getDnY(), computeType, stream})
1474 .getResult();
1475 rewriter.replaceOp(op, {bufferSize, stream});
1476 return success();
1477}
1478
1479LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1480 gpu::SpMVOp op, OpAdaptor adaptor,
1481 ConversionPatternRewriter &rewriter) const {
1482 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1483 failed(isAsyncWithOneDependency(rewriter, op)))
1484 return failure();
1485 Location loc = op.getLoc();
1486 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1487 auto computeType = genConstInt32From(
1488 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1489 auto stream = adaptor.getAsyncDependencies().front();
1490 Value pBuf =
1491 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1492 spMVCallBuilder.create(loc, rewriter,
1493 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1494 adaptor.getDnY(), computeType, pBuf, stream});
1495 rewriter.replaceOp(op, {stream});
1496 return success();
1497}
1498
1499LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1500 gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1501 ConversionPatternRewriter &rewriter) const {
1502 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1503 failed(isAsyncWithOneDependency(rewriter, op)))
1504 return failure();
1505 Location loc = op.getLoc();
1506 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1507 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1508 auto stream = adaptor.getAsyncDependencies().front();
1509 Value bufferSize;
1510 if (is2To4Sparsity(op.getSpmatA())) {
1511 auto pruneFlag =
1512 genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA()));
1513 auto computeType = genConstInt32From(
1514 rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
1515 auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1516 rewriter.getIndexAttr(3));
1517 auto bufferSize =
1518 LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmPointerType,
1519 three, /*alignment=*/16);
1520 createCuSparseLtSpMMBufferSizeBuilder
1521 .create(loc, rewriter,
1522 {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1523 adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1524 pruneFlag, stream})
1525 .getResult();
1526
1527 auto bufferSizePtr1 = LLVM::GEPOp::create(
1528 rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1529 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1530 rewriter.getIndexAttr(1))});
1531 auto bufferSizePtr2 = LLVM::GEPOp::create(
1532 rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
1533 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1534 rewriter.getIndexAttr(2))});
1535 auto bufferSize0 =
1536 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSize);
1537 auto bufferSize1 =
1538 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr1);
1539 auto bufferSize2 =
1540 LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr2);
1541
1542 rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1543 } else {
1544 auto computeType = genConstInt32From(
1545 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1546 bufferSize =
1547 createSpMMBufferSizeCallBuilder
1548 .create(loc, rewriter,
1549 {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1550 adaptor.getDnmatC(), computeType, stream})
1551 .getResult();
1552 rewriter.replaceOp(op, {bufferSize, stream});
1553 }
1554 return success();
1555}
1556
1557LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1558 gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1559 ConversionPatternRewriter &rewriter) const {
1560 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1561 failed(isAsyncWithOneDependency(rewriter, op)))
1562 return failure();
1563 Location loc = op.getLoc();
1564 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1565 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1566 auto computeType = genConstInt32From(
1567 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1568 auto stream = adaptor.getAsyncDependencies().front();
1569 auto bufferSize =
1570 createSDDMMBufferSizeCallBuilder
1571 .create(loc, rewriter,
1572 {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1573 adaptor.getSpmatC(), computeType, stream})
1574 .getResult();
1575 rewriter.replaceOp(op, {bufferSize, stream});
1576 return success();
1577}
1578
1579LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1580 gpu::SpMMOp op, OpAdaptor adaptor,
1581 ConversionPatternRewriter &rewriter) const {
1582 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1583 failed(isAsyncWithOneDependency(rewriter, op)))
1584 return failure();
1585 Location loc = op.getLoc();
1586 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1587 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1588 auto computeType = genConstInt32From(
1589 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1590
1591 auto stream = adaptor.getAsyncDependencies().front();
1592
1593 // Lower to cusparseLt if applicable
1594 if (is2To4Sparsity(op.getSpmatA())) {
1595 SmallVector<Value> pBufs;
1596 for (Value buffer : adaptor.getBuffers()) {
1597 Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
1598 pBufs.push_back(pBuf);
1599 }
1600 createCuSparseLtSpMMBuilder.create(
1601 loc, rewriter,
1602 {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1603 pBufs[0], pBufs[1], pBufs[2], stream});
1604 } else {
1605 Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
1606 .allocatedPtr(rewriter, loc);
1607 createSpMMCallBuilder.create(loc, rewriter,
1608 {modeA, modeB, adaptor.getSpmatA(),
1609 adaptor.getDnmatB(), adaptor.getDnmatC(),
1610 computeType, pBuf, stream});
1611 }
1612 rewriter.replaceOp(op, {stream});
1613 return success();
1614}
1615
1616template <typename T>
1618 converter.addConversion([&converter](T) -> Type {
1619 return LLVM::LLVMPointerType::get(&converter.getContext());
1620 });
1621}
1622
1623LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1624 gpu::SDDMMOp op, OpAdaptor adaptor,
1625 ConversionPatternRewriter &rewriter) const {
1626 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1627 failed(isAsyncWithOneDependency(rewriter, op)))
1628 return failure();
1629 Location loc = op.getLoc();
1630 auto computeType = genConstInt32From(
1631 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1632 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1633 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1634 auto stream = adaptor.getAsyncDependencies().front();
1635 Value pBuf =
1636 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1637 createSDDMMCallBuilder.create(loc, rewriter,
1638 {modeA, modeB, adaptor.getDnmatA(),
1639 adaptor.getDnmatB(), adaptor.getSpmatC(),
1640 computeType, pBuf, stream});
1641 rewriter.replaceOp(op, {stream});
1642 return success();
1643}
1644
1645LogicalResult
1646ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1647 gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1648 ConversionPatternRewriter &rewriter) const {
1649 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1650 failed(isAsyncWithOneDependency(rewriter, op)))
1651 return failure();
1652 Location loc = op.getLoc();
1653 auto stream = adaptor.getAsyncDependencies().front();
1654 Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1655 .getResult();
1656 rewriter.replaceOp(op, {descr, stream});
1657 return success();
1658}
1659
1660LogicalResult
1661ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1662 gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1663 ConversionPatternRewriter &rewriter) const {
1664 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1665 failed(isAsyncWithOneDependency(rewriter, op)))
1666 return failure();
1667 Location loc = op.getLoc();
1668 auto stream = adaptor.getAsyncDependencies().front();
1669 createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1670 {adaptor.getDesc(), stream});
1671 rewriter.replaceOp(op, {stream});
1672 return success();
1673}
1674
1675LogicalResult
1676ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1677 gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1678 ConversionPatternRewriter &rewriter) const {
1679 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1680 failed(isAsyncWithOneDependency(rewriter, op)))
1681 return failure();
1682 Location loc = op.getLoc();
1683 auto computeType = genConstInt32From(
1684 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1685 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1686 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1687 auto stream = adaptor.getAsyncDependencies().front();
1688
1689 Value pBuf =
1690 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1691 Value bufferSizeNew;
1692
1693 if (adaptor.getKind() ==
1694 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1695 bufferSizeNew =
1696 createSpGEMMWorkEstimationBuilder
1697 .create(loc, rewriter,
1698 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1699 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1700 adaptor.getBufferSz(), pBuf, stream})
1701 .getResult();
1702 } else {
1703 bufferSizeNew =
1704 createSpGEMMComputeBuilder
1705 .create(loc, rewriter,
1706 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1707 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1708 adaptor.getBufferSz(), pBuf, stream})
1709 .getResult();
1710 }
1711 rewriter.replaceOp(op, {bufferSizeNew, stream});
1712 return success();
1713}
1714
1715LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1716 gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1717 ConversionPatternRewriter &rewriter) const {
1718 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1719 failed(isAsyncWithOneDependency(rewriter, op)))
1720 return failure();
1721 Location loc = op.getLoc();
1722 auto computeType = genConstInt32From(
1723 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1724 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1725 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1726 auto stream = adaptor.getAsyncDependencies().front();
1727 createSpGEMMCopyBuilder.create(loc, rewriter,
1728 {adaptor.getDesc(), modeA, modeB,
1729 adaptor.getSpmatA(), adaptor.getSpmatB(),
1730 adaptor.getSpmatC(), computeType, stream});
1731 rewriter.replaceOp(op, {stream});
1732 return success();
1733}
1734
1735LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1736 gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1737 ConversionPatternRewriter &rewriter) const {
1738 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1739 failed(isAsyncWithOneDependency(rewriter, op)))
1740 return failure();
1741 Location loc = op.getLoc();
1742 auto stream = adaptor.getAsyncDependencies().front();
1743
1744 auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1745 rewriter.getIndexAttr(3));
1746 auto buffer = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
1747 llvmInt64Type, three, /*alignment=*/16);
1748
1749 auto rowsPtr = LLVM::GEPOp::create(
1750 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1751 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1752 rewriter.getIndexAttr(0))});
1753 auto colsPtr = LLVM::GEPOp::create(
1754 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1755 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1756 rewriter.getIndexAttr(1))});
1757 auto nnzsPtr = LLVM::GEPOp::create(
1758 rewriter, loc, llvmPointerType, llvmPointerType, buffer,
1759 ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1760 rewriter.getIndexAttr(2))});
1761 createSpMatGetSizeBuilder.create(
1762 loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1763 auto rows = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, rowsPtr);
1764 auto cols = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, colsPtr);
1765 auto nnzs = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, nnzsPtr);
1766
1767 rewriter.replaceOp(op, {rows, cols, nnzs, stream});
1768 return success();
1769}
1770
1771LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1772 gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1773 ConversionPatternRewriter &rewriter) const {
1774 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1775 failed(isAsyncWithOneDependency(rewriter, op)))
1776 return failure();
1777 Location loc = op.getLoc();
1778 auto stream = adaptor.getAsyncDependencies().front();
1779 Value pPos =
1780 MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
1781 Value pCrd =
1782 MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
1783 Value pVal =
1784 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1785 createSetCsrPointersBuilder.create(
1786 loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1787 rewriter.replaceOp(op, {stream});
1788 return success();
1789}
1790
1791LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1792 gpu::CreateCscOp op, OpAdaptor adaptor,
1793 ConversionPatternRewriter &rewriter) const {
1794 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1795 failed(isAsyncWithOneDependency(rewriter, op)))
1796 return failure();
1797 Location loc = op.getLoc();
1798 auto stream = adaptor.getAsyncDependencies().front();
1799 Value pColPos =
1800 MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc);
1801 Value pRowIdxs =
1802 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1803 Value pValues =
1804 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1805 Type pType =
1806 llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1807 Type iType =
1808 llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1809 Type dType =
1810 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1811 auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1812 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1813 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1814 auto handle =
1815 createCscCallBuilder
1816 .create(loc, rewriter,
1817 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1818 pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1819 .getResult();
1820 rewriter.replaceOp(op, {handle, stream});
1821 return success();
1822}
1823
1824LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1825 gpu::CreateBsrOp op, OpAdaptor adaptor,
1826 ConversionPatternRewriter &rewriter) const {
1827 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1828 failed(isAsyncWithOneDependency(rewriter, op)))
1829 return failure();
1830 Location loc = op.getLoc();
1831 auto stream = adaptor.getAsyncDependencies().front();
1832 Value pRowPos =
1833 MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc);
1834 Value pColIdxs =
1835 MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc);
1836 Value pValues =
1837 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1838 Type pType =
1839 llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1840 Type iType =
1841 llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1842 Type dType =
1843 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1844 auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1845 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1846 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1847 auto handle =
1848 createBsrCallBuilder
1849 .create(loc, rewriter,
1850 {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1851 adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1852 pColIdxs, pValues, ptp, itp, dtp, stream})
1853 .getResult();
1854 rewriter.replaceOp(op, {handle, stream});
1855 return success();
1856}
1857
1859 LLVMTypeConverter &converter, RewritePatternSet &patterns,
1860 bool kernelBarePtrCallConv, bool kernelIntersperseSizeCallConv) {
1865
1866 // Higher benefit so this pattern wins over the structural async.yield
1867 // rewriter from populateAsyncStructuralTypeConversionsAndLegality on yields
1868 // with gpu.async.token operands. The structural rewriter would silently
1869 // retype operands without recording an event on the underlying stream.
1870 patterns.add<ConvertAsyncYieldToGpuRuntimeCallPattern>(converter,
1871 /*benefit=*/2);
1872
1873 patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
1874 ConvertDeallocOpToGpuRuntimeCallPattern,
1875 ConvertHostRegisterOpToGpuRuntimeCallPattern,
1876 ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1877 ConvertMemcpyOpToGpuRuntimeCallPattern,
1878 ConvertMemsetOpToGpuRuntimeCallPattern,
1879 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1880 ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1881 ConvertWaitOpToGpuRuntimeCallPattern,
1882 ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1883 ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1884 ConvertCreateCooOpToGpuRuntimeCallPattern,
1885 ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1886 ConvertCreateCsrOpToGpuRuntimeCallPattern,
1887 ConvertCreateCscOpToGpuRuntimeCallPattern,
1888 ConvertCreateBsrOpToGpuRuntimeCallPattern,
1889 ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1890 ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1891 ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1892 ConvertSpMVOpToGpuRuntimeCallPattern,
1893 ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1894 ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1895 ConvertSpMMOpToGpuRuntimeCallPattern,
1896 ConvertSDDMMOpToGpuRuntimeCallPattern,
1897 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1898 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1899 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1900 ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1901 ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1902 ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
1903 patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv,
1904 kernelIntersperseSizeCallConv);
1905}
1906
1907//===----------------------------------------------------------------------===//
1908// GPUModuleOp convert to LLVM op interface
1909//===----------------------------------------------------------------------===//
1910
1911namespace {
1912struct GPUModuleOpConvertToLLVMInterface
1913 : public ConvertToLLVMOpInterface::ExternalModel<
1914 GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
1915 /// Get the conversion patterns from the target attribute.
1916 void getConvertToLLVMConversionAttrs(
1918};
1919} // namespace
1920
1921void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
1922 Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const {
1923 auto module = cast<gpu::GPUModuleOp>(op);
1924 ArrayAttr targetsAttr = module.getTargetsAttr();
1925 // Fail if there are no target attributes or there is more than one target.
1926 if (!targetsAttr || targetsAttr.size() != 1)
1927 return;
1928 if (auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
1929 attrs.push_back(patternAttr);
1930}
1931
1933 registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
1934 gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
1935 });
1936}
return success()
static void addOpaquePointerConversion(LLVMTypeConverter &converter)
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue)
static int32_t getCuSparseDataTypeFrom(Type type)
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter)
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue)
static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat)
static bool isGpuAsyncTokenType(Value value)
#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)
Generic rewriting rule for operation on sparse matrices.
static int32_t getCuSparseLtDataTypeFrom(Type type)
static bool isDefinedByCallTo(Value value, StringRef functionName)
static Value bitAndAddrspaceCast(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMPointerType destinationType, Value sourcePtr, const LLVMTypeConverter &typeConverter)
static bool isSpMMCusparseLtOp(Value op)
static int32_t getCuSparseIndexTypeFrom(Type type)
static bool is2To4Sparsity(Value spMat)
static LogicalResult isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, gpu::AsyncOpInterface op)
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
ArrayAttr()
b getContext())
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
FloatType getF32Type()
Definition Builders.cpp:47
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
FloatAttr getF32FloatAttr(float value)
Definition Builders.cpp:251
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition Pattern.cpp:38
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition Pattern.cpp:58
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
This class helps build Operations.
Definition Builders.h:209
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition Builders.h:248
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
void print(raw_ostream &os, const OpPrintingFlags &flags={})
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void registerConvertGpuToLLVMInterface(DialectRegistry &registry)
Registers the ConvertToLLVMOpInterface interface on the gpu::GPUModuleOP operation.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool kernelBarePtrCallConv=false, bool kernelIntersperseSizeCallConv=false)
Collect a set of patterns to convert from the GPU dialect to LLVM and populate converter for gpu type...
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
void populateAsyncStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for async structural type conversions.
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
LLVM::LLVMFunctionType functionType
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const