MLIR  22.0.0git
ExecutionEngine.cpp
Go to the documentation of this file.
1 //===- ExecutionEngine.cpp - MLIR Execution engine and utils --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the execution engine for MLIR modules based on LLVM Orc
10 // JIT engine.
11 //
12 //===----------------------------------------------------------------------===//
15 #include "mlir/IR/BuiltinOps.h"
18 
19 #include "llvm/ExecutionEngine/JITEventListener.h"
20 #include "llvm/ExecutionEngine/ObjectCache.h"
21 #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
22 #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
23 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
24 #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
25 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
26 #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
27 #include "llvm/IR/IRBuilder.h"
28 #include "llvm/MC/TargetRegistry.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/Error.h"
31 #include "llvm/Support/ToolOutputFile.h"
32 #include "llvm/TargetParser/Host.h"
33 #include "llvm/TargetParser/SubtargetFeature.h"
34 
35 #define DEBUG_TYPE "execution-engine"
36 
37 using namespace mlir;
38 using llvm::dbgs;
39 using llvm::Error;
40 using llvm::errs;
41 using llvm::Expected;
42 using llvm::LLVMContext;
43 using llvm::MemoryBuffer;
44 using llvm::MemoryBufferRef;
45 using llvm::Module;
46 using llvm::SectionMemoryManager;
47 using llvm::StringError;
48 using llvm::Triple;
49 using llvm::orc::DynamicLibrarySearchGenerator;
50 using llvm::orc::ExecutionSession;
51 using llvm::orc::IRCompileLayer;
52 using llvm::orc::JITTargetMachineBuilder;
53 using llvm::orc::MangleAndInterner;
54 using llvm::orc::RTDyldObjectLinkingLayer;
55 using llvm::orc::SymbolMap;
56 using llvm::orc::ThreadSafeModule;
57 using llvm::orc::TMOwningSimpleCompiler;
58 
59 /// Wrap a string into an llvm::StringError.
60 static Error makeStringError(const Twine &message) {
61  return llvm::make_error<StringError>(message.str(),
62  llvm::inconvertibleErrorCode());
63 }
64 
66  MemoryBufferRef objBuffer) {
67  cachedObjects[m->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy(
68  objBuffer.getBuffer(), objBuffer.getBufferIdentifier());
69 }
70 
71 std::unique_ptr<MemoryBuffer> SimpleObjectCache::getObject(const Module *m) {
72  auto i = cachedObjects.find(m->getModuleIdentifier());
73  if (i == cachedObjects.end()) {
74  LLVM_DEBUG(dbgs() << "No object for " << m->getModuleIdentifier()
75  << " in cache. Compiling.\n");
76  return nullptr;
77  }
78  LLVM_DEBUG(dbgs() << "Object for " << m->getModuleIdentifier()
79  << " loaded from cache.\n");
80  return MemoryBuffer::getMemBuffer(i->second->getMemBufferRef());
81 }
82 
83 void SimpleObjectCache::dumpToObjectFile(StringRef outputFilename) {
84  // Set up the output file.
85  std::string errorMessage;
86  auto file = openOutputFile(outputFilename, &errorMessage);
87  if (!file) {
88  llvm::errs() << errorMessage << "\n";
89  return;
90  }
91 
92  // Dump the object generated for a single module to the output file.
93  assert(cachedObjects.size() == 1 && "Expected only one object entry.");
94  auto &cachedObject = cachedObjects.begin()->second;
95  file->os() << cachedObject->getBuffer();
96  file->keep();
97 }
98 
99 bool SimpleObjectCache::isEmpty() { return cachedObjects.empty(); }
100 
101 void ExecutionEngine::dumpToObjectFile(StringRef filename) {
102  if (cache == nullptr) {
103  llvm::errs() << "cannot dump ExecutionEngine object code to file: "
104  "object cache is disabled\n";
105  return;
106  }
107  // Compilation is lazy and it doesn't populate object cache unless requested.
108  // In case object dump is requested before cache is populated, we need to
109  // force compilation manually.
110  if (cache->isEmpty()) {
111  for (std::string &functionName : functionNames) {
112  auto result = lookupPacked(functionName);
113  if (!result) {
114  llvm::errs() << "Could not compile " << functionName << ":\n "
115  << result.takeError() << "\n";
116  return;
117  }
118  }
119  }
120  cache->dumpToObjectFile(filename);
121 }
122 
124  llvm::function_ref<SymbolMap(MangleAndInterner)> symbolMap) {
125  auto &mainJitDylib = jit->getMainJITDylib();
126  cantFail(mainJitDylib.define(
127  absoluteSymbols(symbolMap(llvm::orc::MangleAndInterner(
128  mainJitDylib.getExecutionSession(), jit->getDataLayout())))));
129 }
130 
132  llvm::TargetMachine *tm) {
133  llvmModule->setDataLayout(tm->createDataLayout());
134  llvmModule->setTargetTriple(tm->getTargetTriple());
135 }
136 
137 static std::string makePackedFunctionName(StringRef name) {
138  return "_mlir_" + name.str();
139 }
140 
141 // For each function in the LLVM module, define an interface function that wraps
142 // all the arguments of the original function and all its results into an i8**
143 // pointer to provide a unified invocation interface.
144 static void packFunctionArguments(Module *module) {
145  auto &ctx = module->getContext();
146  llvm::IRBuilder<> builder(ctx);
147  DenseSet<llvm::Function *> interfaceFunctions;
148  for (auto &func : module->getFunctionList()) {
149  if (func.isDeclaration()) {
150  continue;
151  }
152  if (interfaceFunctions.count(&func)) {
153  continue;
154  }
155 
156  // Given a function `foo(<...>)`, define the interface function
157  // `mlir_foo(i8**)`.
158  auto *newType =
159  llvm::FunctionType::get(builder.getVoidTy(), builder.getPtrTy(),
160  /*isVarArg=*/false);
161  auto newName = makePackedFunctionName(func.getName());
162  auto funcCst = module->getOrInsertFunction(newName, newType);
163  llvm::Function *interfaceFunc = cast<llvm::Function>(funcCst.getCallee());
164  interfaceFunctions.insert(interfaceFunc);
165 
166  // Extract the arguments from the type-erased argument list and cast them to
167  // the proper types.
168  auto *bb = llvm::BasicBlock::Create(ctx);
169  bb->insertInto(interfaceFunc);
170  builder.SetInsertPoint(bb);
171  llvm::Value *argList = interfaceFunc->arg_begin();
173  args.reserve(llvm::size(func.args()));
174  for (auto [index, arg] : llvm::enumerate(func.args())) {
175  llvm::Value *argIndex = llvm::Constant::getIntegerValue(
176  builder.getInt64Ty(), APInt(64, index));
177  llvm::Value *argPtrPtr =
178  builder.CreateGEP(builder.getPtrTy(), argList, argIndex);
179  llvm::Value *argPtr = builder.CreateLoad(builder.getPtrTy(), argPtrPtr);
180  llvm::Type *argTy = arg.getType();
181  llvm::Value *load = builder.CreateLoad(argTy, argPtr);
182  args.push_back(load);
183  }
184 
185  // Call the implementation function with the extracted arguments.
186  llvm::Value *result = builder.CreateCall(&func, args);
187 
188  // Assuming the result is one value, potentially of type `void`.
189  if (!result->getType()->isVoidTy()) {
190  llvm::Value *retIndex = llvm::Constant::getIntegerValue(
191  builder.getInt64Ty(), APInt(64, llvm::size(func.args())));
192  llvm::Value *retPtrPtr =
193  builder.CreateGEP(builder.getPtrTy(), argList, retIndex);
194  llvm::Value *retPtr = builder.CreateLoad(builder.getPtrTy(), retPtrPtr);
195  builder.CreateStore(result, retPtr);
196  }
197 
198  // The interface function returns void.
199  builder.CreateRetVoid();
200  }
201 }
202 
203 ExecutionEngine::ExecutionEngine(bool enableObjectDump,
204  bool enableGDBNotificationListener,
205  bool enablePerfNotificationListener)
206  : cache(enableObjectDump ? new SimpleObjectCache() : nullptr),
207  functionNames(),
208  gdbListener(enableGDBNotificationListener
209  ? llvm::JITEventListener::createGDBRegistrationListener()
210  : nullptr),
211  perfListener(nullptr) {
212  if (enablePerfNotificationListener) {
213  if (auto *listener = llvm::JITEventListener::createPerfJITEventListener())
214  perfListener = listener;
215  else if (auto *listener =
216  llvm::JITEventListener::createIntelJITEventListener())
217  perfListener = listener;
218  }
219 }
220 
222  // Execute the global destructors from the module being processed.
223  // TODO: Allow JIT deinitialize for AArch64. Currently there's a bug causing a
224  // crash for AArch64 see related issue #71963.
225  if (jit && !jit->getTargetTriple().isAArch64())
226  llvm::consumeError(jit->deinitialize(jit->getMainJITDylib()));
227  // Run all dynamic library destroy callbacks to prepare for the shutdown.
228  for (LibraryDestroyFn destroy : destroyFns)
229  destroy();
230 }
231 
234  std::unique_ptr<llvm::TargetMachine> tm) {
235  auto engine = std::make_unique<ExecutionEngine>(
236  options.enableObjectDump, options.enableGDBNotificationListener,
237  options.enablePerfNotificationListener);
238 
239  // Remember all entry-points if object dumping is enabled.
240  if (options.enableObjectDump) {
241  for (auto funcOp : m->getRegion(0).getOps<LLVM::LLVMFuncOp>()) {
242  StringRef funcName = funcOp.getSymName();
243  engine->functionNames.push_back(funcName.str());
244  }
245  }
246 
247  std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext);
248  auto llvmModule = options.llvmModuleBuilder
249  ? options.llvmModuleBuilder(m, *ctx)
250  : translateModuleToLLVMIR(m, *ctx);
251  if (!llvmModule)
252  return makeStringError("could not convert to LLVM IR");
253 
254  // If no valid TargetMachine was passed, create a default TM ignoring any
255  // input arguments from the user.
256  if (!tm) {
257  auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
258  if (!tmBuilderOrError)
259  return tmBuilderOrError.takeError();
260 
261  auto tmOrError = tmBuilderOrError->createTargetMachine();
262  if (!tmOrError)
263  return tmOrError.takeError();
264  tm = std::move(tmOrError.get());
265  }
266 
267  // TODO: Currently, the LLVM module created above has no triple associated
268  // with it. Instead, the triple is extracted from the TargetMachine, which is
269  // either based on the host defaults or command line arguments when specified
270  // (set-up by callers of this method). It could also be passed to the
271  // translation or dialect conversion instead of this.
272  setupTargetTripleAndDataLayout(llvmModule.get(), tm.get());
273  packFunctionArguments(llvmModule.get());
274 
275  auto dataLayout = llvmModule->getDataLayout();
276 
277  // Use absolute library path so that gdb can find the symbol table.
278  SmallVector<SmallString<256>, 4> sharedLibPaths;
279  transform(
280  options.sharedLibPaths, std::back_inserter(sharedLibPaths),
281  [](StringRef libPath) {
282  SmallString<256> absPath(libPath.begin(), libPath.end());
283  cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath)));
284  return absPath;
285  });
286 
287  // If shared library implements custom execution layer library init and
288  // destroy functions, we'll use them to register the library. Otherwise, load
289  // the library as JITDyLib below.
290  llvm::StringMap<void *> exportSymbols;
292  SmallVector<StringRef> jitDyLibPaths;
293 
294  for (auto &libPath : sharedLibPaths) {
295  auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(
296  libPath.str().str().c_str());
297  void *initSym = lib.getAddressOfSymbol(kLibraryInitFnName);
298  void *destroySim = lib.getAddressOfSymbol(kLibraryDestroyFnName);
299 
300  // Library does not provide call backs, rely on symbol visiblity.
301  if (!initSym || !destroySim) {
302  jitDyLibPaths.push_back(libPath);
303  continue;
304  }
305 
306  auto initFn = reinterpret_cast<LibraryInitFn>(initSym);
307  initFn(exportSymbols);
308 
309  auto destroyFn = reinterpret_cast<LibraryDestroyFn>(destroySim);
310  destroyFns.push_back(destroyFn);
311  }
312  engine->destroyFns = std::move(destroyFns);
313 
314  // Callback to create the object layer with symbol resolution to current
315  // process and dynamically linked libraries.
316  auto objectLinkingLayerCreator = [&](ExecutionSession &session) {
317  auto objectLayer = std::make_unique<RTDyldObjectLinkingLayer>(
318  session, [sectionMemoryMapper =
319  options.sectionMemoryMapper](const MemoryBuffer &) {
320  return std::make_unique<SectionMemoryManager>(sectionMemoryMapper);
321  });
322 
323  // Register JIT event listeners if they are enabled.
324  if (engine->gdbListener)
325  objectLayer->registerJITEventListener(*engine->gdbListener);
326  if (engine->perfListener)
327  objectLayer->registerJITEventListener(*engine->perfListener);
328 
329  // COFF format binaries (Windows) need special handling to deal with
330  // exported symbol visibility.
331  // cf llvm/lib/ExecutionEngine/Orc/LLJIT.cpp LLJIT::createObjectLinkingLayer
332  const llvm::Triple &targetTriple = llvmModule->getTargetTriple();
333  if (targetTriple.isOSBinFormatCOFF()) {
334  objectLayer->setOverrideObjectFlagsWithResponsibilityFlags(true);
335  objectLayer->setAutoClaimResponsibilityForObjectSymbols(true);
336  }
337 
338  // Resolve symbols from shared libraries.
339  for (auto &libPath : jitDyLibPaths) {
340  auto mb = llvm::MemoryBuffer::getFile(libPath);
341  if (!mb) {
342  errs() << "Failed to create MemoryBuffer for: " << libPath
343  << "\nError: " << mb.getError().message() << "\n";
344  continue;
345  }
346  auto &jd = session.createBareJITDylib(std::string(libPath));
347  auto loaded = DynamicLibrarySearchGenerator::Load(
348  libPath.str().c_str(), dataLayout.getGlobalPrefix());
349  if (!loaded) {
350  errs() << "Could not load " << libPath << ":\n " << loaded.takeError()
351  << "\n";
352  continue;
353  }
354  jd.addGenerator(std::move(*loaded));
355  cantFail(objectLayer->add(jd, std::move(mb.get())));
356  }
357 
358  return objectLayer;
359  };
360 
361  // Callback to inspect the cache and recompile on demand. This follows Lang's
362  // LLJITWithObjectCache example.
363  auto compileFunctionCreator = [&](JITTargetMachineBuilder jtmb)
364  -> Expected<std::unique_ptr<IRCompileLayer::IRCompiler>> {
365  if (options.jitCodeGenOptLevel)
366  jtmb.setCodeGenOptLevel(*options.jitCodeGenOptLevel);
367  return std::make_unique<TMOwningSimpleCompiler>(std::move(tm),
368  engine->cache.get());
369  };
370 
371  // Create the LLJIT by calling the LLJITBuilder with 2 callbacks.
372  auto jit =
373  cantFail(llvm::orc::LLJITBuilder()
374  .setCompileFunctionCreator(compileFunctionCreator)
375  .setObjectLinkingLayerCreator(objectLinkingLayerCreator)
376  .setDataLayout(dataLayout)
377  .create());
378 
379  // Add a ThreadSafemodule to the engine and return.
380  ThreadSafeModule tsm(std::move(llvmModule), std::move(ctx));
381  if (options.transformer)
382  cantFail(tsm.withModuleDo(
383  [&](llvm::Module &module) { return options.transformer(&module); }));
384  cantFail(jit->addIRModule(std::move(tsm)));
385  engine->jit = std::move(jit);
386 
387  // Resolve symbols that are statically linked in the current process.
388  llvm::orc::JITDylib &mainJD = engine->jit->getMainJITDylib();
389  mainJD.addGenerator(
390  cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess(
391  dataLayout.getGlobalPrefix())));
392 
393  // Build a runtime symbol map from the exported symbols and register them.
394  auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
395  auto symbolMap = llvm::orc::SymbolMap();
396  for (auto &exportSymbol : exportSymbols)
397  symbolMap[interner(exportSymbol.getKey())] = {
398  llvm::orc::ExecutorAddr::fromPtr(exportSymbol.getValue()),
399  llvm::JITSymbolFlags::Exported};
400  return symbolMap;
401  };
402  engine->registerSymbols(runtimeSymbolMap);
403  return std::move(engine);
404 }
405 
406 Expected<void (*)(void **)>
407 ExecutionEngine::lookupPacked(StringRef name) const {
408  auto result = lookup(makePackedFunctionName(name));
409  if (!result)
410  return result.takeError();
411  return reinterpret_cast<void (*)(void **)>(result.get());
412 }
413 
415  auto expectedSymbol = jit->lookup(name);
416 
417  // JIT lookup may return an Error referring to strings stored internally by
418  // the JIT. If the Error outlives the ExecutionEngine, it would want have a
419  // dangling reference, which is currently caught by an assertion inside JIT
420  // thanks to hand-rolled reference counting. Rewrap the error message into a
421  // string before returning. Alternatively, ORC JIT should consider copying
422  // the string into the error message.
423  if (!expectedSymbol) {
424  std::string errorMessage;
425  llvm::raw_string_ostream os(errorMessage);
426  llvm::handleAllErrors(expectedSymbol.takeError(),
427  [&os](llvm::ErrorInfoBase &ei) { ei.log(os); });
428  return makeStringError(errorMessage);
429  }
430 
431  if (void *fptr = expectedSymbol->toPtr<void *>())
432  return fptr;
433  return makeStringError("looked up function is null");
434 }
435 
438  initialize();
439  auto expectedFPtr = lookupPacked(name);
440  if (!expectedFPtr)
441  return expectedFPtr.takeError();
442  auto fptr = *expectedFPtr;
443 
444  (*fptr)(args.data());
445 
446  return Error::success();
447 }
448 
450  if (isInitialized)
451  return;
452  // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
453  // crash for AArch64 see related issue #71963.
454  if (!jit->getTargetTriple().isAArch64())
455  cantFail(jit->initialize(jit->getMainJITDylib()));
456  isInitialized = true;
457 }
static void packFunctionArguments(Module *module)
static Error makeStringError(const Twine &message)
Wrap a string into an llvm::StringError.
static std::string makePackedFunctionName(StringRef name)
@ Error
static llvm::ManagedStatic< PassManagerOptions > options
llvm::Expected< void(*)(void **)> lookupPacked(StringRef name) const
Looks up a packed-argument function wrapping the function with the given name and returns a pointer t...
void(*)(llvm::StringMap< void * > &) LibraryInitFn
Function type for init functions of shared libraries.
void(*)() LibraryDestroyFn
Function type for destroy functions of shared libraries.
void registerSymbols(llvm::function_ref< llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)> symbolMap)
Register symbols with this ExecutionEngine.
llvm::Expected< void * > lookup(StringRef name) const
Looks up the original function with the given name and returns a pointer to it.
static llvm::Expected< std::unique_ptr< ExecutionEngine > > create(Operation *op, const ExecutionEngineOptions &options={}, std::unique_ptr< llvm::TargetMachine > tm=nullptr)
Creates an execution engine for the given MLIR IR.
void dumpToObjectFile(StringRef filename)
Dump object code to output file filename.
static constexpr const char *const kLibraryInitFnName
Name of init functions of shared libraries.
static Result< T > result(T &t)
Helper function to wrap an output operand when using ExecutionEngine::invoke.
static constexpr const char *const kLibraryDestroyFnName
Name of destroy functions of shared libraries.
llvm::Error invokePacked(StringRef name, MutableArrayRef< void * > args={})
Invokes the function with the given name passing it the list of opaque pointers to the actual argumen...
ExecutionEngine(bool enableObjectDump, bool enableGDBNotificationListener, bool enablePerfNotificationListener)
static void setupTargetTripleAndDataLayout(llvm::Module *llvmModule, llvm::TargetMachine *tm)
Set the target triple and the data layout for the input module based on the input TargetMachine.
void initialize()
Initialize the ExecutionEngine.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
iterator_range< OpIterator > getOps()
Definition: Region.h:172
A simple object cache following Lang's LLJITWithObjectCache example.
bool isEmpty()
Returns true if cache hasn't been populated yet.
void notifyObjectCompiled(const llvm::Module *m, llvm::MemoryBufferRef objBuffer) override
void dumpToObjectFile(StringRef filename)
Dump cached object to output file filename.
std::unique_ptr< llvm::MemoryBuffer > getObject(const llvm::Module *m) override
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
std::unique_ptr< llvm::Module > translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext, llvm::StringRef name="LLVMDialectModule", bool disableVerification=false)
Translates a given LLVM dialect module into an LLVM IR module living in the given context.
std::unique_ptr< llvm::ToolOutputFile > openOutputFile(llvm::StringRef outputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for writing.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...