MLIR 23.0.0git
ROCDLToLLVMIRTranslation.cpp
Go to the documentation of this file.
1//===- ROCDLToLLVMIRTranslation.cpp - Translate ROCDL to LLVM IR ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR ROCDL dialect and
10// LLVM IR.
11//
12//===----------------------------------------------------------------------===//
13
17#include "mlir/IR/Operation.h"
19
20#include "llvm/IR/ConstantRange.h"
21#include "llvm/IR/IRBuilder.h"
22#include "llvm/IR/IntrinsicsAMDGPU.h"
23#include "llvm/Support/raw_ostream.h"
24
25using namespace mlir;
26using namespace mlir::LLVM;
28
29// Create a call to ROCm-Device-Library function that returns an ID.
30// This is intended to specifically call device functions that fetch things like
31// block or grid dimensions, and so is limited to functions that take one
32// integer parameter.
33static llvm::Value *createDimGetterFunctionCall(llvm::IRBuilderBase &builder,
34 Operation *op, StringRef fnName,
35 int parameter) {
36 llvm::Module *module = builder.GetInsertBlock()->getModule();
37 llvm::FunctionType *functionType = llvm::FunctionType::get(
38 llvm::Type::getInt64Ty(module->getContext()), // return type.
39 llvm::Type::getInt32Ty(module->getContext()), // parameter type.
40 false); // no variadic arguments.
41 llvm::Function *fn = dyn_cast<llvm::Function>(
42 module->getOrInsertFunction(fnName, functionType).getCallee());
43 llvm::Value *fnOp0 = llvm::ConstantInt::get(
44 llvm::Type::getInt32Ty(module->getContext()), parameter);
45 auto *call = builder.CreateCall(fn, ArrayRef<llvm::Value *>(fnOp0));
46 if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
47 // Zero-extend to 64 bits because the GPU dialect uses 32-bit bounds but
48 // these ockl functions are defined to be 64-bits
49 call->addRangeRetAttr(llvm::ConstantRange(rangeAttr.getLower().zext(64),
50 rangeAttr.getUpper().zext(64)));
51 }
52 return call;
53}
54
55namespace {
56/// Implementation of the dialect interface that converts operations belonging
57/// to the ROCDL dialect to LLVM IR.
58class ROCDLDialectLLVMIRTranslationInterface
59 : public LLVMTranslationDialectInterface {
60public:
61 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
62
63 /// Translates the given operation to LLVM IR using the provided IR builder
64 /// and saving the state in `moduleTranslation`.
65 LogicalResult
66 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
67 LLVM::ModuleTranslation &moduleTranslation) const final {
68 Operation &opInst = *op;
69#include "mlir/Dialect/LLVMIR/ROCDLConversions.inc"
70
71 return failure();
72 }
73
74 /// Attaches module-level metadata for functions marked as kernels.
75 LogicalResult
76 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
77 NamedAttribute attribute,
78 LLVM::ModuleTranslation &moduleTranslation) const final {
79 auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect());
80 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
81 if (dialect->getKernelAttrHelper().getName() == attribute.getName()) {
82 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
83 if (!func)
84 return op->emitOpError(Twine(attribute.getName()) +
85 " is only supported on `llvm.func` operations");
86 ;
87
88 // For GPU kernels,
89 // 1. Insert AMDGPU_KERNEL calling convention.
90 // 2. Insert amdgpu-flat-work-group-size(1, 256) attribute unless the user
91 // has overriden this value - 256 is the default in clang
92 llvm::Function *llvmFunc =
93 moduleTranslation.lookupFunction(func.getName());
94 llvmFunc->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
95 if (!llvmFunc->hasFnAttribute("amdgpu-flat-work-group-size")) {
96 llvmFunc->addFnAttr("amdgpu-flat-work-group-size", "1,256");
97 }
98
99 // MLIR's GPU kernel APIs all assume and produce uniformly-sized
100 // workgroups, so the lowering of the `rocdl.kernel` marker encodes this
101 // assumption. This assumption may be overridden by setting
102 // `rocdl.uniform_work_group_size` on a given function.
103 if (!llvmFunc->hasFnAttribute("uniform-work-group-size"))
104 llvmFunc->addFnAttr("uniform-work-group-size");
105 }
106 // Override flat-work-group-size
107 // TODO: update clients to rocdl.flat_work_group_size instead,
108 // then remove this half of the branch
109 if (dialect->getMaxFlatWorkGroupSizeAttrHelper().getName() ==
110 attribute.getName()) {
111 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
112 if (!func)
113 return op->emitOpError(Twine(attribute.getName()) +
114 " is only supported on `llvm.func` operations");
115 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
116 if (!value)
117 return op->emitOpError(Twine(attribute.getName()) +
118 " must be an integer");
119
120 llvm::Function *llvmFunc =
121 moduleTranslation.lookupFunction(func.getName());
122 llvm::SmallString<8> llvmAttrValue;
123 llvm::raw_svector_ostream attrValueStream(llvmAttrValue);
124 attrValueStream << "1," << value.getInt();
125 llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
126 }
127 if (dialect->getWavesPerEuAttrHelper().getName() == attribute.getName()) {
128 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
129 if (!func)
130 return op->emitOpError(Twine(attribute.getName()) +
131 " is only supported on `llvm.func` operations");
132 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
133 if (!value)
134 return op->emitOpError(Twine(attribute.getName()) +
135 " must be an integer");
136
137 llvm::Function *llvmFunc =
138 moduleTranslation.lookupFunction(func.getName());
139 llvm::SmallString<8> llvmAttrValue;
140 llvm::raw_svector_ostream attrValueStream(llvmAttrValue);
141 attrValueStream << value.getInt();
142 llvmFunc->addFnAttr("amdgpu-waves-per-eu", llvmAttrValue);
143 }
144 if (dialect->getFlatWorkGroupSizeAttrHelper().getName() ==
145 attribute.getName()) {
146 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
147 if (!func)
148 return op->emitOpError(Twine(attribute.getName()) +
149 " is only supported on `llvm.func` operations");
150 auto value = dyn_cast<StringAttr>(attribute.getValue());
151 if (!value)
152 return op->emitOpError(Twine(attribute.getName()) +
153 " must be a string");
154
155 llvm::Function *llvmFunc =
156 moduleTranslation.lookupFunction(func.getName());
157 llvm::SmallString<8> llvmAttrValue;
158 llvmAttrValue.append(value.getValue());
159 llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
160 }
161 if (ROCDL::ROCDLDialect::getUniformWorkGroupSizeAttrName() ==
162 attribute.getName()) {
163 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
164 if (!func)
165 return op->emitOpError(Twine(attribute.getName()) +
166 " is only supported on `llvm.func` operations");
167 auto value = dyn_cast<BoolAttr>(attribute.getValue());
168 if (!value)
169 return op->emitOpError(Twine(attribute.getName()) +
170 " must be a boolean");
171 llvm::Function *llvmFunc =
172 moduleTranslation.lookupFunction(func.getName());
173 if (value.getValue())
174 llvmFunc->addFnAttr("uniform-work-group-size");
175 else
176 llvmFunc->removeFnAttr("uniform-work-group-size");
177 }
178 if (dialect->getUnsafeFpAtomicsAttrHelper().getName() ==
179 attribute.getName()) {
180 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
181 if (!func)
182 return op->emitOpError(Twine(attribute.getName()) +
183 " is only supported on `llvm.func` operations");
184 auto value = dyn_cast<BoolAttr>(attribute.getValue());
185 if (!value)
186 return op->emitOpError(Twine(attribute.getName()) +
187 " must be a boolean");
188 llvm::Function *llvmFunc =
189 moduleTranslation.lookupFunction(func.getName());
190 llvmFunc->addFnAttr("amdgpu-unsafe-fp-atomics",
191 value.getValue() ? "true" : "false");
192 }
193 // Set reqd_work_group_size metadata
194 if (dialect->getReqdWorkGroupSizeAttrHelper().getName() ==
195 attribute.getName()) {
196 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
197 if (!func)
198 return op->emitOpError(Twine(attribute.getName()) +
199 " is only supported on `llvm.func` operations");
200 auto value = dyn_cast<DenseI32ArrayAttr>(attribute.getValue());
201 if (!value)
202 return op->emitOpError(Twine(attribute.getName()) +
203 " must be a dense i32 array attribute");
204 SmallVector<llvm::Metadata *, 3> metadata;
205 llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32);
206 for (int32_t i : value.asArrayRef()) {
207 llvm::Constant *constant = llvm::ConstantInt::get(i32, i);
208 metadata.push_back(llvm::ConstantAsMetadata::get(constant));
209 }
210 llvm::Function *llvmFunc =
211 moduleTranslation.lookupFunction(func.getName());
212 llvm::MDNode *node = llvm::MDNode::get(llvmContext, metadata);
213 llvmFunc->setMetadata("reqd_work_group_size", node);
214 }
215
216 // Atomic and nontemporal metadata
217 if (dialect->getLastUseAttrHelper().getName() == attribute.getName()) {
218 for (llvm::Instruction *i : instructions)
219 i->setMetadata("amdgpu.last.use", llvm::MDNode::get(llvmContext, {}));
220 }
221 if (dialect->getNoRemoteMemoryAttrHelper().getName() ==
222 attribute.getName()) {
223 for (llvm::Instruction *i : instructions)
224 i->setMetadata("amdgpu.no.remote.memory",
225 llvm::MDNode::get(llvmContext, {}));
226 }
227 if (dialect->getNoFineGrainedMemoryAttrHelper().getName() ==
228 attribute.getName()) {
229 for (llvm::Instruction *i : instructions)
230 i->setMetadata("amdgpu.no.fine.grained.memory",
231 llvm::MDNode::get(llvmContext, {}));
232 }
233 if (dialect->getIgnoreDenormalModeAttrHelper().getName() ==
234 attribute.getName()) {
235 for (llvm::Instruction *i : instructions)
236 i->setMetadata("amdgpu.ignore.denormal.mode",
237 llvm::MDNode::get(llvmContext, {}));
238 }
239
240 return success();
241 }
242};
243} // namespace
244
246 registry.insert<ROCDL::ROCDLDialect>();
247 registry.addExtension(+[](MLIRContext *ctx, ROCDL::ROCDLDialect *dialect) {
248 dialect->addInterfaces<ROCDLDialectLLVMIRTranslationInterface>();
249 });
250}
251
return success()
static llvm::Value * createDimGetterFunctionCall(llvm::IRBuilderBase &builder, Operation *op, StringRef fnName, int parameter)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:579
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
Include the generated interface declarations.
void registerROCDLDialectTranslation(DialectRegistry &registry)
Register the ROCDL dialect and the translation from it to the LLVM IR in the given registry;.