MLIR  20.0.0git
NVVMToLLVMIRTranslation.cpp
Go to the documentation of this file.
1 //===- NVVMToLLVMIRTranslation.cpp - Translate NVVM to LLVM IR ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR NVVM dialect and
10 // LLVM IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
17 #include "mlir/IR/Operation.h"
19 
20 #include "llvm/IR/IRBuilder.h"
21 #include "llvm/IR/IntrinsicsNVPTX.h"
22 
23 using namespace mlir;
24 using namespace mlir::LLVM;
26 
27 static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
28  NVVM::ReduxKind kind) {
29  if (!resultType->isIntegerTy(32))
30  llvm_unreachable("unsupported data type for redux");
31 
32  switch (kind) {
33  case NVVM::ReduxKind::ADD:
34  return llvm::Intrinsic::nvvm_redux_sync_add;
35  case NVVM::ReduxKind::UMAX:
36  return llvm::Intrinsic::nvvm_redux_sync_umax;
37  case NVVM::ReduxKind::UMIN:
38  return llvm::Intrinsic::nvvm_redux_sync_umin;
39  case NVVM::ReduxKind::AND:
40  return llvm::Intrinsic::nvvm_redux_sync_and;
41  case NVVM::ReduxKind::OR:
42  return llvm::Intrinsic::nvvm_redux_sync_or;
43  case NVVM::ReduxKind::XOR:
44  return llvm::Intrinsic::nvvm_redux_sync_xor;
45  case NVVM::ReduxKind::MAX:
46  return llvm::Intrinsic::nvvm_redux_sync_max;
47  case NVVM::ReduxKind::MIN:
48  return llvm::Intrinsic::nvvm_redux_sync_min;
49  }
50  llvm_unreachable("unknown redux kind");
51 }
52 
53 static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
54  NVVM::ShflKind kind,
55  bool withPredicate) {
56 
57  if (withPredicate) {
58  resultType = cast<llvm::StructType>(resultType)->getElementType(0);
59  switch (kind) {
60  case NVVM::ShflKind::bfly:
61  return resultType->isFloatTy()
62  ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
63  : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
64  case NVVM::ShflKind::up:
65  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
66  : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
67  case NVVM::ShflKind::down:
68  return resultType->isFloatTy()
69  ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
70  : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
71  case NVVM::ShflKind::idx:
72  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
73  : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
74  }
75  } else {
76  switch (kind) {
77  case NVVM::ShflKind::bfly:
78  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
79  : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
80  case NVVM::ShflKind::up:
81  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
82  : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
83  case NVVM::ShflKind::down:
84  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
85  : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
86  case NVVM::ShflKind::idx:
87  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
88  : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
89  }
90  }
91  llvm_unreachable("unknown shuffle kind");
92 }
93 
94 /// Return the intrinsic ID associated with ldmatrix for the given paramters.
95 static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
96  int32_t num) {
97  if (layout == NVVM::MMALayout::row) {
98  switch (num) {
99  case 1:
100  return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
101  case 2:
102  return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
103  case 4:
104  return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
105  default:
106  llvm_unreachable("unsupported number of matrix");
107  }
108 
109  } else {
110  switch (num) {
111  case 1:
112  return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
113  case 2:
114  return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
115  case 4:
116  return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
117  default:
118  llvm_unreachable("unsupported number of matrix");
119  }
120  }
121 }
122 
123 static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy,
124  NVVM::ProxyKind toProxy,
125  NVVM::MemScopeKind scope,
126  bool isRelease) {
127  if (fromProxy == NVVM::ProxyKind::GENERIC &&
128  toProxy == NVVM::ProxyKind::TENSORMAP) {
129  switch (scope) {
130  case NVVM::MemScopeKind::CTA: {
131  if (isRelease)
132  return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cta;
133  return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cta;
134  }
135  case NVVM::MemScopeKind::CLUSTER: {
136  if (isRelease)
137  return llvm::Intrinsic::
138  nvvm_fence_proxy_tensormap_generic_release_cluster;
139  return llvm::Intrinsic::
140  nvvm_fence_proxy_tensormap_generic_acquire_cluster;
141  }
142  case NVVM::MemScopeKind::GPU: {
143  if (isRelease)
144  return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_gpu;
145  return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_gpu;
146  }
147  case NVVM::MemScopeKind::SYS: {
148  if (isRelease)
149  return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_sys;
150  return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_sys;
151  }
152  }
153  llvm_unreachable("Unknown scope for uni-directional fence.proxy operation");
154  }
155  llvm_unreachable("Unsupported proxy kinds");
156 }
157 
158 namespace {
159 /// Implementation of the dialect interface that converts operations belonging
160 /// to the NVVM dialect to LLVM IR.
161 class NVVMDialectLLVMIRTranslationInterface
163 public:
165 
166  /// Translates the given operation to LLVM IR using the provided IR builder
167  /// and saving the state in `moduleTranslation`.
168  LogicalResult
169  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
170  LLVM::ModuleTranslation &moduleTranslation) const final {
171  Operation &opInst = *op;
172 #include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
173 
174  return failure();
175  }
176 
177  /// Attaches module-level metadata for functions marked as kernels.
178  LogicalResult
179  amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
180  NamedAttribute attribute,
181  LLVM::ModuleTranslation &moduleTranslation) const final {
182  auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
183  if (!func)
184  return failure();
185  llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
186  llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName());
187 
188  auto generateMetadata = [&](int dim, StringRef name) {
189  llvm::Metadata *llvmMetadata[] = {
190  llvm::ValueAsMetadata::get(llvmFunc),
191  llvm::MDString::get(llvmContext, name),
193  llvm::Type::getInt32Ty(llvmContext), dim))};
194  llvm::MDNode *llvmMetadataNode =
195  llvm::MDNode::get(llvmContext, llvmMetadata);
196  moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations")
197  ->addOperand(llvmMetadataNode);
198  };
199  if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
200  if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
201  return failure();
202  auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
203  generateMetadata(values[0], NVVM::NVVMDialect::getMaxntidXName());
204  if (values.size() > 1)
205  generateMetadata(values[1], NVVM::NVVMDialect::getMaxntidYName());
206  if (values.size() > 2)
207  generateMetadata(values[2], NVVM::NVVMDialect::getMaxntidZName());
208  } else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
209  if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
210  return failure();
211  auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
212  generateMetadata(values[0], NVVM::NVVMDialect::getReqntidXName());
213  if (values.size() > 1)
214  generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName());
215  if (values.size() > 2)
216  generateMetadata(values[2], NVVM::NVVMDialect::getReqntidZName());
217  } else if (attribute.getName() ==
218  NVVM::NVVMDialect::getMinctasmAttrName()) {
219  auto value = dyn_cast<IntegerAttr>(attribute.getValue());
220  generateMetadata(value.getInt(), "minctasm");
221  } else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
222  auto value = dyn_cast<IntegerAttr>(attribute.getValue());
223  generateMetadata(value.getInt(), "maxnreg");
224  } else if (attribute.getName() ==
225  NVVM::NVVMDialect::getKernelFuncAttrName()) {
226  llvm::Metadata *llvmMetadataKernel[] = {
227  llvm::ValueAsMetadata::get(llvmFunc),
228  llvm::MDString::get(llvmContext, "kernel"),
230  llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 1))};
231  llvm::MDNode *llvmMetadataNode =
232  llvm::MDNode::get(llvmContext, llvmMetadataKernel);
233  moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations")
234  ->addOperand(llvmMetadataNode);
235  }
236  return success();
237  }
238 
239  LogicalResult
240  convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute,
241  LLVM::ModuleTranslation &moduleTranslation) const final {
242 
243  llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
244  llvm::Function *llvmFunc =
245  moduleTranslation.lookupFunction(funcOp.getName());
246  llvm::NamedMDNode *nvvmAnnotations =
247  moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations");
248 
249  if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
250  llvm::MDNode *gridConstantMetaData = nullptr;
251 
252  // Check if a 'grid_constant' metadata node exists for the given function
253  for (llvm::MDNode *opnd : llvm::reverse(nvvmAnnotations->operands())) {
254  if (opnd->getNumOperands() == 3 &&
255  opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) &&
256  opnd->getOperand(1) ==
257  llvm::MDString::get(llvmContext, "grid_constant")) {
258  gridConstantMetaData = opnd;
259  break;
260  }
261  }
262 
263  // 'grid_constant' is a function-level meta data node with a list of
264  // integers, where each integer n denotes that the nth parameter has the
265  // grid_constant annotation (numbering from 1). This requires aggregating
266  // the indices of the individual parameters that have this attribute.
267  llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32);
268  if (gridConstantMetaData == nullptr) {
269  // Create a new 'grid_constant' metadata node
270  SmallVector<llvm::Metadata *> gridConstMetadata = {
271  llvm::ValueAsMetadata::getConstant(
272  llvm::ConstantInt::get(i32, argIdx + 1))};
273  llvm::Metadata *llvmMetadata[] = {
274  llvm::ValueAsMetadata::get(llvmFunc),
275  llvm::MDString::get(llvmContext, "grid_constant"),
276  llvm::MDNode::get(llvmContext, gridConstMetadata)};
277  llvm::MDNode *llvmMetadataNode =
278  llvm::MDNode::get(llvmContext, llvmMetadata);
279  nvvmAnnotations->addOperand(llvmMetadataNode);
280  } else {
281  // Append argIdx + 1 to the 'grid_constant' argument list
282  if (auto argList =
283  dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
284  llvm::TempMDTuple clonedArgList = argList->clone();
285  clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
286  llvm::ConstantInt::get(i32, argIdx + 1))));
287  gridConstantMetaData->replaceOperandWith(
288  2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
289  }
290  }
291  }
292  return success();
293  }
294 };
295 } // namespace
296 
298  registry.insert<NVVM::NVVMDialect>();
299  registry.addExtension(+[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) {
300  dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>();
301  });
302 }
303 
305  DialectRegistry registry;
307  context.appendDialectRegistry(registry);
308 }
static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, NVVM::ReduxKind kind)
static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy, NVVM::ProxyKind toProxy, NVVM::MemScopeKind scope, bool isRelease)
static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, NVVM::ShflKind kind, bool withPredicate)
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num)
Return the intrinsic ID associated with ldmatrix for the given paramters.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Base class for dialect interfaces providing translation to LLVM IR.
Implementation class for module translation.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
Include the generated interface declarations.
void registerNVVMDialectTranslation(DialectRegistry &registry)
Register the NVVM dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...