17 #ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
18 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/StringRef.h"
120 std::vector<ArgKind> &argKinds)
const = 0;
126 llvm::StringRef matcherName,
135 void (*matcherFunc)(),
136 llvm::StringRef matcherName,
138 : marshaller(marshaller), matcherFunc(matcherFunc),
139 matcherName(matcherName), argKinds(argKinds.begin(), argKinds.end()) {}
143 return marshaller(matcherFunc, matcherName, nameRange, args, error);
148 unsigned getNumArgs()
const override {
return argKinds.size(); }
150 void getArgKinds(
unsigned argNo, std::vector<ArgKind> &kinds)
const override {
151 kinds.push_back(argKinds[argNo]);
156 void (*
const matcherFunc)();
157 const llvm::StringRef matcherName;
158 const std::vector<ArgKind> argKinds;
165 VarOp varOp, StringRef matcherName)
166 : minCount(minCount), maxCount(maxCount), varOp(varOp),
167 matcherName(matcherName) {}
171 if (args.size() < minCount || maxCount < args.size()) {
173 {llvm::Twine(
"requires between "), llvm::Twine(minCount),
174 llvm::Twine(
" and "), llvm::Twine(maxCount),
175 llvm::Twine(
" args, got "), llvm::Twine(args.size())});
179 std::vector<VariantMatcher> innerArgs;
180 for (int64_t i = 0, e = args.size(); i != e; ++i) {
185 {llvm::Twine(i + 1), llvm::Twine(
"matcher: "),
186 llvm::Twine(value.getTypeAsString())});
198 void getArgKinds(
unsigned argNo, std::vector<ArgKind> &kinds)
const override {
203 const unsigned minCount;
204 const unsigned maxCount;
206 const StringRef matcherName;
213 if (args.size() != expectedArgCount) {
215 {llvm::Twine(expectedArgCount), llvm::Twine(args.size())});
222 template <
typename ArgType,
size_t Index>
228 {llvm::Twine(matcherName), llvm::Twine(Index + 1)});
235 template <
typename ReturnType,
typename... ArgTypes,
size_t... Is>
240 std::index_sequence<Is...>) {
241 using FuncType = ReturnType (*)(ArgTypes...);
244 if (!
checkArgCount(nameRange,
sizeof...(ArgTypes), args, error))
248 if ((... && checkArgTypeAtIndex<ArgTypes, Is>(matcherName, args, error))) {
249 ReturnType fnPointer =
reinterpret_cast<FuncType
>(matcherFunc)(
258 template <
typename ReturnType,
typename... ArgTypes>
264 matcherFunc, matcherName, nameRange, args, error,
265 std::index_sequence_for<ArgTypes...>{});
269 template <
typename ReturnType,
typename... ArgTypes>
270 std::unique_ptr<MatcherDescriptor>
272 llvm::StringRef matcherName) {
275 return std::make_unique<FixedArgCountMatcherDescriptor>(
276 matcherMarshallFixed<ReturnType, ArgTypes...>,
277 reinterpret_cast<void (*)()
>(matcherFunc), matcherName, argKinds);
281 template <
unsigned MinCount,
unsigned MaxCount>
282 std::unique_ptr<MatcherDescriptor>
284 StringRef matcherName) {
285 return std::make_unique<VariadicOperatorMatcherDescriptor>(
286 MinCount, MaxCount, func.
varOp, matcherName);
static std::unique_ptr< DynMatcher > constructDynMatcherFromMatcherFn(MatcherFn &matcherFn)
static VariantMatcher SingleMatcher(DynMatcher matcher)
std::optional< DynMatcher > getDynMatcher() const
static VariantMatcher VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, ArrayRef< VariantMatcher > args)
const VariantMatcher & getMatcher() const
const llvm::StringRef & getString() const
int64_t getSigned() const
void getArgKinds(unsigned argNo, std::vector< ArgKind > &kinds) const override
VariantMatcher create(SourceRange nameRange, llvm::ArrayRef< ParserValue > args, Diagnostics *error) const override
FixedArgCountMatcherDescriptor(MarshallerType marshaller, void(*matcherFunc)(), llvm::StringRef matcherName, llvm::ArrayRef< ArgKind > argKinds)
unsigned getNumArgs() const override
VariantMatcher(*)(void(*matcherFunc)(), llvm::StringRef matcherName, SourceRange nameRange, llvm::ArrayRef< ParserValue > args, Diagnostics *error) MarshallerType
bool isVariadic() const override
virtual bool isVariadic() const =0
virtual VariantMatcher create(SourceRange nameRange, const llvm::ArrayRef< ParserValue > args, Diagnostics *error) const =0
virtual ~MatcherDescriptor()=default
virtual void getArgKinds(unsigned argNo, std::vector< ArgKind > &argKinds) const =0
virtual unsigned getNumArgs() const =0
unsigned getNumArgs() const override
DynMatcher::VariadicOperator VarOp
bool isVariadic() const override
void getArgKinds(unsigned argNo, std::vector< ArgKind > &kinds) const override
VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount, VarOp varOp, StringRef matcherName)
VariantMatcher create(SourceRange nameRange, ArrayRef< ParserValue > args, Diagnostics *error) const override
The OpAsmOpInterface, see OpAsmInterface.td for more details.
static VariantMatcher matcherMarshallFixedImpl(void(*matcherFunc)(), llvm::StringRef matcherName, SourceRange nameRange, llvm::ArrayRef< ParserValue > args, Diagnostics *error, std::index_sequence< Is... >)
void addError(Diagnostics *error, SourceRange range, ErrorType errorType, std::initializer_list< llvm::Twine > errorTexts)
static VariantMatcher matcherMarshallFixed(void(*matcherFunc)(), llvm::StringRef matcherName, SourceRange nameRange, llvm::ArrayRef< ParserValue > args, Diagnostics *error)
bool checkArgTypeAtIndex(llvm::StringRef matcherName, llvm::ArrayRef< ParserValue > args, Diagnostics *error)
bool checkArgCount(SourceRange nameRange, size_t expectedArgCount, llvm::ArrayRef< ParserValue > args, Diagnostics *error)
std::unique_ptr< MatcherDescriptor > makeMatcherAutoMarshall(ReturnType(*matcherFunc)(ArgTypes...), llvm::StringRef matcherName)
internal::SourceRange range
DynMatcher::VariadicOperator varOp
static bool hasCorrectType(const VariantValue &value)
static DynMatcher get(const VariantValue &value)
static std::optional< std::string > getBestGuess(const VariantValue &)
static unsigned get(const VariantValue &value)
static bool hasCorrectType(const VariantValue &value)
static std::optional< std::string > getBestGuess(const VariantValue &)
static std::optional< std::string > getBestGuess(const VariantValue &)
static unsigned get(const VariantValue &value)
static bool hasCorrectType(const VariantValue &value)
static std::optional< std::string > getBestGuess(const VariantValue &)
static bool hasCorrectType(const VariantValue &value)
static const llvm::StringRef & get(const VariantValue &value)