20#include "llvm/IR/IRBuilder.h"
21#include "llvm/IR/IntrinsicsAMDGPU.h"
22#include "llvm/Support/raw_ostream.h"
32class ROCDLDialectLLVMIRTranslationInterface
33 :
public LLVMTranslationDialectInterface {
35 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
40 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
41 LLVM::ModuleTranslation &moduleTranslation)
const final {
42 Operation &opInst = *op;
43#include "mlir/Dialect/LLVMIR/ROCDLConversions.inc"
50 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
51 NamedAttribute attribute,
52 LLVM::ModuleTranslation &moduleTranslation)
const final {
53 auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect());
54 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
55 if (dialect->getKernelAttrHelper().getName() == attribute.getName()) {
56 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
58 return op->emitOpError(Twine(attribute.getName()) +
59 " is only supported on `llvm.func` operations");
66 llvm::Function *llvmFunc =
67 moduleTranslation.lookupFunction(func.getName());
68 llvmFunc->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
69 if (!llvmFunc->hasFnAttribute(
"amdgpu-flat-work-group-size")) {
70 llvmFunc->addFnAttr(
"amdgpu-flat-work-group-size",
"1,256");
77 if (!llvmFunc->hasFnAttribute(
"uniform-work-group-size"))
78 llvmFunc->addFnAttr(
"uniform-work-group-size");
83 if (dialect->getMaxFlatWorkGroupSizeAttrHelper().getName() ==
84 attribute.getName()) {
85 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
87 return op->emitOpError(Twine(attribute.getName()) +
88 " is only supported on `llvm.func` operations");
89 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
91 return op->emitOpError(Twine(attribute.getName()) +
92 " must be an integer");
94 llvm::Function *llvmFunc =
95 moduleTranslation.lookupFunction(func.getName());
96 llvm::SmallString<8> llvmAttrValue;
97 llvm::raw_svector_ostream attrValueStream(llvmAttrValue);
98 attrValueStream <<
"1," << value.getInt();
99 llvmFunc->addFnAttr(
"amdgpu-flat-work-group-size", llvmAttrValue);
101 if (dialect->getWavesPerEuAttrHelper().getName() == attribute.getName()) {
102 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
104 return op->emitOpError(Twine(attribute.getName()) +
105 " is only supported on `llvm.func` operations");
106 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
108 return op->emitOpError(Twine(attribute.getName()) +
109 " must be an integer");
111 llvm::Function *llvmFunc =
112 moduleTranslation.lookupFunction(func.getName());
113 llvm::SmallString<8> llvmAttrValue;
114 llvm::raw_svector_ostream attrValueStream(llvmAttrValue);
115 attrValueStream << value.getInt();
116 llvmFunc->addFnAttr(
"amdgpu-waves-per-eu", llvmAttrValue);
118 if (dialect->getFlatWorkGroupSizeAttrHelper().getName() ==
119 attribute.getName()) {
120 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
122 return op->emitOpError(Twine(attribute.getName()) +
123 " is only supported on `llvm.func` operations");
124 auto value = dyn_cast<StringAttr>(attribute.getValue());
126 return op->emitOpError(Twine(attribute.getName()) +
127 " must be a string");
129 llvm::Function *llvmFunc =
130 moduleTranslation.lookupFunction(func.getName());
131 llvm::SmallString<8> llvmAttrValue;
132 llvmAttrValue.append(value.getValue());
133 llvmFunc->addFnAttr(
"amdgpu-flat-work-group-size", llvmAttrValue);
135 if (ROCDL::ROCDLDialect::getUniformWorkGroupSizeAttrName() ==
136 attribute.getName()) {
137 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
139 return op->emitOpError(Twine(attribute.getName()) +
140 " is only supported on `llvm.func` operations");
141 auto value = dyn_cast<BoolAttr>(attribute.getValue());
143 return op->emitOpError(Twine(attribute.getName()) +
144 " must be a boolean");
145 llvm::Function *llvmFunc =
146 moduleTranslation.lookupFunction(func.getName());
147 if (value.getValue())
148 llvmFunc->addFnAttr(
"uniform-work-group-size");
150 llvmFunc->removeFnAttr(
"uniform-work-group-size");
152 if (dialect->getUnsafeFpAtomicsAttrHelper().getName() ==
153 attribute.getName()) {
154 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
156 return op->emitOpError(Twine(attribute.getName()) +
157 " is only supported on `llvm.func` operations");
158 auto value = dyn_cast<BoolAttr>(attribute.getValue());
160 return op->emitOpError(Twine(attribute.getName()) +
161 " must be a boolean");
162 llvm::Function *llvmFunc =
163 moduleTranslation.lookupFunction(func.getName());
164 llvmFunc->addFnAttr(
"amdgpu-unsafe-fp-atomics",
165 value.getValue() ?
"true" :
"false");
168 if (dialect->getReqdWorkGroupSizeAttrHelper().getName() ==
169 attribute.getName()) {
170 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
172 return op->emitOpError(Twine(attribute.getName()) +
173 " is only supported on `llvm.func` operations");
174 auto value = dyn_cast<DenseI32ArrayAttr>(attribute.getValue());
176 return op->emitOpError(Twine(attribute.getName()) +
177 " must be a dense i32 array attribute");
178 if (value.asArrayRef().size() != 3)
179 return op->emitOpError(Twine(attribute.getName()) +
180 " must contain exactly three values");
182 uint64_t FlatWorkGroupSize = 1;
183 SmallVector<llvm::Metadata *, 3> metadata;
184 llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32);
185 for (int32_t i : value.asArrayRef()) {
186 FlatWorkGroupSize *=
static_cast<uint32_t
>(i);
187 llvm::Constant *constant = llvm::ConstantInt::get(i32, i);
188 metadata.push_back(llvm::ConstantAsMetadata::get(constant));
190 llvm::Function *llvmFunc =
191 moduleTranslation.lookupFunction(func.getName());
192 llvm::SmallString<16> expectedFlatWorkGroupSize;
193 llvm::raw_svector_ostream attrValueStream(expectedFlatWorkGroupSize);
194 attrValueStream << FlatWorkGroupSize <<
"," << FlatWorkGroupSize;
196 StringRef flatAttrName =
197 dialect->getFlatWorkGroupSizeAttrHelper().getName();
199 dyn_cast_if_present<StringAttr>(op->getAttr(flatAttrName))) {
200 if (flatAttr.getValue() != expectedFlatWorkGroupSize)
201 return op->emitOpError(Twine(flatAttrName) +
202 " must match rocdl.reqd_work_group_size");
205 StringRef maxFlatAttrName =
206 dialect->getMaxFlatWorkGroupSizeAttrHelper().getName();
207 if (
auto maxFlatAttr =
208 dyn_cast_if_present<IntegerAttr>(op->getAttr(maxFlatAttrName))) {
209 llvm::SmallString<16> expectedMaxFlatWorkGroupSize;
210 llvm::raw_svector_ostream maxAttrValueStream(
211 expectedMaxFlatWorkGroupSize);
212 maxAttrValueStream <<
"1," << maxFlatAttr.getInt();
213 if (expectedMaxFlatWorkGroupSize != expectedFlatWorkGroupSize)
214 return op->emitOpError(Twine(maxFlatAttrName) +
215 " must match rocdl.reqd_work_group_size");
218 llvmFunc->addFnAttr(
"amdgpu-flat-work-group-size",
219 expectedFlatWorkGroupSize);
220 llvm::MDNode *node = llvm::MDNode::get(llvmContext, metadata);
221 llvmFunc->setMetadata(
"reqd_work_group_size", node);
225 if (dialect->getLastUseAttrHelper().getName() == attribute.getName()) {
226 for (llvm::Instruction *i : instructions)
227 i->setMetadata(
"amdgpu.last.use", llvm::MDNode::get(llvmContext, {}));
229 if (dialect->getNoRemoteMemoryAttrHelper().getName() ==
230 attribute.getName()) {
231 for (llvm::Instruction *i : instructions)
232 i->setMetadata(
"amdgpu.no.remote.memory",
233 llvm::MDNode::get(llvmContext, {}));
235 if (dialect->getNoFineGrainedMemoryAttrHelper().getName() ==
236 attribute.getName()) {
237 for (llvm::Instruction *i : instructions)
238 i->setMetadata(
"amdgpu.no.fine.grained.memory",
239 llvm::MDNode::get(llvmContext, {}));
241 if (dialect->getIgnoreDenormalModeAttrHelper().getName() ==
242 attribute.getName()) {
243 for (llvm::Instruction *i : instructions)
244 i->setMetadata(
"amdgpu.ignore.denormal.mode",
245 llvm::MDNode::get(llvmContext, {}));
254 registry.
insert<ROCDL::ROCDLDialect>();
256 dialect->addInterfaces<ROCDLDialectLLVMIRTranslationInterface>();
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
Include the generated interface declarations.
void registerROCDLDialectTranslation(DialectRegistry ®istry)
Register the ROCDL dialect and the translation from it to the LLVM IR in the given registry;.