MLIR  20.0.0git
RegistryManager.cpp
Go to the documentation of this file.
1 //===- RegistryManager.cpp - Matcher registry -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Registry map populated at static initialization time.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "RegistryManager.h"
15 
16 #include <set>
17 #include <utility>
18 
19 namespace mlir::query::matcher {
20 namespace {
21 
22 // This is needed because these matchers are defined as overloaded functions.
23 using IsConstantOp = detail::constant_op_matcher();
24 using HasOpAttrName = detail::AttrOpMatcher(llvm::StringRef);
25 using HasOpName = detail::NameOpMatcher(llvm::StringRef);
26 
27 // Enum to string for autocomplete.
28 static std::string asArgString(ArgKind kind) {
29  switch (kind) {
30  case ArgKind::Matcher:
31  return "Matcher";
32  case ArgKind::String:
33  return "String";
34  }
35  llvm_unreachable("Unhandled ArgKind");
36 }
37 
38 } // namespace
39 
40 void Registry::registerMatcherDescriptor(
41  llvm::StringRef matcherName,
42  std::unique_ptr<internal::MatcherDescriptor> callback) {
43  assert(!constructorMap.contains(matcherName));
44  constructorMap[matcherName] = std::move(callback);
45 }
46 
47 std::optional<MatcherCtor>
48 RegistryManager::lookupMatcherCtor(llvm::StringRef matcherName,
49  const Registry &matcherRegistry) {
50  auto it = matcherRegistry.constructors().find(matcherName);
51  return it == matcherRegistry.constructors().end()
52  ? std::optional<MatcherCtor>()
53  : it->second.get();
54 }
55 
57  llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) {
58  // Starting with the above seed of acceptable top-level matcher types, compute
59  // the acceptable type set for the argument indicated by each context element.
60  std::set<ArgKind> typeSet;
61  typeSet.insert(ArgKind::Matcher);
62 
63  for (const auto &ctxEntry : context) {
64  MatcherCtor ctor = ctxEntry.first;
65  unsigned argNumber = ctxEntry.second;
66  std::vector<ArgKind> nextTypeSet;
67 
68  if (argNumber < ctor->getNumArgs())
69  ctor->getArgKinds(argNumber, nextTypeSet);
70 
71  typeSet.insert(nextTypeSet.begin(), nextTypeSet.end());
72  }
73 
74  return std::vector<ArgKind>(typeSet.begin(), typeSet.end());
75 }
76 
77 std::vector<MatcherCompletion>
79  const Registry &matcherRegistry) {
80  std::vector<MatcherCompletion> completions;
81 
82  // Search the registry for acceptable matchers.
83  for (const auto &m : matcherRegistry.constructors()) {
84  const internal::MatcherDescriptor &matcher = *m.getValue();
85  llvm::StringRef name = m.getKey();
86 
87  unsigned numArgs = matcher.getNumArgs();
88  std::vector<std::vector<ArgKind>> argKinds(numArgs);
89 
90  for (const ArgKind &kind : acceptedTypes) {
91  if (kind != ArgKind::Matcher)
92  continue;
93 
94  for (unsigned arg = 0; arg != numArgs; ++arg)
95  matcher.getArgKinds(arg, argKinds[arg]);
96  }
97 
98  std::string decl;
99  llvm::raw_string_ostream os(decl);
100 
101  std::string typedText = std::string(name);
102  os << "Matcher: " << name << "(";
103 
104  for (const std::vector<ArgKind> &arg : argKinds) {
105  if (&arg != &argKinds[0])
106  os << ", ";
107 
108  bool firstArgKind = true;
109  // Two steps. First all non-matchers, then matchers only.
110  for (const ArgKind &argKind : arg) {
111  if (!firstArgKind)
112  os << "|";
113 
114  firstArgKind = false;
115  os << asArgString(argKind);
116  }
117  }
118 
119  os << ")";
120  typedText += "(";
121 
122  if (argKinds.empty())
123  typedText += ")";
124  else if (argKinds[0][0] == ArgKind::String)
125  typedText += "\"";
126 
127  completions.emplace_back(typedText, decl);
128  }
129 
130  return completions;
131 }
132 
134  MatcherCtor ctor, internal::SourceRange nameRange,
135  llvm::StringRef functionName, llvm::ArrayRef<ParserValue> args,
136  internal::Diagnostics *error) {
137  VariantMatcher out = ctor->create(nameRange, args, error);
138  if (functionName.empty() || out.isNull())
139  return out;
140 
141  if (std::optional<DynMatcher> result = out.getDynMatcher()) {
142  result->setFunctionName(functionName);
143  return VariantMatcher::SingleMatcher(*result);
144  }
145 
147  return {};
148 }
149 
150 } // namespace mlir::query::matcher
static std::vector< MatcherCompletion > getMatcherCompletions(ArrayRef< ArgKind > acceptedTypes, const Registry &matcherRegistry)
static std::vector< ArgKind > getAcceptedCompletionTypes(llvm::ArrayRef< std::pair< MatcherCtor, unsigned >> context)
static std::optional< MatcherCtor > lookupMatcherCtor(llvm::StringRef matcherName, const Registry &matcherRegistry)
static VariantMatcher constructMatcher(MatcherCtor ctor, internal::SourceRange nameRange, llvm::StringRef functionName, ArrayRef< ParserValue > args, internal::Diagnostics *error)
const ConstructorMap & constructors() const
Definition: Registry.h:33
static VariantMatcher SingleMatcher(DynMatcher matcher)
std::optional< DynMatcher > getDynMatcher() const
ArgStream addError(SourceRange range, ErrorType error)
Definition: Diagnostics.cpp:20
virtual VariantMatcher create(SourceRange nameRange, const llvm::ArrayRef< ParserValue > args, Diagnostics *error) const =0
virtual void getArgKinds(unsigned argNo, std::vector< ArgKind > &argKinds) const =0