MLIR 22.0.0git
ExecutionEngine.cpp
Go to the documentation of this file.
1//===- ExecutionEngine.cpp - MLIR Execution engine and utils --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the execution engine for MLIR modules based on LLVM Orc
10// JIT engine.
11//
12//===----------------------------------------------------------------------===//
15#include "mlir/IR/BuiltinOps.h"
18
19#include "llvm/ExecutionEngine/JITEventListener.h"
20#include "llvm/ExecutionEngine/ObjectCache.h"
21#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
22#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
23#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
24#include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
25#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
26#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
27#include "llvm/IR/IRBuilder.h"
28#include "llvm/MC/TargetRegistry.h"
29#include "llvm/Support/Debug.h"
30#include "llvm/Support/Error.h"
31#include "llvm/Support/ToolOutputFile.h"
32#include "llvm/TargetParser/Host.h"
33#include "llvm/TargetParser/SubtargetFeature.h"
34
35#define DEBUG_TYPE "execution-engine"
36
37using namespace mlir;
38using llvm::dbgs;
39using llvm::Error;
40using llvm::errs;
41using llvm::Expected;
42using llvm::LLVMContext;
43using llvm::MemoryBuffer;
44using llvm::MemoryBufferRef;
45using llvm::Module;
46using llvm::SectionMemoryManager;
47using llvm::StringError;
48using llvm::Triple;
49using llvm::orc::DynamicLibrarySearchGenerator;
50using llvm::orc::ExecutionSession;
51using llvm::orc::IRCompileLayer;
52using llvm::orc::JITTargetMachineBuilder;
53using llvm::orc::MangleAndInterner;
54using llvm::orc::RTDyldObjectLinkingLayer;
55using llvm::orc::SymbolMap;
56using llvm::orc::ThreadSafeModule;
57using llvm::orc::TMOwningSimpleCompiler;
58
59/// Wrap a string into an llvm::StringError.
60static Error makeStringError(const Twine &message) {
61 return llvm::make_error<StringError>(message.str(),
62 llvm::inconvertibleErrorCode());
63}
64
66 MemoryBufferRef objBuffer) {
67 cachedObjects[m->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy(
68 objBuffer.getBuffer(), objBuffer.getBufferIdentifier());
69}
70
71std::unique_ptr<MemoryBuffer> SimpleObjectCache::getObject(const Module *m) {
72 auto i = cachedObjects.find(m->getModuleIdentifier());
73 if (i == cachedObjects.end()) {
74 LLVM_DEBUG(dbgs() << "No object for " << m->getModuleIdentifier()
75 << " in cache. Compiling.\n");
76 return nullptr;
77 }
78 LLVM_DEBUG(dbgs() << "Object for " << m->getModuleIdentifier()
79 << " loaded from cache.\n");
80 return MemoryBuffer::getMemBuffer(i->second->getMemBufferRef());
81}
82
83void SimpleObjectCache::dumpToObjectFile(StringRef outputFilename) {
84 // Set up the output file.
85 std::string errorMessage;
86 auto file = openOutputFile(outputFilename, &errorMessage);
87 if (!file) {
88 llvm::errs() << errorMessage << "\n";
89 return;
90 }
91
92 // Dump the object generated for a single module to the output file.
93 assert(cachedObjects.size() == 1 && "Expected only one object entry.");
94 auto &cachedObject = cachedObjects.begin()->second;
95 file->os() << cachedObject->getBuffer();
96 file->keep();
97}
98
99bool SimpleObjectCache::isEmpty() { return cachedObjects.empty(); }
100
101void ExecutionEngine::dumpToObjectFile(StringRef filename) {
102 if (cache == nullptr) {
103 llvm::errs() << "cannot dump ExecutionEngine object code to file: "
104 "object cache is disabled\n";
105 return;
106 }
107 // Compilation is lazy and it doesn't populate object cache unless requested.
108 // In case object dump is requested before cache is populated, we need to
109 // force compilation manually.
110 if (cache->isEmpty()) {
111 for (std::string &functionName : functionNames) {
112 auto result = lookupPacked(functionName);
113 if (!result) {
114 llvm::errs() << "Could not compile " << functionName << ":\n "
115 << result.takeError() << "\n";
116 return;
117 }
118 }
119 }
120 cache->dumpToObjectFile(filename);
121}
122
124 llvm::function_ref<SymbolMap(MangleAndInterner)> symbolMap) {
125 auto &mainJitDylib = jit->getMainJITDylib();
126 cantFail(mainJitDylib.define(
127 absoluteSymbols(symbolMap(llvm::orc::MangleAndInterner(
128 mainJitDylib.getExecutionSession(), jit->getDataLayout())))));
129}
130
132 llvm::TargetMachine *tm) {
133 llvmModule->setDataLayout(tm->createDataLayout());
134 llvmModule->setTargetTriple(tm->getTargetTriple());
135}
136
137static std::string makePackedFunctionName(StringRef name) {
138 return "_mlir_" + name.str();
139}
140
141// For each function in the LLVM module, define an interface function that wraps
142// all the arguments of the original function and all its results into an i8**
143// pointer to provide a unified invocation interface.
144static void packFunctionArguments(Module *module) {
145 auto &ctx = module->getContext();
146 llvm::IRBuilder<> builder(ctx);
147 DenseSet<llvm::Function *> interfaceFunctions;
148 for (auto &func : module->getFunctionList()) {
149 if (func.isDeclaration() || func.hasLocalLinkage())
150 continue;
151 if (interfaceFunctions.count(&func))
152 continue;
153
154 // Given a function `foo(<...>)`, define the interface function
155 // `mlir_foo(i8**)`.
156 auto *newType =
157 llvm::FunctionType::get(builder.getVoidTy(), builder.getPtrTy(),
158 /*isVarArg=*/false);
159 auto newName = makePackedFunctionName(func.getName());
160 auto funcCst = module->getOrInsertFunction(newName, newType);
161 llvm::Function *interfaceFunc = cast<llvm::Function>(funcCst.getCallee());
162 interfaceFunctions.insert(interfaceFunc);
163
164 // Extract the arguments from the type-erased argument list and cast them to
165 // the proper types.
166 auto *bb = llvm::BasicBlock::Create(ctx);
167 bb->insertInto(interfaceFunc);
168 builder.SetInsertPoint(bb);
169 llvm::Value *argList = interfaceFunc->arg_begin();
171 args.reserve(llvm::size(func.args()));
172 for (auto [index, arg] : llvm::enumerate(func.args())) {
173 llvm::Value *argIndex = llvm::Constant::getIntegerValue(
174 builder.getInt64Ty(), APInt(64, index));
175 llvm::Value *argPtrPtr =
176 builder.CreateGEP(builder.getPtrTy(), argList, argIndex);
177 llvm::Value *argPtr = builder.CreateLoad(builder.getPtrTy(), argPtrPtr);
178 llvm::Type *argTy = arg.getType();
179 llvm::Value *load = builder.CreateLoad(argTy, argPtr);
180 args.push_back(load);
181 }
182
183 // Call the implementation function with the extracted arguments.
184 llvm::Value *result = builder.CreateCall(&func, args);
185
186 // Assuming the result is one value, potentially of type `void`.
187 if (!result->getType()->isVoidTy()) {
188 llvm::Value *retIndex = llvm::Constant::getIntegerValue(
189 builder.getInt64Ty(), APInt(64, llvm::size(func.args())));
190 llvm::Value *retPtrPtr =
191 builder.CreateGEP(builder.getPtrTy(), argList, retIndex);
192 llvm::Value *retPtr = builder.CreateLoad(builder.getPtrTy(), retPtrPtr);
193 builder.CreateStore(result, retPtr);
194 }
195
196 // The interface function returns void.
197 builder.CreateRetVoid();
198 }
199}
200
202 bool enableGDBNotificationListener,
203 bool enablePerfNotificationListener)
204 : cache(enableObjectDump ? new SimpleObjectCache() : nullptr),
205 functionNames(),
206 gdbListener(enableGDBNotificationListener
207 ? llvm::JITEventListener::createGDBRegistrationListener()
208 : nullptr),
209 perfListener(nullptr) {
210 if (enablePerfNotificationListener) {
211 if (auto *listener = llvm::JITEventListener::createPerfJITEventListener())
212 perfListener = listener;
213 else if (auto *listener =
214 llvm::JITEventListener::createIntelJITEventListener())
215 perfListener = listener;
216 }
217}
218
220 // Execute the global destructors from the module being processed.
221 if (jit)
222 llvm::consumeError(jit->deinitialize(jit->getMainJITDylib()));
223 // Run all dynamic library destroy callbacks to prepare for the shutdown.
224 for (LibraryDestroyFn destroy : destroyFns)
225 destroy();
226}
227
230 std::unique_ptr<llvm::TargetMachine> tm) {
231 auto engine = std::make_unique<ExecutionEngine>(
232 options.enableObjectDump, options.enableGDBNotificationListener,
233 options.enablePerfNotificationListener);
234
235 // Remember all entry-points if object dumping is enabled.
236 if (options.enableObjectDump) {
237 for (auto funcOp : m->getRegion(0).getOps<LLVM::LLVMFuncOp>()) {
238 if (funcOp.getBlocks().empty())
239 continue;
240 StringRef funcName = funcOp.getSymName();
241 engine->functionNames.push_back(funcName.str());
242 }
243 }
244
245 std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext);
246 auto llvmModule = options.llvmModuleBuilder
247 ? options.llvmModuleBuilder(m, *ctx)
248 : translateModuleToLLVMIR(m, *ctx);
249 if (!llvmModule)
250 return makeStringError("could not convert to LLVM IR");
251
252 // If no valid TargetMachine was passed, create a default TM ignoring any
253 // input arguments from the user.
254 if (!tm) {
255 auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
256 if (!tmBuilderOrError)
257 return tmBuilderOrError.takeError();
258
259 auto tmOrError = tmBuilderOrError->createTargetMachine();
260 if (!tmOrError)
261 return tmOrError.takeError();
262 tm = std::move(tmOrError.get());
263 }
264
265 // TODO: Currently, the LLVM module created above has no triple associated
266 // with it. Instead, the triple is extracted from the TargetMachine, which is
267 // either based on the host defaults or command line arguments when specified
268 // (set-up by callers of this method). It could also be passed to the
269 // translation or dialect conversion instead of this.
270 setupTargetTripleAndDataLayout(llvmModule.get(), tm.get());
271 packFunctionArguments(llvmModule.get());
272
273 auto dataLayout = llvmModule->getDataLayout();
274
275 // Use absolute library path so that gdb can find the symbol table.
276 SmallVector<SmallString<256>, 4> sharedLibPaths;
277 transform(
278 options.sharedLibPaths, std::back_inserter(sharedLibPaths),
279 [](StringRef libPath) {
280 SmallString<256> absPath(libPath.begin(), libPath.end());
281 cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath)));
282 return absPath;
283 });
284
285 // If shared library implements custom execution layer library init and
286 // destroy functions, we'll use them to register the library. Otherwise, load
287 // the library as JITDyLib below.
288 llvm::StringMap<void *> exportSymbols;
290 SmallVector<StringRef> jitDyLibPaths;
291
292 for (auto &libPath : sharedLibPaths) {
293 auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(
294 libPath.str().str().c_str());
295 void *initSym = lib.getAddressOfSymbol(kLibraryInitFnName);
296 void *destroySim = lib.getAddressOfSymbol(kLibraryDestroyFnName);
297
298 // Library does not provide call backs, rely on symbol visiblity.
299 if (!initSym || !destroySim) {
300 jitDyLibPaths.push_back(libPath);
301 continue;
302 }
303
304 auto initFn = reinterpret_cast<LibraryInitFn>(initSym);
305 initFn(exportSymbols);
306
307 auto destroyFn = reinterpret_cast<LibraryDestroyFn>(destroySim);
308 destroyFns.push_back(destroyFn);
309 }
310 engine->destroyFns = std::move(destroyFns);
311
312 // Callback to create the object layer with symbol resolution to current
313 // process and dynamically linked libraries.
314 auto objectLinkingLayerCreator = [&](ExecutionSession &session) {
315 // Needed to respect AArch64 ABI requirements on the distance between
316 // TEXT and GOT sections.
317 bool reserveAlloc = llvmModule->getTargetTriple().isAArch64();
318 auto objectLayer = std::make_unique<RTDyldObjectLinkingLayer>(
319 session, [sectionMemoryMapper = options.sectionMemoryMapper,
320 reserveAlloc](const MemoryBuffer &) {
321 return std::make_unique<SectionMemoryManager>(sectionMemoryMapper,
322 reserveAlloc);
323 });
324
325 // Register JIT event listeners if they are enabled.
326 if (engine->gdbListener)
327 objectLayer->registerJITEventListener(*engine->gdbListener);
328 if (engine->perfListener)
329 objectLayer->registerJITEventListener(*engine->perfListener);
330
331 // COFF format binaries (Windows) need special handling to deal with
332 // exported symbol visibility.
333 // cf llvm/lib/ExecutionEngine/Orc/LLJIT.cpp LLJIT::createObjectLinkingLayer
334 const llvm::Triple &targetTriple = llvmModule->getTargetTriple();
335 if (targetTriple.isOSBinFormatCOFF()) {
336 objectLayer->setOverrideObjectFlagsWithResponsibilityFlags(true);
337 objectLayer->setAutoClaimResponsibilityForObjectSymbols(true);
338 }
339
340 // Resolve symbols from shared libraries.
341 for (auto &libPath : jitDyLibPaths) {
342 auto mb = llvm::MemoryBuffer::getFile(libPath);
343 if (!mb) {
344 errs() << "Failed to create MemoryBuffer for: " << libPath
345 << "\nError: " << mb.getError().message() << "\n";
346 continue;
347 }
348 auto &jd = session.createBareJITDylib(std::string(libPath));
349 auto loaded = DynamicLibrarySearchGenerator::Load(
350 libPath.str().c_str(), dataLayout.getGlobalPrefix());
351 if (!loaded) {
352 errs() << "Could not load " << libPath << ":\n " << loaded.takeError()
353 << "\n";
354 continue;
355 }
356 jd.addGenerator(std::move(*loaded));
357 cantFail(objectLayer->add(jd, std::move(mb.get())));
358 }
359
360 return objectLayer;
361 };
362
363 // Callback to inspect the cache and recompile on demand. This follows Lang's
364 // LLJITWithObjectCache example.
365 auto compileFunctionCreator = [&](JITTargetMachineBuilder jtmb)
366 -> Expected<std::unique_ptr<IRCompileLayer::IRCompiler>> {
367 if (options.jitCodeGenOptLevel)
368 jtmb.setCodeGenOptLevel(*options.jitCodeGenOptLevel);
369 return std::make_unique<TMOwningSimpleCompiler>(std::move(tm),
370 engine->cache.get());
371 };
372
373 // Create the LLJIT by calling the LLJITBuilder with 2 callbacks.
374 auto jit =
375 cantFail(llvm::orc::LLJITBuilder()
376 .setCompileFunctionCreator(compileFunctionCreator)
377 .setObjectLinkingLayerCreator(objectLinkingLayerCreator)
378 .setDataLayout(dataLayout)
379 .create());
380
381 // Add a ThreadSafemodule to the engine and return.
382 ThreadSafeModule tsm(std::move(llvmModule), std::move(ctx));
383 if (options.transformer)
384 cantFail(tsm.withModuleDo(
385 [&](llvm::Module &module) { return options.transformer(&module); }));
386 cantFail(jit->addIRModule(std::move(tsm)));
387 engine->jit = std::move(jit);
388
389 // Resolve symbols that are statically linked in the current process.
390 llvm::orc::JITDylib &mainJD = engine->jit->getMainJITDylib();
391 mainJD.addGenerator(
392 cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess(
393 dataLayout.getGlobalPrefix())));
394
395 // Build a runtime symbol map from the exported symbols and register them.
396 auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
397 auto symbolMap = llvm::orc::SymbolMap();
398 for (auto &exportSymbol : exportSymbols)
399 symbolMap[interner(exportSymbol.getKey())] = {
400 llvm::orc::ExecutorAddr::fromPtr(exportSymbol.getValue()),
401 llvm::JITSymbolFlags::Exported};
402 return symbolMap;
403 };
404 engine->registerSymbols(runtimeSymbolMap);
405 return std::move(engine);
406}
407
408Expected<void (*)(void **)>
409ExecutionEngine::lookupPacked(StringRef name) const {
411 if (!result)
412 return result.takeError();
413 return reinterpret_cast<void (*)(void **)>(result.get());
414}
415
417 auto expectedSymbol = jit->lookup(name);
418
419 // JIT lookup may return an Error referring to strings stored internally by
420 // the JIT. If the Error outlives the ExecutionEngine, it would want have a
421 // dangling reference, which is currently caught by an assertion inside JIT
422 // thanks to hand-rolled reference counting. Rewrap the error message into a
423 // string before returning. Alternatively, ORC JIT should consider copying
424 // the string into the error message.
425 if (!expectedSymbol) {
426 std::string errorMessage;
427 llvm::raw_string_ostream os(errorMessage);
428 llvm::handleAllErrors(expectedSymbol.takeError(),
429 [&os](llvm::ErrorInfoBase &ei) { ei.log(os); });
430 return makeStringError(errorMessage);
431 }
432
433 if (void *fptr = expectedSymbol->toPtr<void *>())
434 return fptr;
435 return makeStringError("looked up function is null");
436}
437
440 initialize();
441 auto expectedFPtr = lookupPacked(name);
442 if (!expectedFPtr)
443 return expectedFPtr.takeError();
444 auto fptr = *expectedFPtr;
445
446 (*fptr)(args.data());
447
448 return Error::success();
449}
450
452 if (isInitialized)
453 return;
454 cantFail(jit->initialize(jit->getMainJITDylib()));
455 isInitialized = true;
456}
static void packFunctionArguments(Module *module)
static Error makeStringError(const Twine &message)
Wrap a string into an llvm::StringError.
static std::string makePackedFunctionName(StringRef name)
auto load
@ Error
static llvm::ManagedStatic< PassManagerOptions > options
llvm::Expected< void(*)(void **)> lookupPacked(StringRef name) const
Looks up a packed-argument function wrapping the function with the given name and returns a pointer t...
void(*)(llvm::StringMap< void * > &) LibraryInitFn
Function type for init functions of shared libraries.
void(*)() LibraryDestroyFn
Function type for destroy functions of shared libraries.
void registerSymbols(llvm::function_ref< llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)> symbolMap)
Register symbols with this ExecutionEngine.
llvm::Expected< void * > lookup(StringRef name) const
Looks up the original function with the given name and returns a pointer to it.
static llvm::Expected< std::unique_ptr< ExecutionEngine > > create(Operation *op, const ExecutionEngineOptions &options={}, std::unique_ptr< llvm::TargetMachine > tm=nullptr)
Creates an execution engine for the given MLIR IR.
void dumpToObjectFile(StringRef filename)
Dump object code to output file filename.
static Result< T > result(T &t)
Helper function to wrap an output operand when using ExecutionEngine::invoke.
static constexpr const char *const kLibraryInitFnName
Name of init functions of shared libraries.
static constexpr const char *const kLibraryDestroyFnName
Name of destroy functions of shared libraries.
llvm::Error invokePacked(StringRef name, MutableArrayRef< void * > args={})
Invokes the function with the given name passing it the list of opaque pointers to the actual argumen...
ExecutionEngine(bool enableObjectDump, bool enableGDBNotificationListener, bool enablePerfNotificationListener)
static void setupTargetTripleAndDataLayout(llvm::Module *llvmModule, llvm::TargetMachine *tm)
Set the target triple and the data layout for the input module based on the input TargetMachine.
void initialize()
Initialize the ExecutionEngine.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
iterator_range< OpIterator > getOps()
Definition Region.h:172
A simple object cache following Lang's LLJITWithObjectCache example.
bool isEmpty()
Returns true if cache hasn't been populated yet.
void notifyObjectCompiled(const llvm::Module *m, llvm::MemoryBufferRef objBuffer) override
void dumpToObjectFile(StringRef filename)
Dump cached object to output file filename.
std::unique_ptr< llvm::MemoryBuffer > getObject(const llvm::Module *m) override
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
Include the generated interface declarations.
std::unique_ptr< llvm::ToolOutputFile > openOutputFile(llvm::StringRef outputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for writing.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
std::unique_ptr< llvm::Module > translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext, llvm::StringRef name="LLVMDialectModule", bool disableVerification=false)
Translates a given LLVM dialect module into an LLVM IR module living in the given context.