13 #include "llvm/ADT/DenseMap.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/Support/Format.h"
17 #include "llvm/Support/ManagedStatic.h"
18 #include "llvm/Support/MemoryBuffer.h"
19 #include "llvm/Support/SourceMgr.h"
25 using namespace detail;
28 static llvm::ManagedStatic<llvm::StringMap<PassInfo>>
passRegistry;
35 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
42 function_ref<LogicalResult(
const Twine &)> errorHandler) {
43 std::unique_ptr<Pass> pass = allocator();
44 LogicalResult result = pass->initializeOptions(
options, errorHandler);
46 std::optional<StringRef> pmOpName = pm.
getOpName();
47 std::optional<StringRef> passOpName = pass->getOpName();
49 passOpName && *pmOpName != *passOpName) {
50 return errorHandler(llvm::Twine(
"Can't add pass '") + pass->getName() +
51 "' restricted to '" + *pass->getOpName() +
52 "' on a PassManager intended to run on '" +
62 size_t descIndent,
bool isTopLevel) {
63 size_t numSpaces = descIndent - indent - 4;
64 llvm::outs().indent(indent)
65 <<
"--" << llvm::left_justify(arg, numSpaces) <<
"- " << desc <<
'\n';
76 maxWidth =
std::max(maxWidth, entry.second.getOptionWidth() + 4);
79 auto printOrderedEntries = [&](StringRef header,
auto &map) {
82 orderedEntries.push_back(&kv.second);
84 orderedEntries.begin(), orderedEntries.end(),
86 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
89 llvm::outs().indent(0) << header <<
":\n";
91 entry->printHelpStr(2, maxWidth);
102 printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
106 options.printHelp(indent, descIndent);
115 maxLen =
options.getOptionWidth() + 2;
128 std::move(optHandler));
132 report_fatal_error(
"Pass pipeline " + arg +
" registered multiple times");
147 optHandler(allocator()->passOptions);
151 std::unique_ptr<Pass> pass =
function();
152 StringRef arg = pass->getArgument();
154 llvm::report_fatal_error(llvm::Twine(
"Trying to register '") +
156 "' pass that does not override `getArgument()`");
157 StringRef description = pass->getDescription();
158 PassInfo passInfo(arg, description,
function);
163 TypeID entryTypeID = pass->getTypeID();
165 if (it->second != entryTypeID)
166 llvm::report_fatal_error(
167 "pass allocator creates a different pass than previously "
168 "registered for pass " +
175 return it ==
passRegistry->end() ? nullptr : &it->second;
193 StringRef str =
options.take_front(argSize).trim();
200 const auto escapePairs = {std::make_pair(
'\'',
'\''),
201 std::make_pair(
'"',
'"'), std::make_pair(
'{',
'}')};
202 for (
const auto &escape : escapePairs) {
203 if (str.front() == escape.first && str.back() == escape.second) {
205 str = str.drop_front().drop_back().trim();
215 llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
216 function_ref<LogicalResult(StringRef)> elementParseFn) {
219 llvm::unique_function<size_t(StringRef,
size_t,
char)> findChar =
220 [&](StringRef str,
size_t index,
char c) ->
size_t {
221 for (
size_t i = index, e = str.size(); i < e; ++i) {
226 i = findChar(str, i + 1,
'}');
227 else if (str[i] ==
'(')
228 i = findChar(str, i + 1,
')');
229 else if (str[i] ==
'[')
230 i = findChar(str, i + 1,
']');
231 else if (str[i] ==
'\"')
232 i = str.find_first_of(
'\"', i + 1);
233 else if (str[i] ==
'\'')
234 i = str.find_first_of(
'\'', i + 1);
236 return StringRef::npos;
239 size_t nextElePos = findChar(optionStr, 0,
',');
240 while (nextElePos != StringRef::npos) {
247 optionStr = optionStr.drop_front();
248 nextElePos = findChar(optionStr, 0,
',');
250 return elementParseFn(
255 void detail::PassOptions::OptionBase::anchor() {}
259 assert(options.size() == other.options.size());
262 for (
auto optionsIt : llvm::zip(options, other.options))
263 std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
269 static std::tuple<StringRef, StringRef, StringRef>
273 auto tryProcessPunct = [&](
size_t ¤tPos,
char punct) {
274 if (
options[currentPos] != punct)
276 size_t nextIt =
options.find_first_of(punct, currentPos + 1);
277 if (nextIt != StringRef::npos)
284 for (
size_t argEndIt = 0, optionsE =
options.size();; ++argEndIt) {
286 if (argEndIt == optionsE ||
options[argEndIt] ==
' ') {
288 return std::make_tuple(argName, StringRef(),
options);
292 if (
options[argEndIt] ==
'=') {
300 for (
size_t argEndIt = 0, optionsE =
options.size();; ++argEndIt) {
302 if (argEndIt == optionsE ||
options[argEndIt] ==
' ') {
304 return std::make_tuple(argName, value,
options);
309 if (tryProcessPunct(argEndIt,
'\'') || tryProcessPunct(argEndIt,
'"'))
314 size_t braceCount = 1;
315 for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
317 if (tryProcessPunct(argEndIt,
'\'') || tryProcessPunct(argEndIt,
'"'))
321 else if (
options[argEndIt] ==
'}' && --braceCount == 0)
328 llvm_unreachable(
"unexpected control flow in pass option parsing");
332 raw_ostream &errorStream) {
336 StringRef key, value;
341 auto it = OptionsMap.find(key);
342 if (it == OptionsMap.end()) {
343 errorStream <<
"<Pass-Options-Parser>: no such option " << key <<
"\n";
346 if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
357 if (OptionsMap.empty())
362 auto compareOptionArgs = [](OptionBase *
const *lhs, OptionBase *
const *rhs) {
363 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
365 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
370 orderedOps, os, [&](OptionBase *option) { option->print(os); },
" ");
379 auto compareOptionArgs = [](OptionBase *
const *lhs, OptionBase *
const *rhs) {
380 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
382 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
383 for (OptionBase *option : orderedOps) {
388 llvm::outs().indent(indent);
389 option->getOption()->printOptionInfo(descIndent - indent);
408 llvm::cl::OptionValue<OpPassManager>::OptionValue() =
default;
409 llvm::cl::OptionValue<OpPassManager>::OptionValue(
413 llvm::cl::OptionValue<OpPassManager>::OptionValue(
418 llvm::cl::OptionValue<OpPassManager> &
419 llvm::cl::OptionValue<OpPassManager>::operator=(
425 llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() =
default;
427 void llvm::cl::OptionValue<OpPassManager>::setValue(
432 value = std::make_unique<mlir::OpPassManager>(newValue);
434 void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
436 assert(succeeded(pipeline) &&
"invalid pass pipeline");
442 std::string lhsStr, rhsStr;
444 raw_string_ostream lhsStream(lhsStr);
445 value->printAsTextualPipeline(lhsStream);
447 raw_string_ostream rhsStream(rhsStr);
452 return lhsStr == rhsStr;
455 void llvm::cl::OptionValue<OpPassManager>::anchor() {}
467 ParsedPassManager &value) {
469 if (failed(pipeline))
471 value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
481 const Option &opt,
OpPassManager &pm,
const OptVal &defaultValue,
482 size_t globalWidth)
const {
483 printOptionName(opt, globalWidth);
487 if (defaultValue.hasValue()) {
488 outs().indent(2) <<
" (default: ";
489 defaultValue.getValue().printAsTextualPipeline(outs());
500 ParsedPassManager &&) =
default;
510 class TextualPipeline {
514 LogicalResult initialize(StringRef text, raw_ostream &errorStream);
519 function_ref<LogicalResult(
const Twine &)> errorHandler)
const;
525 using ErrorHandlerT =
function_ref<LogicalResult(
const char *, Twine)>;
534 struct PipelineElement {
535 PipelineElement(StringRef name) : name(name) {}
540 std::vector<PipelineElement> innerPipeline;
546 LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
552 ErrorHandlerT errorHandler);
555 LogicalResult resolvePipelineElement(PipelineElement &element,
556 ErrorHandlerT errorHandler);
561 function_ref<LogicalResult(
const Twine &)> errorHandler)
const;
563 std::vector<PipelineElement> pipeline;
570 LogicalResult TextualPipeline::initialize(StringRef text,
571 raw_ostream &errorStream) {
576 llvm::SourceMgr pipelineMgr;
577 pipelineMgr.AddNewSourceBuffer(
578 llvm::MemoryBuffer::getMemBuffer(text,
"MLIR Textual PassPipeline Parser",
581 auto errorHandler = [&](
const char *rawLoc, Twine msg) {
582 pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc),
583 llvm::SourceMgr::DK_Error, msg);
588 if (failed(parsePipelineText(text, errorHandler)))
590 return resolvePipelineElements(pipeline, errorHandler);
594 LogicalResult TextualPipeline::addToPipeline(
596 function_ref<LogicalResult(
const Twine &)> errorHandler)
const {
602 auto restore = llvm::make_scope_exit([&]() { pm.
setNesting(nesting); });
604 return addToPipeline(pipeline, pm, errorHandler);
610 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
611 ErrorHandlerT errorHandler) {
614 std::vector<PipelineElement> &pipeline = *pipelineStack.back();
615 size_t pos = text.find_first_of(
",(){");
616 pipeline.emplace_back(text.substr(0, pos).trim());
619 if (pos == StringRef::npos)
622 text = text.substr(pos);
627 text = text.substr(1);
630 size_t close = StringRef::npos;
631 for (
unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
632 if (text[i] ==
'{') {
636 if (text[i] ==
'}' && --braceCount == 0) {
643 if (close == StringRef::npos) {
646 "missing closing '}' while processing pass options");
648 pipeline.back().options = text.substr(0, close);
649 text = text.substr(close + 1);
655 }
else if (sep ==
'(') {
656 text = text.substr(1);
659 pipelineStack.push_back(&pipeline.back().innerPipeline);
665 while (text.consume_front(
")")) {
667 if (pipelineStack.size() == 1)
668 return errorHandler(text.data() - 1,
669 "encountered extra closing ')' creating unbalanced "
670 "parentheses while parsing pipeline");
672 pipelineStack.pop_back();
683 if (!text.consume_front(
","))
684 return errorHandler(text.data(),
"expected ',' after parsing pipeline");
688 if (pipelineStack.size() > 1)
691 "encountered unbalanced parentheses while parsing pipeline");
693 assert(pipelineStack.back() == &pipeline &&
694 "wrong pipeline at the bottom of the stack");
700 LogicalResult TextualPipeline::resolvePipelineElements(
702 for (
auto &elt : elements)
703 if (failed(resolvePipelineElement(elt, errorHandler)))
710 TextualPipeline::resolvePipelineElement(PipelineElement &element,
711 ErrorHandlerT errorHandler) {
714 if (!element.innerPipeline.empty())
715 return resolvePipelineElements(element.innerPipeline, errorHandler);
727 auto *rawLoc = element.name.data();
728 return errorHandler(rawLoc,
"'" + element.name +
729 "' does not refer to a "
730 "registered pass or pass pipeline");
734 LogicalResult TextualPipeline::addToPipeline(
736 function_ref<LogicalResult(
const Twine &)> errorHandler)
const {
737 for (
auto &elt : elements) {
738 if (elt.registryEntry) {
739 if (failed(elt.registryEntry->addToPipeline(pm, elt.options,
741 return errorHandler(
"failed to add `" + elt.name +
"` with options `" +
744 }
else if (failed(addToPipeline(elt.innerPipeline, pm.
nest(elt.name),
746 return errorHandler(
"failed to add `" + elt.name +
"` with options `" +
747 elt.options +
"` to inner pipeline");
754 raw_ostream &errorStream) {
755 TextualPipeline pipelineParser;
756 if (failed(pipelineParser.initialize(pipeline, errorStream)))
758 auto errorHandler = [&](Twine msg) {
759 errorStream << msg <<
"\n";
762 if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
768 raw_ostream &errorStream) {
769 pipeline = pipeline.trim();
771 size_t pipelineStart = pipeline.find_first_of(
'(');
772 if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
773 !pipeline.consume_back(
")")) {
774 errorStream <<
"expected pass pipeline to be wrapped with the anchor "
775 "operation type, e.g. 'builtin.module(...)'";
779 StringRef opName = pipeline.take_front(pipelineStart).rtrim();
795 PassArgData() =
default;
797 : registryEntry(registryEntry) {}
820 const PassArgData &
getValue()
const {
return value; }
821 void setValue(
const PassArgData &value) { this->value = value; }
832 #define PASS_PIPELINE_ARG "pass-pipeline"
840 void printOptionInfo(
const llvm::cl::Option &opt,
841 size_t globalWidth)
const override;
842 size_t getOptionWidth(
const llvm::cl::Option &opt)
const override;
843 bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
849 bool passNamesOnly =
false;
853 void PassNameParser::initialize() {
858 addLiteralOption(kv.second.getPassArgument(), &kv.second,
859 kv.second.getPassDescription());
862 if (!passNamesOnly) {
864 addLiteralOption(kv.second.getPassArgument(), &kv.second,
865 kv.second.getPassDescription());
870 void PassNameParser::printOptionInfo(
const llvm::cl::Option &opt,
871 size_t globalWidth)
const {
875 llvm::outs() <<
" --" << opt.ArgStr <<
"=<pass-arg>";
876 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
881 if (opt.hasArgStr()) {
882 llvm::outs() <<
" --" << opt.ArgStr;
883 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
885 llvm::outs() <<
" " << opt.HelpStr <<
'\n';
889 auto printOrderedEntries = [&](StringRef header,
auto &map) {
892 orderedEntries.push_back(&kv.second);
893 llvm::array_pod_sort(
894 orderedEntries.begin(), orderedEntries.end(),
896 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
899 llvm::outs().indent(4) << header <<
":\n";
901 entry->printHelpStr(6, globalWidth);
912 size_t PassNameParser::getOptionWidth(
const llvm::cl::Option &opt)
const {
917 maxWidth =
std::max(maxWidth, entry.second.getOptionWidth() + 4);
919 maxWidth =
std::max(maxWidth, entry.second.getOptionWidth() + 4);
924 StringRef arg, PassArgData &value) {
940 : passList(arg,
llvm::cl::desc(description)) {
941 passList.getParser().passNamesOnly = passNamesOnly;
942 passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
948 return llvm::any_of(passList, [&](
const PassArgData &data) {
949 return data.registryEntry == entry;
954 llvm::cl::list<PassArgData, bool, PassNameParser>
passList;
962 arg, description, false)),
965 llvm::cl::desc(
"Textual description of the pass pipeline to run")) {}
970 passPipelineAlias.emplace(alias,
972 llvm::cl::aliasopt(passPipeline));
979 return passPipeline.getNumOccurrences() != 0 ||
980 impl->passList.getNumOccurrences() != 0;
986 return impl->contains(entry);
992 function_ref<LogicalResult(
const Twine &)> errorHandler)
const {
993 if (passPipeline.getNumOccurrences()) {
994 if (
impl->passList.getNumOccurrences())
997 "' option can't be used with individual pass options");
999 llvm::raw_string_ostream os(errMsg);
1002 return errorHandler(errMsg);
1003 pm = std::move(*parsed);
1007 for (
auto &passIt :
impl->passList) {
1008 if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
1021 arg, description, true)) {
1022 impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
1028 return impl->passList.getNumOccurrences() != 0;
1034 return impl->contains(entry);
static llvm::ManagedStatic< PassManagerOptions > options
static llvm::ManagedStatic< llvm::StringMap< PassPipelineInfo > > passPipelineRegistry
Static mapping of all of the registered pass pipelines.
#define PASS_PIPELINE_ARG
The name for the command line option used for parsing the textual pass pipeline.
static llvm::ManagedStatic< llvm::StringMap< PassInfo > > passRegistry
Static mapping of all of the registered passes.
static PassRegistryFunction buildDefaultRegistryFn(const PassAllocatorFunction &allocator)
Utility to create a default registry function from a pass instance.
static void printOptionHelp(StringRef arg, StringRef desc, size_t indent, size_t descIndent, bool isTopLevel)
Utility to print the help string for a specific option.
static llvm::ManagedStatic< llvm::StringMap< TypeID > > passRegistryTypeIDs
A mapping of the above pass registry entries to the corresponding TypeID of the pass that they genera...
static std::tuple< StringRef, StringRef, StringRef > parseNextArg(StringRef options)
Parse in the next argument from the given options string.
static StringRef extractArgAndUpdateOptions(StringRef &options, size_t argSize)
Extract an argument from 'options' and update it to point after the arg.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
This class represents a pass manager that runs passes on either a specific operation type,...
void printAsTextualPipeline(raw_ostream &os) const
Prints out the passes of the pass manager as the textual representation of pipelines.
std::optional< OperationName > getOpName(MLIRContext &context) const
Return the operation name that this pass manager operates on, or std::nullopt if this is an op-agnost...
void setNesting(Nesting nesting)
Enable or disable the implicit nesting on this particular PassManager.
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
Nesting getNesting()
Return the current nesting mode.
Nesting
This enum represents the nesting behavior of the pass manager.
@ Explicit
Explicit nesting behavior.
StringRef getOpAnchorName() const
Return the name used to anchor this pass manager.
OpPassManager & nest(OperationName nestedName)
Nest a new operation pass manager for the given operation kind under this pass manager.
A structure to represent the information for a derived pass class.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
PassInfo(StringRef arg, StringRef description, const PassAllocatorFunction &allocator)
PassInfo constructor should not be invoked directly, instead use PassRegistration or registerPass.
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassNameCLParser(StringRef arg, StringRef description)
Construct a parser with the given command line description.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
This class implements a command-line parser for MLIR passes.
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassPipelineCLParser(StringRef arg, StringRef description)
Construct a pass pipeline parser with the given command line description.
LogicalResult addToPipeline(OpPassManager &pm, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds the passes defined by this parser entry to the given pass manager.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
A structure to represent the information of a registered pass pipeline.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
void printHelpStr(size_t indent, size_t descIndent) const
Print the help information for this pass.
size_t getOptionWidth() const
Return the maximum width required when printing the options of this entry.
This class provides an efficient unique identifier for a specific C++ type.
Base container class and manager for all pass options.
size_t getOptionWidth() const
Return the maximum width required when printing the help string.
void printHelp(size_t indent, size_t descIndent) const
Print the help string for the options held by this struct.
LogicalResult parseFromString(StringRef options, raw_ostream &errorStream=llvm::errs())
Parse options out as key=value pairs that can then be handed off to the llvm::cl command line passing...
void print(raw_ostream &os) const
Print the options held by this struct in a form that can be parsed via 'parseFromString'.
void copyOptionValuesFrom(const PassOptions &other)
Copy the option values from 'other' into 'this', where 'other' has the same options as 'this'.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
LogicalResult parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName, StringRef optionStr, function_ref< LogicalResult(StringRef)> elementParseFn)
Parse a string containing a list of comma-delimited elements, invoking the given parser for each sub-...
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
void printRegisteredPasses()
Prints the passes that were previously registered and stored in passRegistry.
std::function< std::unique_ptr< Pass >()> PassAllocatorFunction
void registerPass(const PassAllocatorFunction &function)
Register a specific dialect pass allocator function with the system, typically used through the PassR...
void registerPassPipeline(StringRef arg, StringRef description, const PassRegistryFunction &function, std::function< void(function_ref< void(const detail::PassOptions &)>)> optHandler)
Register a specific dialect pipeline registry function with the system, typically used through the Pa...
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
std::function< LogicalResult(OpPassManager &, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler)> PassRegistryFunction
A registry function that adds passes to the given pass manager.
Define a valid OptionValue for the command line pass argument.
OptionValue(const PassArgData &value)
void setValue(const PassArgData &value)
const PassArgData & getValue() const
mlir::OpPassManager & getValue() const
Returns the current value of the option.
bool hasValue() const
Returns if the current option has a value.
llvm::cl::list< PassArgData, bool, PassNameParser > passList
The set of passes and pass pipelines to run.
PassPipelineCLParserImpl(StringRef arg, StringRef description, bool passNamesOnly)
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...