MLIR  18.0.0git
ExecutionEngine.cpp
Go to the documentation of this file.
1 //===- ExecutionEngine.cpp - MLIR Execution engine and utils --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the execution engine for MLIR modules based on LLVM Orc
10 // JIT engine.
11 //
12 //===----------------------------------------------------------------------===//
15 #include "mlir/IR/BuiltinOps.h"
18 
19 #include "llvm/ExecutionEngine/JITEventListener.h"
20 #include "llvm/ExecutionEngine/ObjectCache.h"
21 #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
22 #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
23 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
24 #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
25 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
26 #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
27 #include "llvm/IR/IRBuilder.h"
28 #include "llvm/MC/TargetRegistry.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/Error.h"
31 #include "llvm/Support/ToolOutputFile.h"
32 #include "llvm/TargetParser/Host.h"
33 #include "llvm/TargetParser/SubtargetFeature.h"
34 
35 #define DEBUG_TYPE "execution-engine"
36 
37 using namespace mlir;
38 using llvm::dbgs;
39 using llvm::Error;
40 using llvm::errs;
41 using llvm::Expected;
42 using llvm::LLVMContext;
43 using llvm::MemoryBuffer;
44 using llvm::MemoryBufferRef;
45 using llvm::Module;
46 using llvm::SectionMemoryManager;
47 using llvm::StringError;
48 using llvm::Triple;
49 using llvm::orc::DynamicLibrarySearchGenerator;
50 using llvm::orc::ExecutionSession;
51 using llvm::orc::IRCompileLayer;
52 using llvm::orc::JITTargetMachineBuilder;
53 using llvm::orc::MangleAndInterner;
54 using llvm::orc::RTDyldObjectLinkingLayer;
55 using llvm::orc::SymbolMap;
56 using llvm::orc::ThreadSafeModule;
57 using llvm::orc::TMOwningSimpleCompiler;
58 
59 /// Wrap a string into an llvm::StringError.
60 static Error makeStringError(const Twine &message) {
61  return llvm::make_error<StringError>(message.str(),
62  llvm::inconvertibleErrorCode());
63 }
64 
66  MemoryBufferRef objBuffer) {
67  cachedObjects[m->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy(
68  objBuffer.getBuffer(), objBuffer.getBufferIdentifier());
69 }
70 
71 std::unique_ptr<MemoryBuffer> SimpleObjectCache::getObject(const Module *m) {
72  auto i = cachedObjects.find(m->getModuleIdentifier());
73  if (i == cachedObjects.end()) {
74  LLVM_DEBUG(dbgs() << "No object for " << m->getModuleIdentifier()
75  << " in cache. Compiling.\n");
76  return nullptr;
77  }
78  LLVM_DEBUG(dbgs() << "Object for " << m->getModuleIdentifier()
79  << " loaded from cache.\n");
80  return MemoryBuffer::getMemBuffer(i->second->getMemBufferRef());
81 }
82 
83 void SimpleObjectCache::dumpToObjectFile(StringRef outputFilename) {
84  // Set up the output file.
85  std::string errorMessage;
86  auto file = openOutputFile(outputFilename, &errorMessage);
87  if (!file) {
88  llvm::errs() << errorMessage << "\n";
89  return;
90  }
91 
92  // Dump the object generated for a single module to the output file.
93  assert(cachedObjects.size() == 1 && "Expected only one object entry.");
94  auto &cachedObject = cachedObjects.begin()->second;
95  file->os() << cachedObject->getBuffer();
96  file->keep();
97 }
98 
99 bool SimpleObjectCache::isEmpty() { return cachedObjects.empty(); }
100 
101 void ExecutionEngine::dumpToObjectFile(StringRef filename) {
102  if (cache == nullptr) {
103  llvm::errs() << "cannot dump ExecutionEngine object code to file: "
104  "object cache is disabled\n";
105  return;
106  }
107  // Compilation is lazy and it doesn't populate object cache unless requested.
108  // In case object dump is requested before cache is populated, we need to
109  // force compilation manually.
110  if (cache->isEmpty()) {
111  for (std::string &functionName : functionNames) {
112  auto result = lookupPacked(functionName);
113  if (!result) {
114  llvm::errs() << "Could not compile " << functionName << ":\n "
115  << result.takeError() << "\n";
116  return;
117  }
118  }
119  }
120  cache->dumpToObjectFile(filename);
121 }
122 
124  llvm::function_ref<SymbolMap(MangleAndInterner)> symbolMap) {
125  auto &mainJitDylib = jit->getMainJITDylib();
126  cantFail(mainJitDylib.define(
127  absoluteSymbols(symbolMap(llvm::orc::MangleAndInterner(
128  mainJitDylib.getExecutionSession(), jit->getDataLayout())))));
129 }
130 
132  llvm::TargetMachine *tm) {
133  llvmModule->setDataLayout(tm->createDataLayout());
134  llvmModule->setTargetTriple(tm->getTargetTriple().getTriple());
135 }
136 
137 static std::string makePackedFunctionName(StringRef name) {
138  return "_mlir_" + name.str();
139 }
140 
141 // For each function in the LLVM module, define an interface function that wraps
142 // all the arguments of the original function and all its results into an i8**
143 // pointer to provide a unified invocation interface.
144 static void packFunctionArguments(Module *module) {
145  auto &ctx = module->getContext();
146  llvm::IRBuilder<> builder(ctx);
147  DenseSet<llvm::Function *> interfaceFunctions;
148  for (auto &func : module->getFunctionList()) {
149  if (func.isDeclaration()) {
150  continue;
151  }
152  if (interfaceFunctions.count(&func)) {
153  continue;
154  }
155 
156  // Given a function `foo(<...>)`, define the interface function
157  // `mlir_foo(i8**)`.
158  auto *newType = llvm::FunctionType::get(
159  builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(),
160  /*isVarArg=*/false);
161  auto newName = makePackedFunctionName(func.getName());
162  auto funcCst = module->getOrInsertFunction(newName, newType);
163  llvm::Function *interfaceFunc = cast<llvm::Function>(funcCst.getCallee());
164  interfaceFunctions.insert(interfaceFunc);
165 
166  // Extract the arguments from the type-erased argument list and cast them to
167  // the proper types.
168  auto *bb = llvm::BasicBlock::Create(ctx);
169  bb->insertInto(interfaceFunc);
170  builder.SetInsertPoint(bb);
171  llvm::Value *argList = interfaceFunc->arg_begin();
173  args.reserve(llvm::size(func.args()));
174  for (auto [index, arg] : llvm::enumerate(func.args())) {
175  llvm::Value *argIndex = llvm::Constant::getIntegerValue(
176  builder.getInt64Ty(), APInt(64, index));
177  llvm::Value *argPtrPtr =
178  builder.CreateGEP(builder.getInt8PtrTy(), argList, argIndex);
179  llvm::Value *argPtr =
180  builder.CreateLoad(builder.getInt8PtrTy(), argPtrPtr);
181  llvm::Type *argTy = arg.getType();
182  argPtr = builder.CreateBitCast(argPtr, argTy->getPointerTo());
183  llvm::Value *load = builder.CreateLoad(argTy, argPtr);
184  args.push_back(load);
185  }
186 
187  // Call the implementation function with the extracted arguments.
188  llvm::Value *result = builder.CreateCall(&func, args);
189 
190  // Assuming the result is one value, potentially of type `void`.
191  if (!result->getType()->isVoidTy()) {
192  llvm::Value *retIndex = llvm::Constant::getIntegerValue(
193  builder.getInt64Ty(), APInt(64, llvm::size(func.args())));
194  llvm::Value *retPtrPtr =
195  builder.CreateGEP(builder.getInt8PtrTy(), argList, retIndex);
196  llvm::Value *retPtr =
197  builder.CreateLoad(builder.getInt8PtrTy(), retPtrPtr);
198  retPtr = builder.CreateBitCast(retPtr, result->getType()->getPointerTo());
199  builder.CreateStore(result, retPtr);
200  }
201 
202  // The interface function returns void.
203  builder.CreateRetVoid();
204  }
205 }
206 
207 ExecutionEngine::ExecutionEngine(bool enableObjectDump,
208  bool enableGDBNotificationListener,
209  bool enablePerfNotificationListener)
210  : cache(enableObjectDump ? new SimpleObjectCache() : nullptr),
211  functionNames(),
212  gdbListener(enableGDBNotificationListener
213  ? llvm::JITEventListener::createGDBRegistrationListener()
214  : nullptr),
215  perfListener(nullptr) {
216  if (enablePerfNotificationListener) {
217  if (auto *listener = llvm::JITEventListener::createPerfJITEventListener())
218  perfListener = listener;
219  else if (auto *listener =
220  llvm::JITEventListener::createIntelJITEventListener())
221  perfListener = listener;
222  }
223 }
224 
226  // Run all dynamic library destroy callbacks to prepare for the shutdown.
227  for (LibraryDestroyFn destroy : destroyFns)
228  destroy();
229 }
230 
233  std::unique_ptr<llvm::TargetMachine> tm) {
234  auto engine = std::make_unique<ExecutionEngine>(
235  options.enableObjectDump, options.enableGDBNotificationListener,
236  options.enablePerfNotificationListener);
237 
238  // Remember all entry-points if object dumping is enabled.
239  if (options.enableObjectDump) {
240  for (auto funcOp : m->getRegion(0).getOps<LLVM::LLVMFuncOp>()) {
241  StringRef funcName = funcOp.getSymName();
242  engine->functionNames.push_back(funcName.str());
243  }
244  }
245 
246  std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext);
247  auto llvmModule = options.llvmModuleBuilder
248  ? options.llvmModuleBuilder(m, *ctx)
249  : translateModuleToLLVMIR(m, *ctx);
250  if (!llvmModule)
251  return makeStringError("could not convert to LLVM IR");
252 
253  // If no valid TargetMachine was passed, create a default TM ignoring any
254  // input arguments from the user.
255  if (!tm) {
256  auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
257  if (!tmBuilderOrError)
258  return tmBuilderOrError.takeError();
259 
260  auto tmOrError = tmBuilderOrError->createTargetMachine();
261  if (!tmOrError)
262  return tmOrError.takeError();
263  tm = std::move(tmOrError.get());
264  }
265 
266  // TODO: Currently, the LLVM module created above has no triple associated
267  // with it. Instead, the triple is extracted from the TargetMachine, which is
268  // either based on the host defaults or command line arguments when specified
269  // (set-up by callers of this method). It could also be passed to the
270  // translation or dialect conversion instead of this.
271  setupTargetTripleAndDataLayout(llvmModule.get(), tm.get());
272  packFunctionArguments(llvmModule.get());
273 
274  auto dataLayout = llvmModule->getDataLayout();
275 
276  // Use absolute library path so that gdb can find the symbol table.
277  SmallVector<SmallString<256>, 4> sharedLibPaths;
278  transform(
279  options.sharedLibPaths, std::back_inserter(sharedLibPaths),
280  [](StringRef libPath) {
281  SmallString<256> absPath(libPath.begin(), libPath.end());
282  cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath)));
283  return absPath;
284  });
285 
286  // If shared library implements custom execution layer library init and
287  // destroy functions, we'll use them to register the library. Otherwise, load
288  // the library as JITDyLib below.
289  llvm::StringMap<void *> exportSymbols;
291  SmallVector<StringRef> jitDyLibPaths;
292 
293  for (auto &libPath : sharedLibPaths) {
294  auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(
295  libPath.str().str().c_str());
296  void *initSym = lib.getAddressOfSymbol(kLibraryInitFnName);
297  void *destroySim = lib.getAddressOfSymbol(kLibraryDestroyFnName);
298 
299  // Library does not provide call backs, rely on symbol visiblity.
300  if (!initSym || !destroySim) {
301  jitDyLibPaths.push_back(libPath);
302  continue;
303  }
304 
305  auto initFn = reinterpret_cast<LibraryInitFn>(initSym);
306  initFn(exportSymbols);
307 
308  auto destroyFn = reinterpret_cast<LibraryDestroyFn>(destroySim);
309  destroyFns.push_back(destroyFn);
310  }
311  engine->destroyFns = std::move(destroyFns);
312 
313  // Callback to create the object layer with symbol resolution to current
314  // process and dynamically linked libraries.
315  auto objectLinkingLayerCreator = [&](ExecutionSession &session,
316  const Triple &tt) {
317  auto objectLayer = std::make_unique<RTDyldObjectLinkingLayer>(
318  session, [sectionMemoryMapper = options.sectionMemoryMapper]() {
319  return std::make_unique<SectionMemoryManager>(sectionMemoryMapper);
320  });
321 
322  // Register JIT event listeners if they are enabled.
323  if (engine->gdbListener)
324  objectLayer->registerJITEventListener(*engine->gdbListener);
325  if (engine->perfListener)
326  objectLayer->registerJITEventListener(*engine->perfListener);
327 
328  // COFF format binaries (Windows) need special handling to deal with
329  // exported symbol visibility.
330  // cf llvm/lib/ExecutionEngine/Orc/LLJIT.cpp LLJIT::createObjectLinkingLayer
331  llvm::Triple targetTriple(llvm::Twine(llvmModule->getTargetTriple()));
332  if (targetTriple.isOSBinFormatCOFF()) {
333  objectLayer->setOverrideObjectFlagsWithResponsibilityFlags(true);
334  objectLayer->setAutoClaimResponsibilityForObjectSymbols(true);
335  }
336 
337  // Resolve symbols from shared libraries.
338  for (auto &libPath : jitDyLibPaths) {
339  auto mb = llvm::MemoryBuffer::getFile(libPath);
340  if (!mb) {
341  errs() << "Failed to create MemoryBuffer for: " << libPath
342  << "\nError: " << mb.getError().message() << "\n";
343  continue;
344  }
345  auto &jd = session.createBareJITDylib(std::string(libPath));
346  auto loaded = DynamicLibrarySearchGenerator::Load(
347  libPath.str().c_str(), dataLayout.getGlobalPrefix());
348  if (!loaded) {
349  errs() << "Could not load " << libPath << ":\n " << loaded.takeError()
350  << "\n";
351  continue;
352  }
353  jd.addGenerator(std::move(*loaded));
354  cantFail(objectLayer->add(jd, std::move(mb.get())));
355  }
356 
357  return objectLayer;
358  };
359 
360  // Callback to inspect the cache and recompile on demand. This follows Lang's
361  // LLJITWithObjectCache example.
362  auto compileFunctionCreator = [&](JITTargetMachineBuilder jtmb)
363  -> Expected<std::unique_ptr<IRCompileLayer::IRCompiler>> {
364  if (options.jitCodeGenOptLevel)
365  jtmb.setCodeGenOptLevel(*options.jitCodeGenOptLevel);
366  return std::make_unique<TMOwningSimpleCompiler>(std::move(tm),
367  engine->cache.get());
368  };
369 
370  // Create the LLJIT by calling the LLJITBuilder with 2 callbacks.
371  auto jit =
372  cantFail(llvm::orc::LLJITBuilder()
373  .setCompileFunctionCreator(compileFunctionCreator)
374  .setObjectLinkingLayerCreator(objectLinkingLayerCreator)
375  .setDataLayout(dataLayout)
376  .create());
377 
378  // Add a ThreadSafemodule to the engine and return.
379  ThreadSafeModule tsm(std::move(llvmModule), std::move(ctx));
380  if (options.transformer)
381  cantFail(tsm.withModuleDo(
382  [&](llvm::Module &module) { return options.transformer(&module); }));
383  cantFail(jit->addIRModule(std::move(tsm)));
384  engine->jit = std::move(jit);
385 
386  // Resolve symbols that are statically linked in the current process.
387  llvm::orc::JITDylib &mainJD = engine->jit->getMainJITDylib();
388  mainJD.addGenerator(
389  cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess(
390  dataLayout.getGlobalPrefix())));
391 
392  // Build a runtime symbol map from the exported symbols and register them.
393  auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
394  auto symbolMap = llvm::orc::SymbolMap();
395  for (auto &exportSymbol : exportSymbols)
396  symbolMap[interner(exportSymbol.getKey())] = {
397  llvm::orc::ExecutorAddr::fromPtr(exportSymbol.getValue()),
398  llvm::JITSymbolFlags::Exported};
399  return symbolMap;
400  };
401  engine->registerSymbols(runtimeSymbolMap);
402 
403  return std::move(engine);
404 }
405 
406 Expected<void (*)(void **)>
407 ExecutionEngine::lookupPacked(StringRef name) const {
408  auto result = lookup(makePackedFunctionName(name));
409  if (!result)
410  return result.takeError();
411  return reinterpret_cast<void (*)(void **)>(result.get());
412 }
413 
415  auto expectedSymbol = jit->lookup(name);
416 
417  // JIT lookup may return an Error referring to strings stored internally by
418  // the JIT. If the Error outlives the ExecutionEngine, it would want have a
419  // dangling reference, which is currently caught by an assertion inside JIT
420  // thanks to hand-rolled reference counting. Rewrap the error message into a
421  // string before returning. Alternatively, ORC JIT should consider copying
422  // the string into the error message.
423  if (!expectedSymbol) {
424  std::string errorMessage;
425  llvm::raw_string_ostream os(errorMessage);
426  llvm::handleAllErrors(expectedSymbol.takeError(),
427  [&os](llvm::ErrorInfoBase &ei) { ei.log(os); });
428  return makeStringError(os.str());
429  }
430 
431  if (void *fptr = expectedSymbol->toPtr<void *>())
432  return fptr;
433  return makeStringError("looked up function is null");
434 }
435 
438  auto expectedFPtr = lookupPacked(name);
439  if (!expectedFPtr)
440  return expectedFPtr.takeError();
441  auto fptr = *expectedFPtr;
442 
443  (*fptr)(args.data());
444 
445  return Error::success();
446 }
static void packFunctionArguments(Module *module)
static Error makeStringError(const Twine &message)
Wrap a string into an llvm::StringError.
static std::string makePackedFunctionName(StringRef name)
@ Error
static llvm::ManagedStatic< PassManagerOptions > options
llvm::Expected< void(*)(void **)> lookupPacked(StringRef name) const
Looks up a packed-argument function wrapping the function with the given name and returns a pointer t...
void(*)(llvm::StringMap< void * > &) LibraryInitFn
Function type for init functions of shared libraries.
void(*)() LibraryDestroyFn
Function type for destroy functions of shared libraries.
void registerSymbols(llvm::function_ref< llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)> symbolMap)
Register symbols with this ExecutionEngine.
llvm::Expected< void * > lookup(StringRef name) const
Looks up the original function with the given name and returns a pointer to it.
static llvm::Expected< std::unique_ptr< ExecutionEngine > > create(Operation *op, const ExecutionEngineOptions &options={}, std::unique_ptr< llvm::TargetMachine > tm=nullptr)
Creates an execution engine for the given MLIR IR.
void dumpToObjectFile(StringRef filename)
Dump object code to output file filename.
static constexpr const char *const kLibraryInitFnName
Name of init functions of shared libraries.
llvm::Error invokePacked(StringRef name, MutableArrayRef< void * > args=std::nullopt)
Invokes the function with the given name passing it the list of opaque pointers to the actual argumen...
static Result< T > result(T &t)
Helper function to wrap an output operand when using ExecutionEngine::invoke.
static constexpr const char *const kLibraryDestroyFnName
Name of destroy functions of shared libraries.
ExecutionEngine(bool enableObjectDump, bool enableGDBNotificationListener, bool enablePerfNotificationListener)
static void setupTargetTripleAndDataLayout(llvm::Module *llvmModule, llvm::TargetMachine *tm)
Set the target triple and the data layout for the input module based on the input TargetMachine.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:665
iterator_range< OpIterator > getOps()
Definition: Region.h:172
A simple object cache following Lang's LLJITWithObjectCache example.
bool isEmpty()
Returns true if cache hasn't been populated yet.
void notifyObjectCompiled(const llvm::Module *m, llvm::MemoryBufferRef objBuffer) override
void dumpToObjectFile(StringRef filename)
Dump cached object to output file filename.
std::unique_ptr< llvm::MemoryBuffer > getObject(const llvm::Module *m) override
Include the generated interface declarations.
Definition: CallGraph.h:229
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
This header declares functions that assist transformations in the MemRef dialect.
std::unique_ptr< llvm::ToolOutputFile > openOutputFile(llvm::StringRef outputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for writing.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< llvm::Module > translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext, llvm::StringRef name="LLVMDialectModule")
Translate operation that satisfies LLVM dialect module requirements into an LLVM IR module living in ...