MLIR 22.0.0git
NVVMToLLVMIRTranslation.cpp
Go to the documentation of this file.
1//===- NVVMToLLVMIRTranslation.cpp - Translate NVVM to LLVM IR ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR NVVM dialect and
10// LLVM IR.
11//
12//===----------------------------------------------------------------------===//
13
16#include "mlir/IR/Operation.h"
18
19#include "llvm/ADT/StringExtras.h"
20#include "llvm/ADT/iterator_range.h"
21#include "llvm/IR/IRBuilder.h"
22#include "llvm/IR/IntrinsicsNVPTX.h"
23#include "llvm/Support/FormatVariadic.h"
24
25using namespace mlir;
26using namespace mlir::LLVM;
28
29#define REDUX_F32_ID_IMPL(op, abs, hasNaN) \
30 hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##_NaN \
31 : llvm::Intrinsic::nvvm_redux_sync_f##op##abs
32
33#define GET_REDUX_F32_ID(op, hasAbs, hasNaN) \
34 hasAbs ? REDUX_F32_ID_IMPL(op, _abs, hasNaN) : REDUX_F32_ID_IMPL(op, , hasNaN)
35
36static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
37 NVVM::ReduxKind kind,
38 bool hasAbs, bool hasNaN) {
39 switch (kind) {
40 case NVVM::ReduxKind::ADD:
41 return llvm::Intrinsic::nvvm_redux_sync_add;
42 case NVVM::ReduxKind::UMAX:
43 return llvm::Intrinsic::nvvm_redux_sync_umax;
44 case NVVM::ReduxKind::UMIN:
45 return llvm::Intrinsic::nvvm_redux_sync_umin;
46 case NVVM::ReduxKind::AND:
47 return llvm::Intrinsic::nvvm_redux_sync_and;
48 case NVVM::ReduxKind::OR:
49 return llvm::Intrinsic::nvvm_redux_sync_or;
50 case NVVM::ReduxKind::XOR:
51 return llvm::Intrinsic::nvvm_redux_sync_xor;
52 case NVVM::ReduxKind::MAX:
53 return llvm::Intrinsic::nvvm_redux_sync_max;
54 case NVVM::ReduxKind::MIN:
55 return llvm::Intrinsic::nvvm_redux_sync_min;
56 case NVVM::ReduxKind::FMIN:
57 return GET_REDUX_F32_ID(min, hasAbs, hasNaN);
58 case NVVM::ReduxKind::FMAX:
59 return GET_REDUX_F32_ID(max, hasAbs, hasNaN);
60 }
61 llvm_unreachable("unknown redux kind");
62}
63
64static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
65 NVVM::ShflKind kind,
66 bool withPredicate) {
67
68 if (withPredicate) {
69 resultType = cast<llvm::StructType>(resultType)->getElementType(0);
70 switch (kind) {
71 case NVVM::ShflKind::bfly:
72 return resultType->isFloatTy()
73 ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
74 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
75 case NVVM::ShflKind::up:
76 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
77 : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
78 case NVVM::ShflKind::down:
79 return resultType->isFloatTy()
80 ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
81 : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
82 case NVVM::ShflKind::idx:
83 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
84 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
85 }
86 } else {
87 switch (kind) {
88 case NVVM::ShflKind::bfly:
89 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
90 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
91 case NVVM::ShflKind::up:
92 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
93 : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
94 case NVVM::ShflKind::down:
95 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
96 : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
97 case NVVM::ShflKind::idx:
98 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
99 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
100 }
101 }
102 llvm_unreachable("unknown shuffle kind");
103}
104
105static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType,
106 NVVM::MatchSyncKind kind) {
107 switch (kind) {
108 case NVVM::MatchSyncKind::any:
109 return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_any_sync_i32
110 : llvm::Intrinsic::nvvm_match_any_sync_i64;
111 case NVVM::MatchSyncKind::all:
112 // match.all instruction has two variants -- one returns a single value,
113 // another returns a pair {value, predicate}. We currently only implement
114 // the latter as that's the variant exposed by CUDA API.
115 return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_all_sync_i32p
116 : llvm::Intrinsic::nvvm_match_all_sync_i64p;
117 }
118 llvm_unreachable("unsupported match sync kind");
119}
120
121static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
122 switch (kind) {
123 case NVVM::VoteSyncKind::any:
124 return llvm::Intrinsic::nvvm_vote_any_sync;
125 case NVVM::VoteSyncKind::all:
126 return llvm::Intrinsic::nvvm_vote_all_sync;
127 case NVVM::VoteSyncKind::ballot:
128 return llvm::Intrinsic::nvvm_vote_ballot_sync;
129 case NVVM::VoteSyncKind::uni:
130 return llvm::Intrinsic::nvvm_vote_uni_sync;
131 }
132 llvm_unreachable("unsupported vote kind");
133}
134
135static llvm::Intrinsic::ID
136getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
137 NVVM::LdStMatrixShapeAttr shape,
138 NVVM::LdStMatrixEltType eltType) {
139 if (shape.getM() == 8 && shape.getN() == 8) {
140 switch (num) {
141 case 1:
142 return (layout == NVVM::MMALayout::row)
143 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16
144 : llvm::Intrinsic::
145 nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
146 case 2:
147 return (layout == NVVM::MMALayout::row)
148 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16
149 : llvm::Intrinsic::
150 nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
151 case 4:
152 return (layout == NVVM::MMALayout::row)
153 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16
154 : llvm::Intrinsic::
155 nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
156 }
157 } else if (shape.getM() == 8 && shape.getN() == 16) {
158 if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
159 switch (num) {
160 case 1:
161 return llvm::Intrinsic::
162 nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32;
163 case 2:
164 return llvm::Intrinsic::
165 nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32;
166 case 4:
167 return llvm::Intrinsic::
168 nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
169 }
170 } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
171 switch (num) {
172 case 1:
173 return llvm::Intrinsic::
174 nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64;
175 case 2:
176 return llvm::Intrinsic::
177 nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64;
178 case 4:
179 return llvm::Intrinsic::
180 nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64;
181 }
182 }
183 } else if (shape.getM() == 16 && shape.getN() == 16) {
184 if (eltType == NVVM::LdStMatrixEltType::B8) {
185 switch (num) {
186 case 1:
187 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8;
188 case 2:
189 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
190 }
191 } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
192 switch (num) {
193 case 1:
194 return llvm::Intrinsic::
195 nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32;
196 case 2:
197 return llvm::Intrinsic::
198 nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
199 }
200 } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
201 switch (num) {
202 case 1:
203 return llvm::Intrinsic::
204 nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64;
205 case 2:
206 return llvm::Intrinsic::
207 nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64;
208 }
209 }
210 }
211 llvm_unreachable("unknown ldmatrix kind");
212}
213
214/// Return the intrinsic ID associated with stmatrix for the given paramters.
215static llvm::Intrinsic::ID
216getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
217 NVVM::LdStMatrixShapeAttr shape,
218 NVVM::LdStMatrixEltType eltType) {
219 if (shape.getM() == 8 && shape.getN() == 8) {
220 switch (num) {
221 case 1:
222 return (layout == NVVM::MMALayout::row)
223 ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16
224 : llvm::Intrinsic::
225 nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16;
226 case 2:
227 return (layout == NVVM::MMALayout::row)
228 ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16
229 : llvm::Intrinsic::
230 nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16;
231 case 4:
232 return (layout == NVVM::MMALayout::row)
233 ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16
234 : llvm::Intrinsic::
235 nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16;
236 }
237 } else if (shape.getM() == 16 && shape.getN() == 8) {
238 switch (num) {
239 case 1:
240 return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8;
241 case 2:
242 return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8;
243 case 4:
244 return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8;
245 }
246 }
247 llvm_unreachable("unknown stmatrix kind");
248}
249
250/// Return the intrinsic ID associated with st.bulk for the given address type.
251static llvm::Intrinsic::ID
252getStBulkIntrinsicId(LLVM::LLVMPointerType addrType) {
253 bool isSharedMemory = addrType.getAddressSpace() ==
254 static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
255 return isSharedMemory ? llvm::Intrinsic::nvvm_st_bulk_shared_cta
256 : llvm::Intrinsic::nvvm_st_bulk;
257}
258
259static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy,
260 NVVM::ProxyKind toProxy,
261 NVVM::MemScopeKind scope,
262 bool isRelease) {
263 if (fromProxy == NVVM::ProxyKind::GENERIC &&
264 toProxy == NVVM::ProxyKind::TENSORMAP) {
265 switch (scope) {
266 case NVVM::MemScopeKind::CTA: {
267 if (isRelease)
268 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cta;
269 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cta;
270 }
271 case NVVM::MemScopeKind::CLUSTER: {
272 if (isRelease)
273 return llvm::Intrinsic::
274 nvvm_fence_proxy_tensormap_generic_release_cluster;
275 return llvm::Intrinsic::
276 nvvm_fence_proxy_tensormap_generic_acquire_cluster;
277 }
278 case NVVM::MemScopeKind::GPU: {
279 if (isRelease)
280 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_gpu;
281 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_gpu;
282 }
283 case NVVM::MemScopeKind::SYS: {
284 if (isRelease)
285 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_sys;
286 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_sys;
287 }
288 }
289 llvm_unreachable("Unknown scope for uni-directional fence.proxy operation");
290 }
291 llvm_unreachable("Unsupported proxy kinds");
292}
293
294static unsigned getMembarIntrinsicID(NVVM::MemScopeKind scope) {
295 switch (scope) {
296 case NVVM::MemScopeKind::CTA:
297 return llvm::Intrinsic::nvvm_membar_cta;
298 case NVVM::MemScopeKind::CLUSTER:
299 return llvm::Intrinsic::nvvm_fence_sc_cluster;
300 case NVVM::MemScopeKind::GPU:
301 return llvm::Intrinsic::nvvm_membar_gl;
302 case NVVM::MemScopeKind::SYS:
303 return llvm::Intrinsic::nvvm_membar_sys;
304 }
305 llvm_unreachable("Unknown scope for memory barrier");
306}
307
308#define TCGEN05LD(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_ld_##SHAPE##_##NUM
309
310static llvm::Intrinsic::ID
311getTcgen05LdIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num) {
312 llvm::Intrinsic::ID Shape16x64b[] = {
313 TCGEN05LD(16x64b, x1), TCGEN05LD(16x64b, x2), TCGEN05LD(16x64b, x4),
314 TCGEN05LD(16x64b, x8), TCGEN05LD(16x64b, x16), TCGEN05LD(16x64b, x32),
315 TCGEN05LD(16x64b, x64), TCGEN05LD(16x64b, x128),
316 };
317
318 llvm::Intrinsic::ID Shape16x128b[] = {
319 TCGEN05LD(16x128b, x1), TCGEN05LD(16x128b, x2), TCGEN05LD(16x128b, x4),
320 TCGEN05LD(16x128b, x8), TCGEN05LD(16x128b, x16), TCGEN05LD(16x128b, x32),
321 TCGEN05LD(16x128b, x64),
322 };
323
324 llvm::Intrinsic::ID Shape16x256b[] = {
325 TCGEN05LD(16x256b, x1), TCGEN05LD(16x256b, x2), TCGEN05LD(16x256b, x4),
326 TCGEN05LD(16x256b, x8), TCGEN05LD(16x256b, x16), TCGEN05LD(16x256b, x32),
327 };
328
329 llvm::Intrinsic::ID Shape16x32bx2[] = {
330 TCGEN05LD(16x32bx2, x1), TCGEN05LD(16x32bx2, x2),
331 TCGEN05LD(16x32bx2, x4), TCGEN05LD(16x32bx2, x8),
332 TCGEN05LD(16x32bx2, x16), TCGEN05LD(16x32bx2, x32),
333 TCGEN05LD(16x32bx2, x64), TCGEN05LD(16x32bx2, x128),
334 };
335
336 llvm::Intrinsic::ID Shape32x32b[] = {
337 TCGEN05LD(32x32b, x1), TCGEN05LD(32x32b, x2), TCGEN05LD(32x32b, x4),
338 TCGEN05LD(32x32b, x8), TCGEN05LD(32x32b, x16), TCGEN05LD(32x32b, x32),
339 TCGEN05LD(32x32b, x64), TCGEN05LD(32x32b, x128),
340 };
341
342 // `num` contains the length of vector and log2 of `num` returns the index
343 // into the shape array
344 unsigned Idx = std::log2(num);
345
346 switch (shape) {
347 case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
348 return Shape16x64b[Idx];
349 case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
350 return Shape16x128b[Idx - 1];
351 case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
352 return Shape16x256b[Idx - 2];
353 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
354 return Shape32x32b[Idx];
355 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
356 return Shape16x32bx2[Idx];
357 }
358 llvm_unreachable("unhandled tcgen05.ld lowering");
359}
360
361#define TCGEN05ST(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_st_##SHAPE##_##NUM
362
363static llvm::Intrinsic::ID
364getTcgen05StIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num) {
365 llvm::Intrinsic::ID Shape16x64b[] = {
366 TCGEN05ST(16x64b, x1), TCGEN05ST(16x64b, x2), TCGEN05ST(16x64b, x4),
367 TCGEN05ST(16x64b, x8), TCGEN05ST(16x64b, x16), TCGEN05ST(16x64b, x32),
368 TCGEN05ST(16x64b, x64), TCGEN05ST(16x64b, x128),
369 };
370
371 llvm::Intrinsic::ID Shape16x128b[] = {
372 TCGEN05ST(16x128b, x1), TCGEN05ST(16x128b, x2), TCGEN05ST(16x128b, x4),
373 TCGEN05ST(16x128b, x8), TCGEN05ST(16x128b, x16), TCGEN05ST(16x128b, x32),
374 TCGEN05ST(16x128b, x64),
375 };
376
377 llvm::Intrinsic::ID Shape16x256b[] = {
378 TCGEN05ST(16x256b, x1), TCGEN05ST(16x256b, x2), TCGEN05ST(16x256b, x4),
379 TCGEN05ST(16x256b, x8), TCGEN05ST(16x256b, x16), TCGEN05ST(16x256b, x32),
380 };
381
382 llvm::Intrinsic::ID Shape16x32bx2[] = {
383 TCGEN05ST(16x32bx2, x1), TCGEN05ST(16x32bx2, x2),
384 TCGEN05ST(16x32bx2, x4), TCGEN05ST(16x32bx2, x8),
385 TCGEN05ST(16x32bx2, x16), TCGEN05ST(16x32bx2, x32),
386 TCGEN05ST(16x32bx2, x64), TCGEN05ST(16x32bx2, x128),
387 };
388
389 llvm::Intrinsic::ID Shape32x32b[] = {
390 TCGEN05ST(32x32b, x1), TCGEN05ST(32x32b, x2), TCGEN05ST(32x32b, x4),
391 TCGEN05ST(32x32b, x8), TCGEN05ST(32x32b, x16), TCGEN05ST(32x32b, x32),
392 TCGEN05ST(32x32b, x64), TCGEN05ST(32x32b, x128),
393 };
394
395 // `num` contains the length of vector and log2 of `num` returns the index
396 // into the shape array
397 unsigned Idx = std::log2(num);
398
399 switch (shape) {
400 case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
401 return Shape16x64b[Idx];
402 case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
403 return Shape16x128b[Idx - 1];
404 case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
405 return Shape16x256b[Idx - 2];
406 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
407 return Shape32x32b[Idx];
408 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
409 return Shape16x32bx2[Idx];
410 }
411 llvm_unreachable("unhandled tcgen05.st lowering");
412}
413
414namespace {
415/// Implementation of the dialect interface that converts operations belonging
416/// to the NVVM dialect to LLVM IR.
417class NVVMDialectLLVMIRTranslationInterface
419public:
421
422 /// Translates the given operation to LLVM IR using the provided IR builder
423 /// and saving the state in `moduleTranslation`.
424 LogicalResult
425 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
426 LLVM::ModuleTranslation &moduleTranslation) const final {
427 Operation &opInst = *op;
428#include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
429
430 return failure();
431 }
432
433 /// Attaches module-level metadata for functions marked as kernels.
434 LogicalResult
435 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
436 NamedAttribute attribute,
437 LLVM::ModuleTranslation &moduleTranslation) const final {
438 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
439 if (!func)
440 return failure();
441 llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName());
442
443 if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
444 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
445 return failure();
446 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
447 const std::string attr = llvm::formatv(
448 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
449 values.asArrayRef().end()));
450 llvmFunc->addFnAttr("nvvm.maxntid", attr);
451 } else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
452 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
453 return failure();
454 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
455 const std::string attr = llvm::formatv(
456 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
457 values.asArrayRef().end()));
458 llvmFunc->addFnAttr("nvvm.reqntid", attr);
459 } else if (attribute.getName() ==
460 NVVM::NVVMDialect::getClusterDimAttrName()) {
461 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
462 return failure();
463 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
464 const std::string attr = llvm::formatv(
465 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
466 values.asArrayRef().end()));
467 llvmFunc->addFnAttr("nvvm.cluster_dim", attr);
468 } else if (attribute.getName() ==
469 NVVM::NVVMDialect::getClusterMaxBlocksAttrName()) {
470 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
471 llvmFunc->addFnAttr("nvvm.maxclusterrank", llvm::utostr(value.getInt()));
472 } else if (attribute.getName() ==
473 NVVM::NVVMDialect::getMinctasmAttrName()) {
474 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
475 llvmFunc->addFnAttr("nvvm.minctasm", llvm::utostr(value.getInt()));
476 } else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
477 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
478 llvmFunc->addFnAttr("nvvm.maxnreg", llvm::utostr(value.getInt()));
479 } else if (attribute.getName() ==
480 NVVM::NVVMDialect::getKernelFuncAttrName()) {
481 llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel);
482 } else if (attribute.getName() ==
483 NVVM::NVVMDialect::getBlocksAreClustersAttrName()) {
484 llvmFunc->addFnAttr("nvvm.blocksareclusters");
485 }
486
487 return success();
488 }
489
490 LogicalResult
491 convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute,
492 LLVM::ModuleTranslation &moduleTranslation) const final {
493
494 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
495 llvm::Function *llvmFunc =
496 moduleTranslation.lookupFunction(funcOp.getName());
497
498 if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
499 llvmFunc->addParamAttr(
500 argIdx, llvm::Attribute::get(llvmContext, "nvvm.grid_constant"));
501 }
502 return success();
503 }
504};
505} // namespace
506
508 registry.insert<NVVM::NVVMDialect>();
509 registry.addExtension(+[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) {
510 dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>();
511 });
512}
513
515 DialectRegistry registry;
517 context.appendDialectRegistry(registry);
518}
return success()
static LogicalResult convertParameterAttr(llvm::AttrBuilder &attrBuilder, llvm::Attribute::AttrKind llvmKind, NamedAttribute namedAttr, ModuleTranslation &moduleTranslation, Location loc)
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, NVVM::LdStMatrixShapeAttr shape, NVVM::LdStMatrixEltType eltType)
#define GET_REDUX_F32_ID(op, hasAbs, hasNaN)
static llvm::Intrinsic::ID getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, NVVM::LdStMatrixShapeAttr shape, NVVM::LdStMatrixEltType eltType)
Return the intrinsic ID associated with stmatrix for the given paramters.
static llvm::Intrinsic::ID getTcgen05StIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num)
static llvm::Intrinsic::ID getTcgen05LdIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num)
static unsigned getMembarIntrinsicID(NVVM::MemScopeKind scope)
static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy, NVVM::ProxyKind toProxy, NVVM::MemScopeKind scope, bool isRelease)
#define TCGEN05ST(SHAPE, NUM)
static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, NVVM::ReduxKind kind, bool hasAbs, bool hasNaN)
#define TCGEN05LD(SHAPE, NUM)
static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, NVVM::ShflKind kind, bool withPredicate)
static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind)
static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType, NVVM::MatchSyncKind kind)
static llvm::Intrinsic::ID getStBulkIntrinsicId(LLVM::LLVMPointerType addrType)
Return the intrinsic ID associated with st.bulk for the given address type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Base class for dialect interfaces providing translation to LLVM IR.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
Include the generated interface declarations.
void registerNVVMDialectTranslation(DialectRegistry &registry)
Register the NVVM dialect and the translation from it to the LLVM IR in the given registry;.