19#include "llvm/ADT/StringExtras.h"
20#include "llvm/ADT/iterator_range.h"
21#include "llvm/IR/IRBuilder.h"
22#include "llvm/IR/IntrinsicsNVPTX.h"
23#include "llvm/Support/FormatVariadic.h"
29#define REDUX_F32_ID_IMPL(op, abs, hasNaN) \
30 hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##_NaN \
31 : llvm::Intrinsic::nvvm_redux_sync_f##op##abs
33#define GET_REDUX_F32_ID(op, hasAbs, hasNaN) \
34 hasAbs ? REDUX_F32_ID_IMPL(op, _abs, hasNaN) : REDUX_F32_ID_IMPL(op, , hasNaN)
37 NVVM::ReductionKind kind,
38 bool hasAbs,
bool hasNaN) {
40 case NVVM::ReductionKind::ADD:
41 return llvm::Intrinsic::nvvm_redux_sync_add;
42 case NVVM::ReductionKind::UMAX:
43 return llvm::Intrinsic::nvvm_redux_sync_umax;
44 case NVVM::ReductionKind::UMIN:
45 return llvm::Intrinsic::nvvm_redux_sync_umin;
46 case NVVM::ReductionKind::AND:
47 return llvm::Intrinsic::nvvm_redux_sync_and;
48 case NVVM::ReductionKind::OR:
49 return llvm::Intrinsic::nvvm_redux_sync_or;
50 case NVVM::ReductionKind::XOR:
51 return llvm::Intrinsic::nvvm_redux_sync_xor;
52 case NVVM::ReductionKind::MAX:
53 return llvm::Intrinsic::nvvm_redux_sync_max;
54 case NVVM::ReductionKind::MIN:
55 return llvm::Intrinsic::nvvm_redux_sync_min;
56 case NVVM::ReductionKind::FMIN:
58 case NVVM::ReductionKind::FMAX:
61 llvm_unreachable(
"unknown reduction kind");
69 resultType = cast<llvm::StructType>(resultType)->getElementType(0);
71 case NVVM::ShflKind::bfly:
72 return resultType->isFloatTy()
73 ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
74 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
75 case NVVM::ShflKind::up:
76 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
77 : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
78 case NVVM::ShflKind::down:
79 return resultType->isFloatTy()
80 ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
81 : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
82 case NVVM::ShflKind::idx:
83 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
84 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
88 case NVVM::ShflKind::bfly:
89 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
90 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
91 case NVVM::ShflKind::up:
92 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
93 : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
94 case NVVM::ShflKind::down:
95 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
96 : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
97 case NVVM::ShflKind::idx:
98 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
99 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
102 llvm_unreachable(
"unknown shuffle kind");
106 NVVM::MatchSyncKind kind) {
108 case NVVM::MatchSyncKind::any:
109 return valType.
isInteger(32) ? llvm::Intrinsic::nvvm_match_any_sync_i32
110 : llvm::Intrinsic::nvvm_match_any_sync_i64;
111 case NVVM::MatchSyncKind::all:
115 return valType.
isInteger(32) ? llvm::Intrinsic::nvvm_match_all_sync_i32p
116 : llvm::Intrinsic::nvvm_match_all_sync_i64p;
118 llvm_unreachable(
"unsupported match sync kind");
123 case NVVM::VoteSyncKind::any:
124 return llvm::Intrinsic::nvvm_vote_any_sync;
125 case NVVM::VoteSyncKind::all:
126 return llvm::Intrinsic::nvvm_vote_all_sync;
127 case NVVM::VoteSyncKind::ballot:
128 return llvm::Intrinsic::nvvm_vote_ballot_sync;
129 case NVVM::VoteSyncKind::uni:
130 return llvm::Intrinsic::nvvm_vote_uni_sync;
132 llvm_unreachable(
"unsupported vote kind");
135static llvm::Intrinsic::ID
137 NVVM::LdStMatrixShapeAttr
shape,
138 NVVM::LdStMatrixEltType eltType) {
142 return (layout == NVVM::MMALayout::row)
143 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16
145 nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
147 return (layout == NVVM::MMALayout::row)
148 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16
150 nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
152 return (layout == NVVM::MMALayout::row)
153 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16
155 nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
157 }
else if (
shape.getM() == 8 &&
shape.getN() == 16) {
158 if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
161 return llvm::Intrinsic::
162 nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32;
164 return llvm::Intrinsic::
165 nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32;
167 return llvm::Intrinsic::
168 nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
170 }
else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
173 return llvm::Intrinsic::
174 nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64;
176 return llvm::Intrinsic::
177 nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64;
179 return llvm::Intrinsic::
180 nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64;
183 }
else if (
shape.getM() == 16 &&
shape.getN() == 16) {
184 if (eltType == NVVM::LdStMatrixEltType::B8) {
187 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8;
189 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
191 }
else if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
194 return llvm::Intrinsic::
195 nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32;
197 return llvm::Intrinsic::
198 nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
200 }
else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
203 return llvm::Intrinsic::
204 nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64;
206 return llvm::Intrinsic::
207 nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64;
211 llvm_unreachable(
"unknown ldmatrix kind");
215static llvm::Intrinsic::ID
217 NVVM::LdStMatrixShapeAttr
shape,
218 NVVM::LdStMatrixEltType eltType) {
222 return (layout == NVVM::MMALayout::row)
223 ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16
225 nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16;
227 return (layout == NVVM::MMALayout::row)
228 ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16
230 nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16;
232 return (layout == NVVM::MMALayout::row)
233 ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16
235 nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16;
237 }
else if (
shape.getM() == 16 &&
shape.getN() == 8) {
240 return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8;
242 return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8;
244 return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8;
247 llvm_unreachable(
"unknown stmatrix kind");
251static llvm::Intrinsic::ID
254 static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
256 : llvm::Intrinsic::nvvm_st_bulk;
260 NVVM::ProxyKind toProxy,
261 NVVM::MemScopeKind scope,
263 if (fromProxy == NVVM::ProxyKind::GENERIC &&
264 toProxy == NVVM::ProxyKind::TENSORMAP) {
266 case NVVM::MemScopeKind::CTA: {
268 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cta;
269 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cta;
271 case NVVM::MemScopeKind::CLUSTER: {
273 return llvm::Intrinsic::
274 nvvm_fence_proxy_tensormap_generic_release_cluster;
275 return llvm::Intrinsic::
276 nvvm_fence_proxy_tensormap_generic_acquire_cluster;
278 case NVVM::MemScopeKind::GPU: {
280 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_gpu;
281 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_gpu;
283 case NVVM::MemScopeKind::SYS: {
285 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_sys;
286 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_sys;
289 llvm_unreachable(
"Unknown scope for uni-directional fence.proxy operation");
291 llvm_unreachable(
"Unsupported proxy kinds");
296 case NVVM::MemScopeKind::CTA:
297 return llvm::Intrinsic::nvvm_membar_cta;
298 case NVVM::MemScopeKind::CLUSTER:
299 return llvm::Intrinsic::nvvm_fence_sc_cluster;
300 case NVVM::MemScopeKind::GPU:
301 return llvm::Intrinsic::nvvm_membar_gl;
302 case NVVM::MemScopeKind::SYS:
303 return llvm::Intrinsic::nvvm_membar_sys;
305 llvm_unreachable(
"Unknown scope for memory barrier");
308#define TCGEN05LD(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_ld_##SHAPE##_##NUM
310static llvm::Intrinsic::ID
312 llvm::Intrinsic::ID Shape16x64b[] = {
318 llvm::Intrinsic::ID Shape16x128b[] = {
324 llvm::Intrinsic::ID Shape16x256b[] = {
329 llvm::Intrinsic::ID Shape16x32bx2[] = {
336 llvm::Intrinsic::ID Shape32x32b[] = {
344 unsigned Idx = std::log2(num);
347 case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
348 return Shape16x64b[Idx];
349 case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
350 return Shape16x128b[Idx - 1];
351 case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
352 return Shape16x256b[Idx - 2];
353 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
354 return Shape32x32b[Idx];
355 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
356 return Shape16x32bx2[Idx];
358 llvm_unreachable(
"unhandled tcgen05.ld lowering");
361#define TCGEN05ST(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_st_##SHAPE##_##NUM
363static llvm::Intrinsic::ID
365 llvm::Intrinsic::ID Shape16x64b[] = {
371 llvm::Intrinsic::ID Shape16x128b[] = {
377 llvm::Intrinsic::ID Shape16x256b[] = {
382 llvm::Intrinsic::ID Shape16x32bx2[] = {
389 llvm::Intrinsic::ID Shape32x32b[] = {
397 unsigned Idx = std::log2(num);
400 case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
401 return Shape16x64b[Idx];
402 case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
403 return Shape16x128b[Idx - 1];
404 case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
405 return Shape16x256b[Idx - 2];
406 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
407 return Shape32x32b[Idx];
408 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
409 return Shape16x32bx2[Idx];
411 llvm_unreachable(
"unhandled tcgen05.st lowering");
415 return order == NVVM::MemOrderKind::ACQUIRE
417 nvvm_fence_acquire_sync_restrict_space_cluster_scope_cluster
419 nvvm_fence_release_sync_restrict_space_cta_scope_cluster;
422static llvm::Intrinsic::ID
425 case NVVM::ProxyKind::alias:
426 return llvm::Intrinsic::nvvm_fence_proxy_alias;
427 case NVVM::ProxyKind::async:
428 return llvm::Intrinsic::nvvm_fence_proxy_async;
429 case NVVM::ProxyKind::async_global:
430 return llvm::Intrinsic::nvvm_fence_proxy_async_global;
431 case NVVM::ProxyKind::async_shared:
432 return *space == NVVM::SharedSpace::shared_cta
433 ? llvm::Intrinsic::nvvm_fence_proxy_async_shared_cta
434 : llvm::Intrinsic::nvvm_fence_proxy_async_shared_cluster;
436 llvm_unreachable(
"unsupported proxy kind");
440static llvm::Intrinsic::ID
442 return order == NVVM::MemOrderKind::ACQUIRE
444 nvvm_fence_proxy_async_generic_acquire_sync_restrict_space_cluster_scope_cluster
446 nvvm_fence_proxy_async_generic_release_sync_restrict_space_cta_scope_cluster;
455 llvm::Intrinsic::ID IID, llvm::Type *opTypeLLVM,
457 llvm::Type *retType) {
458 if (opTypeLLVM->isVectorTy() && (opTypeLLVM->getScalarType()->isFloatTy() ||
459 opTypeLLVM->getScalarType()->isDoubleTy())) {
460 llvm::Value *
result = llvm::PoisonValue::get(
461 llvm::FixedVectorType::get(opTypeLLVM->getScalarType(), 2));
462 for (
int64_t i = 0; i < 2; ++i) {
464 for (llvm::Value *op : operands)
465 scalarArgs.push_back(
466 builder.CreateExtractElement(op, builder.getInt32(i)));
468 result = builder.CreateInsertElement(
result, res, builder.getInt32(i));
476void NVVM::AddFOp::lowerAddFToLLVMIR(llvm::Value *argLHS, llvm::Value *argRHS,
477 Value res, NVVM::FPRoundingMode rndMode,
478 NVVM::SaturationMode satMode,
bool isFTZ,
480 llvm::IRBuilderBase &builder) {
481 llvm::Type *opTypeLLVM = argLHS->getType();
482 bool isVectorOp = opTypeLLVM->isVectorTy();
483 bool isSat = satMode != NVVM::SaturationMode::NONE;
487 static constexpr llvm::Intrinsic::ID f16IDs[] = {
488 llvm::Intrinsic::nvvm_add_rn_sat_f16,
489 llvm::Intrinsic::nvvm_add_rn_ftz_sat_f16,
490 llvm::Intrinsic::nvvm_add_rn_sat_v2f16,
491 llvm::Intrinsic::nvvm_add_rn_ftz_sat_v2f16,
494 static constexpr llvm::Intrinsic::ID f32IDs[] = {
495 llvm::Intrinsic::nvvm_add_rn_f,
496 llvm::Intrinsic::nvvm_add_rn_f,
497 llvm::Intrinsic::nvvm_add_rm_f,
498 llvm::Intrinsic::nvvm_add_rp_f,
499 llvm::Intrinsic::nvvm_add_rz_f,
500 llvm::Intrinsic::nvvm_add_rn_sat_f,
501 llvm::Intrinsic::nvvm_add_rn_sat_f,
502 llvm::Intrinsic::nvvm_add_rm_sat_f,
503 llvm::Intrinsic::nvvm_add_rp_sat_f,
504 llvm::Intrinsic::nvvm_add_rz_sat_f,
505 llvm::Intrinsic::nvvm_add_rn_ftz_f,
506 llvm::Intrinsic::nvvm_add_rn_ftz_f,
507 llvm::Intrinsic::nvvm_add_rm_ftz_f,
508 llvm::Intrinsic::nvvm_add_rp_ftz_f,
509 llvm::Intrinsic::nvvm_add_rz_ftz_f,
510 llvm::Intrinsic::nvvm_add_rn_ftz_sat_f,
511 llvm::Intrinsic::nvvm_add_rn_ftz_sat_f,
512 llvm::Intrinsic::nvvm_add_rm_ftz_sat_f,
513 llvm::Intrinsic::nvvm_add_rp_ftz_sat_f,
514 llvm::Intrinsic::nvvm_add_rz_ftz_sat_f,
517 static constexpr llvm::Intrinsic::ID f64IDs[] = {
518 llvm::Intrinsic::nvvm_add_rn_d,
519 llvm::Intrinsic::nvvm_add_rn_d, llvm::Intrinsic::nvvm_add_rm_d,
520 llvm::Intrinsic::nvvm_add_rp_d, llvm::Intrinsic::nvvm_add_rz_d};
522 auto addIntrinsic = [&](llvm::Intrinsic::ID IID) -> llvm::Value * {
524 {argLHS, argRHS}, opTypeLLVM);
530 if (opTypeLLVM->getScalarType()->isHalfTy()) {
533 unsigned index = (isVectorOp << 1) | isFTZ;
536 result = builder.CreateFAdd(argLHS, argRHS);
543 if (opTypeLLVM->getScalarType()->isBFloatTy()) {
544 mt.
mapValue(res, builder.CreateFAdd(argLHS, argRHS));
549 if (opTypeLLVM->getScalarType()->isDoubleTy()) {
550 unsigned index =
static_cast<unsigned>(rndMode);
556 const unsigned numRndModes = 5;
557 if (opTypeLLVM->getScalarType()->isFloatTy()) {
559 ((isFTZ << 1) | isSat) * numRndModes + static_cast<unsigned>(rndMode);
566 llvm::IRBuilderBase &builder) {
567 auto thisOp = cast<NVVM::FmaOp>(op);
568 mlir::NVVM::FPRoundingMode rndMode = thisOp.getRnd();
569 unsigned rndIndex =
static_cast<unsigned>(rndMode) - 1;
570 mlir::NVVM::SaturationMode satMode = thisOp.getSat();
571 bool isFTZ = thisOp.getFtz();
572 bool isRelu = thisOp.getRelu();
573 bool isSat = satMode == NVVM::SaturationMode::SAT;
574 bool isOOB = thisOp.getOob();
576 mlir::Type opType = thisOp.getRes().getType();
578 bool isVectorFma = opTypeLLVM->isVectorTy();
584 static constexpr llvm::Intrinsic::ID f16IDs[] = {
585 llvm::Intrinsic::nvvm_fma_rn_f16,
586 llvm::Intrinsic::nvvm_fma_rn_f16x2,
587 llvm::Intrinsic::nvvm_fma_rn_ftz_f16,
588 llvm::Intrinsic::nvvm_fma_rn_ftz_f16x2,
589 llvm::Intrinsic::nvvm_fma_rn_sat_f16,
590 llvm::Intrinsic::nvvm_fma_rn_sat_f16x2,
591 llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f16,
592 llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f16x2,
593 llvm::Intrinsic::nvvm_fma_rn_relu_f16,
594 llvm::Intrinsic::nvvm_fma_rn_relu_f16x2,
595 llvm::Intrinsic::nvvm_fma_rn_ftz_relu_f16,
596 llvm::Intrinsic::nvvm_fma_rn_ftz_relu_f16x2};
598 static constexpr llvm::Intrinsic::ID bf16IDs[] = {
599 llvm::Intrinsic::nvvm_fma_rn_bf16, llvm::Intrinsic::nvvm_fma_rn_bf16x2,
600 llvm::Intrinsic::nvvm_fma_rn_relu_bf16,
601 llvm::Intrinsic::nvvm_fma_rn_relu_bf16x2};
603 static constexpr llvm::Intrinsic::ID f32IDs[] = {
604 llvm::Intrinsic::nvvm_fma_rn_f,
605 llvm::Intrinsic::nvvm_fma_rm_f,
606 llvm::Intrinsic::nvvm_fma_rp_f,
607 llvm::Intrinsic::nvvm_fma_rz_f,
608 llvm::Intrinsic::nvvm_fma_rn_sat_f,
609 llvm::Intrinsic::nvvm_fma_rm_sat_f,
610 llvm::Intrinsic::nvvm_fma_rp_sat_f,
611 llvm::Intrinsic::nvvm_fma_rz_sat_f,
612 llvm::Intrinsic::nvvm_fma_rn_ftz_f,
613 llvm::Intrinsic::nvvm_fma_rm_ftz_f,
614 llvm::Intrinsic::nvvm_fma_rp_ftz_f,
615 llvm::Intrinsic::nvvm_fma_rz_ftz_f,
616 llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f,
617 llvm::Intrinsic::nvvm_fma_rm_ftz_sat_f,
618 llvm::Intrinsic::nvvm_fma_rp_ftz_sat_f,
619 llvm::Intrinsic::nvvm_fma_rz_ftz_sat_f,
622 static constexpr llvm::Intrinsic::ID f64IDs[] = {
623 llvm::Intrinsic::nvvm_fma_rn_d, llvm::Intrinsic::nvvm_fma_rm_d,
624 llvm::Intrinsic::nvvm_fma_rp_d, llvm::Intrinsic::nvvm_fma_rz_d};
626 auto fmaIntrinsic = [&](llvm::Intrinsic::ID IID,
627 llvm::Type *retType) -> llvm::Value * {
629 builder, IID, opTypeLLVM, {argA, argB, argC}, retType);
633 if (opTypeLLVM->getScalarType()->isHalfTy()) {
636 result = fmaIntrinsic(isRelu ? llvm::Intrinsic::nvvm_fma_rn_oob_relu
637 : llvm::Intrinsic::nvvm_fma_rn_oob,
641 (isRelu << 3) | (isSat << 2) | (isFTZ << 1) |
650 if (opTypeLLVM->getScalarType()->isBFloatTy()) {
653 result = fmaIntrinsic(isRelu ? llvm::Intrinsic::nvvm_fma_rn_oob_relu
654 : llvm::Intrinsic::nvvm_fma_rn_oob,
657 unsigned index = (isRelu << 1) | isVectorFma;
665 if (opTypeLLVM->getScalarType()->isDoubleTy()) {
667 fmaIntrinsic(f64IDs[rndIndex], opTypeLLVM->getScalarType()));
672 const unsigned numRndModes = 4;
673 if (opTypeLLVM->getScalarType()->isFloatTy()) {
674 unsigned index = ((isFTZ << 1) | isSat) * numRndModes + rndIndex;
676 fmaIntrinsic(f32IDs[
index], opTypeLLVM->getScalarType()));
684class NVVMDialectLLVMIRTranslationInterface
685 :
public LLVMTranslationDialectInterface {
687 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
692 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
693 LLVM::ModuleTranslation &moduleTranslation)
const final {
694 Operation &opInst = *op;
695#include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
702 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
703 NamedAttribute attribute,
704 LLVM::ModuleTranslation &moduleTranslation)
const final {
705 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
708 llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName());
710 if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
711 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
713 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
714 const std::string attr = llvm::formatv(
715 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
716 values.asArrayRef().end()));
717 llvmFunc->addFnAttr(
"nvvm.maxntid", attr);
718 }
else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
719 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
721 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
722 const std::string attr = llvm::formatv(
723 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
724 values.asArrayRef().end()));
725 llvmFunc->addFnAttr(
"nvvm.reqntid", attr);
726 }
else if (attribute.getName() ==
727 NVVM::NVVMDialect::getClusterDimAttrName()) {
728 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
730 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
731 const std::string attr = llvm::formatv(
732 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
733 values.asArrayRef().end()));
734 llvmFunc->addFnAttr(
"nvvm.cluster_dim", attr);
735 }
else if (attribute.getName() ==
736 NVVM::NVVMDialect::getClusterMaxBlocksAttrName()) {
737 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
738 llvmFunc->addFnAttr(
"nvvm.maxclusterrank", llvm::utostr(value.getInt()));
739 }
else if (attribute.getName() ==
740 NVVM::NVVMDialect::getMinctasmAttrName()) {
741 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
742 llvmFunc->addFnAttr(
"nvvm.minctasm", llvm::utostr(value.getInt()));
743 }
else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
744 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
745 llvmFunc->addFnAttr(
"nvvm.maxnreg", llvm::utostr(value.getInt()));
746 }
else if (attribute.getName() ==
747 NVVM::NVVMDialect::getKernelFuncAttrName()) {
748 llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel);
749 }
else if (attribute.getName() ==
750 NVVM::NVVMDialect::getBlocksAreClustersAttrName()) {
751 llvmFunc->addFnAttr(
"nvvm.blocksareclusters");
759 LLVM::ModuleTranslation &moduleTranslation)
const final {
761 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
762 llvm::Function *llvmFunc =
763 moduleTranslation.lookupFunction(funcOp.getName());
765 if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
766 llvmFunc->addParamAttr(
767 argIdx, llvm::Attribute::get(llvmContext,
"nvvm.grid_constant"));
775 registry.
insert<NVVM::NVVMDialect>();
777 dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>();
static LogicalResult convertParameterAttr(llvm::AttrBuilder &attrBuilder, llvm::Attribute::AttrKind llvmKind, NamedAttribute namedAttr, ModuleTranslation &moduleTranslation, Location loc)
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, NVVM::LdStMatrixShapeAttr shape, NVVM::LdStMatrixEltType eltType)
static llvm::Intrinsic::ID getFenceProxyID(NVVM::ProxyKind kind, std::optional< NVVM::SharedSpace > space)
#define GET_REDUX_F32_ID(op, hasAbs, hasNaN)
static llvm::Intrinsic::ID getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, NVVM::LdStMatrixShapeAttr shape, NVVM::LdStMatrixEltType eltType)
Return the intrinsic ID associated with stmatrix for the given paramters.
static llvm::Intrinsic::ID getTcgen05StIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num)
static llvm::Intrinsic::ID getTcgen05LdIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num)
static unsigned getMembarIntrinsicID(NVVM::MemScopeKind scope)
static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy, NVVM::ProxyKind toProxy, NVVM::MemScopeKind scope, bool isRelease)
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
static llvm::Intrinsic::ID getFenceProxySyncRestrictID(NVVM::MemOrderKind order)
#define TCGEN05ST(SHAPE, NUM)
static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, NVVM::ReductionKind kind, bool hasAbs, bool hasNaN)
static llvm::Value * createScalarizedIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID IID, llvm::Type *opTypeLLVM, ArrayRef< llvm::Value * > operands, llvm::Type *retType)
#define TCGEN05LD(SHAPE, NUM)
static llvm::Intrinsic::ID getFenceSyncRestrictID(NVVM::MemOrderKind order)
static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, NVVM::ShflKind kind, bool withPredicate)
static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind)
static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType, NVVM::MatchSyncKind kind)
static llvm::Intrinsic::ID getStBulkIntrinsicId(LLVM::LLVMPointerType addrType)
Return the intrinsic ID associated with st.bulk for the given address type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
Operation is the basic unit of execution within MLIR.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
Include the generated interface declarations.
void registerNVVMDialectTranslation(DialectRegistry ®istry)
Register the NVVM dialect and the translation from it to the LLVM IR in the given registry;.