19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/IntrinsicsNVPTX.h"
23 #include "llvm/Support/FormatVariadic.h"
29 #define REDUX_F32_ID_IMPL(op, abs, hasNaN) \
30 hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##_NaN \
31 : llvm::Intrinsic::nvvm_redux_sync_f##op##abs
33 #define GET_REDUX_F32_ID(op, hasAbs, hasNaN) \
34 hasAbs ? REDUX_F32_ID_IMPL(op, _abs, hasNaN) : REDUX_F32_ID_IMPL(op, , hasNaN)
38 bool hasAbs,
bool hasNaN) {
39 if (!(resultType->isIntegerTy(32) || resultType->isFloatTy()))
40 llvm_unreachable(
"unsupported data type for redux");
43 case NVVM::ReduxKind::ADD:
44 return llvm::Intrinsic::nvvm_redux_sync_add;
45 case NVVM::ReduxKind::UMAX:
46 return llvm::Intrinsic::nvvm_redux_sync_umax;
47 case NVVM::ReduxKind::UMIN:
48 return llvm::Intrinsic::nvvm_redux_sync_umin;
49 case NVVM::ReduxKind::AND:
50 return llvm::Intrinsic::nvvm_redux_sync_and;
51 case NVVM::ReduxKind::OR:
52 return llvm::Intrinsic::nvvm_redux_sync_or;
53 case NVVM::ReduxKind::XOR:
54 return llvm::Intrinsic::nvvm_redux_sync_xor;
55 case NVVM::ReduxKind::MAX:
56 return llvm::Intrinsic::nvvm_redux_sync_max;
57 case NVVM::ReduxKind::MIN:
58 return llvm::Intrinsic::nvvm_redux_sync_min;
59 case NVVM::ReduxKind::FMIN:
61 case NVVM::ReduxKind::FMAX:
64 llvm_unreachable(
"unknown redux kind");
72 resultType = cast<llvm::StructType>(resultType)->getElementType(0);
74 case NVVM::ShflKind::bfly:
75 return resultType->isFloatTy()
76 ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
77 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
78 case NVVM::ShflKind::up:
79 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
80 : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
81 case NVVM::ShflKind::down:
82 return resultType->isFloatTy()
83 ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
84 : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
85 case NVVM::ShflKind::idx:
86 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
87 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
91 case NVVM::ShflKind::bfly:
92 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
93 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
94 case NVVM::ShflKind::up:
95 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
96 : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
97 case NVVM::ShflKind::down:
98 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
99 : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
100 case NVVM::ShflKind::idx:
101 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
102 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
105 llvm_unreachable(
"unknown shuffle kind");
109 NVVM::MatchSyncKind
kind) {
111 case NVVM::MatchSyncKind::any:
112 return valType.
isInteger(32) ? llvm::Intrinsic::nvvm_match_any_sync_i32
113 : llvm::Intrinsic::nvvm_match_any_sync_i64;
114 case NVVM::MatchSyncKind::all:
118 return valType.
isInteger(32) ? llvm::Intrinsic::nvvm_match_all_sync_i32p
119 : llvm::Intrinsic::nvvm_match_all_sync_i64p;
125 case NVVM::VoteSyncKind::any:
126 return llvm::Intrinsic::nvvm_vote_any_sync;
127 case NVVM::VoteSyncKind::all:
128 return llvm::Intrinsic::nvvm_vote_all_sync;
129 case NVVM::VoteSyncKind::ballot:
130 return llvm::Intrinsic::nvvm_vote_ballot_sync;
131 case NVVM::VoteSyncKind::uni:
132 return llvm::Intrinsic::nvvm_vote_uni_sync;
134 llvm_unreachable(
"unsupported vote kind");
140 if (layout == NVVM::MMALayout::row) {
143 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
145 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
147 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
149 llvm_unreachable(
"unsupported number of matrix");
155 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
157 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
159 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
161 llvm_unreachable(
"unsupported number of matrix");
172 : llvm::Intrinsic::nvvm_st_bulk;
176 NVVM::ProxyKind toProxy,
177 NVVM::MemScopeKind scope,
179 if (fromProxy == NVVM::ProxyKind::GENERIC &&
180 toProxy == NVVM::ProxyKind::TENSORMAP) {
182 case NVVM::MemScopeKind::CTA: {
184 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cta;
185 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cta;
187 case NVVM::MemScopeKind::CLUSTER: {
189 return llvm::Intrinsic::
190 nvvm_fence_proxy_tensormap_generic_release_cluster;
191 return llvm::Intrinsic::
192 nvvm_fence_proxy_tensormap_generic_acquire_cluster;
194 case NVVM::MemScopeKind::GPU: {
196 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_gpu;
197 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_gpu;
199 case NVVM::MemScopeKind::SYS: {
201 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_sys;
202 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_sys;
205 llvm_unreachable(
"Unknown scope for uni-directional fence.proxy operation");
207 llvm_unreachable(
"Unsupported proxy kinds");
210 #define TCGEN05LD(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_ld_##SHAPE##_##NUM
246 unsigned Idx = std::log2(num);
249 case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
250 return Shape16x64b[Idx];
251 case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
252 return Shape16x128b[Idx - 1];
253 case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
254 return Shape16x256b[Idx - 2];
255 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
256 return Shape32x32b[Idx];
257 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
258 return Shape16x32bx2[Idx];
260 llvm_unreachable(
"unhandled tcgen05.ld lowering");
263 #define TCGEN05ST(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_st_##SHAPE##_##NUM
299 unsigned Idx = std::log2(num);
302 case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
303 return Shape16x64b[Idx];
304 case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
305 return Shape16x128b[Idx - 1];
306 case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
307 return Shape16x256b[Idx - 2];
308 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
309 return Shape32x32b[Idx];
310 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
311 return Shape16x32bx2[Idx];
313 llvm_unreachable(
"unhandled tcgen05.st lowering");
319 class NVVMDialectLLVMIRTranslationInterface
327 convertOperation(
Operation *op, llvm::IRBuilderBase &builder,
330 #include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
340 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
343 llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName());
345 if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
346 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
348 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
349 const std::string attr = llvm::formatv(
350 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
351 values.asArrayRef().end()));
352 llvmFunc->addFnAttr(
"nvvm.maxntid", attr);
353 }
else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
354 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
356 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
357 const std::string attr = llvm::formatv(
358 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
359 values.asArrayRef().end()));
360 llvmFunc->addFnAttr(
"nvvm.reqntid", attr);
361 }
else if (attribute.getName() ==
362 NVVM::NVVMDialect::getClusterDimAttrName()) {
363 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
365 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
366 const std::string attr = llvm::formatv(
367 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
368 values.asArrayRef().end()));
369 llvmFunc->addFnAttr(
"nvvm.cluster_dim", attr);
370 }
else if (attribute.getName() ==
371 NVVM::NVVMDialect::getClusterMaxBlocksAttrName()) {
372 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
373 llvmFunc->addFnAttr(
"nvvm.maxclusterrank", llvm::utostr(value.getInt()));
374 }
else if (attribute.getName() ==
375 NVVM::NVVMDialect::getMinctasmAttrName()) {
376 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
377 llvmFunc->addFnAttr(
"nvvm.minctasm", llvm::utostr(value.getInt()));
378 }
else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
379 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
380 llvmFunc->addFnAttr(
"nvvm.maxnreg", llvm::utostr(value.getInt()));
381 }
else if (attribute.getName() ==
382 NVVM::NVVMDialect::getKernelFuncAttrName()) {
383 llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel);
392 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
393 llvm::Function *llvmFunc =
394 moduleTranslation.lookupFunction(funcOp.getName());
395 llvm::NamedMDNode *nvvmAnnotations =
396 moduleTranslation.getOrInsertNamedModuleMetadata(
"nvvm.annotations");
398 if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
399 llvm::MDNode *gridConstantMetaData =
nullptr;
402 for (llvm::MDNode *opnd : llvm::reverse(nvvmAnnotations->operands())) {
403 if (opnd->getNumOperands() == 3 &&
405 opnd->getOperand(1) ==
407 gridConstantMetaData = opnd;
417 if (gridConstantMetaData ==
nullptr) {
420 llvm::ValueAsMetadata::getConstant(
422 llvm::Metadata *llvmMetadata[] = {
426 llvm::MDNode *llvmMetadataNode =
428 nvvmAnnotations->addOperand(llvmMetadataNode);
432 dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
433 llvm::TempMDTuple clonedArgList = argList->clone();
434 clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
436 gridConstantMetaData->replaceOperandWith(
437 2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
447 registry.
insert<NVVM::NVVMDialect>();
449 dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>();
static constexpr int64_t kSharedMemorySpace
union mlir::linalg::@1216::ArityGroupAndKind::Kind kind
static LogicalResult convertParameterAttr(llvm::AttrBuilder &attrBuilder, llvm::Attribute::AttrKind llvmKind, NamedAttribute namedAttr, ModuleTranslation &moduleTranslation, Location loc)
#define GET_REDUX_F32_ID(op, hasAbs, hasNaN)
static llvm::Intrinsic::ID getTcgen05StIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num)
static llvm::Intrinsic::ID getTcgen05LdIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num)
static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy, NVVM::ProxyKind toProxy, NVVM::MemScopeKind scope, bool isRelease)
#define TCGEN05ST(SHAPE, NUM)
static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, NVVM::ReduxKind kind, bool hasAbs, bool hasNaN)
#define TCGEN05LD(SHAPE, NUM)
static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, NVVM::ShflKind kind, bool withPredicate)
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num)
Return the intrinsic ID associated with ldmatrix for the given paramters.
static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind)
static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType, NVVM::MatchSyncKind kind)
static llvm::Intrinsic::ID getStBulkIntrinsicId(LLVM::LLVMPointerType addrType)
Return the intrinsic ID associated with st.bulk for the given address type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Base class for dialect interfaces providing translation to LLVM IR.
LLVMTranslationDialectInterface(Dialect *dialect)
Implementation class for module translation.
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
Operation is the basic unit of execution within MLIR.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
Include the generated interface declarations.
void registerNVVMDialectTranslation(DialectRegistry ®istry)
Register the NVVM dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...