MLIR 22.0.0git
PassRegistry.cpp
Go to the documentation of this file.
1//===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
11#include "mlir/Pass/Pass.h"
13#include "llvm/ADT/ScopeExit.h"
14#include "llvm/ADT/StringRef.h"
15#include "llvm/Support/Format.h"
16#include "llvm/Support/ManagedStatic.h"
17#include "llvm/Support/MemoryBuffer.h"
18#include "llvm/Support/SourceMgr.h"
19
20#include <optional>
21#include <utility>
22
23using namespace mlir;
24using namespace detail;
25
26/// Static mapping of all of the registered passes.
27static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
28
29/// A mapping of the above pass registry entries to the corresponding TypeID
30/// of the pass that they generate.
31static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
32
33/// Static mapping of all of the registered pass pipelines.
34static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
36
37/// Utility to create a default registry function from a pass instance.
40 return [=](OpPassManager &pm, StringRef options,
41 function_ref<LogicalResult(const Twine &)> errorHandler) {
42 std::unique_ptr<Pass> pass = allocator();
43 LogicalResult result = pass->initializeOptions(options, errorHandler);
44
45 std::optional<StringRef> pmOpName = pm.getOpName();
46 std::optional<StringRef> passOpName = pass->getOpName();
47 if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && pmOpName &&
48 passOpName && *pmOpName != *passOpName) {
49 return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
50 "' restricted to '" + *pass->getOpName() +
51 "' on a PassManager intended to run on '" +
52 pm.getOpAnchorName() + "', did you intend to nest?");
53 }
54 pm.addPass(std::move(pass));
55 return result;
56 };
57}
58
59/// Utility to print the help string for a specific option.
60static void printOptionHelp(StringRef arg, StringRef desc, size_t indent,
61 size_t descIndent, bool isTopLevel) {
62 size_t numSpaces = descIndent - indent - 4;
63 llvm::outs().indent(indent)
64 << "--" << llvm::left_justify(arg, numSpaces) << "- " << desc << '\n';
65}
66
67//===----------------------------------------------------------------------===//
68// PassRegistry
69//===----------------------------------------------------------------------===//
70
71/// Prints the passes that were previously registered and stored in passRegistry
73 size_t maxWidth = 0;
74 for (auto &entry : *passRegistry)
75 maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
76
77 // Functor used to print the ordered entries of a registration map.
78 auto printOrderedEntries = [&](StringRef header, auto &map) {
80 for (auto &kv : map)
81 orderedEntries.push_back(&kv.second);
82 llvm::array_pod_sort(
83 orderedEntries.begin(), orderedEntries.end(),
84 [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
85 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
86 });
87
88 llvm::outs().indent(0) << header << ":\n";
89 for (PassRegistryEntry *entry : orderedEntries)
90 entry->printHelpStr(/*indent=*/2, maxWidth);
91 };
92
93 // Print the available passes.
94 printOrderedEntries("Passes", *passRegistry);
95}
96
97/// Print the help information for this pass. This includes the argument,
98/// description, and any pass options. `descIndent` is the indent that the
99/// descriptions should be aligned.
100void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {
101 printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
102 /*isTopLevel=*/true);
103 // If this entry has options, print the help for those as well.
104 optHandler([=](const PassOptions &options) {
105 options.printHelp(indent, descIndent);
106 });
107}
108
109/// Return the maximum width required when printing the options of this
110/// entry.
112 size_t maxLen = 0;
113 optHandler([&](const PassOptions &options) mutable {
114 maxLen = options.getOptionWidth() + 2;
115 });
116 return maxLen;
117}
118
119//===----------------------------------------------------------------------===//
120// PassPipelineInfo
121//===----------------------------------------------------------------------===//
122
124 StringRef arg, StringRef description, const PassRegistryFunction &function,
125 std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {
126 PassPipelineInfo pipelineInfo(arg, description, function,
127 std::move(optHandler));
128 bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;
129#ifndef NDEBUG
130 if (!inserted)
131 report_fatal_error("Pass pipeline " + arg + " registered multiple times");
132#endif
133 (void)inserted;
134}
135
136//===----------------------------------------------------------------------===//
137// PassInfo
138//===----------------------------------------------------------------------===//
139
140PassInfo::PassInfo(StringRef arg, StringRef description,
141 const PassAllocatorFunction &allocator)
143 arg, description, buildDefaultRegistryFn(allocator),
144 // Use a temporary pass to provide an options instance.
145 [=](function_ref<void(const PassOptions &)> optHandler) {
146 optHandler(allocator()->passOptions);
147 }) {}
148
150 std::unique_ptr<Pass> pass = function();
151 StringRef arg = pass->getArgument();
152 if (arg.empty())
153 llvm::report_fatal_error(llvm::Twine("Trying to register '") +
154 pass->getName() +
155 "' pass that does not override `getArgument()`");
156 StringRef description = pass->getDescription();
157 PassInfo passInfo(arg, description, function);
158 passRegistry->try_emplace(arg, passInfo);
159
160 // Verify that the registered pass has the same ID as any registered to this
161 // arg before it.
162 TypeID entryTypeID = pass->getTypeID();
163 auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;
164 if (it->second != entryTypeID)
165 llvm::report_fatal_error(
166 "pass allocator creates a different pass than previously "
167 "registered for pass " +
168 arg);
169}
170
171/// Returns the pass info for the specified pass argument or null if unknown.
172const PassInfo *mlir::PassInfo::lookup(StringRef passArg) {
173 auto it = passRegistry->find(passArg);
174 return it == passRegistry->end() ? nullptr : &it->second;
175}
176
177/// Returns the pass pipeline info for the specified pass pipeline argument or
178/// null if unknown.
180 auto it = passPipelineRegistry->find(pipelineArg);
181 return it == passPipelineRegistry->end() ? nullptr : &it->second;
182}
183
184//===----------------------------------------------------------------------===//
185// PassOptions
186//===----------------------------------------------------------------------===//
187
188/// Attempt to find the next occurance of character 'c' in the string starting
189/// from the `index`-th position , omitting any occurances that appear within
190/// intervening ranges or literals.
191static size_t findChar(StringRef str, size_t index, char c) {
192 for (size_t i = index, e = str.size(); i < e; ++i) {
193 if (str[i] == c)
194 return i;
195 // Check for various range characters.
196 if (str[i] == '{')
197 i = findChar(str, i + 1, '}');
198 else if (str[i] == '(')
199 i = findChar(str, i + 1, ')');
200 else if (str[i] == '[')
201 i = findChar(str, i + 1, ']');
202 else if (str[i] == '\"')
203 i = str.find_first_of('\"', i + 1);
204 else if (str[i] == '\'')
205 i = str.find_first_of('\'', i + 1);
206 if (i == StringRef::npos)
207 return StringRef::npos;
208 }
209 return StringRef::npos;
210}
211
212/// Extract an argument from 'options' and update it to point after the arg.
213/// Returns the cleaned argument string.
214static StringRef extractArgAndUpdateOptions(StringRef &options,
215 size_t argSize) {
216 StringRef str = options.take_front(argSize).trim();
217 options = options.drop_front(argSize).ltrim();
218
219 // Early exit if there's no escape sequence.
220 if (str.size() <= 1)
221 return str;
222
223 const auto escapePairs = {std::make_pair('\'', '\''),
224 std::make_pair('"', '"')};
225 for (const auto &escape : escapePairs) {
226 if (str.front() == escape.first && str.back() == escape.second) {
227 // Drop the escape characters and trim.
228 // Don't process additional escape sequences.
229 return str.drop_front().drop_back().trim();
230 }
231 }
232
233 // Arguments may be wrapped in `{...}`. Unlike the quotation markers that
234 // denote literals, we respect scoping here. The outer `{...}` should not
235 // be stripped in cases such as "arg={...},{...}", which can be used to denote
236 // lists of nested option structs.
237 if (str.front() == '{') {
238 unsigned match = findChar(str, 1, '}');
239 if (match == str.size() - 1)
240 str = str.drop_front().drop_back().trim();
241 }
242
243 return str;
244}
245
247 llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
248 function_ref<LogicalResult(StringRef)> elementParseFn) {
249 if (optionStr.empty())
250 return success();
251
252 size_t nextElePos = findChar(optionStr, 0, ',');
253 while (nextElePos != StringRef::npos) {
254 // Process the portion before the comma.
255 if (failed(
256 elementParseFn(extractArgAndUpdateOptions(optionStr, nextElePos))))
257 return failure();
258
259 // Drop the leading ','
260 optionStr = optionStr.drop_front();
261 nextElePos = findChar(optionStr, 0, ',');
262 }
263 return elementParseFn(
264 extractArgAndUpdateOptions(optionStr, optionStr.size()));
265}
266
267/// Out of line virtual function to provide home for the class.
268void detail::PassOptions::OptionBase::anchor() {}
269
270/// Copy the option values from 'other'.
271void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) {
272 assert(options.size() == other.options.size());
273 if (options.empty())
274 return;
275 for (auto optionsIt : llvm::zip(options, other.options))
276 std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
277}
278
279/// Parse in the next argument from the given options string. Returns a tuple
280/// containing [the key of the option, the value of the option, updated
281/// `options` string pointing after the parsed option].
282static std::tuple<StringRef, StringRef, StringRef>
284 // Try to process the given punctuation, properly escaping any contained
285 // characters.
286 auto tryProcessPunct = [&](size_t &currentPos, char punct) {
287 if (options[currentPos] != punct)
288 return false;
289 size_t nextIt = options.find_first_of(punct, currentPos + 1);
290 if (nextIt != StringRef::npos)
291 currentPos = nextIt;
292 return true;
293 };
294
295 // Parse the argument name of the option.
296 StringRef argName;
297 for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
298 // Check for the end of the full option.
299 if (argEndIt == optionsE || options[argEndIt] == ' ') {
300 argName = extractArgAndUpdateOptions(options, argEndIt);
301 return std::make_tuple(argName, StringRef(), options);
302 }
303
304 // Check for the end of the name and the start of the value.
305 if (options[argEndIt] == '=') {
306 argName = extractArgAndUpdateOptions(options, argEndIt);
307 options = options.drop_front();
308 break;
309 }
310 }
311
312 // Parse the value of the option.
313 for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
314 // Handle the end of the options string.
315 if (argEndIt == optionsE || options[argEndIt] == ' ') {
316 StringRef value = extractArgAndUpdateOptions(options, argEndIt);
317 return std::make_tuple(argName, value, options);
318 }
319
320 // Skip over escaped sequences.
321 char c = options[argEndIt];
322 if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
323 continue;
324 // '{...}' is used to specify options to passes, properly escape it so
325 // that we don't accidentally split any nested options.
326 if (c == '{') {
327 size_t braceCount = 1;
328 for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
329 // Allow nested punctuation.
330 if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
331 continue;
332 if (options[argEndIt] == '{')
333 ++braceCount;
334 else if (options[argEndIt] == '}' && --braceCount == 0)
335 break;
336 }
337 // Account for the increment at the top of the loop.
338 --argEndIt;
339 }
340 }
341 llvm_unreachable("unexpected control flow in pass option parsing");
342}
343
344LogicalResult detail::PassOptions::parseFromString(StringRef options,
345 raw_ostream &errorStream) {
346 // NOTE: `options` is modified in place to always refer to the unprocessed
347 // part of the string.
348 while (!options.empty()) {
349 StringRef key, value;
350 std::tie(key, value, options) = parseNextArg(options);
351 if (key.empty())
352 continue;
353
354 auto it = OptionsMap.find(key);
355 if (it == OptionsMap.end()) {
356 errorStream << "<Pass-Options-Parser>: no such option " << key << "\n";
357 return failure();
358 }
359 if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
360 return failure();
361 }
362
363 return success();
364}
365
366/// Print the options held by this struct in a form that can be parsed via
367/// 'parseFromString'.
368void detail::PassOptions::print(raw_ostream &os) const {
369 // If there are no options, there is nothing left to do.
370 if (OptionsMap.empty())
371 return;
372
373 // Sort the options to make the ordering deterministic.
374 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
375 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
376 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
377 };
378 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
379
380 // Interleave the options with ' '.
381 os << '{';
382 llvm::interleave(
383 orderedOps, os, [&](OptionBase *option) { option->print(os); }, " ");
384 os << '}';
385}
386
387/// Print the help string for the options held by this struct. `descIndent` is
388/// the indent within the stream that the descriptions should be aligned.
389void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {
390 // Sort the options to make the ordering deterministic.
391 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
392 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
393 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
394 };
395 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
396 for (OptionBase *option : orderedOps) {
397 // TODO: printOptionInfo assumes a specific indent and will
398 // print options with values with incorrect indentation. We should add
399 // support to llvm::cl::Option for passing in a base indent to use when
400 // printing.
401 llvm::outs().indent(indent);
402 option->getOption()->printOptionInfo(descIndent - indent);
403 }
404}
405
406/// Return the maximum width required when printing the help string.
407size_t detail::PassOptions::getOptionWidth() const {
408 size_t max = 0;
409 for (auto *option : options)
410 max = std::max(max, option->getOption()->getOptionWidth());
411 return max;
412}
413
414//===----------------------------------------------------------------------===//
415// MLIR Options
416//===----------------------------------------------------------------------===//
417
418//===----------------------------------------------------------------------===//
419// OpPassManager: OptionValue
420//===----------------------------------------------------------------------===//
421
422llvm::cl::OptionValue<OpPassManager>::OptionValue() = default;
423llvm::cl::OptionValue<OpPassManager>::OptionValue(
424 const mlir::OpPassManager &value) {
425 setValue(value);
426}
427llvm::cl::OptionValue<OpPassManager>::OptionValue(
428 const llvm::cl::OptionValue<mlir::OpPassManager> &rhs) {
429 if (rhs.hasValue())
430 setValue(rhs.getValue());
431}
432llvm::cl::OptionValue<OpPassManager> &
433llvm::cl::OptionValue<OpPassManager>::operator=(
434 const mlir::OpPassManager &rhs) {
435 setValue(rhs);
436 return *this;
437}
438
439llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default;
440
441void llvm::cl::OptionValue<OpPassManager>::setValue(
442 const OpPassManager &newValue) {
443 if (hasValue())
444 *value = newValue;
445 else
446 value = std::make_unique<mlir::OpPassManager>(newValue);
447}
448void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
449 FailureOr<OpPassManager> pipeline = parsePassPipeline(pipelineStr);
450 assert(succeeded(pipeline) && "invalid pass pipeline");
451 setValue(*pipeline);
452}
453
454bool llvm::cl::OptionValue<OpPassManager>::compare(
455 const mlir::OpPassManager &rhs) const {
456 std::string lhsStr, rhsStr;
457 {
458 raw_string_ostream lhsStream(lhsStr);
459 value->printAsTextualPipeline(lhsStream);
460
461 raw_string_ostream rhsStream(rhsStr);
462 rhs.printAsTextualPipeline(rhsStream);
463 }
464
465 // Use the textual format for pipeline comparisons.
466 return lhsStr == rhsStr;
467}
468
469void llvm::cl::OptionValue<OpPassManager>::anchor() {}
470
471//===----------------------------------------------------------------------===//
472// OpPassManager: Parser
473//===----------------------------------------------------------------------===//
474
475namespace llvm {
476namespace cl {
477template class basic_parser<OpPassManager>;
478} // namespace cl
479} // namespace llvm
480
481bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg,
482 ParsedPassManager &value) {
483 FailureOr<OpPassManager> pipeline = parsePassPipeline(arg);
484 if (failed(pipeline))
485 return true;
486 value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
487 return false;
488}
489
490void llvm::cl::parser<OpPassManager>::print(raw_ostream &os,
491 const OpPassManager &value) {
492 value.printAsTextualPipeline(os);
493}
494
495void llvm::cl::parser<OpPassManager>::printOptionDiff(
496 const Option &opt, OpPassManager &pm, const OptVal &defaultValue,
497 size_t globalWidth) const {
498 printOptionName(opt, globalWidth);
499 outs() << "= ";
500 pm.printAsTextualPipeline(outs());
501
502 if (defaultValue.hasValue()) {
503 outs().indent(2) << " (default: ";
504 defaultValue.getValue().printAsTextualPipeline(outs());
505 outs() << ")";
506 }
507 outs() << "\n";
508}
509
510void llvm::cl::parser<OpPassManager>::anchor() {}
511
512llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager() =
513 default;
514llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager(
515 ParsedPassManager &&) = default;
516llvm::cl::parser<OpPassManager>::ParsedPassManager::~ParsedPassManager() =
517 default;
518
519//===----------------------------------------------------------------------===//
520// TextualPassPipeline Parser
521//===----------------------------------------------------------------------===//
522
523namespace {
524/// This class represents a textual description of a pass pipeline.
525class TextualPipeline {
526public:
527 /// Try to initialize this pipeline with the given pipeline text.
528 /// `errorStream` is the output stream to emit errors to.
529 LogicalResult initialize(StringRef text, raw_ostream &errorStream);
530
531 /// Add the internal pipeline elements to the provided pass manager.
532 LogicalResult
533 addToPipeline(OpPassManager &pm,
534 function_ref<LogicalResult(const Twine &)> errorHandler) const;
535
536private:
537 /// A functor used to emit errors found during pipeline handling. The first
538 /// parameter corresponds to the raw location within the pipeline string. This
539 /// should always return failure.
540 using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>;
541
542 /// A struct to capture parsed pass pipeline names.
543 ///
544 /// A pipeline is defined as a series of names, each of which may in itself
545 /// recursively contain a nested pipeline. A name is either the name of a pass
546 /// (e.g. "cse") or the name of an operation type (e.g. "buitin.module"). If
547 /// the name is the name of a pass, the InnerPipeline is empty, since passes
548 /// cannot contain inner pipelines.
549 struct PipelineElement {
550 PipelineElement(StringRef name) : name(name) {}
551
552 StringRef name;
553 StringRef options;
554 const PassRegistryEntry *registryEntry = nullptr;
555 std::vector<PipelineElement> innerPipeline;
556 };
557
558 /// Parse the given pipeline text into the internal pipeline vector. This
559 /// function only parses the structure of the pipeline, and does not resolve
560 /// its elements.
561 LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
562
563 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
564 /// the corresponding registry entry.
565 LogicalResult
566 resolvePipelineElements(MutableArrayRef<PipelineElement> elements,
567 ErrorHandlerT errorHandler);
568
569 /// Resolve a single element of the pipeline.
570 LogicalResult resolvePipelineElement(PipelineElement &element,
571 ErrorHandlerT errorHandler);
572
573 /// Add the given pipeline elements to the provided pass manager.
574 LogicalResult
575 addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
576 function_ref<LogicalResult(const Twine &)> errorHandler) const;
577
578 std::vector<PipelineElement> pipeline;
579};
580
581} // namespace
582
583/// Try to initialize this pipeline with the given pipeline text. An option is
584/// given to enable accurate error reporting.
585LogicalResult TextualPipeline::initialize(StringRef text,
586 raw_ostream &errorStream) {
587 if (text.empty())
588 return success();
589
590 // Build a source manager to use for error reporting.
591 llvm::SourceMgr pipelineMgr;
592 pipelineMgr.AddNewSourceBuffer(
593 llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser",
594 /*RequiresNullTerminator=*/false),
595 SMLoc());
596 auto errorHandler = [&](const char *rawLoc, Twine msg) {
597 pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc),
598 llvm::SourceMgr::DK_Error, msg);
599 return failure();
600 };
601
602 // Parse the provided pipeline string.
603 if (failed(parsePipelineText(text, errorHandler)))
604 return failure();
605 return resolvePipelineElements(pipeline, errorHandler);
606}
607
608/// Add the internal pipeline elements to the provided pass manager.
609LogicalResult TextualPipeline::addToPipeline(
610 OpPassManager &pm,
611 function_ref<LogicalResult(const Twine &)> errorHandler) const {
612 // Temporarily disable implicit nesting while we append to the pipeline. We
613 // want the created pipeline to exactly match the parsed text pipeline, so
614 // it's preferrable to just error out if implicit nesting would be required.
615 OpPassManager::Nesting nesting = pm.getNesting();
617 auto restore = llvm::make_scope_exit([&]() { pm.setNesting(nesting); });
618
619 return addToPipeline(pipeline, pm, errorHandler);
620}
621
622/// Parse the given pipeline text into the internal pipeline vector. This
623/// function only parses the structure of the pipeline, and does not resolve
624/// its elements.
625LogicalResult TextualPipeline::parsePipelineText(StringRef text,
626 ErrorHandlerT errorHandler) {
627 SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
628 for (;;) {
629 std::vector<PipelineElement> &pipeline = *pipelineStack.back();
630 size_t pos = text.find_first_of(",(){");
631 pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());
632
633 // If we have a single terminating name, we're done.
634 if (pos == StringRef::npos)
635 break;
636
637 text = text.substr(pos);
638 char sep = text[0];
639
640 // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
641 if (sep == '{') {
642 text = text.substr(1);
643
644 // Skip over everything until the closing '}' and store as options.
645 size_t close = StringRef::npos;
646 for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
647 if (text[i] == '{') {
648 ++braceCount;
649 continue;
650 }
651 if (text[i] == '}' && --braceCount == 0) {
652 close = i;
653 break;
654 }
655 }
656
657 // Check to see if a closing options brace was found.
658 if (close == StringRef::npos) {
659 return errorHandler(
660 /*rawLoc=*/text.data() - 1,
661 "missing closing '}' while processing pass options");
662 }
663 pipeline.back().options = text.substr(0, close);
664 text = text.substr(close + 1);
665
666 // Consume space characters that an user might add for readability.
667 text = text.ltrim();
668
669 // Skip checking for '(' because nested pipelines cannot have options.
670 } else if (sep == '(') {
671 text = text.substr(1);
672
673 // Push the inner pipeline onto the stack to continue processing.
674 pipelineStack.push_back(&pipeline.back().innerPipeline);
675 continue;
676 }
677
678 // When handling the close parenthesis, we greedily consume them to avoid
679 // empty strings in the pipeline.
680 while (text.consume_front(")")) {
681 // If we try to pop the outer pipeline we have unbalanced parentheses.
682 if (pipelineStack.size() == 1)
683 return errorHandler(/*rawLoc=*/text.data() - 1,
684 "encountered extra closing ')' creating unbalanced "
685 "parentheses while parsing pipeline");
686
687 pipelineStack.pop_back();
688 // Consume space characters that an user might add for readability.
689 text = text.ltrim();
690 }
691
692 // Check if we've finished parsing.
693 if (text.empty())
694 break;
695
696 // Otherwise, the end of an inner pipeline always has to be followed by
697 // a comma, and then we can continue.
698 if (!text.consume_front(","))
699 return errorHandler(text.data(), "expected ',' after parsing pipeline");
700 }
701
702 // Check for unbalanced parentheses.
703 if (pipelineStack.size() > 1)
704 return errorHandler(
705 text.data(),
706 "encountered unbalanced parentheses while parsing pipeline");
707
708 assert(pipelineStack.back() == &pipeline &&
709 "wrong pipeline at the bottom of the stack");
710 return success();
711}
712
713/// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
714/// the corresponding registry entry.
715LogicalResult TextualPipeline::resolvePipelineElements(
716 MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) {
717 for (auto &elt : elements)
718 if (failed(resolvePipelineElement(elt, errorHandler)))
719 return failure();
720 return success();
721}
722
723/// Resolve a single element of the pipeline.
724LogicalResult
725TextualPipeline::resolvePipelineElement(PipelineElement &element,
726 ErrorHandlerT errorHandler) {
727 // If the inner pipeline of this element is not empty, this is an operation
728 // pipeline.
729 if (!element.innerPipeline.empty())
730 return resolvePipelineElements(element.innerPipeline, errorHandler);
731
732 // Otherwise, this must be a pass or pass pipeline.
733 // Check to see if a pipeline was registered with this name.
734 if ((element.registryEntry = PassPipelineInfo::lookup(element.name)))
735 return success();
736
737 // If not, then this must be a specific pass name.
738 if ((element.registryEntry = PassInfo::lookup(element.name)))
739 return success();
740
741 // Emit an error for the unknown pass.
742 auto *rawLoc = element.name.data();
743 return errorHandler(rawLoc, "'" + element.name +
744 "' does not refer to a "
745 "registered pass or pass pipeline");
746}
747
748/// Add the given pipeline elements to the provided pass manager.
749LogicalResult TextualPipeline::addToPipeline(
750 ArrayRef<PipelineElement> elements, OpPassManager &pm,
751 function_ref<LogicalResult(const Twine &)> errorHandler) const {
752 for (auto &elt : elements) {
753 if (elt.registryEntry) {
754 if (failed(elt.registryEntry->addToPipeline(pm, elt.options,
755 errorHandler))) {
756 return errorHandler("failed to add `" + elt.name + "` with options `" +
757 elt.options + "`");
758 }
759 } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),
760 errorHandler))) {
761 return errorHandler("failed to add `" + elt.name + "` with options `" +
762 elt.options + "` to inner pipeline");
763 }
764 }
765 return success();
766}
767
768LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
769 raw_ostream &errorStream) {
770 TextualPipeline pipelineParser;
771 if (failed(pipelineParser.initialize(pipeline, errorStream)))
772 return failure();
773 auto errorHandler = [&](Twine msg) {
774 errorStream << msg << "\n";
775 return failure();
776 };
777 if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
778 return failure();
779 return success();
780}
781
782FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline,
783 raw_ostream &errorStream) {
784 pipeline = pipeline.trim();
785 // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
786 size_t pipelineStart = pipeline.find_first_of('(');
787 if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
788 !pipeline.consume_back(")")) {
789 errorStream << "expected pass pipeline to be wrapped with the anchor "
790 "operation type, e.g. 'builtin.module(...)'";
791 return failure();
792 }
793
794 StringRef opName = pipeline.take_front(pipelineStart).rtrim();
795 OpPassManager pm(opName);
796 if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm,
797 errorStream)))
798 return failure();
799 return pm;
800}
801
802//===----------------------------------------------------------------------===//
803// PassNameParser
804//===----------------------------------------------------------------------===//
805
806namespace {
807/// This struct represents the possible data entries in a parsed pass pipeline
808/// list.
809struct PassArgData {
810 PassArgData() = default;
811 PassArgData(const PassRegistryEntry *registryEntry)
812 : registryEntry(registryEntry) {}
813
814 /// This field is used when the parsed option corresponds to a registered pass
815 /// or pass pipeline.
816 const PassRegistryEntry *registryEntry{nullptr};
817
818 /// This field is set when instance specific pass options have been provided
819 /// on the command line.
820 StringRef options;
821};
822} // namespace
823
824namespace llvm {
825namespace cl {
826/// Define a valid OptionValue for the command line pass argument.
827template <>
829 : OptionValueBase<PassArgData, /*isClass=*/true> {
830 OptionValue(const PassArgData &value) { this->setValue(value); }
831 OptionValue() = default;
832 void anchor() override {}
833
834 bool hasValue() const { return true; }
835 const PassArgData &getValue() const { return value; }
836 void setValue(const PassArgData &value) { this->value = value; }
837
838 PassArgData value;
839};
840} // namespace cl
841} // namespace llvm
842
843namespace {
844
845/// The name for the command line option used for parsing the textual pass
846/// pipeline.
847#define PASS_PIPELINE_ARG "pass-pipeline"
848
849/// Adds command line option for each registered pass or pass pipeline, as well
850/// as textual pass pipelines.
851struct PassNameParser : public llvm::cl::parser<PassArgData> {
852 PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}
853
854 void initialize();
855 void printOptionInfo(const llvm::cl::Option &opt,
856 size_t globalWidth) const override;
857 size_t getOptionWidth(const llvm::cl::Option &opt) const override;
858 bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
859 PassArgData &value);
860
861 /// If true, this parser only parses entries that correspond to a concrete
862 /// pass registry entry, and does not include pipeline entries or the options
863 /// for pass entries.
864 bool passNamesOnly = false;
865};
866} // namespace
867
868void PassNameParser::initialize() {
869 llvm::cl::parser<PassArgData>::initialize();
870
871 /// Add the pass entries.
872 for (const auto &kv : *passRegistry) {
873 addLiteralOption(kv.second.getPassArgument(), &kv.second,
874 kv.second.getPassDescription());
875 }
876 /// Add the pass pipeline entries.
877 if (!passNamesOnly) {
878 for (const auto &kv : *passPipelineRegistry) {
879 addLiteralOption(kv.second.getPassArgument(), &kv.second,
880 kv.second.getPassDescription());
881 }
882 }
883}
884
885void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,
886 size_t globalWidth) const {
887 // If this parser is just parsing pass names, print a simplified option
888 // string.
889 if (passNamesOnly) {
890 llvm::outs() << " --" << opt.ArgStr << "=<pass-arg>";
891 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
892 return;
893 }
894
895 // Print the information for the top-level option.
896 if (opt.hasArgStr()) {
897 llvm::outs() << " --" << opt.ArgStr;
898 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
899 } else {
900 llvm::outs() << " " << opt.HelpStr << '\n';
901 }
902
903 // Functor used to print the ordered entries of a registration map.
904 auto printOrderedEntries = [&](StringRef header, auto &map) {
905 llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries;
906 for (auto &kv : map)
907 orderedEntries.push_back(&kv.second);
908 llvm::array_pod_sort(
909 orderedEntries.begin(), orderedEntries.end(),
910 [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
911 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
912 });
913
914 llvm::outs().indent(4) << header << ":\n";
915 for (PassRegistryEntry *entry : orderedEntries)
916 entry->printHelpStr(/*indent=*/6, globalWidth);
917 };
918
919 // Print the available passes.
920 printOrderedEntries("Passes", *passRegistry);
921
922 // Print the available pass pipelines.
923 if (!passPipelineRegistry->empty())
924 printOrderedEntries("Pass Pipelines", *passPipelineRegistry);
925}
926
927size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {
928 size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2;
929
930 // Check for any wider pass or pipeline options.
931 for (auto &entry : *passRegistry)
932 maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
933 for (auto &entry : *passPipelineRegistry)
934 maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
935 return maxWidth;
936}
937
938bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
939 StringRef arg, PassArgData &value) {
940 if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value))
941 return true;
942 value.options = arg;
943 return false;
944}
945
946//===----------------------------------------------------------------------===//
947// PassPipelineCLParser
948//===----------------------------------------------------------------------===//
949
950namespace mlir {
951namespace detail {
953 PassPipelineCLParserImpl(StringRef arg, StringRef description,
954 bool passNamesOnly)
955 : passList(arg, llvm::cl::desc(description)) {
956 passList.getParser().passNamesOnly = passNamesOnly;
957 passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
958 }
959
960 /// Returns true if the given pass registry entry was registered at the
961 /// top-level of the parser, i.e. not within an explicit textual pipeline.
962 bool contains(const PassRegistryEntry *entry) const {
963 return llvm::any_of(passList, [&](const PassArgData &data) {
964 return data.registryEntry == entry;
965 });
966 }
967
968 /// The set of passes and pass pipelines to run.
969 llvm::cl::list<PassArgData, bool, PassNameParser> passList;
970};
971} // namespace detail
972} // namespace mlir
973
974/// Construct a pass pipeline parser with the given command line description.
975PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)
976 : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
977 arg, description, /*passNamesOnly=*/false)),
978 passPipeline(
980 llvm::cl::desc("Textual description of the pass pipeline to run")) {}
981
982PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description,
983 StringRef alias)
984 : PassPipelineCLParser(arg, description) {
985 passPipelineAlias.emplace(alias,
986 llvm::cl::desc("Alias for --" PASS_PIPELINE_ARG),
987 llvm::cl::aliasopt(passPipeline));
988}
989
991
992/// Returns true if this parser contains any valid options to add.
994 return passPipeline.getNumOccurrences() != 0 ||
995 impl->passList.getNumOccurrences() != 0;
996}
997
998/// Returns true if the given pass registry entry was registered at the
999/// top-level of the parser, i.e. not within an explicit textual pipeline.
1001 return impl->contains(entry);
1002}
1003
1004/// Adds the passes defined by this parser entry to the given pass manager.
1006 OpPassManager &pm,
1007 function_ref<LogicalResult(const Twine &)> errorHandler) const {
1008 if (passPipeline.getNumOccurrences()) {
1009 if (impl->passList.getNumOccurrences())
1010 return errorHandler(
1012 "' option can't be used with individual pass options");
1013 std::string errMsg;
1014 llvm::raw_string_ostream os(errMsg);
1015 FailureOr<OpPassManager> parsed = parsePassPipeline(passPipeline, os);
1016 if (failed(parsed))
1017 return errorHandler(errMsg);
1018 pm = std::move(*parsed);
1019 return success();
1020 }
1021
1022 for (auto &passIt : impl->passList) {
1023 if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
1024 errorHandler)))
1025 return failure();
1026 }
1027 return success();
1028}
1029
1030//===----------------------------------------------------------------------===//
1031// PassNameCLParser
1032//===----------------------------------------------------------------------===//
1033
1034/// Construct a pass pipeline parser with the given command line description.
1035PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description)
1036 : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
1037 arg, description, /*passNamesOnly=*/true)) {
1038 impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
1039}
1041
1042/// Returns true if this parser contains any valid options to add.
1044 return impl->passList.getNumOccurrences() != 0;
1045}
1046
1047/// Returns true if the given pass registry entry was registered at the
1048/// top-level of the parser, i.e. not within an explicit textual pipeline.
1050 return impl->contains(entry);
1051}
return success()
true
Given two iterators into the same block, return "true" if a is before `b.
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
lhs
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
false
Parses a map_entries map type from a string format back into its numeric value.
static llvm::ManagedStatic< PassManagerOptions > options
static llvm::ManagedStatic< llvm::StringMap< PassPipelineInfo > > passPipelineRegistry
Static mapping of all of the registered pass pipelines.
#define PASS_PIPELINE_ARG
The name for the command line option used for parsing the textual pass pipeline.
static llvm::ManagedStatic< llvm::StringMap< PassInfo > > passRegistry
Static mapping of all of the registered passes.
static PassRegistryFunction buildDefaultRegistryFn(const PassAllocatorFunction &allocator)
Utility to create a default registry function from a pass instance.
static void printOptionHelp(StringRef arg, StringRef desc, size_t indent, size_t descIndent, bool isTopLevel)
Utility to print the help string for a specific option.
static llvm::ManagedStatic< llvm::StringMap< TypeID > > passRegistryTypeIDs
A mapping of the above pass registry entries to the corresponding TypeID of the pass that they genera...
static std::tuple< StringRef, StringRef, StringRef > parseNextArg(StringRef options)
Parse in the next argument from the given options string.
static size_t findChar(StringRef str, size_t index, char c)
Attempt to find the next occurance of character 'c' in the string starting from the index-th position...
static StringRef extractArgAndUpdateOptions(StringRef &options, size_t argSize)
Extract an argument from 'options' and update it to point after the arg.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class represents a pass manager that runs passes on either a specific operation type,...
Definition PassManager.h:46
void printAsTextualPipeline(raw_ostream &os, bool pretty=false) const
Prints out the passes of the pass manager as the textual representation of pipelines.
Definition Pass.cpp:452
std::optional< OperationName > getOpName(MLIRContext &context) const
Return the operation name that this pass manager operates on, or std::nullopt if this is an op-agnost...
Definition Pass.cpp:411
void setNesting(Nesting nesting)
Enable or disable the implicit nesting on this particular PassManager.
Definition Pass.cpp:478
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
Definition Pass.cpp:392
Nesting getNesting()
Return the current nesting mode.
Definition Pass.cpp:480
Nesting
This enum represents the nesting behavior of the pass manager.
Definition PassManager.h:49
@ Explicit
Explicit nesting behavior.
Definition PassManager.h:56
StringRef getOpAnchorName() const
Return the name used to anchor this pass manager.
Definition Pass.cpp:415
OpPassManager & nest(OperationName nestedName)
Nest a new operation pass manager for the given operation kind under this pass manager.
Definition Pass.cpp:382
A structure to represent the information for a derived pass class.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
PassInfo(StringRef arg, StringRef description, const PassAllocatorFunction &allocator)
PassInfo constructor should not be invoked directly, instead use PassRegistration or registerPass.
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassNameCLParser(StringRef arg, StringRef description)
Construct a parser with the given command line description.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassPipelineCLParser(StringRef arg, StringRef description)
Construct a pass pipeline parser with the given command line description.
LogicalResult addToPipeline(OpPassManager &pm, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds the passes defined by this parser entry to the given pass manager.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
A structure to represent the information of a registered pass pipeline.
PassPipelineInfo(StringRef arg, StringRef description, const PassRegistryFunction &builder, std::function< void(function_ref< void(const detail::PassOptions &)>)> optHandler)
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
void printHelpStr(size_t indent, size_t descIndent) const
Print the help information for this pass.
size_t getOptionWidth() const
Return the maximum width required when printing the options of this entry.
PassRegistryEntry(StringRef arg, StringRef description, const PassRegistryFunction &builder, std::function< void(function_ref< void(const detail::PassOptions &)>)> optHandler)
StringRef getPassDescription() const
Returns a description for the pass, this never returns null.
StringRef getPassArgument() const
Returns the command line option that may be passed to 'mlir-opt' that will cause this pass to run or ...
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
This class represents a specific pass option, with a provided data type.
Base container class and manager for all pass options.
Definition PassOptions.h:89
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
LogicalResult parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName, StringRef optionStr, function_ref< LogicalResult(StringRef)> elementParseFn)
Parse a string containing a list of comma-delimited elements, invoking the given parser for each sub-...
AttrTypeReplacer.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
std::function< std::unique_ptr< Pass >()> PassAllocatorFunction
std::function< LogicalResult( OpPassManager &, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler)> PassRegistryFunction
A registry function that adds passes to the given pass manager.
void printRegisteredPasses()
Prints the passes that were previously registered and stored in passRegistry.
void registerPass(const PassAllocatorFunction &function)
Register a specific dialect pass allocator function with the system, typically used through the PassR...
void registerPassPipeline(StringRef arg, StringRef description, const PassRegistryFunction &function, std::function< void(function_ref< void(const detail::PassOptions &)>)> optHandler)
Register a specific dialect pipeline registry function with the system, typically used through the Pa...
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
const PassArgData & getValue() const
OptionValue(const PassArgData &value)
void setValue(const PassArgData &value)
llvm::cl::list< PassArgData, bool, PassNameParser > passList
The set of passes and pass pipelines to run.
PassPipelineCLParserImpl(StringRef arg, StringRef description, bool passNamesOnly)
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...