MLIR  16.0.0git
ExecutionEngine.cpp
Go to the documentation of this file.
1 //===- ExecutionEngine.cpp - MLIR Execution engine and utils --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the execution engine for MLIR modules based on LLVM Orc
10 // JIT engine.
11 //
12 //===----------------------------------------------------------------------===//
15 #include "mlir/IR/BuiltinOps.h"
18 
19 #include "llvm/ExecutionEngine/JITEventListener.h"
20 #include "llvm/ExecutionEngine/ObjectCache.h"
21 #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
22 #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
23 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
24 #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
25 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
26 #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
27 #include "llvm/IR/IRBuilder.h"
28 #include "llvm/MC/SubtargetFeature.h"
29 #include "llvm/MC/TargetRegistry.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/Error.h"
32 #include "llvm/Support/Host.h"
33 #include "llvm/Support/ToolOutputFile.h"
34 
35 #define DEBUG_TYPE "execution-engine"
36 
37 using namespace mlir;
38 using llvm::dbgs;
39 using llvm::Error;
40 using llvm::errs;
41 using llvm::Expected;
42 using llvm::LLVMContext;
43 using llvm::MemoryBuffer;
44 using llvm::MemoryBufferRef;
45 using llvm::Module;
46 using llvm::SectionMemoryManager;
47 using llvm::StringError;
48 using llvm::Triple;
49 using llvm::orc::DynamicLibrarySearchGenerator;
50 using llvm::orc::ExecutionSession;
51 using llvm::orc::IRCompileLayer;
52 using llvm::orc::JITTargetMachineBuilder;
53 using llvm::orc::MangleAndInterner;
54 using llvm::orc::RTDyldObjectLinkingLayer;
55 using llvm::orc::SymbolMap;
56 using llvm::orc::ThreadSafeModule;
57 using llvm::orc::TMOwningSimpleCompiler;
58 
59 /// Wrap a string into an llvm::StringError.
60 static Error makeStringError(const Twine &message) {
61  return llvm::make_error<StringError>(message.str(),
62  llvm::inconvertibleErrorCode());
63 }
64 
66  MemoryBufferRef objBuffer) {
67  cachedObjects[m->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy(
68  objBuffer.getBuffer(), objBuffer.getBufferIdentifier());
69 }
70 
71 std::unique_ptr<MemoryBuffer> SimpleObjectCache::getObject(const Module *m) {
72  auto i = cachedObjects.find(m->getModuleIdentifier());
73  if (i == cachedObjects.end()) {
74  LLVM_DEBUG(dbgs() << "No object for " << m->getModuleIdentifier()
75  << " in cache. Compiling.\n");
76  return nullptr;
77  }
78  LLVM_DEBUG(dbgs() << "Object for " << m->getModuleIdentifier()
79  << " loaded from cache.\n");
80  return MemoryBuffer::getMemBuffer(i->second->getMemBufferRef());
81 }
82 
83 void SimpleObjectCache::dumpToObjectFile(StringRef outputFilename) {
84  // Set up the output file.
85  std::string errorMessage;
86  auto file = openOutputFile(outputFilename, &errorMessage);
87  if (!file) {
88  llvm::errs() << errorMessage << "\n";
89  return;
90  }
91 
92  // Dump the object generated for a single module to the output file.
93  assert(cachedObjects.size() == 1 && "Expected only one object entry.");
94  auto &cachedObject = cachedObjects.begin()->second;
95  file->os() << cachedObject->getBuffer();
96  file->keep();
97 }
98 
99 bool SimpleObjectCache::isEmpty() { return cachedObjects.empty(); }
100 
101 void ExecutionEngine::dumpToObjectFile(StringRef filename) {
102  if (cache == nullptr) {
103  llvm::errs() << "cannot dump ExecutionEngine object code to file: "
104  "object cache is disabled\n";
105  return;
106  }
107  // Compilation is lazy and it doesn't populate object cache unless requested.
108  // In case object dump is requested before cache is populated, we need to
109  // force compilation manually.
110  if (cache->isEmpty()) {
111  for (std::string &functionName : functionNames) {
112  auto result = lookupPacked(functionName);
113  if (!result) {
114  llvm::errs() << "Could not compile " << functionName << ":\n "
115  << result.takeError() << "\n";
116  return;
117  }
118  }
119  }
120  cache->dumpToObjectFile(filename);
121 }
122 
124  llvm::function_ref<SymbolMap(MangleAndInterner)> symbolMap) {
125  auto &mainJitDylib = jit->getMainJITDylib();
126  cantFail(mainJitDylib.define(
127  absoluteSymbols(symbolMap(llvm::orc::MangleAndInterner(
128  mainJitDylib.getExecutionSession(), jit->getDataLayout())))));
129 }
130 
131 // Setup LLVM target triple from the current machine.
132 bool ExecutionEngine::setupTargetTriple(Module *llvmModule) {
133  // Setup the machine properties from the current architecture.
134  auto targetTriple = llvm::sys::getDefaultTargetTriple();
135  std::string errorMessage;
136  const auto *target =
137  llvm::TargetRegistry::lookupTarget(targetTriple, errorMessage);
138  if (!target) {
139  errs() << "NO target: " << errorMessage << "\n";
140  return true;
141  }
142 
143  std::string cpu(llvm::sys::getHostCPUName());
144  llvm::SubtargetFeatures features;
145  llvm::StringMap<bool> hostFeatures;
146 
147  if (llvm::sys::getHostCPUFeatures(hostFeatures))
148  for (auto &f : hostFeatures)
149  features.AddFeature(f.first(), f.second);
150 
151  std::unique_ptr<llvm::TargetMachine> machine(target->createTargetMachine(
152  targetTriple, cpu, features.getString(), {}, {}));
153  if (!machine) {
154  errs() << "Unable to create target machine\n";
155  return true;
156  }
157  llvmModule->setDataLayout(machine->createDataLayout());
158  llvmModule->setTargetTriple(targetTriple);
159  return false;
160 }
161 
162 static std::string makePackedFunctionName(StringRef name) {
163  return "_mlir_" + name.str();
164 }
165 
166 // For each function in the LLVM module, define an interface function that wraps
167 // all the arguments of the original function and all its results into an i8**
168 // pointer to provide a unified invocation interface.
169 static void packFunctionArguments(Module *module) {
170  auto &ctx = module->getContext();
171  llvm::IRBuilder<> builder(ctx);
172  DenseSet<llvm::Function *> interfaceFunctions;
173  for (auto &func : module->getFunctionList()) {
174  if (func.isDeclaration()) {
175  continue;
176  }
177  if (interfaceFunctions.count(&func)) {
178  continue;
179  }
180 
181  // Given a function `foo(<...>)`, define the interface function
182  // `mlir_foo(i8**)`.
183  auto *newType = llvm::FunctionType::get(
184  builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(),
185  /*isVarArg=*/false);
186  auto newName = makePackedFunctionName(func.getName());
187  auto funcCst = module->getOrInsertFunction(newName, newType);
188  llvm::Function *interfaceFunc = cast<llvm::Function>(funcCst.getCallee());
189  interfaceFunctions.insert(interfaceFunc);
190 
191  // Extract the arguments from the type-erased argument list and cast them to
192  // the proper types.
193  auto *bb = llvm::BasicBlock::Create(ctx);
194  bb->insertInto(interfaceFunc);
195  builder.SetInsertPoint(bb);
196  llvm::Value *argList = interfaceFunc->arg_begin();
198  args.reserve(llvm::size(func.args()));
199  for (auto &indexedArg : llvm::enumerate(func.args())) {
200  llvm::Value *argIndex = llvm::Constant::getIntegerValue(
201  builder.getInt64Ty(), APInt(64, indexedArg.index()));
202  llvm::Value *argPtrPtr =
203  builder.CreateGEP(builder.getInt8PtrTy(), argList, argIndex);
204  llvm::Value *argPtr =
205  builder.CreateLoad(builder.getInt8PtrTy(), argPtrPtr);
206  llvm::Type *argTy = indexedArg.value().getType();
207  argPtr = builder.CreateBitCast(argPtr, argTy->getPointerTo());
208  llvm::Value *arg = builder.CreateLoad(argTy, argPtr);
209  args.push_back(arg);
210  }
211 
212  // Call the implementation function with the extracted arguments.
213  llvm::Value *result = builder.CreateCall(&func, args);
214 
215  // Assuming the result is one value, potentially of type `void`.
216  if (!result->getType()->isVoidTy()) {
217  llvm::Value *retIndex = llvm::Constant::getIntegerValue(
218  builder.getInt64Ty(), APInt(64, llvm::size(func.args())));
219  llvm::Value *retPtrPtr =
220  builder.CreateGEP(builder.getInt8PtrTy(), argList, retIndex);
221  llvm::Value *retPtr =
222  builder.CreateLoad(builder.getInt8PtrTy(), retPtrPtr);
223  retPtr = builder.CreateBitCast(retPtr, result->getType()->getPointerTo());
224  builder.CreateStore(result, retPtr);
225  }
226 
227  // The interface function returns void.
228  builder.CreateRetVoid();
229  }
230 }
231 
232 ExecutionEngine::ExecutionEngine(bool enableObjectDump,
233  bool enableGDBNotificationListener,
234  bool enablePerfNotificationListener)
235  : cache(enableObjectDump ? new SimpleObjectCache() : nullptr),
236  functionNames(),
237  gdbListener(enableGDBNotificationListener
238  ? llvm::JITEventListener::createGDBRegistrationListener()
239  : nullptr),
240  perfListener(nullptr) {
241  if (enablePerfNotificationListener) {
242  if (auto *listener = llvm::JITEventListener::createPerfJITEventListener())
243  perfListener = listener;
244  else if (auto *listener =
245  llvm::JITEventListener::createIntelJITEventListener())
246  perfListener = listener;
247  }
248 }
249 
252  auto engine = std::make_unique<ExecutionEngine>(
253  options.enableObjectDump, options.enableGDBNotificationListener,
254  options.enablePerfNotificationListener);
255 
256  // Remember all entry-points if object dumping is enabled.
257  if (options.enableObjectDump) {
258  for (auto funcOp : m->getRegion(0).getOps<LLVM::LLVMFuncOp>()) {
259  StringRef funcName = funcOp.getSymName();
260  engine->functionNames.push_back(funcName.str());
261  }
262  }
263 
264  std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext);
265  auto llvmModule = options.llvmModuleBuilder
266  ? options.llvmModuleBuilder(m, *ctx)
267  : translateModuleToLLVMIR(m, *ctx);
268  if (!llvmModule)
269  return makeStringError("could not convert to LLVM IR");
270  // FIXME: the triple should be passed to the translation or dialect conversion
271  // instead of this. Currently, the LLVM module created above has no triple
272  // associated with it.
273  setupTargetTriple(llvmModule.get());
274  packFunctionArguments(llvmModule.get());
275 
276  auto dataLayout = llvmModule->getDataLayout();
277 
278  // Callback to create the object layer with symbol resolution to current
279  // process and dynamically linked libraries.
280  auto objectLinkingLayerCreator = [&](ExecutionSession &session,
281  const Triple &tt) {
282  auto objectLayer = std::make_unique<RTDyldObjectLinkingLayer>(
283  session, [sectionMemoryMapper = options.sectionMemoryMapper]() {
284  return std::make_unique<SectionMemoryManager>(sectionMemoryMapper);
285  });
286 
287  // Register JIT event listeners if they are enabled.
288  if (engine->gdbListener)
289  objectLayer->registerJITEventListener(*engine->gdbListener);
290  if (engine->perfListener)
291  objectLayer->registerJITEventListener(*engine->perfListener);
292 
293  // COFF format binaries (Windows) need special handling to deal with
294  // exported symbol visibility.
295  // cf llvm/lib/ExecutionEngine/Orc/LLJIT.cpp LLJIT::createObjectLinkingLayer
296  llvm::Triple targetTriple(llvm::Twine(llvmModule->getTargetTriple()));
297  if (targetTriple.isOSBinFormatCOFF()) {
298  objectLayer->setOverrideObjectFlagsWithResponsibilityFlags(true);
299  objectLayer->setAutoClaimResponsibilityForObjectSymbols(true);
300  }
301 
302  // Resolve symbols from shared libraries.
303  for (auto libPath : options.sharedLibPaths) {
304  auto mb = llvm::MemoryBuffer::getFile(libPath);
305  if (!mb) {
306  errs() << "Failed to create MemoryBuffer for: " << libPath
307  << "\nError: " << mb.getError().message() << "\n";
308  continue;
309  }
310  auto &jd = session.createBareJITDylib(std::string(libPath));
311  auto loaded = DynamicLibrarySearchGenerator::Load(
312  libPath.data(), dataLayout.getGlobalPrefix());
313  if (!loaded) {
314  errs() << "Could not load " << libPath << ":\n " << loaded.takeError()
315  << "\n";
316  continue;
317  }
318  jd.addGenerator(std::move(*loaded));
319  cantFail(objectLayer->add(jd, std::move(mb.get())));
320  }
321 
322  return objectLayer;
323  };
324 
325  // Callback to inspect the cache and recompile on demand. This follows Lang's
326  // LLJITWithObjectCache example.
327  auto compileFunctionCreator = [&](JITTargetMachineBuilder jtmb)
328  -> Expected<std::unique_ptr<IRCompileLayer::IRCompiler>> {
329  if (options.jitCodeGenOptLevel)
330  jtmb.setCodeGenOptLevel(*options.jitCodeGenOptLevel);
331  auto tm = jtmb.createTargetMachine();
332  if (!tm)
333  return tm.takeError();
334  return std::make_unique<TMOwningSimpleCompiler>(std::move(*tm),
335  engine->cache.get());
336  };
337 
338  // Create the LLJIT by calling the LLJITBuilder with 2 callbacks.
339  auto jit =
340  cantFail(llvm::orc::LLJITBuilder()
341  .setCompileFunctionCreator(compileFunctionCreator)
342  .setObjectLinkingLayerCreator(objectLinkingLayerCreator)
343  .create());
344 
345  // Add a ThreadSafemodule to the engine and return.
346  ThreadSafeModule tsm(std::move(llvmModule), std::move(ctx));
347  if (options.transformer)
348  cantFail(tsm.withModuleDo(
349  [&](llvm::Module &module) { return options.transformer(&module); }));
350  cantFail(jit->addIRModule(std::move(tsm)));
351  engine->jit = std::move(jit);
352 
353  // Resolve symbols that are statically linked in the current process.
354  llvm::orc::JITDylib &mainJD = engine->jit->getMainJITDylib();
355  mainJD.addGenerator(
356  cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess(
357  dataLayout.getGlobalPrefix())));
358 
359  return std::move(engine);
360 }
361 
362 Expected<void (*)(void **)>
363 ExecutionEngine::lookupPacked(StringRef name) const {
364  auto result = lookup(makePackedFunctionName(name));
365  if (!result)
366  return result.takeError();
367  return reinterpret_cast<void (*)(void **)>(result.get());
368 }
369 
371  auto expectedSymbol = jit->lookup(name);
372 
373  // JIT lookup may return an Error referring to strings stored internally by
374  // the JIT. If the Error outlives the ExecutionEngine, it would want have a
375  // dangling reference, which is currently caught by an assertion inside JIT
376  // thanks to hand-rolled reference counting. Rewrap the error message into a
377  // string before returning. Alternatively, ORC JIT should consider copying
378  // the string into the error message.
379  if (!expectedSymbol) {
380  std::string errorMessage;
381  llvm::raw_string_ostream os(errorMessage);
382  llvm::handleAllErrors(expectedSymbol.takeError(),
383  [&os](llvm::ErrorInfoBase &ei) { ei.log(os); });
384  return makeStringError(os.str());
385  }
386 
387  if (void *fptr = expectedSymbol->toPtr<void *>())
388  return fptr;
389  return makeStringError("looked up function is null");
390 }
391 
394  auto expectedFPtr = lookupPacked(name);
395  if (!expectedFPtr)
396  return expectedFPtr.takeError();
397  auto fptr = *expectedFPtr;
398 
399  (*fptr)(args.data());
400 
401  return Error::success();
402 }
static void packFunctionArguments(Module *module)
static Error makeStringError(const Twine &message)
Wrap a string into an llvm::StringError.
static std::string makePackedFunctionName(StringRef name)
@ Error
static llvm::ManagedStatic< PassManagerOptions > options
llvm::Expected< void(*)(void **)> lookupPacked(StringRef name) const
Looks up a packed-argument function wrapping the function with the given name and returns a pointer t...
void registerSymbols(llvm::function_ref< llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)> symbolMap)
Register symbols with this ExecutionEngine.
llvm::Expected< void * > lookup(StringRef name) const
Looks up the original function with the given name and returns a pointer to it.
static bool setupTargetTriple(llvm::Module *llvmModule)
Set the target triple on the module.
void dumpToObjectFile(StringRef filename)
Dump object code to output file filename.
llvm::Error invokePacked(StringRef name, MutableArrayRef< void * > args=llvm::None)
Invokes the function with the given name passing it the list of opaque pointers to the actual argumen...
static Result< T > result(T &t)
Helper function to wrap an output operand when using ExecutionEngine::invoke.
static llvm::Expected< std::unique_ptr< ExecutionEngine > > create(Operation *op, const ExecutionEngineOptions &options={})
Creates an execution engine for the given MLIR IR.
ExecutionEngine(bool enableObjectDump, bool enableGDBNotificationListener, bool enablePerfNotificationListener)
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:486
iterator_range< OpIterator > getOps()
Definition: Region.h:172
A simple object cache following Lang's LLJITWithObjectCache example.
bool isEmpty()
Returns true if cache hasn't been populated yet.
void notifyObjectCompiled(const llvm::Module *m, llvm::MemoryBufferRef objBuffer) override
void dumpToObjectFile(StringRef filename)
Dump cached object to output file filename.
std::unique_ptr< llvm::MemoryBuffer > getObject(const llvm::Module *m) override
Include the generated interface declarations.
Definition: CallGraph.h:229
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
Include the generated interface declarations.
std::unique_ptr< llvm::ToolOutputFile > openOutputFile(llvm::StringRef outputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for writing.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::unique_ptr< llvm::Module > translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext, llvm::StringRef name="LLVMDialectModule")
Translate operation that satisfies LLVM dialect module requirements into an LLVM IR module living in ...