13 #include "llvm/ADT/DenseMap.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/Support/Format.h"
17 #include "llvm/Support/ManagedStatic.h"
18 #include "llvm/Support/MemoryBuffer.h"
19 #include "llvm/Support/SourceMgr.h"
25 using namespace detail;
28 static llvm::ManagedStatic<llvm::StringMap<PassInfo>>
passRegistry;
35 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
42 function_ref<LogicalResult(
const Twine &)> errorHandler) {
43 std::unique_ptr<Pass> pass = allocator();
44 LogicalResult result = pass->initializeOptions(
options, errorHandler);
46 std::optional<StringRef> pmOpName = pm.
getOpName();
47 std::optional<StringRef> passOpName = pass->getOpName();
49 passOpName && *pmOpName != *passOpName) {
50 return errorHandler(llvm::Twine(
"Can't add pass '") + pass->getName() +
51 "' restricted to '" + *pass->getOpName() +
52 "' on a PassManager intended to run on '" +
62 size_t descIndent,
bool isTopLevel) {
63 size_t numSpaces = descIndent - indent - 4;
64 llvm::outs().indent(indent)
65 <<
"--" << llvm::left_justify(arg, numSpaces) <<
"- " << desc <<
'\n';
76 maxWidth =
std::max(maxWidth, entry.second.getOptionWidth() + 4);
79 auto printOrderedEntries = [&](StringRef header,
auto &map) {
82 orderedEntries.push_back(&kv.second);
84 orderedEntries.begin(), orderedEntries.end(),
86 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
89 llvm::outs().indent(0) << header <<
":\n";
91 entry->printHelpStr(2, maxWidth);
102 printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
106 options.printHelp(indent, descIndent);
115 maxLen =
options.getOptionWidth() + 2;
128 std::move(optHandler));
132 report_fatal_error(
"Pass pipeline " + arg +
" registered multiple times");
147 optHandler(allocator()->passOptions);
151 std::unique_ptr<Pass> pass =
function();
152 StringRef arg = pass->getArgument();
154 llvm::report_fatal_error(llvm::Twine(
"Trying to register '") +
156 "' pass that does not override `getArgument()`");
157 StringRef description = pass->getDescription();
158 PassInfo passInfo(arg, description,
function);
163 TypeID entryTypeID = pass->getTypeID();
165 if (it->second != entryTypeID)
166 llvm::report_fatal_error(
167 "pass allocator creates a different pass than previously "
168 "registered for pass " +
175 return it ==
passRegistry->end() ? nullptr : &it->second;
192 static size_t findChar(StringRef str,
size_t index,
char c) {
193 for (
size_t i = index, e = str.size(); i < e; ++i) {
199 else if (str[i] ==
'(')
201 else if (str[i] ==
'[')
203 else if (str[i] ==
'\"')
204 i = str.find_first_of(
'\"', i + 1);
205 else if (str[i] ==
'\'')
206 i = str.find_first_of(
'\'', i + 1);
207 if (i == StringRef::npos)
208 return StringRef::npos;
210 return StringRef::npos;
217 StringRef str =
options.take_front(argSize).trim();
224 const auto escapePairs = {std::make_pair(
'\'',
'\''),
225 std::make_pair(
'"',
'"')};
226 for (
const auto &escape : escapePairs) {
227 if (str.front() == escape.first && str.back() == escape.second) {
230 return str.drop_front().drop_back().trim();
238 if (str.front() ==
'{') {
239 unsigned match =
findChar(str, 1,
'}');
240 if (match == str.size() - 1)
241 str = str.drop_front().drop_back().trim();
248 llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
249 function_ref<LogicalResult(StringRef)> elementParseFn) {
250 if (optionStr.empty())
253 size_t nextElePos =
findChar(optionStr, 0,
',');
254 while (nextElePos != StringRef::npos) {
261 optionStr = optionStr.drop_front();
262 nextElePos =
findChar(optionStr, 0,
',');
264 return elementParseFn(
269 void detail::PassOptions::OptionBase::anchor() {}
273 assert(options.size() == other.options.size());
276 for (
auto optionsIt : llvm::zip(options, other.options))
277 std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
283 static std::tuple<StringRef, StringRef, StringRef>
287 auto tryProcessPunct = [&](
size_t ¤tPos,
char punct) {
288 if (
options[currentPos] != punct)
290 size_t nextIt =
options.find_first_of(punct, currentPos + 1);
291 if (nextIt != StringRef::npos)
298 for (
size_t argEndIt = 0, optionsE =
options.size();; ++argEndIt) {
300 if (argEndIt == optionsE ||
options[argEndIt] ==
' ') {
302 return std::make_tuple(argName, StringRef(),
options);
306 if (
options[argEndIt] ==
'=') {
314 for (
size_t argEndIt = 0, optionsE =
options.size();; ++argEndIt) {
316 if (argEndIt == optionsE ||
options[argEndIt] ==
' ') {
318 return std::make_tuple(argName, value,
options);
323 if (tryProcessPunct(argEndIt,
'\'') || tryProcessPunct(argEndIt,
'"'))
328 size_t braceCount = 1;
329 for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
331 if (tryProcessPunct(argEndIt,
'\'') || tryProcessPunct(argEndIt,
'"'))
335 else if (
options[argEndIt] ==
'}' && --braceCount == 0)
342 llvm_unreachable(
"unexpected control flow in pass option parsing");
346 raw_ostream &errorStream) {
350 StringRef key, value;
355 auto it = OptionsMap.find(key);
356 if (it == OptionsMap.end()) {
357 errorStream <<
"<Pass-Options-Parser>: no such option " << key <<
"\n";
360 if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
371 if (OptionsMap.empty())
376 auto compareOptionArgs = [](OptionBase *
const *lhs, OptionBase *
const *rhs) {
377 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
379 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
384 orderedOps, os, [&](OptionBase *option) { option->print(os); },
" ");
393 auto compareOptionArgs = [](OptionBase *
const *lhs, OptionBase *
const *rhs) {
394 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
396 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
397 for (OptionBase *option : orderedOps) {
402 llvm::outs().indent(indent);
403 option->getOption()->printOptionInfo(descIndent - indent);
423 llvm::cl::OptionValue<OpPassManager>::OptionValue() =
default;
424 llvm::cl::OptionValue<OpPassManager>::OptionValue(
428 llvm::cl::OptionValue<OpPassManager>::OptionValue(
433 llvm::cl::OptionValue<OpPassManager> &
434 llvm::cl::OptionValue<OpPassManager>::operator=(
440 llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() =
default;
442 void llvm::cl::OptionValue<OpPassManager>::setValue(
447 value = std::make_unique<mlir::OpPassManager>(newValue);
449 void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
451 assert(succeeded(pipeline) &&
"invalid pass pipeline");
457 std::string lhsStr, rhsStr;
459 raw_string_ostream lhsStream(lhsStr);
460 value->printAsTextualPipeline(lhsStream);
462 raw_string_ostream rhsStream(rhsStr);
467 return lhsStr == rhsStr;
470 void llvm::cl::OptionValue<OpPassManager>::anchor() {}
483 ParsedPassManager &value) {
485 if (failed(pipeline))
487 value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
497 const Option &opt,
OpPassManager &pm,
const OptVal &defaultValue,
498 size_t globalWidth)
const {
499 printOptionName(opt, globalWidth);
503 if (defaultValue.hasValue()) {
504 outs().indent(2) <<
" (default: ";
505 defaultValue.getValue().printAsTextualPipeline(outs());
516 ParsedPassManager &&) =
default;
526 class TextualPipeline {
530 LogicalResult initialize(StringRef text, raw_ostream &errorStream);
535 function_ref<LogicalResult(
const Twine &)> errorHandler)
const;
541 using ErrorHandlerT =
function_ref<LogicalResult(
const char *, Twine)>;
550 struct PipelineElement {
551 PipelineElement(StringRef name) : name(name) {}
556 std::vector<PipelineElement> innerPipeline;
562 LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
568 ErrorHandlerT errorHandler);
571 LogicalResult resolvePipelineElement(PipelineElement &element,
572 ErrorHandlerT errorHandler);
577 function_ref<LogicalResult(
const Twine &)> errorHandler)
const;
579 std::vector<PipelineElement> pipeline;
586 LogicalResult TextualPipeline::initialize(StringRef text,
587 raw_ostream &errorStream) {
592 llvm::SourceMgr pipelineMgr;
593 pipelineMgr.AddNewSourceBuffer(
594 llvm::MemoryBuffer::getMemBuffer(text,
"MLIR Textual PassPipeline Parser",
597 auto errorHandler = [&](
const char *rawLoc, Twine msg) {
598 pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc),
599 llvm::SourceMgr::DK_Error, msg);
604 if (failed(parsePipelineText(text, errorHandler)))
606 return resolvePipelineElements(pipeline, errorHandler);
610 LogicalResult TextualPipeline::addToPipeline(
612 function_ref<LogicalResult(
const Twine &)> errorHandler)
const {
618 auto restore = llvm::make_scope_exit([&]() { pm.
setNesting(nesting); });
620 return addToPipeline(pipeline, pm, errorHandler);
626 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
627 ErrorHandlerT errorHandler) {
630 std::vector<PipelineElement> &pipeline = *pipelineStack.back();
631 size_t pos = text.find_first_of(
",(){");
632 pipeline.emplace_back(text.substr(0, pos).trim());
635 if (pos == StringRef::npos)
638 text = text.substr(pos);
643 text = text.substr(1);
646 size_t close = StringRef::npos;
647 for (
unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
648 if (text[i] ==
'{') {
652 if (text[i] ==
'}' && --braceCount == 0) {
659 if (close == StringRef::npos) {
662 "missing closing '}' while processing pass options");
664 pipeline.back().options = text.substr(0, close);
665 text = text.substr(close + 1);
671 }
else if (sep ==
'(') {
672 text = text.substr(1);
675 pipelineStack.push_back(&pipeline.back().innerPipeline);
681 while (text.consume_front(
")")) {
683 if (pipelineStack.size() == 1)
684 return errorHandler(text.data() - 1,
685 "encountered extra closing ')' creating unbalanced "
686 "parentheses while parsing pipeline");
688 pipelineStack.pop_back();
699 if (!text.consume_front(
","))
700 return errorHandler(text.data(),
"expected ',' after parsing pipeline");
704 if (pipelineStack.size() > 1)
707 "encountered unbalanced parentheses while parsing pipeline");
709 assert(pipelineStack.back() == &pipeline &&
710 "wrong pipeline at the bottom of the stack");
716 LogicalResult TextualPipeline::resolvePipelineElements(
718 for (
auto &elt : elements)
719 if (failed(resolvePipelineElement(elt, errorHandler)))
726 TextualPipeline::resolvePipelineElement(PipelineElement &element,
727 ErrorHandlerT errorHandler) {
730 if (!element.innerPipeline.empty())
731 return resolvePipelineElements(element.innerPipeline, errorHandler);
743 auto *rawLoc = element.name.data();
744 return errorHandler(rawLoc,
"'" + element.name +
745 "' does not refer to a "
746 "registered pass or pass pipeline");
750 LogicalResult TextualPipeline::addToPipeline(
752 function_ref<LogicalResult(
const Twine &)> errorHandler)
const {
753 for (
auto &elt : elements) {
754 if (elt.registryEntry) {
755 if (failed(elt.registryEntry->addToPipeline(pm, elt.options,
757 return errorHandler(
"failed to add `" + elt.name +
"` with options `" +
760 }
else if (failed(addToPipeline(elt.innerPipeline, pm.
nest(elt.name),
762 return errorHandler(
"failed to add `" + elt.name +
"` with options `" +
763 elt.options +
"` to inner pipeline");
770 raw_ostream &errorStream) {
771 TextualPipeline pipelineParser;
772 if (failed(pipelineParser.initialize(pipeline, errorStream)))
774 auto errorHandler = [&](Twine msg) {
775 errorStream << msg <<
"\n";
778 if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
784 raw_ostream &errorStream) {
785 pipeline = pipeline.trim();
787 size_t pipelineStart = pipeline.find_first_of(
'(');
788 if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
789 !pipeline.consume_back(
")")) {
790 errorStream <<
"expected pass pipeline to be wrapped with the anchor "
791 "operation type, e.g. 'builtin.module(...)'";
795 StringRef opName = pipeline.take_front(pipelineStart).rtrim();
811 PassArgData() =
default;
813 : registryEntry(registryEntry) {}
836 const PassArgData &
getValue()
const {
return value; }
837 void setValue(
const PassArgData &value) { this->value = value; }
848 #define PASS_PIPELINE_ARG "pass-pipeline"
856 void printOptionInfo(
const llvm::cl::Option &opt,
857 size_t globalWidth)
const override;
858 size_t getOptionWidth(
const llvm::cl::Option &opt)
const override;
859 bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
865 bool passNamesOnly =
false;
869 void PassNameParser::initialize() {
874 addLiteralOption(kv.second.getPassArgument(), &kv.second,
875 kv.second.getPassDescription());
878 if (!passNamesOnly) {
880 addLiteralOption(kv.second.getPassArgument(), &kv.second,
881 kv.second.getPassDescription());
886 void PassNameParser::printOptionInfo(
const llvm::cl::Option &opt,
887 size_t globalWidth)
const {
891 llvm::outs() <<
" --" << opt.ArgStr <<
"=<pass-arg>";
892 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
897 if (opt.hasArgStr()) {
898 llvm::outs() <<
" --" << opt.ArgStr;
899 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
901 llvm::outs() <<
" " << opt.HelpStr <<
'\n';
905 auto printOrderedEntries = [&](StringRef header,
auto &map) {
908 orderedEntries.push_back(&kv.second);
909 llvm::array_pod_sort(
910 orderedEntries.begin(), orderedEntries.end(),
912 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
915 llvm::outs().indent(4) << header <<
":\n";
917 entry->printHelpStr(6, globalWidth);
928 size_t PassNameParser::getOptionWidth(
const llvm::cl::Option &opt)
const {
933 maxWidth =
std::max(maxWidth, entry.second.getOptionWidth() + 4);
935 maxWidth =
std::max(maxWidth, entry.second.getOptionWidth() + 4);
940 StringRef arg, PassArgData &value) {
956 : passList(arg,
llvm::cl::desc(description)) {
957 passList.getParser().passNamesOnly = passNamesOnly;
958 passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
964 return llvm::any_of(passList, [&](
const PassArgData &data) {
965 return data.registryEntry == entry;
970 llvm::cl::list<PassArgData, bool, PassNameParser>
passList;
978 arg, description, false)),
981 llvm::cl::desc(
"Textual description of the pass pipeline to run")) {}
986 passPipelineAlias.emplace(alias,
988 llvm::cl::aliasopt(passPipeline));
995 return passPipeline.getNumOccurrences() != 0 ||
996 impl->passList.getNumOccurrences() != 0;
1002 return impl->contains(entry);
1008 function_ref<LogicalResult(
const Twine &)> errorHandler)
const {
1009 if (passPipeline.getNumOccurrences()) {
1010 if (
impl->passList.getNumOccurrences())
1011 return errorHandler(
1013 "' option can't be used with individual pass options");
1015 llvm::raw_string_ostream os(errMsg);
1018 return errorHandler(errMsg);
1019 pm = std::move(*parsed);
1023 for (
auto &passIt :
impl->passList) {
1024 if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
1038 arg, description, true)) {
1039 impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
1045 return impl->passList.getNumOccurrences() != 0;
1051 return impl->contains(entry);
static llvm::ManagedStatic< PassManagerOptions > options
static llvm::ManagedStatic< llvm::StringMap< PassPipelineInfo > > passPipelineRegistry
Static mapping of all of the registered pass pipelines.
#define PASS_PIPELINE_ARG
The name for the command line option used for parsing the textual pass pipeline.
static llvm::ManagedStatic< llvm::StringMap< PassInfo > > passRegistry
Static mapping of all of the registered passes.
static PassRegistryFunction buildDefaultRegistryFn(const PassAllocatorFunction &allocator)
Utility to create a default registry function from a pass instance.
static void printOptionHelp(StringRef arg, StringRef desc, size_t indent, size_t descIndent, bool isTopLevel)
Utility to print the help string for a specific option.
static llvm::ManagedStatic< llvm::StringMap< TypeID > > passRegistryTypeIDs
A mapping of the above pass registry entries to the corresponding TypeID of the pass that they genera...
static std::tuple< StringRef, StringRef, StringRef > parseNextArg(StringRef options)
Parse in the next argument from the given options string.
static size_t findChar(StringRef str, size_t index, char c)
Attempt to find the next occurance of character 'c' in the string starting from the index-th position...
static StringRef extractArgAndUpdateOptions(StringRef &options, size_t argSize)
Extract an argument from 'options' and update it to point after the arg.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
This class represents a pass manager that runs passes on either a specific operation type,...
void printAsTextualPipeline(raw_ostream &os) const
Prints out the passes of the pass manager as the textual representation of pipelines.
std::optional< OperationName > getOpName(MLIRContext &context) const
Return the operation name that this pass manager operates on, or std::nullopt if this is an op-agnost...
void setNesting(Nesting nesting)
Enable or disable the implicit nesting on this particular PassManager.
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
Nesting getNesting()
Return the current nesting mode.
Nesting
This enum represents the nesting behavior of the pass manager.
@ Explicit
Explicit nesting behavior.
StringRef getOpAnchorName() const
Return the name used to anchor this pass manager.
OpPassManager & nest(OperationName nestedName)
Nest a new operation pass manager for the given operation kind under this pass manager.
A structure to represent the information for a derived pass class.
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
PassInfo(StringRef arg, StringRef description, const PassAllocatorFunction &allocator)
PassInfo constructor should not be invoked directly, instead use PassRegistration or registerPass.
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassNameCLParser(StringRef arg, StringRef description)
Construct a parser with the given command line description.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
This class implements a command-line parser for MLIR passes.
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
PassPipelineCLParser(StringRef arg, StringRef description)
Construct a pass pipeline parser with the given command line description.
LogicalResult addToPipeline(OpPassManager &pm, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds the passes defined by this parser entry to the given pass manager.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...
A structure to represent the information of a registered pass pipeline.
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
void printHelpStr(size_t indent, size_t descIndent) const
Print the help information for this pass.
size_t getOptionWidth() const
Return the maximum width required when printing the options of this entry.
This class provides an efficient unique identifier for a specific C++ type.
Base container class and manager for all pass options.
size_t getOptionWidth() const
Return the maximum width required when printing the help string.
void printHelp(size_t indent, size_t descIndent) const
Print the help string for the options held by this struct.
LogicalResult parseFromString(StringRef options, raw_ostream &errorStream=llvm::errs())
Parse options out as key=value pairs that can then be handed off to the llvm::cl command line passing...
void print(raw_ostream &os) const
Print the options held by this struct in a form that can be parsed via 'parseFromString'.
void copyOptionValuesFrom(const PassOptions &other)
Copy the option values from 'other' into 'this', where 'other' has the same options as 'this'.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
LogicalResult parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName, StringRef optionStr, function_ref< LogicalResult(StringRef)> elementParseFn)
Parse a string containing a list of comma-delimited elements, invoking the given parser for each sub-...
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
void printRegisteredPasses()
Prints the passes that were previously registered and stored in passRegistry.
std::function< std::unique_ptr< Pass >()> PassAllocatorFunction
void registerPass(const PassAllocatorFunction &function)
Register a specific dialect pass allocator function with the system, typically used through the PassR...
void registerPassPipeline(StringRef arg, StringRef description, const PassRegistryFunction &function, std::function< void(function_ref< void(const detail::PassOptions &)>)> optHandler)
Register a specific dialect pipeline registry function with the system, typically used through the Pa...
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
std::function< LogicalResult(OpPassManager &, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler)> PassRegistryFunction
A registry function that adds passes to the given pass manager.
Define a valid OptionValue for the command line pass argument.
OptionValue(const PassArgData &value)
void setValue(const PassArgData &value)
const PassArgData & getValue() const
mlir::OpPassManager & getValue() const
Returns the current value of the option.
bool hasValue() const
Returns if the current option has a value.
llvm::cl::list< PassArgData, bool, PassNameParser > passList
The set of passes and pass pipelines to run.
PassPipelineCLParserImpl(StringRef arg, StringRef description, bool passNamesOnly)
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser,...