14 #include "mlir/Config/mlir-config.h"
19 #if MLIR_ENABLE_ROCM_CONVERSIONS
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/GlobalVariable.h"
28 #include "llvm/IR/Module.h"
29 #include "llvm/IRReader/IRReader.h"
30 #include "llvm/Linker/Linker.h"
32 #include "llvm/MC/MCAsmBackend.h"
33 #include "llvm/MC/MCAsmInfo.h"
34 #include "llvm/MC/MCCodeEmitter.h"
35 #include "llvm/MC/MCContext.h"
36 #include "llvm/MC/MCInstrInfo.h"
37 #include "llvm/MC/MCObjectFileInfo.h"
38 #include "llvm/MC/MCObjectWriter.h"
39 #include "llvm/MC/MCParser/MCTargetAsmParser.h"
40 #include "llvm/MC/MCRegisterInfo.h"
41 #include "llvm/MC/MCStreamer.h"
42 #include "llvm/MC/MCSubtargetInfo.h"
43 #include "llvm/MC/TargetRegistry.h"
45 #include "llvm/Support/CommandLine.h"
46 #include "llvm/Support/FileSystem.h"
47 #include "llvm/Support/FileUtilities.h"
48 #include "llvm/Support/Path.h"
49 #include "llvm/Support/Program.h"
50 #include "llvm/Support/SourceMgr.h"
51 #include "llvm/Support/TargetSelect.h"
52 #include "llvm/Support/Threading.h"
53 #include "llvm/Support/WithColor.h"
55 #include "llvm/Target/TargetMachine.h"
56 #include "llvm/Target/TargetOptions.h"
58 #include "llvm/Transforms/IPO/Internalize.h"
65 class SerializeToHsacoPass
66 :
public PassWrapper<SerializeToHsacoPass, gpu::SerializeToBlobPass> {
67 static llvm::once_flag initializeBackendOnce;
72 SerializeToHsacoPass(StringRef triple, StringRef arch, StringRef features,
74 SerializeToHsacoPass(
const SerializeToHsacoPass &other);
75 StringRef getArgument()
const override {
return "gpu-to-hsaco"; }
76 StringRef getDescription()
const override {
77 return "Lower GPU kernel function to HSACO binary annotations";
81 Option<std::string> rocmPath{*
this,
"rocm-path",
82 llvm::cl::desc(
"Path to ROCm install")};
85 std::unique_ptr<llvm::Module>
86 translateToLLVMIR(llvm::LLVMContext &llvmContext)
override;
90 std::optional<SmallVector<std::unique_ptr<llvm::Module>, 3>>
91 loadLibraries(SmallVectorImpl<char> &path,
92 SmallVectorImpl<StringRef> &libraries,
93 llvm::LLVMContext &context);
96 std::unique_ptr<std::vector<char>>
97 serializeISA(
const std::string &isa)
override;
100 SmallVectorImpl<char> &result);
101 std::unique_ptr<std::vector<char>> createHsaco(ArrayRef<char> isaBinary);
103 std::string getRocmPath();
107 SerializeToHsacoPass::SerializeToHsacoPass(
const SerializeToHsacoPass &other)
108 :
PassWrapper<SerializeToHsacoPass, gpu::SerializeToBlobPass>(other) {}
113 std::string SerializeToHsacoPass::getRocmPath() {
114 if (rocmPath.getNumOccurrences() > 0)
115 return rocmPath.getValue();
123 if (!option.hasValue())
127 llvm::once_flag SerializeToHsacoPass::initializeBackendOnce;
129 SerializeToHsacoPass::SerializeToHsacoPass(StringRef triple, StringRef arch,
130 StringRef features,
int optLevel) {
133 llvm::call_once(initializeBackendOnce, []() {
135 LLVMInitializeAMDGPUAsmParser();
136 LLVMInitializeAMDGPUAsmPrinter();
137 LLVMInitializeAMDGPUTarget();
138 LLVMInitializeAMDGPUTargetInfo();
139 LLVMInitializeAMDGPUTargetMC();
141 maybeSetOption(this->triple, [&triple] {
return triple.str(); });
142 maybeSetOption(this->chip, [&arch] {
return arch.str(); });
143 maybeSetOption(this->features, [&features] {
return features.str(); });
144 if (this->optLevel.getNumOccurrences() == 0)
145 this->optLevel.setValue(optLevel);
148 std::optional<SmallVector<std::unique_ptr<llvm::Module>, 3>>
149 SerializeToHsacoPass::loadLibraries(SmallVectorImpl<char> &path,
150 SmallVectorImpl<StringRef> &libraries,
151 llvm::LLVMContext &context) {
152 SmallVector<std::unique_ptr<llvm::Module>, 3> ret;
153 size_t dirLength = path.size();
155 if (!llvm::sys::fs::is_directory(path)) {
156 getOperation().emitRemark() <<
"Bitcode path: " << path
157 <<
" does not exist or is not a directory\n";
161 for (
const StringRef file : libraries) {
162 llvm::SMDiagnostic error;
163 llvm::sys::path::append(path, file);
164 llvm::StringRef pathRef(path.data(), path.size());
165 std::unique_ptr<llvm::Module> library =
166 llvm::getLazyIRFileModule(pathRef, error, context);
167 path.truncate(dirLength);
169 getOperation().emitError() <<
"Failed to load library " << file
170 <<
" from " << path << error.getMessage();
174 if (
auto *openclVersion = library->getNamedMetadata(
"opencl.ocl.version"))
175 library->eraseNamedMetadata(openclVersion);
177 if (
auto *ident = library->getNamedMetadata(
"llvm.ident"))
178 library->eraseNamedMetadata(ident);
179 ret.push_back(std::move(library));
182 return std::move(ret);
185 std::unique_ptr<llvm::Module>
186 SerializeToHsacoPass::translateToLLVMIR(llvm::LLVMContext &llvmContext) {
188 std::unique_ptr<llvm::Module> ret =
189 gpu::SerializeToBlobPass::translateToLLVMIR(llvmContext);
192 getOperation().emitOpError(
"Module lowering failed");
197 bool needOpenCl =
false;
198 bool needOckl =
false;
199 bool needOcml =
false;
200 for (llvm::Function &f : ret->functions()) {
201 if (f.hasExternalLinkage() && f.hasName() && !f.hasExactDefinition()) {
202 StringRef funcName = f.getName();
203 if (
"printf" == funcName)
205 if (funcName.starts_with(
"__ockl_"))
207 if (funcName.starts_with(
"__ocml_"))
213 needOcml = needOckl =
true;
216 if (!(needOpenCl || needOcml || needOckl))
223 auto addControlConstant = [&module = *ret](StringRef name, uint32_t value,
225 using llvm::GlobalVariable;
226 if (module.getNamedGlobal(name)) {
229 llvm::IntegerType *type =
230 llvm::IntegerType::getIntNTy(module.getContext(), bitwidth);
232 auto *constant =
new GlobalVariable(
234 true, GlobalVariable::LinkageTypes::LinkOnceODRLinkage,
237 GlobalVariable::ThreadLocalMode::NotThreadLocal,
239 constant->setUnnamedAddr(GlobalVariable::UnnamedAddr::Local);
240 constant->setVisibility(
241 GlobalVariable::VisibilityTypes::ProtectedVisibility);
242 constant->setAlignment(llvm::MaybeAlign(bitwidth / 8));
249 addControlConstant(
"__oclc_finite_only_opt", 0, 8);
250 addControlConstant(
"__oclc_daz_opt", 0, 8);
251 addControlConstant(
"__oclc_correctly_rounded_sqrt32", 1, 8);
252 addControlConstant(
"__oclc_unsafe_math_opt", 0, 8);
254 if (needOcml || needOckl) {
255 addControlConstant(
"__oclc_wavefrontsize64", 1, 8);
256 StringRef chipSet = this->chip.getValue();
257 if (chipSet.starts_with(
"gfx"))
258 chipSet = chipSet.substr(3);
260 llvm::APInt(32, chipSet.substr(chipSet.size() - 2), 16).getZExtValue();
261 uint32_t major = llvm::APInt(32, chipSet.substr(0, chipSet.size() - 2), 10)
263 uint32_t isaNumber = minor + 1000 * major;
264 addControlConstant(
"__oclc_ISA_version", isaNumber, 32);
268 addControlConstant(
"__oclc_ABI_version", 500, 32);
274 libraries.push_back(
"opencl.bc");
276 libraries.push_back(
"ocml.bc");
278 libraries.push_back(
"ockl.bc");
280 std::optional<SmallVector<std::unique_ptr<llvm::Module>, 3>> mbModules;
281 std::string theRocmPath = getRocmPath();
283 llvm::sys::path::append(bitcodePath,
"amdgcn",
"bitcode");
284 mbModules = loadLibraries(bitcodePath, libraries, llvmContext);
288 .emitWarning(
"Could not load required device libraries")
290 <<
"This will probably cause link-time or run-time failures";
294 llvm::Linker linker(*ret);
295 for (std::unique_ptr<llvm::Module> &libModule : *mbModules) {
303 bool err = linker.linkInModule(
304 std::move(libModule), llvm::Linker::Flags::LinkOnlyNeeded,
306 llvm::internalizeModule(m, [&gvs](
const llvm::GlobalValue &gv) {
307 return !gv.hasName() || (gvs.count(gv.getName()) == 0);
312 getOperation().emitError(
313 "Unrecoverable failure during device library linking.");
322 LogicalResult SerializeToHsacoPass::assembleIsa(
const std::string &isa,
323 SmallVectorImpl<char> &result) {
324 auto loc = getOperation().getLoc();
326 llvm::raw_svector_ostream os(result);
328 llvm::Triple triple(llvm::Triple::normalize(this->triple));
330 const llvm::Target *target =
331 llvm::TargetRegistry::lookupTarget(triple.normalize(), error);
333 return emitError(loc, Twine(
"failed to lookup target: ") + error);
335 llvm::SourceMgr srcMgr;
336 srcMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(isa), SMLoc());
338 const llvm::MCTargetOptions mcOptions;
339 std::unique_ptr<llvm::MCRegisterInfo> mri(
340 target->createMCRegInfo(this->triple));
341 std::unique_ptr<llvm::MCAsmInfo> mai(
342 target->createMCAsmInfo(*mri, this->triple, mcOptions));
343 std::unique_ptr<llvm::MCSubtargetInfo> sti(
344 target->createMCSubtargetInfo(this->triple, this->chip, this->features));
346 llvm::MCContext ctx(triple, mai.get(), mri.get(), sti.get(), &srcMgr,
348 std::unique_ptr<llvm::MCObjectFileInfo> mofi(target->createMCObjectFileInfo(
350 ctx.setObjectFileInfo(mofi.get());
352 SmallString<128> cwd;
353 if (!llvm::sys::fs::current_path(cwd))
354 ctx.setCompilationDir(cwd);
356 std::unique_ptr<llvm::MCStreamer> mcStreamer;
357 std::unique_ptr<llvm::MCInstrInfo> mcii(target->createMCInstrInfo());
359 llvm::MCCodeEmitter *ce = target->createMCCodeEmitter(*mcii, ctx);
360 llvm::MCAsmBackend *mab = target->createMCAsmBackend(*sti, *mri, mcOptions);
361 mcStreamer.reset(target->createMCObjectStreamer(
362 triple, ctx, std::unique_ptr<llvm::MCAsmBackend>(mab),
363 mab->createObjectWriter(os), std::unique_ptr<llvm::MCCodeEmitter>(ce),
364 *sti, mcOptions.MCRelaxAll, mcOptions.MCIncrementalLinkerCompatible,
366 mcStreamer->setUseAssemblerInfoForParsing(
true);
368 std::unique_ptr<llvm::MCAsmParser> parser(
369 createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai));
370 std::unique_ptr<llvm::MCTargetAsmParser> tap(
371 target->createMCAsmParser(*sti, *parser, *mcii, mcOptions));
374 return emitError(loc,
"assembler initialization error");
376 parser->setTargetParser(*tap);
382 std::unique_ptr<std::vector<char>>
383 SerializeToHsacoPass::createHsaco(ArrayRef<char> isaBinary) {
384 auto loc = getOperation().getLoc();
387 int tempIsaBinaryFd = -1;
388 SmallString<128> tempIsaBinaryFilename;
389 if (llvm::sys::fs::createTemporaryFile(
"kernel",
"o", tempIsaBinaryFd,
390 tempIsaBinaryFilename)) {
391 emitError(loc,
"temporary file for ISA binary creation error");
394 llvm::FileRemover cleanupIsaBinary(tempIsaBinaryFilename);
395 llvm::raw_fd_ostream tempIsaBinaryOs(tempIsaBinaryFd,
true);
396 tempIsaBinaryOs << StringRef(isaBinary.data(), isaBinary.size());
397 tempIsaBinaryOs.close();
400 SmallString<128> tempHsacoFilename;
401 if (llvm::sys::fs::createTemporaryFile(
"kernel",
"hsaco",
402 tempHsacoFilename)) {
403 emitError(loc,
"temporary file for HSA code object creation error");
406 llvm::FileRemover cleanupHsaco(tempHsacoFilename);
408 std::string theRocmPath = getRocmPath();
410 llvm::sys::path::append(lldPath,
"llvm",
"bin",
"ld.lld");
411 int lldResult = llvm::sys::ExecuteAndWait(
413 {
"ld.lld",
"-shared", tempIsaBinaryFilename,
"-o", tempHsacoFilename});
414 if (lldResult != 0) {
421 llvm::MemoryBuffer::getFile(tempHsacoFilename,
false);
423 emitError(loc,
"read HSA code object from temp file error");
427 StringRef buffer = (*hsacoFile)->getBuffer();
428 return std::make_unique<std::vector<char>>(buffer.begin(), buffer.end());
431 std::unique_ptr<std::vector<char>>
432 SerializeToHsacoPass::serializeISA(
const std::string &isa) {
433 SmallVector<char, 0> isaBinary;
434 if (
failed(assembleIsa(isa, isaBinary)))
436 return createHsaco(isaBinary);
442 return std::make_unique<SerializeToHsacoPass>(
"amdgcn-amd-amdhsa",
"",
"",
453 return std::make_unique<SerializeToHsacoPass>(triple, arch, features,
#define __DEFAULT_ROCM_PATH__
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
This class provides a CRTP wrapper around a base pass class to define several necessary utility metho...
Include the generated interface declarations.
void registerGpuSerializeToHsacoPass()
Register pass to serialize GPU kernel functions to a HSAco binary annotation.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
std::unique_ptr< Pass > createGpuSerializeToHsacoPass(StringRef triple, StringRef arch, StringRef features, int optLevel)
Create an instance of the GPU kernel function to HSAco binary serialization pass.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
PassRegistration provides a global initializer that registers a Pass allocation routine for a concret...
This class represents a specific pass option, with a provided data type.