MLIR  20.0.0git
JitRunner.cpp
Go to the documentation of this file.
1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a library that provides a shared implementation for command line
10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
11 // IR before JIT-compiling and executing the latter.
12 //
13 // The translation can be customized by providing an MLIR to MLIR
14 // transformation.
15 //===----------------------------------------------------------------------===//
16 
18 
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/Parser/Parser.h"
27 
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
30 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
31 #include "llvm/IR/IRBuilder.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/LegacyPassNameParser.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/FileUtilities.h"
37 #include "llvm/Support/SourceMgr.h"
38 #include "llvm/Support/StringSaver.h"
39 #include "llvm/Support/ToolOutputFile.h"
40 #include <cstdint>
41 #include <numeric>
42 #include <optional>
43 #include <utility>
44 
45 #define DEBUG_TYPE "jit-runner"
46 
47 using namespace mlir;
48 using llvm::Error;
49 
50 namespace {
51 /// This options struct prevents the need for global static initializers, and
52 /// is only initialized if the JITRunner is invoked.
53 struct Options {
54  llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional,
55  llvm::cl::desc("<input file>"),
56  llvm::cl::init("-")};
57  llvm::cl::opt<std::string> mainFuncName{
58  "e", llvm::cl::desc("The function to be called"),
59  llvm::cl::value_desc("<function name>"), llvm::cl::init("main")};
60  llvm::cl::opt<std::string> mainFuncType{
61  "entry-point-result",
62  llvm::cl::desc("Textual description of the function type to be called"),
63  llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")};
64 
65  llvm::cl::OptionCategory optFlags{"opt-like flags"};
66 
67  // CLI variables for -On options.
68  llvm::cl::opt<bool> optO0{"O0",
69  llvm::cl::desc("Run opt passes and codegen at O0"),
70  llvm::cl::cat(optFlags)};
71  llvm::cl::opt<bool> optO1{"O1",
72  llvm::cl::desc("Run opt passes and codegen at O1"),
73  llvm::cl::cat(optFlags)};
74  llvm::cl::opt<bool> optO2{"O2",
75  llvm::cl::desc("Run opt passes and codegen at O2"),
76  llvm::cl::cat(optFlags)};
77  llvm::cl::opt<bool> optO3{"O3",
78  llvm::cl::desc("Run opt passes and codegen at O3"),
79  llvm::cl::cat(optFlags)};
80 
81  llvm::cl::list<std::string> mAttrs{
82  "mattr", llvm::cl::MiscFlags::CommaSeparated,
83  llvm::cl::desc("Target specific attributes (-mattr=help for details)"),
84  llvm::cl::value_desc("a1,+a2,-a3,..."), llvm::cl::cat(optFlags)};
85 
86  llvm::cl::opt<std::string> mArch{
87  "march",
88  llvm::cl::desc("Architecture to generate code for (see --version)")};
89 
90  llvm::cl::OptionCategory clOptionsCategory{"linking options"};
91  llvm::cl::list<std::string> clSharedLibs{
92  "shared-libs", llvm::cl::desc("Libraries to link dynamically"),
93  llvm::cl::MiscFlags::CommaSeparated, llvm::cl::cat(clOptionsCategory)};
94 
95  /// CLI variables for debugging.
96  llvm::cl::opt<bool> dumpObjectFile{
97  "dump-object-file",
98  llvm::cl::desc("Dump JITted-compiled object to file specified with "
99  "-object-filename (<input file>.o by default).")};
100 
101  llvm::cl::opt<std::string> objectFilename{
102  "object-filename",
103  llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
104 
105  llvm::cl::opt<bool> hostSupportsJit{"host-supports-jit",
106  llvm::cl::desc("Report host JIT support"),
107  llvm::cl::Hidden};
108 
109  llvm::cl::opt<bool> noImplicitModule{
110  "no-implicit-module",
111  llvm::cl::desc(
112  "Disable implicit addition of a top-level module op during parsing"),
113  llvm::cl::init(false)};
114 };
115 
116 struct CompileAndExecuteConfig {
117  /// LLVM module transformer that is passed to ExecutionEngine.
118  std::function<llvm::Error(llvm::Module *)> transformer;
119 
120  /// A custom function that is passed to ExecutionEngine. It processes MLIR
121  /// module and creates LLVM IR module.
123  llvm::LLVMContext &)>
124  llvmModuleBuilder;
125 
126  /// A custom function that is passed to ExecutinEngine to register symbols at
127  /// runtime.
128  llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
129  runtimeSymbolMap;
130 };
131 
132 } // namespace
133 
134 static OwningOpRef<Operation *> parseMLIRInput(StringRef inputFilename,
135  bool insertImplicitModule,
136  MLIRContext *context) {
137  // Set up the input file.
138  std::string errorMessage;
139  auto file = openInputFile(inputFilename, &errorMessage);
140  if (!file) {
141  llvm::errs() << errorMessage << "\n";
142  return nullptr;
143  }
144 
145  auto sourceMgr = std::make_shared<llvm::SourceMgr>();
146  sourceMgr->AddNewSourceBuffer(std::move(file), SMLoc());
147  OwningOpRef<Operation *> module =
148  parseSourceFileForTool(sourceMgr, context, insertImplicitModule);
149  if (!module)
150  return nullptr;
151  if (!module.get()->hasTrait<OpTrait::SymbolTable>()) {
152  llvm::errs() << "Error: top-level op must be a symbol table.\n";
153  return nullptr;
154  }
155  return module;
156 }
157 
158 static inline Error makeStringError(const Twine &message) {
159  return llvm::make_error<llvm::StringError>(message.str(),
160  llvm::inconvertibleErrorCode());
161 }
162 
163 static std::optional<unsigned> getCommandLineOptLevel(Options &options) {
164  std::optional<unsigned> optLevel;
166  options.optO0, options.optO1, options.optO2, options.optO3};
167 
168  // Determine if there is an optimization flag present.
169  for (unsigned j = 0; j < 4; ++j) {
170  auto &flag = optFlags[j].get();
171  if (flag) {
172  optLevel = j;
173  break;
174  }
175  }
176  return optLevel;
177 }
178 
179 // JIT-compile the given module and run "entryPoint" with "args" as arguments.
180 static Error
181 compileAndExecute(Options &options, Operation *module, StringRef entryPoint,
182  CompileAndExecuteConfig config, void **args,
183  std::unique_ptr<llvm::TargetMachine> tm = nullptr) {
184  std::optional<llvm::CodeGenOptLevel> jitCodeGenOptLevel;
185  if (auto clOptLevel = getCommandLineOptLevel(options))
186  jitCodeGenOptLevel = static_cast<llvm::CodeGenOptLevel>(*clOptLevel);
187 
188  SmallVector<StringRef, 4> sharedLibs(options.clSharedLibs.begin(),
189  options.clSharedLibs.end());
190 
191  mlir::ExecutionEngineOptions engineOptions;
192  engineOptions.llvmModuleBuilder = config.llvmModuleBuilder;
193  if (config.transformer)
194  engineOptions.transformer = config.transformer;
195  engineOptions.jitCodeGenOptLevel = jitCodeGenOptLevel;
196  engineOptions.sharedLibPaths = sharedLibs;
197  engineOptions.enableObjectDump = true;
198  auto expectedEngine =
199  mlir::ExecutionEngine::create(module, engineOptions, std::move(tm));
200  if (!expectedEngine)
201  return expectedEngine.takeError();
202 
203  auto engine = std::move(*expectedEngine);
204 
205  auto expectedFPtr = engine->lookupPacked(entryPoint);
206  if (!expectedFPtr)
207  return expectedFPtr.takeError();
208 
209  if (options.dumpObjectFile)
210  engine->dumpToObjectFile(options.objectFilename.empty()
211  ? options.inputFilename + ".o"
212  : options.objectFilename);
213 
214  void (*fptr)(void **) = *expectedFPtr;
215  (*fptr)(args);
216 
217  return Error::success();
218 }
219 
221  Options &options, Operation *module, StringRef entryPoint,
222  CompileAndExecuteConfig config, std::unique_ptr<llvm::TargetMachine> tm) {
223  auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(
224  SymbolTable::lookupSymbolIn(module, entryPoint));
225  if (!mainFunction || mainFunction.empty())
226  return makeStringError("entry point not found");
227 
228  auto resultType = dyn_cast<LLVM::LLVMVoidType>(
229  mainFunction.getFunctionType().getReturnType());
230  if (!resultType)
231  return makeStringError("expected void function");
232 
233  void *empty = nullptr;
234  return compileAndExecute(options, module, entryPoint, std::move(config),
235  &empty, std::move(tm));
236 }
237 
238 template <typename Type>
239 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction);
240 template <>
241 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
242  auto resultType = dyn_cast<IntegerType>(
243  cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
244  .getReturnType());
245  if (!resultType || resultType.getWidth() != 32)
246  return makeStringError("only single i32 function result supported");
247  return Error::success();
248 }
249 template <>
250 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
251  auto resultType = dyn_cast<IntegerType>(
252  cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
253  .getReturnType());
254  if (!resultType || resultType.getWidth() != 64)
255  return makeStringError("only single i64 function result supported");
256  return Error::success();
257 }
258 template <>
259 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
260  if (!isa<Float32Type>(
261  cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
262  .getReturnType()))
263  return makeStringError("only single f32 function result supported");
264  return Error::success();
265 }
266 template <typename Type>
268  Options &options, Operation *module, StringRef entryPoint,
269  CompileAndExecuteConfig config, std::unique_ptr<llvm::TargetMachine> tm) {
270  auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(
271  SymbolTable::lookupSymbolIn(module, entryPoint));
272  if (!mainFunction || mainFunction.isExternal())
273  return makeStringError("entry point not found");
274 
275  if (cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
276  .getNumParams() != 0)
277  return makeStringError("function inputs not supported");
278 
279  if (Error error = checkCompatibleReturnType<Type>(mainFunction))
280  return error;
281 
282  Type res;
283  struct {
284  void *data;
285  } data;
286  data.data = &res;
287  if (auto error =
288  compileAndExecute(options, module, entryPoint, std::move(config),
289  (void **)&data, std::move(tm)))
290  return error;
291 
292  // Intentional printing of the output so we can test.
293  llvm::outs() << res << '\n';
294 
295  return Error::success();
296 }
297 
298 /// Entry point for all CPU runners. Expects the common argc/argv arguments for
299 /// standard C++ main functions.
300 int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry &registry,
301  JitRunnerConfig config) {
302  llvm::ExitOnError exitOnErr;
303 
304  // Create the options struct containing the command line options for the
305  // runner. This must come before the command line options are parsed.
306  Options options;
307  llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
308 
309  if (options.hostSupportsJit) {
310  auto j = llvm::orc::LLJITBuilder().create();
311  if (j)
312  llvm::outs() << "true\n";
313  else {
314  llvm::outs() << "false\n";
315  exitOnErr(j.takeError());
316  }
317  return 0;
318  }
319 
320  std::optional<unsigned> optLevel = getCommandLineOptLevel(options);
322  options.optO0, options.optO1, options.optO2, options.optO3};
323 
324  MLIRContext context(registry);
325 
326  auto m = parseMLIRInput(options.inputFilename, !options.noImplicitModule,
327  &context);
328  if (!m) {
329  llvm::errs() << "could not parse the input IR\n";
330  return 1;
331  }
332 
333  JitRunnerOptions runnerOptions{options.mainFuncName, options.mainFuncType};
334  if (config.mlirTransformer)
335  if (failed(config.mlirTransformer(m.get(), runnerOptions)))
336  return EXIT_FAILURE;
337 
338  auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
339  if (!tmBuilderOrError) {
340  llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
341  return EXIT_FAILURE;
342  }
343 
344  // Configure TargetMachine builder based on the command line options
345  llvm::SubtargetFeatures features;
346  if (!options.mAttrs.empty()) {
347  for (StringRef attr : options.mAttrs)
348  features.AddFeature(attr);
349  tmBuilderOrError->addFeatures(features.getFeatures());
350  }
351 
352  if (!options.mArch.empty()) {
353  tmBuilderOrError->getTargetTriple().setArchName(options.mArch);
354  }
355 
356  // Build TargetMachine
357  auto tmOrError = tmBuilderOrError->createTargetMachine();
358 
359  if (!tmOrError) {
360  llvm::errs() << "Failed to create a TargetMachine for the host\n";
361  exitOnErr(tmOrError.takeError());
362  }
363 
364  LLVM_DEBUG({
365  llvm::dbgs() << " JITTargetMachineBuilder is "
366  << llvm::orc::JITTargetMachineBuilderPrinter(*tmBuilderOrError,
367  "\n");
368  });
369 
370  CompileAndExecuteConfig compileAndExecuteConfig;
371  if (optLevel) {
372  compileAndExecuteConfig.transformer = mlir::makeOptimizingTransformer(
373  *optLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get());
374  }
375  compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder;
376  compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap;
377 
378  // Get the function used to compile and execute the module.
379  using CompileAndExecuteFnT =
380  Error (*)(Options &, Operation *, StringRef, CompileAndExecuteConfig,
381  std::unique_ptr<llvm::TargetMachine> tm);
382  auto compileAndExecuteFn =
383  StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
384  .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
385  .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>)
386  .Case("f32", compileAndExecuteSingleReturnFunction<float>)
387  .Case("void", compileAndExecuteVoidFunction)
388  .Default(nullptr);
389 
390  Error error = compileAndExecuteFn
391  ? compileAndExecuteFn(
392  options, m.get(), options.mainFuncName.getValue(),
393  compileAndExecuteConfig, std::move(tmOrError.get()))
394  : makeStringError("unsupported function type");
395 
396  int exitCode = EXIT_SUCCESS;
397  llvm::handleAllErrors(std::move(error),
398  [&exitCode](const llvm::ErrorInfoBase &info) {
399  llvm::errs() << "Error: ";
400  info.log(llvm::errs());
401  llvm::errs() << '\n';
402  exitCode = EXIT_FAILURE;
403  });
404 
405  return exitCode;
406 }
Error checkCompatibleReturnType< int64_t >(LLVM::LLVMFuncOp mainFunction)
Definition: JitRunner.cpp:250
Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction)
static Error compileAndExecute(Options &options, Operation *module, StringRef entryPoint, CompileAndExecuteConfig config, void **args, std::unique_ptr< llvm::TargetMachine > tm=nullptr)
Definition: JitRunner.cpp:181
Error compileAndExecuteSingleReturnFunction(Options &options, Operation *module, StringRef entryPoint, CompileAndExecuteConfig config, std::unique_ptr< llvm::TargetMachine > tm)
Definition: JitRunner.cpp:267
static Error compileAndExecuteVoidFunction(Options &options, Operation *module, StringRef entryPoint, CompileAndExecuteConfig config, std::unique_ptr< llvm::TargetMachine > tm)
Definition: JitRunner.cpp:220
Error checkCompatibleReturnType< float >(LLVM::LLVMFuncOp mainFunction)
Definition: JitRunner.cpp:259
static std::optional< unsigned > getCommandLineOptLevel(Options &options)
Definition: JitRunner.cpp:163
static Error makeStringError(const Twine &message)
Definition: JitRunner.cpp:158
Error checkCompatibleReturnType< int32_t >(LLVM::LLVMFuncOp mainFunction)
Definition: JitRunner.cpp:241
static OwningOpRef< Operation * > parseMLIRInput(StringRef inputFilename, bool insertImplicitModule, MLIRContext *context)
Definition: JitRunner.cpp:134
@ Error
static llvm::ManagedStatic< PassManagerOptions > options
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
static llvm::Expected< std::unique_ptr< ExecutionEngine > > create(Operation *op, const ExecutionEngineOptions &options={}, std::unique_ptr< llvm::TargetMachine > tm=nullptr)
Creates an execution engine for the given MLIR IR.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition: OwningOpRef.h:29
OpTy get() const
Allow accessing the internal op.
Definition: OwningOpRef.h:51
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Include the generated interface declarations.
std::unique_ptr< llvm::MemoryBuffer > openInputFile(llvm::StringRef inputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for reading.
int JitRunnerMain(int argc, char **argv, const DialectRegistry &registry, JitRunnerConfig config={})
Entry point for all CPU runners.
Definition: JitRunner.cpp:300
std::function< llvm::Error(llvm::Module *)> makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel, llvm::TargetMachine *targetMachine)
Create a module transformer function for MLIR ExecutionEngine that runs LLVM IR passes corresponding ...
OwningOpRef< Operation * > parseSourceFileForTool(const std::shared_ptr< llvm::SourceMgr > &sourceMgr, const ParserConfig &config, bool insertImplicitModule)
This parses the file specified by the indicated SourceMgr.
std::optional< llvm::CodeGenOptLevel > jitCodeGenOptLevel
jitCodeGenOptLevel, when provided, is used as the optimization level for target code generation.
ArrayRef< StringRef > sharedLibPaths
If sharedLibPaths are provided, the underlying JIT-compilation will open and link the shared librarie...
bool enableObjectDump
If enableObjectCache is set, the JIT compiler will create one to store the object generated for the g...
llvm::function_ref< std::unique_ptr< llvm::Module >Operation *, llvm::LLVMContext &)> llvmModuleBuilder
If llvmModuleBuilder is provided, it will be used to create an LLVM module from the given MLIR IR.
llvm::function_ref< llvm::Error(llvm::Module *)> transformer
If transformer is provided, it will be called on the LLVM module during JIT-compilation and can be us...
Configuration to override functionality of the JitRunner.
Definition: JitRunner.h:48
llvm::function_ref< llvm::LogicalResult(mlir::Operation *, JitRunnerOptions &options)> mlirTransformer
MLIR transformer applied after parsing the input into MLIR IR and before passing the MLIR IR to the E...
Definition: JitRunner.h:53
llvm::function_ref< std::unique_ptr< llvm::Module >Operation *, llvm::LLVMContext &)> llvmModuleBuilder
A custom function that is passed to ExecutionEngine.
Definition: JitRunner.h:59
llvm::function_ref< llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)> runtimesymbolMap
A callback to register symbols with ExecutionEngine at runtime.
Definition: JitRunner.h:63
JitRunner command line options used by JitRunnerConfig methods.
Definition: JitRunner.h:40
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.