MLIR  20.0.0git
PassRegistry.cpp
Go to the documentation of this file.
1 //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/Pass/Pass.h"
12 #include "mlir/Pass/PassManager.h"
13 #include "llvm/ADT/DenseMap.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/Support/Format.h"
17 #include "llvm/Support/ManagedStatic.h"
18 #include "llvm/Support/MemoryBuffer.h"
19 #include "llvm/Support/SourceMgr.h"
20 
21 #include <optional>
22 #include <utility>
23 
24 using namespace mlir;
25 using namespace detail;
26 
27 /// Static mapping of all of the registered passes.
28 static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
29 
30 /// A mapping of the above pass registry entries to the corresponding TypeID
31 /// of the pass that they generate.
32 static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
33 
34 /// Static mapping of all of the registered pass pipelines.
35 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
37 
38 /// Utility to create a default registry function from a pass instance.
41  return [=](OpPassManager &pm, StringRef options,
42  function_ref<LogicalResult(const Twine &)> errorHandler) {
43  std::unique_ptr<Pass> pass = allocator();
44  LogicalResult result = pass->initializeOptions(options, errorHandler);
45 
46  std::optional<StringRef> pmOpName = pm.getOpName();
47  std::optional<StringRef> passOpName = pass->getOpName();
48  if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && pmOpName &&
49  passOpName && *pmOpName != *passOpName) {
50  return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
51  "' restricted to '" + *pass->getOpName() +
52  "' on a PassManager intended to run on '" +
53  pm.getOpAnchorName() + "', did you intend to nest?");
54  }
55  pm.addPass(std::move(pass));
56  return result;
57  };
58 }
59 
60 /// Utility to print the help string for a specific option.
61 static void printOptionHelp(StringRef arg, StringRef desc, size_t indent,
62  size_t descIndent, bool isTopLevel) {
63  size_t numSpaces = descIndent - indent - 4;
64  llvm::outs().indent(indent)
65  << "--" << llvm::left_justify(arg, numSpaces) << "- " << desc << '\n';
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // PassRegistry
70 //===----------------------------------------------------------------------===//
71 
72 /// Prints the passes that were previously registered and stored in passRegistry
74  size_t maxWidth = 0;
75  for (auto &entry : *passRegistry)
76  maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
77 
78  // Functor used to print the ordered entries of a registration map.
79  auto printOrderedEntries = [&](StringRef header, auto &map) {
81  for (auto &kv : map)
82  orderedEntries.push_back(&kv.second);
83  llvm::array_pod_sort(
84  orderedEntries.begin(), orderedEntries.end(),
85  [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
86  return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
87  });
88 
89  llvm::outs().indent(0) << header << ":\n";
90  for (PassRegistryEntry *entry : orderedEntries)
91  entry->printHelpStr(/*indent=*/2, maxWidth);
92  };
93 
94  // Print the available passes.
95  printOrderedEntries("Passes", *passRegistry);
96 }
97 
98 /// Print the help information for this pass. This includes the argument,
99 /// description, and any pass options. `descIndent` is the indent that the
100 /// descriptions should be aligned.
101 void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {
102  printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
103  /*isTopLevel=*/true);
104  // If this entry has options, print the help for those as well.
105  optHandler([=](const PassOptions &options) {
106  options.printHelp(indent, descIndent);
107  });
108 }
109 
110 /// Return the maximum width required when printing the options of this
111 /// entry.
113  size_t maxLen = 0;
114  optHandler([&](const PassOptions &options) mutable {
115  maxLen = options.getOptionWidth() + 2;
116  });
117  return maxLen;
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // PassPipelineInfo
122 //===----------------------------------------------------------------------===//
123 
125  StringRef arg, StringRef description, const PassRegistryFunction &function,
126  std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {
127  PassPipelineInfo pipelineInfo(arg, description, function,
128  std::move(optHandler));
129  bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;
130 #ifndef NDEBUG
131  if (!inserted)
132  report_fatal_error("Pass pipeline " + arg + " registered multiple times");
133 #endif
134  (void)inserted;
135 }
136 
137 //===----------------------------------------------------------------------===//
138 // PassInfo
139 //===----------------------------------------------------------------------===//
140 
141 PassInfo::PassInfo(StringRef arg, StringRef description,
142  const PassAllocatorFunction &allocator)
144  arg, description, buildDefaultRegistryFn(allocator),
145  // Use a temporary pass to provide an options instance.
146  [=](function_ref<void(const PassOptions &)> optHandler) {
147  optHandler(allocator()->passOptions);
148  }) {}
149 
151  std::unique_ptr<Pass> pass = function();
152  StringRef arg = pass->getArgument();
153  if (arg.empty())
154  llvm::report_fatal_error(llvm::Twine("Trying to register '") +
155  pass->getName() +
156  "' pass that does not override `getArgument()`");
157  StringRef description = pass->getDescription();
158  PassInfo passInfo(arg, description, function);
159  passRegistry->try_emplace(arg, passInfo);
160 
161  // Verify that the registered pass has the same ID as any registered to this
162  // arg before it.
163  TypeID entryTypeID = pass->getTypeID();
164  auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;
165  if (it->second != entryTypeID)
166  llvm::report_fatal_error(
167  "pass allocator creates a different pass than previously "
168  "registered for pass " +
169  arg);
170 }
171 
172 /// Returns the pass info for the specified pass argument or null if unknown.
173 const PassInfo *mlir::PassInfo::lookup(StringRef passArg) {
174  auto it = passRegistry->find(passArg);
175  return it == passRegistry->end() ? nullptr : &it->second;
176 }
177 
178 /// Returns the pass pipeline info for the specified pass pipeline argument or
179 /// null if unknown.
180 const PassPipelineInfo *mlir::PassPipelineInfo::lookup(StringRef pipelineArg) {
181  auto it = passPipelineRegistry->find(pipelineArg);
182  return it == passPipelineRegistry->end() ? nullptr : &it->second;
183 }
184 
185 //===----------------------------------------------------------------------===//
186 // PassOptions
187 //===----------------------------------------------------------------------===//
188 
189 /// Extract an argument from 'options' and update it to point after the arg.
190 /// Returns the cleaned argument string.
191 static StringRef extractArgAndUpdateOptions(StringRef &options,
192  size_t argSize) {
193  StringRef str = options.take_front(argSize).trim();
194  options = options.drop_front(argSize).ltrim();
195 
196  // Early exit if there's no escape sequence.
197  if (str.size() <= 2)
198  return str;
199 
200  const auto escapePairs = {std::make_pair('\'', '\''),
201  std::make_pair('"', '"'), std::make_pair('{', '}')};
202  for (const auto &escape : escapePairs) {
203  if (str.front() == escape.first && str.back() == escape.second) {
204  // Drop the escape characters and trim.
205  str = str.drop_front().drop_back().trim();
206  // Don't process additional escape sequences.
207  break;
208  }
209  }
210 
211  return str;
212 }
213 
215  llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
216  function_ref<LogicalResult(StringRef)> elementParseFn) {
217  // Functor used for finding a character in a string, and skipping over
218  // various "range" characters.
219  llvm::unique_function<size_t(StringRef, size_t, char)> findChar =
220  [&](StringRef str, size_t index, char c) -> size_t {
221  for (size_t i = index, e = str.size(); i < e; ++i) {
222  if (str[i] == c)
223  return i;
224  // Check for various range characters.
225  if (str[i] == '{')
226  i = findChar(str, i + 1, '}');
227  else if (str[i] == '(')
228  i = findChar(str, i + 1, ')');
229  else if (str[i] == '[')
230  i = findChar(str, i + 1, ']');
231  else if (str[i] == '\"')
232  i = str.find_first_of('\"', i + 1);
233  else if (str[i] == '\'')
234  i = str.find_first_of('\'', i + 1);
235  }
236  return StringRef::npos;
237  };
238 
239  size_t nextElePos = findChar(optionStr, 0, ',');
240  while (nextElePos != StringRef::npos) {
241  // Process the portion before the comma.
242  if (failed(
243  elementParseFn(extractArgAndUpdateOptions(optionStr, nextElePos))))
244  return failure();
245 
246  // Drop the leading ','
247  optionStr = optionStr.drop_front();
248  nextElePos = findChar(optionStr, 0, ',');
249  }
250  return elementParseFn(
251  extractArgAndUpdateOptions(optionStr, optionStr.size()));
252 }
253 
254 /// Out of line virtual function to provide home for the class.
255 void detail::PassOptions::OptionBase::anchor() {}
256 
257 /// Copy the option values from 'other'.
259  assert(options.size() == other.options.size());
260  if (options.empty())
261  return;
262  for (auto optionsIt : llvm::zip(options, other.options))
263  std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
264 }
265 
266 /// Parse in the next argument from the given options string. Returns a tuple
267 /// containing [the key of the option, the value of the option, updated
268 /// `options` string pointing after the parsed option].
269 static std::tuple<StringRef, StringRef, StringRef>
270 parseNextArg(StringRef options) {
271  // Try to process the given punctuation, properly escaping any contained
272  // characters.
273  auto tryProcessPunct = [&](size_t &currentPos, char punct) {
274  if (options[currentPos] != punct)
275  return false;
276  size_t nextIt = options.find_first_of(punct, currentPos + 1);
277  if (nextIt != StringRef::npos)
278  currentPos = nextIt;
279  return true;
280  };
281 
282  // Parse the argument name of the option.
283  StringRef argName;
284  for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
285  // Check for the end of the full option.
286  if (argEndIt == optionsE || options[argEndIt] == ' ') {
287  argName = extractArgAndUpdateOptions(options, argEndIt);
288  return std::make_tuple(argName, StringRef(), options);
289  }
290 
291  // Check for the end of the name and the start of the value.
292  if (options[argEndIt] == '=') {
293  argName = extractArgAndUpdateOptions(options, argEndIt);
294  options = options.drop_front();
295  break;
296  }
297  }
298 
299  // Parse the value of the option.
300  for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
301  // Handle the end of the options string.
302  if (argEndIt == optionsE || options[argEndIt] == ' ') {
303  StringRef value = extractArgAndUpdateOptions(options, argEndIt);
304  return std::make_tuple(argName, value, options);
305  }
306 
307  // Skip over escaped sequences.
308  char c = options[argEndIt];
309  if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
310  continue;
311  // '{...}' is used to specify options to passes, properly escape it so
312  // that we don't accidentally split any nested options.
313  if (c == '{') {
314  size_t braceCount = 1;
315  for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
316  // Allow nested punctuation.
317  if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
318  continue;
319  if (options[argEndIt] == '{')
320  ++braceCount;
321  else if (options[argEndIt] == '}' && --braceCount == 0)
322  break;
323  }
324  // Account for the increment at the top of the loop.
325  --argEndIt;
326  }
327  }
328  llvm_unreachable("unexpected control flow in pass option parsing");
329 }
330 
332  raw_ostream &errorStream) {
333  // NOTE: `options` is modified in place to always refer to the unprocessed
334  // part of the string.
335  while (!options.empty()) {
336  StringRef key, value;
337  std::tie(key, value, options) = parseNextArg(options);
338  if (key.empty())
339  continue;
340 
341  auto it = OptionsMap.find(key);
342  if (it == OptionsMap.end()) {
343  errorStream << "<Pass-Options-Parser>: no such option " << key << "\n";
344  return failure();
345  }
346  if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
347  return failure();
348  }
349 
350  return success();
351 }
352 
353 /// Print the options held by this struct in a form that can be parsed via
354 /// 'parseFromString'.
355 void detail::PassOptions::print(raw_ostream &os) const {
356  // If there are no options, there is nothing left to do.
357  if (OptionsMap.empty())
358  return;
359 
360  // Sort the options to make the ordering deterministic.
361  SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
362  auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
363  return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
364  };
365  llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
366 
367  // Interleave the options with ' '.
368  os << '{';
369  llvm::interleave(
370  orderedOps, os, [&](OptionBase *option) { option->print(os); }, " ");
371  os << '}';
372 }
373 
374 /// Print the help string for the options held by this struct. `descIndent` is
375 /// the indent within the stream that the descriptions should be aligned.
376 void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {
377  // Sort the options to make the ordering deterministic.
378  SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
379  auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
380  return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
381  };
382  llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
383  for (OptionBase *option : orderedOps) {
384  // TODO: printOptionInfo assumes a specific indent and will
385  // print options with values with incorrect indentation. We should add
386  // support to llvm::cl::Option for passing in a base indent to use when
387  // printing.
388  llvm::outs().indent(indent);
389  option->getOption()->printOptionInfo(descIndent - indent);
390  }
391 }
392 
393 /// Return the maximum width required when printing the help string.
395  size_t max = 0;
396  for (auto *option : options)
397  max = std::max(max, option->getOption()->getOptionWidth());
398  return max;
399 }
400 
401 //===----------------------------------------------------------------------===//
402 // MLIR Options
403 //===----------------------------------------------------------------------===//
404 
405 //===----------------------------------------------------------------------===//
406 // OpPassManager: OptionValue
407 
408 llvm::cl::OptionValue<OpPassManager>::OptionValue() = default;
409 llvm::cl::OptionValue<OpPassManager>::OptionValue(
410  const mlir::OpPassManager &value) {
411  setValue(value);
412 }
413 llvm::cl::OptionValue<OpPassManager>::OptionValue(
415  if (rhs.hasValue())
416  setValue(rhs.getValue());
417 }
418 llvm::cl::OptionValue<OpPassManager> &
419 llvm::cl::OptionValue<OpPassManager>::operator=(
420  const mlir::OpPassManager &rhs) {
421  setValue(rhs);
422  return *this;
423 }
424 
425 llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default;
426 
427 void llvm::cl::OptionValue<OpPassManager>::setValue(
428  const OpPassManager &newValue) {
429  if (hasValue())
430  *value = newValue;
431  else
432  value = std::make_unique<mlir::OpPassManager>(newValue);
433 }
434 void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
435  FailureOr<OpPassManager> pipeline = parsePassPipeline(pipelineStr);
436  assert(succeeded(pipeline) && "invalid pass pipeline");
437  setValue(*pipeline);
438 }
439 
441  const mlir::OpPassManager &rhs) const {
442  std::string lhsStr, rhsStr;
443  {
444  raw_string_ostream lhsStream(lhsStr);
445  value->printAsTextualPipeline(lhsStream);
446 
447  raw_string_ostream rhsStream(rhsStr);
448  rhs.printAsTextualPipeline(rhsStream);
449  }
450 
451  // Use the textual format for pipeline comparisons.
452  return lhsStr == rhsStr;
453 }
454 
455 void llvm::cl::OptionValue<OpPassManager>::anchor() {}
456 
457 //===----------------------------------------------------------------------===//
458 // OpPassManager: Parser
459 
460 namespace llvm {
461 namespace cl {
462 template class basic_parser<OpPassManager>;
463 } // namespace cl
464 } // namespace llvm
465 
466 bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg,
467  ParsedPassManager &value) {
468  FailureOr<OpPassManager> pipeline = parsePassPipeline(arg);
469  if (failed(pipeline))
470  return true;
471  value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
472  return false;
473 }
474 
475 void llvm::cl::parser<OpPassManager>::print(raw_ostream &os,
476  const OpPassManager &value) {
477  value.printAsTextualPipeline(os);
478 }
479 
481  const Option &opt, OpPassManager &pm, const OptVal &defaultValue,
482  size_t globalWidth) const {
483  printOptionName(opt, globalWidth);
484  outs() << "= ";
485  pm.printAsTextualPipeline(outs());
486 
487  if (defaultValue.hasValue()) {
488  outs().indent(2) << " (default: ";
489  defaultValue.getValue().printAsTextualPipeline(outs());
490  outs() << ")";
491  }
492  outs() << "\n";
493 }
494 
496 
498  default;
500  ParsedPassManager &&) = default;
502  default;
503 
504 //===----------------------------------------------------------------------===//
505 // TextualPassPipeline Parser
506 //===----------------------------------------------------------------------===//
507 
508 namespace {
509 /// This class represents a textual description of a pass pipeline.
510 class TextualPipeline {
511 public:
512  /// Try to initialize this pipeline with the given pipeline text.
513  /// `errorStream` is the output stream to emit errors to.
514  LogicalResult initialize(StringRef text, raw_ostream &errorStream);
515 
516  /// Add the internal pipeline elements to the provided pass manager.
517  LogicalResult
518  addToPipeline(OpPassManager &pm,
519  function_ref<LogicalResult(const Twine &)> errorHandler) const;
520 
521 private:
522  /// A functor used to emit errors found during pipeline handling. The first
523  /// parameter corresponds to the raw location within the pipeline string. This
524  /// should always return failure.
525  using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>;
526 
527  /// A struct to capture parsed pass pipeline names.
528  ///
529  /// A pipeline is defined as a series of names, each of which may in itself
530  /// recursively contain a nested pipeline. A name is either the name of a pass
531  /// (e.g. "cse") or the name of an operation type (e.g. "buitin.module"). If
532  /// the name is the name of a pass, the InnerPipeline is empty, since passes
533  /// cannot contain inner pipelines.
534  struct PipelineElement {
535  PipelineElement(StringRef name) : name(name) {}
536 
537  StringRef name;
538  StringRef options;
539  const PassRegistryEntry *registryEntry = nullptr;
540  std::vector<PipelineElement> innerPipeline;
541  };
542 
543  /// Parse the given pipeline text into the internal pipeline vector. This
544  /// function only parses the structure of the pipeline, and does not resolve
545  /// its elements.
546  LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
547 
548  /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
549  /// the corresponding registry entry.
550  LogicalResult
551  resolvePipelineElements(MutableArrayRef<PipelineElement> elements,
552  ErrorHandlerT errorHandler);
553 
554  /// Resolve a single element of the pipeline.
555  LogicalResult resolvePipelineElement(PipelineElement &element,
556  ErrorHandlerT errorHandler);
557 
558  /// Add the given pipeline elements to the provided pass manager.
559  LogicalResult
560  addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
561  function_ref<LogicalResult(const Twine &)> errorHandler) const;
562 
563  std::vector<PipelineElement> pipeline;
564 };
565 
566 } // namespace
567 
568 /// Try to initialize this pipeline with the given pipeline text. An option is
569 /// given to enable accurate error reporting.
570 LogicalResult TextualPipeline::initialize(StringRef text,
571  raw_ostream &errorStream) {
572  if (text.empty())
573  return success();
574 
575  // Build a source manager to use for error reporting.
576  llvm::SourceMgr pipelineMgr;
577  pipelineMgr.AddNewSourceBuffer(
578  llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser",
579  /*RequiresNullTerminator=*/false),
580  SMLoc());
581  auto errorHandler = [&](const char *rawLoc, Twine msg) {
582  pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc),
583  llvm::SourceMgr::DK_Error, msg);
584  return failure();
585  };
586 
587  // Parse the provided pipeline string.
588  if (failed(parsePipelineText(text, errorHandler)))
589  return failure();
590  return resolvePipelineElements(pipeline, errorHandler);
591 }
592 
593 /// Add the internal pipeline elements to the provided pass manager.
594 LogicalResult TextualPipeline::addToPipeline(
595  OpPassManager &pm,
596  function_ref<LogicalResult(const Twine &)> errorHandler) const {
597  // Temporarily disable implicit nesting while we append to the pipeline. We
598  // want the created pipeline to exactly match the parsed text pipeline, so
599  // it's preferrable to just error out if implicit nesting would be required.
600  OpPassManager::Nesting nesting = pm.getNesting();
602  auto restore = llvm::make_scope_exit([&]() { pm.setNesting(nesting); });
603 
604  return addToPipeline(pipeline, pm, errorHandler);
605 }
606 
607 /// Parse the given pipeline text into the internal pipeline vector. This
608 /// function only parses the structure of the pipeline, and does not resolve
609 /// its elements.
610 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
611  ErrorHandlerT errorHandler) {
612  SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
613  for (;;) {
614  std::vector<PipelineElement> &pipeline = *pipelineStack.back();
615  size_t pos = text.find_first_of(",(){");
616  pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());
617 
618  // If we have a single terminating name, we're done.
619  if (pos == StringRef::npos)
620  break;
621 
622  text = text.substr(pos);
623  char sep = text[0];
624 
625  // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
626  if (sep == '{') {
627  text = text.substr(1);
628 
629  // Skip over everything until the closing '}' and store as options.
630  size_t close = StringRef::npos;
631  for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
632  if (text[i] == '{') {
633  ++braceCount;
634  continue;
635  }
636  if (text[i] == '}' && --braceCount == 0) {
637  close = i;
638  break;
639  }
640  }
641 
642  // Check to see if a closing options brace was found.
643  if (close == StringRef::npos) {
644  return errorHandler(
645  /*rawLoc=*/text.data() - 1,
646  "missing closing '}' while processing pass options");
647  }
648  pipeline.back().options = text.substr(0, close);
649  text = text.substr(close + 1);
650 
651  // Consume space characters that an user might add for readability.
652  text = text.ltrim();
653 
654  // Skip checking for '(' because nested pipelines cannot have options.
655  } else if (sep == '(') {
656  text = text.substr(1);
657 
658  // Push the inner pipeline onto the stack to continue processing.
659  pipelineStack.push_back(&pipeline.back().innerPipeline);
660  continue;
661  }
662 
663  // When handling the close parenthesis, we greedily consume them to avoid
664  // empty strings in the pipeline.
665  while (text.consume_front(")")) {
666  // If we try to pop the outer pipeline we have unbalanced parentheses.
667  if (pipelineStack.size() == 1)
668  return errorHandler(/*rawLoc=*/text.data() - 1,
669  "encountered extra closing ')' creating unbalanced "
670  "parentheses while parsing pipeline");
671 
672  pipelineStack.pop_back();
673  // Consume space characters that an user might add for readability.
674  text = text.ltrim();
675  }
676 
677  // Check if we've finished parsing.
678  if (text.empty())
679  break;
680 
681  // Otherwise, the end of an inner pipeline always has to be followed by
682  // a comma, and then we can continue.
683  if (!text.consume_front(","))
684  return errorHandler(text.data(), "expected ',' after parsing pipeline");
685  }
686 
687  // Check for unbalanced parentheses.
688  if (pipelineStack.size() > 1)
689  return errorHandler(
690  text.data(),
691  "encountered unbalanced parentheses while parsing pipeline");
692 
693  assert(pipelineStack.back() == &pipeline &&
694  "wrong pipeline at the bottom of the stack");
695  return success();
696 }
697 
698 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
699 /// the corresponding registry entry.
700 LogicalResult TextualPipeline::resolvePipelineElements(
701  MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) {
702  for (auto &elt : elements)
703  if (failed(resolvePipelineElement(elt, errorHandler)))
704  return failure();
705  return success();
706 }
707 
708 /// Resolve a single element of the pipeline.
709 LogicalResult
710 TextualPipeline::resolvePipelineElement(PipelineElement &element,
711  ErrorHandlerT errorHandler) {
712  // If the inner pipeline of this element is not empty, this is an operation
713  // pipeline.
714  if (!element.innerPipeline.empty())
715  return resolvePipelineElements(element.innerPipeline, errorHandler);
716 
717  // Otherwise, this must be a pass or pass pipeline.
718  // Check to see if a pipeline was registered with this name.
719  if ((element.registryEntry = PassPipelineInfo::lookup(element.name)))
720  return success();
721 
722  // If not, then this must be a specific pass name.
723  if ((element.registryEntry = PassInfo::lookup(element.name)))
724  return success();
725 
726  // Emit an error for the unknown pass.
727  auto *rawLoc = element.name.data();
728  return errorHandler(rawLoc, "'" + element.name +
729  "' does not refer to a "
730  "registered pass or pass pipeline");
731 }
732 
733 /// Add the given pipeline elements to the provided pass manager.
734 LogicalResult TextualPipeline::addToPipeline(
736  function_ref<LogicalResult(const Twine &)> errorHandler) const {
737  for (auto &elt : elements) {
738  if (elt.registryEntry) {
739  if (failed(elt.registryEntry->addToPipeline(pm, elt.options,
740  errorHandler))) {
741  return errorHandler("failed to add `" + elt.name + "` with options `" +
742  elt.options + "`");
743  }
744  } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),
745  errorHandler))) {
746  return errorHandler("failed to add `" + elt.name + "` with options `" +
747  elt.options + "` to inner pipeline");
748  }
749  }
750  return success();
751 }
752 
753 LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
754  raw_ostream &errorStream) {
755  TextualPipeline pipelineParser;
756  if (failed(pipelineParser.initialize(pipeline, errorStream)))
757  return failure();
758  auto errorHandler = [&](Twine msg) {
759  errorStream << msg << "\n";
760  return failure();
761  };
762  if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
763  return failure();
764  return success();
765 }
766 
767 FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline,
768  raw_ostream &errorStream) {
769  pipeline = pipeline.trim();
770  // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
771  size_t pipelineStart = pipeline.find_first_of('(');
772  if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
773  !pipeline.consume_back(")")) {
774  errorStream << "expected pass pipeline to be wrapped with the anchor "
775  "operation type, e.g. 'builtin.module(...)'";
776  return failure();
777  }
778 
779  StringRef opName = pipeline.take_front(pipelineStart).rtrim();
780  OpPassManager pm(opName);
781  if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm,
782  errorStream)))
783  return failure();
784  return pm;
785 }
786 
787 //===----------------------------------------------------------------------===//
788 // PassNameParser
789 //===----------------------------------------------------------------------===//
790 
791 namespace {
792 /// This struct represents the possible data entries in a parsed pass pipeline
793 /// list.
794 struct PassArgData {
795  PassArgData() = default;
796  PassArgData(const PassRegistryEntry *registryEntry)
797  : registryEntry(registryEntry) {}
798 
799  /// This field is used when the parsed option corresponds to a registered pass
800  /// or pass pipeline.
801  const PassRegistryEntry *registryEntry{nullptr};
802 
803  /// This field is set when instance specific pass options have been provided
804  /// on the command line.
805  StringRef options;
806 };
807 } // namespace
808 
809 namespace llvm {
810 namespace cl {
811 /// Define a valid OptionValue for the command line pass argument.
812 template <>
814  : OptionValueBase<PassArgData, /*isClass=*/true> {
815  OptionValue(const PassArgData &value) { this->setValue(value); }
816  OptionValue() = default;
817  void anchor() override {}
818 
819  bool hasValue() const { return true; }
820  const PassArgData &getValue() const { return value; }
821  void setValue(const PassArgData &value) { this->value = value; }
822 
823  PassArgData value;
824 };
825 } // namespace cl
826 } // namespace llvm
827 
828 namespace {
829 
830 /// The name for the command line option used for parsing the textual pass
831 /// pipeline.
832 #define PASS_PIPELINE_ARG "pass-pipeline"
833 
834 /// Adds command line option for each registered pass or pass pipeline, as well
835 /// as textual pass pipelines.
836 struct PassNameParser : public llvm::cl::parser<PassArgData> {
837  PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}
838 
839  void initialize();
840  void printOptionInfo(const llvm::cl::Option &opt,
841  size_t globalWidth) const override;
842  size_t getOptionWidth(const llvm::cl::Option &opt) const override;
843  bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
844  PassArgData &value);
845 
846  /// If true, this parser only parses entries that correspond to a concrete
847  /// pass registry entry, and does not include pipeline entries or the options
848  /// for pass entries.
849  bool passNamesOnly = false;
850 };
851 } // namespace
852 
853 void PassNameParser::initialize() {
855 
856  /// Add the pass entries.
857  for (const auto &kv : *passRegistry) {
858  addLiteralOption(kv.second.getPassArgument(), &kv.second,
859  kv.second.getPassDescription());
860  }
861  /// Add the pass pipeline entries.
862  if (!passNamesOnly) {
863  for (const auto &kv : *passPipelineRegistry) {
864  addLiteralOption(kv.second.getPassArgument(), &kv.second,
865  kv.second.getPassDescription());
866  }
867  }
868 }
869 
870 void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,
871  size_t globalWidth) const {
872  // If this parser is just parsing pass names, print a simplified option
873  // string.
874  if (passNamesOnly) {
875  llvm::outs() << " --" << opt.ArgStr << "=<pass-arg>";
876  opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
877  return;
878  }
879 
880  // Print the information for the top-level option.
881  if (opt.hasArgStr()) {
882  llvm::outs() << " --" << opt.ArgStr;
883  opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
884  } else {
885  llvm::outs() << " " << opt.HelpStr << '\n';
886  }
887 
888  // Functor used to print the ordered entries of a registration map.
889  auto printOrderedEntries = [&](StringRef header, auto &map) {
891  for (auto &kv : map)
892  orderedEntries.push_back(&kv.second);
893  llvm::array_pod_sort(
894  orderedEntries.begin(), orderedEntries.end(),
895  [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
896  return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
897  });
898 
899  llvm::outs().indent(4) << header << ":\n";
900  for (PassRegistryEntry *entry : orderedEntries)
901  entry->printHelpStr(/*indent=*/6, globalWidth);
902  };
903 
904  // Print the available passes.
905  printOrderedEntries("Passes", *passRegistry);
906 
907  // Print the available pass pipelines.
908  if (!passPipelineRegistry->empty())
909  printOrderedEntries("Pass Pipelines", *passPipelineRegistry);
910 }
911 
912 size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {
913  size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2;
914 
915  // Check for any wider pass or pipeline options.
916  for (auto &entry : *passRegistry)
917  maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
918  for (auto &entry : *passPipelineRegistry)
919  maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
920  return maxWidth;
921 }
922 
923 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
924  StringRef arg, PassArgData &value) {
925  if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value))
926  return true;
927  value.options = arg;
928  return false;
929 }
930 
931 //===----------------------------------------------------------------------===//
932 // PassPipelineCLParser
933 //===----------------------------------------------------------------------===//
934 
935 namespace mlir {
936 namespace detail {
938  PassPipelineCLParserImpl(StringRef arg, StringRef description,
939  bool passNamesOnly)
940  : passList(arg, llvm::cl::desc(description)) {
941  passList.getParser().passNamesOnly = passNamesOnly;
942  passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
943  }
944 
945  /// Returns true if the given pass registry entry was registered at the
946  /// top-level of the parser, i.e. not within an explicit textual pipeline.
947  bool contains(const PassRegistryEntry *entry) const {
948  return llvm::any_of(passList, [&](const PassArgData &data) {
949  return data.registryEntry == entry;
950  });
951  }
952 
953  /// The set of passes and pass pipelines to run.
954  llvm::cl::list<PassArgData, bool, PassNameParser> passList;
955 };
956 } // namespace detail
957 } // namespace mlir
958 
959 /// Construct a pass pipeline parser with the given command line description.
960 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)
961  : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
962  arg, description, /*passNamesOnly=*/false)),
963  passPipeline(
965  llvm::cl::desc("Textual description of the pass pipeline to run")) {}
966 
967 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description,
968  StringRef alias)
969  : PassPipelineCLParser(arg, description) {
970  passPipelineAlias.emplace(alias,
971  llvm::cl::desc("Alias for --" PASS_PIPELINE_ARG),
972  llvm::cl::aliasopt(passPipeline));
973 }
974 
976 
977 /// Returns true if this parser contains any valid options to add.
979  return passPipeline.getNumOccurrences() != 0 ||
980  impl->passList.getNumOccurrences() != 0;
981 }
982 
983 /// Returns true if the given pass registry entry was registered at the
984 /// top-level of the parser, i.e. not within an explicit textual pipeline.
986  return impl->contains(entry);
987 }
988 
989 /// Adds the passes defined by this parser entry to the given pass manager.
991  OpPassManager &pm,
992  function_ref<LogicalResult(const Twine &)> errorHandler) const {
993  if (passPipeline.getNumOccurrences()) {
994  if (impl->passList.getNumOccurrences())
995  return errorHandler(
996  "'-" PASS_PIPELINE_ARG
997  "' option can't be used with individual pass options");
998  std::string errMsg;
999  llvm::raw_string_ostream os(errMsg);
1000  FailureOr<OpPassManager> parsed = parsePassPipeline(passPipeline, os);
1001  if (failed(parsed))
1002  return errorHandler(errMsg);
1003  pm = std::move(*parsed);
1004  return success();
1005  }
1006 
1007  for (auto &passIt : impl->passList) {
1008  if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
1009  errorHandler)))
1010  return failure();
1011  }
1012  return success();
1013 }
1014 
1015 //===----------------------------------------------------------------------===//
1016 // PassNameCLParser
1017 
1018 /// Construct a pass pipeline parser with the given command line description.
1019 PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description)
1020  : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
1021  arg, description, /*passNamesOnly=*/true)) {
1022  impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
1023 }
1025 
1026 /// Returns true if this parser contains any valid options to add.
1028  return impl->passList.getNumOccurrences() != 0;
1029 }
1030 
1031 /// Returns true if the given pass registry entry was registered at the
1032 /// top-level of the parser, i.e. not within an explicit textual pipeline.
1034  return impl->contains(entry);
1035 }
static llvm::ManagedStatic< PassManagerOptions > options
static llvm::ManagedStatic< llvm::StringMap< PassPipelineInfo > > passPipelineRegistry
Static mapping of all of the registered pass pipelines.
#define PASS_PIPELINE_ARG
The name for the command line option used for parsing the textual pass pipeline.
static llvm::ManagedStatic< llvm::StringMap< PassInfo > > passRegistry
Static mapping of all of the registered passes.
static PassRegistryFunction buildDefaultRegistryFn(const PassAllocatorFunction &allocator)
Utility to create a default registry function from a pass instance.
static void printOptionHelp(StringRef arg, StringRef desc, size_t indent, size_t descIndent, bool isTopLevel)
Utility to print the help string for a specific option.
static llvm::ManagedStatic< llvm::StringMap< TypeID > > passRegistryTypeIDs
A mapping of the above pass registry entries to the corresponding TypeID of the pass that they genera...
static std::tuple< StringRef, StringRef, StringRef > parseNextArg(StringRef options)
Parse in the next argument from the given options string.
static StringRef extractArgAndUpdateOptions(StringRef &options, size_t argSize)
Extract an argument from 'options' and update it to point after the arg.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
This class represents a pass manager that runs passes on either a specific operation type,...
Definition: PassManager.h:47
void printAsTextualPipeline(raw_ostream &os) const
Prints out the passes of the pass manager as the textual representation of pipelines.
Definition: Pass.cpp:401
std::optional< OperationName > getOpName(MLIRContext &context) const
Return the operation name that this pass manager operates on, or std::nullopt if this is an op-agnost...
Definition: Pass.cpp:382
void setNesting(Nesting nesting)
Enable or disable the implicit nesting on this particular PassManager.
Definition: Pass.cpp:425
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
Definition: Pass.cpp:363
Nesting getNesting()
Return the current nesting mode.
Definition: Pass.cpp:427
Nesting
This enum represents the nesting behavior of the pass manager.
Definition: PassManager.h:50
@ Explicit
Explicit nesting behavior.
StringRef getOpAnchorName() const
Return the name used to anchor this pass manager.
Definition: Pass.cpp:386
OpPassManager & nest(OperationName nestedName)
Nest a new operation pass manager for the given operation kind under this pass manager.
Definition: Pass.cpp:353
A structure to represent the information for a derived pass class.
Definition: PassRegistry.h:118
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
PassInfo(StringRef arg, StringRef description, const PassAllocatorFunction &allocator)
PassInfo constructor should not be invoked directly, instead use PassRegistration or registerPass.
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassNameCLParser(StringRef arg, StringRef description)
Construct a parser with the given command line description.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
This class implements a command-line parser for MLIR passes.
Definition: PassRegistry.h:247
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassPipelineCLParser(StringRef arg, StringRef description)
Construct a pass pipeline parser with the given command line description.
LogicalResult addToPipeline(OpPassManager &pm, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds the passes defined by this parser entry to the given pass manager.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
A structure to represent the information of a registered pass pipeline.
Definition: PassRegistry.h:104
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
Definition: PassRegistry.h:52
void printHelpStr(size_t indent, size_t descIndent) const
Print the help information for this pass.
size_t getOptionWidth() const
Return the maximum width required when printing the options of this entry.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
Base container class and manager for all pass options.
Definition: PassOptions.h:92
size_t getOptionWidth() const
Return the maximum width required when printing the help string.
void printHelp(size_t indent, size_t descIndent) const
Print the help string for the options held by this struct.
LogicalResult parseFromString(StringRef options, raw_ostream &errorStream=llvm::errs())
Parse options out as key=value pairs that can then be handed off to the llvm::cl command line passing...
void print(raw_ostream &os) const
Print the options held by this struct in a form that can be parsed via 'parseFromString'.
void copyOptionValuesFrom(const PassOptions &other)
Copy the option values from 'other' into 'this', where 'other' has the same options as 'this'.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
LogicalResult parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName, StringRef optionStr, function_ref< LogicalResult(StringRef)> elementParseFn)
Parse a string containing a list of comma-delimited elements, invoking the given parser for each sub-...
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:68
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
void printRegisteredPasses()
Prints the passes that were previously registered and stored in passRegistry.
std::function< std::unique_ptr< Pass >()> PassAllocatorFunction
Definition: PassRegistry.h:41
void registerPass(const PassAllocatorFunction &function)
Register a specific dialect pass allocator function with the system, typically used through the PassR...
void registerPassPipeline(StringRef arg, StringRef description, const PassRegistryFunction &function, std::function< void(function_ref< void(const detail::PassOptions &)>)> optHandler)
Register a specific dialect pipeline registry function with the system, typically used through the Pa...
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
std::function< LogicalResult(OpPassManager &, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler)> PassRegistryFunction
A registry function that adds passes to the given pass manager.
Definition: PassRegistry.h:40
Define a valid OptionValue for the command line pass argument.
OptionValue(const PassArgData &value)
void setValue(const PassArgData &value)
const PassArgData & getValue() const
mlir::OpPassManager & getValue() const
Returns the current value of the option.
Definition: PassOptions.h:473
bool hasValue() const
Returns if the current option has a value.
Definition: PassOptions.h:470
llvm::cl::list< PassArgData, bool, PassNameParser > passList
The set of passes and pass pipelines to run.
PassPipelineCLParserImpl(StringRef arg, StringRef description, bool passNamesOnly)
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...