13 #include "llvm/ADT/DenseMap.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/Support/Format.h"
16 #include "llvm/Support/ManagedStatic.h"
17 #include "llvm/Support/MemoryBuffer.h"
18 #include "llvm/Support/SourceMgr.h"
24 using namespace detail;
27 static llvm::ManagedStatic<llvm::StringMap<PassInfo>>
passRegistry;
34 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
41 function_ref<LogicalResult(
const Twine &)> errorHandler) {
42 std::unique_ptr<Pass> pass = allocator();
43 LogicalResult result = pass->initializeOptions(
options, errorHandler);
45 std::optional<StringRef> pmOpName = pm.
getOpName();
46 std::optional<StringRef> passOpName = pass->getOpName();
48 passOpName && *pmOpName != *passOpName) {
49 return errorHandler(llvm::Twine(
"Can't add pass '") + pass->getName() +
50 "' restricted to '" + *pass->getOpName() +
51 "' on a PassManager intended to run on '" +
61 size_t descIndent,
bool isTopLevel) {
62 size_t numSpaces = descIndent - indent - 4;
63 llvm::outs().indent(indent)
64 <<
"--" << llvm::left_justify(arg, numSpaces) <<
"- " << desc <<
'\n';
75 printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
79 options.printHelp(indent, descIndent);
88 maxLen =
options.getOptionWidth() + 2;
101 std::move(optHandler));
105 report_fatal_error(
"Pass pipeline " + arg +
" registered multiple times");
120 optHandler(allocator()->passOptions);
124 std::unique_ptr<Pass> pass =
function();
125 StringRef arg = pass->getArgument();
127 llvm::report_fatal_error(llvm::Twine(
"Trying to register '") +
129 "' pass that does not override `getArgument()`");
130 StringRef description = pass->getDescription();
131 PassInfo passInfo(arg, description,
function);
136 TypeID entryTypeID = pass->getTypeID();
138 if (it->second != entryTypeID)
139 llvm::report_fatal_error(
140 "pass allocator creates a different pass than previously "
141 "registered for pass " +
148 return it ==
passRegistry->end() ? nullptr : &it->second;
163 llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
164 function_ref<LogicalResult(StringRef)> elementParseFn) {
167 llvm::unique_function<size_t(StringRef,
size_t,
char)> findChar =
168 [&](StringRef str,
size_t index,
char c) ->
size_t {
169 for (
size_t i = index, e = str.size(); i < e; ++i) {
174 i = findChar(str, i + 1,
'}');
175 else if (str[i] ==
'(')
176 i = findChar(str, i + 1,
')');
177 else if (str[i] ==
'[')
178 i = findChar(str, i + 1,
']');
179 else if (str[i] ==
'\"')
180 i = str.find_first_of(
'\"', i + 1);
181 else if (str[i] ==
'\'')
182 i = str.find_first_of(
'\'', i + 1);
184 return StringRef::npos;
187 size_t nextElePos = findChar(optionStr, 0,
',');
188 while (nextElePos != StringRef::npos) {
190 if (failed(elementParseFn(optionStr.substr(0, nextElePos))))
193 optionStr = optionStr.substr(nextElePos + 1);
194 nextElePos = findChar(optionStr, 0,
',');
196 return elementParseFn(optionStr.substr(0, nextElePos));
200 void detail::PassOptions::OptionBase::anchor() {}
204 assert(options.size() == other.options.size());
207 for (
auto optionsIt : llvm::zip(options, other.options))
208 std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
214 static std::tuple<StringRef, StringRef, StringRef>
218 auto extractArgAndUpdateOptions = [&](
size_t argSize) {
219 StringRef str =
options.take_front(argSize).trim();
222 if (str.size() > 2) {
223 const auto escapePairs = {std::make_pair(
'\'',
'\''),
224 std::make_pair(
'"',
'"'),
225 std::make_pair(
'{',
'}')};
226 for (
const auto &escape : escapePairs) {
227 if (str.front() == escape.first && str.back() == escape.second) {
229 str = str.drop_front().drop_back().trim();
239 auto tryProcessPunct = [&](
size_t ¤tPos,
char punct) {
240 if (
options[currentPos] != punct)
242 size_t nextIt =
options.find_first_of(punct, currentPos + 1);
243 if (nextIt != StringRef::npos)
250 for (
size_t argEndIt = 0, optionsE =
options.size();; ++argEndIt) {
252 if (argEndIt == optionsE ||
options[argEndIt] ==
' ') {
253 argName = extractArgAndUpdateOptions(argEndIt);
254 return std::make_tuple(argName, StringRef(),
options);
258 if (
options[argEndIt] ==
'=') {
259 argName = extractArgAndUpdateOptions(argEndIt);
266 for (
size_t argEndIt = 0, optionsE =
options.size();; ++argEndIt) {
268 if (argEndIt == optionsE ||
options[argEndIt] ==
' ') {
269 StringRef value = extractArgAndUpdateOptions(argEndIt);
270 return std::make_tuple(argName, value,
options);
275 if (tryProcessPunct(argEndIt,
'\'') || tryProcessPunct(argEndIt,
'"'))
280 size_t braceCount = 1;
281 for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
283 if (tryProcessPunct(argEndIt,
'\'') || tryProcessPunct(argEndIt,
'"'))
287 else if (
options[argEndIt] ==
'}' && --braceCount == 0)
294 llvm_unreachable(
"unexpected control flow in pass option parsing");
298 raw_ostream &errorStream) {
302 StringRef key, value;
307 auto it = OptionsMap.find(key);
308 if (it == OptionsMap.end()) {
309 errorStream <<
"<Pass-Options-Parser>: no such option " << key <<
"\n";
312 if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
323 if (OptionsMap.empty())
328 auto compareOptionArgs = [](OptionBase *
const *lhs, OptionBase *
const *rhs) {
329 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
331 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
336 orderedOps, os, [&](OptionBase *option) { option->print(os); },
" ");
345 auto compareOptionArgs = [](OptionBase *
const *lhs, OptionBase *
const *rhs) {
346 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
348 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
349 for (OptionBase *option : orderedOps) {
354 llvm::outs().indent(indent);
355 option->getOption()->printOptionInfo(descIndent - indent);
374 llvm::cl::OptionValue<OpPassManager>::OptionValue() =
default;
375 llvm::cl::OptionValue<OpPassManager>::OptionValue(
379 llvm::cl::OptionValue<OpPassManager>::OptionValue(
384 llvm::cl::OptionValue<OpPassManager> &
385 llvm::cl::OptionValue<OpPassManager>::operator=(
391 llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() =
default;
393 void llvm::cl::OptionValue<OpPassManager>::setValue(
398 value = std::make_unique<mlir::OpPassManager>(newValue);
400 void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
402 assert(succeeded(pipeline) &&
"invalid pass pipeline");
408 std::string lhsStr, rhsStr;
410 raw_string_ostream lhsStream(lhsStr);
411 value->printAsTextualPipeline(lhsStream);
413 raw_string_ostream rhsStream(rhsStr);
418 return lhsStr == rhsStr;
421 void llvm::cl::OptionValue<OpPassManager>::anchor() {}
433 ParsedPassManager &value) {
435 if (failed(pipeline))
437 value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
447 const Option &opt,
OpPassManager &pm,
const OptVal &defaultValue,
448 size_t globalWidth)
const {
449 printOptionName(opt, globalWidth);
453 if (defaultValue.hasValue()) {
454 outs().indent(2) <<
" (default: ";
455 defaultValue.getValue().printAsTextualPipeline(outs());
466 ParsedPassManager &&) =
default;
476 class TextualPipeline {
480 LogicalResult initialize(StringRef text, raw_ostream &errorStream);
485 function_ref<LogicalResult(
const Twine &)> errorHandler)
const;
491 using ErrorHandlerT =
function_ref<LogicalResult(
const char *, Twine)>;
500 struct PipelineElement {
501 PipelineElement(StringRef name) : name(name) {}
506 std::vector<PipelineElement> innerPipeline;
512 LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
518 ErrorHandlerT errorHandler);
521 LogicalResult resolvePipelineElement(PipelineElement &element,
522 ErrorHandlerT errorHandler);
527 function_ref<LogicalResult(
const Twine &)> errorHandler)
const;
529 std::vector<PipelineElement> pipeline;
536 LogicalResult TextualPipeline::initialize(StringRef text,
537 raw_ostream &errorStream) {
542 llvm::SourceMgr pipelineMgr;
543 pipelineMgr.AddNewSourceBuffer(
544 llvm::MemoryBuffer::getMemBuffer(text,
"MLIR Textual PassPipeline Parser",
547 auto errorHandler = [&](
const char *rawLoc, Twine msg) {
548 pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc),
549 llvm::SourceMgr::DK_Error, msg);
554 if (failed(parsePipelineText(text, errorHandler)))
556 return resolvePipelineElements(pipeline, errorHandler);
560 LogicalResult TextualPipeline::addToPipeline(
562 function_ref<LogicalResult(
const Twine &)> errorHandler)
const {
568 auto restore = llvm::make_scope_exit([&]() { pm.
setNesting(nesting); });
570 return addToPipeline(pipeline, pm, errorHandler);
576 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
577 ErrorHandlerT errorHandler) {
580 std::vector<PipelineElement> &pipeline = *pipelineStack.back();
581 size_t pos = text.find_first_of(
",(){");
582 pipeline.emplace_back(text.substr(0, pos).trim());
585 if (pos == StringRef::npos)
588 text = text.substr(pos);
593 text = text.substr(1);
596 size_t close = StringRef::npos;
597 for (
unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
598 if (text[i] ==
'{') {
602 if (text[i] ==
'}' && --braceCount == 0) {
609 if (close == StringRef::npos) {
612 "missing closing '}' while processing pass options");
614 pipeline.back().options = text.substr(0, close);
615 text = text.substr(close + 1);
621 }
else if (sep ==
'(') {
622 text = text.substr(1);
625 pipelineStack.push_back(&pipeline.back().innerPipeline);
631 while (text.consume_front(
")")) {
633 if (pipelineStack.size() == 1)
634 return errorHandler(text.data() - 1,
635 "encountered extra closing ')' creating unbalanced "
636 "parentheses while parsing pipeline");
638 pipelineStack.pop_back();
649 if (!text.consume_front(
","))
650 return errorHandler(text.data(),
"expected ',' after parsing pipeline");
654 if (pipelineStack.size() > 1)
657 "encountered unbalanced parentheses while parsing pipeline");
659 assert(pipelineStack.back() == &pipeline &&
660 "wrong pipeline at the bottom of the stack");
666 LogicalResult TextualPipeline::resolvePipelineElements(
668 for (
auto &elt : elements)
669 if (failed(resolvePipelineElement(elt, errorHandler)))
676 TextualPipeline::resolvePipelineElement(PipelineElement &element,
677 ErrorHandlerT errorHandler) {
680 if (!element.innerPipeline.empty())
681 return resolvePipelineElements(element.innerPipeline, errorHandler);
693 auto *rawLoc = element.name.data();
694 return errorHandler(rawLoc,
"'" + element.name +
695 "' does not refer to a "
696 "registered pass or pass pipeline");
700 LogicalResult TextualPipeline::addToPipeline(
702 function_ref<LogicalResult(
const Twine &)> errorHandler)
const {
703 for (
auto &elt : elements) {
704 if (elt.registryEntry) {
705 if (failed(elt.registryEntry->addToPipeline(pm, elt.options,
707 return errorHandler(
"failed to add `" + elt.name +
"` with options `" +
710 }
else if (failed(addToPipeline(elt.innerPipeline, pm.
nest(elt.name),
712 return errorHandler(
"failed to add `" + elt.name +
"` with options `" +
713 elt.options +
"` to inner pipeline");
720 raw_ostream &errorStream) {
721 TextualPipeline pipelineParser;
722 if (failed(pipelineParser.initialize(pipeline, errorStream)))
724 auto errorHandler = [&](Twine msg) {
725 errorStream << msg <<
"\n";
728 if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
734 raw_ostream &errorStream) {
735 pipeline = pipeline.trim();
737 size_t pipelineStart = pipeline.find_first_of(
'(');
738 if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
739 !pipeline.consume_back(
")")) {
740 errorStream <<
"expected pass pipeline to be wrapped with the anchor "
741 "operation type, e.g. 'builtin.module(...)'";
745 StringRef opName = pipeline.take_front(pipelineStart).rtrim();
761 PassArgData() =
default;
763 : registryEntry(registryEntry) {}
786 const PassArgData &
getValue()
const {
return value; }
787 void setValue(
const PassArgData &value) { this->value = value; }
798 #define PASS_PIPELINE_ARG "pass-pipeline"
806 void printOptionInfo(
const llvm::cl::Option &opt,
807 size_t globalWidth)
const override;
808 size_t getOptionWidth(
const llvm::cl::Option &opt)
const override;
809 bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
815 bool passNamesOnly =
false;
819 void PassNameParser::initialize() {
824 addLiteralOption(kv.second.getPassArgument(), &kv.second,
825 kv.second.getPassDescription());
828 if (!passNamesOnly) {
830 addLiteralOption(kv.second.getPassArgument(), &kv.second,
831 kv.second.getPassDescription());
836 void PassNameParser::printOptionInfo(
const llvm::cl::Option &opt,
837 size_t globalWidth)
const {
841 llvm::outs() <<
" --" << opt.ArgStr <<
"=<pass-arg>";
842 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
847 if (opt.hasArgStr()) {
848 llvm::outs() <<
" --" << opt.ArgStr;
849 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
851 llvm::outs() <<
" " << opt.HelpStr <<
'\n';
855 auto printOrderedEntries = [&](StringRef header,
auto &map) {
858 orderedEntries.push_back(&kv.second);
859 llvm::array_pod_sort(
860 orderedEntries.begin(), orderedEntries.end(),
862 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
865 llvm::outs().indent(4) << header <<
":\n";
867 entry->printHelpStr(6, globalWidth);
878 size_t PassNameParser::getOptionWidth(
const llvm::cl::Option &opt)
const {
883 maxWidth =
std::max(maxWidth, entry.second.getOptionWidth() + 4);
885 maxWidth =
std::max(maxWidth, entry.second.getOptionWidth() + 4);
890 StringRef arg, PassArgData &value) {
906 : passList(arg,
llvm::cl::desc(description)) {
907 passList.getParser().passNamesOnly = passNamesOnly;
908 passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
914 return llvm::any_of(passList, [&](
const PassArgData &data) {
915 return data.registryEntry == entry;
920 llvm::cl::list<PassArgData, bool, PassNameParser>
passList;
928 arg, description, false)),
931 llvm::cl::desc(
"Textual description of the pass pipeline to run")) {}
936 passPipelineAlias.emplace(alias,
938 llvm::cl::aliasopt(passPipeline));
945 return passPipeline.getNumOccurrences() != 0 ||
946 impl->passList.getNumOccurrences() != 0;
952 return impl->contains(entry);
958 function_ref<LogicalResult(
const Twine &)> errorHandler)
const {
959 if (passPipeline.getNumOccurrences()) {
960 if (
impl->passList.getNumOccurrences())
963 "' option can't be used with individual pass options");
965 llvm::raw_string_ostream os(errMsg);
968 return errorHandler(errMsg);
969 pm = std::move(*parsed);
973 for (
auto &passIt :
impl->passList) {
974 if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
987 arg, description, true)) {
988 impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
994 return impl->passList.getNumOccurrences() != 0;
1000 return impl->contains(entry);
static llvm::ManagedStatic< PassManagerOptions > options
static llvm::ManagedStatic< llvm::StringMap< PassPipelineInfo > > passPipelineRegistry
Static mapping of all of the registered pass pipelines.
#define PASS_PIPELINE_ARG
The name for the command line option used for parsing the textual pass pipeline.
static llvm::ManagedStatic< llvm::StringMap< PassInfo > > passRegistry
Static mapping of all of the registered passes.
static PassRegistryFunction buildDefaultRegistryFn(const PassAllocatorFunction &allocator)
Utility to create a default registry function from a pass instance.
static void printOptionHelp(StringRef arg, StringRef desc, size_t indent, size_t descIndent, bool isTopLevel)
Utility to print the help string for a specific option.
static llvm::ManagedStatic< llvm::StringMap< TypeID > > passRegistryTypeIDs
A mapping of the above pass registry entries to the corresponding TypeID of the pass that they genera...
static std::tuple< StringRef, StringRef, StringRef > parseNextArg(StringRef options)
Parse in the next argument from the given options string.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
This class represents a pass manager that runs passes on either a specific operation type,...
void printAsTextualPipeline(raw_ostream &os) const
Prints out the passes of the pass manager as the textual representation of pipelines.
std::optional< OperationName > getOpName(MLIRContext &context) const
Return the operation name that this pass manager operates on, or std::nullopt if this is an op-agnost...
void setNesting(Nesting nesting)
Enable or disable the implicit nesting on this particular PassManager.
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
Nesting getNesting()
Return the current nesting mode.
Nesting
This enum represents the nesting behavior of the pass manager.
@ Explicit
Explicit nesting behavior.
StringRef getOpAnchorName() const
Return the name used to anchor this pass manager.
OpPassManager & nest(OperationName nestedName)
Nest a new operation pass manager for the given operation kind under this pass manager.
A structure to represent the information for a derived pass class.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
PassInfo(StringRef arg, StringRef description, const PassAllocatorFunction &allocator)
PassInfo constructor should not be invoked directly, instead use PassRegistration or registerPass.
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassNameCLParser(StringRef arg, StringRef description)
Construct a parser with the given command line description.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
This class implements a command-line parser for MLIR passes.
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassPipelineCLParser(StringRef arg, StringRef description)
Construct a pass pipeline parser with the given command line description.
LogicalResult addToPipeline(OpPassManager &pm, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds the passes defined by this parser entry to the given pass manager.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
A structure to represent the information of a registered pass pipeline.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
void printHelpStr(size_t indent, size_t descIndent) const
Print the help information for this pass.
size_t getOptionWidth() const
Return the maximum width required when printing the options of this entry.
This class provides an efficient unique identifier for a specific C++ type.
Base container class and manager for all pass options.
size_t getOptionWidth() const
Return the maximum width required when printing the help string.
void printHelp(size_t indent, size_t descIndent) const
Print the help string for the options held by this struct.
LogicalResult parseFromString(StringRef options, raw_ostream &errorStream=llvm::errs())
Parse options out as key=value pairs that can then be handed off to the llvm::cl command line passing...
void copyOptionValuesFrom(const PassOptions &other)
Copy the option values from 'other' into 'this', where 'other' has the same options as 'this'.
void print(raw_ostream &os)
Print the options held by this struct in a form that can be parsed via 'parseFromString'.
Include the generated interface declarations.
LogicalResult parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName, StringRef optionStr, function_ref< LogicalResult(StringRef)> elementParseFn)
Parse a string containing a list of comma-delimited elements, invoking the given parser for each sub-...
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
std::function< std::unique_ptr< Pass >()> PassAllocatorFunction
void registerPass(const PassAllocatorFunction &function)
Register a specific dialect pass allocator function with the system, typically used through the PassR...
void registerPassPipeline(StringRef arg, StringRef description, const PassRegistryFunction &function, std::function< void(function_ref< void(const detail::PassOptions &)>)> optHandler)
Register a specific dialect pipeline registry function with the system, typically used through the Pa...
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
std::function< LogicalResult(OpPassManager &, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler)> PassRegistryFunction
A registry function that adds passes to the given pass manager.
Define a valid OptionValue for the command line pass argument.
OptionValue(const PassArgData &value)
void setValue(const PassArgData &value)
const PassArgData & getValue() const
mlir::OpPassManager & getValue() const
Returns the current value of the option.
bool hasValue() const
Returns if the current option has a value.
llvm::cl::list< PassArgData, bool, PassNameParser > passList
The set of passes and pass pipelines to run.
PassPipelineCLParserImpl(StringRef arg, StringRef description, bool passNamesOnly)
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...