19#include "llvm/ADT/StringExtras.h"
20#include "llvm/ADT/iterator_range.h"
21#include "llvm/IR/IRBuilder.h"
22#include "llvm/IR/IntrinsicsNVPTX.h"
23#include "llvm/Support/FormatVariadic.h"
24#include "llvm/Support/NVVMAttributes.h"
30#define REDUX_F32_ID_IMPL(op, abs, hasNaN) \
31 hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##_NaN \
32 : llvm::Intrinsic::nvvm_redux_sync_f##op##abs
34#define GET_REDUX_F32_ID(op, hasAbs, hasNaN) \
35 hasAbs ? REDUX_F32_ID_IMPL(op, _abs, hasNaN) : REDUX_F32_ID_IMPL(op, , hasNaN)
38 NVVM::ReductionKind kind,
39 bool hasAbs,
bool hasNaN) {
41 case NVVM::ReductionKind::ADD:
42 return llvm::Intrinsic::nvvm_redux_sync_add;
43 case NVVM::ReductionKind::UMAX:
44 return llvm::Intrinsic::nvvm_redux_sync_umax;
45 case NVVM::ReductionKind::UMIN:
46 return llvm::Intrinsic::nvvm_redux_sync_umin;
47 case NVVM::ReductionKind::AND:
48 return llvm::Intrinsic::nvvm_redux_sync_and;
49 case NVVM::ReductionKind::OR:
50 return llvm::Intrinsic::nvvm_redux_sync_or;
51 case NVVM::ReductionKind::XOR:
52 return llvm::Intrinsic::nvvm_redux_sync_xor;
53 case NVVM::ReductionKind::MAX:
54 return llvm::Intrinsic::nvvm_redux_sync_max;
55 case NVVM::ReductionKind::MIN:
56 return llvm::Intrinsic::nvvm_redux_sync_min;
57 case NVVM::ReductionKind::FMIN:
59 case NVVM::ReductionKind::FMAX:
62 llvm_unreachable(
"unknown reduction kind");
70 resultType = cast<llvm::StructType>(resultType)->getElementType(0);
72 case NVVM::ShflKind::bfly:
73 return resultType->isFloatTy()
74 ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
75 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
76 case NVVM::ShflKind::up:
77 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
78 : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
79 case NVVM::ShflKind::down:
80 return resultType->isFloatTy()
81 ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
82 : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
83 case NVVM::ShflKind::idx:
84 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
85 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
89 case NVVM::ShflKind::bfly:
90 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
91 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
92 case NVVM::ShflKind::up:
93 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
94 : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
95 case NVVM::ShflKind::down:
96 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
97 : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
98 case NVVM::ShflKind::idx:
99 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
100 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
103 llvm_unreachable(
"unknown shuffle kind");
107 NVVM::MatchSyncKind kind) {
109 case NVVM::MatchSyncKind::any:
110 return valType.
isInteger(32) ? llvm::Intrinsic::nvvm_match_any_sync_i32
111 : llvm::Intrinsic::nvvm_match_any_sync_i64;
112 case NVVM::MatchSyncKind::all:
116 return valType.
isInteger(32) ? llvm::Intrinsic::nvvm_match_all_sync_i32p
117 : llvm::Intrinsic::nvvm_match_all_sync_i64p;
119 llvm_unreachable(
"unsupported match sync kind");
124 case NVVM::VoteSyncKind::any:
125 return llvm::Intrinsic::nvvm_vote_any_sync;
126 case NVVM::VoteSyncKind::all:
127 return llvm::Intrinsic::nvvm_vote_all_sync;
128 case NVVM::VoteSyncKind::ballot:
129 return llvm::Intrinsic::nvvm_vote_ballot_sync;
130 case NVVM::VoteSyncKind::uni:
131 return llvm::Intrinsic::nvvm_vote_uni_sync;
133 llvm_unreachable(
"unsupported vote kind");
136static llvm::Intrinsic::ID
138 NVVM::LdStMatrixShapeAttr
shape,
139 NVVM::LdStMatrixEltType eltType) {
143 return (layout == NVVM::MMALayout::row)
144 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16
146 nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
148 return (layout == NVVM::MMALayout::row)
149 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16
151 nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
153 return (layout == NVVM::MMALayout::row)
154 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16
156 nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
158 }
else if (
shape.getM() == 8 &&
shape.getN() == 16) {
159 if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
162 return llvm::Intrinsic::
163 nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32;
165 return llvm::Intrinsic::
166 nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32;
168 return llvm::Intrinsic::
169 nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
171 }
else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
174 return llvm::Intrinsic::
175 nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64;
177 return llvm::Intrinsic::
178 nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64;
180 return llvm::Intrinsic::
181 nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64;
184 }
else if (
shape.getM() == 16 &&
shape.getN() == 16) {
185 if (eltType == NVVM::LdStMatrixEltType::B8) {
188 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8;
190 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
192 }
else if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
195 return llvm::Intrinsic::
196 nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32;
198 return llvm::Intrinsic::
199 nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
201 }
else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
204 return llvm::Intrinsic::
205 nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64;
207 return llvm::Intrinsic::
208 nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64;
212 llvm_unreachable(
"unknown ldmatrix kind");
216static llvm::Intrinsic::ID
218 NVVM::LdStMatrixShapeAttr
shape,
219 NVVM::LdStMatrixEltType eltType) {
223 return (layout == NVVM::MMALayout::row)
224 ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16
226 nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16;
228 return (layout == NVVM::MMALayout::row)
229 ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16
231 nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16;
233 return (layout == NVVM::MMALayout::row)
234 ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16
236 nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16;
238 }
else if (
shape.getM() == 16 &&
shape.getN() == 8) {
241 return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8;
243 return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8;
245 return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8;
248 llvm_unreachable(
"unknown stmatrix kind");
252static llvm::Intrinsic::ID
255 static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
257 : llvm::Intrinsic::nvvm_st_bulk;
261 NVVM::ProxyKind toProxy,
262 NVVM::MemScopeKind scope,
264 if (fromProxy == NVVM::ProxyKind::GENERIC &&
265 toProxy == NVVM::ProxyKind::TENSORMAP) {
267 case NVVM::MemScopeKind::CTA: {
269 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cta;
270 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cta;
272 case NVVM::MemScopeKind::CLUSTER: {
274 return llvm::Intrinsic::
275 nvvm_fence_proxy_tensormap_generic_release_cluster;
276 return llvm::Intrinsic::
277 nvvm_fence_proxy_tensormap_generic_acquire_cluster;
279 case NVVM::MemScopeKind::GPU: {
281 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_gpu;
282 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_gpu;
284 case NVVM::MemScopeKind::SYS: {
286 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_sys;
287 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_sys;
290 llvm_unreachable(
"Unknown scope for uni-directional fence.proxy operation");
292 llvm_unreachable(
"Unsupported proxy kinds");
297 case NVVM::MemScopeKind::CTA:
298 return llvm::Intrinsic::nvvm_membar_cta;
299 case NVVM::MemScopeKind::CLUSTER:
300 return llvm::Intrinsic::nvvm_fence_sc_cluster;
301 case NVVM::MemScopeKind::GPU:
302 return llvm::Intrinsic::nvvm_membar_gl;
303 case NVVM::MemScopeKind::SYS:
304 return llvm::Intrinsic::nvvm_membar_sys;
306 llvm_unreachable(
"Unknown scope for memory barrier");
309#define TCGEN05LD(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_ld_##SHAPE##_##NUM
311static llvm::Intrinsic::ID
313 llvm::Intrinsic::ID Shape16x64b[] = {
319 llvm::Intrinsic::ID Shape16x128b[] = {
325 llvm::Intrinsic::ID Shape16x256b[] = {
330 llvm::Intrinsic::ID Shape16x32bx2[] = {
337 llvm::Intrinsic::ID Shape32x32b[] = {
345 unsigned Idx = std::log2(num);
348 case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
349 return Shape16x64b[Idx];
350 case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
351 return Shape16x128b[Idx - 1];
352 case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
353 return Shape16x256b[Idx - 2];
354 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
355 return Shape32x32b[Idx];
356 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
357 return Shape16x32bx2[Idx];
359 llvm_unreachable(
"unhandled tcgen05.ld lowering");
362#define TCGEN05ST(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_st_##SHAPE##_##NUM
364static llvm::Intrinsic::ID
366 llvm::Intrinsic::ID Shape16x64b[] = {
372 llvm::Intrinsic::ID Shape16x128b[] = {
378 llvm::Intrinsic::ID Shape16x256b[] = {
383 llvm::Intrinsic::ID Shape16x32bx2[] = {
390 llvm::Intrinsic::ID Shape32x32b[] = {
398 unsigned Idx = std::log2(num);
401 case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
402 return Shape16x64b[Idx];
403 case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
404 return Shape16x128b[Idx - 1];
405 case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
406 return Shape16x256b[Idx - 2];
407 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
408 return Shape32x32b[Idx];
409 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
410 return Shape16x32bx2[Idx];
412 llvm_unreachable(
"unhandled tcgen05.st lowering");
416 return order == NVVM::MemOrderKind::ACQUIRE
418 nvvm_fence_acquire_sync_restrict_space_cluster_scope_cluster
420 nvvm_fence_release_sync_restrict_space_cta_scope_cluster;
423static llvm::Intrinsic::ID
426 case NVVM::ProxyKind::alias:
427 return llvm::Intrinsic::nvvm_fence_proxy_alias;
428 case NVVM::ProxyKind::async:
429 return llvm::Intrinsic::nvvm_fence_proxy_async;
430 case NVVM::ProxyKind::async_global:
431 return llvm::Intrinsic::nvvm_fence_proxy_async_global;
432 case NVVM::ProxyKind::async_shared:
433 return *space == NVVM::SharedSpace::shared_cta
434 ? llvm::Intrinsic::nvvm_fence_proxy_async_shared_cta
435 : llvm::Intrinsic::nvvm_fence_proxy_async_shared_cluster;
437 llvm_unreachable(
"unsupported proxy kind");
441static llvm::Intrinsic::ID
443 return order == NVVM::MemOrderKind::ACQUIRE
445 nvvm_fence_proxy_async_generic_acquire_sync_restrict_space_cluster_scope_cluster
447 nvvm_fence_proxy_async_generic_release_sync_restrict_space_cta_scope_cluster;
456 llvm::Intrinsic::ID IID, llvm::Type *opTypeLLVM,
458 llvm::Type *retType) {
459 if (opTypeLLVM->isVectorTy() && (opTypeLLVM->getScalarType()->isFloatTy() ||
460 opTypeLLVM->getScalarType()->isDoubleTy())) {
461 llvm::Value *
result = llvm::PoisonValue::get(
462 llvm::FixedVectorType::get(opTypeLLVM->getScalarType(), 2));
463 for (
int64_t i = 0; i < 2; ++i) {
465 for (llvm::Value *op : operands)
466 scalarArgs.push_back(
467 builder.CreateExtractElement(op, builder.getInt32(i)));
469 result = builder.CreateInsertElement(
result, res, builder.getInt32(i));
477void NVVM::AddFOp::lowerAddFToLLVMIR(llvm::Value *argLHS, llvm::Value *argRHS,
478 Value res, NVVM::FPRoundingMode rndMode,
479 NVVM::SaturationMode satMode,
bool isFTZ,
481 llvm::IRBuilderBase &builder) {
482 llvm::Type *opTypeLLVM = argLHS->getType();
483 bool isVectorOp = opTypeLLVM->isVectorTy();
484 bool isSat = satMode != NVVM::SaturationMode::NONE;
488 static constexpr llvm::Intrinsic::ID f16IDs[] = {
489 llvm::Intrinsic::nvvm_add_rn_sat_f16,
490 llvm::Intrinsic::nvvm_add_rn_ftz_sat_f16,
491 llvm::Intrinsic::nvvm_add_rn_sat_v2f16,
492 llvm::Intrinsic::nvvm_add_rn_ftz_sat_v2f16,
495 static constexpr llvm::Intrinsic::ID f32IDs[] = {
496 llvm::Intrinsic::nvvm_add_rn_f,
497 llvm::Intrinsic::nvvm_add_rn_f,
498 llvm::Intrinsic::nvvm_add_rm_f,
499 llvm::Intrinsic::nvvm_add_rp_f,
500 llvm::Intrinsic::nvvm_add_rz_f,
501 llvm::Intrinsic::nvvm_add_rn_sat_f,
502 llvm::Intrinsic::nvvm_add_rn_sat_f,
503 llvm::Intrinsic::nvvm_add_rm_sat_f,
504 llvm::Intrinsic::nvvm_add_rp_sat_f,
505 llvm::Intrinsic::nvvm_add_rz_sat_f,
506 llvm::Intrinsic::nvvm_add_rn_ftz_f,
507 llvm::Intrinsic::nvvm_add_rn_ftz_f,
508 llvm::Intrinsic::nvvm_add_rm_ftz_f,
509 llvm::Intrinsic::nvvm_add_rp_ftz_f,
510 llvm::Intrinsic::nvvm_add_rz_ftz_f,
511 llvm::Intrinsic::nvvm_add_rn_ftz_sat_f,
512 llvm::Intrinsic::nvvm_add_rn_ftz_sat_f,
513 llvm::Intrinsic::nvvm_add_rm_ftz_sat_f,
514 llvm::Intrinsic::nvvm_add_rp_ftz_sat_f,
515 llvm::Intrinsic::nvvm_add_rz_ftz_sat_f,
518 static constexpr llvm::Intrinsic::ID f64IDs[] = {
519 llvm::Intrinsic::nvvm_add_rn_d,
520 llvm::Intrinsic::nvvm_add_rn_d, llvm::Intrinsic::nvvm_add_rm_d,
521 llvm::Intrinsic::nvvm_add_rp_d, llvm::Intrinsic::nvvm_add_rz_d};
523 auto addIntrinsic = [&](llvm::Intrinsic::ID IID) -> llvm::Value * {
525 {argLHS, argRHS}, opTypeLLVM);
531 if (opTypeLLVM->getScalarType()->isHalfTy()) {
534 unsigned index = (isVectorOp << 1) | isFTZ;
537 result = builder.CreateFAdd(argLHS, argRHS);
544 if (opTypeLLVM->getScalarType()->isBFloatTy()) {
545 mt.
mapValue(res, builder.CreateFAdd(argLHS, argRHS));
550 if (opTypeLLVM->getScalarType()->isDoubleTy()) {
551 unsigned index =
static_cast<unsigned>(rndMode);
557 const unsigned numRndModes = 5;
558 if (opTypeLLVM->getScalarType()->isFloatTy()) {
560 ((isFTZ << 1) | isSat) * numRndModes + static_cast<unsigned>(rndMode);
567 llvm::IRBuilderBase &builder) {
568 auto thisOp = cast<NVVM::FmaOp>(op);
569 mlir::NVVM::FPRoundingMode rndMode = thisOp.getRnd();
570 unsigned rndIndex =
static_cast<unsigned>(rndMode) - 1;
571 mlir::NVVM::SaturationMode satMode = thisOp.getSat();
572 bool isFTZ = thisOp.getFtz();
573 bool isRelu = thisOp.getRelu();
574 bool isSat = satMode == NVVM::SaturationMode::SAT;
575 bool isOOB = thisOp.getOob();
577 mlir::Type opType = thisOp.getRes().getType();
579 bool isVectorFma = opTypeLLVM->isVectorTy();
585 static constexpr llvm::Intrinsic::ID f16IDs[] = {
586 llvm::Intrinsic::nvvm_fma_rn_f16,
587 llvm::Intrinsic::nvvm_fma_rn_f16x2,
588 llvm::Intrinsic::nvvm_fma_rn_ftz_f16,
589 llvm::Intrinsic::nvvm_fma_rn_ftz_f16x2,
590 llvm::Intrinsic::nvvm_fma_rn_sat_f16,
591 llvm::Intrinsic::nvvm_fma_rn_sat_f16x2,
592 llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f16,
593 llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f16x2,
594 llvm::Intrinsic::nvvm_fma_rn_relu_f16,
595 llvm::Intrinsic::nvvm_fma_rn_relu_f16x2,
596 llvm::Intrinsic::nvvm_fma_rn_ftz_relu_f16,
597 llvm::Intrinsic::nvvm_fma_rn_ftz_relu_f16x2};
599 static constexpr llvm::Intrinsic::ID bf16IDs[] = {
600 llvm::Intrinsic::nvvm_fma_rn_bf16, llvm::Intrinsic::nvvm_fma_rn_bf16x2,
601 llvm::Intrinsic::nvvm_fma_rn_relu_bf16,
602 llvm::Intrinsic::nvvm_fma_rn_relu_bf16x2};
604 static constexpr llvm::Intrinsic::ID f32IDs[] = {
605 llvm::Intrinsic::nvvm_fma_rn_f,
606 llvm::Intrinsic::nvvm_fma_rm_f,
607 llvm::Intrinsic::nvvm_fma_rp_f,
608 llvm::Intrinsic::nvvm_fma_rz_f,
609 llvm::Intrinsic::nvvm_fma_rn_sat_f,
610 llvm::Intrinsic::nvvm_fma_rm_sat_f,
611 llvm::Intrinsic::nvvm_fma_rp_sat_f,
612 llvm::Intrinsic::nvvm_fma_rz_sat_f,
613 llvm::Intrinsic::nvvm_fma_rn_ftz_f,
614 llvm::Intrinsic::nvvm_fma_rm_ftz_f,
615 llvm::Intrinsic::nvvm_fma_rp_ftz_f,
616 llvm::Intrinsic::nvvm_fma_rz_ftz_f,
617 llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f,
618 llvm::Intrinsic::nvvm_fma_rm_ftz_sat_f,
619 llvm::Intrinsic::nvvm_fma_rp_ftz_sat_f,
620 llvm::Intrinsic::nvvm_fma_rz_ftz_sat_f,
623 static constexpr llvm::Intrinsic::ID f64IDs[] = {
624 llvm::Intrinsic::nvvm_fma_rn_d, llvm::Intrinsic::nvvm_fma_rm_d,
625 llvm::Intrinsic::nvvm_fma_rp_d, llvm::Intrinsic::nvvm_fma_rz_d};
627 auto fmaIntrinsic = [&](llvm::Intrinsic::ID IID,
628 llvm::Type *retType) -> llvm::Value * {
630 builder, IID, opTypeLLVM, {argA, argB, argC}, retType);
634 if (opTypeLLVM->getScalarType()->isHalfTy()) {
637 result = fmaIntrinsic(isRelu ? llvm::Intrinsic::nvvm_fma_rn_oob_relu
638 : llvm::Intrinsic::nvvm_fma_rn_oob,
642 (isRelu << 3) | (isSat << 2) | (isFTZ << 1) |
651 if (opTypeLLVM->getScalarType()->isBFloatTy()) {
654 result = fmaIntrinsic(isRelu ? llvm::Intrinsic::nvvm_fma_rn_oob_relu
655 : llvm::Intrinsic::nvvm_fma_rn_oob,
658 unsigned index = (isRelu << 1) | isVectorFma;
666 if (opTypeLLVM->getScalarType()->isDoubleTy()) {
668 fmaIntrinsic(f64IDs[rndIndex], opTypeLLVM->getScalarType()));
673 const unsigned numRndModes = 4;
674 if (opTypeLLVM->getScalarType()->isFloatTy()) {
675 unsigned index = ((isFTZ << 1) | isSat) * numRndModes + rndIndex;
677 fmaIntrinsic(f32IDs[
index], opTypeLLVM->getScalarType()));
685class NVVMDialectLLVMIRTranslationInterface
686 :
public LLVMTranslationDialectInterface {
688 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
693 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
694 LLVM::ModuleTranslation &moduleTranslation)
const final {
698 if (!builder.GetInsertBlock())
700 "cannot be translated to LLVM IR without an active insertion "
701 "point; make sure the op is inside a function");
702 Operation &opInst = *op;
703#include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
710 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
711 NamedAttribute attribute,
712 LLVM::ModuleTranslation &moduleTranslation)
const final {
713 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
716 llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName());
718 if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
719 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
721 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
722 const std::string attr = llvm::formatv(
723 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
724 values.asArrayRef().end()));
725 llvmFunc->addFnAttr(llvm::NVVMAttr::MaxNTID, attr);
726 }
else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
727 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
729 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
730 const std::string attr = llvm::formatv(
731 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
732 values.asArrayRef().end()));
733 llvmFunc->addFnAttr(llvm::NVVMAttr::ReqNTID, attr);
734 }
else if (attribute.getName() ==
735 NVVM::NVVMDialect::getClusterDimAttrName()) {
736 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
738 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
739 const std::string attr = llvm::formatv(
740 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
741 values.asArrayRef().end()));
742 llvmFunc->addFnAttr(llvm::NVVMAttr::ClusterDim, attr);
743 }
else if (attribute.getName() ==
744 NVVM::NVVMDialect::getClusterMaxBlocksAttrName()) {
745 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
746 llvmFunc->addFnAttr(llvm::NVVMAttr::MaxClusterRank,
747 llvm::utostr(value.getInt()));
748 }
else if (attribute.getName() ==
749 NVVM::NVVMDialect::getMinctasmAttrName()) {
750 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
751 llvmFunc->addFnAttr(llvm::NVVMAttr::MinCTASm,
752 llvm::utostr(value.getInt()));
753 }
else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
754 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
755 llvmFunc->addFnAttr(llvm::NVVMAttr::MaxNReg,
756 llvm::utostr(value.getInt()));
757 }
else if (attribute.getName() ==
758 NVVM::NVVMDialect::getKernelFuncAttrName()) {
759 llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel);
760 }
else if (attribute.getName() ==
761 NVVM::NVVMDialect::getBlocksAreClustersAttrName()) {
762 llvmFunc->addFnAttr(llvm::NVVMAttr::BlocksAreClusters);
770 LLVM::ModuleTranslation &moduleTranslation)
const final {
772 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
773 llvm::Function *llvmFunc =
774 moduleTranslation.lookupFunction(funcOp.getName());
776 if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
777 llvmFunc->addParamAttr(
779 llvm::Attribute::get(llvmContext, llvm::NVVMAttr::GridConstant));
787 registry.
insert<NVVM::NVVMDialect>();
789 dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>();
static LogicalResult convertParameterAttr(llvm::AttrBuilder &attrBuilder, llvm::Attribute::AttrKind llvmKind, NamedAttribute namedAttr, ModuleTranslation &moduleTranslation, Location loc)
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, NVVM::LdStMatrixShapeAttr shape, NVVM::LdStMatrixEltType eltType)
static llvm::Intrinsic::ID getFenceProxyID(NVVM::ProxyKind kind, std::optional< NVVM::SharedSpace > space)
#define GET_REDUX_F32_ID(op, hasAbs, hasNaN)
static llvm::Intrinsic::ID getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, NVVM::LdStMatrixShapeAttr shape, NVVM::LdStMatrixEltType eltType)
Return the intrinsic ID associated with stmatrix for the given paramters.
static llvm::Intrinsic::ID getTcgen05StIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num)
static llvm::Intrinsic::ID getTcgen05LdIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num)
static unsigned getMembarIntrinsicID(NVVM::MemScopeKind scope)
static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy, NVVM::ProxyKind toProxy, NVVM::MemScopeKind scope, bool isRelease)
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
static llvm::Intrinsic::ID getFenceProxySyncRestrictID(NVVM::MemOrderKind order)
#define TCGEN05ST(SHAPE, NUM)
static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, NVVM::ReductionKind kind, bool hasAbs, bool hasNaN)
static llvm::Value * createScalarizedIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID IID, llvm::Type *opTypeLLVM, ArrayRef< llvm::Value * > operands, llvm::Type *retType)
#define TCGEN05LD(SHAPE, NUM)
static llvm::Intrinsic::ID getFenceSyncRestrictID(NVVM::MemOrderKind order)
static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, NVVM::ShflKind kind, bool withPredicate)
static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind)
static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType, NVVM::MatchSyncKind kind)
static llvm::Intrinsic::ID getStBulkIntrinsicId(LLVM::LLVMPointerType addrType)
Return the intrinsic ID associated with st.bulk for the given address type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
Include the generated interface declarations.
void registerNVVMDialectTranslation(DialectRegistry ®istry)
Register the NVVM dialect and the translation from it to the LLVM IR in the given registry;.