MLIR  20.0.0git
NVVMToLLVMIRTranslation.cpp
Go to the documentation of this file.
1 //===- NVVMToLLVMIRTranslation.cpp - Translate NVVM to LLVM IR ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR NVVM dialect and
10 // LLVM IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
17 #include "mlir/IR/Operation.h"
19 
20 #include "llvm/IR/IRBuilder.h"
21 #include "llvm/IR/IntrinsicsNVPTX.h"
22 
23 using namespace mlir;
24 using namespace mlir::LLVM;
26 
27 static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
28  NVVM::ReduxKind kind) {
29  if (!resultType->isIntegerTy(32))
30  llvm_unreachable("unsupported data type for redux");
31 
32  switch (kind) {
33  case NVVM::ReduxKind::ADD:
34  return llvm::Intrinsic::nvvm_redux_sync_add;
35  case NVVM::ReduxKind::UMAX:
36  return llvm::Intrinsic::nvvm_redux_sync_umax;
37  case NVVM::ReduxKind::UMIN:
38  return llvm::Intrinsic::nvvm_redux_sync_umin;
39  case NVVM::ReduxKind::AND:
40  return llvm::Intrinsic::nvvm_redux_sync_and;
41  case NVVM::ReduxKind::OR:
42  return llvm::Intrinsic::nvvm_redux_sync_or;
43  case NVVM::ReduxKind::XOR:
44  return llvm::Intrinsic::nvvm_redux_sync_xor;
45  case NVVM::ReduxKind::MAX:
46  return llvm::Intrinsic::nvvm_redux_sync_max;
47  case NVVM::ReduxKind::MIN:
48  return llvm::Intrinsic::nvvm_redux_sync_min;
49  }
50  llvm_unreachable("unknown redux kind");
51 }
52 
53 static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
54  NVVM::ShflKind kind,
55  bool withPredicate) {
56 
57  if (withPredicate) {
58  resultType = cast<llvm::StructType>(resultType)->getElementType(0);
59  switch (kind) {
60  case NVVM::ShflKind::bfly:
61  return resultType->isFloatTy()
62  ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
63  : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
64  case NVVM::ShflKind::up:
65  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
66  : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
67  case NVVM::ShflKind::down:
68  return resultType->isFloatTy()
69  ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
70  : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
71  case NVVM::ShflKind::idx:
72  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
73  : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
74  }
75  } else {
76  switch (kind) {
77  case NVVM::ShflKind::bfly:
78  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
79  : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
80  case NVVM::ShflKind::up:
81  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
82  : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
83  case NVVM::ShflKind::down:
84  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
85  : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
86  case NVVM::ShflKind::idx:
87  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
88  : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
89  }
90  }
91  llvm_unreachable("unknown shuffle kind");
92 }
93 
94 /// Return the intrinsic ID associated with ldmatrix for the given paramters.
95 static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
96  int32_t num) {
97  if (layout == NVVM::MMALayout::row) {
98  switch (num) {
99  case 1:
100  return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
101  case 2:
102  return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
103  case 4:
104  return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
105  default:
106  llvm_unreachable("unsupported number of matrix");
107  }
108 
109  } else {
110  switch (num) {
111  case 1:
112  return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
113  case 2:
114  return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
115  case 4:
116  return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
117  default:
118  llvm_unreachable("unsupported number of matrix");
119  }
120  }
121 }
122 
123 namespace {
124 /// Implementation of the dialect interface that converts operations belonging
125 /// to the NVVM dialect to LLVM IR.
126 class NVVMDialectLLVMIRTranslationInterface
128 public:
130 
131  /// Translates the given operation to LLVM IR using the provided IR builder
132  /// and saving the state in `moduleTranslation`.
133  LogicalResult
134  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
135  LLVM::ModuleTranslation &moduleTranslation) const final {
136  Operation &opInst = *op;
137 #include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
138 
139  return failure();
140  }
141 
142  /// Attaches module-level metadata for functions marked as kernels.
143  LogicalResult
144  amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
145  NamedAttribute attribute,
146  LLVM::ModuleTranslation &moduleTranslation) const final {
147  auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
148  if (!func)
149  return failure();
150  llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
151  llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName());
152 
153  auto generateMetadata = [&](int dim, StringRef name) {
154  llvm::Metadata *llvmMetadata[] = {
155  llvm::ValueAsMetadata::get(llvmFunc),
156  llvm::MDString::get(llvmContext, name),
158  llvm::Type::getInt32Ty(llvmContext), dim))};
159  llvm::MDNode *llvmMetadataNode =
160  llvm::MDNode::get(llvmContext, llvmMetadata);
161  moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations")
162  ->addOperand(llvmMetadataNode);
163  };
164  if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
165  if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
166  return failure();
167  auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
168  generateMetadata(values[0], NVVM::NVVMDialect::getMaxntidXName());
169  if (values.size() > 1)
170  generateMetadata(values[1], NVVM::NVVMDialect::getMaxntidYName());
171  if (values.size() > 2)
172  generateMetadata(values[2], NVVM::NVVMDialect::getMaxntidZName());
173  } else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
174  if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
175  return failure();
176  auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
177  generateMetadata(values[0], NVVM::NVVMDialect::getReqntidXName());
178  if (values.size() > 1)
179  generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName());
180  if (values.size() > 2)
181  generateMetadata(values[2], NVVM::NVVMDialect::getReqntidZName());
182  } else if (attribute.getName() ==
183  NVVM::NVVMDialect::getMinctasmAttrName()) {
184  auto value = dyn_cast<IntegerAttr>(attribute.getValue());
185  generateMetadata(value.getInt(), "minctasm");
186  } else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
187  auto value = dyn_cast<IntegerAttr>(attribute.getValue());
188  generateMetadata(value.getInt(), "maxnreg");
189  } else if (attribute.getName() ==
190  NVVM::NVVMDialect::getKernelFuncAttrName()) {
191  llvm::Metadata *llvmMetadataKernel[] = {
192  llvm::ValueAsMetadata::get(llvmFunc),
193  llvm::MDString::get(llvmContext, "kernel"),
195  llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 1))};
196  llvm::MDNode *llvmMetadataNode =
197  llvm::MDNode::get(llvmContext, llvmMetadataKernel);
198  moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations")
199  ->addOperand(llvmMetadataNode);
200  }
201  return success();
202  }
203 
204  LogicalResult
205  convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute,
206  LLVM::ModuleTranslation &moduleTranslation) const final {
207 
208  llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
209  llvm::Function *llvmFunc =
210  moduleTranslation.lookupFunction(funcOp.getName());
211  llvm::NamedMDNode *nvvmAnnotations =
212  moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations");
213 
214  if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
215  llvm::MDNode *gridConstantMetaData = nullptr;
216 
217  // Check if a 'grid_constant' metadata node exists for the given function
218  for (llvm::MDNode *opnd : llvm::reverse(nvvmAnnotations->operands())) {
219  if (opnd->getNumOperands() == 3 &&
220  opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) &&
221  opnd->getOperand(1) ==
222  llvm::MDString::get(llvmContext, "grid_constant")) {
223  gridConstantMetaData = opnd;
224  break;
225  }
226  }
227 
228  // 'grid_constant' is a function-level meta data node with a list of
229  // integers, where each integer n denotes that the nth parameter has the
230  // grid_constant annotation (numbering from 1). This requires aggregating
231  // the indices of the individual parameters that have this attribute.
232  llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32);
233  if (gridConstantMetaData == nullptr) {
234  // Create a new 'grid_constant' metadata node
235  SmallVector<llvm::Metadata *> gridConstMetadata = {
236  llvm::ValueAsMetadata::getConstant(
237  llvm::ConstantInt::get(i32, argIdx + 1))};
238  llvm::Metadata *llvmMetadata[] = {
239  llvm::ValueAsMetadata::get(llvmFunc),
240  llvm::MDString::get(llvmContext, "grid_constant"),
241  llvm::MDNode::get(llvmContext, gridConstMetadata)};
242  llvm::MDNode *llvmMetadataNode =
243  llvm::MDNode::get(llvmContext, llvmMetadata);
244  nvvmAnnotations->addOperand(llvmMetadataNode);
245  } else {
246  // Append argIdx + 1 to the 'grid_constant' argument list
247  if (auto argList =
248  dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
249  llvm::TempMDTuple clonedArgList = argList->clone();
250  clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
251  llvm::ConstantInt::get(i32, argIdx + 1))));
252  gridConstantMetaData->replaceOperandWith(
253  2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
254  }
255  }
256  }
257  return success();
258  }
259 };
260 } // namespace
261 
263  registry.insert<NVVM::NVVMDialect>();
264  registry.addExtension(+[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) {
265  dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>();
266  });
267 }
268 
270  DialectRegistry registry;
272  context.appendDialectRegistry(registry);
273 }
static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, NVVM::ReduxKind kind)
static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, NVVM::ShflKind kind, bool withPredicate)
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num)
Return the intrinsic ID associated with ldmatrix for the given paramters.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Base class for dialect interfaces providing translation to LLVM IR.
Implementation class for module translation.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
Include the generated interface declarations.
void registerNVVMDialectTranslation(DialectRegistry &registry)
Register the NVVM dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...