14 #include "llvm/ADT/DenseMap.h" 15 #include "llvm/Support/Format.h" 16 #include "llvm/Support/ManagedStatic.h" 17 #include "llvm/Support/MemoryBuffer.h" 18 #include "llvm/Support/SourceMgr.h" 21 using namespace detail;
24 static llvm::ManagedStatic<llvm::StringMap<PassInfo>>
passRegistry;
31 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
39 std::unique_ptr<Pass> pass = allocator();
45 passOpName && *pmOpName != *passOpName) {
46 return errorHandler(llvm::Twine(
"Can't add pass '") + pass->getName() +
47 "' restricted to '" + *pass->getOpName() +
48 "' on a PassManager intended to run on '" +
58 size_t descIndent,
bool isTopLevel) {
59 size_t numSpaces = descIndent - indent - 4;
60 llvm::outs().indent(indent)
61 <<
"--" << llvm::left_justify(arg, numSpaces) <<
"- " << desc <<
'\n';
72 printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
98 std::move(optHandler));
100 assert(inserted &&
"Pass pipeline registered multiple times");
114 optHandler(allocator()->passOptions);
118 std::unique_ptr<Pass> pass =
function();
119 StringRef arg = pass->getArgument();
121 llvm::report_fatal_error(llvm::Twine(
"Trying to register '") +
123 "' pass that does not override `getArgument()`");
124 StringRef description = pass->getDescription();
125 PassInfo passInfo(arg, description,
function);
130 TypeID entryTypeID = pass->getTypeID();
132 if (it->second != entryTypeID)
133 llvm::report_fatal_error(
134 "pass allocator creates a different pass than previously " 135 "registered for pass " +
142 return it ==
passRegistry->end() ? nullptr : &it->second;
150 llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
154 llvm::unique_function<size_t(StringRef, size_t, char)> findChar =
155 [&](StringRef str,
size_t index,
char c) ->
size_t {
156 for (
size_t i = index, e = str.size(); i < e; ++i) {
161 i = findChar(str, i + 1,
'}');
162 else if (str[i] ==
'(')
163 i = findChar(str, i + 1,
')');
164 else if (str[i] ==
'[')
165 i = findChar(str, i + 1,
']');
166 else if (str[i] ==
'\"')
167 i = str.find_first_of(
'\"', i + 1);
168 else if (str[i] ==
'\'')
169 i = str.find_first_of(
'\'', i + 1);
171 return StringRef::npos;
174 size_t nextElePos = findChar(optionStr, 0,
',');
175 while (nextElePos != StringRef::npos) {
177 if (
failed(elementParseFn(optionStr.substr(0, nextElePos))))
180 optionStr = optionStr.substr(nextElePos + 1);
181 nextElePos = findChar(optionStr, 0,
',');
183 return elementParseFn(optionStr.substr(0, nextElePos));
187 void detail::PassOptions::OptionBase::anchor() {}
191 assert(
options.size() == other.options.size());
194 for (
auto optionsIt : llvm::zip(
options, other.options))
195 std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
201 static std::tuple<StringRef, StringRef, StringRef>
205 auto extractArgAndUpdateOptions = [&](
size_t argSize) {
206 StringRef str = options.take_front(argSize).trim();
207 options = options.drop_front(argSize).ltrim();
212 auto tryProcessPunct = [&](
size_t ¤tPos,
char punct) {
213 if (options[currentPos] != punct)
215 size_t nextIt = options.find_first_of(punct, currentPos + 1);
216 if (nextIt != StringRef::npos)
223 for (
size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
225 if (argEndIt == optionsE || options[argEndIt] ==
' ') {
226 argName = extractArgAndUpdateOptions(argEndIt);
227 return std::make_tuple(argName, StringRef(), options);
231 if (options[argEndIt] ==
'=') {
232 argName = extractArgAndUpdateOptions(argEndIt);
233 options = options.drop_front();
239 for (
size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
241 if (argEndIt == optionsE || options[argEndIt] ==
' ') {
242 StringRef
value = extractArgAndUpdateOptions(argEndIt);
243 return std::make_tuple(argName, value, options);
247 char c = options[argEndIt];
248 if (tryProcessPunct(argEndIt,
'\'') || tryProcessPunct(argEndIt,
'"'))
253 size_t braceCount = 1;
254 for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
256 if (tryProcessPunct(argEndIt,
'\'') || tryProcessPunct(argEndIt,
'"'))
258 if (options[argEndIt] ==
'{')
260 else if (options[argEndIt] ==
'}' && --braceCount == 0)
267 llvm_unreachable(
"unexpected control flow in pass option parsing");
273 while (!options.empty()) {
274 StringRef key,
value;
279 auto it = OptionsMap.find(key);
280 if (it == OptionsMap.end()) {
281 llvm::errs() <<
"<Pass-Options-Parser>: no such option " << key <<
"\n";
284 if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
295 if (OptionsMap.empty())
300 auto compareOptionArgs = [](OptionBase *
const *lhs, OptionBase *
const *rhs) {
301 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
303 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
308 orderedOps, os, [&](OptionBase *option) { option->print(os); },
" ");
317 auto compareOptionArgs = [](OptionBase *
const *lhs, OptionBase *
const *rhs) {
318 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
320 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
321 for (OptionBase *option : orderedOps) {
326 llvm::outs().indent(indent);
327 option->getOption()->printOptionInfo(descIndent - indent);
335 max =
std::max(max, option->getOption()->getOptionWidth());
346 llvm::cl::OptionValue<OpPassManager>::OptionValue() =
default;
347 llvm::cl::OptionValue<OpPassManager>::OptionValue(
351 llvm::cl::OptionValue<OpPassManager> &
352 llvm::cl::OptionValue<OpPassManager>::operator=(
358 llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() =
default;
360 void llvm::cl::OptionValue<OpPassManager>::setValue(
365 value = std::make_unique<mlir::OpPassManager>(newValue);
367 void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
369 assert(
succeeded(pipeline) &&
"invalid pass pipeline");
375 std::string lhsStr, rhsStr;
377 raw_string_ostream lhsStream(lhsStr);
380 raw_string_ostream rhsStream(rhsStr);
385 return lhsStr == rhsStr;
388 void llvm::cl::OptionValue<OpPassManager>::anchor() {}
400 ParsedPassManager &value) {
404 value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
415 size_t globalWidth)
const {
416 printOptionName(opt, globalWidth);
420 if (defaultValue.hasValue()) {
421 outs().indent(2) <<
" (default: ";
422 defaultValue.getValue().printAsTextualPipeline(outs());
433 ParsedPassManager &&) =
default;
443 class TextualPipeline {
447 LogicalResult initialize(StringRef text, raw_ostream &errorStream);
467 struct PipelineElement {
468 PipelineElement(StringRef name) : name(name) {}
473 std::vector<PipelineElement> innerPipeline;
479 LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
485 ErrorHandlerT errorHandler);
488 LogicalResult resolvePipelineElement(PipelineElement &element,
489 ErrorHandlerT errorHandler);
496 std::vector<PipelineElement> pipeline;
504 raw_ostream &errorStream) {
509 llvm::SourceMgr pipelineMgr;
510 pipelineMgr.AddNewSourceBuffer(
511 llvm::MemoryBuffer::getMemBuffer(text,
"MLIR Textual PassPipeline Parser",
514 auto errorHandler = [&](
const char *rawLoc, Twine msg) {
515 pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc),
516 llvm::SourceMgr::DK_Error, msg);
521 if (
failed(parsePipelineText(text, errorHandler)))
523 return resolvePipelineElements(pipeline, errorHandler);
536 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
537 ErrorHandlerT errorHandler) {
540 std::vector<PipelineElement> &pipeline = *pipelineStack.back();
541 size_t pos = text.find_first_of(
",(){");
542 pipeline.emplace_back(text.substr(0, pos).trim());
545 if (pos == StringRef::npos)
548 text = text.substr(pos);
553 text = text.substr(1);
556 size_t close = StringRef::npos;
557 for (
unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
558 if (text[i] ==
'{') {
562 if (text[i] ==
'}' && --braceCount == 0) {
569 if (close == StringRef::npos) {
572 "missing closing '}' while processing pass options");
574 pipeline.back().options = text.substr(0, close);
575 text = text.substr(close + 1);
578 }
else if (sep ==
'(') {
579 text = text.substr(1);
582 pipelineStack.push_back(&pipeline.back().innerPipeline);
588 while (text.consume_front(
")")) {
590 if (pipelineStack.size() == 1)
591 return errorHandler(text.data() - 1,
592 "encountered extra closing ')' creating unbalanced " 593 "parentheses while parsing pipeline");
595 pipelineStack.pop_back();
604 if (!text.consume_front(
","))
605 return errorHandler(text.data(),
"expected ',' after parsing pipeline");
609 if (pipelineStack.size() > 1)
612 "encountered unbalanced parentheses while parsing pipeline");
614 assert(pipelineStack.back() == &pipeline &&
615 "wrong pipeline at the bottom of the stack");
623 for (
auto &elt : elements)
624 if (
failed(resolvePipelineElement(elt, errorHandler)))
631 TextualPipeline::resolvePipelineElement(PipelineElement &element,
632 ErrorHandlerT errorHandler) {
635 if (!element.innerPipeline.empty())
636 return resolvePipelineElements(element.innerPipeline, errorHandler);
641 element.registryEntry = &pipelineRegistryIt->second;
650 auto *rawLoc = element.name.data();
651 return errorHandler(rawLoc,
"'" + element.name +
652 "' does not refer to a " 653 "registered pass or pass pipeline");
660 for (
auto &elt : elements) {
661 if (elt.registryEntry) {
662 if (
failed(elt.registryEntry->addToPipeline(pm, elt.options,
664 return errorHandler(
"failed to add `" + elt.name +
"` with options `" +
669 return errorHandler(
"failed to add `" + elt.name +
"` with options `" +
670 elt.options +
"` to inner pipeline");
677 raw_ostream &errorStream) {
678 TextualPipeline pipelineParser;
679 if (
failed(pipelineParser.initialize(pipeline, errorStream)))
681 auto errorHandler = [&](Twine msg) {
682 errorStream << msg <<
"\n";
685 if (
failed(pipelineParser.addToPipeline(pm, errorHandler)))
691 raw_ostream &errorStream) {
693 size_t pipelineStart = pipeline.find_first_of(
'(');
694 if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
695 !pipeline.consume_back(
")")) {
696 errorStream <<
"expected pass pipeline to be wrapped with the anchor " 697 "operation type, e.g. `builtin.module(...)";
701 StringRef opName = pipeline.take_front(pipelineStart);
716 PassArgData() =
default;
718 : registryEntry(registryEntry) {}
730 TextualPipeline pipeline;
741 OptionValue() =
default;
757 static constexpr StringLiteral passPipelineArg =
"pass-pipeline";
765 void printOptionInfo(
const llvm::cl::Option &opt,
766 size_t globalWidth)
const override;
767 size_t getOptionWidth(
const llvm::cl::Option &opt)
const override;
768 bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
775 bool passNamesOnly =
false;
779 void PassNameParser::initialize() {
783 if (!passNamesOnly) {
784 addLiteralOption(passPipelineArg, PassArgData(),
785 "A textual description of a pass pipeline to run");
790 addLiteralOption(kv.second.getPassArgument(), &kv.second,
791 kv.second.getPassDescription());
794 if (!passNamesOnly) {
796 addLiteralOption(kv.second.getPassArgument(), &kv.second,
797 kv.second.getPassDescription());
802 void PassNameParser::printOptionInfo(
const llvm::cl::Option &opt,
803 size_t globalWidth)
const {
807 llvm::outs() <<
" --" << opt.ArgStr <<
"=<pass-arg>";
808 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
813 if (opt.hasArgStr()) {
814 llvm::outs() <<
" --" << opt.ArgStr;
815 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
817 llvm::outs() <<
" " << opt.HelpStr <<
'\n';
822 "A textual description of a pass pipeline to run",
823 4, globalWidth, !opt.hasArgStr());
826 auto printOrderedEntries = [&](StringRef header,
auto &map) {
829 orderedEntries.push_back(&kv.second);
830 llvm::array_pod_sort(
831 orderedEntries.begin(), orderedEntries.end(),
833 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
836 llvm::outs().indent(4) << header <<
":\n";
838 entry->printHelpStr(6, globalWidth);
849 size_t PassNameParser::getOptionWidth(
const llvm::cl::Option &opt)
const {
854 maxWidth =
std::max(maxWidth, entry.second.getOptionWidth() + 4);
856 maxWidth =
std::max(maxWidth, entry.second.getOptionWidth() + 4);
860 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
861 StringRef arg, PassArgData &value) {
863 if (argName == passPipelineArg)
864 return failed(value.pipeline.initialize(arg, llvm::errs()));
882 : passList(arg,
llvm::cl::desc(description)) {
883 passList.getParser().passNamesOnly = passNamesOnly;
884 passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
890 return llvm::any_of(passList, [&](
const PassArgData &data) {
891 return data.registryEntry == entry;
904 arg, description, false)) {}
909 return impl->passList.getNumOccurrences() != 0;
915 return impl->contains(entry);
922 for (
auto &passIt :
impl->passList) {
923 if (passIt.registryEntry) {
924 if (
failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
930 LogicalResult status = passIt.pipeline.addToPipeline(pm, errorHandler);
945 arg, description, true)) {
946 impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
952 return impl->passList.getNumOccurrences() != 0;
958 return impl->contains(entry);
Nesting getNesting()
Return the current nesting mode.
Include the generated interface declarations.
static PassRegistryFunction buildDefaultRegistryFn(const PassAllocatorFunction &allocator)
Utility to create a default registry function from a pass instance.
static llvm::ManagedStatic< llvm::StringMap< TypeID > > passRegistryTypeIDs
A mapping of the above pass registry entries to the corresponding TypeID of the pass that they genera...
Explicitly register a set of "builtin" types.
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser...
Explicit nesting behavior.
static llvm::ManagedStatic< llvm::StringMap< PassInfo > > passRegistry
Static mapping of all of the registered passes.
PassInfo(StringRef arg, StringRef description, const PassAllocatorFunction &allocator)
PassInfo constructor should not be invoked directly, instead use PassRegistration or registerPass...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
size_t getOptionWidth() const
Return the maximum width required when printing the help string.
void registerPass(const PassAllocatorFunction &function)
Register a specific dialect pass allocator function with the system, typically used through the PassR...
LogicalResult parseFromString(StringRef options)
Parse options out as key=value pairs that can then be handed off to the llvm::cl command line passing...
static void printOptionHelp(StringRef arg, StringRef desc, size_t indent, size_t descIndent, bool isTopLevel)
Utility to print the help string for a specific option.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
static constexpr const bool value
This class provides an efficient unique identifier for a specific C++ type.
size_t getOptionWidth() const
Return the maximum width required when printing the options of this entry.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
This class provides support for representing a failure result, or a valid value of type T...
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser...
StringRef getOpAnchorName() const
Return the name used to anchor this pass manager.
void copyOptionValuesFrom(const PassOptions &other)
Copy the option values from 'other' into 'this', where 'other' has the same options as 'this'...
std::function< LogicalResult(OpPassManager &, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler)> PassRegistryFunction
A registry function that adds passes to the given pass manager.
Optional< OperationName > getOpName(MLIRContext &context) const
Return the operation name that this pass manager operates on, or None if this is an op-agnostic pass ...
LogicalResult addToPipeline(OpPassManager &pm, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds the passes defined by this parser entry to the given pass manager.
Nesting
This enum represents the nesting behavior of the pass manager.
OptionValue(const PassArgData &value)
This class represents a specific pass option, with a provided data type.
void printHelpStr(size_t indent, size_t descIndent) const
Print the help information for this pass.
bool hasAnyOccurrences() const
Returns true if this parser contains any valid options to add.
void setValue(const PassArgData &value)
PassNameCLParser(StringRef arg, StringRef description)
Construct a parser with the given command line description.
static void print(ArrayType type, DialectAsmPrinter &os)
llvm::cl::list< PassArgData, bool, PassNameParser > passList
The set of passes and pass pipelines to run.
OpPassManager & nest(OperationName nestedName)
Nest a new operation pass manager for the given operation kind under this pass manager.
std::function< std::unique_ptr< Pass >()> PassAllocatorFunction
Base container class and manager for all pass options.
PassPipelineCLParserImpl(StringRef arg, StringRef description, bool passNamesOnly)
static llvm::ManagedStatic< PassManagerOptions > options
A structure to represent the information of a registered pass pipeline.
const PassArgData & getValue() const
void printHelp(size_t indent, size_t descIndent) const
Print the help string for the options held by this struct.
void printAsTextualPipeline(raw_ostream &os) const
Prints out the passes of the pass manager as the textual representation of pipelines.
A structure to represent the information for a derived pass class.
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success...
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
PassPipelineCLParser(StringRef arg, StringRef description)
Construct a pass pipeline parser with the given command line description.
Define a valid OptionValue for the command line pass argument.
static std::tuple< StringRef, StringRef, StringRef > parseNextArg(StringRef options)
Parse in the next argument from the given options string.
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
const PassInfo * lookupPassInfo() const
Returns the pass info for this pass, or null if unknown.
int compare(Fraction x, Fraction y)
Three-way comparison between two fractions.
void registerPassPipeline(StringRef arg, StringRef description, const PassRegistryFunction &function, std::function< void(function_ref< void(const detail::PassOptions &)>)> optHandler)
Register a specific dialect pipeline registry function with the system, typically used through the Pa...
LogicalResult parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName, StringRef optionStr, function_ref< LogicalResult(StringRef)> elementParseFn)
Parse a string containing a list of comma-delimited elements, invoking the given parser for each sub-...
static llvm::ManagedStatic< llvm::StringMap< PassPipelineInfo > > passPipelineRegistry
Static mapping of all of the registered pass pipelines.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt...
void setNesting(Nesting nesting)
Enable or disable the implicit nesting on this particular PassManager.
This class represents a pass manager that runs passes on either a specific operation type...
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
bool contains(const PassRegistryEntry *entry) const
Returns true if the given pass registry entry was registered at the top-level of the parser...
void print(raw_ostream &os)
Print the options held by this struct in a form that can be parsed via 'parseFromString'.