MLIR  20.0.0git
PassRegistry.cpp
Go to the documentation of this file.
1 //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/Pass/Pass.h"
12 #include "mlir/Pass/PassManager.h"
13 #include "llvm/ADT/DenseMap.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/Support/Format.h"
17 #include "llvm/Support/ManagedStatic.h"
18 #include "llvm/Support/MemoryBuffer.h"
19 #include "llvm/Support/SourceMgr.h"
20 
21 #include <optional>
22 #include <utility>
23 
24 using namespace mlir;
25 using namespace detail;
26 
27 /// Static mapping of all of the registered passes.
28 static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
29 
30 /// A mapping of the above pass registry entries to the corresponding TypeID
31 /// of the pass that they generate.
32 static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
33 
34 /// Static mapping of all of the registered pass pipelines.
35 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
37 
38 /// Utility to create a default registry function from a pass instance.
41  return [=](OpPassManager &pm, StringRef options,
42  function_ref<LogicalResult(const Twine &)> errorHandler) {
43  std::unique_ptr<Pass> pass = allocator();
44  LogicalResult result = pass->initializeOptions(options, errorHandler);
45 
46  std::optional<StringRef> pmOpName = pm.getOpName();
47  std::optional<StringRef> passOpName = pass->getOpName();
48  if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && pmOpName &&
49  passOpName && *pmOpName != *passOpName) {
50  return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
51  "' restricted to '" + *pass->getOpName() +
52  "' on a PassManager intended to run on '" +
53  pm.getOpAnchorName() + "', did you intend to nest?");
54  }
55  pm.addPass(std::move(pass));
56  return result;
57  };
58 }
59 
60 /// Utility to print the help string for a specific option.
61 static void printOptionHelp(StringRef arg, StringRef desc, size_t indent,
62  size_t descIndent, bool isTopLevel) {
63  size_t numSpaces = descIndent - indent - 4;
64  llvm::outs().indent(indent)
65  << "--" << llvm::left_justify(arg, numSpaces) << "- " << desc << '\n';
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // PassRegistry
70 //===----------------------------------------------------------------------===//
71 
72 /// Prints the passes that were previously registered and stored in passRegistry
74  size_t maxWidth = 0;
75  for (auto &entry : *passRegistry)
76  maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
77 
78  // Functor used to print the ordered entries of a registration map.
79  auto printOrderedEntries = [&](StringRef header, auto &map) {
81  for (auto &kv : map)
82  orderedEntries.push_back(&kv.second);
83  llvm::array_pod_sort(
84  orderedEntries.begin(), orderedEntries.end(),
85  [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
86  return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
87  });
88 
89  llvm::outs().indent(0) << header << ":\n";
90  for (PassRegistryEntry *entry : orderedEntries)
91  entry->printHelpStr(/*indent=*/2, maxWidth);
92  };
93 
94  // Print the available passes.
95  printOrderedEntries("Passes", *passRegistry);
96 }
97 
98 /// Print the help information for this pass. This includes the argument,
99 /// description, and any pass options. `descIndent` is the indent that the
100 /// descriptions should be aligned.
101 void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {
102  printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
103  /*isTopLevel=*/true);
104  // If this entry has options, print the help for those as well.
105  optHandler([=](const PassOptions &options) {
106  options.printHelp(indent, descIndent);
107  });
108 }
109 
110 /// Return the maximum width required when printing the options of this
111 /// entry.
113  size_t maxLen = 0;
114  optHandler([&](const PassOptions &options) mutable {
115  maxLen = options.getOptionWidth() + 2;
116  });
117  return maxLen;
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // PassPipelineInfo
122 //===----------------------------------------------------------------------===//
123 
125  StringRef arg, StringRef description, const PassRegistryFunction &function,
126  std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {
127  PassPipelineInfo pipelineInfo(arg, description, function,
128  std::move(optHandler));
129  bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;
130 #ifndef NDEBUG
131  if (!inserted)
132  report_fatal_error("Pass pipeline " + arg + " registered multiple times");
133 #endif
134  (void)inserted;
135 }
136 
137 //===----------------------------------------------------------------------===//
138 // PassInfo
139 //===----------------------------------------------------------------------===//
140 
141 PassInfo::PassInfo(StringRef arg, StringRef description,
142  const PassAllocatorFunction &allocator)
144  arg, description, buildDefaultRegistryFn(allocator),
145  // Use a temporary pass to provide an options instance.
146  [=](function_ref<void(const PassOptions &)> optHandler) {
147  optHandler(allocator()->passOptions);
148  }) {}
149 
151  std::unique_ptr<Pass> pass = function();
152  StringRef arg = pass->getArgument();
153  if (arg.empty())
154  llvm::report_fatal_error(llvm::Twine("Trying to register '") +
155  pass->getName() +
156  "' pass that does not override `getArgument()`");
157  StringRef description = pass->getDescription();
158  PassInfo passInfo(arg, description, function);
159  passRegistry->try_emplace(arg, passInfo);
160 
161  // Verify that the registered pass has the same ID as any registered to this
162  // arg before it.
163  TypeID entryTypeID = pass->getTypeID();
164  auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;
165  if (it->second != entryTypeID)
166  llvm::report_fatal_error(
167  "pass allocator creates a different pass than previously "
168  "registered for pass " +
169  arg);
170 }
171 
172 /// Returns the pass info for the specified pass argument or null if unknown.
173 const PassInfo *mlir::PassInfo::lookup(StringRef passArg) {
174  auto it = passRegistry->find(passArg);
175  return it == passRegistry->end() ? nullptr : &it->second;
176 }
177 
178 /// Returns the pass pipeline info for the specified pass pipeline argument or
179 /// null if unknown.
180 const PassPipelineInfo *mlir::PassPipelineInfo::lookup(StringRef pipelineArg) {
181  auto it = passPipelineRegistry->find(pipelineArg);
182  return it == passPipelineRegistry->end() ? nullptr : &it->second;
183 }
184 
185 //===----------------------------------------------------------------------===//
186 // PassOptions
187 //===----------------------------------------------------------------------===//
188 
189 /// Attempt to find the next occurance of character 'c' in the string starting
190 /// from the `index`-th position , omitting any occurances that appear within
191 /// intervening ranges or literals.
192 static size_t findChar(StringRef str, size_t index, char c) {
193  for (size_t i = index, e = str.size(); i < e; ++i) {
194  if (str[i] == c)
195  return i;
196  // Check for various range characters.
197  if (str[i] == '{')
198  i = findChar(str, i + 1, '}');
199  else if (str[i] == '(')
200  i = findChar(str, i + 1, ')');
201  else if (str[i] == '[')
202  i = findChar(str, i + 1, ']');
203  else if (str[i] == '\"')
204  i = str.find_first_of('\"', i + 1);
205  else if (str[i] == '\'')
206  i = str.find_first_of('\'', i + 1);
207  if (i == StringRef::npos)
208  return StringRef::npos;
209  }
210  return StringRef::npos;
211 }
212 
213 /// Extract an argument from 'options' and update it to point after the arg.
214 /// Returns the cleaned argument string.
215 static StringRef extractArgAndUpdateOptions(StringRef &options,
216  size_t argSize) {
217  StringRef str = options.take_front(argSize).trim();
218  options = options.drop_front(argSize).ltrim();
219 
220  // Early exit if there's no escape sequence.
221  if (str.size() <= 1)
222  return str;
223 
224  const auto escapePairs = {std::make_pair('\'', '\''),
225  std::make_pair('"', '"')};
226  for (const auto &escape : escapePairs) {
227  if (str.front() == escape.first && str.back() == escape.second) {
228  // Drop the escape characters and trim.
229  // Don't process additional escape sequences.
230  return str.drop_front().drop_back().trim();
231  }
232  }
233 
234  // Arguments may be wrapped in `{...}`. Unlike the quotation markers that
235  // denote literals, we respect scoping here. The outer `{...}` should not
236  // be stripped in cases such as "arg={...},{...}", which can be used to denote
237  // lists of nested option structs.
238  if (str.front() == '{') {
239  unsigned match = findChar(str, 1, '}');
240  if (match == str.size() - 1)
241  str = str.drop_front().drop_back().trim();
242  }
243 
244  return str;
245 }
246 
248  llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
249  function_ref<LogicalResult(StringRef)> elementParseFn) {
250  if (optionStr.empty())
251  return success();
252 
253  size_t nextElePos = findChar(optionStr, 0, ',');
254  while (nextElePos != StringRef::npos) {
255  // Process the portion before the comma.
256  if (failed(
257  elementParseFn(extractArgAndUpdateOptions(optionStr, nextElePos))))
258  return failure();
259 
260  // Drop the leading ','
261  optionStr = optionStr.drop_front();
262  nextElePos = findChar(optionStr, 0, ',');
263  }
264  return elementParseFn(
265  extractArgAndUpdateOptions(optionStr, optionStr.size()));
266 }
267 
268 /// Out of line virtual function to provide home for the class.
269 void detail::PassOptions::OptionBase::anchor() {}
270 
271 /// Copy the option values from 'other'.
273  assert(options.size() == other.options.size());
274  if (options.empty())
275  return;
276  for (auto optionsIt : llvm::zip(options, other.options))
277  std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
278 }
279 
280 /// Parse in the next argument from the given options string. Returns a tuple
281 /// containing [the key of the option, the value of the option, updated
282 /// `options` string pointing after the parsed option].
283 static std::tuple<StringRef, StringRef, StringRef>
284 parseNextArg(StringRef options) {
285  // Try to process the given punctuation, properly escaping any contained
286  // characters.
287  auto tryProcessPunct = [&](size_t &currentPos, char punct) {
288  if (options[currentPos] != punct)
289  return false;
290  size_t nextIt = options.find_first_of(punct, currentPos + 1);
291  if (nextIt != StringRef::npos)
292  currentPos = nextIt;
293  return true;
294  };
295 
296  // Parse the argument name of the option.
297  StringRef argName;
298  for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
299  // Check for the end of the full option.
300  if (argEndIt == optionsE || options[argEndIt] == ' ') {
301  argName = extractArgAndUpdateOptions(options, argEndIt);
302  return std::make_tuple(argName, StringRef(), options);
303  }
304 
305  // Check for the end of the name and the start of the value.
306  if (options[argEndIt] == '=') {
307  argName = extractArgAndUpdateOptions(options, argEndIt);
308  options = options.drop_front();
309  break;
310  }
311  }
312 
313  // Parse the value of the option.
314  for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
315  // Handle the end of the options string.
316  if (argEndIt == optionsE || options[argEndIt] == ' ') {
317  StringRef value = extractArgAndUpdateOptions(options, argEndIt);
318  return std::make_tuple(argName, value, options);
319  }
320 
321  // Skip over escaped sequences.
322  char c = options[argEndIt];
323  if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
324  continue;
325  // '{...}' is used to specify options to passes, properly escape it so
326  // that we don't accidentally split any nested options.
327  if (c == '{') {
328  size_t braceCount = 1;
329  for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
330  // Allow nested punctuation.
331  if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
332  continue;
333  if (options[argEndIt] == '{')
334  ++braceCount;
335  else if (options[argEndIt] == '}' && --braceCount == 0)
336  break;
337  }
338  // Account for the increment at the top of the loop.
339  --argEndIt;
340  }
341  }
342  llvm_unreachable("unexpected control flow in pass option parsing");
343 }
344 
346  raw_ostream &errorStream) {
347  // NOTE: `options` is modified in place to always refer to the unprocessed
348  // part of the string.
349  while (!options.empty()) {
350  StringRef key, value;
351  std::tie(key, value, options) = parseNextArg(options);
352  if (key.empty())
353  continue;
354 
355  auto it = OptionsMap.find(key);
356  if (it == OptionsMap.end()) {
357  errorStream << "<Pass-Options-Parser>: no such option " << key << "\n";
358  return failure();
359  }
360  if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
361  return failure();
362  }
363 
364  return success();
365 }
366 
367 /// Print the options held by this struct in a form that can be parsed via
368 /// 'parseFromString'.
369 void detail::PassOptions::print(raw_ostream &os) const {
370  // If there are no options, there is nothing left to do.
371  if (OptionsMap.empty())
372  return;
373 
374  // Sort the options to make the ordering deterministic.
375  SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
376  auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
377  return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
378  };
379  llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
380 
381  // Interleave the options with ' '.
382  os << '{';
383  llvm::interleave(
384  orderedOps, os, [&](OptionBase *option) { option->print(os); }, " ");
385  os << '}';
386 }
387 
388 /// Print the help string for the options held by this struct. `descIndent` is
389 /// the indent within the stream that the descriptions should be aligned.
390 void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {
391  // Sort the options to make the ordering deterministic.
392  SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
393  auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
394  return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
395  };
396  llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
397  for (OptionBase *option : orderedOps) {
398  // TODO: printOptionInfo assumes a specific indent and will
399  // print options with values with incorrect indentation. We should add
400  // support to llvm::cl::Option for passing in a base indent to use when
401  // printing.
402  llvm::outs().indent(indent);
403  option->getOption()->printOptionInfo(descIndent - indent);
404  }
405 }
406 
407 /// Return the maximum width required when printing the help string.
409  size_t max = 0;
410  for (auto *option : options)
411  max = std::max(max, option->getOption()->getOptionWidth());
412  return max;
413 }
414 
415 //===----------------------------------------------------------------------===//
416 // MLIR Options
417 //===----------------------------------------------------------------------===//
418 
419 //===----------------------------------------------------------------------===//
420 // OpPassManager: OptionValue
421 
422 llvm::cl::OptionValue<OpPassManager>::OptionValue() = default;
423 llvm::cl::OptionValue<OpPassManager>::OptionValue(
424  const mlir::OpPassManager &value) {
425  setValue(value);
426 }
427 llvm::cl::OptionValue<OpPassManager>::OptionValue(
429  if (rhs.hasValue())
430  setValue(rhs.getValue());
431 }
432 llvm::cl::OptionValue<OpPassManager> &
433 llvm::cl::OptionValue<OpPassManager>::operator=(
434  const mlir::OpPassManager &rhs) {
435  setValue(rhs);
436  return *this;
437 }
438 
439 llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default;
440 
441 void llvm::cl::OptionValue<OpPassManager>::setValue(
442  const OpPassManager &newValue) {
443  if (hasValue())
444  *value = newValue;
445  else
446  value = std::make_unique<mlir::OpPassManager>(newValue);
447 }
448 void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
449  FailureOr<OpPassManager> pipeline = parsePassPipeline(pipelineStr);
450  assert(succeeded(pipeline) && "invalid pass pipeline");
451  setValue(*pipeline);
452 }
453 
455  const mlir::OpPassManager &rhs) const {
456  std::string lhsStr, rhsStr;
457  {
458  raw_string_ostream lhsStream(lhsStr);
459  value->printAsTextualPipeline(lhsStream);
460 
461  raw_string_ostream rhsStream(rhsStr);
462  rhs.printAsTextualPipeline(rhsStream);
463  }
464 
465  // Use the textual format for pipeline comparisons.
466  return lhsStr == rhsStr;
467 }
468 
469 void llvm::cl::OptionValue<OpPassManager>::anchor() {}
470 
471 //===----------------------------------------------------------------------===//
472 // OpPassManager: Parser
473 
474 namespace llvm {
475 namespace cl {
476 template class basic_parser<OpPassManager>;
477 } // namespace cl
478 } // namespace llvm
479 
480 bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg,
481  ParsedPassManager &value) {
482  FailureOr<OpPassManager> pipeline = parsePassPipeline(arg);
483  if (failed(pipeline))
484  return true;
485  value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
486  return false;
487 }
488 
489 void llvm::cl::parser<OpPassManager>::print(raw_ostream &os,
490  const OpPassManager &value) {
491  value.printAsTextualPipeline(os);
492 }
493 
495  const Option &opt, OpPassManager &pm, const OptVal &defaultValue,
496  size_t globalWidth) const {
497  printOptionName(opt, globalWidth);
498  outs() << "= ";
499  pm.printAsTextualPipeline(outs());
500 
501  if (defaultValue.hasValue()) {
502  outs().indent(2) << " (default: ";
503  defaultValue.getValue().printAsTextualPipeline(outs());
504  outs() << ")";
505  }
506  outs() << "\n";
507 }
508 
510 
512  default;
514  ParsedPassManager &&) = default;
516  default;
517 
518 //===----------------------------------------------------------------------===//
519 // TextualPassPipeline Parser
520 //===----------------------------------------------------------------------===//
521 
522 namespace {
523 /// This class represents a textual description of a pass pipeline.
524 class TextualPipeline {
525 public:
526  /// Try to initialize this pipeline with the given pipeline text.
527  /// `errorStream` is the output stream to emit errors to.
528  LogicalResult initialize(StringRef text, raw_ostream &errorStream);
529 
530  /// Add the internal pipeline elements to the provided pass manager.
531  LogicalResult
532  addToPipeline(OpPassManager &pm,
533  function_ref<LogicalResult(const Twine &)> errorHandler) const;
534 
535 private:
536  /// A functor used to emit errors found during pipeline handling. The first
537  /// parameter corresponds to the raw location within the pipeline string. This
538  /// should always return failure.
539  using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>;
540 
541  /// A struct to capture parsed pass pipeline names.
542  ///
543  /// A pipeline is defined as a series of names, each of which may in itself
544  /// recursively contain a nested pipeline. A name is either the name of a pass
545  /// (e.g. "cse") or the name of an operation type (e.g. "buitin.module"). If
546  /// the name is the name of a pass, the InnerPipeline is empty, since passes
547  /// cannot contain inner pipelines.
548  struct PipelineElement {
549  PipelineElement(StringRef name) : name(name) {}
550 
551  StringRef name;
552  StringRef options;
553  const PassRegistryEntry *registryEntry = nullptr;
554  std::vector<PipelineElement> innerPipeline;
555  };
556 
557  /// Parse the given pipeline text into the internal pipeline vector. This
558  /// function only parses the structure of the pipeline, and does not resolve
559  /// its elements.
560  LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
561 
562  /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
563  /// the corresponding registry entry.
564  LogicalResult
565  resolvePipelineElements(MutableArrayRef<PipelineElement> elements,
566  ErrorHandlerT errorHandler);
567 
568  /// Resolve a single element of the pipeline.
569  LogicalResult resolvePipelineElement(PipelineElement &element,
570  ErrorHandlerT errorHandler);
571 
572  /// Add the given pipeline elements to the provided pass manager.
573  LogicalResult
574  addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
575  function_ref<LogicalResult(const Twine &)> errorHandler) const;
576 
577  std::vector<PipelineElement> pipeline;
578 };
579 
580 } // namespace
581 
582 /// Try to initialize this pipeline with the given pipeline text. An option is
583 /// given to enable accurate error reporting.
584 LogicalResult TextualPipeline::initialize(StringRef text,
585  raw_ostream &errorStream) {
586  if (text.empty())
587  return success();
588 
589  // Build a source manager to use for error reporting.
590  llvm::SourceMgr pipelineMgr;
591  pipelineMgr.AddNewSourceBuffer(
592  llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser",
593  /*RequiresNullTerminator=*/false),
594  SMLoc());
595  auto errorHandler = [&](const char *rawLoc, Twine msg) {
596  pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc),
597  llvm::SourceMgr::DK_Error, msg);
598  return failure();
599  };
600 
601  // Parse the provided pipeline string.
602  if (failed(parsePipelineText(text, errorHandler)))
603  return failure();
604  return resolvePipelineElements(pipeline, errorHandler);
605 }
606 
607 /// Add the internal pipeline elements to the provided pass manager.
608 LogicalResult TextualPipeline::addToPipeline(
609  OpPassManager &pm,
610  function_ref<LogicalResult(const Twine &)> errorHandler) const {
611  // Temporarily disable implicit nesting while we append to the pipeline. We
612  // want the created pipeline to exactly match the parsed text pipeline, so
613  // it's preferrable to just error out if implicit nesting would be required.
614  OpPassManager::Nesting nesting = pm.getNesting();
616  auto restore = llvm::make_scope_exit([&]() { pm.setNesting(nesting); });
617 
618  return addToPipeline(pipeline, pm, errorHandler);
619 }
620 
621 /// Parse the given pipeline text into the internal pipeline vector. This
622 /// function only parses the structure of the pipeline, and does not resolve
623 /// its elements.
624 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
625  ErrorHandlerT errorHandler) {
626  SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
627  for (;;) {
628  std::vector<PipelineElement> &pipeline = *pipelineStack.back();
629  size_t pos = text.find_first_of(",(){");
630  pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());
631 
632  // If we have a single terminating name, we're done.
633  if (pos == StringRef::npos)
634  break;
635 
636  text = text.substr(pos);
637  char sep = text[0];
638 
639  // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
640  if (sep == '{') {
641  text = text.substr(1);
642 
643  // Skip over everything until the closing '}' and store as options.
644  size_t close = StringRef::npos;
645  for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
646  if (text[i] == '{') {
647  ++braceCount;
648  continue;
649  }
650  if (text[i] == '}' && --braceCount == 0) {
651  close = i;
652  break;
653  }
654  }
655 
656  // Check to see if a closing options brace was found.
657  if (close == StringRef::npos) {
658  return errorHandler(
659  /*rawLoc=*/text.data() - 1,
660  "missing closing '}' while processing pass options");
661  }
662  pipeline.back().options = text.substr(0, close);
663  text = text.substr(close + 1);
664 
665  // Consume space characters that an user might add for readability.
666  text = text.ltrim();
667 
668  // Skip checking for '(' because nested pipelines cannot have options.
669  } else if (sep == '(') {
670  text = text.substr(1);
671 
672  // Push the inner pipeline onto the stack to continue processing.
673  pipelineStack.push_back(&pipeline.back().innerPipeline);
674  continue;
675  }
676 
677  // When handling the close parenthesis, we greedily consume them to avoid
678  // empty strings in the pipeline.
679  while (text.consume_front(")")) {
680  // If we try to pop the outer pipeline we have unbalanced parentheses.
681  if (pipelineStack.size() == 1)
682  return errorHandler(/*rawLoc=*/text.data() - 1,
683  "encountered extra closing ')' creating unbalanced "
684  "parentheses while parsing pipeline");
685 
686  pipelineStack.pop_back();
687  // Consume space characters that an user might add for readability.
688  text = text.ltrim();
689  }
690 
691  // Check if we've finished parsing.
692  if (text.empty())
693  break;
694 
695  // Otherwise, the end of an inner pipeline always has to be followed by
696  // a comma, and then we can continue.
697  if (!text.consume_front(","))
698  return errorHandler(text.data(), "expected ',' after parsing pipeline");
699  }
700 
701  // Check for unbalanced parentheses.
702  if (pipelineStack.size() > 1)
703  return errorHandler(
704  text.data(),
705  "encountered unbalanced parentheses while parsing pipeline");
706 
707  assert(pipelineStack.back() == &pipeline &&
708  "wrong pipeline at the bottom of the stack");
709  return success();
710 }
711 
712 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
713 /// the corresponding registry entry.
714 LogicalResult TextualPipeline::resolvePipelineElements(
715  MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) {
716  for (auto &elt : elements)
717  if (failed(resolvePipelineElement(elt, errorHandler)))
718  return failure();
719  return success();
720 }
721 
722 /// Resolve a single element of the pipeline.
723 LogicalResult
724 TextualPipeline::resolvePipelineElement(PipelineElement &element,
725  ErrorHandlerT errorHandler) {
726  // If the inner pipeline of this element is not empty, this is an operation
727  // pipeline.
728  if (!element.innerPipeline.empty())
729  return resolvePipelineElements(element.innerPipeline, errorHandler);
730 
731  // Otherwise, this must be a pass or pass pipeline.
732  // Check to see if a pipeline was registered with this name.
733  if ((element.registryEntry = PassPipelineInfo::lookup(element.name)))
734  return success();
735 
736  // If not, then this must be a specific pass name.
737  if ((element.registryEntry = PassInfo::lookup(element.name)))
738  return success();
739 
740  // Emit an error for the unknown pass.
741  auto *rawLoc = element.name.data();
742  return errorHandler(rawLoc, "'" + element.name +
743  "' does not refer to a "
744  "registered pass or pass pipeline");
745 }
746 
747 /// Add the given pipeline elements to the provided pass manager.
748 LogicalResult TextualPipeline::addToPipeline(
750  function_ref<LogicalResult(const Twine &)> errorHandler) const {
751  for (auto &elt : elements) {
752  if (elt.registryEntry) {
753  if (failed(elt.registryEntry->addToPipeline(pm, elt.options,
754  errorHandler))) {
755  return errorHandler("failed to add `" + elt.name + "` with options `" +
756  elt.options + "`");
757  }
758  } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),
759  errorHandler))) {
760  return errorHandler("failed to add `" + elt.name + "` with options `" +
761  elt.options + "` to inner pipeline");
762  }
763  }
764  return success();
765 }
766 
767 LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
768  raw_ostream &errorStream) {
769  TextualPipeline pipelineParser;
770  if (failed(pipelineParser.initialize(pipeline, errorStream)))
771  return failure();
772  auto errorHandler = [&](Twine msg) {
773  errorStream << msg << "\n";
774  return failure();
775  };
776  if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
777  return failure();
778  return success();
779 }
780 
781 FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline,
782  raw_ostream &errorStream) {
783  pipeline = pipeline.trim();
784  // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
785  size_t pipelineStart = pipeline.find_first_of('(');
786  if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
787  !pipeline.consume_back(")")) {
788  errorStream << "expected pass pipeline to be wrapped with the anchor "
789  "operation type, e.g. 'builtin.module(...)'";
790  return failure();
791  }
792 
793  StringRef opName = pipeline.take_front(pipelineStart).rtrim();
794  OpPassManager pm(opName);
795  if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm,
796  errorStream)))
797  return failure();
798  return pm;
799 }
800 
801 //===----------------------------------------------------------------------===//
802 // PassNameParser
803 //===----------------------------------------------------------------------===//
804 
805 namespace {
806 /// This struct represents the possible data entries in a parsed pass pipeline
807 /// list.
808 struct PassArgData {
809  PassArgData() = default;
810  PassArgData(const PassRegistryEntry *registryEntry)
811  : registryEntry(registryEntry) {}
812 
813  /// This field is used when the parsed option corresponds to a registered pass
814  /// or pass pipeline.
815  const PassRegistryEntry *registryEntry{nullptr};
816 
817  /// This field is set when instance specific pass options have been provided
818  /// on the command line.
819  StringRef options;
820 };
821 } // namespace
822 
823 namespace llvm {
824 namespace cl {
825 /// Define a valid OptionValue for the command line pass argument.
826 template <>
828  : OptionValueBase<PassArgData, /*isClass=*/true> {
829  OptionValue(const PassArgData &value) { this->setValue(value); }
830  OptionValue() = default;
831  void anchor() override {}
832 
833  bool hasValue() const { return true; }
834  const PassArgData &getValue() const { return value; }
835  void setValue(const PassArgData &value) { this->value = value; }
836 
837  PassArgData value;
838 };
839 } // namespace cl
840 } // namespace llvm
841 
842 namespace {
843 
844 /// The name for the command line option used for parsing the textual pass
845 /// pipeline.
846 #define PASS_PIPELINE_ARG "pass-pipeline"
847 
848 /// Adds command line option for each registered pass or pass pipeline, as well
849 /// as textual pass pipelines.
850 struct PassNameParser : public llvm::cl::parser<PassArgData> {
851  PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}
852 
853  void initialize();
854  void printOptionInfo(const llvm::cl::Option &opt,
855  size_t globalWidth) const override;
856  size_t getOptionWidth(const llvm::cl::Option &opt) const override;
857  bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
858  PassArgData &value);
859 
860  /// If true, this parser only parses entries that correspond to a concrete
861  /// pass registry entry, and does not include pipeline entries or the options
862  /// for pass entries.
863  bool passNamesOnly = false;
864 };
865 } // namespace
866 
867 void PassNameParser::initialize() {
869 
870  /// Add the pass entries.
871  for (const auto &kv : *passRegistry) {
872  addLiteralOption(kv.second.getPassArgument(), &kv.second,
873  kv.second.getPassDescription());
874  }
875  /// Add the pass pipeline entries.
876  if (!passNamesOnly) {
877  for (const auto &kv : *passPipelineRegistry) {
878  addLiteralOption(kv.second.getPassArgument(), &kv.second,
879  kv.second.getPassDescription());
880  }
881  }
882 }
883 
884 void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,
885  size_t globalWidth) const {
886  // If this parser is just parsing pass names, print a simplified option
887  // string.
888  if (passNamesOnly) {
889  llvm::outs() << " --" << opt.ArgStr << "=<pass-arg>";
890  opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
891  return;
892  }
893 
894  // Print the information for the top-level option.
895  if (opt.hasArgStr()) {
896  llvm::outs() << " --" << opt.ArgStr;
897  opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
898  } else {
899  llvm::outs() << " " << opt.HelpStr << '\n';
900  }
901 
902  // Functor used to print the ordered entries of a registration map.
903  auto printOrderedEntries = [&](StringRef header, auto &map) {
905  for (auto &kv : map)
906  orderedEntries.push_back(&kv.second);
907  llvm::array_pod_sort(
908  orderedEntries.begin(), orderedEntries.end(),
909  [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
910  return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
911  });
912 
913  llvm::outs().indent(4) << header << ":\n";
914  for (PassRegistryEntry *entry : orderedEntries)
915  entry->printHelpStr(/*indent=*/6, globalWidth);
916  };
917 
918  // Print the available passes.
919  printOrderedEntries("Passes", *passRegistry);
920 
921  // Print the available pass pipelines.
922  if (!passPipelineRegistry->empty())
923  printOrderedEntries("Pass Pipelines", *passPipelineRegistry);
924 }
925 
926 size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {
927  size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2;
928 
929  // Check for any wider pass or pipeline options.
930  for (auto &entry : *passRegistry)
931  maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
932  for (auto &entry : *passPipelineRegistry)
933  maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
934  return maxWidth;
935 }
936 
937 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
938  StringRef arg, PassArgData &value) {
939  if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value))
940  return true;
941  value.options = arg;
942  return false;
943 }
944 
945 //===----------------------------------------------------------------------===//
946 // PassPipelineCLParser
947 //===----------------------------------------------------------------------===//
948 
949 namespace mlir {
950 namespace detail {
952  PassPipelineCLParserImpl(StringRef arg, StringRef description,
953  bool passNamesOnly)
954  : passList(arg, llvm::cl::desc(description)) {
955  passList.getParser().passNamesOnly = passNamesOnly;
956  passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
957  }
958 
959  /// Returns true if the given pass registry entry was registered at the
960  /// top-level of the parser, i.e. not within an explicit textual pipeline.
961  bool contains(const PassRegistryEntry *entry) const {
962  return llvm::any_of(passList, [&](const PassArgData &data) {
963  return data.registryEntry == entry;
964  });
965  }
966 
967  /// The set of passes and pass pipelines to run.
968  llvm::cl::list<PassArgData, bool, PassNameParser> passList;
969 };
970 } // namespace detail
971 } // namespace mlir
972 
973 /// Construct a pass pipeline parser with the given command line description.
974 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)
975  : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
976  arg, description, /*passNamesOnly=*/false)),
977  passPipeline(
979  llvm::cl::desc("Textual description of the pass pipeline to run")) {}
980 
981 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description,
982  StringRef alias)
983  : PassPipelineCLParser(arg, description) {
984  passPipelineAlias.emplace(alias,
985  llvm::cl::desc("Alias for --" PASS_PIPELINE_ARG),
986  llvm::cl::aliasopt(passPipeline));
987 }
988 
990 
991 /// Returns true if this parser contains any valid options to add.
993  return passPipeline.getNumOccurrences() != 0 ||
994  impl->passList.getNumOccurrences() != 0;
995 }
996 
997 /// Returns true if the given pass registry entry was registered at the
998 /// top-level of the parser, i.e. not within an explicit textual pipeline.
1000  return impl->contains(entry);
1001 }
1002 
1003 /// Adds the passes defined by this parser entry to the given pass manager.
1005  OpPassManager &pm,
1006  function_ref<LogicalResult(const Twine &)> errorHandler) const {
1007  if (passPipeline.getNumOccurrences()) {
1008  if (impl->passList.getNumOccurrences())
1009  return errorHandler(
1010  "'-" PASS_PIPELINE_ARG
1011  "' option can't be used with individual pass options");
1012  std::string errMsg;
1013  llvm::raw_string_ostream os(errMsg);
1014  FailureOr<OpPassManager> parsed = parsePassPipeline(passPipeline, os);
1015  if (failed(parsed))
1016  return errorHandler(errMsg);
1017  pm = std::move(*parsed);
1018  return success();
1019  }
1020 
1021  for (auto &passIt : impl->passList) {
1022  if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
1023  errorHandler)))
1024  return failure();
1025  }
1026  return success();
1027 }
1028 
1029 //===----------------------------------------------------------------------===//
1030 // PassNameCLParser
1031 
1032 /// Construct a pass pipeline parser with the given command line description.
1033 PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description)
1034  : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
1035  arg, description, /*passNamesOnly=*/true)) {
1036  impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
1037 }
1039 
1040 /// Returns true if this parser contains any valid options to add.
1042  return impl->passList.getNumOccurrences() != 0;
1043 }
1044 
1045 /// Returns true if the given pass registry entry was registered at the
1046 /// top-level of the parser, i.e. not within an explicit textual pipeline.
1048  return impl->contains(entry);
1049 }
static llvm::ManagedStatic< PassManagerOptions > options
static llvm::ManagedStatic< llvm::StringMap< PassPipelineInfo > > passPipelineRegistry
Static mapping of all of the registered pass pipelines.
#define PASS_PIPELINE_ARG
The name for the command line option used for parsing the textual pass pipeline.
static llvm::ManagedStatic< llvm::StringMap< PassInfo > > passRegistry
Static mapping of all of the registered passes.
static PassRegistryFunction buildDefaultRegistryFn(const PassAllocatorFunction &allocator)
Utility to create a default registry function from a pass instance.
static void printOptionHelp(StringRef arg, StringRef desc, size_t indent, size_t descIndent, bool isTopLevel)
Utility to print the help string for a specific option.
static llvm::ManagedStatic< llvm::StringMap< TypeID > > passRegistryTypeIDs
A mapping of the above pass registry entries to the corresponding TypeID of the pass that they genera...
static std::tuple< StringRef, StringRef, StringRef > parseNextArg(StringRef options)
Parse in the next argument from the given options string.
static size_t findChar(StringRef str, size_t index, char c)
Attempt to find the next occurance of character 'c' in the string starting from the index-th position...
static StringRef extractArgAndUpdateOptions(StringRef &options, size_t argSize)
Extract an argument from 'options' and update it to point after the arg.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
This class represents a pass manager that runs passes on either a specific operation type,...
Definition: PassManager.h:47
void printAsTextualPipeline(raw_ostream &os) const
Prints out the passes of the pass manager as the textual representation of pipelines.
Definition: Pass.cpp:401
std::optional< OperationName > getOpName(MLIRContext &context) const
Return the operation name that this pass manager operates on, or std::nullopt if this is an op-agnost...
Definition: Pass.cpp:382
void setNesting(Nesting nesting)
Enable or disable the implicit nesting on this particular PassManager.
Definition: Pass.cpp:425
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
Definition: Pass.cpp:363
Nesting getNesting()
Return the current nesting mode.
Definition: Pass.cpp:427
Nesting
This enum represents the nesting behavior of the pass manager.
Definition: PassManager.h:50
@ Explicit
Explicit nesting behavior.
StringRef getOpAnchorName() const
Return the name used to anchor this pass manager.
Definition: Pass.cpp:386
OpPassManager & nest(OperationName nestedName)
Nest a new operation pass manager for the given operation kind under this pass manager.
Definition: Pass.cpp:353
A structure to represent the information for a derived pass class.
Definition: PassRegistry.h:118
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
PassInfo(StringRef arg, StringRef description, const PassAllocatorFunction &allocator)
PassInfo constructor should not be invoked directly, instead use PassRegistration or registerPass.
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassNameCLParser(StringRef arg, StringRef description)
Construct a parser with the given command line description.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
This class implements a command-line parser for MLIR passes.
Definition: PassRegistry.h:247
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassPipelineCLParser(StringRef arg, StringRef description)
Construct a pass pipeline parser with the given command line description.
LogicalResult addToPipeline(OpPassManager &pm, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds the passes defined by this parser entry to the given pass manager.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
A structure to represent the information of a registered pass pipeline.
Definition: PassRegistry.h:104
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
Definition: PassRegistry.h:52
void printHelpStr(size_t indent, size_t descIndent) const
Print the help information for this pass.
size_t getOptionWidth() const
Return the maximum width required when printing the options of this entry.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
Base container class and manager for all pass options.
Definition: PassOptions.h:92
size_t getOptionWidth() const
Return the maximum width required when printing the help string.
void printHelp(size_t indent, size_t descIndent) const
Print the help string for the options held by this struct.
LogicalResult parseFromString(StringRef options, raw_ostream &errorStream=llvm::errs())
Parse options out as key=value pairs that can then be handed off to the llvm::cl command line passing...
void print(raw_ostream &os) const
Print the options held by this struct in a form that can be parsed via 'parseFromString'.
void copyOptionValuesFrom(const PassOptions &other)
Copy the option values from 'other' into 'this', where 'other' has the same options as 'this'.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
LogicalResult parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName, StringRef optionStr, function_ref< LogicalResult(StringRef)> elementParseFn)
Parse a string containing a list of comma-delimited elements, invoking the given parser for each sub-...
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:68
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
void printRegisteredPasses()
Prints the passes that were previously registered and stored in passRegistry.
std::function< std::unique_ptr< Pass >()> PassAllocatorFunction
Definition: PassRegistry.h:41
void registerPass(const PassAllocatorFunction &function)
Register a specific dialect pass allocator function with the system, typically used through the PassR...
void registerPassPipeline(StringRef arg, StringRef description, const PassRegistryFunction &function, std::function< void(function_ref< void(const detail::PassOptions &)>)> optHandler)
Register a specific dialect pipeline registry function with the system, typically used through the Pa...
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
std::function< LogicalResult(OpPassManager &, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler)> PassRegistryFunction
A registry function that adds passes to the given pass manager.
Definition: PassRegistry.h:40
Define a valid OptionValue for the command line pass argument.
OptionValue(const PassArgData &value)
void setValue(const PassArgData &value)
const PassArgData & getValue() const
mlir::OpPassManager & getValue() const
Returns the current value of the option.
Definition: PassOptions.h:488
bool hasValue() const
Returns if the current option has a value.
Definition: PassOptions.h:485
llvm::cl::list< PassArgData, bool, PassNameParser > passList
The set of passes and pass pipelines to run.
PassPipelineCLParserImpl(StringRef arg, StringRef description, bool passNamesOnly)
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...