MLIR  21.0.0git
JitRunner.cpp
Go to the documentation of this file.
1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a library that provides a shared implementation for command line
10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
11 // IR before JIT-compiling and executing the latter.
12 //
13 // The translation can be customized by providing an MLIR to MLIR
14 // transformation.
15 //===----------------------------------------------------------------------===//
16 
18 
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/Parser/Parser.h"
27 
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
30 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
31 #include "llvm/IR/IRBuilder.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/LegacyPassNameParser.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/FileUtilities.h"
37 #include "llvm/Support/SourceMgr.h"
38 #include "llvm/Support/StringSaver.h"
39 #include "llvm/Support/ToolOutputFile.h"
40 #include <cstdint>
41 #include <numeric>
42 #include <optional>
43 #include <utility>
44 
45 #define DEBUG_TYPE "jit-runner"
46 
47 using namespace mlir;
48 using llvm::Error;
49 
50 namespace {
51 /// This options struct prevents the need for global static initializers, and
52 /// is only initialized if the JITRunner is invoked.
53 struct Options {
54  llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional,
55  llvm::cl::desc("<input file>"),
56  llvm::cl::init("-")};
57  llvm::cl::opt<std::string> mainFuncName{
58  "e", llvm::cl::desc("The function to be called"),
59  llvm::cl::value_desc("<function name>"), llvm::cl::init("main")};
60  llvm::cl::opt<std::string> mainFuncType{
61  "entry-point-result",
62  llvm::cl::desc("Textual description of the function type to be called"),
63  llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")};
64 
65  llvm::cl::OptionCategory optFlags{"opt-like flags"};
66 
67  // CLI variables for -On options.
68  llvm::cl::opt<bool> optO0{"O0",
69  llvm::cl::desc("Run opt passes and codegen at O0"),
70  llvm::cl::cat(optFlags)};
71  llvm::cl::opt<bool> optO1{"O1",
72  llvm::cl::desc("Run opt passes and codegen at O1"),
73  llvm::cl::cat(optFlags)};
74  llvm::cl::opt<bool> optO2{"O2",
75  llvm::cl::desc("Run opt passes and codegen at O2"),
76  llvm::cl::cat(optFlags)};
77  llvm::cl::opt<bool> optO3{"O3",
78  llvm::cl::desc("Run opt passes and codegen at O3"),
79  llvm::cl::cat(optFlags)};
80 
81  llvm::cl::list<std::string> mAttrs{
82  "mattr", llvm::cl::MiscFlags::CommaSeparated,
83  llvm::cl::desc("Target specific attributes (-mattr=help for details)"),
84  llvm::cl::value_desc("a1,+a2,-a3,..."), llvm::cl::cat(optFlags)};
85 
86  llvm::cl::opt<std::string> mArch{
87  "march",
88  llvm::cl::desc("Architecture to generate code for (see --version)")};
89 
90  llvm::cl::OptionCategory clOptionsCategory{"linking options"};
91  llvm::cl::list<std::string> clSharedLibs{
92  "shared-libs", llvm::cl::desc("Libraries to link dynamically"),
93  llvm::cl::MiscFlags::CommaSeparated, llvm::cl::cat(clOptionsCategory)};
94 
95  /// CLI variables for debugging.
96  llvm::cl::opt<bool> dumpObjectFile{
97  "dump-object-file",
98  llvm::cl::desc("Dump JITted-compiled object to file specified with "
99  "-object-filename (<input file>.o by default).")};
100 
101  llvm::cl::opt<std::string> objectFilename{
102  "object-filename",
103  llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
104 
105  llvm::cl::opt<bool> hostSupportsJit{"host-supports-jit",
106  llvm::cl::desc("Report host JIT support"),
107  llvm::cl::Hidden};
108 
109  llvm::cl::opt<bool> noImplicitModule{
110  "no-implicit-module",
111  llvm::cl::desc(
112  "Disable implicit addition of a top-level module op during parsing"),
113  llvm::cl::init(false)};
114 };
115 
116 struct CompileAndExecuteConfig {
117  /// LLVM module transformer that is passed to ExecutionEngine.
118  std::function<llvm::Error(llvm::Module *)> transformer;
119 
120  /// A custom function that is passed to ExecutionEngine. It processes MLIR
121  /// module and creates LLVM IR module.
123  llvm::LLVMContext &)>
124  llvmModuleBuilder;
125 
126  /// A custom function that is passed to ExecutinEngine to register symbols at
127  /// runtime.
128  llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
129  runtimeSymbolMap;
130 };
131 
132 } // namespace
133 
134 static OwningOpRef<Operation *> parseMLIRInput(StringRef inputFilename,
135  bool insertImplicitModule,
136  MLIRContext *context) {
137  // Set up the input file.
138  std::string errorMessage;
139  auto file = openInputFile(inputFilename, &errorMessage);
140  if (!file) {
141  llvm::errs() << errorMessage << "\n";
142  return nullptr;
143  }
144 
145  auto sourceMgr = std::make_shared<llvm::SourceMgr>();
146  sourceMgr->AddNewSourceBuffer(std::move(file), SMLoc());
147  OwningOpRef<Operation *> module =
148  parseSourceFileForTool(sourceMgr, context, insertImplicitModule);
149  if (!module)
150  return nullptr;
151  if (!module.get()->hasTrait<OpTrait::SymbolTable>()) {
152  llvm::errs() << "Error: top-level op must be a symbol table.\n";
153  return nullptr;
154  }
155  return module;
156 }
157 
158 static inline Error makeStringError(const Twine &message) {
159  return llvm::make_error<llvm::StringError>(message.str(),
160  llvm::inconvertibleErrorCode());
161 }
162 
163 static std::optional<unsigned> getCommandLineOptLevel(Options &options) {
164  std::optional<unsigned> optLevel;
166  options.optO0, options.optO1, options.optO2, options.optO3};
167 
168  // Determine if there is an optimization flag present.
169  for (unsigned j = 0; j < 4; ++j) {
170  auto &flag = optFlags[j].get();
171  if (flag) {
172  optLevel = j;
173  break;
174  }
175  }
176  return optLevel;
177 }
178 
179 // JIT-compile the given module and run "entryPoint" with "args" as arguments.
180 static Error
181 compileAndExecute(Options &options, Operation *module, StringRef entryPoint,
182  CompileAndExecuteConfig config, void **args,
183  std::unique_ptr<llvm::TargetMachine> tm = nullptr) {
184  std::optional<llvm::CodeGenOptLevel> jitCodeGenOptLevel;
185  if (auto clOptLevel = getCommandLineOptLevel(options))
186  jitCodeGenOptLevel = static_cast<llvm::CodeGenOptLevel>(*clOptLevel);
187 
188  SmallVector<StringRef, 4> sharedLibs(options.clSharedLibs.begin(),
189  options.clSharedLibs.end());
190 
191  mlir::ExecutionEngineOptions engineOptions;
192  engineOptions.llvmModuleBuilder = config.llvmModuleBuilder;
193  if (config.transformer)
194  engineOptions.transformer = config.transformer;
195  engineOptions.jitCodeGenOptLevel = jitCodeGenOptLevel;
196  engineOptions.sharedLibPaths = sharedLibs;
197  engineOptions.enableObjectDump = true;
198  auto expectedEngine =
199  mlir::ExecutionEngine::create(module, engineOptions, std::move(tm));
200  if (!expectedEngine)
201  return expectedEngine.takeError();
202 
203  auto engine = std::move(*expectedEngine);
204 
205  auto expectedFPtr = engine->lookupPacked(entryPoint);
206  if (!expectedFPtr)
207  return expectedFPtr.takeError();
208 
209  if (options.dumpObjectFile)
210  engine->dumpToObjectFile(options.objectFilename.empty()
211  ? options.inputFilename + ".o"
212  : options.objectFilename);
213 
214  void (*fptr)(void **) = *expectedFPtr;
215  (*fptr)(args);
216 
217  return Error::success();
218 }
219 
221  Options &options, Operation *module, StringRef entryPoint,
222  CompileAndExecuteConfig config, std::unique_ptr<llvm::TargetMachine> tm) {
223  auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(
224  SymbolTable::lookupSymbolIn(module, entryPoint));
225  if (!mainFunction || mainFunction.isExternal())
226  return makeStringError("entry point not found");
227 
228  if (cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
229  .getNumParams() != 0)
230  return makeStringError(
231  "JIT can't invoke a main function expecting arguments");
232 
233  auto resultType = dyn_cast<LLVM::LLVMVoidType>(
234  mainFunction.getFunctionType().getReturnType());
235  if (!resultType)
236  return makeStringError("expected void function");
237 
238  void *empty = nullptr;
239  return compileAndExecute(options, module, entryPoint, std::move(config),
240  &empty, std::move(tm));
241 }
242 
243 template <typename Type>
244 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction);
245 template <>
246 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
247  auto resultType = dyn_cast<IntegerType>(
248  cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
249  .getReturnType());
250  if (!resultType || resultType.getWidth() != 32)
251  return makeStringError("only single i32 function result supported");
252  return Error::success();
253 }
254 template <>
255 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
256  auto resultType = dyn_cast<IntegerType>(
257  cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
258  .getReturnType());
259  if (!resultType || resultType.getWidth() != 64)
260  return makeStringError("only single i64 function result supported");
261  return Error::success();
262 }
263 template <>
264 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
265  if (!isa<Float32Type>(
266  cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
267  .getReturnType()))
268  return makeStringError("only single f32 function result supported");
269  return Error::success();
270 }
271 template <typename Type>
273  Options &options, Operation *module, StringRef entryPoint,
274  CompileAndExecuteConfig config, std::unique_ptr<llvm::TargetMachine> tm) {
275  auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(
276  SymbolTable::lookupSymbolIn(module, entryPoint));
277  if (!mainFunction || mainFunction.isExternal())
278  return makeStringError("entry point not found");
279 
280  if (cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
281  .getNumParams() != 0)
282  return makeStringError(
283  "JIT can't invoke a main function expecting arguments");
284 
285  if (Error error = checkCompatibleReturnType<Type>(mainFunction))
286  return error;
287 
288  Type res;
289  struct {
290  void *data;
291  } data;
292  data.data = &res;
293  if (auto error =
294  compileAndExecute(options, module, entryPoint, std::move(config),
295  (void **)&data, std::move(tm)))
296  return error;
297 
298  // Intentional printing of the output so we can test.
299  llvm::outs() << res << '\n';
300 
301  return Error::success();
302 }
303 
304 /// Entry point for all CPU runners. Expects the common argc/argv arguments for
305 /// standard C++ main functions.
306 int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry &registry,
308  llvm::ExitOnError exitOnErr;
309 
310  // Create the options struct containing the command line options for the
311  // runner. This must come before the command line options are parsed.
312  Options options;
313  llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
314 
315  if (options.hostSupportsJit) {
316  auto j = llvm::orc::LLJITBuilder().create();
317  if (j)
318  llvm::outs() << "true\n";
319  else {
320  llvm::outs() << "false\n";
321  exitOnErr(j.takeError());
322  }
323  return 0;
324  }
325 
326  std::optional<unsigned> optLevel = getCommandLineOptLevel(options);
328  options.optO0, options.optO1, options.optO2, options.optO3};
329 
330  MLIRContext context(registry);
331 
332  auto m = parseMLIRInput(options.inputFilename, !options.noImplicitModule,
333  &context);
334  if (!m) {
335  llvm::errs() << "could not parse the input IR\n";
336  return 1;
337  }
338 
339  JitRunnerOptions runnerOptions{options.mainFuncName, options.mainFuncType};
340  if (config.mlirTransformer)
341  if (failed(config.mlirTransformer(m.get(), runnerOptions)))
342  return EXIT_FAILURE;
343 
344  auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
345  if (!tmBuilderOrError) {
346  llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
347  return EXIT_FAILURE;
348  }
349 
350  // Configure TargetMachine builder based on the command line options
351  llvm::SubtargetFeatures features;
352  if (!options.mAttrs.empty()) {
353  for (StringRef attr : options.mAttrs)
354  features.AddFeature(attr);
355  tmBuilderOrError->addFeatures(features.getFeatures());
356  }
357 
358  if (!options.mArch.empty()) {
359  tmBuilderOrError->getTargetTriple().setArchName(options.mArch);
360  }
361 
362  // Build TargetMachine
363  auto tmOrError = tmBuilderOrError->createTargetMachine();
364 
365  if (!tmOrError) {
366  llvm::errs() << "Failed to create a TargetMachine for the host\n";
367  exitOnErr(tmOrError.takeError());
368  }
369 
370  LLVM_DEBUG({
371  llvm::dbgs() << " JITTargetMachineBuilder is "
372  << llvm::orc::JITTargetMachineBuilderPrinter(*tmBuilderOrError,
373  "\n");
374  });
375 
376  CompileAndExecuteConfig compileAndExecuteConfig;
377  if (optLevel) {
378  compileAndExecuteConfig.transformer = mlir::makeOptimizingTransformer(
379  *optLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get());
380  }
381  compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder;
382  compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap;
383 
384  // Get the function used to compile and execute the module.
385  using CompileAndExecuteFnT =
386  Error (*)(Options &, Operation *, StringRef, CompileAndExecuteConfig,
387  std::unique_ptr<llvm::TargetMachine> tm);
388  auto compileAndExecuteFn =
389  StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
390  .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
391  .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>)
392  .Case("f32", compileAndExecuteSingleReturnFunction<float>)
393  .Case("void", compileAndExecuteVoidFunction)
394  .Default(nullptr);
395 
396  Error error = compileAndExecuteFn
397  ? compileAndExecuteFn(
398  options, m.get(), options.mainFuncName.getValue(),
399  compileAndExecuteConfig, std::move(tmOrError.get()))
400  : makeStringError("unsupported function type");
401 
402  int exitCode = EXIT_SUCCESS;
403  llvm::handleAllErrors(std::move(error),
404  [&exitCode](const llvm::ErrorInfoBase &info) {
405  llvm::errs() << "Error: ";
406  info.log(llvm::errs());
407  llvm::errs() << '\n';
408  exitCode = EXIT_FAILURE;
409  });
410 
411  return exitCode;
412 }
Error checkCompatibleReturnType< int64_t >(LLVM::LLVMFuncOp mainFunction)
Definition: JitRunner.cpp:255
Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction)
static Error compileAndExecute(Options &options, Operation *module, StringRef entryPoint, CompileAndExecuteConfig config, void **args, std::unique_ptr< llvm::TargetMachine > tm=nullptr)
Definition: JitRunner.cpp:181
Error compileAndExecuteSingleReturnFunction(Options &options, Operation *module, StringRef entryPoint, CompileAndExecuteConfig config, std::unique_ptr< llvm::TargetMachine > tm)
Definition: JitRunner.cpp:272
static Error compileAndExecuteVoidFunction(Options &options, Operation *module, StringRef entryPoint, CompileAndExecuteConfig config, std::unique_ptr< llvm::TargetMachine > tm)
Definition: JitRunner.cpp:220
Error checkCompatibleReturnType< float >(LLVM::LLVMFuncOp mainFunction)
Definition: JitRunner.cpp:264
static std::optional< unsigned > getCommandLineOptLevel(Options &options)
Definition: JitRunner.cpp:163
static Error makeStringError(const Twine &message)
Definition: JitRunner.cpp:158
Error checkCompatibleReturnType< int32_t >(LLVM::LLVMFuncOp mainFunction)
Definition: JitRunner.cpp:246
static OwningOpRef< Operation * > parseMLIRInput(StringRef inputFilename, bool insertImplicitModule, MLIRContext *context)
Definition: JitRunner.cpp:134
@ Error
static llvm::ManagedStatic< PassManagerOptions > options
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
static llvm::Expected< std::unique_ptr< ExecutionEngine > > create(Operation *op, const ExecutionEngineOptions &options={}, std::unique_ptr< llvm::TargetMachine > tm=nullptr)
Creates an execution engine for the given MLIR IR.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:442
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition: OwningOpRef.h:29
OpTy get() const
Allow accessing the internal op.
Definition: OwningOpRef.h:51
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Include the generated interface declarations.
std::unique_ptr< llvm::MemoryBuffer > openInputFile(llvm::StringRef inputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for reading.
const FrozenRewritePatternSet GreedyRewriteConfig config
int JitRunnerMain(int argc, char **argv, const DialectRegistry &registry, JitRunnerConfig config={})
Entry point for all CPU runners.
Definition: JitRunner.cpp:306
std::function< llvm::Error(llvm::Module *)> makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel, llvm::TargetMachine *targetMachine)
Create a module transformer function for MLIR ExecutionEngine that runs LLVM IR passes corresponding ...
OwningOpRef< Operation * > parseSourceFileForTool(const std::shared_ptr< llvm::SourceMgr > &sourceMgr, const ParserConfig &config, bool insertImplicitModule)
This parses the file specified by the indicated SourceMgr.
std::optional< llvm::CodeGenOptLevel > jitCodeGenOptLevel
jitCodeGenOptLevel, when provided, is used as the optimization level for target code generation.
ArrayRef< StringRef > sharedLibPaths
If sharedLibPaths are provided, the underlying JIT-compilation will open and link the shared librarie...
bool enableObjectDump
If enableObjectCache is set, the JIT compiler will create one to store the object generated for the g...
llvm::function_ref< std::unique_ptr< llvm::Module >Operation *, llvm::LLVMContext &)> llvmModuleBuilder
If llvmModuleBuilder is provided, it will be used to create an LLVM module from the given MLIR IR.
llvm::function_ref< llvm::Error(llvm::Module *)> transformer
If transformer is provided, it will be called on the LLVM module during JIT-compilation and can be us...
Configuration to override functionality of the JitRunner.
Definition: JitRunner.h:48
JitRunner command line options used by JitRunnerConfig methods.
Definition: JitRunner.h:40
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.