20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/ADT/iterator_range.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/IntrinsicsNVPTX.h"
24 #include "llvm/Support/FormatVariadic.h"
30 #define REDUX_F32_ID_IMPL(op, abs, hasNaN) \
31 hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##_NaN \
32 : llvm::Intrinsic::nvvm_redux_sync_f##op##abs
34 #define GET_REDUX_F32_ID(op, hasAbs, hasNaN) \
35 hasAbs ? REDUX_F32_ID_IMPL(op, _abs, hasNaN) : REDUX_F32_ID_IMPL(op, , hasNaN)
39 bool hasAbs,
bool hasNaN) {
40 if (!(resultType->isIntegerTy(32) || resultType->isFloatTy()))
41 llvm_unreachable(
"unsupported data type for redux");
44 case NVVM::ReduxKind::ADD:
45 return llvm::Intrinsic::nvvm_redux_sync_add;
46 case NVVM::ReduxKind::UMAX:
47 return llvm::Intrinsic::nvvm_redux_sync_umax;
48 case NVVM::ReduxKind::UMIN:
49 return llvm::Intrinsic::nvvm_redux_sync_umin;
50 case NVVM::ReduxKind::AND:
51 return llvm::Intrinsic::nvvm_redux_sync_and;
52 case NVVM::ReduxKind::OR:
53 return llvm::Intrinsic::nvvm_redux_sync_or;
54 case NVVM::ReduxKind::XOR:
55 return llvm::Intrinsic::nvvm_redux_sync_xor;
56 case NVVM::ReduxKind::MAX:
57 return llvm::Intrinsic::nvvm_redux_sync_max;
58 case NVVM::ReduxKind::MIN:
59 return llvm::Intrinsic::nvvm_redux_sync_min;
60 case NVVM::ReduxKind::FMIN:
62 case NVVM::ReduxKind::FMAX:
65 llvm_unreachable(
"unknown redux kind");
73 resultType = cast<llvm::StructType>(resultType)->getElementType(0);
75 case NVVM::ShflKind::bfly:
76 return resultType->isFloatTy()
77 ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
78 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
79 case NVVM::ShflKind::up:
80 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
81 : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
82 case NVVM::ShflKind::down:
83 return resultType->isFloatTy()
84 ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
85 : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
86 case NVVM::ShflKind::idx:
87 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
88 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
92 case NVVM::ShflKind::bfly:
93 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
94 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
95 case NVVM::ShflKind::up:
96 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
97 : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
98 case NVVM::ShflKind::down:
99 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
100 : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
101 case NVVM::ShflKind::idx:
102 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
103 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
106 llvm_unreachable(
"unknown shuffle kind");
110 NVVM::MatchSyncKind
kind) {
112 case NVVM::MatchSyncKind::any:
113 return valType.
isInteger(32) ? llvm::Intrinsic::nvvm_match_any_sync_i32
114 : llvm::Intrinsic::nvvm_match_any_sync_i64;
115 case NVVM::MatchSyncKind::all:
119 return valType.
isInteger(32) ? llvm::Intrinsic::nvvm_match_all_sync_i32p
120 : llvm::Intrinsic::nvvm_match_all_sync_i64p;
126 case NVVM::VoteSyncKind::any:
127 return llvm::Intrinsic::nvvm_vote_any_sync;
128 case NVVM::VoteSyncKind::all:
129 return llvm::Intrinsic::nvvm_vote_all_sync;
130 case NVVM::VoteSyncKind::ballot:
131 return llvm::Intrinsic::nvvm_vote_ballot_sync;
132 case NVVM::VoteSyncKind::uni:
133 return llvm::Intrinsic::nvvm_vote_uni_sync;
135 llvm_unreachable(
"unsupported vote kind");
141 if (layout == NVVM::MMALayout::row) {
144 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
146 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
148 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
150 llvm_unreachable(
"unsupported number of matrix");
156 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
158 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
160 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
162 llvm_unreachable(
"unsupported number of matrix");
173 : llvm::Intrinsic::nvvm_st_bulk;
177 NVVM::ProxyKind toProxy,
178 NVVM::MemScopeKind scope,
180 if (fromProxy == NVVM::ProxyKind::GENERIC &&
181 toProxy == NVVM::ProxyKind::TENSORMAP) {
183 case NVVM::MemScopeKind::CTA: {
185 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cta;
186 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cta;
188 case NVVM::MemScopeKind::CLUSTER: {
190 return llvm::Intrinsic::
191 nvvm_fence_proxy_tensormap_generic_release_cluster;
192 return llvm::Intrinsic::
193 nvvm_fence_proxy_tensormap_generic_acquire_cluster;
195 case NVVM::MemScopeKind::GPU: {
197 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_gpu;
198 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_gpu;
200 case NVVM::MemScopeKind::SYS: {
202 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_sys;
203 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_sys;
206 llvm_unreachable(
"Unknown scope for uni-directional fence.proxy operation");
208 llvm_unreachable(
"Unsupported proxy kinds");
211 #define TCGEN05LD(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_ld_##SHAPE##_##NUM
247 unsigned Idx = std::log2(num);
250 case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
251 return Shape16x64b[Idx];
252 case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
253 return Shape16x128b[Idx - 1];
254 case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
255 return Shape16x256b[Idx - 2];
256 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
257 return Shape32x32b[Idx];
258 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
259 return Shape16x32bx2[Idx];
261 llvm_unreachable(
"unhandled tcgen05.ld lowering");
264 #define TCGEN05ST(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_st_##SHAPE##_##NUM
300 unsigned Idx = std::log2(num);
303 case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
304 return Shape16x64b[Idx];
305 case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
306 return Shape16x128b[Idx - 1];
307 case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
308 return Shape16x256b[Idx - 2];
309 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
310 return Shape32x32b[Idx];
311 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
312 return Shape16x32bx2[Idx];
314 llvm_unreachable(
"unhandled tcgen05.st lowering");
320 class NVVMDialectLLVMIRTranslationInterface
328 convertOperation(
Operation *op, llvm::IRBuilderBase &builder,
331 #include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
341 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
344 llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName());
346 if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
347 if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
349 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
350 const std::string attr = llvm::formatv(
351 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
352 values.asArrayRef().end()));
353 llvmFunc->addFnAttr(
"nvvm.maxntid", attr);
354 }
else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
355 if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
357 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
358 const std::string attr = llvm::formatv(
359 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
360 values.asArrayRef().end()));
361 llvmFunc->addFnAttr(
"nvvm.reqntid", attr);
362 }
else if (attribute.getName() ==
363 NVVM::NVVMDialect::getClusterDimAttrName()) {
364 if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
366 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
367 const std::string attr = llvm::formatv(
368 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
369 values.asArrayRef().end()));
370 llvmFunc->addFnAttr(
"nvvm.cluster_dim", attr);
371 }
else if (attribute.getName() ==
372 NVVM::NVVMDialect::getClusterMaxBlocksAttrName()) {
373 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
374 llvmFunc->addFnAttr(
"nvvm.maxclusterrank", llvm::utostr(value.getInt()));
375 }
else if (attribute.getName() ==
376 NVVM::NVVMDialect::getMinctasmAttrName()) {
377 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
378 llvmFunc->addFnAttr(
"nvvm.minctasm", llvm::utostr(value.getInt()));
379 }
else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
380 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
381 llvmFunc->addFnAttr(
"nvvm.maxnreg", llvm::utostr(value.getInt()));
382 }
else if (attribute.getName() ==
383 NVVM::NVVMDialect::getKernelFuncAttrName()) {
384 llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel);
393 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
394 llvm::Function *llvmFunc =
395 moduleTranslation.lookupFunction(funcOp.getName());
396 llvm::NamedMDNode *nvvmAnnotations =
397 moduleTranslation.getOrInsertNamedModuleMetadata(
"nvvm.annotations");
399 if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
400 llvm::MDNode *gridConstantMetaData =
nullptr;
403 for (llvm::MDNode *opnd : llvm::reverse(nvvmAnnotations->operands())) {
404 if (opnd->getNumOperands() == 3 &&
406 opnd->getOperand(1) ==
408 gridConstantMetaData = opnd;
418 if (gridConstantMetaData ==
nullptr) {
421 llvm::ValueAsMetadata::getConstant(
423 llvm::Metadata *llvmMetadata[] = {
427 llvm::MDNode *llvmMetadataNode =
429 nvvmAnnotations->addOperand(llvmMetadataNode);
433 dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
434 llvm::TempMDTuple clonedArgList = argList->clone();
435 clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
437 gridConstantMetaData->replaceOperandWith(
438 2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
448 registry.
insert<NVVM::NVVMDialect>();
450 dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>();
static constexpr int64_t kSharedMemorySpace
union mlir::linalg::@1195::ArityGroupAndKind::Kind kind
static LogicalResult convertParameterAttr(llvm::AttrBuilder &attrBuilder, llvm::Attribute::AttrKind llvmKind, NamedAttribute namedAttr, ModuleTranslation &moduleTranslation, Location loc)
#define GET_REDUX_F32_ID(op, hasAbs, hasNaN)
static llvm::Intrinsic::ID getTcgen05StIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num)
static llvm::Intrinsic::ID getTcgen05LdIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num)
static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy, NVVM::ProxyKind toProxy, NVVM::MemScopeKind scope, bool isRelease)
#define TCGEN05ST(SHAPE, NUM)
static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, NVVM::ReduxKind kind, bool hasAbs, bool hasNaN)
#define TCGEN05LD(SHAPE, NUM)
static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, NVVM::ShflKind kind, bool withPredicate)
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num)
Return the intrinsic ID associated with ldmatrix for the given paramters.
static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind)
static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType, NVVM::MatchSyncKind kind)
static llvm::Intrinsic::ID getStBulkIntrinsicId(LLVM::LLVMPointerType addrType)
Return the intrinsic ID associated with st.bulk for the given address type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Base class for dialect interfaces providing translation to LLVM IR.
LLVMTranslationDialectInterface(Dialect *dialect)
Implementation class for module translation.
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
Operation is the basic unit of execution within MLIR.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
Include the generated interface declarations.
void registerNVVMDialectTranslation(DialectRegistry ®istry)
Register the NVVM dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...