MLIR  21.0.0git
PassRegistry.cpp
Go to the documentation of this file.
1 //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/Pass/Pass.h"
12 #include "mlir/Pass/PassManager.h"
13 #include "llvm/ADT/DenseMap.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/Support/Format.h"
17 #include "llvm/Support/ManagedStatic.h"
18 #include "llvm/Support/MemoryBuffer.h"
19 #include "llvm/Support/SourceMgr.h"
20 
21 #include <optional>
22 #include <utility>
23 
24 using namespace mlir;
25 using namespace detail;
26 
27 /// Static mapping of all of the registered passes.
28 static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
29 
30 /// A mapping of the above pass registry entries to the corresponding TypeID
31 /// of the pass that they generate.
32 static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
33 
34 /// Static mapping of all of the registered pass pipelines.
35 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
37 
38 /// Utility to create a default registry function from a pass instance.
41  return [=](OpPassManager &pm, StringRef options,
42  function_ref<LogicalResult(const Twine &)> errorHandler) {
43  std::unique_ptr<Pass> pass = allocator();
44  LogicalResult result = pass->initializeOptions(options, errorHandler);
45 
46  std::optional<StringRef> pmOpName = pm.getOpName();
47  std::optional<StringRef> passOpName = pass->getOpName();
48  if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && pmOpName &&
49  passOpName && *pmOpName != *passOpName) {
50  return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
51  "' restricted to '" + *pass->getOpName() +
52  "' on a PassManager intended to run on '" +
53  pm.getOpAnchorName() + "', did you intend to nest?");
54  }
55  pm.addPass(std::move(pass));
56  return result;
57  };
58 }
59 
60 /// Utility to print the help string for a specific option.
61 static void printOptionHelp(StringRef arg, StringRef desc, size_t indent,
62  size_t descIndent, bool isTopLevel) {
63  size_t numSpaces = descIndent - indent - 4;
64  llvm::outs().indent(indent)
65  << "--" << llvm::left_justify(arg, numSpaces) << "- " << desc << '\n';
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // PassRegistry
70 //===----------------------------------------------------------------------===//
71 
72 /// Prints the passes that were previously registered and stored in passRegistry
74  size_t maxWidth = 0;
75  for (auto &entry : *passRegistry)
76  maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
77 
78  // Functor used to print the ordered entries of a registration map.
79  auto printOrderedEntries = [&](StringRef header, auto &map) {
81  for (auto &kv : map)
82  orderedEntries.push_back(&kv.second);
83  llvm::array_pod_sort(
84  orderedEntries.begin(), orderedEntries.end(),
85  [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
86  return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
87  });
88 
89  llvm::outs().indent(0) << header << ":\n";
90  for (PassRegistryEntry *entry : orderedEntries)
91  entry->printHelpStr(/*indent=*/2, maxWidth);
92  };
93 
94  // Print the available passes.
95  printOrderedEntries("Passes", *passRegistry);
96 }
97 
98 /// Print the help information for this pass. This includes the argument,
99 /// description, and any pass options. `descIndent` is the indent that the
100 /// descriptions should be aligned.
101 void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {
102  printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
103  /*isTopLevel=*/true);
104  // If this entry has options, print the help for those as well.
105  optHandler([=](const PassOptions &options) {
106  options.printHelp(indent, descIndent);
107  });
108 }
109 
110 /// Return the maximum width required when printing the options of this
111 /// entry.
113  size_t maxLen = 0;
114  optHandler([&](const PassOptions &options) mutable {
115  maxLen = options.getOptionWidth() + 2;
116  });
117  return maxLen;
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // PassPipelineInfo
122 //===----------------------------------------------------------------------===//
123 
125  StringRef arg, StringRef description, const PassRegistryFunction &function,
126  std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {
127  PassPipelineInfo pipelineInfo(arg, description, function,
128  std::move(optHandler));
129  bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;
130 #ifndef NDEBUG
131  if (!inserted)
132  report_fatal_error("Pass pipeline " + arg + " registered multiple times");
133 #endif
134  (void)inserted;
135 }
136 
137 //===----------------------------------------------------------------------===//
138 // PassInfo
139 //===----------------------------------------------------------------------===//
140 
141 PassInfo::PassInfo(StringRef arg, StringRef description,
142  const PassAllocatorFunction &allocator)
144  arg, description, buildDefaultRegistryFn(allocator),
145  // Use a temporary pass to provide an options instance.
146  [=](function_ref<void(const PassOptions &)> optHandler) {
147  optHandler(allocator()->passOptions);
148  }) {}
149 
151  std::unique_ptr<Pass> pass = function();
152  StringRef arg = pass->getArgument();
153  if (arg.empty())
154  llvm::report_fatal_error(llvm::Twine("Trying to register '") +
155  pass->getName() +
156  "' pass that does not override `getArgument()`");
157  StringRef description = pass->getDescription();
158  PassInfo passInfo(arg, description, function);
159  passRegistry->try_emplace(arg, passInfo);
160 
161  // Verify that the registered pass has the same ID as any registered to this
162  // arg before it.
163  TypeID entryTypeID = pass->getTypeID();
164  auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;
165  if (it->second != entryTypeID)
166  llvm::report_fatal_error(
167  "pass allocator creates a different pass than previously "
168  "registered for pass " +
169  arg);
170 }
171 
172 /// Returns the pass info for the specified pass argument or null if unknown.
173 const PassInfo *mlir::PassInfo::lookup(StringRef passArg) {
174  auto it = passRegistry->find(passArg);
175  return it == passRegistry->end() ? nullptr : &it->second;
176 }
177 
178 /// Returns the pass pipeline info for the specified pass pipeline argument or
179 /// null if unknown.
180 const PassPipelineInfo *mlir::PassPipelineInfo::lookup(StringRef pipelineArg) {
181  auto it = passPipelineRegistry->find(pipelineArg);
182  return it == passPipelineRegistry->end() ? nullptr : &it->second;
183 }
184 
185 //===----------------------------------------------------------------------===//
186 // PassOptions
187 //===----------------------------------------------------------------------===//
188 
189 /// Attempt to find the next occurance of character 'c' in the string starting
190 /// from the `index`-th position , omitting any occurances that appear within
191 /// intervening ranges or literals.
192 static size_t findChar(StringRef str, size_t index, char c) {
193  for (size_t i = index, e = str.size(); i < e; ++i) {
194  if (str[i] == c)
195  return i;
196  // Check for various range characters.
197  if (str[i] == '{')
198  i = findChar(str, i + 1, '}');
199  else if (str[i] == '(')
200  i = findChar(str, i + 1, ')');
201  else if (str[i] == '[')
202  i = findChar(str, i + 1, ']');
203  else if (str[i] == '\"')
204  i = str.find_first_of('\"', i + 1);
205  else if (str[i] == '\'')
206  i = str.find_first_of('\'', i + 1);
207  if (i == StringRef::npos)
208  return StringRef::npos;
209  }
210  return StringRef::npos;
211 }
212 
213 /// Extract an argument from 'options' and update it to point after the arg.
214 /// Returns the cleaned argument string.
215 static StringRef extractArgAndUpdateOptions(StringRef &options,
216  size_t argSize) {
217  StringRef str = options.take_front(argSize).trim();
218  options = options.drop_front(argSize).ltrim();
219 
220  // Early exit if there's no escape sequence.
221  if (str.size() <= 1)
222  return str;
223 
224  const auto escapePairs = {std::make_pair('\'', '\''),
225  std::make_pair('"', '"')};
226  for (const auto &escape : escapePairs) {
227  if (str.front() == escape.first && str.back() == escape.second) {
228  // Drop the escape characters and trim.
229  // Don't process additional escape sequences.
230  return str.drop_front().drop_back().trim();
231  }
232  }
233 
234  // Arguments may be wrapped in `{...}`. Unlike the quotation markers that
235  // denote literals, we respect scoping here. The outer `{...}` should not
236  // be stripped in cases such as "arg={...},{...}", which can be used to denote
237  // lists of nested option structs.
238  if (str.front() == '{') {
239  unsigned match = findChar(str, 1, '}');
240  if (match == str.size() - 1)
241  str = str.drop_front().drop_back().trim();
242  }
243 
244  return str;
245 }
246 
248  llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
249  function_ref<LogicalResult(StringRef)> elementParseFn) {
250  if (optionStr.empty())
251  return success();
252 
253  size_t nextElePos = findChar(optionStr, 0, ',');
254  while (nextElePos != StringRef::npos) {
255  // Process the portion before the comma.
256  if (failed(
257  elementParseFn(extractArgAndUpdateOptions(optionStr, nextElePos))))
258  return failure();
259 
260  // Drop the leading ','
261  optionStr = optionStr.drop_front();
262  nextElePos = findChar(optionStr, 0, ',');
263  }
264  return elementParseFn(
265  extractArgAndUpdateOptions(optionStr, optionStr.size()));
266 }
267 
268 /// Out of line virtual function to provide home for the class.
269 void detail::PassOptions::OptionBase::anchor() {}
270 
271 /// Copy the option values from 'other'.
273  assert(options.size() == other.options.size());
274  if (options.empty())
275  return;
276  for (auto optionsIt : llvm::zip(options, other.options))
277  std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
278 }
279 
280 /// Parse in the next argument from the given options string. Returns a tuple
281 /// containing [the key of the option, the value of the option, updated
282 /// `options` string pointing after the parsed option].
283 static std::tuple<StringRef, StringRef, StringRef>
284 parseNextArg(StringRef options) {
285  // Try to process the given punctuation, properly escaping any contained
286  // characters.
287  auto tryProcessPunct = [&](size_t &currentPos, char punct) {
288  if (options[currentPos] != punct)
289  return false;
290  size_t nextIt = options.find_first_of(punct, currentPos + 1);
291  if (nextIt != StringRef::npos)
292  currentPos = nextIt;
293  return true;
294  };
295 
296  // Parse the argument name of the option.
297  StringRef argName;
298  for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
299  // Check for the end of the full option.
300  if (argEndIt == optionsE || options[argEndIt] == ' ') {
301  argName = extractArgAndUpdateOptions(options, argEndIt);
302  return std::make_tuple(argName, StringRef(), options);
303  }
304 
305  // Check for the end of the name and the start of the value.
306  if (options[argEndIt] == '=') {
307  argName = extractArgAndUpdateOptions(options, argEndIt);
308  options = options.drop_front();
309  break;
310  }
311  }
312 
313  // Parse the value of the option.
314  for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
315  // Handle the end of the options string.
316  if (argEndIt == optionsE || options[argEndIt] == ' ') {
317  StringRef value = extractArgAndUpdateOptions(options, argEndIt);
318  return std::make_tuple(argName, value, options);
319  }
320 
321  // Skip over escaped sequences.
322  char c = options[argEndIt];
323  if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
324  continue;
325  // '{...}' is used to specify options to passes, properly escape it so
326  // that we don't accidentally split any nested options.
327  if (c == '{') {
328  size_t braceCount = 1;
329  for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
330  // Allow nested punctuation.
331  if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
332  continue;
333  if (options[argEndIt] == '{')
334  ++braceCount;
335  else if (options[argEndIt] == '}' && --braceCount == 0)
336  break;
337  }
338  // Account for the increment at the top of the loop.
339  --argEndIt;
340  }
341  }
342  llvm_unreachable("unexpected control flow in pass option parsing");
343 }
344 
346  raw_ostream &errorStream) {
347  // NOTE: `options` is modified in place to always refer to the unprocessed
348  // part of the string.
349  while (!options.empty()) {
350  StringRef key, value;
351  std::tie(key, value, options) = parseNextArg(options);
352  if (key.empty())
353  continue;
354 
355  auto it = OptionsMap.find(key);
356  if (it == OptionsMap.end()) {
357  errorStream << "<Pass-Options-Parser>: no such option " << key << "\n";
358  return failure();
359  }
360  if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
361  return failure();
362  }
363 
364  return success();
365 }
366 
367 /// Print the options held by this struct in a form that can be parsed via
368 /// 'parseFromString'.
369 void detail::PassOptions::print(raw_ostream &os) const {
370  // If there are no options, there is nothing left to do.
371  if (OptionsMap.empty())
372  return;
373 
374  // Sort the options to make the ordering deterministic.
375  SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
376  auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
377  return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
378  };
379  llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
380 
381  // Interleave the options with ' '.
382  os << '{';
383  llvm::interleave(
384  orderedOps, os, [&](OptionBase *option) { option->print(os); }, " ");
385  os << '}';
386 }
387 
388 /// Print the help string for the options held by this struct. `descIndent` is
389 /// the indent within the stream that the descriptions should be aligned.
390 void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {
391  // Sort the options to make the ordering deterministic.
392  SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
393  auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
394  return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
395  };
396  llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
397  for (OptionBase *option : orderedOps) {
398  // TODO: printOptionInfo assumes a specific indent and will
399  // print options with values with incorrect indentation. We should add
400  // support to llvm::cl::Option for passing in a base indent to use when
401  // printing.
402  llvm::outs().indent(indent);
403  option->getOption()->printOptionInfo(descIndent - indent);
404  }
405 }
406 
407 /// Return the maximum width required when printing the help string.
409  size_t max = 0;
410  for (auto *option : options)
411  max = std::max(max, option->getOption()->getOptionWidth());
412  return max;
413 }
414 
415 //===----------------------------------------------------------------------===//
416 // MLIR Options
417 //===----------------------------------------------------------------------===//
418 
419 //===----------------------------------------------------------------------===//
420 // OpPassManager: OptionValue
421 //===----------------------------------------------------------------------===//
422 
423 llvm::cl::OptionValue<OpPassManager>::OptionValue() = default;
424 llvm::cl::OptionValue<OpPassManager>::OptionValue(
425  const mlir::OpPassManager &value) {
426  setValue(value);
427 }
428 llvm::cl::OptionValue<OpPassManager>::OptionValue(
430  if (rhs.hasValue())
431  setValue(rhs.getValue());
432 }
433 llvm::cl::OptionValue<OpPassManager> &
434 llvm::cl::OptionValue<OpPassManager>::operator=(
435  const mlir::OpPassManager &rhs) {
436  setValue(rhs);
437  return *this;
438 }
439 
440 llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default;
441 
442 void llvm::cl::OptionValue<OpPassManager>::setValue(
443  const OpPassManager &newValue) {
444  if (hasValue())
445  *value = newValue;
446  else
447  value = std::make_unique<mlir::OpPassManager>(newValue);
448 }
449 void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
450  FailureOr<OpPassManager> pipeline = parsePassPipeline(pipelineStr);
451  assert(succeeded(pipeline) && "invalid pass pipeline");
452  setValue(*pipeline);
453 }
454 
456  const mlir::OpPassManager &rhs) const {
457  std::string lhsStr, rhsStr;
458  {
459  raw_string_ostream lhsStream(lhsStr);
460  value->printAsTextualPipeline(lhsStream);
461 
462  raw_string_ostream rhsStream(rhsStr);
463  rhs.printAsTextualPipeline(rhsStream);
464  }
465 
466  // Use the textual format for pipeline comparisons.
467  return lhsStr == rhsStr;
468 }
469 
470 void llvm::cl::OptionValue<OpPassManager>::anchor() {}
471 
472 //===----------------------------------------------------------------------===//
473 // OpPassManager: Parser
474 //===----------------------------------------------------------------------===//
475 
476 namespace llvm {
477 namespace cl {
478 template class basic_parser<OpPassManager>;
479 } // namespace cl
480 } // namespace llvm
481 
482 bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg,
483  ParsedPassManager &value) {
484  FailureOr<OpPassManager> pipeline = parsePassPipeline(arg);
485  if (failed(pipeline))
486  return true;
487  value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
488  return false;
489 }
490 
491 void llvm::cl::parser<OpPassManager>::print(raw_ostream &os,
492  const OpPassManager &value) {
493  value.printAsTextualPipeline(os);
494 }
495 
497  const Option &opt, OpPassManager &pm, const OptVal &defaultValue,
498  size_t globalWidth) const {
499  printOptionName(opt, globalWidth);
500  outs() << "= ";
501  pm.printAsTextualPipeline(outs());
502 
503  if (defaultValue.hasValue()) {
504  outs().indent(2) << " (default: ";
505  defaultValue.getValue().printAsTextualPipeline(outs());
506  outs() << ")";
507  }
508  outs() << "\n";
509 }
510 
512 
514  default;
516  ParsedPassManager &&) = default;
518  default;
519 
520 //===----------------------------------------------------------------------===//
521 // TextualPassPipeline Parser
522 //===----------------------------------------------------------------------===//
523 
524 namespace {
525 /// This class represents a textual description of a pass pipeline.
526 class TextualPipeline {
527 public:
528  /// Try to initialize this pipeline with the given pipeline text.
529  /// `errorStream` is the output stream to emit errors to.
530  LogicalResult initialize(StringRef text, raw_ostream &errorStream);
531 
532  /// Add the internal pipeline elements to the provided pass manager.
533  LogicalResult
534  addToPipeline(OpPassManager &pm,
535  function_ref<LogicalResult(const Twine &)> errorHandler) const;
536 
537 private:
538  /// A functor used to emit errors found during pipeline handling. The first
539  /// parameter corresponds to the raw location within the pipeline string. This
540  /// should always return failure.
541  using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>;
542 
543  /// A struct to capture parsed pass pipeline names.
544  ///
545  /// A pipeline is defined as a series of names, each of which may in itself
546  /// recursively contain a nested pipeline. A name is either the name of a pass
547  /// (e.g. "cse") or the name of an operation type (e.g. "buitin.module"). If
548  /// the name is the name of a pass, the InnerPipeline is empty, since passes
549  /// cannot contain inner pipelines.
550  struct PipelineElement {
551  PipelineElement(StringRef name) : name(name) {}
552 
553  StringRef name;
554  StringRef options;
555  const PassRegistryEntry *registryEntry = nullptr;
556  std::vector<PipelineElement> innerPipeline;
557  };
558 
559  /// Parse the given pipeline text into the internal pipeline vector. This
560  /// function only parses the structure of the pipeline, and does not resolve
561  /// its elements.
562  LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
563 
564  /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
565  /// the corresponding registry entry.
566  LogicalResult
567  resolvePipelineElements(MutableArrayRef<PipelineElement> elements,
568  ErrorHandlerT errorHandler);
569 
570  /// Resolve a single element of the pipeline.
571  LogicalResult resolvePipelineElement(PipelineElement &element,
572  ErrorHandlerT errorHandler);
573 
574  /// Add the given pipeline elements to the provided pass manager.
575  LogicalResult
576  addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
577  function_ref<LogicalResult(const Twine &)> errorHandler) const;
578 
579  std::vector<PipelineElement> pipeline;
580 };
581 
582 } // namespace
583 
584 /// Try to initialize this pipeline with the given pipeline text. An option is
585 /// given to enable accurate error reporting.
586 LogicalResult TextualPipeline::initialize(StringRef text,
587  raw_ostream &errorStream) {
588  if (text.empty())
589  return success();
590 
591  // Build a source manager to use for error reporting.
592  llvm::SourceMgr pipelineMgr;
593  pipelineMgr.AddNewSourceBuffer(
594  llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser",
595  /*RequiresNullTerminator=*/false),
596  SMLoc());
597  auto errorHandler = [&](const char *rawLoc, Twine msg) {
598  pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc),
599  llvm::SourceMgr::DK_Error, msg);
600  return failure();
601  };
602 
603  // Parse the provided pipeline string.
604  if (failed(parsePipelineText(text, errorHandler)))
605  return failure();
606  return resolvePipelineElements(pipeline, errorHandler);
607 }
608 
609 /// Add the internal pipeline elements to the provided pass manager.
610 LogicalResult TextualPipeline::addToPipeline(
611  OpPassManager &pm,
612  function_ref<LogicalResult(const Twine &)> errorHandler) const {
613  // Temporarily disable implicit nesting while we append to the pipeline. We
614  // want the created pipeline to exactly match the parsed text pipeline, so
615  // it's preferrable to just error out if implicit nesting would be required.
616  OpPassManager::Nesting nesting = pm.getNesting();
618  auto restore = llvm::make_scope_exit([&]() { pm.setNesting(nesting); });
619 
620  return addToPipeline(pipeline, pm, errorHandler);
621 }
622 
623 /// Parse the given pipeline text into the internal pipeline vector. This
624 /// function only parses the structure of the pipeline, and does not resolve
625 /// its elements.
626 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
627  ErrorHandlerT errorHandler) {
628  SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
629  for (;;) {
630  std::vector<PipelineElement> &pipeline = *pipelineStack.back();
631  size_t pos = text.find_first_of(",(){");
632  pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());
633 
634  // If we have a single terminating name, we're done.
635  if (pos == StringRef::npos)
636  break;
637 
638  text = text.substr(pos);
639  char sep = text[0];
640 
641  // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
642  if (sep == '{') {
643  text = text.substr(1);
644 
645  // Skip over everything until the closing '}' and store as options.
646  size_t close = StringRef::npos;
647  for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
648  if (text[i] == '{') {
649  ++braceCount;
650  continue;
651  }
652  if (text[i] == '}' && --braceCount == 0) {
653  close = i;
654  break;
655  }
656  }
657 
658  // Check to see if a closing options brace was found.
659  if (close == StringRef::npos) {
660  return errorHandler(
661  /*rawLoc=*/text.data() - 1,
662  "missing closing '}' while processing pass options");
663  }
664  pipeline.back().options = text.substr(0, close);
665  text = text.substr(close + 1);
666 
667  // Consume space characters that an user might add for readability.
668  text = text.ltrim();
669 
670  // Skip checking for '(' because nested pipelines cannot have options.
671  } else if (sep == '(') {
672  text = text.substr(1);
673 
674  // Push the inner pipeline onto the stack to continue processing.
675  pipelineStack.push_back(&pipeline.back().innerPipeline);
676  continue;
677  }
678 
679  // When handling the close parenthesis, we greedily consume them to avoid
680  // empty strings in the pipeline.
681  while (text.consume_front(")")) {
682  // If we try to pop the outer pipeline we have unbalanced parentheses.
683  if (pipelineStack.size() == 1)
684  return errorHandler(/*rawLoc=*/text.data() - 1,
685  "encountered extra closing ')' creating unbalanced "
686  "parentheses while parsing pipeline");
687 
688  pipelineStack.pop_back();
689  // Consume space characters that an user might add for readability.
690  text = text.ltrim();
691  }
692 
693  // Check if we've finished parsing.
694  if (text.empty())
695  break;
696 
697  // Otherwise, the end of an inner pipeline always has to be followed by
698  // a comma, and then we can continue.
699  if (!text.consume_front(","))
700  return errorHandler(text.data(), "expected ',' after parsing pipeline");
701  }
702 
703  // Check for unbalanced parentheses.
704  if (pipelineStack.size() > 1)
705  return errorHandler(
706  text.data(),
707  "encountered unbalanced parentheses while parsing pipeline");
708 
709  assert(pipelineStack.back() == &pipeline &&
710  "wrong pipeline at the bottom of the stack");
711  return success();
712 }
713 
714 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
715 /// the corresponding registry entry.
716 LogicalResult TextualPipeline::resolvePipelineElements(
717  MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) {
718  for (auto &elt : elements)
719  if (failed(resolvePipelineElement(elt, errorHandler)))
720  return failure();
721  return success();
722 }
723 
724 /// Resolve a single element of the pipeline.
725 LogicalResult
726 TextualPipeline::resolvePipelineElement(PipelineElement &element,
727  ErrorHandlerT errorHandler) {
728  // If the inner pipeline of this element is not empty, this is an operation
729  // pipeline.
730  if (!element.innerPipeline.empty())
731  return resolvePipelineElements(element.innerPipeline, errorHandler);
732 
733  // Otherwise, this must be a pass or pass pipeline.
734  // Check to see if a pipeline was registered with this name.
735  if ((element.registryEntry = PassPipelineInfo::lookup(element.name)))
736  return success();
737 
738  // If not, then this must be a specific pass name.
739  if ((element.registryEntry = PassInfo::lookup(element.name)))
740  return success();
741 
742  // Emit an error for the unknown pass.
743  auto *rawLoc = element.name.data();
744  return errorHandler(rawLoc, "'" + element.name +
745  "' does not refer to a "
746  "registered pass or pass pipeline");
747 }
748 
749 /// Add the given pipeline elements to the provided pass manager.
750 LogicalResult TextualPipeline::addToPipeline(
752  function_ref<LogicalResult(const Twine &)> errorHandler) const {
753  for (auto &elt : elements) {
754  if (elt.registryEntry) {
755  if (failed(elt.registryEntry->addToPipeline(pm, elt.options,
756  errorHandler))) {
757  return errorHandler("failed to add `" + elt.name + "` with options `" +
758  elt.options + "`");
759  }
760  } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),
761  errorHandler))) {
762  return errorHandler("failed to add `" + elt.name + "` with options `" +
763  elt.options + "` to inner pipeline");
764  }
765  }
766  return success();
767 }
768 
769 LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
770  raw_ostream &errorStream) {
771  TextualPipeline pipelineParser;
772  if (failed(pipelineParser.initialize(pipeline, errorStream)))
773  return failure();
774  auto errorHandler = [&](Twine msg) {
775  errorStream << msg << "\n";
776  return failure();
777  };
778  if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
779  return failure();
780  return success();
781 }
782 
783 FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline,
784  raw_ostream &errorStream) {
785  pipeline = pipeline.trim();
786  // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
787  size_t pipelineStart = pipeline.find_first_of('(');
788  if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
789  !pipeline.consume_back(")")) {
790  errorStream << "expected pass pipeline to be wrapped with the anchor "
791  "operation type, e.g. 'builtin.module(...)'";
792  return failure();
793  }
794 
795  StringRef opName = pipeline.take_front(pipelineStart).rtrim();
796  OpPassManager pm(opName);
797  if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm,
798  errorStream)))
799  return failure();
800  return pm;
801 }
802 
803 //===----------------------------------------------------------------------===//
804 // PassNameParser
805 //===----------------------------------------------------------------------===//
806 
807 namespace {
808 /// This struct represents the possible data entries in a parsed pass pipeline
809 /// list.
810 struct PassArgData {
811  PassArgData() = default;
812  PassArgData(const PassRegistryEntry *registryEntry)
813  : registryEntry(registryEntry) {}
814 
815  /// This field is used when the parsed option corresponds to a registered pass
816  /// or pass pipeline.
817  const PassRegistryEntry *registryEntry{nullptr};
818 
819  /// This field is set when instance specific pass options have been provided
820  /// on the command line.
821  StringRef options;
822 };
823 } // namespace
824 
825 namespace llvm {
826 namespace cl {
827 /// Define a valid OptionValue for the command line pass argument.
828 template <>
830  : OptionValueBase<PassArgData, /*isClass=*/true> {
831  OptionValue(const PassArgData &value) { this->setValue(value); }
832  OptionValue() = default;
833  void anchor() override {}
834 
835  bool hasValue() const { return true; }
836  const PassArgData &getValue() const { return value; }
837  void setValue(const PassArgData &value) { this->value = value; }
838 
839  PassArgData value;
840 };
841 } // namespace cl
842 } // namespace llvm
843 
844 namespace {
845 
846 /// The name for the command line option used for parsing the textual pass
847 /// pipeline.
848 #define PASS_PIPELINE_ARG "pass-pipeline"
849 
850 /// Adds command line option for each registered pass or pass pipeline, as well
851 /// as textual pass pipelines.
852 struct PassNameParser : public llvm::cl::parser<PassArgData> {
853  PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}
854 
855  void initialize();
856  void printOptionInfo(const llvm::cl::Option &opt,
857  size_t globalWidth) const override;
858  size_t getOptionWidth(const llvm::cl::Option &opt) const override;
859  bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
860  PassArgData &value);
861 
862  /// If true, this parser only parses entries that correspond to a concrete
863  /// pass registry entry, and does not include pipeline entries or the options
864  /// for pass entries.
865  bool passNamesOnly = false;
866 };
867 } // namespace
868 
869 void PassNameParser::initialize() {
871 
872  /// Add the pass entries.
873  for (const auto &kv : *passRegistry) {
874  addLiteralOption(kv.second.getPassArgument(), &kv.second,
875  kv.second.getPassDescription());
876  }
877  /// Add the pass pipeline entries.
878  if (!passNamesOnly) {
879  for (const auto &kv : *passPipelineRegistry) {
880  addLiteralOption(kv.second.getPassArgument(), &kv.second,
881  kv.second.getPassDescription());
882  }
883  }
884 }
885 
886 void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,
887  size_t globalWidth) const {
888  // If this parser is just parsing pass names, print a simplified option
889  // string.
890  if (passNamesOnly) {
891  llvm::outs() << " --" << opt.ArgStr << "=<pass-arg>";
892  opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
893  return;
894  }
895 
896  // Print the information for the top-level option.
897  if (opt.hasArgStr()) {
898  llvm::outs() << " --" << opt.ArgStr;
899  opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
900  } else {
901  llvm::outs() << " " << opt.HelpStr << '\n';
902  }
903 
904  // Functor used to print the ordered entries of a registration map.
905  auto printOrderedEntries = [&](StringRef header, auto &map) {
907  for (auto &kv : map)
908  orderedEntries.push_back(&kv.second);
909  llvm::array_pod_sort(
910  orderedEntries.begin(), orderedEntries.end(),
911  [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
912  return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
913  });
914 
915  llvm::outs().indent(4) << header << ":\n";
916  for (PassRegistryEntry *entry : orderedEntries)
917  entry->printHelpStr(/*indent=*/6, globalWidth);
918  };
919 
920  // Print the available passes.
921  printOrderedEntries("Passes", *passRegistry);
922 
923  // Print the available pass pipelines.
924  if (!passPipelineRegistry->empty())
925  printOrderedEntries("Pass Pipelines", *passPipelineRegistry);
926 }
927 
928 size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {
929  size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2;
930 
931  // Check for any wider pass or pipeline options.
932  for (auto &entry : *passRegistry)
933  maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
934  for (auto &entry : *passPipelineRegistry)
935  maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
936  return maxWidth;
937 }
938 
939 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
940  StringRef arg, PassArgData &value) {
941  if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value))
942  return true;
943  value.options = arg;
944  return false;
945 }
946 
947 //===----------------------------------------------------------------------===//
948 // PassPipelineCLParser
949 //===----------------------------------------------------------------------===//
950 
951 namespace mlir {
952 namespace detail {
954  PassPipelineCLParserImpl(StringRef arg, StringRef description,
955  bool passNamesOnly)
956  : passList(arg, llvm::cl::desc(description)) {
957  passList.getParser().passNamesOnly = passNamesOnly;
958  passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
959  }
960 
961  /// Returns true if the given pass registry entry was registered at the
962  /// top-level of the parser, i.e. not within an explicit textual pipeline.
963  bool contains(const PassRegistryEntry *entry) const {
964  return llvm::any_of(passList, [&](const PassArgData &data) {
965  return data.registryEntry == entry;
966  });
967  }
968 
969  /// The set of passes and pass pipelines to run.
970  llvm::cl::list<PassArgData, bool, PassNameParser> passList;
971 };
972 } // namespace detail
973 } // namespace mlir
974 
975 /// Construct a pass pipeline parser with the given command line description.
976 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)
977  : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
978  arg, description, /*passNamesOnly=*/false)),
979  passPipeline(
981  llvm::cl::desc("Textual description of the pass pipeline to run")) {}
982 
983 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description,
984  StringRef alias)
985  : PassPipelineCLParser(arg, description) {
986  passPipelineAlias.emplace(alias,
987  llvm::cl::desc("Alias for --" PASS_PIPELINE_ARG),
988  llvm::cl::aliasopt(passPipeline));
989 }
990 
992 
993 /// Returns true if this parser contains any valid options to add.
995  return passPipeline.getNumOccurrences() != 0 ||
996  impl->passList.getNumOccurrences() != 0;
997 }
998 
999 /// Returns true if the given pass registry entry was registered at the
1000 /// top-level of the parser, i.e. not within an explicit textual pipeline.
1002  return impl->contains(entry);
1003 }
1004 
1005 /// Adds the passes defined by this parser entry to the given pass manager.
1007  OpPassManager &pm,
1008  function_ref<LogicalResult(const Twine &)> errorHandler) const {
1009  if (passPipeline.getNumOccurrences()) {
1010  if (impl->passList.getNumOccurrences())
1011  return errorHandler(
1012  "'-" PASS_PIPELINE_ARG
1013  "' option can't be used with individual pass options");
1014  std::string errMsg;
1015  llvm::raw_string_ostream os(errMsg);
1016  FailureOr<OpPassManager> parsed = parsePassPipeline(passPipeline, os);
1017  if (failed(parsed))
1018  return errorHandler(errMsg);
1019  pm = std::move(*parsed);
1020  return success();
1021  }
1022 
1023  for (auto &passIt : impl->passList) {
1024  if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
1025  errorHandler)))
1026  return failure();
1027  }
1028  return success();
1029 }
1030 
1031 //===----------------------------------------------------------------------===//
1032 // PassNameCLParser
1033 //===----------------------------------------------------------------------===//
1034 
1035 /// Construct a pass pipeline parser with the given command line description.
1036 PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description)
1037  : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
1038  arg, description, /*passNamesOnly=*/true)) {
1039  impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
1040 }
1042 
1043 /// Returns true if this parser contains any valid options to add.
1045  return impl->passList.getNumOccurrences() != 0;
1046 }
1047 
1048 /// Returns true if the given pass registry entry was registered at the
1049 /// top-level of the parser, i.e. not within an explicit textual pipeline.
1051  return impl->contains(entry);
1052 }
static llvm::ManagedStatic< PassManagerOptions > options
static llvm::ManagedStatic< llvm::StringMap< PassPipelineInfo > > passPipelineRegistry
Static mapping of all of the registered pass pipelines.
#define PASS_PIPELINE_ARG
The name for the command line option used for parsing the textual pass pipeline.
static llvm::ManagedStatic< llvm::StringMap< PassInfo > > passRegistry
Static mapping of all of the registered passes.
static PassRegistryFunction buildDefaultRegistryFn(const PassAllocatorFunction &allocator)
Utility to create a default registry function from a pass instance.
static void printOptionHelp(StringRef arg, StringRef desc, size_t indent, size_t descIndent, bool isTopLevel)
Utility to print the help string for a specific option.
static llvm::ManagedStatic< llvm::StringMap< TypeID > > passRegistryTypeIDs
A mapping of the above pass registry entries to the corresponding TypeID of the pass that they genera...
static std::tuple< StringRef, StringRef, StringRef > parseNextArg(StringRef options)
Parse in the next argument from the given options string.
static size_t findChar(StringRef str, size_t index, char c)
Attempt to find the next occurance of character 'c' in the string starting from the index-th position...
static StringRef extractArgAndUpdateOptions(StringRef &options, size_t argSize)
Extract an argument from 'options' and update it to point after the arg.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
This class represents a pass manager that runs passes on either a specific operation type,...
Definition: PassManager.h:46
void printAsTextualPipeline(raw_ostream &os) const
Prints out the passes of the pass manager as the textual representation of pipelines.
Definition: Pass.cpp:401
std::optional< OperationName > getOpName(MLIRContext &context) const
Return the operation name that this pass manager operates on, or std::nullopt if this is an op-agnost...
Definition: Pass.cpp:382
void setNesting(Nesting nesting)
Enable or disable the implicit nesting on this particular PassManager.
Definition: Pass.cpp:425
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
Definition: Pass.cpp:363
Nesting getNesting()
Return the current nesting mode.
Definition: Pass.cpp:427
Nesting
This enum represents the nesting behavior of the pass manager.
Definition: PassManager.h:49
@ Explicit
Explicit nesting behavior.
StringRef getOpAnchorName() const
Return the name used to anchor this pass manager.
Definition: Pass.cpp:386
OpPassManager & nest(OperationName nestedName)
Nest a new operation pass manager for the given operation kind under this pass manager.
Definition: Pass.cpp:353
A structure to represent the information for a derived pass class.
Definition: PassRegistry.h:118
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
PassInfo(StringRef arg, StringRef description, const PassAllocatorFunction &allocator)
PassInfo constructor should not be invoked directly, instead use PassRegistration or registerPass.
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassNameCLParser(StringRef arg, StringRef description)
Construct a parser with the given command line description.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
This class implements a command-line parser for MLIR passes.
Definition: PassRegistry.h:247
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassPipelineCLParser(StringRef arg, StringRef description)
Construct a pass pipeline parser with the given command line description.
LogicalResult addToPipeline(OpPassManager &pm, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds the passes defined by this parser entry to the given pass manager.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
A structure to represent the information of a registered pass pipeline.
Definition: PassRegistry.h:104
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
Definition: PassRegistry.h:52
void printHelpStr(size_t indent, size_t descIndent) const
Print the help information for this pass.
size_t getOptionWidth() const
Return the maximum width required when printing the options of this entry.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
Base container class and manager for all pass options.
Definition: PassOptions.h:92
size_t getOptionWidth() const
Return the maximum width required when printing the help string.
void printHelp(size_t indent, size_t descIndent) const
Print the help string for the options held by this struct.
LogicalResult parseFromString(StringRef options, raw_ostream &errorStream=llvm::errs())
Parse options out as key=value pairs that can then be handed off to the llvm::cl command line passing...
void print(raw_ostream &os) const
Print the options held by this struct in a form that can be parsed via 'parseFromString'.
void copyOptionValuesFrom(const PassOptions &other)
Copy the option values from 'other' into 'this', where 'other' has the same options as 'this'.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
LogicalResult parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName, StringRef optionStr, function_ref< LogicalResult(StringRef)> elementParseFn)
Parse a string containing a list of comma-delimited elements, invoking the given parser for each sub-...
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:68
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
void printRegisteredPasses()
Prints the passes that were previously registered and stored in passRegistry.
std::function< std::unique_ptr< Pass >()> PassAllocatorFunction
Definition: PassRegistry.h:41
void registerPass(const PassAllocatorFunction &function)
Register a specific dialect pass allocator function with the system, typically used through the PassR...
void registerPassPipeline(StringRef arg, StringRef description, const PassRegistryFunction &function, std::function< void(function_ref< void(const detail::PassOptions &)>)> optHandler)
Register a specific dialect pipeline registry function with the system, typically used through the Pa...
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
std::function< LogicalResult(OpPassManager &, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler)> PassRegistryFunction
A registry function that adds passes to the given pass manager.
Definition: PassRegistry.h:40
Define a valid OptionValue for the command line pass argument.
OptionValue(const PassArgData &value)
void setValue(const PassArgData &value)
const PassArgData & getValue() const
mlir::OpPassManager & getValue() const
Returns the current value of the option.
Definition: PassOptions.h:490
bool hasValue() const
Returns if the current option has a value.
Definition: PassOptions.h:487
llvm::cl::list< PassArgData, bool, PassNameParser > passList
The set of passes and pass pipelines to run.
PassPipelineCLParserImpl(StringRef arg, StringRef description, bool passNamesOnly)
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...