MLIR  20.0.0git
PassRegistry.cpp
Go to the documentation of this file.
1 //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/Pass/Pass.h"
12 #include "mlir/Pass/PassManager.h"
13 #include "llvm/ADT/DenseMap.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/Support/Format.h"
16 #include "llvm/Support/ManagedStatic.h"
17 #include "llvm/Support/MemoryBuffer.h"
18 #include "llvm/Support/SourceMgr.h"
19 
20 #include <optional>
21 #include <utility>
22 
23 using namespace mlir;
24 using namespace detail;
25 
26 /// Static mapping of all of the registered passes.
27 static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
28 
29 /// A mapping of the above pass registry entries to the corresponding TypeID
30 /// of the pass that they generate.
31 static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
32 
33 /// Static mapping of all of the registered pass pipelines.
34 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
36 
37 /// Utility to create a default registry function from a pass instance.
40  return [=](OpPassManager &pm, StringRef options,
41  function_ref<LogicalResult(const Twine &)> errorHandler) {
42  std::unique_ptr<Pass> pass = allocator();
43  LogicalResult result = pass->initializeOptions(options, errorHandler);
44 
45  std::optional<StringRef> pmOpName = pm.getOpName();
46  std::optional<StringRef> passOpName = pass->getOpName();
47  if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && pmOpName &&
48  passOpName && *pmOpName != *passOpName) {
49  return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
50  "' restricted to '" + *pass->getOpName() +
51  "' on a PassManager intended to run on '" +
52  pm.getOpAnchorName() + "', did you intend to nest?");
53  }
54  pm.addPass(std::move(pass));
55  return result;
56  };
57 }
58 
59 /// Utility to print the help string for a specific option.
60 static void printOptionHelp(StringRef arg, StringRef desc, size_t indent,
61  size_t descIndent, bool isTopLevel) {
62  size_t numSpaces = descIndent - indent - 4;
63  llvm::outs().indent(indent)
64  << "--" << llvm::left_justify(arg, numSpaces) << "- " << desc << '\n';
65 }
66 
67 //===----------------------------------------------------------------------===//
68 // PassRegistry
69 //===----------------------------------------------------------------------===//
70 
71 /// Print the help information for this pass. This includes the argument,
72 /// description, and any pass options. `descIndent` is the indent that the
73 /// descriptions should be aligned.
74 void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {
75  printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
76  /*isTopLevel=*/true);
77  // If this entry has options, print the help for those as well.
78  optHandler([=](const PassOptions &options) {
79  options.printHelp(indent, descIndent);
80  });
81 }
82 
83 /// Return the maximum width required when printing the options of this
84 /// entry.
86  size_t maxLen = 0;
87  optHandler([&](const PassOptions &options) mutable {
88  maxLen = options.getOptionWidth() + 2;
89  });
90  return maxLen;
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // PassPipelineInfo
95 //===----------------------------------------------------------------------===//
96 
98  StringRef arg, StringRef description, const PassRegistryFunction &function,
99  std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {
100  PassPipelineInfo pipelineInfo(arg, description, function,
101  std::move(optHandler));
102  bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;
103 #ifndef NDEBUG
104  if (!inserted)
105  report_fatal_error("Pass pipeline " + arg + " registered multiple times");
106 #endif
107  (void)inserted;
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // PassInfo
112 //===----------------------------------------------------------------------===//
113 
114 PassInfo::PassInfo(StringRef arg, StringRef description,
115  const PassAllocatorFunction &allocator)
117  arg, description, buildDefaultRegistryFn(allocator),
118  // Use a temporary pass to provide an options instance.
119  [=](function_ref<void(const PassOptions &)> optHandler) {
120  optHandler(allocator()->passOptions);
121  }) {}
122 
124  std::unique_ptr<Pass> pass = function();
125  StringRef arg = pass->getArgument();
126  if (arg.empty())
127  llvm::report_fatal_error(llvm::Twine("Trying to register '") +
128  pass->getName() +
129  "' pass that does not override `getArgument()`");
130  StringRef description = pass->getDescription();
131  PassInfo passInfo(arg, description, function);
132  passRegistry->try_emplace(arg, passInfo);
133 
134  // Verify that the registered pass has the same ID as any registered to this
135  // arg before it.
136  TypeID entryTypeID = pass->getTypeID();
137  auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;
138  if (it->second != entryTypeID)
139  llvm::report_fatal_error(
140  "pass allocator creates a different pass than previously "
141  "registered for pass " +
142  arg);
143 }
144 
145 /// Returns the pass info for the specified pass argument or null if unknown.
146 const PassInfo *mlir::PassInfo::lookup(StringRef passArg) {
147  auto it = passRegistry->find(passArg);
148  return it == passRegistry->end() ? nullptr : &it->second;
149 }
150 
151 /// Returns the pass pipeline info for the specified pass pipeline argument or
152 /// null if unknown.
153 const PassPipelineInfo *mlir::PassPipelineInfo::lookup(StringRef pipelineArg) {
154  auto it = passPipelineRegistry->find(pipelineArg);
155  return it == passPipelineRegistry->end() ? nullptr : &it->second;
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // PassOptions
160 //===----------------------------------------------------------------------===//
161 
163  llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
164  function_ref<LogicalResult(StringRef)> elementParseFn) {
165  // Functor used for finding a character in a string, and skipping over
166  // various "range" characters.
167  llvm::unique_function<size_t(StringRef, size_t, char)> findChar =
168  [&](StringRef str, size_t index, char c) -> size_t {
169  for (size_t i = index, e = str.size(); i < e; ++i) {
170  if (str[i] == c)
171  return i;
172  // Check for various range characters.
173  if (str[i] == '{')
174  i = findChar(str, i + 1, '}');
175  else if (str[i] == '(')
176  i = findChar(str, i + 1, ')');
177  else if (str[i] == '[')
178  i = findChar(str, i + 1, ']');
179  else if (str[i] == '\"')
180  i = str.find_first_of('\"', i + 1);
181  else if (str[i] == '\'')
182  i = str.find_first_of('\'', i + 1);
183  }
184  return StringRef::npos;
185  };
186 
187  size_t nextElePos = findChar(optionStr, 0, ',');
188  while (nextElePos != StringRef::npos) {
189  // Process the portion before the comma.
190  if (failed(elementParseFn(optionStr.substr(0, nextElePos))))
191  return failure();
192 
193  optionStr = optionStr.substr(nextElePos + 1);
194  nextElePos = findChar(optionStr, 0, ',');
195  }
196  return elementParseFn(optionStr.substr(0, nextElePos));
197 }
198 
199 /// Out of line virtual function to provide home for the class.
200 void detail::PassOptions::OptionBase::anchor() {}
201 
202 /// Copy the option values from 'other'.
204  assert(options.size() == other.options.size());
205  if (options.empty())
206  return;
207  for (auto optionsIt : llvm::zip(options, other.options))
208  std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
209 }
210 
211 /// Parse in the next argument from the given options string. Returns a tuple
212 /// containing [the key of the option, the value of the option, updated
213 /// `options` string pointing after the parsed option].
214 static std::tuple<StringRef, StringRef, StringRef>
215 parseNextArg(StringRef options) {
216  // Functor used to extract an argument from 'options' and update it to point
217  // after the arg.
218  auto extractArgAndUpdateOptions = [&](size_t argSize) {
219  StringRef str = options.take_front(argSize).trim();
220  options = options.drop_front(argSize).ltrim();
221  // Handle escape sequences
222  if (str.size() > 2) {
223  const auto escapePairs = {std::make_pair('\'', '\''),
224  std::make_pair('"', '"'),
225  std::make_pair('{', '}')};
226  for (const auto &escape : escapePairs) {
227  if (str.front() == escape.first && str.back() == escape.second) {
228  // Drop the escape characters and trim.
229  str = str.drop_front().drop_back().trim();
230  // Don't process additional escape sequences.
231  break;
232  }
233  }
234  }
235  return str;
236  };
237  // Try to process the given punctuation, properly escaping any contained
238  // characters.
239  auto tryProcessPunct = [&](size_t &currentPos, char punct) {
240  if (options[currentPos] != punct)
241  return false;
242  size_t nextIt = options.find_first_of(punct, currentPos + 1);
243  if (nextIt != StringRef::npos)
244  currentPos = nextIt;
245  return true;
246  };
247 
248  // Parse the argument name of the option.
249  StringRef argName;
250  for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
251  // Check for the end of the full option.
252  if (argEndIt == optionsE || options[argEndIt] == ' ') {
253  argName = extractArgAndUpdateOptions(argEndIt);
254  return std::make_tuple(argName, StringRef(), options);
255  }
256 
257  // Check for the end of the name and the start of the value.
258  if (options[argEndIt] == '=') {
259  argName = extractArgAndUpdateOptions(argEndIt);
260  options = options.drop_front();
261  break;
262  }
263  }
264 
265  // Parse the value of the option.
266  for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
267  // Handle the end of the options string.
268  if (argEndIt == optionsE || options[argEndIt] == ' ') {
269  StringRef value = extractArgAndUpdateOptions(argEndIt);
270  return std::make_tuple(argName, value, options);
271  }
272 
273  // Skip over escaped sequences.
274  char c = options[argEndIt];
275  if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
276  continue;
277  // '{...}' is used to specify options to passes, properly escape it so
278  // that we don't accidentally split any nested options.
279  if (c == '{') {
280  size_t braceCount = 1;
281  for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
282  // Allow nested punctuation.
283  if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
284  continue;
285  if (options[argEndIt] == '{')
286  ++braceCount;
287  else if (options[argEndIt] == '}' && --braceCount == 0)
288  break;
289  }
290  // Account for the increment at the top of the loop.
291  --argEndIt;
292  }
293  }
294  llvm_unreachable("unexpected control flow in pass option parsing");
295 }
296 
298  raw_ostream &errorStream) {
299  // NOTE: `options` is modified in place to always refer to the unprocessed
300  // part of the string.
301  while (!options.empty()) {
302  StringRef key, value;
303  std::tie(key, value, options) = parseNextArg(options);
304  if (key.empty())
305  continue;
306 
307  auto it = OptionsMap.find(key);
308  if (it == OptionsMap.end()) {
309  errorStream << "<Pass-Options-Parser>: no such option " << key << "\n";
310  return failure();
311  }
312  if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
313  return failure();
314  }
315 
316  return success();
317 }
318 
319 /// Print the options held by this struct in a form that can be parsed via
320 /// 'parseFromString'.
321 void detail::PassOptions::print(raw_ostream &os) {
322  // If there are no options, there is nothing left to do.
323  if (OptionsMap.empty())
324  return;
325 
326  // Sort the options to make the ordering deterministic.
327  SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
328  auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
329  return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
330  };
331  llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
332 
333  // Interleave the options with ' '.
334  os << '{';
335  llvm::interleave(
336  orderedOps, os, [&](OptionBase *option) { option->print(os); }, " ");
337  os << '}';
338 }
339 
340 /// Print the help string for the options held by this struct. `descIndent` is
341 /// the indent within the stream that the descriptions should be aligned.
342 void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {
343  // Sort the options to make the ordering deterministic.
344  SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
345  auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
346  return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
347  };
348  llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
349  for (OptionBase *option : orderedOps) {
350  // TODO: printOptionInfo assumes a specific indent and will
351  // print options with values with incorrect indentation. We should add
352  // support to llvm::cl::Option for passing in a base indent to use when
353  // printing.
354  llvm::outs().indent(indent);
355  option->getOption()->printOptionInfo(descIndent - indent);
356  }
357 }
358 
359 /// Return the maximum width required when printing the help string.
361  size_t max = 0;
362  for (auto *option : options)
363  max = std::max(max, option->getOption()->getOptionWidth());
364  return max;
365 }
366 
367 //===----------------------------------------------------------------------===//
368 // MLIR Options
369 //===----------------------------------------------------------------------===//
370 
371 //===----------------------------------------------------------------------===//
372 // OpPassManager: OptionValue
373 
374 llvm::cl::OptionValue<OpPassManager>::OptionValue() = default;
375 llvm::cl::OptionValue<OpPassManager>::OptionValue(
376  const mlir::OpPassManager &value) {
377  setValue(value);
378 }
379 llvm::cl::OptionValue<OpPassManager>::OptionValue(
381  if (rhs.hasValue())
382  setValue(rhs.getValue());
383 }
384 llvm::cl::OptionValue<OpPassManager> &
385 llvm::cl::OptionValue<OpPassManager>::operator=(
386  const mlir::OpPassManager &rhs) {
387  setValue(rhs);
388  return *this;
389 }
390 
391 llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default;
392 
393 void llvm::cl::OptionValue<OpPassManager>::setValue(
394  const OpPassManager &newValue) {
395  if (hasValue())
396  *value = newValue;
397  else
398  value = std::make_unique<mlir::OpPassManager>(newValue);
399 }
400 void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
401  FailureOr<OpPassManager> pipeline = parsePassPipeline(pipelineStr);
402  assert(succeeded(pipeline) && "invalid pass pipeline");
403  setValue(*pipeline);
404 }
405 
407  const mlir::OpPassManager &rhs) const {
408  std::string lhsStr, rhsStr;
409  {
410  raw_string_ostream lhsStream(lhsStr);
411  value->printAsTextualPipeline(lhsStream);
412 
413  raw_string_ostream rhsStream(rhsStr);
414  rhs.printAsTextualPipeline(rhsStream);
415  }
416 
417  // Use the textual format for pipeline comparisons.
418  return lhsStr == rhsStr;
419 }
420 
421 void llvm::cl::OptionValue<OpPassManager>::anchor() {}
422 
423 //===----------------------------------------------------------------------===//
424 // OpPassManager: Parser
425 
426 namespace llvm {
427 namespace cl {
428 template class basic_parser<OpPassManager>;
429 } // namespace cl
430 } // namespace llvm
431 
432 bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg,
433  ParsedPassManager &value) {
434  FailureOr<OpPassManager> pipeline = parsePassPipeline(arg);
435  if (failed(pipeline))
436  return true;
437  value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
438  return false;
439 }
440 
441 void llvm::cl::parser<OpPassManager>::print(raw_ostream &os,
442  const OpPassManager &value) {
443  value.printAsTextualPipeline(os);
444 }
445 
447  const Option &opt, OpPassManager &pm, const OptVal &defaultValue,
448  size_t globalWidth) const {
449  printOptionName(opt, globalWidth);
450  outs() << "= ";
451  pm.printAsTextualPipeline(outs());
452 
453  if (defaultValue.hasValue()) {
454  outs().indent(2) << " (default: ";
455  defaultValue.getValue().printAsTextualPipeline(outs());
456  outs() << ")";
457  }
458  outs() << "\n";
459 }
460 
462 
464  default;
466  ParsedPassManager &&) = default;
468  default;
469 
470 //===----------------------------------------------------------------------===//
471 // TextualPassPipeline Parser
472 //===----------------------------------------------------------------------===//
473 
474 namespace {
475 /// This class represents a textual description of a pass pipeline.
476 class TextualPipeline {
477 public:
478  /// Try to initialize this pipeline with the given pipeline text.
479  /// `errorStream` is the output stream to emit errors to.
480  LogicalResult initialize(StringRef text, raw_ostream &errorStream);
481 
482  /// Add the internal pipeline elements to the provided pass manager.
483  LogicalResult
484  addToPipeline(OpPassManager &pm,
485  function_ref<LogicalResult(const Twine &)> errorHandler) const;
486 
487 private:
488  /// A functor used to emit errors found during pipeline handling. The first
489  /// parameter corresponds to the raw location within the pipeline string. This
490  /// should always return failure.
491  using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>;
492 
493  /// A struct to capture parsed pass pipeline names.
494  ///
495  /// A pipeline is defined as a series of names, each of which may in itself
496  /// recursively contain a nested pipeline. A name is either the name of a pass
497  /// (e.g. "cse") or the name of an operation type (e.g. "buitin.module"). If
498  /// the name is the name of a pass, the InnerPipeline is empty, since passes
499  /// cannot contain inner pipelines.
500  struct PipelineElement {
501  PipelineElement(StringRef name) : name(name) {}
502 
503  StringRef name;
504  StringRef options;
505  const PassRegistryEntry *registryEntry = nullptr;
506  std::vector<PipelineElement> innerPipeline;
507  };
508 
509  /// Parse the given pipeline text into the internal pipeline vector. This
510  /// function only parses the structure of the pipeline, and does not resolve
511  /// its elements.
512  LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
513 
514  /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
515  /// the corresponding registry entry.
516  LogicalResult
517  resolvePipelineElements(MutableArrayRef<PipelineElement> elements,
518  ErrorHandlerT errorHandler);
519 
520  /// Resolve a single element of the pipeline.
521  LogicalResult resolvePipelineElement(PipelineElement &element,
522  ErrorHandlerT errorHandler);
523 
524  /// Add the given pipeline elements to the provided pass manager.
525  LogicalResult
526  addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
527  function_ref<LogicalResult(const Twine &)> errorHandler) const;
528 
529  std::vector<PipelineElement> pipeline;
530 };
531 
532 } // namespace
533 
534 /// Try to initialize this pipeline with the given pipeline text. An option is
535 /// given to enable accurate error reporting.
536 LogicalResult TextualPipeline::initialize(StringRef text,
537  raw_ostream &errorStream) {
538  if (text.empty())
539  return success();
540 
541  // Build a source manager to use for error reporting.
542  llvm::SourceMgr pipelineMgr;
543  pipelineMgr.AddNewSourceBuffer(
544  llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser",
545  /*RequiresNullTerminator=*/false),
546  SMLoc());
547  auto errorHandler = [&](const char *rawLoc, Twine msg) {
548  pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc),
549  llvm::SourceMgr::DK_Error, msg);
550  return failure();
551  };
552 
553  // Parse the provided pipeline string.
554  if (failed(parsePipelineText(text, errorHandler)))
555  return failure();
556  return resolvePipelineElements(pipeline, errorHandler);
557 }
558 
559 /// Add the internal pipeline elements to the provided pass manager.
560 LogicalResult TextualPipeline::addToPipeline(
561  OpPassManager &pm,
562  function_ref<LogicalResult(const Twine &)> errorHandler) const {
563  // Temporarily disable implicit nesting while we append to the pipeline. We
564  // want the created pipeline to exactly match the parsed text pipeline, so
565  // it's preferrable to just error out if implicit nesting would be required.
566  OpPassManager::Nesting nesting = pm.getNesting();
568  auto restore = llvm::make_scope_exit([&]() { pm.setNesting(nesting); });
569 
570  return addToPipeline(pipeline, pm, errorHandler);
571 }
572 
573 /// Parse the given pipeline text into the internal pipeline vector. This
574 /// function only parses the structure of the pipeline, and does not resolve
575 /// its elements.
576 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
577  ErrorHandlerT errorHandler) {
578  SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
579  for (;;) {
580  std::vector<PipelineElement> &pipeline = *pipelineStack.back();
581  size_t pos = text.find_first_of(",(){");
582  pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());
583 
584  // If we have a single terminating name, we're done.
585  if (pos == StringRef::npos)
586  break;
587 
588  text = text.substr(pos);
589  char sep = text[0];
590 
591  // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
592  if (sep == '{') {
593  text = text.substr(1);
594 
595  // Skip over everything until the closing '}' and store as options.
596  size_t close = StringRef::npos;
597  for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
598  if (text[i] == '{') {
599  ++braceCount;
600  continue;
601  }
602  if (text[i] == '}' && --braceCount == 0) {
603  close = i;
604  break;
605  }
606  }
607 
608  // Check to see if a closing options brace was found.
609  if (close == StringRef::npos) {
610  return errorHandler(
611  /*rawLoc=*/text.data() - 1,
612  "missing closing '}' while processing pass options");
613  }
614  pipeline.back().options = text.substr(0, close);
615  text = text.substr(close + 1);
616 
617  // Consume space characters that an user might add for readability.
618  text = text.ltrim();
619 
620  // Skip checking for '(' because nested pipelines cannot have options.
621  } else if (sep == '(') {
622  text = text.substr(1);
623 
624  // Push the inner pipeline onto the stack to continue processing.
625  pipelineStack.push_back(&pipeline.back().innerPipeline);
626  continue;
627  }
628 
629  // When handling the close parenthesis, we greedily consume them to avoid
630  // empty strings in the pipeline.
631  while (text.consume_front(")")) {
632  // If we try to pop the outer pipeline we have unbalanced parentheses.
633  if (pipelineStack.size() == 1)
634  return errorHandler(/*rawLoc=*/text.data() - 1,
635  "encountered extra closing ')' creating unbalanced "
636  "parentheses while parsing pipeline");
637 
638  pipelineStack.pop_back();
639  // Consume space characters that an user might add for readability.
640  text = text.ltrim();
641  }
642 
643  // Check if we've finished parsing.
644  if (text.empty())
645  break;
646 
647  // Otherwise, the end of an inner pipeline always has to be followed by
648  // a comma, and then we can continue.
649  if (!text.consume_front(","))
650  return errorHandler(text.data(), "expected ',' after parsing pipeline");
651  }
652 
653  // Check for unbalanced parentheses.
654  if (pipelineStack.size() > 1)
655  return errorHandler(
656  text.data(),
657  "encountered unbalanced parentheses while parsing pipeline");
658 
659  assert(pipelineStack.back() == &pipeline &&
660  "wrong pipeline at the bottom of the stack");
661  return success();
662 }
663 
664 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
665 /// the corresponding registry entry.
666 LogicalResult TextualPipeline::resolvePipelineElements(
667  MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) {
668  for (auto &elt : elements)
669  if (failed(resolvePipelineElement(elt, errorHandler)))
670  return failure();
671  return success();
672 }
673 
674 /// Resolve a single element of the pipeline.
675 LogicalResult
676 TextualPipeline::resolvePipelineElement(PipelineElement &element,
677  ErrorHandlerT errorHandler) {
678  // If the inner pipeline of this element is not empty, this is an operation
679  // pipeline.
680  if (!element.innerPipeline.empty())
681  return resolvePipelineElements(element.innerPipeline, errorHandler);
682 
683  // Otherwise, this must be a pass or pass pipeline.
684  // Check to see if a pipeline was registered with this name.
685  if ((element.registryEntry = PassPipelineInfo::lookup(element.name)))
686  return success();
687 
688  // If not, then this must be a specific pass name.
689  if ((element.registryEntry = PassInfo::lookup(element.name)))
690  return success();
691 
692  // Emit an error for the unknown pass.
693  auto *rawLoc = element.name.data();
694  return errorHandler(rawLoc, "'" + element.name +
695  "' does not refer to a "
696  "registered pass or pass pipeline");
697 }
698 
699 /// Add the given pipeline elements to the provided pass manager.
700 LogicalResult TextualPipeline::addToPipeline(
702  function_ref<LogicalResult(const Twine &)> errorHandler) const {
703  for (auto &elt : elements) {
704  if (elt.registryEntry) {
705  if (failed(elt.registryEntry->addToPipeline(pm, elt.options,
706  errorHandler))) {
707  return errorHandler("failed to add `" + elt.name + "` with options `" +
708  elt.options + "`");
709  }
710  } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),
711  errorHandler))) {
712  return errorHandler("failed to add `" + elt.name + "` with options `" +
713  elt.options + "` to inner pipeline");
714  }
715  }
716  return success();
717 }
718 
719 LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
720  raw_ostream &errorStream) {
721  TextualPipeline pipelineParser;
722  if (failed(pipelineParser.initialize(pipeline, errorStream)))
723  return failure();
724  auto errorHandler = [&](Twine msg) {
725  errorStream << msg << "\n";
726  return failure();
727  };
728  if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
729  return failure();
730  return success();
731 }
732 
733 FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline,
734  raw_ostream &errorStream) {
735  pipeline = pipeline.trim();
736  // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
737  size_t pipelineStart = pipeline.find_first_of('(');
738  if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
739  !pipeline.consume_back(")")) {
740  errorStream << "expected pass pipeline to be wrapped with the anchor "
741  "operation type, e.g. 'builtin.module(...)'";
742  return failure();
743  }
744 
745  StringRef opName = pipeline.take_front(pipelineStart).rtrim();
746  OpPassManager pm(opName);
747  if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm,
748  errorStream)))
749  return failure();
750  return pm;
751 }
752 
753 //===----------------------------------------------------------------------===//
754 // PassNameParser
755 //===----------------------------------------------------------------------===//
756 
757 namespace {
758 /// This struct represents the possible data entries in a parsed pass pipeline
759 /// list.
760 struct PassArgData {
761  PassArgData() = default;
762  PassArgData(const PassRegistryEntry *registryEntry)
763  : registryEntry(registryEntry) {}
764 
765  /// This field is used when the parsed option corresponds to a registered pass
766  /// or pass pipeline.
767  const PassRegistryEntry *registryEntry{nullptr};
768 
769  /// This field is set when instance specific pass options have been provided
770  /// on the command line.
771  StringRef options;
772 };
773 } // namespace
774 
775 namespace llvm {
776 namespace cl {
777 /// Define a valid OptionValue for the command line pass argument.
778 template <>
780  : OptionValueBase<PassArgData, /*isClass=*/true> {
781  OptionValue(const PassArgData &value) { this->setValue(value); }
782  OptionValue() = default;
783  void anchor() override {}
784 
785  bool hasValue() const { return true; }
786  const PassArgData &getValue() const { return value; }
787  void setValue(const PassArgData &value) { this->value = value; }
788 
789  PassArgData value;
790 };
791 } // namespace cl
792 } // namespace llvm
793 
794 namespace {
795 
796 /// The name for the command line option used for parsing the textual pass
797 /// pipeline.
798 #define PASS_PIPELINE_ARG "pass-pipeline"
799 
800 /// Adds command line option for each registered pass or pass pipeline, as well
801 /// as textual pass pipelines.
802 struct PassNameParser : public llvm::cl::parser<PassArgData> {
803  PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}
804 
805  void initialize();
806  void printOptionInfo(const llvm::cl::Option &opt,
807  size_t globalWidth) const override;
808  size_t getOptionWidth(const llvm::cl::Option &opt) const override;
809  bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
810  PassArgData &value);
811 
812  /// If true, this parser only parses entries that correspond to a concrete
813  /// pass registry entry, and does not include pipeline entries or the options
814  /// for pass entries.
815  bool passNamesOnly = false;
816 };
817 } // namespace
818 
819 void PassNameParser::initialize() {
821 
822  /// Add the pass entries.
823  for (const auto &kv : *passRegistry) {
824  addLiteralOption(kv.second.getPassArgument(), &kv.second,
825  kv.second.getPassDescription());
826  }
827  /// Add the pass pipeline entries.
828  if (!passNamesOnly) {
829  for (const auto &kv : *passPipelineRegistry) {
830  addLiteralOption(kv.second.getPassArgument(), &kv.second,
831  kv.second.getPassDescription());
832  }
833  }
834 }
835 
836 void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,
837  size_t globalWidth) const {
838  // If this parser is just parsing pass names, print a simplified option
839  // string.
840  if (passNamesOnly) {
841  llvm::outs() << " --" << opt.ArgStr << "=<pass-arg>";
842  opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
843  return;
844  }
845 
846  // Print the information for the top-level option.
847  if (opt.hasArgStr()) {
848  llvm::outs() << " --" << opt.ArgStr;
849  opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
850  } else {
851  llvm::outs() << " " << opt.HelpStr << '\n';
852  }
853 
854  // Functor used to print the ordered entries of a registration map.
855  auto printOrderedEntries = [&](StringRef header, auto &map) {
857  for (auto &kv : map)
858  orderedEntries.push_back(&kv.second);
859  llvm::array_pod_sort(
860  orderedEntries.begin(), orderedEntries.end(),
861  [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
862  return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
863  });
864 
865  llvm::outs().indent(4) << header << ":\n";
866  for (PassRegistryEntry *entry : orderedEntries)
867  entry->printHelpStr(/*indent=*/6, globalWidth);
868  };
869 
870  // Print the available passes.
871  printOrderedEntries("Passes", *passRegistry);
872 
873  // Print the available pass pipelines.
874  if (!passPipelineRegistry->empty())
875  printOrderedEntries("Pass Pipelines", *passPipelineRegistry);
876 }
877 
878 size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {
879  size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2;
880 
881  // Check for any wider pass or pipeline options.
882  for (auto &entry : *passRegistry)
883  maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
884  for (auto &entry : *passPipelineRegistry)
885  maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
886  return maxWidth;
887 }
888 
889 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
890  StringRef arg, PassArgData &value) {
891  if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value))
892  return true;
893  value.options = arg;
894  return false;
895 }
896 
897 //===----------------------------------------------------------------------===//
898 // PassPipelineCLParser
899 //===----------------------------------------------------------------------===//
900 
901 namespace mlir {
902 namespace detail {
904  PassPipelineCLParserImpl(StringRef arg, StringRef description,
905  bool passNamesOnly)
906  : passList(arg, llvm::cl::desc(description)) {
907  passList.getParser().passNamesOnly = passNamesOnly;
908  passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
909  }
910 
911  /// Returns true if the given pass registry entry was registered at the
912  /// top-level of the parser, i.e. not within an explicit textual pipeline.
913  bool contains(const PassRegistryEntry *entry) const {
914  return llvm::any_of(passList, [&](const PassArgData &data) {
915  return data.registryEntry == entry;
916  });
917  }
918 
919  /// The set of passes and pass pipelines to run.
920  llvm::cl::list<PassArgData, bool, PassNameParser> passList;
921 };
922 } // namespace detail
923 } // namespace mlir
924 
925 /// Construct a pass pipeline parser with the given command line description.
926 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)
927  : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
928  arg, description, /*passNamesOnly=*/false)),
929  passPipeline(
931  llvm::cl::desc("Textual description of the pass pipeline to run")) {}
932 
933 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description,
934  StringRef alias)
935  : PassPipelineCLParser(arg, description) {
936  passPipelineAlias.emplace(alias,
937  llvm::cl::desc("Alias for --" PASS_PIPELINE_ARG),
938  llvm::cl::aliasopt(passPipeline));
939 }
940 
942 
943 /// Returns true if this parser contains any valid options to add.
945  return passPipeline.getNumOccurrences() != 0 ||
946  impl->passList.getNumOccurrences() != 0;
947 }
948 
949 /// Returns true if the given pass registry entry was registered at the
950 /// top-level of the parser, i.e. not within an explicit textual pipeline.
952  return impl->contains(entry);
953 }
954 
955 /// Adds the passes defined by this parser entry to the given pass manager.
957  OpPassManager &pm,
958  function_ref<LogicalResult(const Twine &)> errorHandler) const {
959  if (passPipeline.getNumOccurrences()) {
960  if (impl->passList.getNumOccurrences())
961  return errorHandler(
962  "'-" PASS_PIPELINE_ARG
963  "' option can't be used with individual pass options");
964  std::string errMsg;
965  llvm::raw_string_ostream os(errMsg);
966  FailureOr<OpPassManager> parsed = parsePassPipeline(passPipeline, os);
967  if (failed(parsed))
968  return errorHandler(errMsg);
969  pm = std::move(*parsed);
970  return success();
971  }
972 
973  for (auto &passIt : impl->passList) {
974  if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
975  errorHandler)))
976  return failure();
977  }
978  return success();
979 }
980 
981 //===----------------------------------------------------------------------===//
982 // PassNameCLParser
983 
984 /// Construct a pass pipeline parser with the given command line description.
985 PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description)
986  : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
987  arg, description, /*passNamesOnly=*/true)) {
988  impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
989 }
991 
992 /// Returns true if this parser contains any valid options to add.
994  return impl->passList.getNumOccurrences() != 0;
995 }
996 
997 /// Returns true if the given pass registry entry was registered at the
998 /// top-level of the parser, i.e. not within an explicit textual pipeline.
1000  return impl->contains(entry);
1001 }
static llvm::ManagedStatic< PassManagerOptions > options
static llvm::ManagedStatic< llvm::StringMap< PassPipelineInfo > > passPipelineRegistry
Static mapping of all of the registered pass pipelines.
#define PASS_PIPELINE_ARG
The name for the command line option used for parsing the textual pass pipeline.
static llvm::ManagedStatic< llvm::StringMap< PassInfo > > passRegistry
Static mapping of all of the registered passes.
static PassRegistryFunction buildDefaultRegistryFn(const PassAllocatorFunction &allocator)
Utility to create a default registry function from a pass instance.
static void printOptionHelp(StringRef arg, StringRef desc, size_t indent, size_t descIndent, bool isTopLevel)
Utility to print the help string for a specific option.
static llvm::ManagedStatic< llvm::StringMap< TypeID > > passRegistryTypeIDs
A mapping of the above pass registry entries to the corresponding TypeID of the pass that they genera...
static std::tuple< StringRef, StringRef, StringRef > parseNextArg(StringRef options)
Parse in the next argument from the given options string.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
This class represents a pass manager that runs passes on either a specific operation type,...
Definition: PassManager.h:47
void printAsTextualPipeline(raw_ostream &os) const
Prints out the passes of the pass manager as the textual representation of pipelines.
Definition: Pass.cpp:402
std::optional< OperationName > getOpName(MLIRContext &context) const
Return the operation name that this pass manager operates on, or std::nullopt if this is an op-agnost...
Definition: Pass.cpp:383
void setNesting(Nesting nesting)
Enable or disable the implicit nesting on this particular PassManager.
Definition: Pass.cpp:426
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
Definition: Pass.cpp:364
Nesting getNesting()
Return the current nesting mode.
Definition: Pass.cpp:428
Nesting
This enum represents the nesting behavior of the pass manager.
Definition: PassManager.h:50
@ Explicit
Explicit nesting behavior.
StringRef getOpAnchorName() const
Return the name used to anchor this pass manager.
Definition: Pass.cpp:387
OpPassManager & nest(OperationName nestedName)
Nest a new operation pass manager for the given operation kind under this pass manager.
Definition: Pass.cpp:354
A structure to represent the information for a derived pass class.
Definition: PassRegistry.h:115
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
PassInfo(StringRef arg, StringRef description, const PassAllocatorFunction &allocator)
PassInfo constructor should not be invoked directly, instead use PassRegistration or registerPass.
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassNameCLParser(StringRef arg, StringRef description)
Construct a parser with the given command line description.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
This class implements a command-line parser for MLIR passes.
Definition: PassRegistry.h:244
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassPipelineCLParser(StringRef arg, StringRef description)
Construct a pass pipeline parser with the given command line description.
LogicalResult addToPipeline(OpPassManager &pm, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds the passes defined by this parser entry to the given pass manager.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
A structure to represent the information of a registered pass pipeline.
Definition: PassRegistry.h:101
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
Definition: PassRegistry.h:49
void printHelpStr(size_t indent, size_t descIndent) const
Print the help information for this pass.
size_t getOptionWidth() const
Return the maximum width required when printing the options of this entry.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
Base container class and manager for all pass options.
Definition: PassOptions.h:92
size_t getOptionWidth() const
Return the maximum width required when printing the help string.
void printHelp(size_t indent, size_t descIndent) const
Print the help string for the options held by this struct.
LogicalResult parseFromString(StringRef options, raw_ostream &errorStream=llvm::errs())
Parse options out as key=value pairs that can then be handed off to the llvm::cl command line passing...
void copyOptionValuesFrom(const PassOptions &other)
Copy the option values from 'other' into 'this', where 'other' has the same options as 'this'.
void print(raw_ostream &os)
Print the options held by this struct in a form that can be parsed via 'parseFromString'.
Include the generated interface declarations.
Definition: CallGraph.h:229
LogicalResult parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName, StringRef optionStr, function_ref< LogicalResult(StringRef)> elementParseFn)
Parse a string containing a list of comma-delimited elements, invoking the given parser for each sub-...
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:67
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
std::function< std::unique_ptr< Pass >()> PassAllocatorFunction
Definition: PassRegistry.h:41
void registerPass(const PassAllocatorFunction &function)
Register a specific dialect pass allocator function with the system, typically used through the PassR...
void registerPassPipeline(StringRef arg, StringRef description, const PassRegistryFunction &function, std::function< void(function_ref< void(const detail::PassOptions &)>)> optHandler)
Register a specific dialect pipeline registry function with the system, typically used through the Pa...
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
std::function< LogicalResult(OpPassManager &, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler)> PassRegistryFunction
A registry function that adds passes to the given pass manager.
Definition: PassRegistry.h:40
Define a valid OptionValue for the command line pass argument.
OptionValue(const PassArgData &value)
void setValue(const PassArgData &value)
const PassArgData & getValue() const
mlir::OpPassManager & getValue() const
Returns the current value of the option.
Definition: PassOptions.h:448
bool hasValue() const
Returns if the current option has a value.
Definition: PassOptions.h:445
llvm::cl::list< PassArgData, bool, PassNameParser > passList
The set of passes and pass pipelines to run.
PassPipelineCLParserImpl(StringRef arg, StringRef description, bool passNamesOnly)
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...