1 //===- MlirOptMain.cpp - MLIR Optimizer Driver ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a utility that runs an optimization pass and prints the result back
10 // out. It is designed to support unit testing.
11 //
12 //===----------------------------------------------------------------------===//
17 #include "mlir/Debug/Counter.h"
23 #include "mlir/IR/AsmState.h"
24 #include "mlir/IR/Attributes.h"
25 #include "mlir/IR/BuiltinOps.h"
26 #include "mlir/IR/Diagnostics.h"
27 #include "mlir/IR/Dialect.h"
28 #include "mlir/IR/Location.h"
29 #include "mlir/IR/MLIRContext.h"
30 #include "mlir/Parser/Parser.h"
31 #include "mlir/Pass/Pass.h"
32 #include "mlir/Pass/PassManager.h"
34 #include "mlir/Support/Timing.h"
39 #include "llvm/ADT/StringRef.h"
40 #include "llvm/Support/CommandLine.h"
41 #include "llvm/Support/FileUtilities.h"
42 #include "llvm/Support/InitLLVM.h"
43 #include "llvm/Support/ManagedStatic.h"
44 #include "llvm/Support/Process.h"
45 #include "llvm/Support/Regex.h"
46 #include "llvm/Support/SourceMgr.h"
47 #include "llvm/Support/StringSaver.h"
48 #include "llvm/Support/ThreadPool.h"
49 #include "llvm/Support/ToolOutputFile.h"
51 using namespace mlir;
52 using namespace llvm;
54 namespace {
55 class BytecodeVersionParser : public cl::parser<std::optional<int64_t>> {
56 public:
57  BytecodeVersionParser(cl::Option &o)
58  : cl::parser<std::optional<int64_t>>(o) {}
60  bool parse(cl::Option &o, StringRef /*argName*/, StringRef arg,
61  std::optional<int64_t> &v) {
62  long long w;
63  if (getAsSignedInteger(arg, 10, w))
64  return o.error("Invalid argument '" + arg +
65  "', only integer is supported.");
66  v = w;
67  return false;
68  }
69 };
71 /// This class is intended to manage the handling of command line options for
72 /// creating a *-opt config. This is a singleton.
73 struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
74  MlirOptMainConfigCLOptions() {
75  // These options are static but all uses ExternalStorage to initialize the
76  // members of the parent class. This is unusual but since this class is a
77  // singleton it basically attaches command line option to the singleton
78  // members.
80  static cl::opt<bool, /*ExternalStorage=*/true> allowUnregisteredDialects(
81  "allow-unregistered-dialect",
82  cl::desc("Allow operation with no registered dialects"),
83  cl::location(allowUnregisteredDialectsFlag), cl::init(false));
85  static cl::opt<bool, /*ExternalStorage=*/true> dumpPassPipeline(
86  "dump-pass-pipeline", cl::desc("Print the pipeline that will be run"),
87  cl::location(dumpPassPipelineFlag), cl::init(false));
89  static cl::opt<bool, /*ExternalStorage=*/true> emitBytecode(
90  "emit-bytecode", cl::desc("Emit bytecode when generating output"),
91  cl::location(emitBytecodeFlag), cl::init(false));
93  static cl::opt<bool, /*ExternalStorage=*/true> elideResourcesFromBytecode(
94  "elide-resource-data-from-bytecode",
95  cl::desc("Elide resources when generating bytecode"),
96  cl::location(elideResourceDataFromBytecodeFlag), cl::init(false));
98  static cl::opt<std::optional<int64_t>, /*ExternalStorage=*/true,
99  BytecodeVersionParser>
100  bytecodeVersion(
101  "emit-bytecode-version",
102  cl::desc("Use specified bytecode when generating output"),
103  cl::location(emitBytecodeVersion), cl::init(std::nullopt));
105  static cl::opt<std::string, /*ExternalStorage=*/true> irdlFile(
106  "irdl-file",
107  cl::desc("IRDL file to register before processing the input"),
108  cl::location(irdlFileFlag), cl::init(""), cl::value_desc("filename"));
110  static cl::opt<bool, /*ExternalStorage=*/true> enableDebuggerHook(
111  "mlir-enable-debugger-hook",
112  cl::desc("Enable Debugger hook for debugging MLIR Actions"),
113  cl::location(enableDebuggerActionHookFlag), cl::init(false));
115  static cl::opt<bool, /*ExternalStorage=*/true> explicitModule(
116  "no-implicit-module",
117  cl::desc("Disable implicit addition of a top-level module op during "
118  "parsing"),
119  cl::location(useExplicitModuleFlag), cl::init(false));
121  static cl::opt<bool, /*ExternalStorage=*/true> runReproducer(
122  "run-reproducer", cl::desc("Run the pipeline stored in the reproducer"),
123  cl::location(runReproducerFlag), cl::init(false));
125  static cl::opt<bool, /*ExternalStorage=*/true> showDialects(
126  "show-dialects",
127  cl::desc("Print the list of registered dialects and exit"),
128  cl::location(showDialectsFlag), cl::init(false));
130  static cl::opt<std::string, /*ExternalStorage=*/true> splitInputFile{
131  "split-input-file", llvm::cl::ValueOptional,
132  cl::callback([&](const std::string &str) {
133  // Implicit value: use default marker if flag was used without value.
134  if (str.empty())
135  splitInputFile.setValue(kDefaultSplitMarker);
136  }),
137  cl::desc("Split the input file into chunks using the given or "
138  "default marker and process each chunk independently"),
139  cl::location(splitInputFileFlag), cl::init("")};
141  static cl::opt<std::string, /*ExternalStorage=*/true> outputSplitMarker(
142  "output-split-marker",
143  cl::desc("Split marker to use for merging the ouput"),
144  cl::location(outputSplitMarkerFlag), cl::init(kDefaultSplitMarker));
146  static cl::opt<bool, /*ExternalStorage=*/true> verifyDiagnostics(
147  "verify-diagnostics",
148  cl::desc("Check that emitted diagnostics match "
149  "expected-* lines on the corresponding line"),
150  cl::location(verifyDiagnosticsFlag), cl::init(false));
152  static cl::opt<bool, /*ExternalStorage=*/true> verifyPasses(
153  "verify-each",
154  cl::desc("Run the verifier after each transformation pass"),
155  cl::location(verifyPassesFlag), cl::init(true));
157  static cl::opt<bool, /*ExternalStorage=*/true> verifyRoundtrip(
158  "verify-roundtrip",
159  cl::desc("Round-trip the IR after parsing and ensure it succeeds"),
160  cl::location(verifyRoundtripFlag), cl::init(false));
162  static cl::list<std::string> passPlugins(
163  "load-pass-plugin", cl::desc("Load passes from plugin library"));
165  static cl::opt<std::string, /*ExternalStorage=*/true>
166  generateReproducerFile(
167  "mlir-generate-reproducer",
168  llvm::cl::desc(
169  "Generate an mlir reproducer at the provided filename"
170  " (no crash required)"),
171  cl::location(generateReproducerFileFlag), cl::init(""),
172  cl::value_desc("filename"));
174  /// Set the callback to load a pass plugin.
175  passPlugins.setCallback([&](const std::string &pluginPath) {
176  auto plugin = PassPlugin::load(pluginPath);
177  if (!plugin) {
178  errs() << "Failed to load passes from '" << pluginPath
179  << "'. Request ignored.\n";
180  return;
181  }
182  plugin.get().registerPassRegistryCallbacks();
183  });
185  static cl::list<std::string> dialectPlugins(
186  "load-dialect-plugin", cl::desc("Load dialects from plugin library"));
187  this->dialectPlugins = std::addressof(dialectPlugins);
189  static PassPipelineCLParser passPipeline("", "Compiler passes to run", "p");
190  setPassPipelineParser(passPipeline);
191  }
193  /// Set the callback to load a dialect plugin.
194  void setDialectPluginsCallback(DialectRegistry &registry);
196  /// Pointer to static dialectPlugins variable in constructor, needed by
197  /// setDialectPluginsCallback(DialectRegistry&).
198  cl::list<std::string> *dialectPlugins = nullptr;
199 };
200 } // namespace
202 ManagedStatic<MlirOptMainConfigCLOptions> clOptionsConfig;
205  clOptionsConfig->setDialectPluginsCallback(registry);
207 }
211  return *clOptionsConfig;
212 }
215  const PassPipelineCLParser &passPipeline) {
216  passPipelineCallback = [&](PassManager &pm) {
217  auto errorHandler = [&](const Twine &msg) {
218  emitError(UnknownLoc::get(pm.getContext())) << msg;
219  return failure();
220  };
221  if (failed(passPipeline.addToPipeline(pm, errorHandler)))
222  return failure();
223  if (this->shouldDumpPassPipeline()) {
225  pm.dump();
226  llvm::errs() << "\n";
227  }
228  return success();
229  };
230  return *this;
231 }
233 void MlirOptMainConfigCLOptions::setDialectPluginsCallback(
234  DialectRegistry &registry) {
235  dialectPlugins->setCallback([&](const std::string &pluginPath) {
236  auto plugin = DialectPlugin::load(pluginPath);
237  if (!plugin) {
238  errs() << "Failed to load dialect plugin from '" << pluginPath
239  << "'. Request ignored.\n";
240  return;
241  };
242  plugin.get().registerDialectRegistryCallbacks(registry);
243  });
244 }
246 LogicalResult loadIRDLDialects(StringRef irdlFile, MLIRContext &ctx) {
247  DialectRegistry registry;
248  registry.insert<irdl::IRDLDialect>();
249  ctx.appendDialectRegistry(registry);
251  // Set up the input file.
252  std::string errorMessage;
253  std::unique_ptr<MemoryBuffer> file = openInputFile(irdlFile, &errorMessage);
254  if (!file) {
255  emitError(UnknownLoc::get(&ctx)) << errorMessage;
256  return failure();
257  }
259  // Give the buffer to the source manager.
260  // This will be picked up by the parser.
261  SourceMgr sourceMgr;
262  sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
264  SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &ctx);
266  // Parse the input file.
267  OwningOpRef<ModuleOp> module(parseSourceFile<ModuleOp>(sourceMgr, &ctx));
268  if (!module)
269  return failure();
271  // Load IRDL dialects.
272  return irdl::loadDialects(module.get());
273 }
275 // Return success if the module can correctly round-trip. This intended to test
276 // that the custom printers/parsers are complete.
277 static LogicalResult doVerifyRoundTrip(Operation *op,
278  const MlirOptMainConfig &config,
279  bool useBytecode) {
280  // We use a new context to avoid resource handle renaming issue in the diff.
281  MLIRContext roundtripContext;
282  OwningOpRef<Operation *> roundtripModule;
283  roundtripContext.appendDialectRegistry(
284  op->getContext()->getDialectRegistry());
286  roundtripContext.allowUnregisteredDialects();
287  StringRef irdlFile = config.getIrdlFile();
288  if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, roundtripContext)))
289  return failure();
291  std::string testType = (useBytecode) ? "bytecode" : "textual";
292  // Print a first time with custom format (or bytecode) and parse it back to
293  // the roundtripModule.
294  {
295  std::string buffer;
296  llvm::raw_string_ostream ostream(buffer);
297  if (useBytecode) {
298  if (failed(writeBytecodeToFile(op, ostream))) {
299  op->emitOpError()
300  << "failed to write bytecode, cannot verify round-trip.\n";
301  return failure();
302  }
303  } else {
304  op->print(ostream,
305  OpPrintingFlags().printGenericOpForm().enableDebugInfo());
306  }
307  FallbackAsmResourceMap fallbackResourceMap;
308  ParserConfig parseConfig(&roundtripContext, /*verifyAfterParse=*/true,
309  &fallbackResourceMap);
310  roundtripModule =
311  parseSourceString<Operation *>(ostream.str(), parseConfig);
312  if (!roundtripModule) {
313  op->emitOpError() << "failed to parse " << testType
314  << " content back, cannot verify round-trip.\n";
315  return failure();
316  }
317  }
319  // Print in the generic form for the reference module and the round-tripped
320  // one and compare the outputs.
321  std::string reference, roundtrip;
322  {
323  llvm::raw_string_ostream ostreamref(reference);
324  op->print(ostreamref,
325  OpPrintingFlags().printGenericOpForm().enableDebugInfo());
326  llvm::raw_string_ostream ostreamrndtrip(roundtrip);
327  roundtripModule.get()->print(
328  ostreamrndtrip,
329  OpPrintingFlags().printGenericOpForm().enableDebugInfo());
330  }
331  if (reference != roundtrip) {
332  // TODO implement a diff.
333  return op->emitOpError()
334  << testType
335  << " roundTrip testing roundtripped module differs "
336  "from reference:\n<<<<<<Reference\n"
337  << reference << "\n=====\n"
338  << roundtrip << "\n>>>>>roundtripped\n";
339  }
341  return success();
342 }
344 static LogicalResult doVerifyRoundTrip(Operation *op,
345  const MlirOptMainConfig &config) {
346  auto txtStatus = doVerifyRoundTrip(op, config, /*useBytecode=*/false);
347  auto bcStatus = doVerifyRoundTrip(op, config, /*useBytecode=*/true);
348  return success(succeeded(txtStatus) && succeeded(bcStatus));
349 }
351 /// Perform the actions on the input file indicated by the command line flags
352 /// within the specified context.
353 ///
354 /// This typically parses the main source file, runs zero or more optimization
355 /// passes, then prints the output.
356 ///
357 static LogicalResult
358 performActions(raw_ostream &os,
359  const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
360  MLIRContext *context, const MlirOptMainConfig &config) {
363  TimingScope timing = tm.getRootScope();
365  // Disable multi-threading when parsing the input file. This removes the
366  // unnecessary/costly context synchronization when parsing.
367  bool wasThreadingEnabled = context->isMultithreadingEnabled();
368  context->disableMultithreading();
370  // Prepare the parser config, and attach any useful/necessary resource
371  // handlers. Unhandled external resources are treated as passthrough, i.e.
372  // they are not processed and will be emitted directly to the output
373  // untouched.
374  PassReproducerOptions reproOptions;
375  FallbackAsmResourceMap fallbackResourceMap;
376  ParserConfig parseConfig(context, /*verifyAfterParse=*/true,
377  &fallbackResourceMap);
378  if (config.shouldRunReproducer())
379  reproOptions.attachResourceParser(parseConfig);
381  // Parse the input file and reset the context threading state.
382  TimingScope parserTiming = timing.nest("Parser");
384  sourceMgr, parseConfig, !config.shouldUseExplicitModule());
385  parserTiming.stop();
386  if (!op)
387  return failure();
389  // Perform round-trip verification if requested
390  if (config.shouldVerifyRoundtrip() &&
391  failed(doVerifyRoundTrip(op.get(), config)))
392  return failure();
394  context->enableMultithreading(wasThreadingEnabled);
396  // Prepare the pass manager, applying command-line and reproducer options.
398  pm.enableVerifier(config.shouldVerifyPasses());
399  if (failed(applyPassManagerCLOptions(pm)))
400  return failure();
401  pm.enableTiming(timing);
402  if (config.shouldRunReproducer() && failed(reproOptions.apply(pm)))
403  return failure();
404  if (failed(config.setupPassPipeline(pm)))
405  return failure();
407  // Run the pipeline.
408  if (failed(*op)))
409  return failure();
411  // Generate reproducers if requested
412  if (!config.getReproducerFilename().empty()) {
413  StringRef anchorName = pm.getAnyOpAnchorName();
414  const auto &passes = pm.getPasses();
415  makeReproducer(anchorName, passes, op.get(),
416  config.getReproducerFilename());
417  }
419  // Print the output.
420  TimingScope outputTiming = timing.nest("Output");
421  if (config.shouldEmitBytecode()) {
422  BytecodeWriterConfig writerConfig(fallbackResourceMap);
423  if (auto v = config.bytecodeVersionToEmit())
424  writerConfig.setDesiredBytecodeVersion(*v);
426  writerConfig.setElideResourceDataFlag();
427  return writeBytecodeToFile(op.get(), os, writerConfig);
428  }
430  if (config.bytecodeVersionToEmit().has_value())
431  return emitError(UnknownLoc::get(pm.getContext()))
432  << "bytecode version while not emitting bytecode";
433  AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr,
434  &fallbackResourceMap);
435  op.get()->print(os, asmState);
436  os << '\n';
437  return success();
438 }
440 /// Parses the memory buffer. If successfully, run a series of passes against
441 /// it and print the result.
442 static LogicalResult processBuffer(raw_ostream &os,
443  std::unique_ptr<MemoryBuffer> ownedBuffer,
444  const MlirOptMainConfig &config,
445  DialectRegistry &registry,
446  llvm::ThreadPoolInterface *threadPool) {
447  // Tell sourceMgr about this buffer, which is what the parser will pick up.
448  auto sourceMgr = std::make_shared<SourceMgr>();
449  sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
451  // Create a context just for the current buffer. Disable threading on creation
452  // since we'll inject the thread-pool separately.
454  if (threadPool)
455  context.setThreadPool(*threadPool);
457  StringRef irdlFile = config.getIrdlFile();
458  if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, context)))
459  return failure();
461  // Parse the input file.
463  if (config.shouldVerifyDiagnostics())
464  context.printOpOnDiagnostic(false);
466  tracing::InstallDebugHandler installDebugHandler(context,
467  config.getDebugConfig());
469  // If we are in verify diagnostics mode then we have a lot of work to do,
470  // otherwise just perform the actions without worrying about it.
471  if (!config.shouldVerifyDiagnostics()) {
472  SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context);
473  return performActions(os, sourceMgr, &context, config);
474  }
476  SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context);
478  // Do any processing requested by command line flags. We don't care whether
479  // these actions succeed or fail, we only care what diagnostics they produce
480  // and whether they match our expectations.
481  (void)performActions(os, sourceMgr, &context, config);
483  // Verify the diagnostic handler to make sure that each of the diagnostics
484  // matched.
485  return sourceMgrHandler.verify();
486 }
488 std::pair<std::string, std::string>
489 mlir::registerAndParseCLIOptions(int argc, char **argv,
490  llvm::StringRef toolName,
491  DialectRegistry &registry) {
492  static cl::opt<std::string> inputFilename(
493  cl::Positional, cl::desc("<input file>"), cl::init("-"));
495  static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
496  cl::value_desc("filename"),
497  cl::init("-"));
498  // Register any command line options.
506  // Build the list of dialects as a header for the --help message.
507  std::string helpHeader = (toolName + "\nAvailable Dialects: ").str();
508  {
509  llvm::raw_string_ostream os(helpHeader);
510  interleaveComma(registry.getDialectNames(), os,
511  [&](auto name) { os << name; });
512  }
513  // Parse pass names in main to ensure static initialization completed.
514  cl::ParseCommandLineOptions(argc, argv, helpHeader);
515  return std::make_pair(inputFilename.getValue(), outputFilename.getValue());
516 }
518 static LogicalResult printRegisteredDialects(DialectRegistry &registry) {
519  llvm::outs() << "Available Dialects: ";
520  interleave(registry.getDialectNames(), llvm::outs(), ",");
521  llvm::outs() << "\n";
522  return success();
523 }
525 LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream,
526  std::unique_ptr<llvm::MemoryBuffer> buffer,
527  DialectRegistry &registry,
528  const MlirOptMainConfig &config) {
529  if (config.shouldShowDialects())
530  return printRegisteredDialects(registry);
532  // The split-input-file mode is a very specific mode that slices the file
533  // up into small pieces and checks each independently.
534  // We use an explicit threadpool to avoid creating and joining/destroying
535  // threads for each of the split.
536  ThreadPoolInterface *threadPool = nullptr;
538  // Create a temporary context for the sake of checking if
539  // --mlir-disable-threading was passed on the command line.
540  // We use the thread-pool this context is creating, and avoid
541  // creating any thread when disabled.
542  MLIRContext threadPoolCtx;
543  if (threadPoolCtx.isMultithreadingEnabled())
544  threadPool = &threadPoolCtx.getThreadPool();
546  auto chunkFn = [&](std::unique_ptr<MemoryBuffer> chunkBuffer,
547  raw_ostream &os) {
548  return processBuffer(os, std::move(chunkBuffer), config, registry,
549  threadPool);
550  };
551  return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream,
552  config.inputSplitMarker(),
553  config.outputSplitMarker());
554 }
556 LogicalResult mlir::MlirOptMain(int argc, char **argv,
557  llvm::StringRef inputFilename,
558  llvm::StringRef outputFilename,
559  DialectRegistry &registry) {
561  InitLLVM y(argc, argv);
565  if (config.shouldShowDialects())
566  return printRegisteredDialects(registry);
568  // When reading from stdin and the input is a tty, it is often a user mistake
569  // and the process "appears to be stuck". Print a message to let the user know
570  // about it!
571  if (inputFilename == "-" &&
572  sys::Process::FileDescriptorIsDisplayed(fileno(stdin)))
573  llvm::errs() << "(processing input from stdin now, hit ctrl-c/ctrl-d to "
574  "interrupt)\n";
576  // Set up the input file.
577  std::string errorMessage;
578  auto file = openInputFile(inputFilename, &errorMessage);
579  if (!file) {
580  llvm::errs() << errorMessage << "\n";
581  return failure();
582  }
584  auto output = openOutputFile(outputFilename, &errorMessage);
585  if (!output) {
586  llvm::errs() << errorMessage << "\n";
587  return failure();
588  }
589  if (failed(MlirOptMain(output->os(), std::move(file), registry, config)))
590  return failure();
592  // Keep the output file if the invocation of MlirOptMain was successful.
593  output->keep();
594  return success();
595 }
597 LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
598  DialectRegistry &registry) {
600  // Register and parse command line options.
601  std::string inputFilename, outputFilename;
602  std::tie(inputFilename, outputFilename) =
603  registerAndParseCLIOptions(argc, argv, toolName, registry);
605  return MlirOptMain(argc, argv, inputFilename, outputFilename, registry);
606 }
