MLIR 22.0.0git
JitRunner.cpp
Go to the documentation of this file.
1//===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is a library that provides a shared implementation for command line
10// utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
11// IR before JIT-compiling and executing the latter.
12//
13// The translation can be customized by providing an MLIR to MLIR
14// transformation.
15//===----------------------------------------------------------------------===//
16
18
23#include "mlir/IR/MLIRContext.h"
24#include "mlir/Parser/Parser.h"
27
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
30#include "llvm/ExecutionEngine/Orc/LLJIT.h"
31#include "llvm/IR/IRBuilder.h"
32#include "llvm/IR/LLVMContext.h"
33#include "llvm/IR/LegacyPassNameParser.h"
34#include "llvm/Support/CommandLine.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Support/FileUtilities.h"
37#include "llvm/Support/SourceMgr.h"
38#include "llvm/Support/StringSaver.h"
39#include "llvm/Support/ToolOutputFile.h"
40#include <cstdint>
41#include <numeric>
42#include <optional>
43#include <utility>
44
45#define DEBUG_TYPE "jit-runner"
46
47using namespace mlir;
48using llvm::Error;
49
50namespace {
51/// This options struct prevents the need for global static initializers, and
52/// is only initialized if the JITRunner is invoked.
53struct Options {
54 llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional,
55 llvm::cl::desc("<input file>"),
56 llvm::cl::init("-")};
57 llvm::cl::opt<std::string> mainFuncName{
58 "e", llvm::cl::desc("The function to be called"),
59 llvm::cl::value_desc("<function name>"), llvm::cl::init("main")};
60 llvm::cl::opt<std::string> mainFuncType{
61 "entry-point-result",
62 llvm::cl::desc("Textual description of the function type to be called"),
63 llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")};
64
65 llvm::cl::OptionCategory optFlags{"opt-like flags"};
66
67 // CLI variables for -On options.
68 llvm::cl::opt<bool> optO0{"O0",
69 llvm::cl::desc("Run opt passes and codegen at O0"),
70 llvm::cl::cat(optFlags)};
71 llvm::cl::opt<bool> optO1{"O1",
72 llvm::cl::desc("Run opt passes and codegen at O1"),
73 llvm::cl::cat(optFlags)};
74 llvm::cl::opt<bool> optO2{"O2",
75 llvm::cl::desc("Run opt passes and codegen at O2"),
76 llvm::cl::cat(optFlags)};
77 llvm::cl::opt<bool> optO3{"O3",
78 llvm::cl::desc("Run opt passes and codegen at O3"),
79 llvm::cl::cat(optFlags)};
80
81 llvm::cl::list<std::string> mAttrs{
82 "mattr", llvm::cl::MiscFlags::CommaSeparated,
83 llvm::cl::desc("Target specific attributes (-mattr=help for details)"),
84 llvm::cl::value_desc("a1,+a2,-a3,..."), llvm::cl::cat(optFlags)};
85
86 llvm::cl::opt<std::string> mArch{
87 "march",
88 llvm::cl::desc("Architecture to generate code for (see --version)")};
89
90 llvm::cl::OptionCategory clOptionsCategory{"linking options"};
91 llvm::cl::list<std::string> clSharedLibs{
92 "shared-libs", llvm::cl::desc("Libraries to link dynamically"),
93 llvm::cl::MiscFlags::CommaSeparated, llvm::cl::cat(clOptionsCategory)};
94
95 /// CLI variables for debugging.
96 llvm::cl::opt<bool> dumpObjectFile{
97 "dump-object-file",
98 llvm::cl::desc("Dump JITted-compiled object to file specified with "
99 "-object-filename (<input file>.o by default).")};
100
101 llvm::cl::opt<std::string> objectFilename{
102 "object-filename",
103 llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
104
105 llvm::cl::opt<bool> hostSupportsJit{"host-supports-jit",
106 llvm::cl::desc("Report host JIT support"),
107 llvm::cl::Hidden};
108
109 llvm::cl::opt<bool> noImplicitModule{
110 "no-implicit-module",
111 llvm::cl::desc(
112 "Disable implicit addition of a top-level module op during parsing"),
113 llvm::cl::init(false)};
114};
115
116struct CompileAndExecuteConfig {
117 /// LLVM module transformer that is passed to ExecutionEngine.
118 std::function<llvm::Error(llvm::Module *)> transformer;
119
120 /// A custom function that is passed to ExecutionEngine. It processes MLIR
121 /// module and creates LLVM IR module.
123 llvm::LLVMContext &)>
124 llvmModuleBuilder;
125
126 /// A custom function that is passed to ExecutinEngine to register symbols at
127 /// runtime.
128 llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
129 runtimeSymbolMap;
130};
131
132} // namespace
133
134static OwningOpRef<Operation *> parseMLIRInput(StringRef inputFilename,
135 bool insertImplicitModule,
136 MLIRContext *context) {
137 // Set up the input file.
138 std::string errorMessage;
139 auto file = openInputFile(inputFilename, &errorMessage);
140 if (!file) {
141 llvm::errs() << errorMessage << "\n";
142 return nullptr;
143 }
144
145 auto sourceMgr = std::make_shared<llvm::SourceMgr>();
146 sourceMgr->AddNewSourceBuffer(std::move(file), SMLoc());
148 parseSourceFileForTool(sourceMgr, context, insertImplicitModule);
149 if (!module)
150 return nullptr;
151 if (!module.get()->hasTrait<OpTrait::SymbolTable>()) {
152 llvm::errs() << "Error: top-level op must be a symbol table.\n";
153 return nullptr;
154 }
155 return module;
156}
157
158static inline Error makeStringError(const Twine &message) {
159 return llvm::make_error<llvm::StringError>(message.str(),
160 llvm::inconvertibleErrorCode());
161}
162
163static std::optional<unsigned> getCommandLineOptLevel(Options &options) {
164 std::optional<unsigned> optLevel;
166 options.optO0, options.optO1, options.optO2, options.optO3};
167
168 // Determine if there is an optimization flag present.
169 for (unsigned j = 0; j < 4; ++j) {
170 auto &flag = optFlags[j].get();
171 if (flag) {
172 optLevel = j;
173 break;
174 }
175 }
176 return optLevel;
177}
178
179// JIT-compile the given module and run "entryPoint" with "args" as arguments.
180static Error
181compileAndExecute(Options &options, Operation *module, StringRef entryPoint,
182 CompileAndExecuteConfig config, void **args,
183 std::unique_ptr<llvm::TargetMachine> tm = nullptr) {
184 std::optional<llvm::CodeGenOptLevel> jitCodeGenOptLevel;
185 if (auto clOptLevel = getCommandLineOptLevel(options))
186 jitCodeGenOptLevel = static_cast<llvm::CodeGenOptLevel>(*clOptLevel);
187
188 SmallVector<StringRef, 4> sharedLibs(options.clSharedLibs.begin(),
189 options.clSharedLibs.end());
190
191 mlir::ExecutionEngineOptions engineOptions;
192 engineOptions.llvmModuleBuilder = config.llvmModuleBuilder;
193 if (config.transformer)
194 engineOptions.transformer = config.transformer;
195 engineOptions.jitCodeGenOptLevel = jitCodeGenOptLevel;
196 engineOptions.sharedLibPaths = sharedLibs;
197 engineOptions.enableObjectDump = true;
198 auto expectedEngine =
199 mlir::ExecutionEngine::create(module, engineOptions, std::move(tm));
200 if (!expectedEngine)
201 return expectedEngine.takeError();
202
203 auto engine = std::move(*expectedEngine);
204
205 engine->initialize();
206
207 auto expectedFPtr = engine->lookupPacked(entryPoint);
208 if (!expectedFPtr)
209 return expectedFPtr.takeError();
210
211 if (options.dumpObjectFile)
212 engine->dumpToObjectFile(options.objectFilename.empty()
213 ? options.inputFilename + ".o"
214 : options.objectFilename);
215
216 void (*fptr)(void **) = *expectedFPtr;
217 (*fptr)(args);
218
219 return Error::success();
220}
221
223 Options &options, Operation *module, StringRef entryPoint,
224 CompileAndExecuteConfig config, std::unique_ptr<llvm::TargetMachine> tm) {
225 auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(
226 SymbolTable::lookupSymbolIn(module, entryPoint));
227 if (!mainFunction || mainFunction.isExternal())
228 return makeStringError("entry point not found");
229
230 if (cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
231 .getNumParams() != 0)
232 return makeStringError(
233 "JIT can't invoke a main function expecting arguments");
234
235 auto resultType = dyn_cast<LLVM::LLVMVoidType>(
236 mainFunction.getFunctionType().getReturnType());
237 if (!resultType)
238 return makeStringError("expected void function");
239
240 void *empty = nullptr;
241 return compileAndExecute(options, module, entryPoint, std::move(config),
242 &empty, std::move(tm));
243}
244
245template <typename Type>
246Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction);
247template <>
248Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
249 auto resultType = dyn_cast<IntegerType>(
250 cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
251 .getReturnType());
252 if (!resultType || resultType.getWidth() != 32)
253 return makeStringError("only single i32 function result supported");
254 return Error::success();
255}
256template <>
257Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
258 auto resultType = dyn_cast<IntegerType>(
259 cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
260 .getReturnType());
261 if (!resultType || resultType.getWidth() != 64)
262 return makeStringError("only single i64 function result supported");
263 return Error::success();
264}
265template <>
266Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
267 if (!isa<Float32Type>(
268 cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
269 .getReturnType()))
270 return makeStringError("only single f32 function result supported");
271 return Error::success();
272}
273template <typename Type>
275 Options &options, Operation *module, StringRef entryPoint,
276 CompileAndExecuteConfig config, std::unique_ptr<llvm::TargetMachine> tm) {
277 auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(
278 SymbolTable::lookupSymbolIn(module, entryPoint));
279 if (!mainFunction || mainFunction.isExternal())
280 return makeStringError("entry point not found");
281
282 if (cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
283 .getNumParams() != 0)
284 return makeStringError(
285 "JIT can't invoke a main function expecting arguments");
286
287 if (Error error = checkCompatibleReturnType<Type>(mainFunction))
288 return error;
289
290 Type res;
291 struct {
292 void *data;
293 } data;
294 data.data = &res;
295 if (auto error =
296 compileAndExecute(options, module, entryPoint, std::move(config),
297 (void **)&data, std::move(tm)))
298 return error;
299
300 // Intentional printing of the output so we can test.
301 llvm::outs() << res << '\n';
302
303 return Error::success();
304}
305
306/// Entry point for all CPU runners. Expects the common argc/argv arguments for
307/// standard C++ main functions.
308int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry &registry,
310 llvm::ExitOnError exitOnErr;
311
312 // Create the options struct containing the command line options for the
313 // runner. This must come before the command line options are parsed.
314 Options options;
315 llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
316
317 if (options.hostSupportsJit) {
318 auto j = llvm::orc::LLJITBuilder().create();
319 if (j)
320 llvm::outs() << "true\n";
321 else {
322 llvm::outs() << "false\n";
323 exitOnErr(j.takeError());
324 }
325 return 0;
326 }
327
328 std::optional<unsigned> optLevel = getCommandLineOptLevel(options);
330 options.optO0, options.optO1, options.optO2, options.optO3};
331
332 MLIRContext context(registry);
333
334 auto m = parseMLIRInput(options.inputFilename, !options.noImplicitModule,
335 &context);
336 if (!m) {
337 llvm::errs() << "could not parse the input IR\n";
338 return 1;
339 }
340
341 JitRunnerOptions runnerOptions{options.mainFuncName, options.mainFuncType};
342 if (config.mlirTransformer)
343 if (failed(config.mlirTransformer(m.get(), runnerOptions)))
344 return EXIT_FAILURE;
345
346 auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
347 if (!tmBuilderOrError) {
348 llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
349 return EXIT_FAILURE;
350 }
351
352 // Configure TargetMachine builder based on the command line options
353 llvm::SubtargetFeatures features;
354 if (!options.mAttrs.empty()) {
355 for (StringRef attr : options.mAttrs)
356 features.AddFeature(attr);
357 tmBuilderOrError->addFeatures(features.getFeatures());
358 }
359
360 if (!options.mArch.empty()) {
361 tmBuilderOrError->getTargetTriple().setArchName(options.mArch);
362 }
363
364 // Build TargetMachine
365 auto tmOrError = tmBuilderOrError->createTargetMachine();
366
367 if (!tmOrError) {
368 llvm::errs() << "Failed to create a TargetMachine for the host\n";
369 exitOnErr(tmOrError.takeError());
370 }
371
372 LLVM_DEBUG({
373 llvm::dbgs() << " JITTargetMachineBuilder is "
374 << llvm::orc::JITTargetMachineBuilderPrinter(*tmBuilderOrError,
375 "\n");
376 });
377
378 CompileAndExecuteConfig compileAndExecuteConfig;
379 if (optLevel) {
380 compileAndExecuteConfig.transformer = mlir::makeOptimizingTransformer(
381 *optLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get());
382 }
383 compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder;
384 compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap;
385
386 // Get the function used to compile and execute the module.
387 using CompileAndExecuteFnT =
388 Error (*)(Options &, Operation *, StringRef, CompileAndExecuteConfig,
389 std::unique_ptr<llvm::TargetMachine> tm);
390 auto compileAndExecuteFn =
391 StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
395 .Case("void", compileAndExecuteVoidFunction)
396 .Default(nullptr);
397
398 Error error = compileAndExecuteFn
399 ? compileAndExecuteFn(
400 options, m.get(), options.mainFuncName.getValue(),
401 compileAndExecuteConfig, std::move(tmOrError.get()))
402 : makeStringError("unsupported function type");
403
404 int exitCode = EXIT_SUCCESS;
405 llvm::handleAllErrors(std::move(error),
406 [&exitCode](const llvm::ErrorInfoBase &info) {
407 llvm::errs() << "Error: ";
408 info.log(llvm::errs());
409 llvm::errs() << '\n';
410 exitCode = EXIT_FAILURE;
411 });
412
413 return exitCode;
414}
Error checkCompatibleReturnType< int64_t >(LLVM::LLVMFuncOp mainFunction)
Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction)
static Error compileAndExecute(Options &options, Operation *module, StringRef entryPoint, CompileAndExecuteConfig config, void **args, std::unique_ptr< llvm::TargetMachine > tm=nullptr)
static OwningOpRef< Operation * > parseMLIRInput(StringRef inputFilename, bool insertImplicitModule, MLIRContext *context)
static Error compileAndExecuteVoidFunction(Options &options, Operation *module, StringRef entryPoint, CompileAndExecuteConfig config, std::unique_ptr< llvm::TargetMachine > tm)
Error checkCompatibleReturnType< float >(LLVM::LLVMFuncOp mainFunction)
static Error compileAndExecuteSingleReturnFunction(Options &options, Operation *module, StringRef entryPoint, CompileAndExecuteConfig config, std::unique_ptr< llvm::TargetMachine > tm)
static Error makeStringError(const Twine &message)
static std::optional< unsigned > getCommandLineOptLevel(Options &options)
Error checkCompatibleReturnType< int32_t >(LLVM::LLVMFuncOp mainFunction)
@ Error
static llvm::ManagedStatic< PassManagerOptions > options
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
static llvm::Expected< std::unique_ptr< ExecutionEngine > > create(Operation *op, const ExecutionEngineOptions &options={}, std::unique_ptr< llvm::TargetMachine > tm=nullptr)
Creates an execution engine for the given MLIR IR.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition OwningOpRef.h:29
OpTy get() const
Allow accessing the internal op.
Definition OwningOpRef.h:51
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Include the generated interface declarations.
std::function< llvm::Error(llvm::Module *)> makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel, llvm::TargetMachine *targetMachine)
Create a module transformer function for MLIR ExecutionEngine that runs LLVM IR passes corresponding ...
const FrozenRewritePatternSet GreedyRewriteConfig config
int JitRunnerMain(int argc, char **argv, const DialectRegistry &registry, JitRunnerConfig config={})
Entry point for all CPU runners.
std::unique_ptr< llvm::MemoryBuffer > openInputFile(llvm::StringRef inputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for reading.
OwningOpRef< Operation * > parseSourceFileForTool(const std::shared_ptr< llvm::SourceMgr > &sourceMgr, const ParserConfig &config, bool insertImplicitModule)
This parses the file specified by the indicated SourceMgr.
llvm::StringSwitch< T, R > StringSwitch
Definition LLVM.h:141
llvm::function_ref< std::unique_ptr< llvm::Module >(Operation *, llvm::LLVMContext &)> llvmModuleBuilder
If llvmModuleBuilder is provided, it will be used to create an LLVM module from the given MLIR IR.
std::optional< llvm::CodeGenOptLevel > jitCodeGenOptLevel
jitCodeGenOptLevel, when provided, is used as the optimization level for target code generation.
ArrayRef< StringRef > sharedLibPaths
If sharedLibPaths are provided, the underlying JIT-compilation will open and link the shared librarie...
bool enableObjectDump
If enableObjectCache is set, the JIT compiler will create one to store the object generated for the g...
llvm::function_ref< llvm::Error(llvm::Module *)> transformer
If transformer is provided, it will be called on the LLVM module during JIT-compilation and can be us...
Configuration to override functionality of the JitRunner.
Definition JitRunner.h:48
JitRunner command line options used by JitRunnerConfig methods.
Definition JitRunner.h:40
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.