13 #include "llvm/ADT/ScopeExit.h"
14 #include "llvm/ADT/StringRef.h"
15 #include "llvm/Support/Format.h"
16 #include "llvm/Support/ManagedStatic.h"
17 #include "llvm/Support/MemoryBuffer.h"
18 #include "llvm/Support/SourceMgr.h"
24 using namespace detail;
27 static llvm::ManagedStatic<llvm::StringMap<PassInfo>>
passRegistry;
34 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
41 function_ref<LogicalResult(
const Twine &)> errorHandler) {
42 std::unique_ptr<Pass> pass = allocator();
43 LogicalResult result = pass->initializeOptions(
options, errorHandler);
45 std::optional<StringRef> pmOpName = pm.
getOpName();
46 std::optional<StringRef> passOpName = pass->getOpName();
48 passOpName && *pmOpName != *passOpName) {
49 return errorHandler(llvm::Twine(
"Can't add pass '") + pass->getName() +
50 "' restricted to '" + *pass->getOpName() +
51 "' on a PassManager intended to run on '" +
61 size_t descIndent,
bool isTopLevel) {
62 size_t numSpaces = descIndent - indent - 4;
63 llvm::outs().indent(indent)
64 <<
"--" << llvm::left_justify(arg, numSpaces) <<
"- " << desc <<
'\n';
75 maxWidth =
std::max(maxWidth, entry.second.getOptionWidth() + 4);
78 auto printOrderedEntries = [&](StringRef header,
auto &map) {
81 orderedEntries.push_back(&kv.second);
83 orderedEntries.begin(), orderedEntries.end(),
85 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
88 llvm::outs().indent(0) << header <<
":\n";
90 entry->printHelpStr(2, maxWidth);
101 printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
105 options.printHelp(indent, descIndent);
114 maxLen =
options.getOptionWidth() + 2;
127 std::move(optHandler));
131 report_fatal_error(
"Pass pipeline " + arg +
" registered multiple times");
146 optHandler(allocator()->passOptions);
150 std::unique_ptr<Pass> pass =
function();
151 StringRef arg = pass->getArgument();
153 llvm::report_fatal_error(llvm::Twine(
"Trying to register '") +
155 "' pass that does not override `getArgument()`");
156 StringRef description = pass->getDescription();
157 PassInfo passInfo(arg, description,
function);
162 TypeID entryTypeID = pass->getTypeID();
164 if (it->second != entryTypeID)
165 llvm::report_fatal_error(
166 "pass allocator creates a different pass than previously "
167 "registered for pass " +
174 return it ==
passRegistry->end() ? nullptr : &it->second;
191 static size_t findChar(StringRef str,
size_t index,
char c) {
192 for (
size_t i = index, e = str.size(); i < e; ++i) {
198 else if (str[i] ==
'(')
200 else if (str[i] ==
'[')
202 else if (str[i] ==
'\"')
203 i = str.find_first_of(
'\"', i + 1);
204 else if (str[i] ==
'\'')
205 i = str.find_first_of(
'\'', i + 1);
206 if (i == StringRef::npos)
207 return StringRef::npos;
209 return StringRef::npos;
216 StringRef str =
options.take_front(argSize).trim();
223 const auto escapePairs = {std::make_pair(
'\'',
'\''),
224 std::make_pair(
'"',
'"')};
225 for (
const auto &escape : escapePairs) {
226 if (str.front() == escape.first && str.back() == escape.second) {
229 return str.drop_front().drop_back().trim();
237 if (str.front() ==
'{') {
238 unsigned match =
findChar(str, 1,
'}');
239 if (match == str.size() - 1)
240 str = str.drop_front().drop_back().trim();
247 llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
248 function_ref<LogicalResult(StringRef)> elementParseFn) {
249 if (optionStr.empty())
252 size_t nextElePos =
findChar(optionStr, 0,
',');
253 while (nextElePos != StringRef::npos) {
260 optionStr = optionStr.drop_front();
261 nextElePos =
findChar(optionStr, 0,
',');
263 return elementParseFn(
268 void detail::PassOptions::OptionBase::anchor() {}
272 assert(options.size() == other.options.size());
275 for (
auto optionsIt : llvm::zip(options, other.options))
276 std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
282 static std::tuple<StringRef, StringRef, StringRef>
286 auto tryProcessPunct = [&](
size_t ¤tPos,
char punct) {
287 if (
options[currentPos] != punct)
289 size_t nextIt =
options.find_first_of(punct, currentPos + 1);
290 if (nextIt != StringRef::npos)
297 for (
size_t argEndIt = 0, optionsE =
options.size();; ++argEndIt) {
299 if (argEndIt == optionsE ||
options[argEndIt] ==
' ') {
301 return std::make_tuple(argName, StringRef(),
options);
305 if (
options[argEndIt] ==
'=') {
313 for (
size_t argEndIt = 0, optionsE =
options.size();; ++argEndIt) {
315 if (argEndIt == optionsE ||
options[argEndIt] ==
' ') {
317 return std::make_tuple(argName, value,
options);
322 if (tryProcessPunct(argEndIt,
'\'') || tryProcessPunct(argEndIt,
'"'))
327 size_t braceCount = 1;
328 for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
330 if (tryProcessPunct(argEndIt,
'\'') || tryProcessPunct(argEndIt,
'"'))
334 else if (
options[argEndIt] ==
'}' && --braceCount == 0)
341 llvm_unreachable(
"unexpected control flow in pass option parsing");
345 raw_ostream &errorStream) {
349 StringRef key, value;
354 auto it = OptionsMap.find(key);
355 if (it == OptionsMap.end()) {
356 errorStream <<
"<Pass-Options-Parser>: no such option " << key <<
"\n";
359 if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
370 if (OptionsMap.empty())
375 auto compareOptionArgs = [](OptionBase *
const *lhs, OptionBase *
const *rhs) {
376 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
378 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
383 orderedOps, os, [&](OptionBase *option) { option->print(os); },
" ");
392 auto compareOptionArgs = [](OptionBase *
const *lhs, OptionBase *
const *rhs) {
393 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
395 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
396 for (OptionBase *option : orderedOps) {
401 llvm::outs().indent(indent);
402 option->getOption()->printOptionInfo(descIndent - indent);
422 llvm::cl::OptionValue<OpPassManager>::OptionValue() =
default;
423 llvm::cl::OptionValue<OpPassManager>::OptionValue(
427 llvm::cl::OptionValue<OpPassManager>::OptionValue(
432 llvm::cl::OptionValue<OpPassManager> &
433 llvm::cl::OptionValue<OpPassManager>::operator=(
439 llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() =
default;
441 void llvm::cl::OptionValue<OpPassManager>::setValue(
446 value = std::make_unique<mlir::OpPassManager>(newValue);
448 void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
450 assert(succeeded(pipeline) &&
"invalid pass pipeline");
456 std::string lhsStr, rhsStr;
458 raw_string_ostream lhsStream(lhsStr);
459 value->printAsTextualPipeline(lhsStream);
461 raw_string_ostream rhsStream(rhsStr);
466 return lhsStr == rhsStr;
469 void llvm::cl::OptionValue<OpPassManager>::anchor() {}
482 ParsedPassManager &value) {
486 value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
496 const Option &opt,
OpPassManager &pm,
const OptVal &defaultValue,
497 size_t globalWidth)
const {
498 printOptionName(opt, globalWidth);
502 if (defaultValue.hasValue()) {
503 outs().indent(2) <<
" (default: ";
504 defaultValue.getValue().printAsTextualPipeline(outs());
515 ParsedPassManager &&) =
default;
525 class TextualPipeline {
529 LogicalResult initialize(StringRef text, raw_ostream &errorStream);
534 function_ref<LogicalResult(
const Twine &)> errorHandler)
const;
540 using ErrorHandlerT =
function_ref<LogicalResult(
const char *, Twine)>;
549 struct PipelineElement {
550 PipelineElement(StringRef name) : name(name) {}
555 std::vector<PipelineElement> innerPipeline;
561 LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
567 ErrorHandlerT errorHandler);
570 LogicalResult resolvePipelineElement(PipelineElement &element,
571 ErrorHandlerT errorHandler);
576 function_ref<LogicalResult(
const Twine &)> errorHandler)
const;
578 std::vector<PipelineElement> pipeline;
585 LogicalResult TextualPipeline::initialize(StringRef text,
586 raw_ostream &errorStream) {
591 llvm::SourceMgr pipelineMgr;
592 pipelineMgr.AddNewSourceBuffer(
593 llvm::MemoryBuffer::getMemBuffer(text,
"MLIR Textual PassPipeline Parser",
596 auto errorHandler = [&](
const char *rawLoc, Twine msg) {
597 pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc),
598 llvm::SourceMgr::DK_Error, msg);
603 if (
failed(parsePipelineText(text, errorHandler)))
605 return resolvePipelineElements(pipeline, errorHandler);
609 LogicalResult TextualPipeline::addToPipeline(
611 function_ref<LogicalResult(
const Twine &)> errorHandler)
const {
617 auto restore = llvm::make_scope_exit([&]() { pm.
setNesting(nesting); });
619 return addToPipeline(pipeline, pm, errorHandler);
625 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
626 ErrorHandlerT errorHandler) {
629 std::vector<PipelineElement> &pipeline = *pipelineStack.back();
630 size_t pos = text.find_first_of(
",(){");
631 pipeline.emplace_back(text.substr(0, pos).trim());
634 if (pos == StringRef::npos)
637 text = text.substr(pos);
642 text = text.substr(1);
645 size_t close = StringRef::npos;
646 for (
unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
647 if (text[i] ==
'{') {
651 if (text[i] ==
'}' && --braceCount == 0) {
658 if (close == StringRef::npos) {
661 "missing closing '}' while processing pass options");
663 pipeline.back().options = text.substr(0, close);
664 text = text.substr(close + 1);
670 }
else if (sep ==
'(') {
671 text = text.substr(1);
674 pipelineStack.push_back(&pipeline.back().innerPipeline);
680 while (text.consume_front(
")")) {
682 if (pipelineStack.size() == 1)
683 return errorHandler(text.data() - 1,
684 "encountered extra closing ')' creating unbalanced "
685 "parentheses while parsing pipeline");
687 pipelineStack.pop_back();
698 if (!text.consume_front(
","))
699 return errorHandler(text.data(),
"expected ',' after parsing pipeline");
703 if (pipelineStack.size() > 1)
706 "encountered unbalanced parentheses while parsing pipeline");
708 assert(pipelineStack.back() == &pipeline &&
709 "wrong pipeline at the bottom of the stack");
715 LogicalResult TextualPipeline::resolvePipelineElements(
717 for (
auto &elt : elements)
718 if (
failed(resolvePipelineElement(elt, errorHandler)))
725 TextualPipeline::resolvePipelineElement(PipelineElement &element,
726 ErrorHandlerT errorHandler) {
729 if (!element.innerPipeline.empty())
730 return resolvePipelineElements(element.innerPipeline, errorHandler);
742 auto *rawLoc = element.name.data();
743 return errorHandler(rawLoc,
"'" + element.name +
744 "' does not refer to a "
745 "registered pass or pass pipeline");
749 LogicalResult TextualPipeline::addToPipeline(
751 function_ref<LogicalResult(
const Twine &)> errorHandler)
const {
752 for (
auto &elt : elements) {
753 if (elt.registryEntry) {
754 if (
failed(elt.registryEntry->addToPipeline(pm, elt.options,
756 return errorHandler(
"failed to add `" + elt.name +
"` with options `" +
759 }
else if (
failed(addToPipeline(elt.innerPipeline, pm.
nest(elt.name),
761 return errorHandler(
"failed to add `" + elt.name +
"` with options `" +
762 elt.options +
"` to inner pipeline");
769 raw_ostream &errorStream) {
770 TextualPipeline pipelineParser;
771 if (
failed(pipelineParser.initialize(pipeline, errorStream)))
773 auto errorHandler = [&](Twine msg) {
774 errorStream << msg <<
"\n";
777 if (
failed(pipelineParser.addToPipeline(pm, errorHandler)))
783 raw_ostream &errorStream) {
784 pipeline = pipeline.trim();
786 size_t pipelineStart = pipeline.find_first_of(
'(');
787 if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
788 !pipeline.consume_back(
")")) {
789 errorStream <<
"expected pass pipeline to be wrapped with the anchor "
790 "operation type, e.g. 'builtin.module(...)'";
794 StringRef opName = pipeline.take_front(pipelineStart).rtrim();
810 PassArgData() =
default;
812 : registryEntry(registryEntry) {}
835 const PassArgData &
getValue()
const {
return value; }
836 void setValue(
const PassArgData &value) { this->value = value; }
847 #define PASS_PIPELINE_ARG "pass-pipeline"
855 void printOptionInfo(
const llvm::cl::Option &opt,
856 size_t globalWidth)
const override;
857 size_t getOptionWidth(
const llvm::cl::Option &opt)
const override;
858 bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
864 bool passNamesOnly =
false;
868 void PassNameParser::initialize() {
873 addLiteralOption(kv.second.getPassArgument(), &kv.second,
874 kv.second.getPassDescription());
877 if (!passNamesOnly) {
879 addLiteralOption(kv.second.getPassArgument(), &kv.second,
880 kv.second.getPassDescription());
885 void PassNameParser::printOptionInfo(
const llvm::cl::Option &opt,
886 size_t globalWidth)
const {
890 llvm::outs() <<
" --" << opt.ArgStr <<
"=<pass-arg>";
891 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
896 if (opt.hasArgStr()) {
897 llvm::outs() <<
" --" << opt.ArgStr;
898 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
900 llvm::outs() <<
" " << opt.HelpStr <<
'\n';
904 auto printOrderedEntries = [&](StringRef header,
auto &map) {
907 orderedEntries.push_back(&kv.second);
908 llvm::array_pod_sort(
909 orderedEntries.begin(), orderedEntries.end(),
911 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
914 llvm::outs().indent(4) << header <<
":\n";
916 entry->printHelpStr(6, globalWidth);
927 size_t PassNameParser::getOptionWidth(
const llvm::cl::Option &opt)
const {
932 maxWidth =
std::max(maxWidth, entry.second.getOptionWidth() + 4);
934 maxWidth =
std::max(maxWidth, entry.second.getOptionWidth() + 4);
939 StringRef arg, PassArgData &value) {
955 : passList(arg,
llvm::cl::desc(description)) {
956 passList.getParser().passNamesOnly = passNamesOnly;
957 passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
963 return llvm::any_of(passList, [&](
const PassArgData &data) {
964 return data.registryEntry == entry;
969 llvm::cl::list<PassArgData, bool, PassNameParser>
passList;
977 arg, description, false)),
980 llvm::cl::desc(
"Textual description of the pass pipeline to run")) {}
985 passPipelineAlias.emplace(alias,
987 llvm::cl::aliasopt(passPipeline));
994 return passPipeline.getNumOccurrences() != 0 ||
995 impl->passList.getNumOccurrences() != 0;
1001 return impl->contains(entry);
1007 function_ref<LogicalResult(
const Twine &)> errorHandler)
const {
1008 if (passPipeline.getNumOccurrences()) {
1009 if (
impl->passList.getNumOccurrences())
1010 return errorHandler(
1012 "' option can't be used with individual pass options");
1014 llvm::raw_string_ostream os(errMsg);
1017 return errorHandler(errMsg);
1018 pm = std::move(*parsed);
1022 for (
auto &passIt :
impl->passList) {
1023 if (
failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
1037 arg, description, true)) {
1038 impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
1044 return impl->passList.getNumOccurrences() != 0;
1050 return impl->contains(entry);
static llvm::ManagedStatic< PassManagerOptions > options
static llvm::ManagedStatic< llvm::StringMap< PassPipelineInfo > > passPipelineRegistry
Static mapping of all of the registered pass pipelines.
#define PASS_PIPELINE_ARG
The name for the command line option used for parsing the textual pass pipeline.
static llvm::ManagedStatic< llvm::StringMap< PassInfo > > passRegistry
Static mapping of all of the registered passes.
static PassRegistryFunction buildDefaultRegistryFn(const PassAllocatorFunction &allocator)
Utility to create a default registry function from a pass instance.
static void printOptionHelp(StringRef arg, StringRef desc, size_t indent, size_t descIndent, bool isTopLevel)
Utility to print the help string for a specific option.
static llvm::ManagedStatic< llvm::StringMap< TypeID > > passRegistryTypeIDs
A mapping of the above pass registry entries to the corresponding TypeID of the pass that they genera...
static std::tuple< StringRef, StringRef, StringRef > parseNextArg(StringRef options)
Parse in the next argument from the given options string.
static size_t findChar(StringRef str, size_t index, char c)
Attempt to find the next occurance of character 'c' in the string starting from the index-th position...
static StringRef extractArgAndUpdateOptions(StringRef &options, size_t argSize)
Extract an argument from 'options' and update it to point after the arg.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
This class represents a pass manager that runs passes on either a specific operation type,...
void printAsTextualPipeline(raw_ostream &os, bool pretty=false) const
Prints out the passes of the pass manager as the textual representation of pipelines.
std::optional< OperationName > getOpName(MLIRContext &context) const
Return the operation name that this pass manager operates on, or std::nullopt if this is an op-agnost...
void setNesting(Nesting nesting)
Enable or disable the implicit nesting on this particular PassManager.
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
Nesting getNesting()
Return the current nesting mode.
Nesting
This enum represents the nesting behavior of the pass manager.
@ Explicit
Explicit nesting behavior.
StringRef getOpAnchorName() const
Return the name used to anchor this pass manager.
OpPassManager & nest(OperationName nestedName)
Nest a new operation pass manager for the given operation kind under this pass manager.
A structure to represent the information for a derived pass class.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
PassInfo(StringRef arg, StringRef description, const PassAllocatorFunction &allocator)
PassInfo constructor should not be invoked directly, instead use PassRegistration or registerPass.
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassNameCLParser(StringRef arg, StringRef description)
Construct a parser with the given command line description.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
This class implements a command-line parser for MLIR passes.
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassPipelineCLParser(StringRef arg, StringRef description)
Construct a pass pipeline parser with the given command line description.
LogicalResult addToPipeline(OpPassManager &pm, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds the passes defined by this parser entry to the given pass manager.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
A structure to represent the information of a registered pass pipeline.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
void printHelpStr(size_t indent, size_t descIndent) const
Print the help information for this pass.
size_t getOptionWidth() const
Return the maximum width required when printing the options of this entry.
This class provides an efficient unique identifier for a specific C++ type.
Base container class and manager for all pass options.
size_t getOptionWidth() const
Return the maximum width required when printing the help string.
void printHelp(size_t indent, size_t descIndent) const
Print the help string for the options held by this struct.
LogicalResult parseFromString(StringRef options, raw_ostream &errorStream=llvm::errs())
Parse options out as key=value pairs that can then be handed off to the llvm::cl command line passing...
void print(raw_ostream &os) const
Print the options held by this struct in a form that can be parsed via 'parseFromString'.
void copyOptionValuesFrom(const PassOptions &other)
Copy the option values from 'other' into 'this', where 'other' has the same options as 'this'.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
LogicalResult parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName, StringRef optionStr, function_ref< LogicalResult(StringRef)> elementParseFn)
Parse a string containing a list of comma-delimited elements, invoking the given parser for each sub-...
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
void printRegisteredPasses()
Prints the passes that were previously registered and stored in passRegistry.
std::function< std::unique_ptr< Pass >()> PassAllocatorFunction
void registerPass(const PassAllocatorFunction &function)
Register a specific dialect pass allocator function with the system, typically used through the PassR...
void registerPassPipeline(StringRef arg, StringRef description, const PassRegistryFunction &function, std::function< void(function_ref< void(const detail::PassOptions &)>)> optHandler)
Register a specific dialect pipeline registry function with the system, typically used through the Pa...
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
std::function< LogicalResult(OpPassManager &, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler)> PassRegistryFunction
A registry function that adds passes to the given pass manager.
Define a valid OptionValue for the command line pass argument.
OptionValue(const PassArgData &value)
void setValue(const PassArgData &value)
const PassArgData & getValue() const
mlir::OpPassManager & getValue() const
Returns the current value of the option.
bool hasValue() const
Returns if the current option has a value.
llvm::cl::list< PassArgData, bool, PassNameParser > passList
The set of passes and pass pipelines to run.
PassPipelineCLParserImpl(StringRef arg, StringRef description, bool passNamesOnly)
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...