MLIR  17.0.0git
PassRegistry.cpp
Go to the documentation of this file.
1 //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 #include <optional>
11 
12 #include "mlir/Pass/Pass.h"
13 #include "mlir/Pass/PassManager.h"
14 #include "mlir/Pass/PassRegistry.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/ScopeExit.h"
17 #include "llvm/Support/Format.h"
18 #include "llvm/Support/ManagedStatic.h"
19 #include "llvm/Support/MemoryBuffer.h"
20 #include "llvm/Support/SourceMgr.h"
21 
22 using namespace mlir;
23 using namespace detail;
24 
25 /// Static mapping of all of the registered passes.
26 static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
27 
28 /// A mapping of the above pass registry entries to the corresponding TypeID
29 /// of the pass that they generate.
30 static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
31 
32 /// Static mapping of all of the registered pass pipelines.
33 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
35 
36 /// Utility to create a default registry function from a pass instance.
39  return [=](OpPassManager &pm, StringRef options,
40  function_ref<LogicalResult(const Twine &)> errorHandler) {
41  std::unique_ptr<Pass> pass = allocator();
42  LogicalResult result = pass->initializeOptions(options);
43 
44  std::optional<StringRef> pmOpName = pm.getOpName();
45  std::optional<StringRef> passOpName = pass->getOpName();
46  if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && pmOpName &&
47  passOpName && *pmOpName != *passOpName) {
48  return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
49  "' restricted to '" + *pass->getOpName() +
50  "' on a PassManager intended to run on '" +
51  pm.getOpAnchorName() + "', did you intend to nest?");
52  }
53  pm.addPass(std::move(pass));
54  return result;
55  };
56 }
57 
58 /// Utility to print the help string for a specific option.
59 static void printOptionHelp(StringRef arg, StringRef desc, size_t indent,
60  size_t descIndent, bool isTopLevel) {
61  size_t numSpaces = descIndent - indent - 4;
62  llvm::outs().indent(indent)
63  << "--" << llvm::left_justify(arg, numSpaces) << "- " << desc << '\n';
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // PassRegistry
68 //===----------------------------------------------------------------------===//
69 
70 /// Print the help information for this pass. This includes the argument,
71 /// description, and any pass options. `descIndent` is the indent that the
72 /// descriptions should be aligned.
73 void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {
74  printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
75  /*isTopLevel=*/true);
76  // If this entry has options, print the help for those as well.
77  optHandler([=](const PassOptions &options) {
78  options.printHelp(indent, descIndent);
79  });
80 }
81 
82 /// Return the maximum width required when printing the options of this
83 /// entry.
85  size_t maxLen = 0;
86  optHandler([&](const PassOptions &options) mutable {
87  maxLen = options.getOptionWidth() + 2;
88  });
89  return maxLen;
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // PassPipelineInfo
94 //===----------------------------------------------------------------------===//
95 
97  StringRef arg, StringRef description, const PassRegistryFunction &function,
98  std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {
99  PassPipelineInfo pipelineInfo(arg, description, function,
100  std::move(optHandler));
101  bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;
102  assert(inserted && "Pass pipeline registered multiple times");
103  (void)inserted;
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // PassInfo
108 //===----------------------------------------------------------------------===//
109 
110 PassInfo::PassInfo(StringRef arg, StringRef description,
111  const PassAllocatorFunction &allocator)
113  arg, description, buildDefaultRegistryFn(allocator),
114  // Use a temporary pass to provide an options instance.
115  [=](function_ref<void(const PassOptions &)> optHandler) {
116  optHandler(allocator()->passOptions);
117  }) {}
118 
120  std::unique_ptr<Pass> pass = function();
121  StringRef arg = pass->getArgument();
122  if (arg.empty())
123  llvm::report_fatal_error(llvm::Twine("Trying to register '") +
124  pass->getName() +
125  "' pass that does not override `getArgument()`");
126  StringRef description = pass->getDescription();
127  PassInfo passInfo(arg, description, function);
128  passRegistry->try_emplace(arg, passInfo);
129 
130  // Verify that the registered pass has the same ID as any registered to this
131  // arg before it.
132  TypeID entryTypeID = pass->getTypeID();
133  auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;
134  if (it->second != entryTypeID)
135  llvm::report_fatal_error(
136  "pass allocator creates a different pass than previously "
137  "registered for pass " +
138  arg);
139 }
140 
141 /// Returns the pass info for the specified pass argument or null if unknown.
142 const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) {
143  auto it = passRegistry->find(passArg);
144  return it == passRegistry->end() ? nullptr : &it->second;
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // PassOptions
149 //===----------------------------------------------------------------------===//
150 
152  llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
153  function_ref<LogicalResult(StringRef)> elementParseFn) {
154  // Functor used for finding a character in a string, and skipping over
155  // various "range" characters.
156  llvm::unique_function<size_t(StringRef, size_t, char)> findChar =
157  [&](StringRef str, size_t index, char c) -> size_t {
158  for (size_t i = index, e = str.size(); i < e; ++i) {
159  if (str[i] == c)
160  return i;
161  // Check for various range characters.
162  if (str[i] == '{')
163  i = findChar(str, i + 1, '}');
164  else if (str[i] == '(')
165  i = findChar(str, i + 1, ')');
166  else if (str[i] == '[')
167  i = findChar(str, i + 1, ']');
168  else if (str[i] == '\"')
169  i = str.find_first_of('\"', i + 1);
170  else if (str[i] == '\'')
171  i = str.find_first_of('\'', i + 1);
172  }
173  return StringRef::npos;
174  };
175 
176  size_t nextElePos = findChar(optionStr, 0, ',');
177  while (nextElePos != StringRef::npos) {
178  // Process the portion before the comma.
179  if (failed(elementParseFn(optionStr.substr(0, nextElePos))))
180  return failure();
181 
182  optionStr = optionStr.substr(nextElePos + 1);
183  nextElePos = findChar(optionStr, 0, ',');
184  }
185  return elementParseFn(optionStr.substr(0, nextElePos));
186 }
187 
188 /// Out of line virtual function to provide home for the class.
189 void detail::PassOptions::OptionBase::anchor() {}
190 
191 /// Copy the option values from 'other'.
193  assert(options.size() == other.options.size());
194  if (options.empty())
195  return;
196  for (auto optionsIt : llvm::zip(options, other.options))
197  std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
198 }
199 
200 /// Parse in the next argument from the given options string. Returns a tuple
201 /// containing [the key of the option, the value of the option, updated
202 /// `options` string pointing after the parsed option].
203 static std::tuple<StringRef, StringRef, StringRef>
204 parseNextArg(StringRef options) {
205  // Functor used to extract an argument from 'options' and update it to point
206  // after the arg.
207  auto extractArgAndUpdateOptions = [&](size_t argSize) {
208  StringRef str = options.take_front(argSize).trim();
209  options = options.drop_front(argSize).ltrim();
210  return str;
211  };
212  // Try to process the given punctuation, properly escaping any contained
213  // characters.
214  auto tryProcessPunct = [&](size_t &currentPos, char punct) {
215  if (options[currentPos] != punct)
216  return false;
217  size_t nextIt = options.find_first_of(punct, currentPos + 1);
218  if (nextIt != StringRef::npos)
219  currentPos = nextIt;
220  return true;
221  };
222 
223  // Parse the argument name of the option.
224  StringRef argName;
225  for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
226  // Check for the end of the full option.
227  if (argEndIt == optionsE || options[argEndIt] == ' ') {
228  argName = extractArgAndUpdateOptions(argEndIt);
229  return std::make_tuple(argName, StringRef(), options);
230  }
231 
232  // Check for the end of the name and the start of the value.
233  if (options[argEndIt] == '=') {
234  argName = extractArgAndUpdateOptions(argEndIt);
235  options = options.drop_front();
236  break;
237  }
238  }
239 
240  // Parse the value of the option.
241  for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
242  // Handle the end of the options string.
243  if (argEndIt == optionsE || options[argEndIt] == ' ') {
244  StringRef value = extractArgAndUpdateOptions(argEndIt);
245  return std::make_tuple(argName, value, options);
246  }
247 
248  // Skip over escaped sequences.
249  char c = options[argEndIt];
250  if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
251  continue;
252  // '{...}' is used to specify options to passes, properly escape it so
253  // that we don't accidentally split any nested options.
254  if (c == '{') {
255  size_t braceCount = 1;
256  for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
257  // Allow nested punctuation.
258  if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
259  continue;
260  if (options[argEndIt] == '{')
261  ++braceCount;
262  else if (options[argEndIt] == '}' && --braceCount == 0)
263  break;
264  }
265  // Account for the increment at the top of the loop.
266  --argEndIt;
267  }
268  }
269  llvm_unreachable("unexpected control flow in pass option parsing");
270 }
271 
273  // NOTE: `options` is modified in place to always refer to the unprocessed
274  // part of the string.
275  while (!options.empty()) {
276  StringRef key, value;
277  std::tie(key, value, options) = parseNextArg(options);
278  if (key.empty())
279  continue;
280 
281  auto it = OptionsMap.find(key);
282  if (it == OptionsMap.end()) {
283  llvm::errs() << "<Pass-Options-Parser>: no such option " << key << "\n";
284  return failure();
285  }
286  if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
287  return failure();
288  }
289 
290  return success();
291 }
292 
293 /// Print the options held by this struct in a form that can be parsed via
294 /// 'parseFromString'.
295 void detail::PassOptions::print(raw_ostream &os) {
296  // If there are no options, there is nothing left to do.
297  if (OptionsMap.empty())
298  return;
299 
300  // Sort the options to make the ordering deterministic.
301  SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
302  auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
303  return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
304  };
305  llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
306 
307  // Interleave the options with ' '.
308  os << '{';
309  llvm::interleave(
310  orderedOps, os, [&](OptionBase *option) { option->print(os); }, " ");
311  os << '}';
312 }
313 
314 /// Print the help string for the options held by this struct. `descIndent` is
315 /// the indent within the stream that the descriptions should be aligned.
316 void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {
317  // Sort the options to make the ordering deterministic.
318  SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
319  auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
320  return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
321  };
322  llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
323  for (OptionBase *option : orderedOps) {
324  // TODO: printOptionInfo assumes a specific indent and will
325  // print options with values with incorrect indentation. We should add
326  // support to llvm::cl::Option for passing in a base indent to use when
327  // printing.
328  llvm::outs().indent(indent);
329  option->getOption()->printOptionInfo(descIndent - indent);
330  }
331 }
332 
333 /// Return the maximum width required when printing the help string.
335  size_t max = 0;
336  for (auto *option : options)
337  max = std::max(max, option->getOption()->getOptionWidth());
338  return max;
339 }
340 
341 //===----------------------------------------------------------------------===//
342 // MLIR Options
343 //===----------------------------------------------------------------------===//
344 
345 //===----------------------------------------------------------------------===//
346 // OpPassManager: OptionValue
347 
348 llvm::cl::OptionValue<OpPassManager>::OptionValue() = default;
349 llvm::cl::OptionValue<OpPassManager>::OptionValue(
350  const mlir::OpPassManager &value) {
351  setValue(value);
352 }
353 llvm::cl::OptionValue<OpPassManager>::OptionValue(
355  if (rhs.hasValue())
356  setValue(rhs.getValue());
357 }
358 llvm::cl::OptionValue<OpPassManager> &
359 llvm::cl::OptionValue<OpPassManager>::operator=(
360  const mlir::OpPassManager &rhs) {
361  setValue(rhs);
362  return *this;
363 }
364 
365 llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default;
366 
367 void llvm::cl::OptionValue<OpPassManager>::setValue(
368  const OpPassManager &newValue) {
369  if (hasValue())
370  *value = newValue;
371  else
372  value = std::make_unique<mlir::OpPassManager>(newValue);
373 }
374 void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
375  FailureOr<OpPassManager> pipeline = parsePassPipeline(pipelineStr);
376  assert(succeeded(pipeline) && "invalid pass pipeline");
377  setValue(*pipeline);
378 }
379 
381  const mlir::OpPassManager &rhs) const {
382  std::string lhsStr, rhsStr;
383  {
384  raw_string_ostream lhsStream(lhsStr);
385  value->printAsTextualPipeline(lhsStream);
386 
387  raw_string_ostream rhsStream(rhsStr);
388  rhs.printAsTextualPipeline(rhsStream);
389  }
390 
391  // Use the textual format for pipeline comparisons.
392  return lhsStr == rhsStr;
393 }
394 
395 void llvm::cl::OptionValue<OpPassManager>::anchor() {}
396 
397 //===----------------------------------------------------------------------===//
398 // OpPassManager: Parser
399 
400 namespace llvm {
401 namespace cl {
402 template class basic_parser<OpPassManager>;
403 } // namespace cl
404 } // namespace llvm
405 
406 bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg,
407  ParsedPassManager &value) {
409  if (failed(pipeline))
410  return true;
411  value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
412  return false;
413 }
414 
415 void llvm::cl::parser<OpPassManager>::print(raw_ostream &os,
416  const OpPassManager &value) {
417  value.printAsTextualPipeline(os);
418 }
419 
421  const Option &opt, OpPassManager &pm, const OptVal &defaultValue,
422  size_t globalWidth) const {
423  printOptionName(opt, globalWidth);
424  outs() << "= ";
425  pm.printAsTextualPipeline(outs());
426 
427  if (defaultValue.hasValue()) {
428  outs().indent(2) << " (default: ";
429  defaultValue.getValue().printAsTextualPipeline(outs());
430  outs() << ")";
431  }
432  outs() << "\n";
433 }
434 
436 
438  default;
440  ParsedPassManager &&) = default;
442  default;
443 
444 //===----------------------------------------------------------------------===//
445 // TextualPassPipeline Parser
446 //===----------------------------------------------------------------------===//
447 
448 namespace {
449 /// This class represents a textual description of a pass pipeline.
450 class TextualPipeline {
451 public:
452  /// Try to initialize this pipeline with the given pipeline text.
453  /// `errorStream` is the output stream to emit errors to.
454  LogicalResult initialize(StringRef text, raw_ostream &errorStream);
455 
456  /// Add the internal pipeline elements to the provided pass manager.
458  addToPipeline(OpPassManager &pm,
459  function_ref<LogicalResult(const Twine &)> errorHandler) const;
460 
461 private:
462  /// A functor used to emit errors found during pipeline handling. The first
463  /// parameter corresponds to the raw location within the pipeline string. This
464  /// should always return failure.
465  using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>;
466 
467  /// A struct to capture parsed pass pipeline names.
468  ///
469  /// A pipeline is defined as a series of names, each of which may in itself
470  /// recursively contain a nested pipeline. A name is either the name of a pass
471  /// (e.g. "cse") or the name of an operation type (e.g. "buitin.module"). If
472  /// the name is the name of a pass, the InnerPipeline is empty, since passes
473  /// cannot contain inner pipelines.
474  struct PipelineElement {
475  PipelineElement(StringRef name) : name(name) {}
476 
477  StringRef name;
478  StringRef options;
479  const PassRegistryEntry *registryEntry = nullptr;
480  std::vector<PipelineElement> innerPipeline;
481  };
482 
483  /// Parse the given pipeline text into the internal pipeline vector. This
484  /// function only parses the structure of the pipeline, and does not resolve
485  /// its elements.
486  LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
487 
488  /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
489  /// the corresponding registry entry.
491  resolvePipelineElements(MutableArrayRef<PipelineElement> elements,
492  ErrorHandlerT errorHandler);
493 
494  /// Resolve a single element of the pipeline.
495  LogicalResult resolvePipelineElement(PipelineElement &element,
496  ErrorHandlerT errorHandler);
497 
498  /// Add the given pipeline elements to the provided pass manager.
500  addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
501  function_ref<LogicalResult(const Twine &)> errorHandler) const;
502 
503  std::vector<PipelineElement> pipeline;
504 };
505 
506 } // namespace
507 
508 /// Try to initialize this pipeline with the given pipeline text. An option is
509 /// given to enable accurate error reporting.
510 LogicalResult TextualPipeline::initialize(StringRef text,
511  raw_ostream &errorStream) {
512  if (text.empty())
513  return success();
514 
515  // Build a source manager to use for error reporting.
516  llvm::SourceMgr pipelineMgr;
517  pipelineMgr.AddNewSourceBuffer(
518  llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser",
519  /*RequiresNullTerminator=*/false),
520  SMLoc());
521  auto errorHandler = [&](const char *rawLoc, Twine msg) {
522  pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc),
523  llvm::SourceMgr::DK_Error, msg);
524  return failure();
525  };
526 
527  // Parse the provided pipeline string.
528  if (failed(parsePipelineText(text, errorHandler)))
529  return failure();
530  return resolvePipelineElements(pipeline, errorHandler);
531 }
532 
533 /// Add the internal pipeline elements to the provided pass manager.
534 LogicalResult TextualPipeline::addToPipeline(
535  OpPassManager &pm,
536  function_ref<LogicalResult(const Twine &)> errorHandler) const {
537  // Temporarily disable implicit nesting while we append to the pipeline. We
538  // want the created pipeline to exactly match the parsed text pipeline, so
539  // it's preferrable to just error out if implicit nesting would be required.
540  OpPassManager::Nesting nesting = pm.getNesting();
542  auto restore = llvm::make_scope_exit([&]() { pm.setNesting(nesting); });
543 
544  return addToPipeline(pipeline, pm, errorHandler);
545 }
546 
547 /// Parse the given pipeline text into the internal pipeline vector. This
548 /// function only parses the structure of the pipeline, and does not resolve
549 /// its elements.
550 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
551  ErrorHandlerT errorHandler) {
552  SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
553  for (;;) {
554  std::vector<PipelineElement> &pipeline = *pipelineStack.back();
555  size_t pos = text.find_first_of(",(){");
556  pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());
557 
558  // If we have a single terminating name, we're done.
559  if (pos == StringRef::npos)
560  break;
561 
562  text = text.substr(pos);
563  char sep = text[0];
564 
565  // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
566  if (sep == '{') {
567  text = text.substr(1);
568 
569  // Skip over everything until the closing '}' and store as options.
570  size_t close = StringRef::npos;
571  for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
572  if (text[i] == '{') {
573  ++braceCount;
574  continue;
575  }
576  if (text[i] == '}' && --braceCount == 0) {
577  close = i;
578  break;
579  }
580  }
581 
582  // Check to see if a closing options brace was found.
583  if (close == StringRef::npos) {
584  return errorHandler(
585  /*rawLoc=*/text.data() - 1,
586  "missing closing '}' while processing pass options");
587  }
588  pipeline.back().options = text.substr(0, close);
589  text = text.substr(close + 1);
590 
591  // Skip checking for '(' because nested pipelines cannot have options.
592  } else if (sep == '(') {
593  text = text.substr(1);
594 
595  // Push the inner pipeline onto the stack to continue processing.
596  pipelineStack.push_back(&pipeline.back().innerPipeline);
597  continue;
598  }
599 
600  // When handling the close parenthesis, we greedily consume them to avoid
601  // empty strings in the pipeline.
602  while (text.consume_front(")")) {
603  // If we try to pop the outer pipeline we have unbalanced parentheses.
604  if (pipelineStack.size() == 1)
605  return errorHandler(/*rawLoc=*/text.data() - 1,
606  "encountered extra closing ')' creating unbalanced "
607  "parentheses while parsing pipeline");
608 
609  pipelineStack.pop_back();
610  }
611 
612  // Check if we've finished parsing.
613  if (text.empty())
614  break;
615 
616  // Otherwise, the end of an inner pipeline always has to be followed by
617  // a comma, and then we can continue.
618  if (!text.consume_front(","))
619  return errorHandler(text.data(), "expected ',' after parsing pipeline");
620  }
621 
622  // Check for unbalanced parentheses.
623  if (pipelineStack.size() > 1)
624  return errorHandler(
625  text.data(),
626  "encountered unbalanced parentheses while parsing pipeline");
627 
628  assert(pipelineStack.back() == &pipeline &&
629  "wrong pipeline at the bottom of the stack");
630  return success();
631 }
632 
633 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
634 /// the corresponding registry entry.
635 LogicalResult TextualPipeline::resolvePipelineElements(
636  MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) {
637  for (auto &elt : elements)
638  if (failed(resolvePipelineElement(elt, errorHandler)))
639  return failure();
640  return success();
641 }
642 
643 /// Resolve a single element of the pipeline.
645 TextualPipeline::resolvePipelineElement(PipelineElement &element,
646  ErrorHandlerT errorHandler) {
647  // If the inner pipeline of this element is not empty, this is an operation
648  // pipeline.
649  if (!element.innerPipeline.empty())
650  return resolvePipelineElements(element.innerPipeline, errorHandler);
651  // Otherwise, this must be a pass or pass pipeline.
652  // Check to see if a pipeline was registered with this name.
653  auto pipelineRegistryIt = passPipelineRegistry->find(element.name);
654  if (pipelineRegistryIt != passPipelineRegistry->end()) {
655  element.registryEntry = &pipelineRegistryIt->second;
656  return success();
657  }
658 
659  // If not, then this must be a specific pass name.
660  if ((element.registryEntry = Pass::lookupPassInfo(element.name)))
661  return success();
662 
663  // Emit an error for the unknown pass.
664  auto *rawLoc = element.name.data();
665  return errorHandler(rawLoc, "'" + element.name +
666  "' does not refer to a "
667  "registered pass or pass pipeline");
668 }
669 
670 /// Add the given pipeline elements to the provided pass manager.
671 LogicalResult TextualPipeline::addToPipeline(
673  function_ref<LogicalResult(const Twine &)> errorHandler) const {
674  for (auto &elt : elements) {
675  if (elt.registryEntry) {
676  if (failed(elt.registryEntry->addToPipeline(pm, elt.options,
677  errorHandler))) {
678  return errorHandler("failed to add `" + elt.name + "` with options `" +
679  elt.options + "`");
680  }
681  } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),
682  errorHandler))) {
683  return errorHandler("failed to add `" + elt.name + "` with options `" +
684  elt.options + "` to inner pipeline");
685  }
686  }
687  return success();
688 }
689 
691  raw_ostream &errorStream) {
692  TextualPipeline pipelineParser;
693  if (failed(pipelineParser.initialize(pipeline, errorStream)))
694  return failure();
695  auto errorHandler = [&](Twine msg) {
696  errorStream << msg << "\n";
697  return failure();
698  };
699  if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
700  return failure();
701  return success();
702 }
703 
705  raw_ostream &errorStream) {
706  // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
707  size_t pipelineStart = pipeline.find_first_of('(');
708  if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
709  !pipeline.consume_back(")")) {
710  errorStream << "expected pass pipeline to be wrapped with the anchor "
711  "operation type, e.g. 'builtin.module(...)'";
712  return failure();
713  }
714 
715  StringRef opName = pipeline.take_front(pipelineStart);
716  OpPassManager pm(opName);
717  if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm,
718  errorStream)))
719  return failure();
720  return pm;
721 }
722 
723 //===----------------------------------------------------------------------===//
724 // PassNameParser
725 //===----------------------------------------------------------------------===//
726 
727 namespace {
728 /// This struct represents the possible data entries in a parsed pass pipeline
729 /// list.
730 struct PassArgData {
731  PassArgData() = default;
732  PassArgData(const PassRegistryEntry *registryEntry)
733  : registryEntry(registryEntry) {}
734 
735  /// This field is used when the parsed option corresponds to a registered pass
736  /// or pass pipeline.
737  const PassRegistryEntry *registryEntry{nullptr};
738 
739  /// This field is set when instance specific pass options have been provided
740  /// on the command line.
741  StringRef options;
742 };
743 } // namespace
744 
745 namespace llvm {
746 namespace cl {
747 /// Define a valid OptionValue for the command line pass argument.
748 template <>
750  : OptionValueBase<PassArgData, /*isClass=*/true> {
751  OptionValue(const PassArgData &value) { this->setValue(value); }
752  OptionValue() = default;
753  void anchor() override {}
754 
755  bool hasValue() const { return true; }
756  const PassArgData &getValue() const { return value; }
757  void setValue(const PassArgData &value) { this->value = value; }
758 
759  PassArgData value;
760 };
761 } // namespace cl
762 } // namespace llvm
763 
764 namespace {
765 
766 /// The name for the command line option used for parsing the textual pass
767 /// pipeline.
768 #define PASS_PIPELINE_ARG "pass-pipeline"
769 
770 /// Adds command line option for each registered pass or pass pipeline, as well
771 /// as textual pass pipelines.
772 struct PassNameParser : public llvm::cl::parser<PassArgData> {
773  PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}
774 
775  void initialize();
776  void printOptionInfo(const llvm::cl::Option &opt,
777  size_t globalWidth) const override;
778  size_t getOptionWidth(const llvm::cl::Option &opt) const override;
779  bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
780  PassArgData &value);
781 
782  /// If true, this parser only parses entries that correspond to a concrete
783  /// pass registry entry, and does not include pipeline entries or the options
784  /// for pass entries.
785  bool passNamesOnly = false;
786 };
787 } // namespace
788 
789 void PassNameParser::initialize() {
791 
792  /// Add the pass entries.
793  for (const auto &kv : *passRegistry) {
794  addLiteralOption(kv.second.getPassArgument(), &kv.second,
795  kv.second.getPassDescription());
796  }
797  /// Add the pass pipeline entries.
798  if (!passNamesOnly) {
799  for (const auto &kv : *passPipelineRegistry) {
800  addLiteralOption(kv.second.getPassArgument(), &kv.second,
801  kv.second.getPassDescription());
802  }
803  }
804 }
805 
806 void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,
807  size_t globalWidth) const {
808  // If this parser is just parsing pass names, print a simplified option
809  // string.
810  if (passNamesOnly) {
811  llvm::outs() << " --" << opt.ArgStr << "=<pass-arg>";
812  opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
813  return;
814  }
815 
816  // Print the information for the top-level option.
817  if (opt.hasArgStr()) {
818  llvm::outs() << " --" << opt.ArgStr;
819  opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
820  } else {
821  llvm::outs() << " " << opt.HelpStr << '\n';
822  }
823 
824  // Functor used to print the ordered entries of a registration map.
825  auto printOrderedEntries = [&](StringRef header, auto &map) {
827  for (auto &kv : map)
828  orderedEntries.push_back(&kv.second);
829  llvm::array_pod_sort(
830  orderedEntries.begin(), orderedEntries.end(),
831  [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
832  return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
833  });
834 
835  llvm::outs().indent(4) << header << ":\n";
836  for (PassRegistryEntry *entry : orderedEntries)
837  entry->printHelpStr(/*indent=*/6, globalWidth);
838  };
839 
840  // Print the available passes.
841  printOrderedEntries("Passes", *passRegistry);
842 
843  // Print the available pass pipelines.
844  if (!passPipelineRegistry->empty())
845  printOrderedEntries("Pass Pipelines", *passPipelineRegistry);
846 }
847 
848 size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {
849  size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2;
850 
851  // Check for any wider pass or pipeline options.
852  for (auto &entry : *passRegistry)
853  maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
854  for (auto &entry : *passPipelineRegistry)
855  maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
856  return maxWidth;
857 }
858 
859 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
860  StringRef arg, PassArgData &value) {
861  if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value))
862  return true;
863  value.options = arg;
864  return false;
865 }
866 
867 //===----------------------------------------------------------------------===//
868 // PassPipelineCLParser
869 //===----------------------------------------------------------------------===//
870 
871 namespace mlir {
872 namespace detail {
874  PassPipelineCLParserImpl(StringRef arg, StringRef description,
875  bool passNamesOnly)
876  : passList(arg, llvm::cl::desc(description)) {
877  passList.getParser().passNamesOnly = passNamesOnly;
878  passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
879  }
880 
881  /// Returns true if the given pass registry entry was registered at the
882  /// top-level of the parser, i.e. not within an explicit textual pipeline.
883  bool contains(const PassRegistryEntry *entry) const {
884  return llvm::any_of(passList, [&](const PassArgData &data) {
885  return data.registryEntry == entry;
886  });
887  }
888 
889  /// The set of passes and pass pipelines to run.
890  llvm::cl::list<PassArgData, bool, PassNameParser> passList;
891 };
892 } // namespace detail
893 } // namespace mlir
894 
895 /// Construct a pass pipeline parser with the given command line description.
896 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)
897  : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
898  arg, description, /*passNamesOnly=*/false)),
899  passPipeline(
901  llvm::cl::desc("Textual description of the pass pipeline to run")) {}
902 
903 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description,
904  StringRef alias)
905  : PassPipelineCLParser(arg, description) {
906  passPipelineAlias.emplace(alias,
907  llvm::cl::desc("Alias for --" PASS_PIPELINE_ARG),
908  llvm::cl::aliasopt(passPipeline));
909 }
910 
912 
913 /// Returns true if this parser contains any valid options to add.
915  return passPipeline.getNumOccurrences() != 0 ||
916  impl->passList.getNumOccurrences() != 0;
917 }
918 
919 /// Returns true if the given pass registry entry was registered at the
920 /// top-level of the parser, i.e. not within an explicit textual pipeline.
922  return impl->contains(entry);
923 }
924 
925 /// Adds the passes defined by this parser entry to the given pass manager.
927  OpPassManager &pm,
928  function_ref<LogicalResult(const Twine &)> errorHandler) const {
929  if (passPipeline.getNumOccurrences()) {
930  if (impl->passList.getNumOccurrences())
931  return errorHandler(
932  "'-" PASS_PIPELINE_ARG
933  "' option can't be used with individual pass options");
934  std::string errMsg;
935  llvm::raw_string_ostream os(errMsg);
936  FailureOr<OpPassManager> parsed = parsePassPipeline(passPipeline, os);
937  if (failed(parsed))
938  return errorHandler(errMsg);
939  pm = std::move(*parsed);
940  return success();
941  }
942 
943  for (auto &passIt : impl->passList) {
944  if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
945  errorHandler)))
946  return failure();
947  }
948  return success();
949 }
950 
951 //===----------------------------------------------------------------------===//
952 // PassNameCLParser
953 
954 /// Construct a pass pipeline parser with the given command line description.
955 PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description)
956  : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
957  arg, description, /*passNamesOnly=*/true)) {
958  impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
959 }
961 
962 /// Returns true if this parser contains any valid options to add.
964  return impl->passList.getNumOccurrences() != 0;
965 }
966 
967 /// Returns true if the given pass registry entry was registered at the
968 /// top-level of the parser, i.e. not within an explicit textual pipeline.
970  return impl->contains(entry);
971 }
static llvm::ManagedStatic< PassManagerOptions > options
static llvm::ManagedStatic< llvm::StringMap< PassPipelineInfo > > passPipelineRegistry
Static mapping of all of the registered pass pipelines.
#define PASS_PIPELINE_ARG
The name for the command line option used for parsing the textual pass pipeline.
static llvm::ManagedStatic< llvm::StringMap< PassInfo > > passRegistry
Static mapping of all of the registered passes.
static PassRegistryFunction buildDefaultRegistryFn(const PassAllocatorFunction &allocator)
Utility to create a default registry function from a pass instance.
static void printOptionHelp(StringRef arg, StringRef desc, size_t indent, size_t descIndent, bool isTopLevel)
Utility to print the help string for a specific option.
static llvm::ManagedStatic< llvm::StringMap< TypeID > > passRegistryTypeIDs
A mapping of the above pass registry entries to the corresponding TypeID of the pass that they genera...
static std::tuple< StringRef, StringRef, StringRef > parseNextArg(StringRef options)
Parse in the next argument from the given options string.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class represents a pass manager that runs passes on either a specific operation type,...
Definition: PassManager.h:52
void printAsTextualPipeline(raw_ostream &os) const
Prints out the passes of the pass manager as the textual representation of pipelines.
Definition: Pass.cpp:365
std::optional< OperationName > getOpName(MLIRContext &context) const
Return the operation name that this pass manager operates on, or std::nullopt if this is an op-agnost...
Definition: Pass.cpp:355
void setNesting(Nesting nesting)
Enable or disable the implicit nesting on this particular PassManager.
Definition: Pass.cpp:392
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
Definition: Pass.cpp:336
Nesting getNesting()
Return the current nesting mode.
Definition: Pass.cpp:394
Nesting
This enum represents the nesting behavior of the pass manager.
Definition: PassManager.h:55
@ Explicit
Explicit nesting behavior.
StringRef getOpAnchorName() const
Return the name used to anchor this pass manager.
Definition: Pass.cpp:359
OpPassManager & nest(OperationName nestedName)
Nest a new operation pass manager for the given operation kind under this pass manager.
Definition: Pass.cpp:326
A structure to represent the information for a derived pass class.
Definition: PassRegistry.h:111
PassInfo(StringRef arg, StringRef description, const PassAllocatorFunction &allocator)
PassInfo constructor should not be invoked directly, instead use PassRegistration or registerPass.
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassNameCLParser(StringRef arg, StringRef description)
Construct a parser with the given command line description.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
This class implements a command-line parser for MLIR passes.
Definition: PassRegistry.h:237
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassPipelineCLParser(StringRef arg, StringRef description)
Construct a pass pipeline parser with the given command line description.
LogicalResult addToPipeline(OpPassManager &pm, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds the passes defined by this parser entry to the given pass manager.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
A structure to represent the information of a registered pass pipeline.
Definition: PassRegistry.h:101
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
Definition: PassRegistry.h:49
void printHelpStr(size_t indent, size_t descIndent) const
Print the help information for this pass.
size_t getOptionWidth() const
Return the maximum width required when printing the options of this entry.
const PassInfo * lookupPassInfo() const
Returns the pass info for this pass, or null if unknown.
Definition: Pass.h:62
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
Base container class and manager for all pass options.
Definition: PassOptions.h:79
size_t getOptionWidth() const
Return the maximum width required when printing the help string.
void printHelp(size_t indent, size_t descIndent) const
Print the help string for the options held by this struct.
LogicalResult parseFromString(StringRef options)
Parse options out as key=value pairs that can then be handed off to the llvm::cl command line passing...
void copyOptionValuesFrom(const PassOptions &other)
Copy the option values from 'other' into 'this', where 'other' has the same options as 'this'.
void print(raw_ostream &os)
Print the options held by this struct in a form that can be parsed via 'parseFromString'.
Include the generated interface declarations.
Definition: CallGraph.h:229
LogicalResult parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName, StringRef optionStr, function_ref< LogicalResult(StringRef)> elementParseFn)
Parse a string containing a list of comma-delimited elements, invoking the given parser for each sub-...
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:59
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::function< std::unique_ptr< Pass >()> PassAllocatorFunction
Definition: PassRegistry.h:41
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
void registerPass(const PassAllocatorFunction &function)
Register a specific dialect pass allocator function with the system, typically used through the PassR...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void registerPassPipeline(StringRef arg, StringRef description, const PassRegistryFunction &function, std::function< void(function_ref< void(const detail::PassOptions &)>)> optHandler)
Register a specific dialect pipeline registry function with the system, typically used through the Pa...
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
std::function< LogicalResult(OpPassManager &, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler)> PassRegistryFunction
A registry function that adds passes to the given pass manager.
Definition: PassRegistry.h:40
Define a valid OptionValue for the command line pass argument.
OptionValue(const PassArgData &value)
void setValue(const PassArgData &value)
const PassArgData & getValue() const
mlir::OpPassManager & getValue() const
Returns the current value of the option.
Definition: PassOptions.h:434
bool hasValue() const
Returns if the current option has a value.
Definition: PassOptions.h:431
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
llvm::cl::list< PassArgData, bool, PassNameParser > passList
The set of passes and pass pipelines to run.
PassPipelineCLParserImpl(StringRef arg, StringRef description, bool passNamesOnly)
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...