MLIR 23.0.0git
PDLPatternMatch.h.inc
Go to the documentation of this file.
1//===- PDLPatternMatch.h - PDLPatternMatcher classes -------==---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_IR_PDLPATTERNMATCH_H
10#define MLIR_IR_PDLPATTERNMATCH_H
11
12#include "mlir/Config/mlir-config.h"
13
14#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
15#include "mlir/IR/Builders.h"
16#include "mlir/IR/BuiltinOps.h"
17#include "llvm/ADT/TypeSwitch.h"
18
19namespace mlir {
20//===----------------------------------------------------------------------===//
21// PDL Patterns
22//===----------------------------------------------------------------------===//
23
24//===----------------------------------------------------------------------===//
25// PDLValue
26
27/// Storage type of byte-code interpreter values. These are passed to constraint
28/// functions as arguments.
29class PDLValue {
30public:
31 /// The underlying kind of a PDL value.
32 enum class Kind { Attribute, Operation, Type, TypeRange, Value, ValueRange };
33
34 /// Construct a new PDL value.
35 PDLValue(const PDLValue &other) = default;
36 PDLValue(std::nullptr_t = nullptr) {}
37 PDLValue(Attribute value)
38 : value(value.getAsOpaquePointer()), kind(Kind::Attribute) {}
39 PDLValue(Operation *value) : value(value), kind(Kind::Operation) {}
40 PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {}
41 PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {}
42 PDLValue(Value value)
43 : value(value.getAsOpaquePointer()), kind(Kind::Value) {}
44 PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {}
45
46 /// Returns true if the type of the held value is `T`.
47 template <typename T>
48 bool isa() const {
49 assert(value && "isa<> used on a null value");
50 return kind == getKindOf<T>();
51 }
52
53 /// Attempt to dynamically cast this value to type `T`, returns null if this
54 /// value is not an instance of `T`.
55 template <typename T,
56 typename ResultT = std::conditional_t<
57 std::is_constructible_v<bool, T>, T, std::optional<T>>>
58 ResultT dyn_cast() const {
59 return isa<T>() ? castImpl<T>() : ResultT();
60 }
61
62 /// Cast this value to type `T`, asserts if this value is not an instance of
63 /// `T`.
64 template <typename T>
65 T cast() const {
66 assert(isa<T>() && "expected value to be of type `T`");
67 return castImpl<T>();
68 }
69
70 /// Get an opaque pointer to the value.
71 const void *getAsOpaquePointer() const { return value; }
72
73 /// Return if this value is null or not.
74 explicit operator bool() const { return value; }
75
76 /// Return the kind of this value.
77 Kind getKind() const { return kind; }
78
79 /// Print this value to the provided output stream.
80 void print(raw_ostream &os) const;
81
82 /// Print the specified value kind to an output stream.
83 static void print(raw_ostream &os, Kind kind);
84
85private:
86 /// Find the index of a given type in a range of other types.
87 template <typename...>
88 struct index_of_t;
89 template <typename T, typename... R>
90 struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
91 template <typename T, typename F, typename... R>
92 struct index_of_t<T, F, R...>
93 : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
94
95 /// Return the kind used for the given T.
96 template <typename T>
97 static Kind getKindOf() {
98 return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
99 TypeRange, Value, ValueRange>::value);
100 }
101
102 /// The internal implementation of `cast`, that returns the underlying value
103 /// as the given type `T`.
104 template <typename T>
105 std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T>
106 castImpl() const {
107 return T::getFromOpaquePointer(value);
108 }
109 template <typename T>
110 std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T>
111 castImpl() const {
112 return *reinterpret_cast<T *>(const_cast<void *>(value));
113 }
114 template <typename T>
115 std::enable_if_t<std::is_pointer<T>::value, T> castImpl() const {
116 return reinterpret_cast<T>(const_cast<void *>(value));
117 }
118
119 /// The internal opaque representation of a PDLValue.
120 const void *value{nullptr};
121 /// The kind of the opaque value.
122 Kind kind{Kind::Attribute};
123};
124
125inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
126 value.print(os);
127 return os;
128}
129
130inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) {
131 PDLValue::print(os, kind);
132 return os;
133}
134
135//===----------------------------------------------------------------------===//
136// PDLResultList
137
138/// The class represents a list of PDL results, returned by a native rewrite
139/// method. It provides the mechanism with which to pass PDLValues back to the
140/// PDL bytecode.
141class PDLResultList {
142public:
143 /// Push a new Attribute value onto the result list.
144 void push_back(Attribute value) { results.push_back(value); }
145
146 /// Push a new Operation onto the result list.
147 void push_back(Operation *value) { results.push_back(value); }
148
149 /// Push a new Type onto the result list.
150 void push_back(Type value) { results.push_back(value); }
151
152 /// Push a new TypeRange onto the result list.
153 void push_back(TypeRange value) {
154 // The lifetime of a TypeRange can't be guaranteed, so we'll need to
155 // allocate a storage for it.
156 allocatedTypeRanges.emplace_back(value.begin(), value.end());
157 typeRanges.push_back(allocatedTypeRanges.back());
158 results.push_back(&typeRanges.back());
159 }
160 void push_back(ValueTypeRange<OperandRange> value) {
161 typeRanges.push_back(value);
162 results.push_back(&typeRanges.back());
163 }
164 void push_back(ValueTypeRange<ResultRange> value) {
165 typeRanges.push_back(value);
166 results.push_back(&typeRanges.back());
167 }
168
169 /// Push a new Value onto the result list.
170 void push_back(Value value) { results.push_back(value); }
171
172 /// Push a new ValueRange onto the result list.
173 void push_back(ValueRange value) {
174 // The lifetime of a ValueRange can't be guaranteed, so we'll need to
175 // allocate a storage for it.
176 allocatedValueRanges.emplace_back(value.begin(), value.end());
177 valueRanges.push_back(allocatedValueRanges.back());
178 results.push_back(&valueRanges.back());
179 }
180 void push_back(OperandRange value) {
181 valueRanges.push_back(value);
182 results.push_back(&valueRanges.back());
183 }
184 void push_back(ResultRange value) {
185 valueRanges.push_back(value);
186 results.push_back(&valueRanges.back());
187 }
188
189protected:
190 /// Create a new result list with the expected number of results.
191 PDLResultList(unsigned maxNumResults) {
192 // For now just reserve enough space for all of the results. We could do
193 // separate counts per range type, but it isn't really worth it unless there
194 // are a "large" number of results.
195 typeRanges.reserve(maxNumResults);
196 valueRanges.reserve(maxNumResults);
197 }
198
199 /// The PDL results held by this list.
200 SmallVector<PDLValue> results;
201 /// Memory used to store ranges held by the list.
202 SmallVector<TypeRange> typeRanges;
203 SmallVector<ValueRange> valueRanges;
204 /// Memory allocated to store ranges in the result list whose lifetime was
205 /// generated in the native function.
206 SmallVector<std::vector<Type>> allocatedTypeRanges;
207 SmallVector<std::vector<Value>> allocatedValueRanges;
208};
209
210//===----------------------------------------------------------------------===//
211// PDLPatternConfig
212
213/// An individual configuration for a pattern, which can be accessed by native
214/// functions via the PDLPatternConfigSet. This allows for injecting additional
215/// configuration into PDL patterns that is specific to certain compilation
216/// flows.
217class PDLPatternConfig {
218public:
219 virtual ~PDLPatternConfig() = default;
220
221 /// Hooks that are invoked at the beginning and end of a rewrite of a matched
222 /// pattern. These can be used to setup any specific state necessary for the
223 /// rewrite.
224 virtual void notifyRewriteBegin(PatternRewriter &rewriter) {}
225 virtual void notifyRewriteEnd(PatternRewriter &rewriter) {}
226
227 /// Return the TypeID that represents this configuration.
228 TypeID getTypeID() const { return id; }
229
230protected:
231 PDLPatternConfig(TypeID id) : id(id) {}
232
233private:
234 TypeID id;
235};
236
237/// This class provides a base class for users implementing a type of pattern
238/// configuration.
239template <typename T>
240class PDLPatternConfigBase : public PDLPatternConfig {
241public:
242 /// Support LLVM style casting.
243 static bool classof(const PDLPatternConfig *config) {
244 return config->getTypeID() == getConfigID();
245 }
246
247 /// Return the type id used for this configuration.
248 static TypeID getConfigID() { return TypeID::get<T>(); }
249
250protected:
251 PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {}
252};
253
254/// This class contains a set of configurations for a specific pattern.
255/// Configurations are uniqued by TypeID, meaning that only one configuration of
256/// each type is allowed.
257class PDLPatternConfigSet {
258public:
259 PDLPatternConfigSet() = default;
260
261 /// Construct a set with the given configurations.
262 template <typename... ConfigsT>
263 PDLPatternConfigSet(ConfigsT &&...configs) {
264 (addConfig(std::forward<ConfigsT>(configs)), ...);
265 }
266
267 /// Get the configuration defined by the given type. Asserts that the
268 /// configuration of the provided type exists.
269 template <typename T>
270 const T &get() const {
271 const T *config = tryGet<T>();
272 assert(config && "configuration not found");
273 return *config;
274 }
275
276 /// Get the configuration defined by the given type, returns nullptr if the
277 /// configuration does not exist.
278 template <typename T>
279 const T *tryGet() const {
280 for (const auto &configIt : configs)
281 if (const T *config = dyn_cast<T>(configIt.get()))
282 return config;
283 return nullptr;
284 }
285
286 /// Notify the configurations within this set at the beginning or end of a
287 /// rewrite of a matched pattern.
288 void notifyRewriteBegin(PatternRewriter &rewriter) {
289 for (const auto &config : configs)
290 config->notifyRewriteBegin(rewriter);
291 }
292 void notifyRewriteEnd(PatternRewriter &rewriter) {
293 for (const auto &config : configs)
294 config->notifyRewriteEnd(rewriter);
295 }
296
297protected:
298 /// Add a configuration to the set.
299 template <typename T>
300 void addConfig(T &&config) {
301 assert(!tryGet<std::decay_t<T>>() && "configuration already exists");
302 configs.emplace_back(
303 std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
304 }
305
306 /// The set of configurations for this pattern. This uses a vector instead of
307 /// a map with the expectation that the number of configurations per set is
308 /// small (<= 1).
309 SmallVector<std::unique_ptr<PDLPatternConfig>> configs;
310};
311
312//===----------------------------------------------------------------------===//
313// PDLPatternModule
314
315/// A generic PDL pattern constraint function. This function applies a
316/// constraint to a given set of opaque PDLValue entities. Returns success if
317/// the constraint successfully held, failure otherwise.
318using PDLConstraintFunction = std::function<LogicalResult(
319 PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
320
321/// A native PDL rewrite function. This function performs a rewrite on the
322/// given set of values. Any results from this rewrite that should be passed
323/// back to PDL should be added to the provided result list. This method is only
324/// invoked when the corresponding match was successful. Returns failure if an
325/// invariant of the rewrite was broken (certain rewriters may recover from
326/// partial pattern application).
327using PDLRewriteFunction = std::function<LogicalResult(
328 PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
329
330namespace detail {
331namespace pdl_function_builder {
332/// A utility variable that always resolves to false. This is useful for static
333/// asserts that are always false, but only should fire in certain templated
334/// constructs. For example, if a templated function should never be called, the
335/// function could be defined as:
336///
337/// template <typename T>
338/// void foo() {
339/// static_assert(always_false<T>, "This function should never be called");
340/// }
341///
342template <class... T>
343constexpr bool always_false = false;
344
345//===----------------------------------------------------------------------===//
346// PDL Function Builder: Type Processing
347//===----------------------------------------------------------------------===//
348
349/// This struct provides a convenient way to determine how to process a given
350/// type as either a PDL parameter, or a result value. This allows for
351/// supporting complex types in constraint and rewrite functions, without
352/// requiring the user to hand-write the necessary glue code themselves.
353/// Specializations of this class should implement the following methods to
354/// enable support as a PDL argument or result type:
355///
356/// static LogicalResult verifyAsArg(
357/// function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue,
358/// size_t argIdx);
359///
360/// * This method verifies that the given PDLValue is valid for use as a
361/// value of `T`.
362///
363/// static T processAsArg(PDLValue pdlValue);
364///
365/// * This method processes the given PDLValue as a value of `T`.
366///
367/// static void processAsResult(PatternRewriter &, PDLResultList &results,
368/// const T &value);
369///
370/// * This method processes the given value of `T` as the result of a
371/// function invocation. The method should package the value into an
372/// appropriate form and append it to the given result list.
373///
374/// If the type `T` is based on a higher order value, consider using
375/// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify
376/// the implementation.
377///
378template <typename T, typename Enable = void>
379struct ProcessPDLValue;
380
381/// This struct provides a simplified model for processing types that are based
382/// on another type, e.g. APInt is based on the handling for IntegerAttr. This
383/// allows for building the necessary processing functions on top of the base
384/// value instead of a PDLValue. Derived users should implement the following
385/// (which subsume the ProcessPDLValue variants):
386///
387/// static LogicalResult verifyAsArg(
388/// function_ref<LogicalResult(const Twine &)> errorFn,
389/// const BaseT &baseValue, size_t argIdx);
390///
391/// * This method verifies that the given PDLValue is valid for use as a
392/// value of `T`.
393///
394/// static T processAsArg(BaseT baseValue);
395///
396/// * This method processes the given base value as a value of `T`.
397///
398template <typename T, typename BaseT>
399struct ProcessPDLValueBasedOn {
400 static LogicalResult
401 verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
402 PDLValue pdlValue, size_t argIdx) {
403 // Verify the base class before continuing.
404 if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx)))
405 return failure();
406 return ProcessPDLValue<T>::verifyAsArg(
407 errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx);
408 }
409 static T processAsArg(PDLValue pdlValue) {
410 return ProcessPDLValue<T>::processAsArg(
411 ProcessPDLValue<BaseT>::processAsArg(pdlValue));
412 }
413
414 /// Explicitly add the expected parent API to ensure the parent class
415 /// implements the necessary API (and doesn't implicitly inherit it from
416 /// somewhere else).
417 static LogicalResult
418 verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, BaseT value,
419 size_t argIdx) {
420 return success();
421 }
422 static T processAsArg(BaseT baseValue);
423};
424
425/// This struct provides a simplified model for processing types that have
426/// "builtin" PDLValue support:
427/// * Attribute, Operation *, Type, TypeRange, ValueRange
428template <typename T>
429struct ProcessBuiltinPDLValue {
430 static LogicalResult
431 verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
432 PDLValue pdlValue, size_t argIdx) {
433 if (pdlValue)
434 return success();
435 return errorFn("expected a non-null value for argument " + Twine(argIdx) +
436 " of type: " + llvm::getTypeName<T>());
437 }
438
439 static T processAsArg(PDLValue pdlValue) { return pdlValue.cast<T>(); }
440 static void processAsResult(PatternRewriter &, PDLResultList &results,
441 T value) {
442 results.push_back(value);
443 }
444};
445
446/// This struct provides a simplified model for processing types that inherit
447/// from builtin PDLValue types. For example, derived attributes like
448/// IntegerAttr, derived types like IntegerType, derived operations like
449/// ModuleOp, Interfaces, etc.
450template <typename T, typename BaseT>
451struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn<T, BaseT> {
452 static LogicalResult
453 verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
454 BaseT baseValue, size_t argIdx) {
455 return TypeSwitch<BaseT, LogicalResult>(baseValue)
456 .Case([&](T) { return success(); })
457 .Default([&](BaseT) {
458 return errorFn("expected argument " + Twine(argIdx) +
459 " to be of type: " + llvm::getTypeName<T>());
460 });
461 }
462 using ProcessPDLValueBasedOn<T, BaseT>::verifyAsArg;
463
464 static T processAsArg(BaseT baseValue) {
465 return baseValue.template cast<T>();
466 }
467 using ProcessPDLValueBasedOn<T, BaseT>::processAsArg;
468
469 static void processAsResult(PatternRewriter &, PDLResultList &results,
470 T value) {
471 results.push_back(value);
472 }
473};
474
475//===----------------------------------------------------------------------===//
476// Attribute
477
478template <>
479struct ProcessPDLValue<Attribute> : public ProcessBuiltinPDLValue<Attribute> {};
480template <typename T>
481struct ProcessPDLValue<T,
482 std::enable_if_t<std::is_base_of<Attribute, T>::value>>
483 : public ProcessDerivedPDLValue<T, Attribute> {};
484
485/// Handling for various Attribute value types.
486template <>
487struct ProcessPDLValue<StringRef>
488 : public ProcessPDLValueBasedOn<StringRef, StringAttr> {
489 static StringRef processAsArg(StringAttr value) { return value.getValue(); }
490 using ProcessPDLValueBasedOn<StringRef, StringAttr>::processAsArg;
491
492 static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
493 StringRef value) {
494 results.push_back(rewriter.getStringAttr(value));
495 }
496};
497template <>
498struct ProcessPDLValue<std::string>
499 : public ProcessPDLValueBasedOn<std::string, StringAttr> {
500 template <typename T>
501 static std::string processAsArg(T value) {
502 static_assert(always_false<T>,
503 "`std::string` arguments require a string copy, use "
504 "`StringRef` for string-like arguments instead");
505 return {};
506 }
507 static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
508 StringRef value) {
509 results.push_back(rewriter.getStringAttr(value));
510 }
511};
512
513//===----------------------------------------------------------------------===//
514// Operation
515
516template <>
517struct ProcessPDLValue<Operation *>
518 : public ProcessBuiltinPDLValue<Operation *> {};
519template <typename T>
520struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>>
521 : public ProcessDerivedPDLValue<T, Operation *> {
522 static T processAsArg(Operation *value) { return cast<T>(value); }
523 using ProcessDerivedPDLValue<T, Operation *>::processAsArg;
524};
525
526//===----------------------------------------------------------------------===//
527// Type
528
529template <>
530struct ProcessPDLValue<Type> : public ProcessBuiltinPDLValue<Type> {};
531template <typename T>
532struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>>
533 : public ProcessDerivedPDLValue<T, Type> {};
534
535//===----------------------------------------------------------------------===//
536// TypeRange
537
538template <>
539struct ProcessPDLValue<TypeRange> : public ProcessBuiltinPDLValue<TypeRange> {};
540template <>
541struct ProcessPDLValue<ValueTypeRange<OperandRange>> {
542 static void processAsResult(PatternRewriter &, PDLResultList &results,
543 ValueTypeRange<OperandRange> types) {
544 results.push_back(types);
545 }
546};
547template <>
548struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
549 static void processAsResult(PatternRewriter &, PDLResultList &results,
550 ValueTypeRange<ResultRange> types) {
551 results.push_back(types);
552 }
553};
554template <unsigned N>
555struct ProcessPDLValue<SmallVector<Type, N>> {
556 static void processAsResult(PatternRewriter &, PDLResultList &results,
557 SmallVector<Type, N> values) {
558 results.push_back(TypeRange(values));
559 }
560};
561
562//===----------------------------------------------------------------------===//
563// Value
564
565template <>
566struct ProcessPDLValue<Value> : public ProcessBuiltinPDLValue<Value> {};
567
568//===----------------------------------------------------------------------===//
569// ValueRange
570
571template <>
572struct ProcessPDLValue<ValueRange> : public ProcessBuiltinPDLValue<ValueRange> {
573};
574template <>
575struct ProcessPDLValue<OperandRange> {
576 static void processAsResult(PatternRewriter &, PDLResultList &results,
577 OperandRange values) {
578 results.push_back(values);
579 }
580};
581template <>
582struct ProcessPDLValue<ResultRange> {
583 static void processAsResult(PatternRewriter &, PDLResultList &results,
584 ResultRange values) {
585 results.push_back(values);
586 }
587};
588template <unsigned N>
589struct ProcessPDLValue<SmallVector<Value, N>> {
590 static void processAsResult(PatternRewriter &, PDLResultList &results,
591 SmallVector<Value, N> values) {
592 results.push_back(ValueRange(values));
593 }
594};
595
596//===----------------------------------------------------------------------===//
597// PDL Function Builder: Argument Handling
598//===----------------------------------------------------------------------===//
599
600/// Validate the given PDLValues match the constraints defined by the argument
601/// types of the given function. In the case of failure, a match failure
602/// diagnostic is emitted.
603/// FIXME: This should be completely removed in favor of `assertArgs`, but PDL
604/// does not currently preserve Constraint application ordering.
605template <typename PDLFnT, std::size_t... I>
606LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
607 std::index_sequence<I...>) {
608 using FnTraitsT = llvm::function_traits<PDLFnT>;
609
610 auto errorFn = [&](const Twine &msg) {
611 return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg);
612 };
613 return success(
614 (succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
615 verifyAsArg(errorFn, values[I], I)) &&
616 ...));
617}
618
619/// Assert that the given PDLValues match the constraints defined by the
620/// arguments of the given function. In the case of failure, a fatal error
621/// is emitted.
622template <typename PDLFnT, std::size_t... I>
623void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
624 std::index_sequence<I...>) {
625 // We only want to do verification in debug builds, same as with `assert`.
626#ifndef NDEBUG
627 using FnTraitsT = llvm::function_traits<PDLFnT>;
628 auto errorFn = [&](const Twine &msg) -> LogicalResult {
629 llvm::report_fatal_error(msg);
630 };
631 (void)errorFn;
632 assert((succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
633 verifyAsArg(errorFn, values[I], I)) &&
634 ...));
635#endif
636 (void)values;
637}
638
639//===----------------------------------------------------------------------===//
640// PDL Function Builder: Results Handling
641//===----------------------------------------------------------------------===//
642
643/// Store a single result within the result list.
644template <typename T>
645static LogicalResult processResults(PatternRewriter &rewriter,
646 PDLResultList &results, T &&value) {
647 ProcessPDLValue<T>::processAsResult(rewriter, results,
648 std::forward<T>(value));
649 return success();
650}
651
652/// Store a std::pair<> as individual results within the result list.
653template <typename T1, typename T2>
654static LogicalResult processResults(PatternRewriter &rewriter,
655 PDLResultList &results,
656 std::pair<T1, T2> &&pair) {
657 if (failed(processResults(rewriter, results, std::move(pair.first))) ||
658 failed(processResults(rewriter, results, std::move(pair.second))))
659 return failure();
660 return success();
661}
662
663/// Store a std::tuple<> as individual results within the result list.
664template <typename... Ts>
665static LogicalResult processResults(PatternRewriter &rewriter,
666 PDLResultList &results,
667 std::tuple<Ts...> &&tuple) {
668 auto applyFn = [&](auto &&...args) {
669 return (succeeded(processResults(rewriter, results, std::move(args))) &&
670 ...);
671 };
672 return success(std::apply(applyFn, std::move(tuple)));
673}
674
675/// Handle LogicalResult propagation.
676inline LogicalResult processResults(PatternRewriter &rewriter,
677 PDLResultList &results,
678 LogicalResult &&result) {
679 return result;
680}
681template <typename T>
682static LogicalResult processResults(PatternRewriter &rewriter,
683 PDLResultList &results,
684 FailureOr<T> &&result) {
685 if (failed(result))
686 return failure();
687 return processResults(rewriter, results, std::move(*result));
688}
689
690//===----------------------------------------------------------------------===//
691// PDL Constraint Builder
692//===----------------------------------------------------------------------===//
693
694/// Process the arguments of a native constraint and invoke it.
695template <typename PDLFnT, std::size_t... I,
696 typename FnTraitsT = llvm::function_traits<PDLFnT>>
697typename FnTraitsT::result_t
698processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter,
699 ArrayRef<PDLValue> values,
700 std::index_sequence<I...>) {
701 return fn(
702 rewriter,
703 (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
704 values[I]))...);
705}
706
707/// Build a constraint function from the given function `ConstraintFnT`. This
708/// allows for enabling the user to define simpler, more direct constraint
709/// functions without needing to handle the low-level PDL goop.
710///
711/// If the constraint function is already in the correct form, we just forward
712/// it directly.
713template <typename ConstraintFnT>
714std::enable_if_t<
715 std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
716 PDLConstraintFunction>
717buildConstraintFn(ConstraintFnT &&constraintFn) {
718 return std::forward<ConstraintFnT>(constraintFn);
719}
720/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
721/// we desire.
722template <typename ConstraintFnT>
723std::enable_if_t<
724 !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
725 PDLConstraintFunction>
726buildConstraintFn(ConstraintFnT &&constraintFn) {
727 return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
728 PatternRewriter &rewriter, PDLResultList &,
729 ArrayRef<PDLValue> values) -> LogicalResult {
730 auto argIndices = std::make_index_sequence<
731 llvm::function_traits<ConstraintFnT>::num_args - 1>();
732 if (failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
733 return failure();
734 return processArgsAndInvokeConstraint(constraintFn, rewriter, values,
735 argIndices);
736 };
737}
738
739//===----------------------------------------------------------------------===//
740// PDL Rewrite Builder
741//===----------------------------------------------------------------------===//
742
743/// Process the arguments of a native rewrite and invoke it.
744/// This overload handles the case of no return values.
745template <typename PDLFnT, std::size_t... I,
746 typename FnTraitsT = llvm::function_traits<PDLFnT>>
747std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value,
748 LogicalResult>
749processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
750 PDLResultList &, ArrayRef<PDLValue> values,
751 std::index_sequence<I...>) {
752 fn(rewriter,
753 (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
754 values[I]))...);
755 return success();
756}
757/// This overload handles the case of return values, which need to be packaged
758/// into the result list.
759template <typename PDLFnT, std::size_t... I,
760 typename FnTraitsT = llvm::function_traits<PDLFnT>>
761std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value,
762 LogicalResult>
763processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
764 PDLResultList &results, ArrayRef<PDLValue> values,
765 std::index_sequence<I...>) {
766 return processResults(
767 rewriter, results,
768 fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
769 processAsArg(values[I]))...));
770 (void)values;
771}
772
773/// Build a rewrite function from the given function `RewriteFnT`. This
774/// allows for enabling the user to define simpler, more direct rewrite
775/// functions without needing to handle the low-level PDL goop.
776///
777/// If the rewrite function is already in the correct form, we just forward
778/// it directly.
779template <typename RewriteFnT>
780std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
781 PDLRewriteFunction>
782buildRewriteFn(RewriteFnT &&rewriteFn) {
783 return std::forward<RewriteFnT>(rewriteFn);
784}
785/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
786/// we desire.
787template <typename RewriteFnT>
788std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
789 PDLRewriteFunction>
790buildRewriteFn(RewriteFnT &&rewriteFn) {
791 return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
792 PatternRewriter &rewriter, PDLResultList &results,
793 ArrayRef<PDLValue> values) {
794 auto argIndices =
795 std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
796 1>();
797 assertArgs<RewriteFnT>(rewriter, values, argIndices);
798 return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
799 argIndices);
800 };
801}
802
803} // namespace pdl_function_builder
804} // namespace detail
805
806//===----------------------------------------------------------------------===//
807// PDLPatternModule
808
809/// This class contains all of the necessary data for a set of PDL patterns, or
810/// pattern rewrites specified in the form of the PDL dialect. This PDL module
811/// contained by this pattern may contain any number of `pdl.pattern`
812/// operations.
813class PDLPatternModule {
814public:
815 PDLPatternModule() = default;
816
817 /// Construct a PDL pattern with the given module and configurations.
818 PDLPatternModule(OwningOpRef<ModuleOp> module)
819 : pdlModule(std::move(module)) {}
820 template <typename... ConfigsT>
821 PDLPatternModule(OwningOpRef<ModuleOp> module, ConfigsT &&...patternConfigs)
822 : PDLPatternModule(std::move(module)) {
823 auto configSet = std::make_unique<PDLPatternConfigSet>(
824 std::forward<ConfigsT>(patternConfigs)...);
825 attachConfigToPatterns(*pdlModule, *configSet);
826 configs.emplace_back(std::move(configSet));
827 }
828
829 /// Merge the state in `other` into this pattern module.
830 void mergeIn(PDLPatternModule &&other);
831
832 /// Return the internal PDL module of this pattern.
833 ModuleOp getModule() { return pdlModule.get(); }
834
835 /// Return the MLIR context of this pattern.
836 MLIRContext *getContext() { return getModule()->getContext(); }
837
838 //===--------------------------------------------------------------------===//
839 // Function Registry
840
841 /// Register a constraint function with PDL. A constraint function may be
842 /// specified in one of two ways:
843 ///
844 /// * `LogicalResult (PatternRewriter &,
845 /// PDLResultList &,
846 /// ArrayRef<PDLValue>)`
847 ///
848 /// In this overload the arguments of the constraint function are passed via
849 /// the low-level PDLValue form, and the results are manually appended to
850 /// the given result list.
851 ///
852 /// * `LogicalResult (PatternRewriter &, ValueTs... values)`
853 ///
854 /// In this form the arguments of the constraint function are passed via the
855 /// expected high level C++ type. In this form, the framework will
856 /// automatically unwrap PDLValues and convert them to the expected ValueTs.
857 /// For example, if the constraint function accepts a `Operation *`, the
858 /// framework will automatically cast the input PDLValue. In the case of a
859 /// `StringRef`, the framework will automatically unwrap the argument as a
860 /// StringAttr and pass the underlying string value. To see the full list of
861 /// supported types, or to see how to add handling for custom types, view
862 /// the definition of `ProcessPDLValue` above.
863 void registerConstraintFunction(StringRef name,
864 PDLConstraintFunction constraintFn);
865 template <typename ConstraintFnT>
866 void registerConstraintFunction(StringRef name,
867 ConstraintFnT &&constraintFn) {
868 registerConstraintFunction(name,
869 detail::pdl_function_builder::buildConstraintFn(
870 std::forward<ConstraintFnT>(constraintFn)));
871 }
872
873 /// Register a rewrite function with PDL. A rewrite function may be specified
874 /// in one of two ways:
875 ///
876 /// * `void (PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)`
877 ///
878 /// In this overload the arguments of the constraint function are passed via
879 /// the low-level PDLValue form, and the results are manually appended to
880 /// the given result list.
881 ///
882 /// * `ResultT (PatternRewriter &, ValueTs... values)`
883 ///
884 /// In this form the arguments and result of the rewrite function are passed
885 /// via the expected high level C++ type. In this form, the framework will
886 /// automatically unwrap the PDLValues arguments and convert them to the
887 /// expected ValueTs. It will also automatically handle the processing and
888 /// packaging of the result value to the result list. For example, if the
889 /// rewrite function takes a `Operation *`, the framework will automatically
890 /// cast the input PDLValue. In the case of a `StringRef`, the framework
891 /// will automatically unwrap the argument as a StringAttr and pass the
892 /// underlying string value. In the reverse case, if the rewrite returns a
893 /// StringRef or std::string, it will automatically package this as a
894 /// StringAttr and append it to the result list. To see the full list of
895 /// supported types, or to see how to add handling for custom types, view
896 /// the definition of `ProcessPDLValue` above.
897 void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
898 template <typename RewriteFnT>
899 void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {
900 registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn(
901 std::forward<RewriteFnT>(rewriteFn)));
902 }
903
904 /// Return the set of the registered constraint functions.
905 const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
906 return constraintFunctions;
907 }
908 llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
909 return constraintFunctions;
910 }
911 /// Return the set of the registered rewrite functions.
912 const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
913 return rewriteFunctions;
914 }
915 llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
916 return rewriteFunctions;
917 }
918
919 /// Return the set of the registered pattern configs.
920 SmallVector<std::unique_ptr<PDLPatternConfigSet>> takeConfigs() {
921 return std::move(configs);
922 }
923 DenseMap<Operation *, PDLPatternConfigSet *> takeConfigMap() {
924 return std::move(configMap);
925 }
926
927 /// Clear out the patterns and functions within this module.
928 void clear() {
929 pdlModule = nullptr;
930 constraintFunctions.clear();
931 rewriteFunctions.clear();
932 }
933
934private:
935 /// Attach the given pattern config set to the patterns defined within the
936 /// given module.
937 void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet);
938
939 /// The module containing the `pdl.pattern` operations.
940 OwningOpRef<ModuleOp> pdlModule;
941
942 /// The set of configuration sets referenced by patterns within `pdlModule`.
943 SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
944 DenseMap<Operation *, PDLPatternConfigSet *> configMap;
945
946 /// The external functions referenced from within the PDL module.
947 llvm::StringMap<PDLConstraintFunction> constraintFunctions;
948 llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
949};
950} // namespace mlir
951
952#else
953
954namespace mlir {
955// Stubs for when PDL in pattern rewrites is not enabled.
956
957class PDLValue {
958public:
959 template <typename T>
960 T dyn_cast() const {
961 return nullptr;
962 }
963};
964class PDLResultList {};
965using PDLConstraintFunction = std::function<LogicalResult(
966 PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
967using PDLRewriteFunction = std::function<LogicalResult(
968 PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
969
970class PDLPatternModule {
971public:
972 PDLPatternModule() = default;
973
974 PDLPatternModule(OwningOpRef<ModuleOp> /*module*/) {}
975 MLIRContext *getContext() {
976 llvm_unreachable("Error: PDL for rewrites when PDL is not enabled");
977 }
978 void mergeIn(PDLPatternModule &&other) {}
979 void clear() {}
980 template <typename ConstraintFnT>
981 void registerConstraintFunction(StringRef name,
982 ConstraintFnT &&constraintFn) {}
983 void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn) {}
984 template <typename RewriteFnT>
985 void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {}
986 const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
987 return constraintFunctions;
988 }
989
990private:
991 llvm::StringMap<PDLConstraintFunction> constraintFunctions;
992};
993
994} // namespace mlir
995#endif
996
997#endif // MLIR_IR_PDLPATTERNMATCH_H
return success()
b getContext())
memberIdxs push_back(ArrayAttr::get(parser.getContext(), values))
values clear()
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
AttrTypeReplacer.
Kind
An enumeration of the kinds of predicates.
Definition Predicate.h:44
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)