15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/ScopeExit.h"
17 #include "llvm/Support/Format.h"
18 #include "llvm/Support/ManagedStatic.h"
19 #include "llvm/Support/MemoryBuffer.h"
20 #include "llvm/Support/SourceMgr.h"
23 using namespace detail;
26 static llvm::ManagedStatic<llvm::StringMap<PassInfo>>
passRegistry;
33 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
41 std::unique_ptr<Pass> pass = allocator();
44 std::optional<StringRef> pmOpName = pm.
getOpName();
45 std::optional<StringRef> passOpName = pass->getOpName();
47 passOpName && *pmOpName != *passOpName) {
48 return errorHandler(llvm::Twine(
"Can't add pass '") + pass->getName() +
49 "' restricted to '" + *pass->getOpName() +
50 "' on a PassManager intended to run on '" +
60 size_t descIndent,
bool isTopLevel) {
61 size_t numSpaces = descIndent - indent - 4;
62 llvm::outs().indent(indent)
63 <<
"--" << llvm::left_justify(arg, numSpaces) <<
"- " << desc <<
'\n';
74 printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
78 options.printHelp(indent, descIndent);
87 maxLen =
options.getOptionWidth() + 2;
100 std::move(optHandler));
102 assert(inserted &&
"Pass pipeline registered multiple times");
116 optHandler(allocator()->passOptions);
120 std::unique_ptr<Pass> pass =
function();
121 StringRef arg = pass->getArgument();
123 llvm::report_fatal_error(llvm::Twine(
"Trying to register '") +
125 "' pass that does not override `getArgument()`");
126 StringRef description = pass->getDescription();
127 PassInfo passInfo(arg, description,
function);
132 TypeID entryTypeID = pass->getTypeID();
134 if (it->second != entryTypeID)
135 llvm::report_fatal_error(
136 "pass allocator creates a different pass than previously "
137 "registered for pass " +
144 return it ==
passRegistry->end() ? nullptr : &it->second;
152 llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
156 llvm::unique_function<size_t(StringRef,
size_t,
char)> findChar =
157 [&](StringRef str,
size_t index,
char c) ->
size_t {
158 for (
size_t i = index, e = str.size(); i < e; ++i) {
163 i = findChar(str, i + 1,
'}');
164 else if (str[i] ==
'(')
165 i = findChar(str, i + 1,
')');
166 else if (str[i] ==
'[')
167 i = findChar(str, i + 1,
']');
168 else if (str[i] ==
'\"')
169 i = str.find_first_of(
'\"', i + 1);
170 else if (str[i] ==
'\'')
171 i = str.find_first_of(
'\'', i + 1);
173 return StringRef::npos;
176 size_t nextElePos = findChar(optionStr, 0,
',');
177 while (nextElePos != StringRef::npos) {
179 if (
failed(elementParseFn(optionStr.substr(0, nextElePos))))
182 optionStr = optionStr.substr(nextElePos + 1);
183 nextElePos = findChar(optionStr, 0,
',');
185 return elementParseFn(optionStr.substr(0, nextElePos));
189 void detail::PassOptions::OptionBase::anchor() {}
193 assert(options.size() == other.options.size());
196 for (
auto optionsIt : llvm::zip(options, other.options))
197 std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
203 static std::tuple<StringRef, StringRef, StringRef>
207 auto extractArgAndUpdateOptions = [&](
size_t argSize) {
208 StringRef str =
options.take_front(argSize).trim();
214 auto tryProcessPunct = [&](
size_t ¤tPos,
char punct) {
215 if (
options[currentPos] != punct)
217 size_t nextIt =
options.find_first_of(punct, currentPos + 1);
218 if (nextIt != StringRef::npos)
225 for (
size_t argEndIt = 0, optionsE =
options.size();; ++argEndIt) {
227 if (argEndIt == optionsE ||
options[argEndIt] ==
' ') {
228 argName = extractArgAndUpdateOptions(argEndIt);
229 return std::make_tuple(argName, StringRef(),
options);
233 if (
options[argEndIt] ==
'=') {
234 argName = extractArgAndUpdateOptions(argEndIt);
241 for (
size_t argEndIt = 0, optionsE =
options.size();; ++argEndIt) {
243 if (argEndIt == optionsE ||
options[argEndIt] ==
' ') {
244 StringRef value = extractArgAndUpdateOptions(argEndIt);
245 return std::make_tuple(argName, value,
options);
250 if (tryProcessPunct(argEndIt,
'\'') || tryProcessPunct(argEndIt,
'"'))
255 size_t braceCount = 1;
256 for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
258 if (tryProcessPunct(argEndIt,
'\'') || tryProcessPunct(argEndIt,
'"'))
262 else if (
options[argEndIt] ==
'}' && --braceCount == 0)
269 llvm_unreachable(
"unexpected control flow in pass option parsing");
276 StringRef key, value;
281 auto it = OptionsMap.find(key);
282 if (it == OptionsMap.end()) {
283 llvm::errs() <<
"<Pass-Options-Parser>: no such option " << key <<
"\n";
286 if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
297 if (OptionsMap.empty())
302 auto compareOptionArgs = [](OptionBase *
const *lhs, OptionBase *
const *rhs) {
303 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
305 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
310 orderedOps, os, [&](OptionBase *option) { option->print(os); },
" ");
319 auto compareOptionArgs = [](OptionBase *
const *lhs, OptionBase *
const *rhs) {
320 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
322 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
323 for (OptionBase *option : orderedOps) {
328 llvm::outs().indent(indent);
329 option->getOption()->printOptionInfo(descIndent - indent);
348 llvm::cl::OptionValue<OpPassManager>::OptionValue() =
default;
349 llvm::cl::OptionValue<OpPassManager>::OptionValue(
353 llvm::cl::OptionValue<OpPassManager>::OptionValue(
358 llvm::cl::OptionValue<OpPassManager> &
359 llvm::cl::OptionValue<OpPassManager>::operator=(
365 llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() =
default;
367 void llvm::cl::OptionValue<OpPassManager>::setValue(
372 value = std::make_unique<mlir::OpPassManager>(newValue);
374 void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
376 assert(
succeeded(pipeline) &&
"invalid pass pipeline");
382 std::string lhsStr, rhsStr;
384 raw_string_ostream lhsStream(lhsStr);
385 value->printAsTextualPipeline(lhsStream);
387 raw_string_ostream rhsStream(rhsStr);
392 return lhsStr == rhsStr;
395 void llvm::cl::OptionValue<OpPassManager>::anchor() {}
407 ParsedPassManager &value) {
411 value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
421 const Option &opt,
OpPassManager &pm,
const OptVal &defaultValue,
422 size_t globalWidth)
const {
423 printOptionName(opt, globalWidth);
427 if (defaultValue.hasValue()) {
428 outs().indent(2) <<
" (default: ";
429 defaultValue.getValue().printAsTextualPipeline(outs());
440 ParsedPassManager &&) =
default;
450 class TextualPipeline {
454 LogicalResult initialize(StringRef text, raw_ostream &errorStream);
474 struct PipelineElement {
475 PipelineElement(StringRef name) : name(name) {}
480 std::vector<PipelineElement> innerPipeline;
486 LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
492 ErrorHandlerT errorHandler);
495 LogicalResult resolvePipelineElement(PipelineElement &element,
496 ErrorHandlerT errorHandler);
503 std::vector<PipelineElement> pipeline;
511 raw_ostream &errorStream) {
516 llvm::SourceMgr pipelineMgr;
517 pipelineMgr.AddNewSourceBuffer(
518 llvm::MemoryBuffer::getMemBuffer(text,
"MLIR Textual PassPipeline Parser",
521 auto errorHandler = [&](
const char *rawLoc, Twine msg) {
522 pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc),
523 llvm::SourceMgr::DK_Error, msg);
528 if (
failed(parsePipelineText(text, errorHandler)))
530 return resolvePipelineElements(pipeline, errorHandler);
542 auto restore = llvm::make_scope_exit([&]() { pm.
setNesting(nesting); });
544 return addToPipeline(pipeline, pm, errorHandler);
550 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
551 ErrorHandlerT errorHandler) {
554 std::vector<PipelineElement> &pipeline = *pipelineStack.back();
555 size_t pos = text.find_first_of(
",(){");
556 pipeline.emplace_back(text.substr(0, pos).trim());
559 if (pos == StringRef::npos)
562 text = text.substr(pos);
567 text = text.substr(1);
570 size_t close = StringRef::npos;
571 for (
unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
572 if (text[i] ==
'{') {
576 if (text[i] ==
'}' && --braceCount == 0) {
583 if (close == StringRef::npos) {
586 "missing closing '}' while processing pass options");
588 pipeline.back().options = text.substr(0, close);
589 text = text.substr(close + 1);
592 }
else if (sep ==
'(') {
593 text = text.substr(1);
596 pipelineStack.push_back(&pipeline.back().innerPipeline);
602 while (text.consume_front(
")")) {
604 if (pipelineStack.size() == 1)
605 return errorHandler(text.data() - 1,
606 "encountered extra closing ')' creating unbalanced "
607 "parentheses while parsing pipeline");
609 pipelineStack.pop_back();
618 if (!text.consume_front(
","))
619 return errorHandler(text.data(),
"expected ',' after parsing pipeline");
623 if (pipelineStack.size() > 1)
626 "encountered unbalanced parentheses while parsing pipeline");
628 assert(pipelineStack.back() == &pipeline &&
629 "wrong pipeline at the bottom of the stack");
637 for (
auto &elt : elements)
638 if (
failed(resolvePipelineElement(elt, errorHandler)))
645 TextualPipeline::resolvePipelineElement(PipelineElement &element,
646 ErrorHandlerT errorHandler) {
649 if (!element.innerPipeline.empty())
650 return resolvePipelineElements(element.innerPipeline, errorHandler);
655 element.registryEntry = &pipelineRegistryIt->second;
664 auto *rawLoc = element.name.data();
665 return errorHandler(rawLoc,
"'" + element.name +
666 "' does not refer to a "
667 "registered pass or pass pipeline");
674 for (
auto &elt : elements) {
675 if (elt.registryEntry) {
676 if (
failed(elt.registryEntry->addToPipeline(pm, elt.options,
678 return errorHandler(
"failed to add `" + elt.name +
"` with options `" +
681 }
else if (
failed(addToPipeline(elt.innerPipeline, pm.
nest(elt.name),
683 return errorHandler(
"failed to add `" + elt.name +
"` with options `" +
684 elt.options +
"` to inner pipeline");
691 raw_ostream &errorStream) {
692 TextualPipeline pipelineParser;
693 if (
failed(pipelineParser.initialize(pipeline, errorStream)))
695 auto errorHandler = [&](Twine msg) {
696 errorStream << msg <<
"\n";
699 if (
failed(pipelineParser.addToPipeline(pm, errorHandler)))
705 raw_ostream &errorStream) {
707 size_t pipelineStart = pipeline.find_first_of(
'(');
708 if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
709 !pipeline.consume_back(
")")) {
710 errorStream <<
"expected pass pipeline to be wrapped with the anchor "
711 "operation type, e.g. 'builtin.module(...)'";
715 StringRef opName = pipeline.take_front(pipelineStart);
731 PassArgData() =
default;
733 : registryEntry(registryEntry) {}
756 const PassArgData &
getValue()
const {
return value; }
757 void setValue(
const PassArgData &value) { this->value = value; }
768 #define PASS_PIPELINE_ARG "pass-pipeline"
776 void printOptionInfo(
const llvm::cl::Option &opt,
777 size_t globalWidth)
const override;
778 size_t getOptionWidth(
const llvm::cl::Option &opt)
const override;
779 bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
785 bool passNamesOnly =
false;
789 void PassNameParser::initialize() {
794 addLiteralOption(kv.second.getPassArgument(), &kv.second,
795 kv.second.getPassDescription());
798 if (!passNamesOnly) {
800 addLiteralOption(kv.second.getPassArgument(), &kv.second,
801 kv.second.getPassDescription());
806 void PassNameParser::printOptionInfo(
const llvm::cl::Option &opt,
807 size_t globalWidth)
const {
811 llvm::outs() <<
" --" << opt.ArgStr <<
"=<pass-arg>";
812 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
817 if (opt.hasArgStr()) {
818 llvm::outs() <<
" --" << opt.ArgStr;
819 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
821 llvm::outs() <<
" " << opt.HelpStr <<
'\n';
825 auto printOrderedEntries = [&](StringRef header,
auto &map) {
828 orderedEntries.push_back(&kv.second);
829 llvm::array_pod_sort(
830 orderedEntries.begin(), orderedEntries.end(),
832 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
835 llvm::outs().indent(4) << header <<
":\n";
837 entry->printHelpStr(6, globalWidth);
848 size_t PassNameParser::getOptionWidth(
const llvm::cl::Option &opt)
const {
853 maxWidth =
std::max(maxWidth, entry.second.getOptionWidth() + 4);
855 maxWidth =
std::max(maxWidth, entry.second.getOptionWidth() + 4);
859 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
860 StringRef arg, PassArgData &value) {
876 : passList(arg,
llvm::cl::desc(description)) {
877 passList.getParser().passNamesOnly = passNamesOnly;
878 passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
884 return llvm::any_of(passList, [&](
const PassArgData &data) {
885 return data.registryEntry == entry;
890 llvm::cl::list<PassArgData, bool, PassNameParser>
passList;
898 arg, description, false)),
901 llvm::cl::desc(
"Textual description of the pass pipeline to run")) {}
906 passPipelineAlias.emplace(alias,
908 llvm::cl::aliasopt(passPipeline));
915 return passPipeline.getNumOccurrences() != 0 ||
916 impl->passList.getNumOccurrences() != 0;
922 return impl->contains(entry);
929 if (passPipeline.getNumOccurrences()) {
930 if (
impl->passList.getNumOccurrences())
933 "' option can't be used with individual pass options");
935 llvm::raw_string_ostream os(errMsg);
938 return errorHandler(errMsg);
939 pm = std::move(*parsed);
943 for (
auto &passIt :
impl->passList) {
944 if (
failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
957 arg, description, true)) {
958 impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
964 return impl->passList.getNumOccurrences() != 0;
970 return impl->contains(entry);
static llvm::ManagedStatic< PassManagerOptions > options
static llvm::ManagedStatic< llvm::StringMap< PassPipelineInfo > > passPipelineRegistry
Static mapping of all of the registered pass pipelines.
#define PASS_PIPELINE_ARG
The name for the command line option used for parsing the textual pass pipeline.
static llvm::ManagedStatic< llvm::StringMap< PassInfo > > passRegistry
Static mapping of all of the registered passes.
static PassRegistryFunction buildDefaultRegistryFn(const PassAllocatorFunction &allocator)
Utility to create a default registry function from a pass instance.
static void printOptionHelp(StringRef arg, StringRef desc, size_t indent, size_t descIndent, bool isTopLevel)
Utility to print the help string for a specific option.
static llvm::ManagedStatic< llvm::StringMap< TypeID > > passRegistryTypeIDs
A mapping of the above pass registry entries to the corresponding TypeID of the pass that they genera...
static std::tuple< StringRef, StringRef, StringRef > parseNextArg(StringRef options)
Parse in the next argument from the given options string.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
This class provides support for representing a failure result, or a valid value of type T.
This class represents a pass manager that runs passes on either a specific operation type,...
void printAsTextualPipeline(raw_ostream &os) const
Prints out the passes of the pass manager as the textual representation of pipelines.
std::optional< OperationName > getOpName(MLIRContext &context) const
Return the operation name that this pass manager operates on, or std::nullopt if this is an op-agnost...
void setNesting(Nesting nesting)
Enable or disable the implicit nesting on this particular PassManager.
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
Nesting getNesting()
Return the current nesting mode.
Nesting
This enum represents the nesting behavior of the pass manager.
@ Explicit
Explicit nesting behavior.
StringRef getOpAnchorName() const
Return the name used to anchor this pass manager.
OpPassManager & nest(OperationName nestedName)
Nest a new operation pass manager for the given operation kind under this pass manager.
A structure to represent the information for a derived pass class.
PassInfo(StringRef arg, StringRef description, const PassAllocatorFunction &allocator)
PassInfo constructor should not be invoked directly, instead use PassRegistration or registerPass.
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassNameCLParser(StringRef arg, StringRef description)
Construct a parser with the given command line description.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
This class implements a command-line parser for MLIR passes.
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassPipelineCLParser(StringRef arg, StringRef description)
Construct a pass pipeline parser with the given command line description.
LogicalResult addToPipeline(OpPassManager &pm, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds the passes defined by this parser entry to the given pass manager.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
A structure to represent the information of a registered pass pipeline.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
void printHelpStr(size_t indent, size_t descIndent) const
Print the help information for this pass.
size_t getOptionWidth() const
Return the maximum width required when printing the options of this entry.
const PassInfo * lookupPassInfo() const
Returns the pass info for this pass, or null if unknown.
This class provides an efficient unique identifier for a specific C++ type.
Base container class and manager for all pass options.
size_t getOptionWidth() const
Return the maximum width required when printing the help string.
void printHelp(size_t indent, size_t descIndent) const
Print the help string for the options held by this struct.
LogicalResult parseFromString(StringRef options)
Parse options out as key=value pairs that can then be handed off to the llvm::cl command line passing...
void copyOptionValuesFrom(const PassOptions &other)
Copy the option values from 'other' into 'this', where 'other' has the same options as 'this'.
void print(raw_ostream &os)
Print the options held by this struct in a form that can be parsed via 'parseFromString'.
Include the generated interface declarations.
LogicalResult parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName, StringRef optionStr, function_ref< LogicalResult(StringRef)> elementParseFn)
Parse a string containing a list of comma-delimited elements, invoking the given parser for each sub-...
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::function< std::unique_ptr< Pass >()> PassAllocatorFunction
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
void registerPass(const PassAllocatorFunction &function)
Register a specific dialect pass allocator function with the system, typically used through the PassR...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void registerPassPipeline(StringRef arg, StringRef description, const PassRegistryFunction &function, std::function< void(function_ref< void(const detail::PassOptions &)>)> optHandler)
Register a specific dialect pipeline registry function with the system, typically used through the Pa...
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
std::function< LogicalResult(OpPassManager &, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler)> PassRegistryFunction
A registry function that adds passes to the given pass manager.
Define a valid OptionValue for the command line pass argument.
OptionValue(const PassArgData &value)
void setValue(const PassArgData &value)
const PassArgData & getValue() const
mlir::OpPassManager & getValue() const
Returns the current value of the option.
bool hasValue() const
Returns if the current option has a value.
This class represents an efficient way to signal success or failure.
llvm::cl::list< PassArgData, bool, PassNameParser > passList
The set of passes and pass pipelines to run.
PassPipelineCLParserImpl(StringRef arg, StringRef description, bool passNamesOnly)
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...