MLIR  22.0.0git
PassRegistry.cpp
Go to the documentation of this file.
1 //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/Pass/Pass.h"
12 #include "mlir/Pass/PassManager.h"
13 #include "llvm/ADT/ScopeExit.h"
14 #include "llvm/ADT/StringRef.h"
15 #include "llvm/Support/Format.h"
16 #include "llvm/Support/ManagedStatic.h"
17 #include "llvm/Support/MemoryBuffer.h"
18 #include "llvm/Support/SourceMgr.h"
19 
20 #include <optional>
21 #include <utility>
22 
23 using namespace mlir;
24 using namespace detail;
25 
26 /// Static mapping of all of the registered passes.
27 static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
28 
29 /// A mapping of the above pass registry entries to the corresponding TypeID
30 /// of the pass that they generate.
31 static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
32 
33 /// Static mapping of all of the registered pass pipelines.
34 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
36 
37 /// Utility to create a default registry function from a pass instance.
40  return [=](OpPassManager &pm, StringRef options,
41  function_ref<LogicalResult(const Twine &)> errorHandler) {
42  std::unique_ptr<Pass> pass = allocator();
43  LogicalResult result = pass->initializeOptions(options, errorHandler);
44 
45  std::optional<StringRef> pmOpName = pm.getOpName();
46  std::optional<StringRef> passOpName = pass->getOpName();
47  if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && pmOpName &&
48  passOpName && *pmOpName != *passOpName) {
49  return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
50  "' restricted to '" + *pass->getOpName() +
51  "' on a PassManager intended to run on '" +
52  pm.getOpAnchorName() + "', did you intend to nest?");
53  }
54  pm.addPass(std::move(pass));
55  return result;
56  };
57 }
58 
59 /// Utility to print the help string for a specific option.
60 static void printOptionHelp(StringRef arg, StringRef desc, size_t indent,
61  size_t descIndent, bool isTopLevel) {
62  size_t numSpaces = descIndent - indent - 4;
63  llvm::outs().indent(indent)
64  << "--" << llvm::left_justify(arg, numSpaces) << "- " << desc << '\n';
65 }
66 
67 //===----------------------------------------------------------------------===//
68 // PassRegistry
69 //===----------------------------------------------------------------------===//
70 
71 /// Prints the passes that were previously registered and stored in passRegistry
73  size_t maxWidth = 0;
74  for (auto &entry : *passRegistry)
75  maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
76 
77  // Functor used to print the ordered entries of a registration map.
78  auto printOrderedEntries = [&](StringRef header, auto &map) {
80  for (auto &kv : map)
81  orderedEntries.push_back(&kv.second);
82  llvm::array_pod_sort(
83  orderedEntries.begin(), orderedEntries.end(),
84  [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
85  return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
86  });
87 
88  llvm::outs().indent(0) << header << ":\n";
89  for (PassRegistryEntry *entry : orderedEntries)
90  entry->printHelpStr(/*indent=*/2, maxWidth);
91  };
92 
93  // Print the available passes.
94  printOrderedEntries("Passes", *passRegistry);
95 }
96 
97 /// Print the help information for this pass. This includes the argument,
98 /// description, and any pass options. `descIndent` is the indent that the
99 /// descriptions should be aligned.
100 void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {
101  printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
102  /*isTopLevel=*/true);
103  // If this entry has options, print the help for those as well.
104  optHandler([=](const PassOptions &options) {
105  options.printHelp(indent, descIndent);
106  });
107 }
108 
109 /// Return the maximum width required when printing the options of this
110 /// entry.
112  size_t maxLen = 0;
113  optHandler([&](const PassOptions &options) mutable {
114  maxLen = options.getOptionWidth() + 2;
115  });
116  return maxLen;
117 }
118 
119 //===----------------------------------------------------------------------===//
120 // PassPipelineInfo
121 //===----------------------------------------------------------------------===//
122 
124  StringRef arg, StringRef description, const PassRegistryFunction &function,
125  std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {
126  PassPipelineInfo pipelineInfo(arg, description, function,
127  std::move(optHandler));
128  bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;
129 #ifndef NDEBUG
130  if (!inserted)
131  report_fatal_error("Pass pipeline " + arg + " registered multiple times");
132 #endif
133  (void)inserted;
134 }
135 
136 //===----------------------------------------------------------------------===//
137 // PassInfo
138 //===----------------------------------------------------------------------===//
139 
140 PassInfo::PassInfo(StringRef arg, StringRef description,
141  const PassAllocatorFunction &allocator)
143  arg, description, buildDefaultRegistryFn(allocator),
144  // Use a temporary pass to provide an options instance.
145  [=](function_ref<void(const PassOptions &)> optHandler) {
146  optHandler(allocator()->passOptions);
147  }) {}
148 
150  std::unique_ptr<Pass> pass = function();
151  StringRef arg = pass->getArgument();
152  if (arg.empty())
153  llvm::report_fatal_error(llvm::Twine("Trying to register '") +
154  pass->getName() +
155  "' pass that does not override `getArgument()`");
156  StringRef description = pass->getDescription();
157  PassInfo passInfo(arg, description, function);
158  passRegistry->try_emplace(arg, passInfo);
159 
160  // Verify that the registered pass has the same ID as any registered to this
161  // arg before it.
162  TypeID entryTypeID = pass->getTypeID();
163  auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;
164  if (it->second != entryTypeID)
165  llvm::report_fatal_error(
166  "pass allocator creates a different pass than previously "
167  "registered for pass " +
168  arg);
169 }
170 
171 /// Returns the pass info for the specified pass argument or null if unknown.
172 const PassInfo *mlir::PassInfo::lookup(StringRef passArg) {
173  auto it = passRegistry->find(passArg);
174  return it == passRegistry->end() ? nullptr : &it->second;
175 }
176 
177 /// Returns the pass pipeline info for the specified pass pipeline argument or
178 /// null if unknown.
179 const PassPipelineInfo *mlir::PassPipelineInfo::lookup(StringRef pipelineArg) {
180  auto it = passPipelineRegistry->find(pipelineArg);
181  return it == passPipelineRegistry->end() ? nullptr : &it->second;
182 }
183 
184 //===----------------------------------------------------------------------===//
185 // PassOptions
186 //===----------------------------------------------------------------------===//
187 
188 /// Attempt to find the next occurance of character 'c' in the string starting
189 /// from the `index`-th position , omitting any occurances that appear within
190 /// intervening ranges or literals.
191 static size_t findChar(StringRef str, size_t index, char c) {
192  for (size_t i = index, e = str.size(); i < e; ++i) {
193  if (str[i] == c)
194  return i;
195  // Check for various range characters.
196  if (str[i] == '{')
197  i = findChar(str, i + 1, '}');
198  else if (str[i] == '(')
199  i = findChar(str, i + 1, ')');
200  else if (str[i] == '[')
201  i = findChar(str, i + 1, ']');
202  else if (str[i] == '\"')
203  i = str.find_first_of('\"', i + 1);
204  else if (str[i] == '\'')
205  i = str.find_first_of('\'', i + 1);
206  if (i == StringRef::npos)
207  return StringRef::npos;
208  }
209  return StringRef::npos;
210 }
211 
212 /// Extract an argument from 'options' and update it to point after the arg.
213 /// Returns the cleaned argument string.
214 static StringRef extractArgAndUpdateOptions(StringRef &options,
215  size_t argSize) {
216  StringRef str = options.take_front(argSize).trim();
217  options = options.drop_front(argSize).ltrim();
218 
219  // Early exit if there's no escape sequence.
220  if (str.size() <= 1)
221  return str;
222 
223  const auto escapePairs = {std::make_pair('\'', '\''),
224  std::make_pair('"', '"')};
225  for (const auto &escape : escapePairs) {
226  if (str.front() == escape.first && str.back() == escape.second) {
227  // Drop the escape characters and trim.
228  // Don't process additional escape sequences.
229  return str.drop_front().drop_back().trim();
230  }
231  }
232 
233  // Arguments may be wrapped in `{...}`. Unlike the quotation markers that
234  // denote literals, we respect scoping here. The outer `{...}` should not
235  // be stripped in cases such as "arg={...},{...}", which can be used to denote
236  // lists of nested option structs.
237  if (str.front() == '{') {
238  unsigned match = findChar(str, 1, '}');
239  if (match == str.size() - 1)
240  str = str.drop_front().drop_back().trim();
241  }
242 
243  return str;
244 }
245 
247  llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
248  function_ref<LogicalResult(StringRef)> elementParseFn) {
249  if (optionStr.empty())
250  return success();
251 
252  size_t nextElePos = findChar(optionStr, 0, ',');
253  while (nextElePos != StringRef::npos) {
254  // Process the portion before the comma.
255  if (failed(
256  elementParseFn(extractArgAndUpdateOptions(optionStr, nextElePos))))
257  return failure();
258 
259  // Drop the leading ','
260  optionStr = optionStr.drop_front();
261  nextElePos = findChar(optionStr, 0, ',');
262  }
263  return elementParseFn(
264  extractArgAndUpdateOptions(optionStr, optionStr.size()));
265 }
266 
267 /// Out of line virtual function to provide home for the class.
268 void detail::PassOptions::OptionBase::anchor() {}
269 
270 /// Copy the option values from 'other'.
272  assert(options.size() == other.options.size());
273  if (options.empty())
274  return;
275  for (auto optionsIt : llvm::zip(options, other.options))
276  std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
277 }
278 
279 /// Parse in the next argument from the given options string. Returns a tuple
280 /// containing [the key of the option, the value of the option, updated
281 /// `options` string pointing after the parsed option].
282 static std::tuple<StringRef, StringRef, StringRef>
283 parseNextArg(StringRef options) {
284  // Try to process the given punctuation, properly escaping any contained
285  // characters.
286  auto tryProcessPunct = [&](size_t &currentPos, char punct) {
287  if (options[currentPos] != punct)
288  return false;
289  size_t nextIt = options.find_first_of(punct, currentPos + 1);
290  if (nextIt != StringRef::npos)
291  currentPos = nextIt;
292  return true;
293  };
294 
295  // Parse the argument name of the option.
296  StringRef argName;
297  for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
298  // Check for the end of the full option.
299  if (argEndIt == optionsE || options[argEndIt] == ' ') {
300  argName = extractArgAndUpdateOptions(options, argEndIt);
301  return std::make_tuple(argName, StringRef(), options);
302  }
303 
304  // Check for the end of the name and the start of the value.
305  if (options[argEndIt] == '=') {
306  argName = extractArgAndUpdateOptions(options, argEndIt);
307  options = options.drop_front();
308  break;
309  }
310  }
311 
312  // Parse the value of the option.
313  for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
314  // Handle the end of the options string.
315  if (argEndIt == optionsE || options[argEndIt] == ' ') {
316  StringRef value = extractArgAndUpdateOptions(options, argEndIt);
317  return std::make_tuple(argName, value, options);
318  }
319 
320  // Skip over escaped sequences.
321  char c = options[argEndIt];
322  if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
323  continue;
324  // '{...}' is used to specify options to passes, properly escape it so
325  // that we don't accidentally split any nested options.
326  if (c == '{') {
327  size_t braceCount = 1;
328  for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
329  // Allow nested punctuation.
330  if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
331  continue;
332  if (options[argEndIt] == '{')
333  ++braceCount;
334  else if (options[argEndIt] == '}' && --braceCount == 0)
335  break;
336  }
337  // Account for the increment at the top of the loop.
338  --argEndIt;
339  }
340  }
341  llvm_unreachable("unexpected control flow in pass option parsing");
342 }
343 
345  raw_ostream &errorStream) {
346  // NOTE: `options` is modified in place to always refer to the unprocessed
347  // part of the string.
348  while (!options.empty()) {
349  StringRef key, value;
350  std::tie(key, value, options) = parseNextArg(options);
351  if (key.empty())
352  continue;
353 
354  auto it = OptionsMap.find(key);
355  if (it == OptionsMap.end()) {
356  errorStream << "<Pass-Options-Parser>: no such option " << key << "\n";
357  return failure();
358  }
359  if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
360  return failure();
361  }
362 
363  return success();
364 }
365 
366 /// Print the options held by this struct in a form that can be parsed via
367 /// 'parseFromString'.
368 void detail::PassOptions::print(raw_ostream &os) const {
369  // If there are no options, there is nothing left to do.
370  if (OptionsMap.empty())
371  return;
372 
373  // Sort the options to make the ordering deterministic.
374  SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
375  auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
376  return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
377  };
378  llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
379 
380  // Interleave the options with ' '.
381  os << '{';
382  llvm::interleave(
383  orderedOps, os, [&](OptionBase *option) { option->print(os); }, " ");
384  os << '}';
385 }
386 
387 /// Print the help string for the options held by this struct. `descIndent` is
388 /// the indent within the stream that the descriptions should be aligned.
389 void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {
390  // Sort the options to make the ordering deterministic.
391  SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
392  auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
393  return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
394  };
395  llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
396  for (OptionBase *option : orderedOps) {
397  // TODO: printOptionInfo assumes a specific indent and will
398  // print options with values with incorrect indentation. We should add
399  // support to llvm::cl::Option for passing in a base indent to use when
400  // printing.
401  llvm::outs().indent(indent);
402  option->getOption()->printOptionInfo(descIndent - indent);
403  }
404 }
405 
406 /// Return the maximum width required when printing the help string.
408  size_t max = 0;
409  for (auto *option : options)
410  max = std::max(max, option->getOption()->getOptionWidth());
411  return max;
412 }
413 
414 //===----------------------------------------------------------------------===//
415 // MLIR Options
416 //===----------------------------------------------------------------------===//
417 
418 //===----------------------------------------------------------------------===//
419 // OpPassManager: OptionValue
420 //===----------------------------------------------------------------------===//
421 
422 llvm::cl::OptionValue<OpPassManager>::OptionValue() = default;
423 llvm::cl::OptionValue<OpPassManager>::OptionValue(
424  const mlir::OpPassManager &value) {
425  setValue(value);
426 }
427 llvm::cl::OptionValue<OpPassManager>::OptionValue(
429  if (rhs.hasValue())
430  setValue(rhs.getValue());
431 }
432 llvm::cl::OptionValue<OpPassManager> &
433 llvm::cl::OptionValue<OpPassManager>::operator=(
434  const mlir::OpPassManager &rhs) {
435  setValue(rhs);
436  return *this;
437 }
438 
439 llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default;
440 
441 void llvm::cl::OptionValue<OpPassManager>::setValue(
442  const OpPassManager &newValue) {
443  if (hasValue())
444  *value = newValue;
445  else
446  value = std::make_unique<mlir::OpPassManager>(newValue);
447 }
448 void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
449  FailureOr<OpPassManager> pipeline = parsePassPipeline(pipelineStr);
450  assert(succeeded(pipeline) && "invalid pass pipeline");
451  setValue(*pipeline);
452 }
453 
455  const mlir::OpPassManager &rhs) const {
456  std::string lhsStr, rhsStr;
457  {
458  raw_string_ostream lhsStream(lhsStr);
459  value->printAsTextualPipeline(lhsStream);
460 
461  raw_string_ostream rhsStream(rhsStr);
462  rhs.printAsTextualPipeline(rhsStream);
463  }
464 
465  // Use the textual format for pipeline comparisons.
466  return lhsStr == rhsStr;
467 }
468 
469 void llvm::cl::OptionValue<OpPassManager>::anchor() {}
470 
471 //===----------------------------------------------------------------------===//
472 // OpPassManager: Parser
473 //===----------------------------------------------------------------------===//
474 
475 namespace llvm {
476 namespace cl {
477 template class basic_parser<OpPassManager>;
478 } // namespace cl
479 } // namespace llvm
480 
481 bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg,
482  ParsedPassManager &value) {
483  FailureOr<OpPassManager> pipeline = parsePassPipeline(arg);
484  if (failed(pipeline))
485  return true;
486  value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
487  return false;
488 }
489 
490 void llvm::cl::parser<OpPassManager>::print(raw_ostream &os,
491  const OpPassManager &value) {
492  value.printAsTextualPipeline(os);
493 }
494 
496  const Option &opt, OpPassManager &pm, const OptVal &defaultValue,
497  size_t globalWidth) const {
498  printOptionName(opt, globalWidth);
499  outs() << "= ";
500  pm.printAsTextualPipeline(outs());
501 
502  if (defaultValue.hasValue()) {
503  outs().indent(2) << " (default: ";
504  defaultValue.getValue().printAsTextualPipeline(outs());
505  outs() << ")";
506  }
507  outs() << "\n";
508 }
509 
511 
513  default;
515  ParsedPassManager &&) = default;
517  default;
518 
519 //===----------------------------------------------------------------------===//
520 // TextualPassPipeline Parser
521 //===----------------------------------------------------------------------===//
522 
523 namespace {
524 /// This class represents a textual description of a pass pipeline.
525 class TextualPipeline {
526 public:
527  /// Try to initialize this pipeline with the given pipeline text.
528  /// `errorStream` is the output stream to emit errors to.
529  LogicalResult initialize(StringRef text, raw_ostream &errorStream);
530 
531  /// Add the internal pipeline elements to the provided pass manager.
532  LogicalResult
533  addToPipeline(OpPassManager &pm,
534  function_ref<LogicalResult(const Twine &)> errorHandler) const;
535 
536 private:
537  /// A functor used to emit errors found during pipeline handling. The first
538  /// parameter corresponds to the raw location within the pipeline string. This
539  /// should always return failure.
540  using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>;
541 
542  /// A struct to capture parsed pass pipeline names.
543  ///
544  /// A pipeline is defined as a series of names, each of which may in itself
545  /// recursively contain a nested pipeline. A name is either the name of a pass
546  /// (e.g. "cse") or the name of an operation type (e.g. "buitin.module"). If
547  /// the name is the name of a pass, the InnerPipeline is empty, since passes
548  /// cannot contain inner pipelines.
549  struct PipelineElement {
550  PipelineElement(StringRef name) : name(name) {}
551 
552  StringRef name;
553  StringRef options;
554  const PassRegistryEntry *registryEntry = nullptr;
555  std::vector<PipelineElement> innerPipeline;
556  };
557 
558  /// Parse the given pipeline text into the internal pipeline vector. This
559  /// function only parses the structure of the pipeline, and does not resolve
560  /// its elements.
561  LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
562 
563  /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
564  /// the corresponding registry entry.
565  LogicalResult
566  resolvePipelineElements(MutableArrayRef<PipelineElement> elements,
567  ErrorHandlerT errorHandler);
568 
569  /// Resolve a single element of the pipeline.
570  LogicalResult resolvePipelineElement(PipelineElement &element,
571  ErrorHandlerT errorHandler);
572 
573  /// Add the given pipeline elements to the provided pass manager.
574  LogicalResult
575  addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
576  function_ref<LogicalResult(const Twine &)> errorHandler) const;
577 
578  std::vector<PipelineElement> pipeline;
579 };
580 
581 } // namespace
582 
583 /// Try to initialize this pipeline with the given pipeline text. An option is
584 /// given to enable accurate error reporting.
585 LogicalResult TextualPipeline::initialize(StringRef text,
586  raw_ostream &errorStream) {
587  if (text.empty())
588  return success();
589 
590  // Build a source manager to use for error reporting.
591  llvm::SourceMgr pipelineMgr;
592  pipelineMgr.AddNewSourceBuffer(
593  llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser",
594  /*RequiresNullTerminator=*/false),
595  SMLoc());
596  auto errorHandler = [&](const char *rawLoc, Twine msg) {
597  pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc),
598  llvm::SourceMgr::DK_Error, msg);
599  return failure();
600  };
601 
602  // Parse the provided pipeline string.
603  if (failed(parsePipelineText(text, errorHandler)))
604  return failure();
605  return resolvePipelineElements(pipeline, errorHandler);
606 }
607 
608 /// Add the internal pipeline elements to the provided pass manager.
609 LogicalResult TextualPipeline::addToPipeline(
610  OpPassManager &pm,
611  function_ref<LogicalResult(const Twine &)> errorHandler) const {
612  // Temporarily disable implicit nesting while we append to the pipeline. We
613  // want the created pipeline to exactly match the parsed text pipeline, so
614  // it's preferrable to just error out if implicit nesting would be required.
615  OpPassManager::Nesting nesting = pm.getNesting();
617  auto restore = llvm::make_scope_exit([&]() { pm.setNesting(nesting); });
618 
619  return addToPipeline(pipeline, pm, errorHandler);
620 }
621 
622 /// Parse the given pipeline text into the internal pipeline vector. This
623 /// function only parses the structure of the pipeline, and does not resolve
624 /// its elements.
625 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
626  ErrorHandlerT errorHandler) {
627  SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
628  for (;;) {
629  std::vector<PipelineElement> &pipeline = *pipelineStack.back();
630  size_t pos = text.find_first_of(",(){");
631  pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());
632 
633  // If we have a single terminating name, we're done.
634  if (pos == StringRef::npos)
635  break;
636 
637  text = text.substr(pos);
638  char sep = text[0];
639 
640  // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
641  if (sep == '{') {
642  text = text.substr(1);
643 
644  // Skip over everything until the closing '}' and store as options.
645  size_t close = StringRef::npos;
646  for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
647  if (text[i] == '{') {
648  ++braceCount;
649  continue;
650  }
651  if (text[i] == '}' && --braceCount == 0) {
652  close = i;
653  break;
654  }
655  }
656 
657  // Check to see if a closing options brace was found.
658  if (close == StringRef::npos) {
659  return errorHandler(
660  /*rawLoc=*/text.data() - 1,
661  "missing closing '}' while processing pass options");
662  }
663  pipeline.back().options = text.substr(0, close);
664  text = text.substr(close + 1);
665 
666  // Consume space characters that an user might add for readability.
667  text = text.ltrim();
668 
669  // Skip checking for '(' because nested pipelines cannot have options.
670  } else if (sep == '(') {
671  text = text.substr(1);
672 
673  // Push the inner pipeline onto the stack to continue processing.
674  pipelineStack.push_back(&pipeline.back().innerPipeline);
675  continue;
676  }
677 
678  // When handling the close parenthesis, we greedily consume them to avoid
679  // empty strings in the pipeline.
680  while (text.consume_front(")")) {
681  // If we try to pop the outer pipeline we have unbalanced parentheses.
682  if (pipelineStack.size() == 1)
683  return errorHandler(/*rawLoc=*/text.data() - 1,
684  "encountered extra closing ')' creating unbalanced "
685  "parentheses while parsing pipeline");
686 
687  pipelineStack.pop_back();
688  // Consume space characters that an user might add for readability.
689  text = text.ltrim();
690  }
691 
692  // Check if we've finished parsing.
693  if (text.empty())
694  break;
695 
696  // Otherwise, the end of an inner pipeline always has to be followed by
697  // a comma, and then we can continue.
698  if (!text.consume_front(","))
699  return errorHandler(text.data(), "expected ',' after parsing pipeline");
700  }
701 
702  // Check for unbalanced parentheses.
703  if (pipelineStack.size() > 1)
704  return errorHandler(
705  text.data(),
706  "encountered unbalanced parentheses while parsing pipeline");
707 
708  assert(pipelineStack.back() == &pipeline &&
709  "wrong pipeline at the bottom of the stack");
710  return success();
711 }
712 
713 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
714 /// the corresponding registry entry.
715 LogicalResult TextualPipeline::resolvePipelineElements(
716  MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) {
717  for (auto &elt : elements)
718  if (failed(resolvePipelineElement(elt, errorHandler)))
719  return failure();
720  return success();
721 }
722 
723 /// Resolve a single element of the pipeline.
724 LogicalResult
725 TextualPipeline::resolvePipelineElement(PipelineElement &element,
726  ErrorHandlerT errorHandler) {
727  // If the inner pipeline of this element is not empty, this is an operation
728  // pipeline.
729  if (!element.innerPipeline.empty())
730  return resolvePipelineElements(element.innerPipeline, errorHandler);
731 
732  // Otherwise, this must be a pass or pass pipeline.
733  // Check to see if a pipeline was registered with this name.
734  if ((element.registryEntry = PassPipelineInfo::lookup(element.name)))
735  return success();
736 
737  // If not, then this must be a specific pass name.
738  if ((element.registryEntry = PassInfo::lookup(element.name)))
739  return success();
740 
741  // Emit an error for the unknown pass.
742  auto *rawLoc = element.name.data();
743  return errorHandler(rawLoc, "'" + element.name +
744  "' does not refer to a "
745  "registered pass or pass pipeline");
746 }
747 
748 /// Add the given pipeline elements to the provided pass manager.
749 LogicalResult TextualPipeline::addToPipeline(
751  function_ref<LogicalResult(const Twine &)> errorHandler) const {
752  for (auto &elt : elements) {
753  if (elt.registryEntry) {
754  if (failed(elt.registryEntry->addToPipeline(pm, elt.options,
755  errorHandler))) {
756  return errorHandler("failed to add `" + elt.name + "` with options `" +
757  elt.options + "`");
758  }
759  } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),
760  errorHandler))) {
761  return errorHandler("failed to add `" + elt.name + "` with options `" +
762  elt.options + "` to inner pipeline");
763  }
764  }
765  return success();
766 }
767 
768 LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
769  raw_ostream &errorStream) {
770  TextualPipeline pipelineParser;
771  if (failed(pipelineParser.initialize(pipeline, errorStream)))
772  return failure();
773  auto errorHandler = [&](Twine msg) {
774  errorStream << msg << "\n";
775  return failure();
776  };
777  if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
778  return failure();
779  return success();
780 }
781 
782 FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline,
783  raw_ostream &errorStream) {
784  pipeline = pipeline.trim();
785  // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
786  size_t pipelineStart = pipeline.find_first_of('(');
787  if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
788  !pipeline.consume_back(")")) {
789  errorStream << "expected pass pipeline to be wrapped with the anchor "
790  "operation type, e.g. 'builtin.module(...)'";
791  return failure();
792  }
793 
794  StringRef opName = pipeline.take_front(pipelineStart).rtrim();
795  OpPassManager pm(opName);
796  if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm,
797  errorStream)))
798  return failure();
799  return pm;
800 }
801 
802 //===----------------------------------------------------------------------===//
803 // PassNameParser
804 //===----------------------------------------------------------------------===//
805 
806 namespace {
807 /// This struct represents the possible data entries in a parsed pass pipeline
808 /// list.
809 struct PassArgData {
810  PassArgData() = default;
811  PassArgData(const PassRegistryEntry *registryEntry)
812  : registryEntry(registryEntry) {}
813 
814  /// This field is used when the parsed option corresponds to a registered pass
815  /// or pass pipeline.
816  const PassRegistryEntry *registryEntry{nullptr};
817 
818  /// This field is set when instance specific pass options have been provided
819  /// on the command line.
820  StringRef options;
821 };
822 } // namespace
823 
824 namespace llvm {
825 namespace cl {
826 /// Define a valid OptionValue for the command line pass argument.
827 template <>
829  : OptionValueBase<PassArgData, /*isClass=*/true> {
830  OptionValue(const PassArgData &value) { this->setValue(value); }
831  OptionValue() = default;
832  void anchor() override {}
833 
834  bool hasValue() const { return true; }
835  const PassArgData &getValue() const { return value; }
836  void setValue(const PassArgData &value) { this->value = value; }
837 
838  PassArgData value;
839 };
840 } // namespace cl
841 } // namespace llvm
842 
843 namespace {
844 
845 /// The name for the command line option used for parsing the textual pass
846 /// pipeline.
847 #define PASS_PIPELINE_ARG "pass-pipeline"
848 
849 /// Adds command line option for each registered pass or pass pipeline, as well
850 /// as textual pass pipelines.
851 struct PassNameParser : public llvm::cl::parser<PassArgData> {
852  PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}
853 
854  void initialize();
855  void printOptionInfo(const llvm::cl::Option &opt,
856  size_t globalWidth) const override;
857  size_t getOptionWidth(const llvm::cl::Option &opt) const override;
858  bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
859  PassArgData &value);
860 
861  /// If true, this parser only parses entries that correspond to a concrete
862  /// pass registry entry, and does not include pipeline entries or the options
863  /// for pass entries.
864  bool passNamesOnly = false;
865 };
866 } // namespace
867 
868 void PassNameParser::initialize() {
870 
871  /// Add the pass entries.
872  for (const auto &kv : *passRegistry) {
873  addLiteralOption(kv.second.getPassArgument(), &kv.second,
874  kv.second.getPassDescription());
875  }
876  /// Add the pass pipeline entries.
877  if (!passNamesOnly) {
878  for (const auto &kv : *passPipelineRegistry) {
879  addLiteralOption(kv.second.getPassArgument(), &kv.second,
880  kv.second.getPassDescription());
881  }
882  }
883 }
884 
885 void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,
886  size_t globalWidth) const {
887  // If this parser is just parsing pass names, print a simplified option
888  // string.
889  if (passNamesOnly) {
890  llvm::outs() << " --" << opt.ArgStr << "=<pass-arg>";
891  opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
892  return;
893  }
894 
895  // Print the information for the top-level option.
896  if (opt.hasArgStr()) {
897  llvm::outs() << " --" << opt.ArgStr;
898  opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
899  } else {
900  llvm::outs() << " " << opt.HelpStr << '\n';
901  }
902 
903  // Functor used to print the ordered entries of a registration map.
904  auto printOrderedEntries = [&](StringRef header, auto &map) {
906  for (auto &kv : map)
907  orderedEntries.push_back(&kv.second);
908  llvm::array_pod_sort(
909  orderedEntries.begin(), orderedEntries.end(),
910  [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
911  return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
912  });
913 
914  llvm::outs().indent(4) << header << ":\n";
915  for (PassRegistryEntry *entry : orderedEntries)
916  entry->printHelpStr(/*indent=*/6, globalWidth);
917  };
918 
919  // Print the available passes.
920  printOrderedEntries("Passes", *passRegistry);
921 
922  // Print the available pass pipelines.
923  if (!passPipelineRegistry->empty())
924  printOrderedEntries("Pass Pipelines", *passPipelineRegistry);
925 }
926 
927 size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {
928  size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2;
929 
930  // Check for any wider pass or pipeline options.
931  for (auto &entry : *passRegistry)
932  maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
933  for (auto &entry : *passPipelineRegistry)
934  maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
935  return maxWidth;
936 }
937 
938 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
939  StringRef arg, PassArgData &value) {
940  if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value))
941  return true;
942  value.options = arg;
943  return false;
944 }
945 
946 //===----------------------------------------------------------------------===//
947 // PassPipelineCLParser
948 //===----------------------------------------------------------------------===//
949 
950 namespace mlir {
951 namespace detail {
953  PassPipelineCLParserImpl(StringRef arg, StringRef description,
954  bool passNamesOnly)
955  : passList(arg, llvm::cl::desc(description)) {
956  passList.getParser().passNamesOnly = passNamesOnly;
957  passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
958  }
959 
960  /// Returns true if the given pass registry entry was registered at the
961  /// top-level of the parser, i.e. not within an explicit textual pipeline.
962  bool contains(const PassRegistryEntry *entry) const {
963  return llvm::any_of(passList, [&](const PassArgData &data) {
964  return data.registryEntry == entry;
965  });
966  }
967 
968  /// The set of passes and pass pipelines to run.
969  llvm::cl::list<PassArgData, bool, PassNameParser> passList;
970 };
971 } // namespace detail
972 } // namespace mlir
973 
974 /// Construct a pass pipeline parser with the given command line description.
975 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)
976  : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
977  arg, description, /*passNamesOnly=*/false)),
978  passPipeline(
980  llvm::cl::desc("Textual description of the pass pipeline to run")) {}
981 
982 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description,
983  StringRef alias)
984  : PassPipelineCLParser(arg, description) {
985  passPipelineAlias.emplace(alias,
986  llvm::cl::desc("Alias for --" PASS_PIPELINE_ARG),
987  llvm::cl::aliasopt(passPipeline));
988 }
989 
991 
992 /// Returns true if this parser contains any valid options to add.
994  return passPipeline.getNumOccurrences() != 0 ||
995  impl->passList.getNumOccurrences() != 0;
996 }
997 
998 /// Returns true if the given pass registry entry was registered at the
999 /// top-level of the parser, i.e. not within an explicit textual pipeline.
1001  return impl->contains(entry);
1002 }
1003 
1004 /// Adds the passes defined by this parser entry to the given pass manager.
1006  OpPassManager &pm,
1007  function_ref<LogicalResult(const Twine &)> errorHandler) const {
1008  if (passPipeline.getNumOccurrences()) {
1009  if (impl->passList.getNumOccurrences())
1010  return errorHandler(
1011  "'-" PASS_PIPELINE_ARG
1012  "' option can't be used with individual pass options");
1013  std::string errMsg;
1014  llvm::raw_string_ostream os(errMsg);
1015  FailureOr<OpPassManager> parsed = parsePassPipeline(passPipeline, os);
1016  if (failed(parsed))
1017  return errorHandler(errMsg);
1018  pm = std::move(*parsed);
1019  return success();
1020  }
1021 
1022  for (auto &passIt : impl->passList) {
1023  if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
1024  errorHandler)))
1025  return failure();
1026  }
1027  return success();
1028 }
1029 
1030 //===----------------------------------------------------------------------===//
1031 // PassNameCLParser
1032 //===----------------------------------------------------------------------===//
1033 
1034 /// Construct a pass pipeline parser with the given command line description.
1035 PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description)
1036  : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
1037  arg, description, /*passNamesOnly=*/true)) {
1038  impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
1039 }
1041 
1042 /// Returns true if this parser contains any valid options to add.
1044  return impl->passList.getNumOccurrences() != 0;
1045 }
1046 
1047 /// Returns true if the given pass registry entry was registered at the
1048 /// top-level of the parser, i.e. not within an explicit textual pipeline.
1050  return impl->contains(entry);
1051 }
static llvm::ManagedStatic< PassManagerOptions > options
static llvm::ManagedStatic< llvm::StringMap< PassPipelineInfo > > passPipelineRegistry
Static mapping of all of the registered pass pipelines.
#define PASS_PIPELINE_ARG
The name for the command line option used for parsing the textual pass pipeline.
static llvm::ManagedStatic< llvm::StringMap< PassInfo > > passRegistry
Static mapping of all of the registered passes.
static PassRegistryFunction buildDefaultRegistryFn(const PassAllocatorFunction &allocator)
Utility to create a default registry function from a pass instance.
static void printOptionHelp(StringRef arg, StringRef desc, size_t indent, size_t descIndent, bool isTopLevel)
Utility to print the help string for a specific option.
static llvm::ManagedStatic< llvm::StringMap< TypeID > > passRegistryTypeIDs
A mapping of the above pass registry entries to the corresponding TypeID of the pass that they genera...
static std::tuple< StringRef, StringRef, StringRef > parseNextArg(StringRef options)
Parse in the next argument from the given options string.
static size_t findChar(StringRef str, size_t index, char c)
Attempt to find the next occurance of character 'c' in the string starting from the index-th position...
static StringRef extractArgAndUpdateOptions(StringRef &options, size_t argSize)
Extract an argument from 'options' and update it to point after the arg.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
This class represents a pass manager that runs passes on either a specific operation type,...
Definition: PassManager.h:46
void printAsTextualPipeline(raw_ostream &os, bool pretty=false) const
Prints out the passes of the pass manager as the textual representation of pipelines.
Definition: Pass.cpp:427
std::optional< OperationName > getOpName(MLIRContext &context) const
Return the operation name that this pass manager operates on, or std::nullopt if this is an op-agnost...
Definition: Pass.cpp:386
void setNesting(Nesting nesting)
Enable or disable the implicit nesting on this particular PassManager.
Definition: Pass.cpp:453
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
Definition: Pass.cpp:367
Nesting getNesting()
Return the current nesting mode.
Definition: Pass.cpp:455
Nesting
This enum represents the nesting behavior of the pass manager.
Definition: PassManager.h:49
@ Explicit
Explicit nesting behavior.
StringRef getOpAnchorName() const
Return the name used to anchor this pass manager.
Definition: Pass.cpp:390
OpPassManager & nest(OperationName nestedName)
Nest a new operation pass manager for the given operation kind under this pass manager.
Definition: Pass.cpp:357
A structure to represent the information for a derived pass class.
Definition: PassRegistry.h:118
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
PassInfo(StringRef arg, StringRef description, const PassAllocatorFunction &allocator)
PassInfo constructor should not be invoked directly, instead use PassRegistration or registerPass.
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassNameCLParser(StringRef arg, StringRef description)
Construct a parser with the given command line description.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
This class implements a command-line parser for MLIR passes.
Definition: PassRegistry.h:247
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassPipelineCLParser(StringRef arg, StringRef description)
Construct a pass pipeline parser with the given command line description.
LogicalResult addToPipeline(OpPassManager &pm, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds the passes defined by this parser entry to the given pass manager.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
A structure to represent the information of a registered pass pipeline.
Definition: PassRegistry.h:104
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
Definition: PassRegistry.h:52
void printHelpStr(size_t indent, size_t descIndent) const
Print the help information for this pass.
size_t getOptionWidth() const
Return the maximum width required when printing the options of this entry.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
Base container class and manager for all pass options.
Definition: PassOptions.h:89
size_t getOptionWidth() const
Return the maximum width required when printing the help string.
void printHelp(size_t indent, size_t descIndent) const
Print the help string for the options held by this struct.
LogicalResult parseFromString(StringRef options, raw_ostream &errorStream=llvm::errs())
Parse options out as key=value pairs that can then be handed off to the llvm::cl command line passing...
void print(raw_ostream &os) const
Print the options held by this struct in a form that can be parsed via 'parseFromString'.
void copyOptionValuesFrom(const PassOptions &other)
Copy the option values from 'other' into 'this', where 'other' has the same options as 'this'.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
LogicalResult parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName, StringRef optionStr, function_ref< LogicalResult(StringRef)> elementParseFn)
Parse a string containing a list of comma-delimited elements, invoking the given parser for each sub-...
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:68
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
void printRegisteredPasses()
Prints the passes that were previously registered and stored in passRegistry.
std::function< std::unique_ptr< Pass >()> PassAllocatorFunction
Definition: PassRegistry.h:41
void registerPass(const PassAllocatorFunction &function)
Register a specific dialect pass allocator function with the system, typically used through the PassR...
void registerPassPipeline(StringRef arg, StringRef description, const PassRegistryFunction &function, std::function< void(function_ref< void(const detail::PassOptions &)>)> optHandler)
Register a specific dialect pipeline registry function with the system, typically used through the Pa...
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
std::function< LogicalResult(OpPassManager &, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler)> PassRegistryFunction
A registry function that adds passes to the given pass manager.
Definition: PassRegistry.h:40
Define a valid OptionValue for the command line pass argument.
OptionValue(const PassArgData &value)
void setValue(const PassArgData &value)
const PassArgData & getValue() const
mlir::OpPassManager & getValue() const
Returns the current value of the option.
Definition: PassOptions.h:487
bool hasValue() const
Returns if the current option has a value.
Definition: PassOptions.h:484
llvm::cl::list< PassArgData, bool, PassNameParser > passList
The set of passes and pass pipelines to run.
PassPipelineCLParserImpl(StringRef arg, StringRef description, bool passNamesOnly)
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...