20 #include "llvm/IR/IRBuilder.h"
21 #include "llvm/IR/IntrinsicsNVPTX.h"
28 NVVM::ReduxKind kind) {
29 if (!resultType->isIntegerTy(32))
30 llvm_unreachable(
"unsupported data type for redux");
33 case NVVM::ReduxKind::ADD:
34 return llvm::Intrinsic::nvvm_redux_sync_add;
35 case NVVM::ReduxKind::UMAX:
36 return llvm::Intrinsic::nvvm_redux_sync_umax;
37 case NVVM::ReduxKind::UMIN:
38 return llvm::Intrinsic::nvvm_redux_sync_umin;
39 case NVVM::ReduxKind::AND:
40 return llvm::Intrinsic::nvvm_redux_sync_and;
41 case NVVM::ReduxKind::OR:
42 return llvm::Intrinsic::nvvm_redux_sync_or;
43 case NVVM::ReduxKind::XOR:
44 return llvm::Intrinsic::nvvm_redux_sync_xor;
45 case NVVM::ReduxKind::MAX:
46 return llvm::Intrinsic::nvvm_redux_sync_max;
47 case NVVM::ReduxKind::MIN:
48 return llvm::Intrinsic::nvvm_redux_sync_min;
50 llvm_unreachable(
"unknown redux kind");
58 resultType = cast<llvm::StructType>(resultType)->getElementType(0);
60 case NVVM::ShflKind::bfly:
61 return resultType->isFloatTy()
62 ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
63 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
64 case NVVM::ShflKind::up:
65 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
66 : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
67 case NVVM::ShflKind::down:
68 return resultType->isFloatTy()
69 ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
70 : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
71 case NVVM::ShflKind::idx:
72 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
73 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
77 case NVVM::ShflKind::bfly:
78 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
79 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
80 case NVVM::ShflKind::up:
81 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
82 : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
83 case NVVM::ShflKind::down:
84 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
85 : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
86 case NVVM::ShflKind::idx:
87 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
88 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
91 llvm_unreachable(
"unknown shuffle kind");
97 if (layout == NVVM::MMALayout::row) {
100 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
102 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
104 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
106 llvm_unreachable(
"unsupported number of matrix");
112 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
114 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
116 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
118 llvm_unreachable(
"unsupported number of matrix");
124 NVVM::ProxyKind toProxy,
125 NVVM::MemScopeKind scope,
127 if (fromProxy == NVVM::ProxyKind::GENERIC &&
128 toProxy == NVVM::ProxyKind::TENSORMAP) {
130 case NVVM::MemScopeKind::CTA: {
132 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cta;
133 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cta;
135 case NVVM::MemScopeKind::CLUSTER: {
137 return llvm::Intrinsic::
138 nvvm_fence_proxy_tensormap_generic_release_cluster;
139 return llvm::Intrinsic::
140 nvvm_fence_proxy_tensormap_generic_acquire_cluster;
142 case NVVM::MemScopeKind::GPU: {
144 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_gpu;
145 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_gpu;
147 case NVVM::MemScopeKind::SYS: {
149 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_sys;
150 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_sys;
153 llvm_unreachable(
"Unknown scope for uni-directional fence.proxy operation");
155 llvm_unreachable(
"Unsupported proxy kinds");
161 class NVVMDialectLLVMIRTranslationInterface
169 convertOperation(
Operation *op, llvm::IRBuilderBase &builder,
172 #include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
182 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
185 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
186 llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName());
188 auto generateMetadata = [&](
int dim, StringRef name) {
189 llvm::Metadata *llvmMetadata[] = {
193 llvm::Type::getInt32Ty(llvmContext), dim))};
194 llvm::MDNode *llvmMetadataNode =
196 moduleTranslation.getOrInsertNamedModuleMetadata(
"nvvm.annotations")
197 ->addOperand(llvmMetadataNode);
199 if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
200 if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
202 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
203 generateMetadata(values[0], NVVM::NVVMDialect::getMaxntidXName());
204 if (values.size() > 1)
205 generateMetadata(values[1], NVVM::NVVMDialect::getMaxntidYName());
206 if (values.size() > 2)
207 generateMetadata(values[2], NVVM::NVVMDialect::getMaxntidZName());
208 }
else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
209 if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
211 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
212 generateMetadata(values[0], NVVM::NVVMDialect::getReqntidXName());
213 if (values.size() > 1)
214 generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName());
215 if (values.size() > 2)
216 generateMetadata(values[2], NVVM::NVVMDialect::getReqntidZName());
217 }
else if (attribute.getName() ==
218 NVVM::NVVMDialect::getMinctasmAttrName()) {
219 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
220 generateMetadata(value.getInt(),
"minctasm");
221 }
else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
222 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
223 generateMetadata(value.getInt(),
"maxnreg");
224 }
else if (attribute.getName() ==
225 NVVM::NVVMDialect::getKernelFuncAttrName()) {
226 llvm::Metadata *llvmMetadataKernel[] = {
231 llvm::MDNode *llvmMetadataNode =
233 moduleTranslation.getOrInsertNamedModuleMetadata(
"nvvm.annotations")
234 ->addOperand(llvmMetadataNode);
240 convertParameterAttr(LLVMFuncOp funcOp,
int argIdx,
NamedAttribute attribute,
243 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
244 llvm::Function *llvmFunc =
245 moduleTranslation.lookupFunction(funcOp.getName());
246 llvm::NamedMDNode *nvvmAnnotations =
247 moduleTranslation.getOrInsertNamedModuleMetadata(
"nvvm.annotations");
249 if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
250 llvm::MDNode *gridConstantMetaData =
nullptr;
253 for (llvm::MDNode *opnd : llvm::reverse(nvvmAnnotations->operands())) {
254 if (opnd->getNumOperands() == 3 &&
256 opnd->getOperand(1) ==
258 gridConstantMetaData = opnd;
268 if (gridConstantMetaData ==
nullptr) {
271 llvm::ValueAsMetadata::getConstant(
273 llvm::Metadata *llvmMetadata[] = {
277 llvm::MDNode *llvmMetadataNode =
279 nvvmAnnotations->addOperand(llvmMetadataNode);
283 dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
284 llvm::TempMDTuple clonedArgList = argList->clone();
285 clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
287 gridConstantMetaData->replaceOperandWith(
288 2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
298 registry.
insert<NVVM::NVVMDialect>();
300 dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>();
static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, NVVM::ReduxKind kind)
static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy, NVVM::ProxyKind toProxy, NVVM::MemScopeKind scope, bool isRelease)
static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, NVVM::ShflKind kind, bool withPredicate)
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num)
Return the intrinsic ID associated with ldmatrix for the given paramters.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Base class for dialect interfaces providing translation to LLVM IR.
LLVMTranslationDialectInterface(Dialect *dialect)
Implementation class for module translation.
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
Operation is the basic unit of execution within MLIR.
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
Include the generated interface declarations.
void registerNVVMDialectTranslation(DialectRegistry ®istry)
Register the NVVM dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...