MLIR  18.0.0git
PassRegistry.cpp
Go to the documentation of this file.
1 //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/Pass/Pass.h"
12 #include "mlir/Pass/PassManager.h"
13 #include "llvm/ADT/DenseMap.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/Support/Format.h"
16 #include "llvm/Support/ManagedStatic.h"
17 #include "llvm/Support/MemoryBuffer.h"
18 #include "llvm/Support/SourceMgr.h"
19 
20 #include <optional>
21 #include <utility>
22 
23 using namespace mlir;
24 using namespace detail;
25 
26 /// Static mapping of all of the registered passes.
27 static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
28 
29 /// A mapping of the above pass registry entries to the corresponding TypeID
30 /// of the pass that they generate.
31 static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
32 
33 /// Static mapping of all of the registered pass pipelines.
34 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
36 
37 /// Utility to create a default registry function from a pass instance.
40  return [=](OpPassManager &pm, StringRef options,
41  function_ref<LogicalResult(const Twine &)> errorHandler) {
42  std::unique_ptr<Pass> pass = allocator();
43  LogicalResult result = pass->initializeOptions(options);
44 
45  std::optional<StringRef> pmOpName = pm.getOpName();
46  std::optional<StringRef> passOpName = pass->getOpName();
47  if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && pmOpName &&
48  passOpName && *pmOpName != *passOpName) {
49  return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
50  "' restricted to '" + *pass->getOpName() +
51  "' on a PassManager intended to run on '" +
52  pm.getOpAnchorName() + "', did you intend to nest?");
53  }
54  pm.addPass(std::move(pass));
55  return result;
56  };
57 }
58 
59 /// Utility to print the help string for a specific option.
60 static void printOptionHelp(StringRef arg, StringRef desc, size_t indent,
61  size_t descIndent, bool isTopLevel) {
62  size_t numSpaces = descIndent - indent - 4;
63  llvm::outs().indent(indent)
64  << "--" << llvm::left_justify(arg, numSpaces) << "- " << desc << '\n';
65 }
66 
67 //===----------------------------------------------------------------------===//
68 // PassRegistry
69 //===----------------------------------------------------------------------===//
70 
71 /// Print the help information for this pass. This includes the argument,
72 /// description, and any pass options. `descIndent` is the indent that the
73 /// descriptions should be aligned.
74 void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {
75  printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
76  /*isTopLevel=*/true);
77  // If this entry has options, print the help for those as well.
78  optHandler([=](const PassOptions &options) {
79  options.printHelp(indent, descIndent);
80  });
81 }
82 
83 /// Return the maximum width required when printing the options of this
84 /// entry.
86  size_t maxLen = 0;
87  optHandler([&](const PassOptions &options) mutable {
88  maxLen = options.getOptionWidth() + 2;
89  });
90  return maxLen;
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // PassPipelineInfo
95 //===----------------------------------------------------------------------===//
96 
98  StringRef arg, StringRef description, const PassRegistryFunction &function,
99  std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {
100  PassPipelineInfo pipelineInfo(arg, description, function,
101  std::move(optHandler));
102  bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;
103 #ifndef NDEBUG
104  if (!inserted)
105  report_fatal_error("Pass pipeline " + arg + " registered multiple times");
106 #endif
107  (void)inserted;
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // PassInfo
112 //===----------------------------------------------------------------------===//
113 
114 PassInfo::PassInfo(StringRef arg, StringRef description,
115  const PassAllocatorFunction &allocator)
117  arg, description, buildDefaultRegistryFn(allocator),
118  // Use a temporary pass to provide an options instance.
119  [=](function_ref<void(const PassOptions &)> optHandler) {
120  optHandler(allocator()->passOptions);
121  }) {}
122 
124  std::unique_ptr<Pass> pass = function();
125  StringRef arg = pass->getArgument();
126  if (arg.empty())
127  llvm::report_fatal_error(llvm::Twine("Trying to register '") +
128  pass->getName() +
129  "' pass that does not override `getArgument()`");
130  StringRef description = pass->getDescription();
131  PassInfo passInfo(arg, description, function);
132  passRegistry->try_emplace(arg, passInfo);
133 
134  // Verify that the registered pass has the same ID as any registered to this
135  // arg before it.
136  TypeID entryTypeID = pass->getTypeID();
137  auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;
138  if (it->second != entryTypeID)
139  llvm::report_fatal_error(
140  "pass allocator creates a different pass than previously "
141  "registered for pass " +
142  arg);
143 }
144 
145 /// Returns the pass info for the specified pass argument or null if unknown.
146 const PassInfo *mlir::PassInfo::lookup(StringRef passArg) {
147  auto it = passRegistry->find(passArg);
148  return it == passRegistry->end() ? nullptr : &it->second;
149 }
150 
151 /// Returns the pass pipeline info for the specified pass pipeline argument or
152 /// null if unknown.
153 const PassPipelineInfo *mlir::PassPipelineInfo::lookup(StringRef pipelineArg) {
154  auto it = passPipelineRegistry->find(pipelineArg);
155  return it == passPipelineRegistry->end() ? nullptr : &it->second;
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // PassOptions
160 //===----------------------------------------------------------------------===//
161 
163  llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
164  function_ref<LogicalResult(StringRef)> elementParseFn) {
165  // Functor used for finding a character in a string, and skipping over
166  // various "range" characters.
167  llvm::unique_function<size_t(StringRef, size_t, char)> findChar =
168  [&](StringRef str, size_t index, char c) -> size_t {
169  for (size_t i = index, e = str.size(); i < e; ++i) {
170  if (str[i] == c)
171  return i;
172  // Check for various range characters.
173  if (str[i] == '{')
174  i = findChar(str, i + 1, '}');
175  else if (str[i] == '(')
176  i = findChar(str, i + 1, ')');
177  else if (str[i] == '[')
178  i = findChar(str, i + 1, ']');
179  else if (str[i] == '\"')
180  i = str.find_first_of('\"', i + 1);
181  else if (str[i] == '\'')
182  i = str.find_first_of('\'', i + 1);
183  }
184  return StringRef::npos;
185  };
186 
187  size_t nextElePos = findChar(optionStr, 0, ',');
188  while (nextElePos != StringRef::npos) {
189  // Process the portion before the comma.
190  if (failed(elementParseFn(optionStr.substr(0, nextElePos))))
191  return failure();
192 
193  optionStr = optionStr.substr(nextElePos + 1);
194  nextElePos = findChar(optionStr, 0, ',');
195  }
196  return elementParseFn(optionStr.substr(0, nextElePos));
197 }
198 
199 /// Out of line virtual function to provide home for the class.
200 void detail::PassOptions::OptionBase::anchor() {}
201 
202 /// Copy the option values from 'other'.
204  assert(options.size() == other.options.size());
205  if (options.empty())
206  return;
207  for (auto optionsIt : llvm::zip(options, other.options))
208  std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
209 }
210 
211 /// Parse in the next argument from the given options string. Returns a tuple
212 /// containing [the key of the option, the value of the option, updated
213 /// `options` string pointing after the parsed option].
214 static std::tuple<StringRef, StringRef, StringRef>
215 parseNextArg(StringRef options) {
216  // Functor used to extract an argument from 'options' and update it to point
217  // after the arg.
218  auto extractArgAndUpdateOptions = [&](size_t argSize) {
219  StringRef str = options.take_front(argSize).trim();
220  options = options.drop_front(argSize).ltrim();
221  return str;
222  };
223  // Try to process the given punctuation, properly escaping any contained
224  // characters.
225  auto tryProcessPunct = [&](size_t &currentPos, char punct) {
226  if (options[currentPos] != punct)
227  return false;
228  size_t nextIt = options.find_first_of(punct, currentPos + 1);
229  if (nextIt != StringRef::npos)
230  currentPos = nextIt;
231  return true;
232  };
233 
234  // Parse the argument name of the option.
235  StringRef argName;
236  for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
237  // Check for the end of the full option.
238  if (argEndIt == optionsE || options[argEndIt] == ' ') {
239  argName = extractArgAndUpdateOptions(argEndIt);
240  return std::make_tuple(argName, StringRef(), options);
241  }
242 
243  // Check for the end of the name and the start of the value.
244  if (options[argEndIt] == '=') {
245  argName = extractArgAndUpdateOptions(argEndIt);
246  options = options.drop_front();
247  break;
248  }
249  }
250 
251  // Parse the value of the option.
252  for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
253  // Handle the end of the options string.
254  if (argEndIt == optionsE || options[argEndIt] == ' ') {
255  StringRef value = extractArgAndUpdateOptions(argEndIt);
256  return std::make_tuple(argName, value, options);
257  }
258 
259  // Skip over escaped sequences.
260  char c = options[argEndIt];
261  if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
262  continue;
263  // '{...}' is used to specify options to passes, properly escape it so
264  // that we don't accidentally split any nested options.
265  if (c == '{') {
266  size_t braceCount = 1;
267  for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
268  // Allow nested punctuation.
269  if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
270  continue;
271  if (options[argEndIt] == '{')
272  ++braceCount;
273  else if (options[argEndIt] == '}' && --braceCount == 0)
274  break;
275  }
276  // Account for the increment at the top of the loop.
277  --argEndIt;
278  }
279  }
280  llvm_unreachable("unexpected control flow in pass option parsing");
281 }
282 
284  // NOTE: `options` is modified in place to always refer to the unprocessed
285  // part of the string.
286  while (!options.empty()) {
287  StringRef key, value;
288  std::tie(key, value, options) = parseNextArg(options);
289  if (key.empty())
290  continue;
291 
292  auto it = OptionsMap.find(key);
293  if (it == OptionsMap.end()) {
294  llvm::errs() << "<Pass-Options-Parser>: no such option " << key << "\n";
295  return failure();
296  }
297  if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
298  return failure();
299  }
300 
301  return success();
302 }
303 
304 /// Print the options held by this struct in a form that can be parsed via
305 /// 'parseFromString'.
306 void detail::PassOptions::print(raw_ostream &os) {
307  // If there are no options, there is nothing left to do.
308  if (OptionsMap.empty())
309  return;
310 
311  // Sort the options to make the ordering deterministic.
312  SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
313  auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
314  return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
315  };
316  llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
317 
318  // Interleave the options with ' '.
319  os << '{';
320  llvm::interleave(
321  orderedOps, os, [&](OptionBase *option) { option->print(os); }, " ");
322  os << '}';
323 }
324 
325 /// Print the help string for the options held by this struct. `descIndent` is
326 /// the indent within the stream that the descriptions should be aligned.
327 void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {
328  // Sort the options to make the ordering deterministic.
329  SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
330  auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
331  return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
332  };
333  llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
334  for (OptionBase *option : orderedOps) {
335  // TODO: printOptionInfo assumes a specific indent and will
336  // print options with values with incorrect indentation. We should add
337  // support to llvm::cl::Option for passing in a base indent to use when
338  // printing.
339  llvm::outs().indent(indent);
340  option->getOption()->printOptionInfo(descIndent - indent);
341  }
342 }
343 
344 /// Return the maximum width required when printing the help string.
346  size_t max = 0;
347  for (auto *option : options)
348  max = std::max(max, option->getOption()->getOptionWidth());
349  return max;
350 }
351 
352 //===----------------------------------------------------------------------===//
353 // MLIR Options
354 //===----------------------------------------------------------------------===//
355 
356 //===----------------------------------------------------------------------===//
357 // OpPassManager: OptionValue
358 
359 llvm::cl::OptionValue<OpPassManager>::OptionValue() = default;
360 llvm::cl::OptionValue<OpPassManager>::OptionValue(
361  const mlir::OpPassManager &value) {
362  setValue(value);
363 }
364 llvm::cl::OptionValue<OpPassManager>::OptionValue(
366  if (rhs.hasValue())
367  setValue(rhs.getValue());
368 }
369 llvm::cl::OptionValue<OpPassManager> &
370 llvm::cl::OptionValue<OpPassManager>::operator=(
371  const mlir::OpPassManager &rhs) {
372  setValue(rhs);
373  return *this;
374 }
375 
376 llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default;
377 
378 void llvm::cl::OptionValue<OpPassManager>::setValue(
379  const OpPassManager &newValue) {
380  if (hasValue())
381  *value = newValue;
382  else
383  value = std::make_unique<mlir::OpPassManager>(newValue);
384 }
385 void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
386  FailureOr<OpPassManager> pipeline = parsePassPipeline(pipelineStr);
387  assert(succeeded(pipeline) && "invalid pass pipeline");
388  setValue(*pipeline);
389 }
390 
392  const mlir::OpPassManager &rhs) const {
393  std::string lhsStr, rhsStr;
394  {
395  raw_string_ostream lhsStream(lhsStr);
396  value->printAsTextualPipeline(lhsStream);
397 
398  raw_string_ostream rhsStream(rhsStr);
399  rhs.printAsTextualPipeline(rhsStream);
400  }
401 
402  // Use the textual format for pipeline comparisons.
403  return lhsStr == rhsStr;
404 }
405 
406 void llvm::cl::OptionValue<OpPassManager>::anchor() {}
407 
408 //===----------------------------------------------------------------------===//
409 // OpPassManager: Parser
410 
411 namespace llvm {
412 namespace cl {
413 template class basic_parser<OpPassManager>;
414 } // namespace cl
415 } // namespace llvm
416 
417 bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg,
418  ParsedPassManager &value) {
420  if (failed(pipeline))
421  return true;
422  value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
423  return false;
424 }
425 
426 void llvm::cl::parser<OpPassManager>::print(raw_ostream &os,
427  const OpPassManager &value) {
428  value.printAsTextualPipeline(os);
429 }
430 
432  const Option &opt, OpPassManager &pm, const OptVal &defaultValue,
433  size_t globalWidth) const {
434  printOptionName(opt, globalWidth);
435  outs() << "= ";
436  pm.printAsTextualPipeline(outs());
437 
438  if (defaultValue.hasValue()) {
439  outs().indent(2) << " (default: ";
440  defaultValue.getValue().printAsTextualPipeline(outs());
441  outs() << ")";
442  }
443  outs() << "\n";
444 }
445 
447 
449  default;
451  ParsedPassManager &&) = default;
453  default;
454 
455 //===----------------------------------------------------------------------===//
456 // TextualPassPipeline Parser
457 //===----------------------------------------------------------------------===//
458 
459 namespace {
460 /// This class represents a textual description of a pass pipeline.
461 class TextualPipeline {
462 public:
463  /// Try to initialize this pipeline with the given pipeline text.
464  /// `errorStream` is the output stream to emit errors to.
465  LogicalResult initialize(StringRef text, raw_ostream &errorStream);
466 
467  /// Add the internal pipeline elements to the provided pass manager.
469  addToPipeline(OpPassManager &pm,
470  function_ref<LogicalResult(const Twine &)> errorHandler) const;
471 
472 private:
473  /// A functor used to emit errors found during pipeline handling. The first
474  /// parameter corresponds to the raw location within the pipeline string. This
475  /// should always return failure.
476  using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>;
477 
478  /// A struct to capture parsed pass pipeline names.
479  ///
480  /// A pipeline is defined as a series of names, each of which may in itself
481  /// recursively contain a nested pipeline. A name is either the name of a pass
482  /// (e.g. "cse") or the name of an operation type (e.g. "buitin.module"). If
483  /// the name is the name of a pass, the InnerPipeline is empty, since passes
484  /// cannot contain inner pipelines.
485  struct PipelineElement {
486  PipelineElement(StringRef name) : name(name) {}
487 
488  StringRef name;
489  StringRef options;
490  const PassRegistryEntry *registryEntry = nullptr;
491  std::vector<PipelineElement> innerPipeline;
492  };
493 
494  /// Parse the given pipeline text into the internal pipeline vector. This
495  /// function only parses the structure of the pipeline, and does not resolve
496  /// its elements.
497  LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
498 
499  /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
500  /// the corresponding registry entry.
502  resolvePipelineElements(MutableArrayRef<PipelineElement> elements,
503  ErrorHandlerT errorHandler);
504 
505  /// Resolve a single element of the pipeline.
506  LogicalResult resolvePipelineElement(PipelineElement &element,
507  ErrorHandlerT errorHandler);
508 
509  /// Add the given pipeline elements to the provided pass manager.
511  addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
512  function_ref<LogicalResult(const Twine &)> errorHandler) const;
513 
514  std::vector<PipelineElement> pipeline;
515 };
516 
517 } // namespace
518 
519 /// Try to initialize this pipeline with the given pipeline text. An option is
520 /// given to enable accurate error reporting.
521 LogicalResult TextualPipeline::initialize(StringRef text,
522  raw_ostream &errorStream) {
523  if (text.empty())
524  return success();
525 
526  // Build a source manager to use for error reporting.
527  llvm::SourceMgr pipelineMgr;
528  pipelineMgr.AddNewSourceBuffer(
529  llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser",
530  /*RequiresNullTerminator=*/false),
531  SMLoc());
532  auto errorHandler = [&](const char *rawLoc, Twine msg) {
533  pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc),
534  llvm::SourceMgr::DK_Error, msg);
535  return failure();
536  };
537 
538  // Parse the provided pipeline string.
539  if (failed(parsePipelineText(text, errorHandler)))
540  return failure();
541  return resolvePipelineElements(pipeline, errorHandler);
542 }
543 
544 /// Add the internal pipeline elements to the provided pass manager.
545 LogicalResult TextualPipeline::addToPipeline(
546  OpPassManager &pm,
547  function_ref<LogicalResult(const Twine &)> errorHandler) const {
548  // Temporarily disable implicit nesting while we append to the pipeline. We
549  // want the created pipeline to exactly match the parsed text pipeline, so
550  // it's preferrable to just error out if implicit nesting would be required.
551  OpPassManager::Nesting nesting = pm.getNesting();
553  auto restore = llvm::make_scope_exit([&]() { pm.setNesting(nesting); });
554 
555  return addToPipeline(pipeline, pm, errorHandler);
556 }
557 
558 /// Parse the given pipeline text into the internal pipeline vector. This
559 /// function only parses the structure of the pipeline, and does not resolve
560 /// its elements.
561 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
562  ErrorHandlerT errorHandler) {
563  SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
564  for (;;) {
565  std::vector<PipelineElement> &pipeline = *pipelineStack.back();
566  size_t pos = text.find_first_of(",(){");
567  pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());
568 
569  // If we have a single terminating name, we're done.
570  if (pos == StringRef::npos)
571  break;
572 
573  text = text.substr(pos);
574  char sep = text[0];
575 
576  // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
577  if (sep == '{') {
578  text = text.substr(1);
579 
580  // Skip over everything until the closing '}' and store as options.
581  size_t close = StringRef::npos;
582  for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
583  if (text[i] == '{') {
584  ++braceCount;
585  continue;
586  }
587  if (text[i] == '}' && --braceCount == 0) {
588  close = i;
589  break;
590  }
591  }
592 
593  // Check to see if a closing options brace was found.
594  if (close == StringRef::npos) {
595  return errorHandler(
596  /*rawLoc=*/text.data() - 1,
597  "missing closing '}' while processing pass options");
598  }
599  pipeline.back().options = text.substr(0, close);
600  text = text.substr(close + 1);
601 
602  // Consume space characters that an user might add for readability.
603  text = text.ltrim();
604 
605  // Skip checking for '(' because nested pipelines cannot have options.
606  } else if (sep == '(') {
607  text = text.substr(1);
608 
609  // Push the inner pipeline onto the stack to continue processing.
610  pipelineStack.push_back(&pipeline.back().innerPipeline);
611  continue;
612  }
613 
614  // When handling the close parenthesis, we greedily consume them to avoid
615  // empty strings in the pipeline.
616  while (text.consume_front(")")) {
617  // If we try to pop the outer pipeline we have unbalanced parentheses.
618  if (pipelineStack.size() == 1)
619  return errorHandler(/*rawLoc=*/text.data() - 1,
620  "encountered extra closing ')' creating unbalanced "
621  "parentheses while parsing pipeline");
622 
623  pipelineStack.pop_back();
624  // Consume space characters that an user might add for readability.
625  text = text.ltrim();
626  }
627 
628  // Check if we've finished parsing.
629  if (text.empty())
630  break;
631 
632  // Otherwise, the end of an inner pipeline always has to be followed by
633  // a comma, and then we can continue.
634  if (!text.consume_front(","))
635  return errorHandler(text.data(), "expected ',' after parsing pipeline");
636  }
637 
638  // Check for unbalanced parentheses.
639  if (pipelineStack.size() > 1)
640  return errorHandler(
641  text.data(),
642  "encountered unbalanced parentheses while parsing pipeline");
643 
644  assert(pipelineStack.back() == &pipeline &&
645  "wrong pipeline at the bottom of the stack");
646  return success();
647 }
648 
649 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
650 /// the corresponding registry entry.
651 LogicalResult TextualPipeline::resolvePipelineElements(
652  MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) {
653  for (auto &elt : elements)
654  if (failed(resolvePipelineElement(elt, errorHandler)))
655  return failure();
656  return success();
657 }
658 
659 /// Resolve a single element of the pipeline.
661 TextualPipeline::resolvePipelineElement(PipelineElement &element,
662  ErrorHandlerT errorHandler) {
663  // If the inner pipeline of this element is not empty, this is an operation
664  // pipeline.
665  if (!element.innerPipeline.empty())
666  return resolvePipelineElements(element.innerPipeline, errorHandler);
667 
668  // Otherwise, this must be a pass or pass pipeline.
669  // Check to see if a pipeline was registered with this name.
670  if ((element.registryEntry = PassPipelineInfo::lookup(element.name)))
671  return success();
672 
673  // If not, then this must be a specific pass name.
674  if ((element.registryEntry = PassInfo::lookup(element.name)))
675  return success();
676 
677  // Emit an error for the unknown pass.
678  auto *rawLoc = element.name.data();
679  return errorHandler(rawLoc, "'" + element.name +
680  "' does not refer to a "
681  "registered pass or pass pipeline");
682 }
683 
684 /// Add the given pipeline elements to the provided pass manager.
685 LogicalResult TextualPipeline::addToPipeline(
687  function_ref<LogicalResult(const Twine &)> errorHandler) const {
688  for (auto &elt : elements) {
689  if (elt.registryEntry) {
690  if (failed(elt.registryEntry->addToPipeline(pm, elt.options,
691  errorHandler))) {
692  return errorHandler("failed to add `" + elt.name + "` with options `" +
693  elt.options + "`");
694  }
695  } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),
696  errorHandler))) {
697  return errorHandler("failed to add `" + elt.name + "` with options `" +
698  elt.options + "` to inner pipeline");
699  }
700  }
701  return success();
702 }
703 
705  raw_ostream &errorStream) {
706  TextualPipeline pipelineParser;
707  if (failed(pipelineParser.initialize(pipeline, errorStream)))
708  return failure();
709  auto errorHandler = [&](Twine msg) {
710  errorStream << msg << "\n";
711  return failure();
712  };
713  if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
714  return failure();
715  return success();
716 }
717 
719  raw_ostream &errorStream) {
720  pipeline = pipeline.trim();
721  // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
722  size_t pipelineStart = pipeline.find_first_of('(');
723  if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
724  !pipeline.consume_back(")")) {
725  errorStream << "expected pass pipeline to be wrapped with the anchor "
726  "operation type, e.g. 'builtin.module(...)'";
727  return failure();
728  }
729 
730  StringRef opName = pipeline.take_front(pipelineStart).rtrim();
731  OpPassManager pm(opName);
732  if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm,
733  errorStream)))
734  return failure();
735  return pm;
736 }
737 
738 //===----------------------------------------------------------------------===//
739 // PassNameParser
740 //===----------------------------------------------------------------------===//
741 
742 namespace {
743 /// This struct represents the possible data entries in a parsed pass pipeline
744 /// list.
745 struct PassArgData {
746  PassArgData() = default;
747  PassArgData(const PassRegistryEntry *registryEntry)
748  : registryEntry(registryEntry) {}
749 
750  /// This field is used when the parsed option corresponds to a registered pass
751  /// or pass pipeline.
752  const PassRegistryEntry *registryEntry{nullptr};
753 
754  /// This field is set when instance specific pass options have been provided
755  /// on the command line.
756  StringRef options;
757 };
758 } // namespace
759 
760 namespace llvm {
761 namespace cl {
762 /// Define a valid OptionValue for the command line pass argument.
763 template <>
765  : OptionValueBase<PassArgData, /*isClass=*/true> {
766  OptionValue(const PassArgData &value) { this->setValue(value); }
767  OptionValue() = default;
768  void anchor() override {}
769 
770  bool hasValue() const { return true; }
771  const PassArgData &getValue() const { return value; }
772  void setValue(const PassArgData &value) { this->value = value; }
773 
774  PassArgData value;
775 };
776 } // namespace cl
777 } // namespace llvm
778 
779 namespace {
780 
781 /// The name for the command line option used for parsing the textual pass
782 /// pipeline.
783 #define PASS_PIPELINE_ARG "pass-pipeline"
784 
785 /// Adds command line option for each registered pass or pass pipeline, as well
786 /// as textual pass pipelines.
787 struct PassNameParser : public llvm::cl::parser<PassArgData> {
788  PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}
789 
790  void initialize();
791  void printOptionInfo(const llvm::cl::Option &opt,
792  size_t globalWidth) const override;
793  size_t getOptionWidth(const llvm::cl::Option &opt) const override;
794  bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
795  PassArgData &value);
796 
797  /// If true, this parser only parses entries that correspond to a concrete
798  /// pass registry entry, and does not include pipeline entries or the options
799  /// for pass entries.
800  bool passNamesOnly = false;
801 };
802 } // namespace
803 
804 void PassNameParser::initialize() {
806 
807  /// Add the pass entries.
808  for (const auto &kv : *passRegistry) {
809  addLiteralOption(kv.second.getPassArgument(), &kv.second,
810  kv.second.getPassDescription());
811  }
812  /// Add the pass pipeline entries.
813  if (!passNamesOnly) {
814  for (const auto &kv : *passPipelineRegistry) {
815  addLiteralOption(kv.second.getPassArgument(), &kv.second,
816  kv.second.getPassDescription());
817  }
818  }
819 }
820 
821 void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,
822  size_t globalWidth) const {
823  // If this parser is just parsing pass names, print a simplified option
824  // string.
825  if (passNamesOnly) {
826  llvm::outs() << " --" << opt.ArgStr << "=<pass-arg>";
827  opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
828  return;
829  }
830 
831  // Print the information for the top-level option.
832  if (opt.hasArgStr()) {
833  llvm::outs() << " --" << opt.ArgStr;
834  opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
835  } else {
836  llvm::outs() << " " << opt.HelpStr << '\n';
837  }
838 
839  // Functor used to print the ordered entries of a registration map.
840  auto printOrderedEntries = [&](StringRef header, auto &map) {
842  for (auto &kv : map)
843  orderedEntries.push_back(&kv.second);
844  llvm::array_pod_sort(
845  orderedEntries.begin(), orderedEntries.end(),
846  [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
847  return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
848  });
849 
850  llvm::outs().indent(4) << header << ":\n";
851  for (PassRegistryEntry *entry : orderedEntries)
852  entry->printHelpStr(/*indent=*/6, globalWidth);
853  };
854 
855  // Print the available passes.
856  printOrderedEntries("Passes", *passRegistry);
857 
858  // Print the available pass pipelines.
859  if (!passPipelineRegistry->empty())
860  printOrderedEntries("Pass Pipelines", *passPipelineRegistry);
861 }
862 
863 size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {
864  size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2;
865 
866  // Check for any wider pass or pipeline options.
867  for (auto &entry : *passRegistry)
868  maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
869  for (auto &entry : *passPipelineRegistry)
870  maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
871  return maxWidth;
872 }
873 
874 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
875  StringRef arg, PassArgData &value) {
876  if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value))
877  return true;
878  value.options = arg;
879  return false;
880 }
881 
882 //===----------------------------------------------------------------------===//
883 // PassPipelineCLParser
884 //===----------------------------------------------------------------------===//
885 
886 namespace mlir {
887 namespace detail {
889  PassPipelineCLParserImpl(StringRef arg, StringRef description,
890  bool passNamesOnly)
891  : passList(arg, llvm::cl::desc(description)) {
892  passList.getParser().passNamesOnly = passNamesOnly;
893  passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
894  }
895 
896  /// Returns true if the given pass registry entry was registered at the
897  /// top-level of the parser, i.e. not within an explicit textual pipeline.
898  bool contains(const PassRegistryEntry *entry) const {
899  return llvm::any_of(passList, [&](const PassArgData &data) {
900  return data.registryEntry == entry;
901  });
902  }
903 
904  /// The set of passes and pass pipelines to run.
905  llvm::cl::list<PassArgData, bool, PassNameParser> passList;
906 };
907 } // namespace detail
908 } // namespace mlir
909 
910 /// Construct a pass pipeline parser with the given command line description.
911 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)
912  : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
913  arg, description, /*passNamesOnly=*/false)),
914  passPipeline(
916  llvm::cl::desc("Textual description of the pass pipeline to run")) {}
917 
918 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description,
919  StringRef alias)
920  : PassPipelineCLParser(arg, description) {
921  passPipelineAlias.emplace(alias,
922  llvm::cl::desc("Alias for --" PASS_PIPELINE_ARG),
923  llvm::cl::aliasopt(passPipeline));
924 }
925 
927 
928 /// Returns true if this parser contains any valid options to add.
930  return passPipeline.getNumOccurrences() != 0 ||
931  impl->passList.getNumOccurrences() != 0;
932 }
933 
934 /// Returns true if the given pass registry entry was registered at the
935 /// top-level of the parser, i.e. not within an explicit textual pipeline.
937  return impl->contains(entry);
938 }
939 
940 /// Adds the passes defined by this parser entry to the given pass manager.
942  OpPassManager &pm,
943  function_ref<LogicalResult(const Twine &)> errorHandler) const {
944  if (passPipeline.getNumOccurrences()) {
945  if (impl->passList.getNumOccurrences())
946  return errorHandler(
947  "'-" PASS_PIPELINE_ARG
948  "' option can't be used with individual pass options");
949  std::string errMsg;
950  llvm::raw_string_ostream os(errMsg);
951  FailureOr<OpPassManager> parsed = parsePassPipeline(passPipeline, os);
952  if (failed(parsed))
953  return errorHandler(errMsg);
954  pm = std::move(*parsed);
955  return success();
956  }
957 
958  for (auto &passIt : impl->passList) {
959  if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
960  errorHandler)))
961  return failure();
962  }
963  return success();
964 }
965 
966 //===----------------------------------------------------------------------===//
967 // PassNameCLParser
968 
969 /// Construct a pass pipeline parser with the given command line description.
970 PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description)
971  : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
972  arg, description, /*passNamesOnly=*/true)) {
973  impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
974 }
976 
977 /// Returns true if this parser contains any valid options to add.
979  return impl->passList.getNumOccurrences() != 0;
980 }
981 
982 /// Returns true if the given pass registry entry was registered at the
983 /// top-level of the parser, i.e. not within an explicit textual pipeline.
985  return impl->contains(entry);
986 }
static llvm::ManagedStatic< PassManagerOptions > options
static llvm::ManagedStatic< llvm::StringMap< PassPipelineInfo > > passPipelineRegistry
Static mapping of all of the registered pass pipelines.
#define PASS_PIPELINE_ARG
The name for the command line option used for parsing the textual pass pipeline.
static llvm::ManagedStatic< llvm::StringMap< PassInfo > > passRegistry
Static mapping of all of the registered passes.
static PassRegistryFunction buildDefaultRegistryFn(const PassAllocatorFunction &allocator)
Utility to create a default registry function from a pass instance.
static void printOptionHelp(StringRef arg, StringRef desc, size_t indent, size_t descIndent, bool isTopLevel)
Utility to print the help string for a specific option.
static llvm::ManagedStatic< llvm::StringMap< TypeID > > passRegistryTypeIDs
A mapping of the above pass registry entries to the corresponding TypeID of the pass that they genera...
static std::tuple< StringRef, StringRef, StringRef > parseNextArg(StringRef options)
Parse in the next argument from the given options string.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class represents a pass manager that runs passes on either a specific operation type,...
Definition: PassManager.h:48
void printAsTextualPipeline(raw_ostream &os) const
Prints out the passes of the pass manager as the textual representation of pipelines.
Definition: Pass.cpp:375
std::optional< OperationName > getOpName(MLIRContext &context) const
Return the operation name that this pass manager operates on, or std::nullopt if this is an op-agnost...
Definition: Pass.cpp:365
void setNesting(Nesting nesting)
Enable or disable the implicit nesting on this particular PassManager.
Definition: Pass.cpp:402
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
Definition: Pass.cpp:346
Nesting getNesting()
Return the current nesting mode.
Definition: Pass.cpp:404
Nesting
This enum represents the nesting behavior of the pass manager.
Definition: PassManager.h:51
@ Explicit
Explicit nesting behavior.
StringRef getOpAnchorName() const
Return the name used to anchor this pass manager.
Definition: Pass.cpp:369
OpPassManager & nest(OperationName nestedName)
Nest a new operation pass manager for the given operation kind under this pass manager.
Definition: Pass.cpp:336
A structure to represent the information for a derived pass class.
Definition: PassRegistry.h:115
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
PassInfo(StringRef arg, StringRef description, const PassAllocatorFunction &allocator)
PassInfo constructor should not be invoked directly, instead use PassRegistration or registerPass.
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassNameCLParser(StringRef arg, StringRef description)
Construct a parser with the given command line description.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
This class implements a command-line parser for MLIR passes.
Definition: PassRegistry.h:244
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassPipelineCLParser(StringRef arg, StringRef description)
Construct a pass pipeline parser with the given command line description.
LogicalResult addToPipeline(OpPassManager &pm, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds the passes defined by this parser entry to the given pass manager.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
A structure to represent the information of a registered pass pipeline.
Definition: PassRegistry.h:101
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
Definition: PassRegistry.h:49
void printHelpStr(size_t indent, size_t descIndent) const
Print the help information for this pass.
size_t getOptionWidth() const
Return the maximum width required when printing the options of this entry.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
Base container class and manager for all pass options.
Definition: PassOptions.h:79
size_t getOptionWidth() const
Return the maximum width required when printing the help string.
void printHelp(size_t indent, size_t descIndent) const
Print the help string for the options held by this struct.
LogicalResult parseFromString(StringRef options)
Parse options out as key=value pairs that can then be handed off to the llvm::cl command line passing...
void copyOptionValuesFrom(const PassOptions &other)
Copy the option values from 'other' into 'this', where 'other' has the same options as 'this'.
void print(raw_ostream &os)
Print the options held by this struct in a form that can be parsed via 'parseFromString'.
Include the generated interface declarations.
Definition: CallGraph.h:229
LogicalResult parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName, StringRef optionStr, function_ref< LogicalResult(StringRef)> elementParseFn)
Parse a string containing a list of comma-delimited elements, invoking the given parser for each sub-...
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:65
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:19
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::function< std::unique_ptr< Pass >()> PassAllocatorFunction
Definition: PassRegistry.h:41
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
void registerPass(const PassAllocatorFunction &function)
Register a specific dialect pass allocator function with the system, typically used through the PassR...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void registerPassPipeline(StringRef arg, StringRef description, const PassRegistryFunction &function, std::function< void(function_ref< void(const detail::PassOptions &)>)> optHandler)
Register a specific dialect pipeline registry function with the system, typically used through the Pa...
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
std::function< LogicalResult(OpPassManager &, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler)> PassRegistryFunction
A registry function that adds passes to the given pass manager.
Definition: PassRegistry.h:40
Define a valid OptionValue for the command line pass argument.
OptionValue(const PassArgData &value)
void setValue(const PassArgData &value)
const PassArgData & getValue() const
mlir::OpPassManager & getValue() const
Returns the current value of the option.
Definition: PassOptions.h:434
bool hasValue() const
Returns if the current option has a value.
Definition: PassOptions.h:431
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
llvm::cl::list< PassArgData, bool, PassNameParser > passList
The set of passes and pass pipelines to run.
PassPipelineCLParserImpl(StringRef arg, StringRef description, bool passNamesOnly)
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...