MLIR  22.0.0git
JitRunner.cpp
Go to the documentation of this file.
1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a library that provides a shared implementation for command line
10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
11 // IR before JIT-compiling and executing the latter.
12 //
13 // The translation can be customized by providing an MLIR to MLIR
14 // transformation.
15 //===----------------------------------------------------------------------===//
16 
18 
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/Parser/Parser.h"
27 
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
30 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
31 #include "llvm/IR/IRBuilder.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/LegacyPassNameParser.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/FileUtilities.h"
37 #include "llvm/Support/SourceMgr.h"
38 #include "llvm/Support/StringSaver.h"
39 #include "llvm/Support/ToolOutputFile.h"
40 #include <cstdint>
41 #include <numeric>
42 #include <optional>
43 #include <utility>
44 
45 #define DEBUG_TYPE "jit-runner"
46 
47 using namespace mlir;
48 using llvm::Error;
49 
50 namespace {
51 /// This options struct prevents the need for global static initializers, and
52 /// is only initialized if the JITRunner is invoked.
53 struct Options {
54  llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional,
55  llvm::cl::desc("<input file>"),
56  llvm::cl::init("-")};
57  llvm::cl::opt<std::string> mainFuncName{
58  "e", llvm::cl::desc("The function to be called"),
59  llvm::cl::value_desc("<function name>"), llvm::cl::init("main")};
60  llvm::cl::opt<std::string> mainFuncType{
61  "entry-point-result",
62  llvm::cl::desc("Textual description of the function type to be called"),
63  llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")};
64 
65  llvm::cl::OptionCategory optFlags{"opt-like flags"};
66 
67  // CLI variables for -On options.
68  llvm::cl::opt<bool> optO0{"O0",
69  llvm::cl::desc("Run opt passes and codegen at O0"),
70  llvm::cl::cat(optFlags)};
71  llvm::cl::opt<bool> optO1{"O1",
72  llvm::cl::desc("Run opt passes and codegen at O1"),
73  llvm::cl::cat(optFlags)};
74  llvm::cl::opt<bool> optO2{"O2",
75  llvm::cl::desc("Run opt passes and codegen at O2"),
76  llvm::cl::cat(optFlags)};
77  llvm::cl::opt<bool> optO3{"O3",
78  llvm::cl::desc("Run opt passes and codegen at O3"),
79  llvm::cl::cat(optFlags)};
80 
81  llvm::cl::list<std::string> mAttrs{
82  "mattr", llvm::cl::MiscFlags::CommaSeparated,
83  llvm::cl::desc("Target specific attributes (-mattr=help for details)"),
84  llvm::cl::value_desc("a1,+a2,-a3,..."), llvm::cl::cat(optFlags)};
85 
86  llvm::cl::opt<std::string> mArch{
87  "march",
88  llvm::cl::desc("Architecture to generate code for (see --version)")};
89 
90  llvm::cl::OptionCategory clOptionsCategory{"linking options"};
91  llvm::cl::list<std::string> clSharedLibs{
92  "shared-libs", llvm::cl::desc("Libraries to link dynamically"),
93  llvm::cl::MiscFlags::CommaSeparated, llvm::cl::cat(clOptionsCategory)};
94 
95  /// CLI variables for debugging.
96  llvm::cl::opt<bool> dumpObjectFile{
97  "dump-object-file",
98  llvm::cl::desc("Dump JITted-compiled object to file specified with "
99  "-object-filename (<input file>.o by default).")};
100 
101  llvm::cl::opt<std::string> objectFilename{
102  "object-filename",
103  llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
104 
105  llvm::cl::opt<bool> hostSupportsJit{"host-supports-jit",
106  llvm::cl::desc("Report host JIT support"),
107  llvm::cl::Hidden};
108 
109  llvm::cl::opt<bool> noImplicitModule{
110  "no-implicit-module",
111  llvm::cl::desc(
112  "Disable implicit addition of a top-level module op during parsing"),
113  llvm::cl::init(false)};
114 };
115 
116 struct CompileAndExecuteConfig {
117  /// LLVM module transformer that is passed to ExecutionEngine.
118  std::function<llvm::Error(llvm::Module *)> transformer;
119 
120  /// A custom function that is passed to ExecutionEngine. It processes MLIR
121  /// module and creates LLVM IR module.
123  llvm::LLVMContext &)>
124  llvmModuleBuilder;
125 
126  /// A custom function that is passed to ExecutinEngine to register symbols at
127  /// runtime.
128  llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
129  runtimeSymbolMap;
130 };
131 
132 } // namespace
133 
134 static OwningOpRef<Operation *> parseMLIRInput(StringRef inputFilename,
135  bool insertImplicitModule,
136  MLIRContext *context) {
137  // Set up the input file.
138  std::string errorMessage;
139  auto file = openInputFile(inputFilename, &errorMessage);
140  if (!file) {
141  llvm::errs() << errorMessage << "\n";
142  return nullptr;
143  }
144 
145  auto sourceMgr = std::make_shared<llvm::SourceMgr>();
146  sourceMgr->AddNewSourceBuffer(std::move(file), SMLoc());
147  OwningOpRef<Operation *> module =
148  parseSourceFileForTool(sourceMgr, context, insertImplicitModule);
149  if (!module)
150  return nullptr;
151  if (!module.get()->hasTrait<OpTrait::SymbolTable>()) {
152  llvm::errs() << "Error: top-level op must be a symbol table.\n";
153  return nullptr;
154  }
155  return module;
156 }
157 
158 static inline Error makeStringError(const Twine &message) {
159  return llvm::make_error<llvm::StringError>(message.str(),
160  llvm::inconvertibleErrorCode());
161 }
162 
163 static std::optional<unsigned> getCommandLineOptLevel(Options &options) {
164  std::optional<unsigned> optLevel;
166  options.optO0, options.optO1, options.optO2, options.optO3};
167 
168  // Determine if there is an optimization flag present.
169  for (unsigned j = 0; j < 4; ++j) {
170  auto &flag = optFlags[j].get();
171  if (flag) {
172  optLevel = j;
173  break;
174  }
175  }
176  return optLevel;
177 }
178 
179 // JIT-compile the given module and run "entryPoint" with "args" as arguments.
180 static Error
181 compileAndExecute(Options &options, Operation *module, StringRef entryPoint,
182  CompileAndExecuteConfig config, void **args,
183  std::unique_ptr<llvm::TargetMachine> tm = nullptr) {
184  std::optional<llvm::CodeGenOptLevel> jitCodeGenOptLevel;
185  if (auto clOptLevel = getCommandLineOptLevel(options))
186  jitCodeGenOptLevel = static_cast<llvm::CodeGenOptLevel>(*clOptLevel);
187 
188  SmallVector<StringRef, 4> sharedLibs(options.clSharedLibs.begin(),
189  options.clSharedLibs.end());
190 
191  mlir::ExecutionEngineOptions engineOptions;
192  engineOptions.llvmModuleBuilder = config.llvmModuleBuilder;
193  if (config.transformer)
194  engineOptions.transformer = config.transformer;
195  engineOptions.jitCodeGenOptLevel = jitCodeGenOptLevel;
196  engineOptions.sharedLibPaths = sharedLibs;
197  engineOptions.enableObjectDump = true;
198  auto expectedEngine =
199  mlir::ExecutionEngine::create(module, engineOptions, std::move(tm));
200  if (!expectedEngine)
201  return expectedEngine.takeError();
202 
203  auto engine = std::move(*expectedEngine);
204 
205  engine->initialize();
206 
207  auto expectedFPtr = engine->lookupPacked(entryPoint);
208  if (!expectedFPtr)
209  return expectedFPtr.takeError();
210 
211  if (options.dumpObjectFile)
212  engine->dumpToObjectFile(options.objectFilename.empty()
213  ? options.inputFilename + ".o"
214  : options.objectFilename);
215 
216  void (*fptr)(void **) = *expectedFPtr;
217  (*fptr)(args);
218 
219  return Error::success();
220 }
221 
223  Options &options, Operation *module, StringRef entryPoint,
224  CompileAndExecuteConfig config, std::unique_ptr<llvm::TargetMachine> tm) {
225  auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(
226  SymbolTable::lookupSymbolIn(module, entryPoint));
227  if (!mainFunction || mainFunction.isExternal())
228  return makeStringError("entry point not found");
229 
230  if (cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
231  .getNumParams() != 0)
232  return makeStringError(
233  "JIT can't invoke a main function expecting arguments");
234 
235  auto resultType = dyn_cast<LLVM::LLVMVoidType>(
236  mainFunction.getFunctionType().getReturnType());
237  if (!resultType)
238  return makeStringError("expected void function");
239 
240  void *empty = nullptr;
241  return compileAndExecute(options, module, entryPoint, std::move(config),
242  &empty, std::move(tm));
243 }
244 
245 template <typename Type>
246 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction);
247 template <>
248 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
249  auto resultType = dyn_cast<IntegerType>(
250  cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
251  .getReturnType());
252  if (!resultType || resultType.getWidth() != 32)
253  return makeStringError("only single i32 function result supported");
254  return Error::success();
255 }
256 template <>
257 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
258  auto resultType = dyn_cast<IntegerType>(
259  cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
260  .getReturnType());
261  if (!resultType || resultType.getWidth() != 64)
262  return makeStringError("only single i64 function result supported");
263  return Error::success();
264 }
265 template <>
266 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
267  if (!isa<Float32Type>(
268  cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
269  .getReturnType()))
270  return makeStringError("only single f32 function result supported");
271  return Error::success();
272 }
273 template <typename Type>
275  Options &options, Operation *module, StringRef entryPoint,
276  CompileAndExecuteConfig config, std::unique_ptr<llvm::TargetMachine> tm) {
277  auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(
278  SymbolTable::lookupSymbolIn(module, entryPoint));
279  if (!mainFunction || mainFunction.isExternal())
280  return makeStringError("entry point not found");
281 
282  if (cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
283  .getNumParams() != 0)
284  return makeStringError(
285  "JIT can't invoke a main function expecting arguments");
286 
287  if (Error error = checkCompatibleReturnType<Type>(mainFunction))
288  return error;
289 
290  Type res;
291  struct {
292  void *data;
293  } data;
294  data.data = &res;
295  if (auto error =
296  compileAndExecute(options, module, entryPoint, std::move(config),
297  (void **)&data, std::move(tm)))
298  return error;
299 
300  // Intentional printing of the output so we can test.
301  llvm::outs() << res << '\n';
302 
303  return Error::success();
304 }
305 
306 /// Entry point for all CPU runners. Expects the common argc/argv arguments for
307 /// standard C++ main functions.
308 int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry &registry,
310  llvm::ExitOnError exitOnErr;
311 
312  // Create the options struct containing the command line options for the
313  // runner. This must come before the command line options are parsed.
314  Options options;
315  llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
316 
317  if (options.hostSupportsJit) {
318  auto j = llvm::orc::LLJITBuilder().create();
319  if (j)
320  llvm::outs() << "true\n";
321  else {
322  llvm::outs() << "false\n";
323  exitOnErr(j.takeError());
324  }
325  return 0;
326  }
327 
328  std::optional<unsigned> optLevel = getCommandLineOptLevel(options);
330  options.optO0, options.optO1, options.optO2, options.optO3};
331 
332  MLIRContext context(registry);
333 
334  auto m = parseMLIRInput(options.inputFilename, !options.noImplicitModule,
335  &context);
336  if (!m) {
337  llvm::errs() << "could not parse the input IR\n";
338  return 1;
339  }
340 
341  JitRunnerOptions runnerOptions{options.mainFuncName, options.mainFuncType};
342  if (config.mlirTransformer)
343  if (failed(config.mlirTransformer(m.get(), runnerOptions)))
344  return EXIT_FAILURE;
345 
346  auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
347  if (!tmBuilderOrError) {
348  llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
349  return EXIT_FAILURE;
350  }
351 
352  // Configure TargetMachine builder based on the command line options
353  llvm::SubtargetFeatures features;
354  if (!options.mAttrs.empty()) {
355  for (StringRef attr : options.mAttrs)
356  features.AddFeature(attr);
357  tmBuilderOrError->addFeatures(features.getFeatures());
358  }
359 
360  if (!options.mArch.empty()) {
361  tmBuilderOrError->getTargetTriple().setArchName(options.mArch);
362  }
363 
364  // Build TargetMachine
365  auto tmOrError = tmBuilderOrError->createTargetMachine();
366 
367  if (!tmOrError) {
368  llvm::errs() << "Failed to create a TargetMachine for the host\n";
369  exitOnErr(tmOrError.takeError());
370  }
371 
372  LLVM_DEBUG({
373  llvm::dbgs() << " JITTargetMachineBuilder is "
374  << llvm::orc::JITTargetMachineBuilderPrinter(*tmBuilderOrError,
375  "\n");
376  });
377 
378  CompileAndExecuteConfig compileAndExecuteConfig;
379  if (optLevel) {
380  compileAndExecuteConfig.transformer = mlir::makeOptimizingTransformer(
381  *optLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get());
382  }
383  compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder;
384  compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap;
385 
386  // Get the function used to compile and execute the module.
387  using CompileAndExecuteFnT =
388  Error (*)(Options &, Operation *, StringRef, CompileAndExecuteConfig,
389  std::unique_ptr<llvm::TargetMachine> tm);
390  auto compileAndExecuteFn =
391  StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
392  .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
393  .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>)
394  .Case("f32", compileAndExecuteSingleReturnFunction<float>)
395  .Case("void", compileAndExecuteVoidFunction)
396  .Default(nullptr);
397 
398  Error error = compileAndExecuteFn
399  ? compileAndExecuteFn(
400  options, m.get(), options.mainFuncName.getValue(),
401  compileAndExecuteConfig, std::move(tmOrError.get()))
402  : makeStringError("unsupported function type");
403 
404  int exitCode = EXIT_SUCCESS;
405  llvm::handleAllErrors(std::move(error),
406  [&exitCode](const llvm::ErrorInfoBase &info) {
407  llvm::errs() << "Error: ";
408  info.log(llvm::errs());
409  llvm::errs() << '\n';
410  exitCode = EXIT_FAILURE;
411  });
412 
413  return exitCode;
414 }
Error checkCompatibleReturnType< int64_t >(LLVM::LLVMFuncOp mainFunction)
Definition: JitRunner.cpp:257
Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction)
static Error compileAndExecute(Options &options, Operation *module, StringRef entryPoint, CompileAndExecuteConfig config, void **args, std::unique_ptr< llvm::TargetMachine > tm=nullptr)
Definition: JitRunner.cpp:181
Error compileAndExecuteSingleReturnFunction(Options &options, Operation *module, StringRef entryPoint, CompileAndExecuteConfig config, std::unique_ptr< llvm::TargetMachine > tm)
Definition: JitRunner.cpp:274
static Error compileAndExecuteVoidFunction(Options &options, Operation *module, StringRef entryPoint, CompileAndExecuteConfig config, std::unique_ptr< llvm::TargetMachine > tm)
Definition: JitRunner.cpp:222
Error checkCompatibleReturnType< float >(LLVM::LLVMFuncOp mainFunction)
Definition: JitRunner.cpp:266
static std::optional< unsigned > getCommandLineOptLevel(Options &options)
Definition: JitRunner.cpp:163
static Error makeStringError(const Twine &message)
Definition: JitRunner.cpp:158
Error checkCompatibleReturnType< int32_t >(LLVM::LLVMFuncOp mainFunction)
Definition: JitRunner.cpp:248
static OwningOpRef< Operation * > parseMLIRInput(StringRef inputFilename, bool insertImplicitModule, MLIRContext *context)
Definition: JitRunner.cpp:134
@ Error
static llvm::ManagedStatic< PassManagerOptions > options
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
static llvm::Expected< std::unique_ptr< ExecutionEngine > > create(Operation *op, const ExecutionEngineOptions &options={}, std::unique_ptr< llvm::TargetMachine > tm=nullptr)
Creates an execution engine for the given MLIR IR.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition: OwningOpRef.h:29
OpTy get() const
Allow accessing the internal op.
Definition: OwningOpRef.h:51
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
std::unique_ptr< llvm::MemoryBuffer > openInputFile(llvm::StringRef inputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for reading.
const FrozenRewritePatternSet GreedyRewriteConfig config
int JitRunnerMain(int argc, char **argv, const DialectRegistry &registry, JitRunnerConfig config={})
Entry point for all CPU runners.
Definition: JitRunner.cpp:308
std::function< llvm::Error(llvm::Module *)> makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel, llvm::TargetMachine *targetMachine)
Create a module transformer function for MLIR ExecutionEngine that runs LLVM IR passes corresponding ...
OwningOpRef< Operation * > parseSourceFileForTool(const std::shared_ptr< llvm::SourceMgr > &sourceMgr, const ParserConfig &config, bool insertImplicitModule)
This parses the file specified by the indicated SourceMgr.
std::optional< llvm::CodeGenOptLevel > jitCodeGenOptLevel
jitCodeGenOptLevel, when provided, is used as the optimization level for target code generation.
ArrayRef< StringRef > sharedLibPaths
If sharedLibPaths are provided, the underlying JIT-compilation will open and link the shared librarie...
bool enableObjectDump
If enableObjectCache is set, the JIT compiler will create one to store the object generated for the g...
llvm::function_ref< std::unique_ptr< llvm::Module >Operation *, llvm::LLVMContext &)> llvmModuleBuilder
If llvmModuleBuilder is provided, it will be used to create an LLVM module from the given MLIR IR.
llvm::function_ref< llvm::Error(llvm::Module *)> transformer
If transformer is provided, it will be called on the LLVM module during JIT-compilation and can be us...
Configuration to override functionality of the JitRunner.
Definition: JitRunner.h:48
JitRunner command line options used by JitRunnerConfig methods.
Definition: JitRunner.h:40
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.