13 #include "llvm/ADT/DenseMap.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/Support/Format.h"
17 #include "llvm/Support/ManagedStatic.h"
18 #include "llvm/Support/MemoryBuffer.h"
19 #include "llvm/Support/SourceMgr.h"
25 using namespace detail;
28 static llvm::ManagedStatic<llvm::StringMap<PassInfo>>
passRegistry;
35 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
42 function_ref<LogicalResult(
const Twine &)> errorHandler) {
43 std::unique_ptr<Pass> pass = allocator();
44 LogicalResult result = pass->initializeOptions(
options, errorHandler);
46 std::optional<StringRef> pmOpName = pm.
getOpName();
47 std::optional<StringRef> passOpName = pass->getOpName();
49 passOpName && *pmOpName != *passOpName) {
50 return errorHandler(llvm::Twine(
"Can't add pass '") + pass->getName() +
51 "' restricted to '" + *pass->getOpName() +
52 "' on a PassManager intended to run on '" +
62 size_t descIndent,
bool isTopLevel) {
63 size_t numSpaces = descIndent - indent - 4;
64 llvm::outs().indent(indent)
65 <<
"--" << llvm::left_justify(arg, numSpaces) <<
"- " << desc <<
'\n';
76 maxWidth =
std::max(maxWidth, entry.second.getOptionWidth() + 4);
79 auto printOrderedEntries = [&](StringRef header,
auto &map) {
82 orderedEntries.push_back(&kv.second);
84 orderedEntries.begin(), orderedEntries.end(),
86 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
89 llvm::outs().indent(0) << header <<
":\n";
91 entry->printHelpStr(2, maxWidth);
102 printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
106 options.printHelp(indent, descIndent);
115 maxLen =
options.getOptionWidth() + 2;
128 std::move(optHandler));
132 report_fatal_error(
"Pass pipeline " + arg +
" registered multiple times");
147 optHandler(allocator()->passOptions);
151 std::unique_ptr<Pass> pass =
function();
152 StringRef arg = pass->getArgument();
154 llvm::report_fatal_error(llvm::Twine(
"Trying to register '") +
156 "' pass that does not override `getArgument()`");
157 StringRef description = pass->getDescription();
158 PassInfo passInfo(arg, description,
function);
163 TypeID entryTypeID = pass->getTypeID();
165 if (it->second != entryTypeID)
166 llvm::report_fatal_error(
167 "pass allocator creates a different pass than previously "
168 "registered for pass " +
175 return it ==
passRegistry->end() ? nullptr : &it->second;
192 static size_t findChar(StringRef str,
size_t index,
char c) {
193 for (
size_t i = index, e = str.size(); i < e; ++i) {
199 else if (str[i] ==
'(')
201 else if (str[i] ==
'[')
203 else if (str[i] ==
'\"')
204 i = str.find_first_of(
'\"', i + 1);
205 else if (str[i] ==
'\'')
206 i = str.find_first_of(
'\'', i + 1);
207 if (i == StringRef::npos)
208 return StringRef::npos;
210 return StringRef::npos;
217 StringRef str =
options.take_front(argSize).trim();
224 const auto escapePairs = {std::make_pair(
'\'',
'\''),
225 std::make_pair(
'"',
'"')};
226 for (
const auto &escape : escapePairs) {
227 if (str.front() == escape.first && str.back() == escape.second) {
230 return str.drop_front().drop_back().trim();
238 if (str.front() ==
'{') {
239 unsigned match =
findChar(str, 1,
'}');
240 if (match == str.size() - 1)
241 str = str.drop_front().drop_back().trim();
248 llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
249 function_ref<LogicalResult(StringRef)> elementParseFn) {
250 if (optionStr.empty())
253 size_t nextElePos =
findChar(optionStr, 0,
',');
254 while (nextElePos != StringRef::npos) {
261 optionStr = optionStr.drop_front();
262 nextElePos =
findChar(optionStr, 0,
',');
264 return elementParseFn(
269 void detail::PassOptions::OptionBase::anchor() {}
273 assert(options.size() == other.options.size());
276 for (
auto optionsIt : llvm::zip(options, other.options))
277 std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
283 static std::tuple<StringRef, StringRef, StringRef>
287 auto tryProcessPunct = [&](
size_t ¤tPos,
char punct) {
288 if (
options[currentPos] != punct)
290 size_t nextIt =
options.find_first_of(punct, currentPos + 1);
291 if (nextIt != StringRef::npos)
298 for (
size_t argEndIt = 0, optionsE =
options.size();; ++argEndIt) {
300 if (argEndIt == optionsE ||
options[argEndIt] ==
' ') {
302 return std::make_tuple(argName, StringRef(),
options);
306 if (
options[argEndIt] ==
'=') {
314 for (
size_t argEndIt = 0, optionsE =
options.size();; ++argEndIt) {
316 if (argEndIt == optionsE ||
options[argEndIt] ==
' ') {
318 return std::make_tuple(argName, value,
options);
323 if (tryProcessPunct(argEndIt,
'\'') || tryProcessPunct(argEndIt,
'"'))
328 size_t braceCount = 1;
329 for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
331 if (tryProcessPunct(argEndIt,
'\'') || tryProcessPunct(argEndIt,
'"'))
335 else if (
options[argEndIt] ==
'}' && --braceCount == 0)
342 llvm_unreachable(
"unexpected control flow in pass option parsing");
346 raw_ostream &errorStream) {
350 StringRef key, value;
355 auto it = OptionsMap.find(key);
356 if (it == OptionsMap.end()) {
357 errorStream <<
"<Pass-Options-Parser>: no such option " << key <<
"\n";
360 if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
371 if (OptionsMap.empty())
376 auto compareOptionArgs = [](OptionBase *
const *lhs, OptionBase *
const *rhs) {
377 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
379 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
384 orderedOps, os, [&](OptionBase *option) { option->print(os); },
" ");
393 auto compareOptionArgs = [](OptionBase *
const *lhs, OptionBase *
const *rhs) {
394 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
396 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
397 for (OptionBase *option : orderedOps) {
402 llvm::outs().indent(indent);
403 option->getOption()->printOptionInfo(descIndent - indent);
422 llvm::cl::OptionValue<OpPassManager>::OptionValue() =
default;
423 llvm::cl::OptionValue<OpPassManager>::OptionValue(
427 llvm::cl::OptionValue<OpPassManager>::OptionValue(
432 llvm::cl::OptionValue<OpPassManager> &
433 llvm::cl::OptionValue<OpPassManager>::operator=(
439 llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() =
default;
441 void llvm::cl::OptionValue<OpPassManager>::setValue(
446 value = std::make_unique<mlir::OpPassManager>(newValue);
448 void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
450 assert(succeeded(pipeline) &&
"invalid pass pipeline");
456 std::string lhsStr, rhsStr;
458 raw_string_ostream lhsStream(lhsStr);
459 value->printAsTextualPipeline(lhsStream);
461 raw_string_ostream rhsStream(rhsStr);
466 return lhsStr == rhsStr;
469 void llvm::cl::OptionValue<OpPassManager>::anchor() {}
481 ParsedPassManager &value) {
483 if (failed(pipeline))
485 value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
495 const Option &opt,
OpPassManager &pm,
const OptVal &defaultValue,
496 size_t globalWidth)
const {
497 printOptionName(opt, globalWidth);
501 if (defaultValue.hasValue()) {
502 outs().indent(2) <<
" (default: ";
503 defaultValue.getValue().printAsTextualPipeline(outs());
514 ParsedPassManager &&) =
default;
524 class TextualPipeline {
528 LogicalResult initialize(StringRef text, raw_ostream &errorStream);
533 function_ref<LogicalResult(
const Twine &)> errorHandler)
const;
539 using ErrorHandlerT =
function_ref<LogicalResult(
const char *, Twine)>;
548 struct PipelineElement {
549 PipelineElement(StringRef name) : name(name) {}
554 std::vector<PipelineElement> innerPipeline;
560 LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
566 ErrorHandlerT errorHandler);
569 LogicalResult resolvePipelineElement(PipelineElement &element,
570 ErrorHandlerT errorHandler);
575 function_ref<LogicalResult(
const Twine &)> errorHandler)
const;
577 std::vector<PipelineElement> pipeline;
584 LogicalResult TextualPipeline::initialize(StringRef text,
585 raw_ostream &errorStream) {
590 llvm::SourceMgr pipelineMgr;
591 pipelineMgr.AddNewSourceBuffer(
592 llvm::MemoryBuffer::getMemBuffer(text,
"MLIR Textual PassPipeline Parser",
595 auto errorHandler = [&](
const char *rawLoc, Twine msg) {
596 pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc),
597 llvm::SourceMgr::DK_Error, msg);
602 if (failed(parsePipelineText(text, errorHandler)))
604 return resolvePipelineElements(pipeline, errorHandler);
608 LogicalResult TextualPipeline::addToPipeline(
610 function_ref<LogicalResult(
const Twine &)> errorHandler)
const {
616 auto restore = llvm::make_scope_exit([&]() { pm.
setNesting(nesting); });
618 return addToPipeline(pipeline, pm, errorHandler);
624 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
625 ErrorHandlerT errorHandler) {
628 std::vector<PipelineElement> &pipeline = *pipelineStack.back();
629 size_t pos = text.find_first_of(
",(){");
630 pipeline.emplace_back(text.substr(0, pos).trim());
633 if (pos == StringRef::npos)
636 text = text.substr(pos);
641 text = text.substr(1);
644 size_t close = StringRef::npos;
645 for (
unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
646 if (text[i] ==
'{') {
650 if (text[i] ==
'}' && --braceCount == 0) {
657 if (close == StringRef::npos) {
660 "missing closing '}' while processing pass options");
662 pipeline.back().options = text.substr(0, close);
663 text = text.substr(close + 1);
669 }
else if (sep ==
'(') {
670 text = text.substr(1);
673 pipelineStack.push_back(&pipeline.back().innerPipeline);
679 while (text.consume_front(
")")) {
681 if (pipelineStack.size() == 1)
682 return errorHandler(text.data() - 1,
683 "encountered extra closing ')' creating unbalanced "
684 "parentheses while parsing pipeline");
686 pipelineStack.pop_back();
697 if (!text.consume_front(
","))
698 return errorHandler(text.data(),
"expected ',' after parsing pipeline");
702 if (pipelineStack.size() > 1)
705 "encountered unbalanced parentheses while parsing pipeline");
707 assert(pipelineStack.back() == &pipeline &&
708 "wrong pipeline at the bottom of the stack");
714 LogicalResult TextualPipeline::resolvePipelineElements(
716 for (
auto &elt : elements)
717 if (failed(resolvePipelineElement(elt, errorHandler)))
724 TextualPipeline::resolvePipelineElement(PipelineElement &element,
725 ErrorHandlerT errorHandler) {
728 if (!element.innerPipeline.empty())
729 return resolvePipelineElements(element.innerPipeline, errorHandler);
741 auto *rawLoc = element.name.data();
742 return errorHandler(rawLoc,
"'" + element.name +
743 "' does not refer to a "
744 "registered pass or pass pipeline");
748 LogicalResult TextualPipeline::addToPipeline(
750 function_ref<LogicalResult(
const Twine &)> errorHandler)
const {
751 for (
auto &elt : elements) {
752 if (elt.registryEntry) {
753 if (failed(elt.registryEntry->addToPipeline(pm, elt.options,
755 return errorHandler(
"failed to add `" + elt.name +
"` with options `" +
758 }
else if (failed(addToPipeline(elt.innerPipeline, pm.
nest(elt.name),
760 return errorHandler(
"failed to add `" + elt.name +
"` with options `" +
761 elt.options +
"` to inner pipeline");
768 raw_ostream &errorStream) {
769 TextualPipeline pipelineParser;
770 if (failed(pipelineParser.initialize(pipeline, errorStream)))
772 auto errorHandler = [&](Twine msg) {
773 errorStream << msg <<
"\n";
776 if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
782 raw_ostream &errorStream) {
783 pipeline = pipeline.trim();
785 size_t pipelineStart = pipeline.find_first_of(
'(');
786 if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
787 !pipeline.consume_back(
")")) {
788 errorStream <<
"expected pass pipeline to be wrapped with the anchor "
789 "operation type, e.g. 'builtin.module(...)'";
793 StringRef opName = pipeline.take_front(pipelineStart).rtrim();
809 PassArgData() =
default;
811 : registryEntry(registryEntry) {}
834 const PassArgData &
getValue()
const {
return value; }
835 void setValue(
const PassArgData &value) { this->value = value; }
846 #define PASS_PIPELINE_ARG "pass-pipeline"
854 void printOptionInfo(
const llvm::cl::Option &opt,
855 size_t globalWidth)
const override;
856 size_t getOptionWidth(
const llvm::cl::Option &opt)
const override;
857 bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
863 bool passNamesOnly =
false;
867 void PassNameParser::initialize() {
872 addLiteralOption(kv.second.getPassArgument(), &kv.second,
873 kv.second.getPassDescription());
876 if (!passNamesOnly) {
878 addLiteralOption(kv.second.getPassArgument(), &kv.second,
879 kv.second.getPassDescription());
884 void PassNameParser::printOptionInfo(
const llvm::cl::Option &opt,
885 size_t globalWidth)
const {
889 llvm::outs() <<
" --" << opt.ArgStr <<
"=<pass-arg>";
890 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
895 if (opt.hasArgStr()) {
896 llvm::outs() <<
" --" << opt.ArgStr;
897 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
899 llvm::outs() <<
" " << opt.HelpStr <<
'\n';
903 auto printOrderedEntries = [&](StringRef header,
auto &map) {
906 orderedEntries.push_back(&kv.second);
907 llvm::array_pod_sort(
908 orderedEntries.begin(), orderedEntries.end(),
910 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
913 llvm::outs().indent(4) << header <<
":\n";
915 entry->printHelpStr(6, globalWidth);
926 size_t PassNameParser::getOptionWidth(
const llvm::cl::Option &opt)
const {
931 maxWidth =
std::max(maxWidth, entry.second.getOptionWidth() + 4);
933 maxWidth =
std::max(maxWidth, entry.second.getOptionWidth() + 4);
938 StringRef arg, PassArgData &value) {
954 : passList(arg,
llvm::cl::desc(description)) {
955 passList.getParser().passNamesOnly = passNamesOnly;
956 passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
962 return llvm::any_of(passList, [&](
const PassArgData &data) {
963 return data.registryEntry == entry;
968 llvm::cl::list<PassArgData, bool, PassNameParser>
passList;
976 arg, description, false)),
979 llvm::cl::desc(
"Textual description of the pass pipeline to run")) {}
984 passPipelineAlias.emplace(alias,
986 llvm::cl::aliasopt(passPipeline));
993 return passPipeline.getNumOccurrences() != 0 ||
994 impl->passList.getNumOccurrences() != 0;
1000 return impl->contains(entry);
1006 function_ref<LogicalResult(
const Twine &)> errorHandler)
const {
1007 if (passPipeline.getNumOccurrences()) {
1008 if (
impl->passList.getNumOccurrences())
1009 return errorHandler(
1011 "' option can't be used with individual pass options");
1013 llvm::raw_string_ostream os(errMsg);
1016 return errorHandler(errMsg);
1017 pm = std::move(*parsed);
1021 for (
auto &passIt :
impl->passList) {
1022 if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
1035 arg, description, true)) {
1036 impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
1042 return impl->passList.getNumOccurrences() != 0;
1048 return impl->contains(entry);
static llvm::ManagedStatic< PassManagerOptions > options
static llvm::ManagedStatic< llvm::StringMap< PassPipelineInfo > > passPipelineRegistry
Static mapping of all of the registered pass pipelines.
#define PASS_PIPELINE_ARG
The name for the command line option used for parsing the textual pass pipeline.
static llvm::ManagedStatic< llvm::StringMap< PassInfo > > passRegistry
Static mapping of all of the registered passes.
static PassRegistryFunction buildDefaultRegistryFn(const PassAllocatorFunction &allocator)
Utility to create a default registry function from a pass instance.
static void printOptionHelp(StringRef arg, StringRef desc, size_t indent, size_t descIndent, bool isTopLevel)
Utility to print the help string for a specific option.
static llvm::ManagedStatic< llvm::StringMap< TypeID > > passRegistryTypeIDs
A mapping of the above pass registry entries to the corresponding TypeID of the pass that they genera...
static std::tuple< StringRef, StringRef, StringRef > parseNextArg(StringRef options)
Parse in the next argument from the given options string.
static size_t findChar(StringRef str, size_t index, char c)
Attempt to find the next occurance of character 'c' in the string starting from the index-th position...
static StringRef extractArgAndUpdateOptions(StringRef &options, size_t argSize)
Extract an argument from 'options' and update it to point after the arg.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
This class represents a pass manager that runs passes on either a specific operation type,...
void printAsTextualPipeline(raw_ostream &os) const
Prints out the passes of the pass manager as the textual representation of pipelines.
std::optional< OperationName > getOpName(MLIRContext &context) const
Return the operation name that this pass manager operates on, or std::nullopt if this is an op-agnost...
void setNesting(Nesting nesting)
Enable or disable the implicit nesting on this particular PassManager.
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
Nesting getNesting()
Return the current nesting mode.
Nesting
This enum represents the nesting behavior of the pass manager.
@ Explicit
Explicit nesting behavior.
StringRef getOpAnchorName() const
Return the name used to anchor this pass manager.
OpPassManager & nest(OperationName nestedName)
Nest a new operation pass manager for the given operation kind under this pass manager.
A structure to represent the information for a derived pass class.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
PassInfo(StringRef arg, StringRef description, const PassAllocatorFunction &allocator)
PassInfo constructor should not be invoked directly, instead use PassRegistration or registerPass.
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassNameCLParser(StringRef arg, StringRef description)
Construct a parser with the given command line description.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
This class implements a command-line parser for MLIR passes.
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassPipelineCLParser(StringRef arg, StringRef description)
Construct a pass pipeline parser with the given command line description.
LogicalResult addToPipeline(OpPassManager &pm, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds the passes defined by this parser entry to the given pass manager.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
A structure to represent the information of a registered pass pipeline.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
void printHelpStr(size_t indent, size_t descIndent) const
Print the help information for this pass.
size_t getOptionWidth() const
Return the maximum width required when printing the options of this entry.
This class provides an efficient unique identifier for a specific C++ type.
Base container class and manager for all pass options.
size_t getOptionWidth() const
Return the maximum width required when printing the help string.
void printHelp(size_t indent, size_t descIndent) const
Print the help string for the options held by this struct.
LogicalResult parseFromString(StringRef options, raw_ostream &errorStream=llvm::errs())
Parse options out as key=value pairs that can then be handed off to the llvm::cl command line passing...
void print(raw_ostream &os) const
Print the options held by this struct in a form that can be parsed via 'parseFromString'.
void copyOptionValuesFrom(const PassOptions &other)
Copy the option values from 'other' into 'this', where 'other' has the same options as 'this'.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
LogicalResult parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName, StringRef optionStr, function_ref< LogicalResult(StringRef)> elementParseFn)
Parse a string containing a list of comma-delimited elements, invoking the given parser for each sub-...
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
void printRegisteredPasses()
Prints the passes that were previously registered and stored in passRegistry.
std::function< std::unique_ptr< Pass >()> PassAllocatorFunction
void registerPass(const PassAllocatorFunction &function)
Register a specific dialect pass allocator function with the system, typically used through the PassR...
void registerPassPipeline(StringRef arg, StringRef description, const PassRegistryFunction &function, std::function< void(function_ref< void(const detail::PassOptions &)>)> optHandler)
Register a specific dialect pipeline registry function with the system, typically used through the Pa...
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
std::function< LogicalResult(OpPassManager &, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler)> PassRegistryFunction
A registry function that adds passes to the given pass manager.
Define a valid OptionValue for the command line pass argument.
OptionValue(const PassArgData &value)
void setValue(const PassArgData &value)
const PassArgData & getValue() const
mlir::OpPassManager & getValue() const
Returns the current value of the option.
bool hasValue() const
Returns if the current option has a value.
llvm::cl::list< PassArgData, bool, PassNameParser > passList
The set of passes and pass pipelines to run.
PassPipelineCLParserImpl(StringRef arg, StringRef description, bool passNamesOnly)
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...