MLIR  20.0.0git
Marshallers.h
Go to the documentation of this file.
1 //===--- Marshallers.h - Generic matcher function marshallers ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains function templates and classes to wrap matcher construct
10 // functions. It provides a collection of template function and classes that
11 // present a generic marshalling layer on top of matcher construct functions.
12 // The registry uses these to export all marshaller constructors with a uniform
13 // interface. This mechanism takes inspiration from clang-query.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
18 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
19 
20 #include "ErrorBuilder.h"
21 #include "VariantValue.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/StringRef.h"
24 
26 
27 // Helper template class for jumping from argument type to the correct is/get
28 // functions in VariantValue. This is used for verifying and extracting the
29 // matcher arguments.
30 template <class T>
32 template <class T>
33 struct ArgTypeTraits<const T &> : public ArgTypeTraits<T> {};
34 
35 template <>
36 struct ArgTypeTraits<llvm::StringRef> {
37 
38  static bool hasCorrectType(const VariantValue &value) {
39  return value.isString();
40  }
41 
42  static const llvm::StringRef &get(const VariantValue &value) {
43  return value.getString();
44  }
45 
46  static ArgKind getKind() { return ArgKind::String; }
47 
48  static std::optional<std::string> getBestGuess(const VariantValue &) {
49  return std::nullopt;
50  }
51 };
52 
53 template <>
55 
56  static bool hasCorrectType(const VariantValue &value) {
57  return value.isMatcher();
58  }
59 
60  static DynMatcher get(const VariantValue &value) {
61  return *value.getMatcher().getDynMatcher();
62  }
63 
64  static ArgKind getKind() { return ArgKind::Matcher; }
65 
66  static std::optional<std::string> getBestGuess(const VariantValue &) {
67  return std::nullopt;
68  }
69 };
70 
71 // Interface for generic matcher descriptor.
72 // Offers a create() method that constructs the matcher from the provided
73 // arguments.
75 public:
76  virtual ~MatcherDescriptor() = default;
77  virtual VariantMatcher create(SourceRange nameRange,
78  const llvm::ArrayRef<ParserValue> args,
79  Diagnostics *error) const = 0;
80 
81  // Returns the number of arguments accepted by the matcher.
82  virtual unsigned getNumArgs() const = 0;
83 
84  // Append the set of argument types accepted for argument 'argNo' to
85  // 'argKinds'.
86  virtual void getArgKinds(unsigned argNo,
87  std::vector<ArgKind> &argKinds) const = 0;
88 };
89 
91 public:
92  using MarshallerType = VariantMatcher (*)(void (*matcherFunc)(),
93  llvm::StringRef matcherName,
94  SourceRange nameRange,
96  Diagnostics *error);
97 
98  // Marshaller Function to unpack the arguments and call Func. Func is the
99  // Matcher construct function. This is the function that the matcher
100  // expressions would use to create the matcher.
102  void (*matcherFunc)(),
103  llvm::StringRef matcherName,
104  llvm::ArrayRef<ArgKind> argKinds)
105  : marshaller(marshaller), matcherFunc(matcherFunc),
106  matcherName(matcherName), argKinds(argKinds.begin(), argKinds.end()) {}
107 
109  Diagnostics *error) const override {
110  return marshaller(matcherFunc, matcherName, nameRange, args, error);
111  }
112 
113  unsigned getNumArgs() const override { return argKinds.size(); }
114 
115  void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
116  kinds.push_back(argKinds[argNo]);
117  }
118 
119 private:
120  const MarshallerType marshaller;
121  void (*const matcherFunc)();
122  const llvm::StringRef matcherName;
123  const std::vector<ArgKind> argKinds;
124 };
125 
126 // Helper function to check if argument count matches expected count
127 inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
129  Diagnostics *error) {
130  if (args.size() != expectedArgCount) {
131  addError(error, nameRange, ErrorType::RegistryWrongArgCount,
132  {llvm::Twine(expectedArgCount), llvm::Twine(args.size())});
133  return false;
134  }
135  return true;
136 }
137 
138 // Helper function for checking argument type
139 template <typename ArgType, size_t Index>
140 inline bool checkArgTypeAtIndex(llvm::StringRef matcherName,
142  Diagnostics *error) {
143  if (!ArgTypeTraits<ArgType>::hasCorrectType(args[Index].value)) {
144  addError(error, args[Index].range, ErrorType::RegistryWrongArgType,
145  {llvm::Twine(matcherName), llvm::Twine(Index + 1)});
146  return false;
147  }
148  return true;
149 }
150 
151 // Marshaller function for fixed number of arguments
152 template <typename ReturnType, typename... ArgTypes, size_t... Is>
153 static VariantMatcher
154 matcherMarshallFixedImpl(void (*matcherFunc)(), llvm::StringRef matcherName,
155  SourceRange nameRange,
157  std::index_sequence<Is...>) {
158  using FuncType = ReturnType (*)(ArgTypes...);
159 
160  // Check if the argument count matches the expected count
161  if (!checkArgCount(nameRange, sizeof...(ArgTypes), args, error))
162  return VariantMatcher();
163 
164  // Check if each argument at the corresponding index has the correct type
165  if ((... && checkArgTypeAtIndex<ArgTypes, Is>(matcherName, args, error))) {
166  ReturnType fnPointer = reinterpret_cast<FuncType>(matcherFunc)(
167  ArgTypeTraits<ArgTypes>::get(args[Is].value)...);
170  }
171 
172  return VariantMatcher();
173 }
174 
175 template <typename ReturnType, typename... ArgTypes>
176 static VariantMatcher
177 matcherMarshallFixed(void (*matcherFunc)(), llvm::StringRef matcherName,
179  Diagnostics *error) {
180  return matcherMarshallFixedImpl<ReturnType, ArgTypes...>(
181  matcherFunc, matcherName, nameRange, args, error,
182  std::index_sequence_for<ArgTypes...>{});
183 }
184 
185 // Fixed number of arguments overload
186 template <typename ReturnType, typename... ArgTypes>
187 std::unique_ptr<MatcherDescriptor>
188 makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
189  llvm::StringRef matcherName) {
190  // Create a vector of argument kinds
191  std::vector<ArgKind> argKinds = {ArgTypeTraits<ArgTypes>::getKind()...};
192  return std::make_unique<FixedArgCountMatcherDescriptor>(
193  matcherMarshallFixed<ReturnType, ArgTypes...>,
194  reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
195 }
196 
197 } // namespace mlir::query::matcher::internal
198 
199 #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
static std::unique_ptr< DynMatcher > constructDynMatcherFromMatcherFn(MatcherFn &matcherFn)
static VariantMatcher SingleMatcher(DynMatcher matcher)
std::optional< DynMatcher > getDynMatcher() const
const VariantMatcher & getMatcher() const
const llvm::StringRef & getString() const
void getArgKinds(unsigned argNo, std::vector< ArgKind > &kinds) const override
Definition: Marshallers.h:115
VariantMatcher create(SourceRange nameRange, llvm::ArrayRef< ParserValue > args, Diagnostics *error) const override
Definition: Marshallers.h:108
FixedArgCountMatcherDescriptor(MarshallerType marshaller, void(*matcherFunc)(), llvm::StringRef matcherName, llvm::ArrayRef< ArgKind > argKinds)
Definition: Marshallers.h:101
VariantMatcher(*)(void(*matcherFunc)(), llvm::StringRef matcherName, SourceRange nameRange, llvm::ArrayRef< ParserValue > args, Diagnostics *error) MarshallerType
Definition: Marshallers.h:96
virtual VariantMatcher create(SourceRange nameRange, const llvm::ArrayRef< ParserValue > args, Diagnostics *error) const =0
virtual void getArgKinds(unsigned argNo, std::vector< ArgKind > &argKinds) const =0
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
static VariantMatcher matcherMarshallFixedImpl(void(*matcherFunc)(), llvm::StringRef matcherName, SourceRange nameRange, llvm::ArrayRef< ParserValue > args, Diagnostics *error, std::index_sequence< Is... >)
Definition: Marshallers.h:154
void addError(Diagnostics *error, SourceRange range, ErrorType errorType, std::initializer_list< llvm::Twine > errorTexts)
static VariantMatcher matcherMarshallFixed(void(*matcherFunc)(), llvm::StringRef matcherName, SourceRange nameRange, llvm::ArrayRef< ParserValue > args, Diagnostics *error)
Definition: Marshallers.h:177
bool checkArgTypeAtIndex(llvm::StringRef matcherName, llvm::ArrayRef< ParserValue > args, Diagnostics *error)
Definition: Marshallers.h:140
bool checkArgCount(SourceRange nameRange, size_t expectedArgCount, llvm::ArrayRef< ParserValue > args, Diagnostics *error)
Definition: Marshallers.h:127
std::unique_ptr< MatcherDescriptor > makeMatcherAutoMarshall(ReturnType(*matcherFunc)(ArgTypes...), llvm::StringRef matcherName)
Definition: Marshallers.h:188
static bool hasCorrectType(const VariantValue &value)
Definition: Marshallers.h:56
static DynMatcher get(const VariantValue &value)
Definition: Marshallers.h:60
static std::optional< std::string > getBestGuess(const VariantValue &)
Definition: Marshallers.h:66
static std::optional< std::string > getBestGuess(const VariantValue &)
Definition: Marshallers.h:48
static bool hasCorrectType(const VariantValue &value)
Definition: Marshallers.h:38
static const llvm::StringRef & get(const VariantValue &value)
Definition: Marshallers.h:42