MLIR  21.0.0git
NVVMToLLVMIRTranslation.cpp
Go to the documentation of this file.
1 //===- NVVMToLLVMIRTranslation.cpp - Translate NVVM to LLVM IR ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR NVVM dialect and
10 // LLVM IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
17 #include "mlir/IR/Operation.h"
19 
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/IntrinsicsNVPTX.h"
23 
24 using namespace mlir;
25 using namespace mlir::LLVM;
27 
28 static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
29  NVVM::ReduxKind kind) {
30  if (!resultType->isIntegerTy(32))
31  llvm_unreachable("unsupported data type for redux");
32 
33  switch (kind) {
34  case NVVM::ReduxKind::ADD:
35  return llvm::Intrinsic::nvvm_redux_sync_add;
36  case NVVM::ReduxKind::UMAX:
37  return llvm::Intrinsic::nvvm_redux_sync_umax;
38  case NVVM::ReduxKind::UMIN:
39  return llvm::Intrinsic::nvvm_redux_sync_umin;
40  case NVVM::ReduxKind::AND:
41  return llvm::Intrinsic::nvvm_redux_sync_and;
42  case NVVM::ReduxKind::OR:
43  return llvm::Intrinsic::nvvm_redux_sync_or;
44  case NVVM::ReduxKind::XOR:
45  return llvm::Intrinsic::nvvm_redux_sync_xor;
46  case NVVM::ReduxKind::MAX:
47  return llvm::Intrinsic::nvvm_redux_sync_max;
48  case NVVM::ReduxKind::MIN:
49  return llvm::Intrinsic::nvvm_redux_sync_min;
50  }
51  llvm_unreachable("unknown redux kind");
52 }
53 
54 static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
55  NVVM::ShflKind kind,
56  bool withPredicate) {
57 
58  if (withPredicate) {
59  resultType = cast<llvm::StructType>(resultType)->getElementType(0);
60  switch (kind) {
61  case NVVM::ShflKind::bfly:
62  return resultType->isFloatTy()
63  ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
64  : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
65  case NVVM::ShflKind::up:
66  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
67  : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
68  case NVVM::ShflKind::down:
69  return resultType->isFloatTy()
70  ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
71  : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
72  case NVVM::ShflKind::idx:
73  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
74  : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
75  }
76  } else {
77  switch (kind) {
78  case NVVM::ShflKind::bfly:
79  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
80  : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
81  case NVVM::ShflKind::up:
82  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
83  : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
84  case NVVM::ShflKind::down:
85  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
86  : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
87  case NVVM::ShflKind::idx:
88  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
89  : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
90  }
91  }
92  llvm_unreachable("unknown shuffle kind");
93 }
94 
95 /// Return the intrinsic ID associated with ldmatrix for the given paramters.
96 static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
97  int32_t num) {
98  if (layout == NVVM::MMALayout::row) {
99  switch (num) {
100  case 1:
101  return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
102  case 2:
103  return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
104  case 4:
105  return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
106  default:
107  llvm_unreachable("unsupported number of matrix");
108  }
109 
110  } else {
111  switch (num) {
112  case 1:
113  return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
114  case 2:
115  return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
116  case 4:
117  return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
118  default:
119  llvm_unreachable("unsupported number of matrix");
120  }
121  }
122 }
123 
124 static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy,
125  NVVM::ProxyKind toProxy,
126  NVVM::MemScopeKind scope,
127  bool isRelease) {
128  if (fromProxy == NVVM::ProxyKind::GENERIC &&
129  toProxy == NVVM::ProxyKind::TENSORMAP) {
130  switch (scope) {
131  case NVVM::MemScopeKind::CTA: {
132  if (isRelease)
133  return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cta;
134  return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cta;
135  }
136  case NVVM::MemScopeKind::CLUSTER: {
137  if (isRelease)
138  return llvm::Intrinsic::
139  nvvm_fence_proxy_tensormap_generic_release_cluster;
140  return llvm::Intrinsic::
141  nvvm_fence_proxy_tensormap_generic_acquire_cluster;
142  }
143  case NVVM::MemScopeKind::GPU: {
144  if (isRelease)
145  return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_gpu;
146  return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_gpu;
147  }
148  case NVVM::MemScopeKind::SYS: {
149  if (isRelease)
150  return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_sys;
151  return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_sys;
152  }
153  }
154  llvm_unreachable("Unknown scope for uni-directional fence.proxy operation");
155  }
156  llvm_unreachable("Unsupported proxy kinds");
157 }
158 
159 namespace {
160 /// Implementation of the dialect interface that converts operations belonging
161 /// to the NVVM dialect to LLVM IR.
162 class NVVMDialectLLVMIRTranslationInterface
164 public:
166 
167  /// Translates the given operation to LLVM IR using the provided IR builder
168  /// and saving the state in `moduleTranslation`.
169  LogicalResult
170  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
171  LLVM::ModuleTranslation &moduleTranslation) const final {
172  Operation &opInst = *op;
173 #include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
174 
175  return failure();
176  }
177 
178  /// Attaches module-level metadata for functions marked as kernels.
179  LogicalResult
180  amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
181  NamedAttribute attribute,
182  LLVM::ModuleTranslation &moduleTranslation) const final {
183  auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
184  if (!func)
185  return failure();
186  llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
187  llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName());
188 
189  auto generateMetadata = [&](int dim, StringRef name) {
190  llvm::Metadata *llvmMetadata[] = {
191  llvm::ValueAsMetadata::get(llvmFunc),
192  llvm::MDString::get(llvmContext, name),
194  llvm::Type::getInt32Ty(llvmContext), dim))};
195  llvm::MDNode *llvmMetadataNode =
196  llvm::MDNode::get(llvmContext, llvmMetadata);
197  moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations")
198  ->addOperand(llvmMetadataNode);
199  };
200  if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
201  if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
202  return failure();
203  auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
204  generateMetadata(values[0], NVVM::NVVMDialect::getMaxntidXName());
205  if (values.size() > 1)
206  generateMetadata(values[1], NVVM::NVVMDialect::getMaxntidYName());
207  if (values.size() > 2)
208  generateMetadata(values[2], NVVM::NVVMDialect::getMaxntidZName());
209  } else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
210  if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
211  return failure();
212  auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
213  generateMetadata(values[0], NVVM::NVVMDialect::getReqntidXName());
214  if (values.size() > 1)
215  generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName());
216  if (values.size() > 2)
217  generateMetadata(values[2], NVVM::NVVMDialect::getReqntidZName());
218  } else if (attribute.getName() ==
219  NVVM::NVVMDialect::getClusterDimAttrName()) {
220  if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
221  return failure();
222  auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
223  generateMetadata(values[0], NVVM::NVVMDialect::getClusterDimXName());
224  if (values.size() > 1)
225  generateMetadata(values[1], NVVM::NVVMDialect::getClusterDimYName());
226  if (values.size() > 2)
227  generateMetadata(values[2], NVVM::NVVMDialect::getClusterDimZName());
228  } else if (attribute.getName() ==
229  NVVM::NVVMDialect::getClusterMaxBlocksAttrName()) {
230  auto value = dyn_cast<IntegerAttr>(attribute.getValue());
231  llvmFunc->addFnAttr("nvvm.maxclusterrank", llvm::utostr(value.getInt()));
232  } else if (attribute.getName() ==
233  NVVM::NVVMDialect::getMinctasmAttrName()) {
234  auto value = dyn_cast<IntegerAttr>(attribute.getValue());
235  llvmFunc->addFnAttr("nvvm.minctasm", llvm::utostr(value.getInt()));
236  } else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
237  auto value = dyn_cast<IntegerAttr>(attribute.getValue());
238  llvmFunc->addFnAttr("nvvm.maxnreg", llvm::utostr(value.getInt()));
239  } else if (attribute.getName() ==
240  NVVM::NVVMDialect::getKernelFuncAttrName()) {
241  llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel);
242  }
243  return success();
244  }
245 
246  LogicalResult
247  convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute,
248  LLVM::ModuleTranslation &moduleTranslation) const final {
249 
250  llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
251  llvm::Function *llvmFunc =
252  moduleTranslation.lookupFunction(funcOp.getName());
253  llvm::NamedMDNode *nvvmAnnotations =
254  moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations");
255 
256  if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
257  llvm::MDNode *gridConstantMetaData = nullptr;
258 
259  // Check if a 'grid_constant' metadata node exists for the given function
260  for (llvm::MDNode *opnd : llvm::reverse(nvvmAnnotations->operands())) {
261  if (opnd->getNumOperands() == 3 &&
262  opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) &&
263  opnd->getOperand(1) ==
264  llvm::MDString::get(llvmContext, "grid_constant")) {
265  gridConstantMetaData = opnd;
266  break;
267  }
268  }
269 
270  // 'grid_constant' is a function-level meta data node with a list of
271  // integers, where each integer n denotes that the nth parameter has the
272  // grid_constant annotation (numbering from 1). This requires aggregating
273  // the indices of the individual parameters that have this attribute.
274  llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32);
275  if (gridConstantMetaData == nullptr) {
276  // Create a new 'grid_constant' metadata node
277  SmallVector<llvm::Metadata *> gridConstMetadata = {
278  llvm::ValueAsMetadata::getConstant(
279  llvm::ConstantInt::get(i32, argIdx + 1))};
280  llvm::Metadata *llvmMetadata[] = {
281  llvm::ValueAsMetadata::get(llvmFunc),
282  llvm::MDString::get(llvmContext, "grid_constant"),
283  llvm::MDNode::get(llvmContext, gridConstMetadata)};
284  llvm::MDNode *llvmMetadataNode =
285  llvm::MDNode::get(llvmContext, llvmMetadata);
286  nvvmAnnotations->addOperand(llvmMetadataNode);
287  } else {
288  // Append argIdx + 1 to the 'grid_constant' argument list
289  if (auto argList =
290  dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
291  llvm::TempMDTuple clonedArgList = argList->clone();
292  clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
293  llvm::ConstantInt::get(i32, argIdx + 1))));
294  gridConstantMetaData->replaceOperandWith(
295  2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
296  }
297  }
298  }
299  return success();
300  }
301 };
302 } // namespace
303 
305  registry.insert<NVVM::NVVMDialect>();
306  registry.addExtension(+[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) {
307  dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>();
308  });
309 }
310 
312  DialectRegistry registry;
314  context.appendDialectRegistry(registry);
315 }
static LogicalResult convertParameterAttr(llvm::AttrBuilder &attrBuilder, llvm::Attribute::AttrKind llvmKind, NamedAttribute namedAttr, ModuleTranslation &moduleTranslation, Location loc)
static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, NVVM::ReduxKind kind)
static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy, NVVM::ProxyKind toProxy, NVVM::MemScopeKind scope, bool isRelease)
static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, NVVM::ShflKind kind, bool withPredicate)
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num)
Return the intrinsic ID associated with ldmatrix for the given paramters.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Base class for dialect interfaces providing translation to LLVM IR.
Implementation class for module translation.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
Include the generated interface declarations.
void registerNVVMDialectTranslation(DialectRegistry &registry)
Register the NVVM dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...