19#include "llvm/ADT/StringExtras.h"
20#include "llvm/ADT/iterator_range.h"
21#include "llvm/IR/IRBuilder.h"
22#include "llvm/IR/IntrinsicsNVPTX.h"
23#include "llvm/Support/FormatVariadic.h"
29#define REDUX_F32_ID_IMPL(op, abs, hasNaN) \
30 hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##_NaN \
31 : llvm::Intrinsic::nvvm_redux_sync_f##op##abs
33#define GET_REDUX_F32_ID(op, hasAbs, hasNaN) \
34 hasAbs ? REDUX_F32_ID_IMPL(op, _abs, hasNaN) : REDUX_F32_ID_IMPL(op, , hasNaN)
38 bool hasAbs,
bool hasNaN) {
40 case NVVM::ReduxKind::ADD:
41 return llvm::Intrinsic::nvvm_redux_sync_add;
42 case NVVM::ReduxKind::UMAX:
43 return llvm::Intrinsic::nvvm_redux_sync_umax;
44 case NVVM::ReduxKind::UMIN:
45 return llvm::Intrinsic::nvvm_redux_sync_umin;
46 case NVVM::ReduxKind::AND:
47 return llvm::Intrinsic::nvvm_redux_sync_and;
48 case NVVM::ReduxKind::OR:
49 return llvm::Intrinsic::nvvm_redux_sync_or;
50 case NVVM::ReduxKind::XOR:
51 return llvm::Intrinsic::nvvm_redux_sync_xor;
52 case NVVM::ReduxKind::MAX:
53 return llvm::Intrinsic::nvvm_redux_sync_max;
54 case NVVM::ReduxKind::MIN:
55 return llvm::Intrinsic::nvvm_redux_sync_min;
56 case NVVM::ReduxKind::FMIN:
58 case NVVM::ReduxKind::FMAX:
61 llvm_unreachable(
"unknown redux kind");
69 resultType = cast<llvm::StructType>(resultType)->getElementType(0);
71 case NVVM::ShflKind::bfly:
72 return resultType->isFloatTy()
73 ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
74 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
75 case NVVM::ShflKind::up:
76 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
77 : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
78 case NVVM::ShflKind::down:
79 return resultType->isFloatTy()
80 ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
81 : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
82 case NVVM::ShflKind::idx:
83 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
84 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
88 case NVVM::ShflKind::bfly:
89 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
90 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
91 case NVVM::ShflKind::up:
92 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
93 : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
94 case NVVM::ShflKind::down:
95 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
96 : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
97 case NVVM::ShflKind::idx:
98 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
99 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
102 llvm_unreachable(
"unknown shuffle kind");
106 NVVM::MatchSyncKind kind) {
108 case NVVM::MatchSyncKind::any:
109 return valType.
isInteger(32) ? llvm::Intrinsic::nvvm_match_any_sync_i32
110 : llvm::Intrinsic::nvvm_match_any_sync_i64;
111 case NVVM::MatchSyncKind::all:
115 return valType.
isInteger(32) ? llvm::Intrinsic::nvvm_match_all_sync_i32p
116 : llvm::Intrinsic::nvvm_match_all_sync_i64p;
118 llvm_unreachable(
"unsupported match sync kind");
123 case NVVM::VoteSyncKind::any:
124 return llvm::Intrinsic::nvvm_vote_any_sync;
125 case NVVM::VoteSyncKind::all:
126 return llvm::Intrinsic::nvvm_vote_all_sync;
127 case NVVM::VoteSyncKind::ballot:
128 return llvm::Intrinsic::nvvm_vote_ballot_sync;
129 case NVVM::VoteSyncKind::uni:
130 return llvm::Intrinsic::nvvm_vote_uni_sync;
132 llvm_unreachable(
"unsupported vote kind");
135static llvm::Intrinsic::ID
137 NVVM::LdStMatrixShapeAttr
shape,
138 NVVM::LdStMatrixEltType eltType) {
142 return (layout == NVVM::MMALayout::row)
143 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16
145 nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
147 return (layout == NVVM::MMALayout::row)
148 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16
150 nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
152 return (layout == NVVM::MMALayout::row)
153 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16
155 nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
157 }
else if (
shape.getM() == 8 &&
shape.getN() == 16) {
158 if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
161 return llvm::Intrinsic::
162 nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32;
164 return llvm::Intrinsic::
165 nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32;
167 return llvm::Intrinsic::
168 nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
170 }
else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
173 return llvm::Intrinsic::
174 nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64;
176 return llvm::Intrinsic::
177 nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64;
179 return llvm::Intrinsic::
180 nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64;
183 }
else if (
shape.getM() == 16 &&
shape.getN() == 16) {
184 if (eltType == NVVM::LdStMatrixEltType::B8) {
187 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8;
189 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
191 }
else if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
194 return llvm::Intrinsic::
195 nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32;
197 return llvm::Intrinsic::
198 nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
200 }
else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
203 return llvm::Intrinsic::
204 nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64;
206 return llvm::Intrinsic::
207 nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64;
211 llvm_unreachable(
"unknown ldmatrix kind");
215static llvm::Intrinsic::ID
217 NVVM::LdStMatrixShapeAttr
shape,
218 NVVM::LdStMatrixEltType eltType) {
222 return (layout == NVVM::MMALayout::row)
223 ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16
225 nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16;
227 return (layout == NVVM::MMALayout::row)
228 ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16
230 nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16;
232 return (layout == NVVM::MMALayout::row)
233 ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16
235 nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16;
237 }
else if (
shape.getM() == 16 &&
shape.getN() == 8) {
240 return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8;
242 return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8;
244 return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8;
247 llvm_unreachable(
"unknown stmatrix kind");
251static llvm::Intrinsic::ID
254 static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
256 : llvm::Intrinsic::nvvm_st_bulk;
260 NVVM::ProxyKind toProxy,
261 NVVM::MemScopeKind scope,
263 if (fromProxy == NVVM::ProxyKind::GENERIC &&
264 toProxy == NVVM::ProxyKind::TENSORMAP) {
266 case NVVM::MemScopeKind::CTA: {
268 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cta;
269 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cta;
271 case NVVM::MemScopeKind::CLUSTER: {
273 return llvm::Intrinsic::
274 nvvm_fence_proxy_tensormap_generic_release_cluster;
275 return llvm::Intrinsic::
276 nvvm_fence_proxy_tensormap_generic_acquire_cluster;
278 case NVVM::MemScopeKind::GPU: {
280 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_gpu;
281 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_gpu;
283 case NVVM::MemScopeKind::SYS: {
285 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_sys;
286 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_sys;
289 llvm_unreachable(
"Unknown scope for uni-directional fence.proxy operation");
291 llvm_unreachable(
"Unsupported proxy kinds");
294#define TCGEN05LD(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_ld_##SHAPE##_##NUM
296static llvm::Intrinsic::ID
298 llvm::Intrinsic::ID Shape16x64b[] = {
304 llvm::Intrinsic::ID Shape16x128b[] = {
310 llvm::Intrinsic::ID Shape16x256b[] = {
315 llvm::Intrinsic::ID Shape16x32bx2[] = {
322 llvm::Intrinsic::ID Shape32x32b[] = {
330 unsigned Idx = std::log2(num);
333 case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
334 return Shape16x64b[Idx];
335 case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
336 return Shape16x128b[Idx - 1];
337 case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
338 return Shape16x256b[Idx - 2];
339 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
340 return Shape32x32b[Idx];
341 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
342 return Shape16x32bx2[Idx];
344 llvm_unreachable(
"unhandled tcgen05.ld lowering");
347#define TCGEN05ST(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_st_##SHAPE##_##NUM
349static llvm::Intrinsic::ID
351 llvm::Intrinsic::ID Shape16x64b[] = {
357 llvm::Intrinsic::ID Shape16x128b[] = {
363 llvm::Intrinsic::ID Shape16x256b[] = {
368 llvm::Intrinsic::ID Shape16x32bx2[] = {
375 llvm::Intrinsic::ID Shape32x32b[] = {
383 unsigned Idx = std::log2(num);
386 case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
387 return Shape16x64b[Idx];
388 case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
389 return Shape16x128b[Idx - 1];
390 case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
391 return Shape16x256b[Idx - 2];
392 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
393 return Shape32x32b[Idx];
394 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
395 return Shape16x32bx2[Idx];
397 llvm_unreachable(
"unhandled tcgen05.st lowering");
403class NVVMDialectLLVMIRTranslationInterface
411 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
412 LLVM::ModuleTranslation &moduleTranslation)
const final {
413 Operation &opInst = *op;
414#include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
421 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
422 NamedAttribute attribute,
423 LLVM::ModuleTranslation &moduleTranslation)
const final {
424 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
427 llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName());
429 if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
430 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
432 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
433 const std::string attr = llvm::formatv(
434 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
435 values.asArrayRef().end()));
436 llvmFunc->addFnAttr(
"nvvm.maxntid", attr);
437 }
else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
438 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
440 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
441 const std::string attr = llvm::formatv(
442 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
443 values.asArrayRef().end()));
444 llvmFunc->addFnAttr(
"nvvm.reqntid", attr);
445 }
else if (attribute.getName() ==
446 NVVM::NVVMDialect::getClusterDimAttrName()) {
447 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
449 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
450 const std::string attr = llvm::formatv(
451 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
452 values.asArrayRef().end()));
453 llvmFunc->addFnAttr(
"nvvm.cluster_dim", attr);
454 }
else if (attribute.getName() ==
455 NVVM::NVVMDialect::getClusterMaxBlocksAttrName()) {
456 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
457 llvmFunc->addFnAttr(
"nvvm.maxclusterrank", llvm::utostr(value.getInt()));
458 }
else if (attribute.getName() ==
459 NVVM::NVVMDialect::getMinctasmAttrName()) {
460 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
461 llvmFunc->addFnAttr(
"nvvm.minctasm", llvm::utostr(value.getInt()));
462 }
else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
463 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
464 llvmFunc->addFnAttr(
"nvvm.maxnreg", llvm::utostr(value.getInt()));
465 }
else if (attribute.getName() ==
466 NVVM::NVVMDialect::getKernelFuncAttrName()) {
467 llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel);
468 }
else if (attribute.getName() ==
469 NVVM::NVVMDialect::getBlocksAreClustersAttrName()) {
470 llvmFunc->addFnAttr(
"nvvm.blocksareclusters");
478 LLVM::ModuleTranslation &moduleTranslation)
const final {
480 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
481 llvm::Function *llvmFunc =
482 moduleTranslation.lookupFunction(funcOp.getName());
484 if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
485 llvmFunc->addParamAttr(
486 argIdx, llvm::Attribute::get(llvmContext,
"nvvm.grid_constant"));
494 registry.
insert<NVVM::NVVMDialect>();
496 dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>();
static LogicalResult convertParameterAttr(llvm::AttrBuilder &attrBuilder, llvm::Attribute::AttrKind llvmKind, NamedAttribute namedAttr, ModuleTranslation &moduleTranslation, Location loc)
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, NVVM::LdStMatrixShapeAttr shape, NVVM::LdStMatrixEltType eltType)
#define GET_REDUX_F32_ID(op, hasAbs, hasNaN)
static llvm::Intrinsic::ID getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, NVVM::LdStMatrixShapeAttr shape, NVVM::LdStMatrixEltType eltType)
Return the intrinsic ID associated with stmatrix for the given paramters.
static llvm::Intrinsic::ID getTcgen05StIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num)
static llvm::Intrinsic::ID getTcgen05LdIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num)
static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy, NVVM::ProxyKind toProxy, NVVM::MemScopeKind scope, bool isRelease)
#define TCGEN05ST(SHAPE, NUM)
static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, NVVM::ReduxKind kind, bool hasAbs, bool hasNaN)
#define TCGEN05LD(SHAPE, NUM)
static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, NVVM::ShflKind kind, bool withPredicate)
static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind)
static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType, NVVM::MatchSyncKind kind)
static llvm::Intrinsic::ID getStBulkIntrinsicId(LLVM::LLVMPointerType addrType)
Return the intrinsic ID associated with st.bulk for the given address type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Base class for dialect interfaces providing translation to LLVM IR.
LLVMTranslationDialectInterface(Dialect *dialect)
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
Include the generated interface declarations.
void registerNVVMDialectTranslation(DialectRegistry ®istry)
Register the NVVM dialect and the translation from it to the LLVM IR in the given registry;.