MLIR  17.0.0git
JitRunner.cpp
Go to the documentation of this file.
1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a library that provides a shared implementation for command line
10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
11 // IR before JIT-compiling and executing the latter.
12 //
13 // The translation can be customized by providing an MLIR to MLIR
14 // transformation.
15 //===----------------------------------------------------------------------===//
16 
18 
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/Parser/Parser.h"
27 
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
30 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
31 #include "llvm/IR/IRBuilder.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/LegacyPassNameParser.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Support/FileUtilities.h"
36 #include "llvm/Support/SourceMgr.h"
37 #include "llvm/Support/StringSaver.h"
38 #include "llvm/Support/ToolOutputFile.h"
39 #include <cstdint>
40 #include <numeric>
41 #include <utility>
42 #include <optional>
43 
44 using namespace mlir;
45 using llvm::Error;
46 
47 namespace {
48 /// This options struct prevents the need for global static initializers, and
49 /// is only initialized if the JITRunner is invoked.
50 struct Options {
51  llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional,
52  llvm::cl::desc("<input file>"),
53  llvm::cl::init("-")};
54  llvm::cl::opt<std::string> mainFuncName{
55  "e", llvm::cl::desc("The function to be called"),
56  llvm::cl::value_desc("<function name>"), llvm::cl::init("main")};
57  llvm::cl::opt<std::string> mainFuncType{
58  "entry-point-result",
59  llvm::cl::desc("Textual description of the function type to be called"),
60  llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")};
61 
62  llvm::cl::OptionCategory optFlags{"opt-like flags"};
63 
64  // CLI variables for -On options.
65  llvm::cl::opt<bool> optO0{"O0",
66  llvm::cl::desc("Run opt passes and codegen at O0"),
67  llvm::cl::cat(optFlags)};
68  llvm::cl::opt<bool> optO1{"O1",
69  llvm::cl::desc("Run opt passes and codegen at O1"),
70  llvm::cl::cat(optFlags)};
71  llvm::cl::opt<bool> optO2{"O2",
72  llvm::cl::desc("Run opt passes and codegen at O2"),
73  llvm::cl::cat(optFlags)};
74  llvm::cl::opt<bool> optO3{"O3",
75  llvm::cl::desc("Run opt passes and codegen at O3"),
76  llvm::cl::cat(optFlags)};
77 
78  llvm::cl::OptionCategory clOptionsCategory{"linking options"};
79  llvm::cl::list<std::string> clSharedLibs{
80  "shared-libs", llvm::cl::desc("Libraries to link dynamically"),
81  llvm::cl::MiscFlags::CommaSeparated, llvm::cl::cat(clOptionsCategory)};
82 
83  /// CLI variables for debugging.
84  llvm::cl::opt<bool> dumpObjectFile{
85  "dump-object-file",
86  llvm::cl::desc("Dump JITted-compiled object to file specified with "
87  "-object-filename (<input file>.o by default).")};
88 
89  llvm::cl::opt<std::string> objectFilename{
90  "object-filename",
91  llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
92 
93  llvm::cl::opt<bool> hostSupportsJit{"host-supports-jit",
94  llvm::cl::desc("Report host JIT support"),
95  llvm::cl::Hidden};
96 
97  llvm::cl::opt<bool> noImplicitModule{
98  "no-implicit-module",
99  llvm::cl::desc(
100  "Disable implicit addition of a top-level module op during parsing"),
101  llvm::cl::init(false)};
102 };
103 
104 struct CompileAndExecuteConfig {
105  /// LLVM module transformer that is passed to ExecutionEngine.
106  std::function<llvm::Error(llvm::Module *)> transformer;
107 
108  /// A custom function that is passed to ExecutionEngine. It processes MLIR
109  /// module and creates LLVM IR module.
111  llvm::LLVMContext &)>
112  llvmModuleBuilder;
113 
114  /// A custom function that is passed to ExecutinEngine to register symbols at
115  /// runtime.
116  llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
117  runtimeSymbolMap;
118 };
119 
120 } // namespace
121 
122 static OwningOpRef<Operation *> parseMLIRInput(StringRef inputFilename,
123  bool insertImplicitModule,
124  MLIRContext *context) {
125  // Set up the input file.
126  std::string errorMessage;
127  auto file = openInputFile(inputFilename, &errorMessage);
128  if (!file) {
129  llvm::errs() << errorMessage << "\n";
130  return nullptr;
131  }
132 
133  auto sourceMgr = std::make_shared<llvm::SourceMgr>();
134  sourceMgr->AddNewSourceBuffer(std::move(file), SMLoc());
135  OwningOpRef<Operation *> module =
136  parseSourceFileForTool(sourceMgr, context, insertImplicitModule);
137  if (!module)
138  return nullptr;
139  if (!module.get()->hasTrait<OpTrait::SymbolTable>()) {
140  llvm::errs() << "Error: top-level op must be a symbol table.\n";
141  return nullptr;
142  }
143  return module;
144 }
145 
146 static inline Error makeStringError(const Twine &message) {
147  return llvm::make_error<llvm::StringError>(message.str(),
148  llvm::inconvertibleErrorCode());
149 }
150 
151 static std::optional<unsigned> getCommandLineOptLevel(Options &options) {
152  std::optional<unsigned> optLevel;
154  options.optO0, options.optO1, options.optO2, options.optO3};
155 
156  // Determine if there is an optimization flag present.
157  for (unsigned j = 0; j < 4; ++j) {
158  auto &flag = optFlags[j].get();
159  if (flag) {
160  optLevel = j;
161  break;
162  }
163  }
164  return optLevel;
165 }
166 
167 // JIT-compile the given module and run "entryPoint" with "args" as arguments.
168 static Error compileAndExecute(Options &options, Operation *module,
169  StringRef entryPoint,
170  CompileAndExecuteConfig config, void **args) {
171  std::optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel;
172  if (auto clOptLevel = getCommandLineOptLevel(options))
173  jitCodeGenOptLevel = static_cast<llvm::CodeGenOpt::Level>(*clOptLevel);
174 
175  // If shared library implements custom mlir-runner library init and destroy
176  // functions, we'll use them to register the library with the execution
177  // engine. Otherwise we'll pass library directly to the execution engine.
178  SmallVector<SmallString<256>, 4> libPaths;
179 
180  // Use absolute library path so that gdb can find the symbol table.
181  transform(
182  options.clSharedLibs, std::back_inserter(libPaths),
183  [](std::string libPath) {
184  SmallString<256> absPath(libPath.begin(), libPath.end());
185  cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath)));
186  return absPath;
187  });
188 
189  // Libraries that we'll pass to the ExecutionEngine for loading.
190  SmallVector<StringRef, 4> executionEngineLibs;
191 
192  using MlirRunnerInitFn = void (*)(llvm::StringMap<void *> &);
193  using MlirRunnerDestroyFn = void (*)();
194 
195  llvm::StringMap<void *> exportSymbols;
197 
198  // Handle libraries that do support mlir-runner init/destroy callbacks.
199  for (auto &libPath : libPaths) {
200  auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str());
201  void *initSym = lib.getAddressOfSymbol("__mlir_runner_init");
202  void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy");
203 
204  // Library does not support mlir runner, load it with ExecutionEngine.
205  if (!initSym || !destroySim) {
206  executionEngineLibs.push_back(libPath);
207  continue;
208  }
209 
210  auto initFn = reinterpret_cast<MlirRunnerInitFn>(initSym);
211  initFn(exportSymbols);
212 
213  auto destroyFn = reinterpret_cast<MlirRunnerDestroyFn>(destroySim);
214  destroyFns.push_back(destroyFn);
215  }
216 
217  // Build a runtime symbol map from the config and exported symbols.
218  auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
219  auto symbolMap = config.runtimeSymbolMap ? config.runtimeSymbolMap(interner)
220  : llvm::orc::SymbolMap();
221  for (auto &exportSymbol : exportSymbols)
222  symbolMap[interner(exportSymbol.getKey())] =
223  llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue());
224  return symbolMap;
225  };
226 
227  mlir::ExecutionEngineOptions engineOptions;
228  engineOptions.llvmModuleBuilder = config.llvmModuleBuilder;
229  if (config.transformer)
230  engineOptions.transformer = config.transformer;
231  engineOptions.jitCodeGenOptLevel = jitCodeGenOptLevel;
232  engineOptions.sharedLibPaths = executionEngineLibs;
233  engineOptions.enableObjectDump = true;
234  auto expectedEngine = mlir::ExecutionEngine::create(module, engineOptions);
235  if (!expectedEngine)
236  return expectedEngine.takeError();
237 
238  auto engine = std::move(*expectedEngine);
239  engine->registerSymbols(runtimeSymbolMap);
240 
241  auto expectedFPtr = engine->lookupPacked(entryPoint);
242  if (!expectedFPtr)
243  return expectedFPtr.takeError();
244 
245  if (options.dumpObjectFile)
246  engine->dumpToObjectFile(options.objectFilename.empty()
247  ? options.inputFilename + ".o"
248  : options.objectFilename);
249 
250  void (*fptr)(void **) = *expectedFPtr;
251  (*fptr)(args);
252 
253  // Run all dynamic library destroy callbacks to prepare for the shutdown.
254  for (MlirRunnerDestroyFn destroy : destroyFns)
255  destroy();
256 
257  return Error::success();
258 }
259 
261  StringRef entryPoint,
262  CompileAndExecuteConfig config) {
263  auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(
264  SymbolTable::lookupSymbolIn(module, entryPoint));
265  if (!mainFunction || mainFunction.empty())
266  return makeStringError("entry point not found");
267  void *empty = nullptr;
268  return compileAndExecute(options, module, entryPoint, std::move(config),
269  &empty);
270 }
271 
272 template <typename Type>
273 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction);
274 template <>
275 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
276  auto resultType = mainFunction.getFunctionType()
277  .cast<LLVM::LLVMFunctionType>()
278  .getReturnType()
279  .dyn_cast<IntegerType>();
280  if (!resultType || resultType.getWidth() != 32)
281  return makeStringError("only single i32 function result supported");
282  return Error::success();
283 }
284 template <>
285 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
286  auto resultType = mainFunction.getFunctionType()
287  .cast<LLVM::LLVMFunctionType>()
288  .getReturnType()
289  .dyn_cast<IntegerType>();
290  if (!resultType || resultType.getWidth() != 64)
291  return makeStringError("only single i64 function result supported");
292  return Error::success();
293 }
294 template <>
295 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
296  if (!mainFunction.getFunctionType()
297  .cast<LLVM::LLVMFunctionType>()
298  .getReturnType()
299  .isa<Float32Type>())
300  return makeStringError("only single f32 function result supported");
301  return Error::success();
302 }
303 template <typename Type>
305  StringRef entryPoint,
306  CompileAndExecuteConfig config) {
307  auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(
308  SymbolTable::lookupSymbolIn(module, entryPoint));
309  if (!mainFunction || mainFunction.isExternal())
310  return makeStringError("entry point not found");
311 
312  if (mainFunction.getFunctionType()
313  .cast<LLVM::LLVMFunctionType>()
314  .getNumParams() != 0)
315  return makeStringError("function inputs not supported");
316 
317  if (Error error = checkCompatibleReturnType<Type>(mainFunction))
318  return error;
319 
320  Type res;
321  struct {
322  void *data;
323  } data;
324  data.data = &res;
325  if (auto error = compileAndExecute(options, module, entryPoint,
326  std::move(config), (void **)&data))
327  return error;
328 
329  // Intentional printing of the output so we can test.
330  llvm::outs() << res << '\n';
331 
332  return Error::success();
333 }
334 
335 /// Entry point for all CPU runners. Expects the common argc/argv arguments for
336 /// standard C++ main functions.
337 int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry &registry,
338  JitRunnerConfig config) {
339  // Create the options struct containing the command line options for the
340  // runner. This must come before the command line options are parsed.
341  Options options;
342  llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
343 
344  if (options.hostSupportsJit) {
345  auto j = llvm::orc::LLJITBuilder().create();
346  if (j)
347  llvm::outs() << "true\n";
348  else {
349  llvm::consumeError(j.takeError());
350  llvm::outs() << "false\n";
351  }
352  return 0;
353  }
354 
355  std::optional<unsigned> optLevel = getCommandLineOptLevel(options);
357  options.optO0, options.optO1, options.optO2, options.optO3};
358 
359  MLIRContext context(registry);
360 
361  auto m = parseMLIRInput(options.inputFilename, !options.noImplicitModule,
362  &context);
363  if (!m) {
364  llvm::errs() << "could not parse the input IR\n";
365  return 1;
366  }
367 
368  JitRunnerOptions runnerOptions{options.mainFuncName, options.mainFuncType};
369  if (config.mlirTransformer)
370  if (failed(config.mlirTransformer(m.get(), runnerOptions)))
371  return EXIT_FAILURE;
372 
373  auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
374  if (!tmBuilderOrError) {
375  llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
376  return EXIT_FAILURE;
377  }
378  auto tmOrError = tmBuilderOrError->createTargetMachine();
379  if (!tmOrError) {
380  llvm::errs() << "Failed to create a TargetMachine for the host\n";
381  return EXIT_FAILURE;
382  }
383 
384  CompileAndExecuteConfig compileAndExecuteConfig;
385  if (optLevel) {
386  compileAndExecuteConfig.transformer = mlir::makeOptimizingTransformer(
387  *optLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get());
388  }
389  compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder;
390  compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap;
391 
392  // Get the function used to compile and execute the module.
393  using CompileAndExecuteFnT =
394  Error (*)(Options &, Operation *, StringRef, CompileAndExecuteConfig);
395  auto compileAndExecuteFn =
396  StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
397  .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
398  .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>)
399  .Case("f32", compileAndExecuteSingleReturnFunction<float>)
400  .Case("void", compileAndExecuteVoidFunction)
401  .Default(nullptr);
402 
403  Error error = compileAndExecuteFn
404  ? compileAndExecuteFn(options, m.get(),
405  options.mainFuncName.getValue(),
406  compileAndExecuteConfig)
407  : makeStringError("unsupported function type");
408 
409  int exitCode = EXIT_SUCCESS;
410  llvm::handleAllErrors(std::move(error),
411  [&exitCode](const llvm::ErrorInfoBase &info) {
412  llvm::errs() << "Error: ";
413  info.log(llvm::errs());
414  llvm::errs() << '\n';
415  exitCode = EXIT_FAILURE;
416  });
417 
418  return exitCode;
419 }
Error checkCompatibleReturnType< int64_t >(LLVM::LLVMFuncOp mainFunction)
Definition: JitRunner.cpp:285
Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction)
Error checkCompatibleReturnType< float >(LLVM::LLVMFuncOp mainFunction)
Definition: JitRunner.cpp:295
static Error compileAndExecuteVoidFunction(Options &options, Operation *module, StringRef entryPoint, CompileAndExecuteConfig config)
Definition: JitRunner.cpp:260
static std::optional< unsigned > getCommandLineOptLevel(Options &options)
Definition: JitRunner.cpp:151
static Error makeStringError(const Twine &message)
Definition: JitRunner.cpp:146
Error checkCompatibleReturnType< int32_t >(LLVM::LLVMFuncOp mainFunction)
Definition: JitRunner.cpp:275
Error compileAndExecuteSingleReturnFunction(Options &options, Operation *module, StringRef entryPoint, CompileAndExecuteConfig config)
Definition: JitRunner.cpp:304
static Error compileAndExecute(Options &options, Operation *module, StringRef entryPoint, CompileAndExecuteConfig config, void **args)
Definition: JitRunner.cpp:168
static OwningOpRef< Operation * > parseMLIRInput(StringRef inputFilename, bool insertImplicitModule, MLIRContext *context)
Definition: JitRunner.cpp:122
@ Error
static llvm::ManagedStatic< PassManagerOptions > options
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
static llvm::Expected< std::unique_ptr< ExecutionEngine > > create(Operation *op, const ExecutionEngineOptions &options={})
Creates an execution engine for the given MLIR IR.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:343
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition: OwningOpRef.h:28
OpTy get() const
Allow accessing the internal op.
Definition: OwningOpRef.h:50
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Include the generated interface declarations.
std::unique_ptr< llvm::MemoryBuffer > openInputFile(llvm::StringRef inputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for reading.
int JitRunnerMain(int argc, char **argv, const DialectRegistry &registry, JitRunnerConfig config={})
Entry point for all CPU runners.
Definition: JitRunner.cpp:337
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::function< llvm::Error(llvm::Module *)> makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel, llvm::TargetMachine *targetMachine)
Create a module transformer function for MLIR ExecutionEngine that runs LLVM IR passes corresponding ...
OwningOpRef< Operation * > parseSourceFileForTool(const std::shared_ptr< llvm::SourceMgr > &sourceMgr, const ParserConfig &config, bool insertImplicitModule)
This parses the file specified by the indicated SourceMgr.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
std::optional< llvm::CodeGenOpt::Level > jitCodeGenOptLevel
jitCodeGenOptLevel, when provided, is used as the optimization level for target code generation.
ArrayRef< StringRef > sharedLibPaths
If sharedLibPaths are provided, the underlying JIT-compilation will open and link the shared librarie...
bool enableObjectDump
If enableObjectCache is set, the JIT compiler will create one to store the object generated for the g...
llvm::function_ref< std::unique_ptr< llvm::Module >Operation *, llvm::LLVMContext &)> llvmModuleBuilder
If llvmModuleBuilder is provided, it will be used to create an LLVM module from the given MLIR IR.
llvm::function_ref< llvm::Error(llvm::Module *)> transformer
If transformer is provided, it will be called on the LLVM module during JIT-compilation and can be us...
Configuration to override functionality of the JitRunner.
Definition: JitRunner.h:48
llvm::function_ref< LogicalResult(mlir::Operation *, JitRunnerOptions &options)> mlirTransformer
MLIR transformer applied after parsing the input into MLIR IR and before passing the MLIR IR to the E...
Definition: JitRunner.h:53
llvm::function_ref< std::unique_ptr< llvm::Module >Operation *, llvm::LLVMContext &)> llvmModuleBuilder
A custom function that is passed to ExecutionEngine.
Definition: JitRunner.h:59
llvm::function_ref< llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)> runtimesymbolMap
A callback to register symbols with ExecutionEngine at runtime.
Definition: JitRunner.h:63
JitRunner command line options used by JitRunnerConfig methods.
Definition: JitRunner.h:40
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.