19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/IntrinsicsNVPTX.h"
23 #include "llvm/Support/FormatVariadic.h"
29 #define REDUX_F32_ID_IMPL(op, abs, hasNaN) \
30 hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##_NaN \
31 : llvm::Intrinsic::nvvm_redux_sync_f##op##abs
33 #define GET_REDUX_F32_ID(op, hasAbs, hasNaN) \
34 hasAbs ? REDUX_F32_ID_IMPL(op, _abs, hasNaN) : REDUX_F32_ID_IMPL(op, , hasNaN)
38 bool hasAbs,
bool hasNaN) {
39 if (!(resultType->isIntegerTy(32) || resultType->isFloatTy()))
40 llvm_unreachable(
"unsupported data type for redux");
43 case NVVM::ReduxKind::ADD:
44 return llvm::Intrinsic::nvvm_redux_sync_add;
45 case NVVM::ReduxKind::UMAX:
46 return llvm::Intrinsic::nvvm_redux_sync_umax;
47 case NVVM::ReduxKind::UMIN:
48 return llvm::Intrinsic::nvvm_redux_sync_umin;
49 case NVVM::ReduxKind::AND:
50 return llvm::Intrinsic::nvvm_redux_sync_and;
51 case NVVM::ReduxKind::OR:
52 return llvm::Intrinsic::nvvm_redux_sync_or;
53 case NVVM::ReduxKind::XOR:
54 return llvm::Intrinsic::nvvm_redux_sync_xor;
55 case NVVM::ReduxKind::MAX:
56 return llvm::Intrinsic::nvvm_redux_sync_max;
57 case NVVM::ReduxKind::MIN:
58 return llvm::Intrinsic::nvvm_redux_sync_min;
59 case NVVM::ReduxKind::FMIN:
61 case NVVM::ReduxKind::FMAX:
64 llvm_unreachable(
"unknown redux kind");
72 resultType = cast<llvm::StructType>(resultType)->getElementType(0);
74 case NVVM::ShflKind::bfly:
75 return resultType->isFloatTy()
76 ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
77 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
78 case NVVM::ShflKind::up:
79 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
80 : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
81 case NVVM::ShflKind::down:
82 return resultType->isFloatTy()
83 ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
84 : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
85 case NVVM::ShflKind::idx:
86 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
87 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
91 case NVVM::ShflKind::bfly:
92 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
93 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
94 case NVVM::ShflKind::up:
95 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
96 : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
97 case NVVM::ShflKind::down:
98 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
99 : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
100 case NVVM::ShflKind::idx:
101 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
102 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
105 llvm_unreachable(
"unknown shuffle kind");
109 NVVM::MatchSyncKind
kind) {
111 case NVVM::MatchSyncKind::any:
112 return valType.
isInteger(32) ? llvm::Intrinsic::nvvm_match_any_sync_i32
113 : llvm::Intrinsic::nvvm_match_any_sync_i64;
114 case NVVM::MatchSyncKind::all:
118 return valType.
isInteger(32) ? llvm::Intrinsic::nvvm_match_all_sync_i32p
119 : llvm::Intrinsic::nvvm_match_all_sync_i64p;
121 llvm_unreachable(
"unsupported match sync kind");
126 case NVVM::VoteSyncKind::any:
127 return llvm::Intrinsic::nvvm_vote_any_sync;
128 case NVVM::VoteSyncKind::all:
129 return llvm::Intrinsic::nvvm_vote_all_sync;
130 case NVVM::VoteSyncKind::ballot:
131 return llvm::Intrinsic::nvvm_vote_ballot_sync;
132 case NVVM::VoteSyncKind::uni:
133 return llvm::Intrinsic::nvvm_vote_uni_sync;
135 llvm_unreachable(
"unsupported vote kind");
140 NVVM::LdStMatrixShapeAttr shape,
141 NVVM::LdStMatrixEltType eltType) {
142 if (shape.getM() == 8 && shape.getN() == 8) {
145 return (layout == NVVM::MMALayout::row)
146 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16
148 nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
150 return (layout == NVVM::MMALayout::row)
151 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16
153 nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
155 return (layout == NVVM::MMALayout::row)
156 ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16
158 nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
160 }
else if (shape.getM() == 8 && shape.getN() == 16) {
161 if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
164 return llvm::Intrinsic::
165 nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32;
167 return llvm::Intrinsic::
168 nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32;
170 return llvm::Intrinsic::
171 nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
173 }
else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
176 return llvm::Intrinsic::
177 nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64;
179 return llvm::Intrinsic::
180 nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64;
182 return llvm::Intrinsic::
183 nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64;
186 }
else if (shape.getM() == 16 && shape.getN() == 16) {
187 if (eltType == NVVM::LdStMatrixEltType::B8) {
190 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8;
192 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
194 }
else if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
197 return llvm::Intrinsic::
198 nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32;
200 return llvm::Intrinsic::
201 nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
203 }
else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
206 return llvm::Intrinsic::
207 nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64;
209 return llvm::Intrinsic::
210 nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64;
214 llvm_unreachable(
"unknown ldmatrix kind");
220 NVVM::LdStMatrixShapeAttr shape,
221 NVVM::LdStMatrixEltType eltType) {
222 if (shape.getM() == 8 && shape.getN() == 8) {
225 return (layout == NVVM::MMALayout::row)
226 ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16
228 nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16;
230 return (layout == NVVM::MMALayout::row)
231 ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16
233 nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16;
235 return (layout == NVVM::MMALayout::row)
236 ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16
238 nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16;
240 }
else if (shape.getM() == 16 && shape.getN() == 8) {
243 return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8;
245 return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8;
247 return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8;
250 llvm_unreachable(
"unknown stmatrix kind");
257 static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
259 : llvm::Intrinsic::nvvm_st_bulk;
263 NVVM::ProxyKind toProxy,
264 NVVM::MemScopeKind scope,
266 if (fromProxy == NVVM::ProxyKind::GENERIC &&
267 toProxy == NVVM::ProxyKind::TENSORMAP) {
269 case NVVM::MemScopeKind::CTA: {
271 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cta;
272 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cta;
274 case NVVM::MemScopeKind::CLUSTER: {
276 return llvm::Intrinsic::
277 nvvm_fence_proxy_tensormap_generic_release_cluster;
278 return llvm::Intrinsic::
279 nvvm_fence_proxy_tensormap_generic_acquire_cluster;
281 case NVVM::MemScopeKind::GPU: {
283 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_gpu;
284 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_gpu;
286 case NVVM::MemScopeKind::SYS: {
288 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_sys;
289 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_sys;
292 llvm_unreachable(
"Unknown scope for uni-directional fence.proxy operation");
294 llvm_unreachable(
"Unsupported proxy kinds");
297 #define TCGEN05LD(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_ld_##SHAPE##_##NUM
333 unsigned Idx = std::log2(num);
336 case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
337 return Shape16x64b[Idx];
338 case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
339 return Shape16x128b[Idx - 1];
340 case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
341 return Shape16x256b[Idx - 2];
342 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
343 return Shape32x32b[Idx];
344 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
345 return Shape16x32bx2[Idx];
347 llvm_unreachable(
"unhandled tcgen05.ld lowering");
350 #define TCGEN05ST(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_st_##SHAPE##_##NUM
386 unsigned Idx = std::log2(num);
389 case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
390 return Shape16x64b[Idx];
391 case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
392 return Shape16x128b[Idx - 1];
393 case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
394 return Shape16x256b[Idx - 2];
395 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
396 return Shape32x32b[Idx];
397 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
398 return Shape16x32bx2[Idx];
400 llvm_unreachable(
"unhandled tcgen05.st lowering");
406 class NVVMDialectLLVMIRTranslationInterface
414 convertOperation(
Operation *op, llvm::IRBuilderBase &builder,
415 LLVM::ModuleTranslation &moduleTranslation)
const final {
417 #include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
426 LLVM::ModuleTranslation &moduleTranslation)
const final {
427 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
430 llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName());
432 if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
433 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
435 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
436 const std::string attr = llvm::formatv(
437 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
438 values.asArrayRef().end()));
439 llvmFunc->addFnAttr(
"nvvm.maxntid", attr);
440 }
else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
441 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
443 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
444 const std::string attr = llvm::formatv(
445 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
446 values.asArrayRef().end()));
447 llvmFunc->addFnAttr(
"nvvm.reqntid", attr);
448 }
else if (attribute.getName() ==
449 NVVM::NVVMDialect::getClusterDimAttrName()) {
450 if (!isa<DenseI32ArrayAttr>(attribute.getValue()))
452 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
453 const std::string attr = llvm::formatv(
454 "{0:$[,]}", llvm::make_range(values.asArrayRef().begin(),
455 values.asArrayRef().end()));
456 llvmFunc->addFnAttr(
"nvvm.cluster_dim", attr);
457 }
else if (attribute.getName() ==
458 NVVM::NVVMDialect::getClusterMaxBlocksAttrName()) {
459 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
460 llvmFunc->addFnAttr(
"nvvm.maxclusterrank", llvm::utostr(value.getInt()));
461 }
else if (attribute.getName() ==
462 NVVM::NVVMDialect::getMinctasmAttrName()) {
463 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
464 llvmFunc->addFnAttr(
"nvvm.minctasm", llvm::utostr(value.getInt()));
465 }
else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
466 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
467 llvmFunc->addFnAttr(
"nvvm.maxnreg", llvm::utostr(value.getInt()));
468 }
else if (attribute.getName() ==
469 NVVM::NVVMDialect::getKernelFuncAttrName()) {
470 llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel);
471 }
else if (attribute.getName() ==
472 NVVM::NVVMDialect::getBlocksAreClustersAttrName()) {
473 llvmFunc->addFnAttr(
"nvvm.blocksareclusters");
481 LLVM::ModuleTranslation &moduleTranslation)
const final {
483 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
484 llvm::Function *llvmFunc =
485 moduleTranslation.lookupFunction(funcOp.getName());
487 if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
488 llvmFunc->addParamAttr(
497 registry.
insert<NVVM::NVVMDialect>();
499 dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>();
union mlir::linalg::@1243::ArityGroupAndKind::Kind kind
static LogicalResult convertParameterAttr(llvm::AttrBuilder &attrBuilder, llvm::Attribute::AttrKind llvmKind, NamedAttribute namedAttr, ModuleTranslation &moduleTranslation, Location loc)
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, NVVM::LdStMatrixShapeAttr shape, NVVM::LdStMatrixEltType eltType)
#define GET_REDUX_F32_ID(op, hasAbs, hasNaN)
static llvm::Intrinsic::ID getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, NVVM::LdStMatrixShapeAttr shape, NVVM::LdStMatrixEltType eltType)
Return the intrinsic ID associated with stmatrix for the given paramters.
static llvm::Intrinsic::ID getTcgen05StIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num)
static llvm::Intrinsic::ID getTcgen05LdIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num)
static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy, NVVM::ProxyKind toProxy, NVVM::MemScopeKind scope, bool isRelease)
#define TCGEN05ST(SHAPE, NUM)
static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, NVVM::ReduxKind kind, bool hasAbs, bool hasNaN)
#define TCGEN05LD(SHAPE, NUM)
static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, NVVM::ShflKind kind, bool withPredicate)
static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind)
static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType, NVVM::MatchSyncKind kind)
static llvm::Intrinsic::ID getStBulkIntrinsicId(LLVM::LLVMPointerType addrType)
Return the intrinsic ID associated with st.bulk for the given address type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Base class for dialect interfaces providing translation to LLVM IR.
LLVMTranslationDialectInterface(Dialect *dialect)
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
Operation is the basic unit of execution within MLIR.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
Include the generated interface declarations.
void registerNVVMDialectTranslation(DialectRegistry ®istry)
Register the NVVM dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...