MLIR 22.0.0git
Marshallers.h
Go to the documentation of this file.
1//===--- Marshallers.h - Generic matcher function marshallers ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains function templates and classes to wrap matcher construct
10// functions. It provides a collection of template function and classes that
11// present a generic marshalling layer on top of matcher construct functions.
12// The registry uses these to export all marshaller constructors with a uniform
13// interface. This mechanism takes inspiration from clang-query.
14//
15//===----------------------------------------------------------------------===//
16
17#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
18#define MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
19
20#include "ErrorBuilder.h"
21#include "VariantValue.h"
22#include "llvm/ADT/ArrayRef.h"
23#include "llvm/ADT/StringRef.h"
24
26
27// Helper template class for jumping from argument type to the correct is/get
28// functions in VariantValue. This is used for verifying and extracting the
29// matcher arguments.
30template <class T>
32template <class T>
33struct ArgTypeTraits<const T &> : public ArgTypeTraits<T> {};
34
35template <>
36struct ArgTypeTraits<llvm::StringRef> {
37
38 static bool hasCorrectType(const VariantValue &value) {
39 return value.isString();
40 }
41
42 static const llvm::StringRef &get(const VariantValue &value) {
43 return value.getString();
44 }
45
46 static ArgKind getKind() { return ArgKind::String; }
47
48 static std::optional<std::string> getBestGuess(const VariantValue &) {
49 return std::nullopt;
50 }
51};
52
53template <>
55 static bool hasCorrectType(const VariantValue &value) {
56 return value.isSigned();
57 }
58
59 static unsigned get(const VariantValue &value) { return value.getSigned(); }
60
61 static ArgKind getKind() { return ArgKind::Signed; }
62
63 static std::optional<std::string> getBestGuess(const VariantValue &) {
64 return std::nullopt;
65 }
66};
67
68template <>
70 static bool hasCorrectType(const VariantValue &value) {
71 return value.isBoolean();
72 }
73
74 static unsigned get(const VariantValue &value) { return value.getBoolean(); }
75
76 static ArgKind getKind() { return ArgKind::Boolean; }
77
78 static std::optional<std::string> getBestGuess(const VariantValue &) {
79 return std::nullopt;
80 }
81};
82
83template <>
85
86 static bool hasCorrectType(const VariantValue &value) {
87 return value.isMatcher();
88 }
89
90 static DynMatcher get(const VariantValue &value) {
91 return *value.getMatcher().getDynMatcher();
92 }
93
94 static ArgKind getKind() { return ArgKind::Matcher; }
95
96 static std::optional<std::string> getBestGuess(const VariantValue &) {
97 return std::nullopt;
98 }
99};
100
101// Interface for generic matcher descriptor.
102// Offers a create() method that constructs the matcher from the provided
103// arguments.
105public:
106 virtual ~MatcherDescriptor() = default;
109 Diagnostics *error) const = 0;
110
111 // If the matcher is variadic, it can take any number of arguments.
112 virtual bool isVariadic() const = 0;
113
114 // Returns the number of arguments accepted by the matcher.
115 virtual unsigned getNumArgs() const = 0;
116
117 // Append the set of argument types accepted for argument 'argNo' to
118 // 'argKinds'.
119 virtual void getArgKinds(unsigned argNo,
120 std::vector<ArgKind> &argKinds) const = 0;
121};
122
124public:
125 using MarshallerType = VariantMatcher (*)(void (*matcherFunc)(),
126 llvm::StringRef matcherName,
127 SourceRange nameRange,
129 Diagnostics *error);
130
131 // Marshaller Function to unpack the arguments and call Func. Func is the
132 // Matcher construct function. This is the function that the matcher
133 // expressions would use to create the matcher.
135 void (*matcherFunc)(),
136 llvm::StringRef matcherName,
138 : marshaller(marshaller), matcherFunc(matcherFunc),
139 matcherName(matcherName), argKinds(argKinds.begin(), argKinds.end()) {}
140
142 Diagnostics *error) const override {
143 return marshaller(matcherFunc, matcherName, nameRange, args, error);
144 }
145
146 bool isVariadic() const override { return false; }
147
148 unsigned getNumArgs() const override { return argKinds.size(); }
149
150 void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
151 kinds.push_back(argKinds[argNo]);
152 }
153
154private:
155 const MarshallerType marshaller;
156 void (*const matcherFunc)();
157 const llvm::StringRef matcherName;
158 const std::vector<ArgKind> argKinds;
159};
160
162public:
164 VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount,
165 VarOp varOp, StringRef matcherName)
166 : minCount(minCount), maxCount(maxCount), varOp(varOp),
167 matcherName(matcherName) {}
168
170 Diagnostics *error) const override {
171 if (args.size() < minCount || maxCount < args.size()) {
173 {llvm::Twine("requires between "), llvm::Twine(minCount),
174 llvm::Twine(" and "), llvm::Twine(maxCount),
175 llvm::Twine(" args, got "), llvm::Twine(args.size())});
176 return VariantMatcher();
177 }
178
179 std::vector<VariantMatcher> innerArgs;
180 for (int64_t i = 0, e = args.size(); i != e; ++i) {
181 const ParserValue &arg = args[i];
182 const VariantValue &value = arg.value;
183 if (!value.isMatcher()) {
185 {llvm::Twine(i + 1), llvm::Twine("matcher: "),
186 llvm::Twine(value.getTypeAsString())});
187 return VariantMatcher();
188 }
189 innerArgs.push_back(value.getMatcher());
190 }
191 return VariantMatcher::VariadicOperatorMatcher(varOp, std::move(innerArgs));
192 }
193
194 bool isVariadic() const override { return true; }
195
196 unsigned getNumArgs() const override { return 0; }
197
198 void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
199 kinds.push_back(ArgKind(ArgKind::Matcher));
200 }
201
202private:
203 const unsigned minCount;
204 const unsigned maxCount;
205 const VarOp varOp;
206 const StringRef matcherName;
207};
208
209// Helper function to check if argument count matches expected count
210inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
212 Diagnostics *error) {
213 if (args.size() != expectedArgCount) {
215 {llvm::Twine(expectedArgCount), llvm::Twine(args.size())});
216 return false;
217 }
218 return true;
219}
220
221// Helper function for checking argument type
222template <typename ArgType, size_t Index>
223inline bool checkArgTypeAtIndex(llvm::StringRef matcherName,
225 Diagnostics *error) {
226 if (!ArgTypeTraits<ArgType>::hasCorrectType(args[Index].value)) {
227 addError(error, args[Index].range, ErrorType::RegistryWrongArgType,
228 {llvm::Twine(matcherName), llvm::Twine(Index + 1)});
229 return false;
230 }
231 return true;
232}
233
234// Marshaller function for fixed number of arguments
235template <typename ReturnType, typename... ArgTypes, size_t... Is>
236static VariantMatcher
237matcherMarshallFixedImpl(void (*matcherFunc)(), llvm::StringRef matcherName,
238 SourceRange nameRange,
240 std::index_sequence<Is...>) {
241 using FuncType = ReturnType (*)(ArgTypes...);
242
243 // Check if the argument count matches the expected count
244 if (!checkArgCount(nameRange, sizeof...(ArgTypes), args, error))
245 return VariantMatcher();
246
247 // Check if each argument at the corresponding index has the correct type
248 if ((... && checkArgTypeAtIndex<ArgTypes, Is>(matcherName, args, error))) {
249 ReturnType fnPointer = reinterpret_cast<FuncType>(matcherFunc)(
250 ArgTypeTraits<ArgTypes>::get(args[Is].value)...);
253 }
254
255 return VariantMatcher();
256}
257
258template <typename ReturnType, typename... ArgTypes>
259static VariantMatcher
260matcherMarshallFixed(void (*matcherFunc)(), llvm::StringRef matcherName,
262 Diagnostics *error) {
263 return matcherMarshallFixedImpl<ReturnType, ArgTypes...>(
264 matcherFunc, matcherName, nameRange, args, error,
265 std::index_sequence_for<ArgTypes...>{});
266}
267
268// Fixed number of arguments overload
269template <typename ReturnType, typename... ArgTypes>
270std::unique_ptr<MatcherDescriptor>
271makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
272 llvm::StringRef matcherName) {
273 // Create a vector of argument kinds
274 std::vector<ArgKind> argKinds = {ArgTypeTraits<ArgTypes>::getKind()...};
275 return std::make_unique<FixedArgCountMatcherDescriptor>(
277 reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
278}
279
280// Variadic operator overload.
281template <unsigned MinCount, unsigned MaxCount>
282std::unique_ptr<MatcherDescriptor>
284 StringRef matcherName) {
285 return std::make_unique<VariadicOperatorMatcherDescriptor>(
286 MinCount, MaxCount, func.varOp, matcherName);
287}
288} // namespace mlir::query::matcher::internal
289
290#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
static std::unique_ptr< DynMatcher > constructDynMatcherFromMatcherFn(MatcherFn &matcherFn)
static VariantMatcher SingleMatcher(DynMatcher matcher)
std::optional< DynMatcher > getDynMatcher() const
static VariantMatcher VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, ArrayRef< VariantMatcher > args)
const VariantMatcher & getMatcher() const
const llvm::StringRef & getString() const
void getArgKinds(unsigned argNo, std::vector< ArgKind > &kinds) const override
VariantMatcher create(SourceRange nameRange, llvm::ArrayRef< ParserValue > args, Diagnostics *error) const override
FixedArgCountMatcherDescriptor(MarshallerType marshaller, void(*matcherFunc)(), llvm::StringRef matcherName, llvm::ArrayRef< ArgKind > argKinds)
VariantMatcher(*)(void(*matcherFunc)(), llvm::StringRef matcherName, SourceRange nameRange, llvm::ArrayRef< ParserValue > args, Diagnostics *error) MarshallerType
virtual VariantMatcher create(SourceRange nameRange, const llvm::ArrayRef< ParserValue > args, Diagnostics *error) const =0
virtual void getArgKinds(unsigned argNo, std::vector< ArgKind > &argKinds) const =0
void getArgKinds(unsigned argNo, std::vector< ArgKind > &kinds) const override
VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount, VarOp varOp, StringRef matcherName)
VariantMatcher create(SourceRange nameRange, ArrayRef< ParserValue > args, Diagnostics *error) const override
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
static VariantMatcher matcherMarshallFixedImpl(void(*matcherFunc)(), llvm::StringRef matcherName, SourceRange nameRange, llvm::ArrayRef< ParserValue > args, Diagnostics *error, std::index_sequence< Is... >)
std::unique_ptr< MatcherDescriptor > makeMatcherAutoMarshall(ReturnType(*matcherFunc)(ArgTypes...), llvm::StringRef matcherName)
void addError(Diagnostics *error, SourceRange range, ErrorType errorType, std::initializer_list< llvm::Twine > errorTexts)
static VariantMatcher matcherMarshallFixed(void(*matcherFunc)(), llvm::StringRef matcherName, SourceRange nameRange, llvm::ArrayRef< ParserValue > args, Diagnostics *error)
bool checkArgTypeAtIndex(llvm::StringRef matcherName, llvm::ArrayRef< ParserValue > args, Diagnostics *error)
bool checkArgCount(SourceRange nameRange, size_t expectedArgCount, llvm::ArrayRef< ParserValue > args, Diagnostics *error)
static bool hasCorrectType(const VariantValue &value)
Definition Marshallers.h:86
static DynMatcher get(const VariantValue &value)
Definition Marshallers.h:90
static std::optional< std::string > getBestGuess(const VariantValue &)
Definition Marshallers.h:96
static unsigned get(const VariantValue &value)
Definition Marshallers.h:74
static bool hasCorrectType(const VariantValue &value)
Definition Marshallers.h:70
static std::optional< std::string > getBestGuess(const VariantValue &)
Definition Marshallers.h:78
static std::optional< std::string > getBestGuess(const VariantValue &)
Definition Marshallers.h:63
static unsigned get(const VariantValue &value)
Definition Marshallers.h:59
static bool hasCorrectType(const VariantValue &value)
Definition Marshallers.h:55
static bool hasCorrectType(const VariantValue &value)
Definition Marshallers.h:38
static const llvm::StringRef & get(const VariantValue &value)
Definition Marshallers.h:42
static std::optional< std::string > getBestGuess(const VariantValue &)
Definition Marshallers.h:48