MLIR 22.0.0git
RegistryManager.cpp
Go to the documentation of this file.
1//===- RegistryManager.cpp - Matcher registry -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Registry map populated at static initialization time.
10//
11//===----------------------------------------------------------------------===//
12
13#include "RegistryManager.h"
15
16#include <set>
17#include <utility>
18
19namespace mlir::query::matcher {
20namespace {
21
22// Enum to string for autocomplete.
23static std::string asArgString(ArgKind kind) {
24 switch (kind) {
26 return "Boolean";
28 return "Matcher";
29 case ArgKind::Signed:
30 return "Signed";
31 case ArgKind::String:
32 return "String";
33 }
34 llvm_unreachable("Unhandled ArgKind");
35}
36
37} // namespace
38
39void Registry::registerMatcherDescriptor(
40 llvm::StringRef matcherName,
41 std::unique_ptr<internal::MatcherDescriptor> callback) {
42 assert(!constructorMap.contains(matcherName));
43 constructorMap[matcherName] = std::move(callback);
44}
45
46std::optional<MatcherCtor>
47RegistryManager::lookupMatcherCtor(llvm::StringRef matcherName,
48 const Registry &matcherRegistry) {
49 auto it = matcherRegistry.constructors().find(matcherName);
50 return it == matcherRegistry.constructors().end()
51 ? std::optional<MatcherCtor>()
52 : it->second.get();
53}
54
56 llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) {
57 // Starting with the above seed of acceptable top-level matcher types, compute
58 // the acceptable type set for the argument indicated by each context element.
59 std::set<ArgKind> typeSet;
60 typeSet.insert(ArgKind::Matcher);
61
62 for (const auto &ctxEntry : context) {
63 MatcherCtor ctor = ctxEntry.first;
64 unsigned argNumber = ctxEntry.second;
65 std::vector<ArgKind> nextTypeSet;
66
67 if (ctor->isVariadic() || argNumber < ctor->getNumArgs())
68 ctor->getArgKinds(argNumber, nextTypeSet);
69
70 typeSet.insert(nextTypeSet.begin(), nextTypeSet.end());
71 }
72
73 return std::vector<ArgKind>(typeSet.begin(), typeSet.end());
74}
75
76std::vector<MatcherCompletion>
78 const Registry &matcherRegistry) {
79 std::vector<MatcherCompletion> completions;
80
81 // Search the registry for acceptable matchers.
82 for (const auto &m : matcherRegistry.constructors()) {
83 const internal::MatcherDescriptor &matcher = *m.getValue();
84 llvm::StringRef name = m.getKey();
85
86 unsigned numArgs = matcher.isVariadic() ? 1 : matcher.getNumArgs();
87 std::vector<std::vector<ArgKind>> argKinds(numArgs);
88
89 for (const ArgKind &kind : acceptedTypes) {
90 if (kind != ArgKind::Matcher)
91 continue;
92
93 for (unsigned arg = 0; arg != numArgs; ++arg)
94 matcher.getArgKinds(arg, argKinds[arg]);
95 }
96
97 std::string decl;
98 llvm::raw_string_ostream os(decl);
99
100 std::string typedText = std::string(name);
101 os << "Matcher: " << name << "(";
102
103 for (const std::vector<ArgKind> &arg : argKinds) {
104 if (&arg != &argKinds[0])
105 os << ", ";
106
107 bool firstArgKind = true;
108 // Two steps. First all non-matchers, then matchers only.
109 for (const ArgKind &argKind : arg) {
110 if (!firstArgKind)
111 os << "|";
112
113 firstArgKind = false;
114 os << asArgString(argKind);
115 }
116 }
117
118 if (matcher.isVariadic())
119 os << ",...";
120
121 os << ")";
122 typedText += "(";
123
124 if (argKinds.empty())
125 typedText += ")";
126 else if (argKinds[0][0] == ArgKind::String)
127 typedText += "\"";
128
129 completions.emplace_back(typedText, decl);
130 }
131
132 return completions;
133}
134
136 MatcherCtor ctor, internal::SourceRange nameRange,
137 llvm::StringRef functionName, llvm::ArrayRef<ParserValue> args,
138 internal::Diagnostics *error) {
139 VariantMatcher out = ctor->create(nameRange, args, error);
140 if (functionName.empty() || out.isNull())
141 return out;
142
143 if (std::optional<DynMatcher> result = out.getDynMatcher()) {
144 result->setFunctionName(functionName);
146 }
147
149 return {};
150}
151
152} // namespace mlir::query::matcher
static std::vector< MatcherCompletion > getMatcherCompletions(ArrayRef< ArgKind > acceptedTypes, const Registry &matcherRegistry)
static std::vector< ArgKind > getAcceptedCompletionTypes(llvm::ArrayRef< std::pair< MatcherCtor, unsigned > > context)
static std::optional< MatcherCtor > lookupMatcherCtor(llvm::StringRef matcherName, const Registry &matcherRegistry)
static VariantMatcher constructMatcher(MatcherCtor ctor, internal::SourceRange nameRange, llvm::StringRef functionName, ArrayRef< ParserValue > args, internal::Diagnostics *error)
const ConstructorMap & constructors() const
Definition Registry.h:33
static VariantMatcher SingleMatcher(DynMatcher matcher)
std::optional< DynMatcher > getDynMatcher() const
ArgStream addError(SourceRange range, ErrorType error)
virtual VariantMatcher create(SourceRange nameRange, const llvm::ArrayRef< ParserValue > args, Diagnostics *error) const =0
virtual void getArgKinds(unsigned argNo, std::vector< ArgKind > &argKinds) const =0
Computes the backward-slice of all transitive defs reachable from rootOp, if innerMatcher matches.
const internal::MatcherDescriptor * MatcherCtor