20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/IntrinsicsNVPTX.h"
29 NVVM::ReduxKind kind) {
30 if (!resultType->isIntegerTy(32))
31 llvm_unreachable(
"unsupported data type for redux");
34 case NVVM::ReduxKind::ADD:
35 return llvm::Intrinsic::nvvm_redux_sync_add;
36 case NVVM::ReduxKind::UMAX:
37 return llvm::Intrinsic::nvvm_redux_sync_umax;
38 case NVVM::ReduxKind::UMIN:
39 return llvm::Intrinsic::nvvm_redux_sync_umin;
40 case NVVM::ReduxKind::AND:
41 return llvm::Intrinsic::nvvm_redux_sync_and;
42 case NVVM::ReduxKind::OR:
43 return llvm::Intrinsic::nvvm_redux_sync_or;
44 case NVVM::ReduxKind::XOR:
45 return llvm::Intrinsic::nvvm_redux_sync_xor;
46 case NVVM::ReduxKind::MAX:
47 return llvm::Intrinsic::nvvm_redux_sync_max;
48 case NVVM::ReduxKind::MIN:
49 return llvm::Intrinsic::nvvm_redux_sync_min;
51 llvm_unreachable(
"unknown redux kind");
59 resultType = cast<llvm::StructType>(resultType)->getElementType(0);
61 case NVVM::ShflKind::bfly:
62 return resultType->isFloatTy()
63 ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
64 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
65 case NVVM::ShflKind::up:
66 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
67 : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
68 case NVVM::ShflKind::down:
69 return resultType->isFloatTy()
70 ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
71 : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
72 case NVVM::ShflKind::idx:
73 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
74 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
78 case NVVM::ShflKind::bfly:
79 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
80 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
81 case NVVM::ShflKind::up:
82 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
83 : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
84 case NVVM::ShflKind::down:
85 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
86 : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
87 case NVVM::ShflKind::idx:
88 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
89 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
92 llvm_unreachable(
"unknown shuffle kind");
98 if (layout == NVVM::MMALayout::row) {
101 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
103 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
105 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
107 llvm_unreachable(
"unsupported number of matrix");
113 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
115 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
117 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
119 llvm_unreachable(
"unsupported number of matrix");
125 NVVM::ProxyKind toProxy,
126 NVVM::MemScopeKind scope,
128 if (fromProxy == NVVM::ProxyKind::GENERIC &&
129 toProxy == NVVM::ProxyKind::TENSORMAP) {
131 case NVVM::MemScopeKind::CTA: {
133 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cta;
134 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cta;
136 case NVVM::MemScopeKind::CLUSTER: {
138 return llvm::Intrinsic::
139 nvvm_fence_proxy_tensormap_generic_release_cluster;
140 return llvm::Intrinsic::
141 nvvm_fence_proxy_tensormap_generic_acquire_cluster;
143 case NVVM::MemScopeKind::GPU: {
145 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_gpu;
146 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_gpu;
148 case NVVM::MemScopeKind::SYS: {
150 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_sys;
151 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_sys;
154 llvm_unreachable(
"Unknown scope for uni-directional fence.proxy operation");
156 llvm_unreachable(
"Unsupported proxy kinds");
162 class NVVMDialectLLVMIRTranslationInterface
170 convertOperation(
Operation *op, llvm::IRBuilderBase &builder,
173 #include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
183 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
186 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
187 llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName());
189 auto generateMetadata = [&](
int dim, StringRef name) {
190 llvm::Metadata *llvmMetadata[] = {
194 llvm::Type::getInt32Ty(llvmContext), dim))};
195 llvm::MDNode *llvmMetadataNode =
197 moduleTranslation.getOrInsertNamedModuleMetadata(
"nvvm.annotations")
198 ->addOperand(llvmMetadataNode);
200 if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
201 if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
203 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
204 generateMetadata(values[0], NVVM::NVVMDialect::getMaxntidXName());
205 if (values.size() > 1)
206 generateMetadata(values[1], NVVM::NVVMDialect::getMaxntidYName());
207 if (values.size() > 2)
208 generateMetadata(values[2], NVVM::NVVMDialect::getMaxntidZName());
209 }
else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
210 if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
212 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
213 generateMetadata(values[0], NVVM::NVVMDialect::getReqntidXName());
214 if (values.size() > 1)
215 generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName());
216 if (values.size() > 2)
217 generateMetadata(values[2], NVVM::NVVMDialect::getReqntidZName());
218 }
else if (attribute.getName() ==
219 NVVM::NVVMDialect::getClusterDimAttrName()) {
220 if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
222 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
223 generateMetadata(values[0], NVVM::NVVMDialect::getClusterDimXName());
224 if (values.size() > 1)
225 generateMetadata(values[1], NVVM::NVVMDialect::getClusterDimYName());
226 if (values.size() > 2)
227 generateMetadata(values[2], NVVM::NVVMDialect::getClusterDimZName());
228 }
else if (attribute.getName() ==
229 NVVM::NVVMDialect::getClusterMaxBlocksAttrName()) {
230 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
231 llvmFunc->addFnAttr(
"nvvm.maxclusterrank", llvm::utostr(value.getInt()));
232 }
else if (attribute.getName() ==
233 NVVM::NVVMDialect::getMinctasmAttrName()) {
234 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
235 llvmFunc->addFnAttr(
"nvvm.minctasm", llvm::utostr(value.getInt()));
236 }
else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
237 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
238 llvmFunc->addFnAttr(
"nvvm.maxnreg", llvm::utostr(value.getInt()));
239 }
else if (attribute.getName() ==
240 NVVM::NVVMDialect::getKernelFuncAttrName()) {
241 llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel);
250 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
251 llvm::Function *llvmFunc =
252 moduleTranslation.lookupFunction(funcOp.getName());
253 llvm::NamedMDNode *nvvmAnnotations =
254 moduleTranslation.getOrInsertNamedModuleMetadata(
"nvvm.annotations");
256 if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
257 llvm::MDNode *gridConstantMetaData =
nullptr;
260 for (llvm::MDNode *opnd : llvm::reverse(nvvmAnnotations->operands())) {
261 if (opnd->getNumOperands() == 3 &&
263 opnd->getOperand(1) ==
265 gridConstantMetaData = opnd;
275 if (gridConstantMetaData ==
nullptr) {
278 llvm::ValueAsMetadata::getConstant(
280 llvm::Metadata *llvmMetadata[] = {
284 llvm::MDNode *llvmMetadataNode =
286 nvvmAnnotations->addOperand(llvmMetadataNode);
290 dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
291 llvm::TempMDTuple clonedArgList = argList->clone();
292 clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
294 gridConstantMetaData->replaceOperandWith(
295 2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
305 registry.
insert<NVVM::NVVMDialect>();
307 dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>();
static LogicalResult convertParameterAttr(llvm::AttrBuilder &attrBuilder, llvm::Attribute::AttrKind llvmKind, NamedAttribute namedAttr, ModuleTranslation &moduleTranslation, Location loc)
static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, NVVM::ReduxKind kind)
static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy, NVVM::ProxyKind toProxy, NVVM::MemScopeKind scope, bool isRelease)
static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, NVVM::ShflKind kind, bool withPredicate)
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num)
Return the intrinsic ID associated with ldmatrix for the given paramters.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Base class for dialect interfaces providing translation to LLVM IR.
LLVMTranslationDialectInterface(Dialect *dialect)
Implementation class for module translation.
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
Operation is the basic unit of execution within MLIR.
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
Include the generated interface declarations.
void registerNVVMDialectTranslation(DialectRegistry ®istry)
Register the NVVM dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...