MLIR  21.0.0git
Marshallers.h
Go to the documentation of this file.
1 //===--- Marshallers.h - Generic matcher function marshallers ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains function templates and classes to wrap matcher construct
10 // functions. It provides a collection of template function and classes that
11 // present a generic marshalling layer on top of matcher construct functions.
12 // The registry uses these to export all marshaller constructors with a uniform
13 // interface. This mechanism takes inspiration from clang-query.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
18 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
19 
20 #include "ErrorBuilder.h"
21 #include "VariantValue.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/StringRef.h"
24 
26 
27 // Helper template class for jumping from argument type to the correct is/get
28 // functions in VariantValue. This is used for verifying and extracting the
29 // matcher arguments.
30 template <class T>
32 template <class T>
33 struct ArgTypeTraits<const T &> : public ArgTypeTraits<T> {};
34 
35 template <>
36 struct ArgTypeTraits<llvm::StringRef> {
37 
38  static bool hasCorrectType(const VariantValue &value) {
39  return value.isString();
40  }
41 
42  static const llvm::StringRef &get(const VariantValue &value) {
43  return value.getString();
44  }
45 
46  static ArgKind getKind() { return ArgKind::String; }
47 
48  static std::optional<std::string> getBestGuess(const VariantValue &) {
49  return std::nullopt;
50  }
51 };
52 
53 template <>
54 struct ArgTypeTraits<int64_t> {
55  static bool hasCorrectType(const VariantValue &value) {
56  return value.isSigned();
57  }
58 
59  static unsigned get(const VariantValue &value) { return value.getSigned(); }
60 
61  static ArgKind getKind() { return ArgKind::Signed; }
62 
63  static std::optional<std::string> getBestGuess(const VariantValue &) {
64  return std::nullopt;
65  }
66 };
67 
68 template <>
69 struct ArgTypeTraits<bool> {
70  static bool hasCorrectType(const VariantValue &value) {
71  return value.isBoolean();
72  }
73 
74  static unsigned get(const VariantValue &value) { return value.getBoolean(); }
75 
76  static ArgKind getKind() { return ArgKind::Boolean; }
77 
78  static std::optional<std::string> getBestGuess(const VariantValue &) {
79  return std::nullopt;
80  }
81 };
82 
83 template <>
85 
86  static bool hasCorrectType(const VariantValue &value) {
87  return value.isMatcher();
88  }
89 
90  static DynMatcher get(const VariantValue &value) {
91  return *value.getMatcher().getDynMatcher();
92  }
93 
94  static ArgKind getKind() { return ArgKind::Matcher; }
95 
96  static std::optional<std::string> getBestGuess(const VariantValue &) {
97  return std::nullopt;
98  }
99 };
100 
101 // Interface for generic matcher descriptor.
102 // Offers a create() method that constructs the matcher from the provided
103 // arguments.
105 public:
106  virtual ~MatcherDescriptor() = default;
107  virtual VariantMatcher create(SourceRange nameRange,
108  const llvm::ArrayRef<ParserValue> args,
109  Diagnostics *error) const = 0;
110 
111  // If the matcher is variadic, it can take any number of arguments.
112  virtual bool isVariadic() const = 0;
113 
114  // Returns the number of arguments accepted by the matcher.
115  virtual unsigned getNumArgs() const = 0;
116 
117  // Append the set of argument types accepted for argument 'argNo' to
118  // 'argKinds'.
119  virtual void getArgKinds(unsigned argNo,
120  std::vector<ArgKind> &argKinds) const = 0;
121 };
122 
124 public:
125  using MarshallerType = VariantMatcher (*)(void (*matcherFunc)(),
126  llvm::StringRef matcherName,
127  SourceRange nameRange,
129  Diagnostics *error);
130 
131  // Marshaller Function to unpack the arguments and call Func. Func is the
132  // Matcher construct function. This is the function that the matcher
133  // expressions would use to create the matcher.
135  void (*matcherFunc)(),
136  llvm::StringRef matcherName,
137  llvm::ArrayRef<ArgKind> argKinds)
138  : marshaller(marshaller), matcherFunc(matcherFunc),
139  matcherName(matcherName), argKinds(argKinds.begin(), argKinds.end()) {}
140 
142  Diagnostics *error) const override {
143  return marshaller(matcherFunc, matcherName, nameRange, args, error);
144  }
145 
146  bool isVariadic() const override { return false; }
147 
148  unsigned getNumArgs() const override { return argKinds.size(); }
149 
150  void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
151  kinds.push_back(argKinds[argNo]);
152  }
153 
154 private:
155  const MarshallerType marshaller;
156  void (*const matcherFunc)();
157  const llvm::StringRef matcherName;
158  const std::vector<ArgKind> argKinds;
159 };
160 
162 public:
164  VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount,
165  VarOp varOp, StringRef matcherName)
166  : minCount(minCount), maxCount(maxCount), varOp(varOp),
167  matcherName(matcherName) {}
168 
170  Diagnostics *error) const override {
171  if (args.size() < minCount || maxCount < args.size()) {
172  addError(error, nameRange, ErrorType::RegistryWrongArgCount,
173  {llvm::Twine("requires between "), llvm::Twine(minCount),
174  llvm::Twine(" and "), llvm::Twine(maxCount),
175  llvm::Twine(" args, got "), llvm::Twine(args.size())});
176  return VariantMatcher();
177  }
178 
179  std::vector<VariantMatcher> innerArgs;
180  for (int64_t i = 0, e = args.size(); i != e; ++i) {
181  const ParserValue &arg = args[i];
182  const VariantValue &value = arg.value;
183  if (!value.isMatcher()) {
185  {llvm::Twine(i + 1), llvm::Twine("matcher: "),
186  llvm::Twine(value.getTypeAsString())});
187  return VariantMatcher();
188  }
189  innerArgs.push_back(value.getMatcher());
190  }
191  return VariantMatcher::VariadicOperatorMatcher(varOp, std::move(innerArgs));
192  }
193 
194  bool isVariadic() const override { return true; }
195 
196  unsigned getNumArgs() const override { return 0; }
197 
198  void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
199  kinds.push_back(ArgKind(ArgKind::Matcher));
200  }
201 
202 private:
203  const unsigned minCount;
204  const unsigned maxCount;
205  const VarOp varOp;
206  const StringRef matcherName;
207 };
208 
209 // Helper function to check if argument count matches expected count
210 inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
212  Diagnostics *error) {
213  if (args.size() != expectedArgCount) {
214  addError(error, nameRange, ErrorType::RegistryWrongArgCount,
215  {llvm::Twine(expectedArgCount), llvm::Twine(args.size())});
216  return false;
217  }
218  return true;
219 }
220 
221 // Helper function for checking argument type
222 template <typename ArgType, size_t Index>
223 inline bool checkArgTypeAtIndex(llvm::StringRef matcherName,
225  Diagnostics *error) {
226  if (!ArgTypeTraits<ArgType>::hasCorrectType(args[Index].value)) {
227  addError(error, args[Index].range, ErrorType::RegistryWrongArgType,
228  {llvm::Twine(matcherName), llvm::Twine(Index + 1)});
229  return false;
230  }
231  return true;
232 }
233 
234 // Marshaller function for fixed number of arguments
235 template <typename ReturnType, typename... ArgTypes, size_t... Is>
236 static VariantMatcher
237 matcherMarshallFixedImpl(void (*matcherFunc)(), llvm::StringRef matcherName,
238  SourceRange nameRange,
240  std::index_sequence<Is...>) {
241  using FuncType = ReturnType (*)(ArgTypes...);
242 
243  // Check if the argument count matches the expected count
244  if (!checkArgCount(nameRange, sizeof...(ArgTypes), args, error))
245  return VariantMatcher();
246 
247  // Check if each argument at the corresponding index has the correct type
248  if ((... && checkArgTypeAtIndex<ArgTypes, Is>(matcherName, args, error))) {
249  ReturnType fnPointer = reinterpret_cast<FuncType>(matcherFunc)(
250  ArgTypeTraits<ArgTypes>::get(args[Is].value)...);
253  }
254 
255  return VariantMatcher();
256 }
257 
258 template <typename ReturnType, typename... ArgTypes>
259 static VariantMatcher
260 matcherMarshallFixed(void (*matcherFunc)(), llvm::StringRef matcherName,
262  Diagnostics *error) {
263  return matcherMarshallFixedImpl<ReturnType, ArgTypes...>(
264  matcherFunc, matcherName, nameRange, args, error,
265  std::index_sequence_for<ArgTypes...>{});
266 }
267 
268 // Fixed number of arguments overload
269 template <typename ReturnType, typename... ArgTypes>
270 std::unique_ptr<MatcherDescriptor>
271 makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
272  llvm::StringRef matcherName) {
273  // Create a vector of argument kinds
274  std::vector<ArgKind> argKinds = {ArgTypeTraits<ArgTypes>::getKind()...};
275  return std::make_unique<FixedArgCountMatcherDescriptor>(
276  matcherMarshallFixed<ReturnType, ArgTypes...>,
277  reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
278 }
279 
280 // Variadic operator overload.
281 template <unsigned MinCount, unsigned MaxCount>
282 std::unique_ptr<MatcherDescriptor>
284  StringRef matcherName) {
285  return std::make_unique<VariadicOperatorMatcherDescriptor>(
286  MinCount, MaxCount, func.varOp, matcherName);
287 }
288 } // namespace mlir::query::matcher::internal
289 
290 #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
static std::unique_ptr< DynMatcher > constructDynMatcherFromMatcherFn(MatcherFn &matcherFn)
static VariantMatcher SingleMatcher(DynMatcher matcher)
std::optional< DynMatcher > getDynMatcher() const
static VariantMatcher VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, ArrayRef< VariantMatcher > args)
const VariantMatcher & getMatcher() const
const llvm::StringRef & getString() const
void getArgKinds(unsigned argNo, std::vector< ArgKind > &kinds) const override
Definition: Marshallers.h:150
VariantMatcher create(SourceRange nameRange, llvm::ArrayRef< ParserValue > args, Diagnostics *error) const override
Definition: Marshallers.h:141
FixedArgCountMatcherDescriptor(MarshallerType marshaller, void(*matcherFunc)(), llvm::StringRef matcherName, llvm::ArrayRef< ArgKind > argKinds)
Definition: Marshallers.h:134
VariantMatcher(*)(void(*matcherFunc)(), llvm::StringRef matcherName, SourceRange nameRange, llvm::ArrayRef< ParserValue > args, Diagnostics *error) MarshallerType
Definition: Marshallers.h:129
virtual VariantMatcher create(SourceRange nameRange, const llvm::ArrayRef< ParserValue > args, Diagnostics *error) const =0
virtual void getArgKinds(unsigned argNo, std::vector< ArgKind > &argKinds) const =0
void getArgKinds(unsigned argNo, std::vector< ArgKind > &kinds) const override
Definition: Marshallers.h:198
VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount, VarOp varOp, StringRef matcherName)
Definition: Marshallers.h:164
VariantMatcher create(SourceRange nameRange, ArrayRef< ParserValue > args, Diagnostics *error) const override
Definition: Marshallers.h:169
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
static VariantMatcher matcherMarshallFixedImpl(void(*matcherFunc)(), llvm::StringRef matcherName, SourceRange nameRange, llvm::ArrayRef< ParserValue > args, Diagnostics *error, std::index_sequence< Is... >)
Definition: Marshallers.h:237
void addError(Diagnostics *error, SourceRange range, ErrorType errorType, std::initializer_list< llvm::Twine > errorTexts)
static VariantMatcher matcherMarshallFixed(void(*matcherFunc)(), llvm::StringRef matcherName, SourceRange nameRange, llvm::ArrayRef< ParserValue > args, Diagnostics *error)
Definition: Marshallers.h:260
bool checkArgTypeAtIndex(llvm::StringRef matcherName, llvm::ArrayRef< ParserValue > args, Diagnostics *error)
Definition: Marshallers.h:223
bool checkArgCount(SourceRange nameRange, size_t expectedArgCount, llvm::ArrayRef< ParserValue > args, Diagnostics *error)
Definition: Marshallers.h:210
std::unique_ptr< MatcherDescriptor > makeMatcherAutoMarshall(ReturnType(*matcherFunc)(ArgTypes...), llvm::StringRef matcherName)
Definition: Marshallers.h:271
internal::SourceRange range
Definition: VariantValue.h:148
static bool hasCorrectType(const VariantValue &value)
Definition: Marshallers.h:86
static DynMatcher get(const VariantValue &value)
Definition: Marshallers.h:90
static std::optional< std::string > getBestGuess(const VariantValue &)
Definition: Marshallers.h:96
static unsigned get(const VariantValue &value)
Definition: Marshallers.h:74
static bool hasCorrectType(const VariantValue &value)
Definition: Marshallers.h:70
static std::optional< std::string > getBestGuess(const VariantValue &)
Definition: Marshallers.h:78
static std::optional< std::string > getBestGuess(const VariantValue &)
Definition: Marshallers.h:63
static unsigned get(const VariantValue &value)
Definition: Marshallers.h:59
static bool hasCorrectType(const VariantValue &value)
Definition: Marshallers.h:55
static std::optional< std::string > getBestGuess(const VariantValue &)
Definition: Marshallers.h:48
static bool hasCorrectType(const VariantValue &value)
Definition: Marshallers.h:38
static const llvm::StringRef & get(const VariantValue &value)
Definition: Marshallers.h:42