MLIR 23.0.0git
ROCDLToLLVMIRTranslation.cpp
Go to the documentation of this file.
1//===- ROCDLToLLVMIRTranslation.cpp - Translate ROCDL to LLVM IR ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR ROCDL dialect and
10// LLVM IR.
11//
12//===----------------------------------------------------------------------===//
13
17#include "mlir/IR/Operation.h"
19
20#include "llvm/IR/IRBuilder.h"
21#include "llvm/IR/IntrinsicsAMDGPU.h"
22#include "llvm/Support/raw_ostream.h"
23#include <cstdint>
24
25using namespace mlir;
26using namespace mlir::LLVM;
28
29namespace {
30/// Implementation of the dialect interface that converts operations belonging
31/// to the ROCDL dialect to LLVM IR.
32class ROCDLDialectLLVMIRTranslationInterface
33 : public LLVMTranslationDialectInterface {
34public:
35 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
36
37 /// Translates the given operation to LLVM IR using the provided IR builder
38 /// and saving the state in `moduleTranslation`.
39 LogicalResult
40 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
41 LLVM::ModuleTranslation &moduleTranslation) const final {
42 Operation &opInst = *op;
43#include "mlir/Dialect/LLVMIR/ROCDLConversions.inc"
44
45 return failure();
46 }
47
48 /// Attaches module-level metadata for functions marked as kernels.
49 LogicalResult
50 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
51 NamedAttribute attribute,
52 LLVM::ModuleTranslation &moduleTranslation) const final {
53 auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect());
54 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
55 if (dialect->getKernelAttrHelper().getName() == attribute.getName()) {
56 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
57 if (!func)
58 return op->emitOpError(Twine(attribute.getName()) +
59 " is only supported on `llvm.func` operations");
60 ;
61
62 // For GPU kernels,
63 // 1. Insert AMDGPU_KERNEL calling convention.
64 // 2. Insert amdgpu-flat-work-group-size(1, 256) attribute unless the user
65 // has overriden this value - 256 is the default in clang
66 llvm::Function *llvmFunc =
67 moduleTranslation.lookupFunction(func.getName());
68 llvmFunc->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
69 if (!llvmFunc->hasFnAttribute("amdgpu-flat-work-group-size")) {
70 llvmFunc->addFnAttr("amdgpu-flat-work-group-size", "1,256");
71 }
72
73 // MLIR's GPU kernel APIs all assume and produce uniformly-sized
74 // workgroups, so the lowering of the `rocdl.kernel` marker encodes this
75 // assumption. This assumption may be overridden by setting
76 // `rocdl.uniform_work_group_size` on a given function.
77 if (!llvmFunc->hasFnAttribute("uniform-work-group-size"))
78 llvmFunc->addFnAttr("uniform-work-group-size");
79 }
80 // Override flat-work-group-size
81 // TODO: update clients to rocdl.flat_work_group_size instead,
82 // then remove this half of the branch
83 if (dialect->getMaxFlatWorkGroupSizeAttrHelper().getName() ==
84 attribute.getName()) {
85 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
86 if (!func)
87 return op->emitOpError(Twine(attribute.getName()) +
88 " is only supported on `llvm.func` operations");
89 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
90 if (!value)
91 return op->emitOpError(Twine(attribute.getName()) +
92 " must be an integer");
93
94 llvm::Function *llvmFunc =
95 moduleTranslation.lookupFunction(func.getName());
96 llvm::SmallString<8> llvmAttrValue;
97 llvm::raw_svector_ostream attrValueStream(llvmAttrValue);
98 attrValueStream << "1," << value.getInt();
99 llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
100 }
101 if (dialect->getWavesPerEuAttrHelper().getName() == attribute.getName()) {
102 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
103 if (!func)
104 return op->emitOpError(Twine(attribute.getName()) +
105 " is only supported on `llvm.func` operations");
106 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
107 if (!value)
108 return op->emitOpError(Twine(attribute.getName()) +
109 " must be an integer");
110
111 llvm::Function *llvmFunc =
112 moduleTranslation.lookupFunction(func.getName());
113 llvm::SmallString<8> llvmAttrValue;
114 llvm::raw_svector_ostream attrValueStream(llvmAttrValue);
115 attrValueStream << value.getInt();
116 llvmFunc->addFnAttr("amdgpu-waves-per-eu", llvmAttrValue);
117 }
118 if (dialect->getFlatWorkGroupSizeAttrHelper().getName() ==
119 attribute.getName()) {
120 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
121 if (!func)
122 return op->emitOpError(Twine(attribute.getName()) +
123 " is only supported on `llvm.func` operations");
124 auto value = dyn_cast<StringAttr>(attribute.getValue());
125 if (!value)
126 return op->emitOpError(Twine(attribute.getName()) +
127 " must be a string");
128
129 llvm::Function *llvmFunc =
130 moduleTranslation.lookupFunction(func.getName());
131 llvm::SmallString<8> llvmAttrValue;
132 llvmAttrValue.append(value.getValue());
133 llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
134 }
135 if (ROCDL::ROCDLDialect::getUniformWorkGroupSizeAttrName() ==
136 attribute.getName()) {
137 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
138 if (!func)
139 return op->emitOpError(Twine(attribute.getName()) +
140 " is only supported on `llvm.func` operations");
141 auto value = dyn_cast<BoolAttr>(attribute.getValue());
142 if (!value)
143 return op->emitOpError(Twine(attribute.getName()) +
144 " must be a boolean");
145 llvm::Function *llvmFunc =
146 moduleTranslation.lookupFunction(func.getName());
147 if (value.getValue())
148 llvmFunc->addFnAttr("uniform-work-group-size");
149 else
150 llvmFunc->removeFnAttr("uniform-work-group-size");
151 }
152 if (dialect->getUnsafeFpAtomicsAttrHelper().getName() ==
153 attribute.getName()) {
154 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
155 if (!func)
156 return op->emitOpError(Twine(attribute.getName()) +
157 " is only supported on `llvm.func` operations");
158 auto value = dyn_cast<BoolAttr>(attribute.getValue());
159 if (!value)
160 return op->emitOpError(Twine(attribute.getName()) +
161 " must be a boolean");
162 llvm::Function *llvmFunc =
163 moduleTranslation.lookupFunction(func.getName());
164 llvmFunc->addFnAttr("amdgpu-unsafe-fp-atomics",
165 value.getValue() ? "true" : "false");
166 }
167 // Set reqd_work_group_size metadata
168 if (dialect->getReqdWorkGroupSizeAttrHelper().getName() ==
169 attribute.getName()) {
170 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
171 if (!func)
172 return op->emitOpError(Twine(attribute.getName()) +
173 " is only supported on `llvm.func` operations");
174 auto value = dyn_cast<DenseI32ArrayAttr>(attribute.getValue());
175 if (!value)
176 return op->emitOpError(Twine(attribute.getName()) +
177 " must be a dense i32 array attribute");
178 if (value.asArrayRef().size() != 3)
179 return op->emitOpError(Twine(attribute.getName()) +
180 " must contain exactly three values");
181
182 uint64_t FlatWorkGroupSize = 1;
183 SmallVector<llvm::Metadata *, 3> metadata;
184 llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32);
185 for (int32_t i : value.asArrayRef()) {
186 FlatWorkGroupSize *= static_cast<uint32_t>(i);
187 llvm::Constant *constant = llvm::ConstantInt::get(i32, i);
188 metadata.push_back(llvm::ConstantAsMetadata::get(constant));
189 }
190 llvm::Function *llvmFunc =
191 moduleTranslation.lookupFunction(func.getName());
192 llvm::SmallString<16> expectedFlatWorkGroupSize;
193 llvm::raw_svector_ostream attrValueStream(expectedFlatWorkGroupSize);
194 attrValueStream << FlatWorkGroupSize << "," << FlatWorkGroupSize;
195
196 StringRef flatAttrName =
197 dialect->getFlatWorkGroupSizeAttrHelper().getName();
198 if (auto flatAttr =
199 dyn_cast_if_present<StringAttr>(op->getAttr(flatAttrName))) {
200 if (flatAttr.getValue() != expectedFlatWorkGroupSize)
201 return op->emitOpError(Twine(flatAttrName) +
202 " must match rocdl.reqd_work_group_size");
203 }
204
205 StringRef maxFlatAttrName =
206 dialect->getMaxFlatWorkGroupSizeAttrHelper().getName();
207 if (auto maxFlatAttr =
208 dyn_cast_if_present<IntegerAttr>(op->getAttr(maxFlatAttrName))) {
209 llvm::SmallString<16> expectedMaxFlatWorkGroupSize;
210 llvm::raw_svector_ostream maxAttrValueStream(
211 expectedMaxFlatWorkGroupSize);
212 maxAttrValueStream << "1," << maxFlatAttr.getInt();
213 if (expectedMaxFlatWorkGroupSize != expectedFlatWorkGroupSize)
214 return op->emitOpError(Twine(maxFlatAttrName) +
215 " must match rocdl.reqd_work_group_size");
216 }
217
218 llvmFunc->addFnAttr("amdgpu-flat-work-group-size",
219 expectedFlatWorkGroupSize);
220 llvm::MDNode *node = llvm::MDNode::get(llvmContext, metadata);
221 llvmFunc->setMetadata("reqd_work_group_size", node);
222 }
223
224 // Atomic and nontemporal metadata
225 if (dialect->getLastUseAttrHelper().getName() == attribute.getName()) {
226 for (llvm::Instruction *i : instructions)
227 i->setMetadata("amdgpu.last.use", llvm::MDNode::get(llvmContext, {}));
228 }
229 if (dialect->getNoRemoteMemoryAttrHelper().getName() ==
230 attribute.getName()) {
231 for (llvm::Instruction *i : instructions)
232 i->setMetadata("amdgpu.no.remote.memory",
233 llvm::MDNode::get(llvmContext, {}));
234 }
235 if (dialect->getNoFineGrainedMemoryAttrHelper().getName() ==
236 attribute.getName()) {
237 for (llvm::Instruction *i : instructions)
238 i->setMetadata("amdgpu.no.fine.grained.memory",
239 llvm::MDNode::get(llvmContext, {}));
240 }
241 if (dialect->getIgnoreDenormalModeAttrHelper().getName() ==
242 attribute.getName()) {
243 for (llvm::Instruction *i : instructions)
244 i->setMetadata("amdgpu.ignore.denormal.mode",
245 llvm::MDNode::get(llvmContext, {}));
246 }
247
248 return success();
249 }
250};
251} // namespace
252
254 registry.insert<ROCDL::ROCDLDialect>();
255 registry.addExtension(+[](MLIRContext *ctx, ROCDL::ROCDLDialect *dialect) {
256 dialect->addInterfaces<ROCDLDialectLLVMIRTranslationInterface>();
257 });
258}
259
return success()
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
Include the generated interface declarations.
void registerROCDLDialectTranslation(DialectRegistry &registry)
Register the ROCDL dialect and the translation from it to the LLVM IR in the given registry;.