MLIR 22.0.0git
PDLPatternMatch.h.inc
Go to the documentation of this file.
1//===- PDLPatternMatch.h - PDLPatternMatcher classes -------==---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_IR_PDLPATTERNMATCH_H
10#define MLIR_IR_PDLPATTERNMATCH_H
11
12#include "mlir/Config/mlir-config.h"
13
14#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
15#include "mlir/IR/Builders.h"
16#include "mlir/IR/BuiltinOps.h"
17
18namespace mlir {
19//===----------------------------------------------------------------------===//
20// PDL Patterns
21//===----------------------------------------------------------------------===//
22
23//===----------------------------------------------------------------------===//
24// PDLValue
25
26/// Storage type of byte-code interpreter values. These are passed to constraint
27/// functions as arguments.
28class PDLValue {
29public:
30 /// The underlying kind of a PDL value.
31 enum class Kind { Attribute, Operation, Type, TypeRange, Value, ValueRange };
32
33 /// Construct a new PDL value.
34 PDLValue(const PDLValue &other) = default;
35 PDLValue(std::nullptr_t = nullptr) {}
36 PDLValue(Attribute value)
37 : value(value.getAsOpaquePointer()), kind(Kind::Attribute) {}
38 PDLValue(Operation *value) : value(value), kind(Kind::Operation) {}
39 PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {}
40 PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {}
41 PDLValue(Value value)
42 : value(value.getAsOpaquePointer()), kind(Kind::Value) {}
43 PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {}
44
45 /// Returns true if the type of the held value is `T`.
46 template <typename T>
47 bool isa() const {
48 assert(value && "isa<> used on a null value");
49 return kind == getKindOf<T>();
50 }
51
52 /// Attempt to dynamically cast this value to type `T`, returns null if this
53 /// value is not an instance of `T`.
54 template <typename T,
55 typename ResultT = std::conditional_t<
56 std::is_constructible_v<bool, T>, T, std::optional<T>>>
57 ResultT dyn_cast() const {
58 return isa<T>() ? castImpl<T>() : ResultT();
59 }
60
61 /// Cast this value to type `T`, asserts if this value is not an instance of
62 /// `T`.
63 template <typename T>
64 T cast() const {
65 assert(isa<T>() && "expected value to be of type `T`");
66 return castImpl<T>();
67 }
68
69 /// Get an opaque pointer to the value.
70 const void *getAsOpaquePointer() const { return value; }
71
72 /// Return if this value is null or not.
73 explicit operator bool() const { return value; }
74
75 /// Return the kind of this value.
76 Kind getKind() const { return kind; }
77
78 /// Print this value to the provided output stream.
79 void print(raw_ostream &os) const;
80
81 /// Print the specified value kind to an output stream.
82 static void print(raw_ostream &os, Kind kind);
83
84private:
85 /// Find the index of a given type in a range of other types.
86 template <typename...>
87 struct index_of_t;
88 template <typename T, typename... R>
89 struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
90 template <typename T, typename F, typename... R>
91 struct index_of_t<T, F, R...>
92 : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
93
94 /// Return the kind used for the given T.
95 template <typename T>
96 static Kind getKindOf() {
97 return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
98 TypeRange, Value, ValueRange>::value);
99 }
100
101 /// The internal implementation of `cast`, that returns the underlying value
102 /// as the given type `T`.
103 template <typename T>
104 std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T>
105 castImpl() const {
106 return T::getFromOpaquePointer(value);
107 }
108 template <typename T>
109 std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T>
110 castImpl() const {
111 return *reinterpret_cast<T *>(const_cast<void *>(value));
112 }
113 template <typename T>
114 std::enable_if_t<std::is_pointer<T>::value, T> castImpl() const {
115 return reinterpret_cast<T>(const_cast<void *>(value));
116 }
117
118 /// The internal opaque representation of a PDLValue.
119 const void *value{nullptr};
120 /// The kind of the opaque value.
121 Kind kind{Kind::Attribute};
122};
123
124inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
125 value.print(os);
126 return os;
127}
128
129inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) {
130 PDLValue::print(os, kind);
131 return os;
132}
133
134//===----------------------------------------------------------------------===//
135// PDLResultList
136
137/// The class represents a list of PDL results, returned by a native rewrite
138/// method. It provides the mechanism with which to pass PDLValues back to the
139/// PDL bytecode.
140class PDLResultList {
141public:
142 /// Push a new Attribute value onto the result list.
143 void push_back(Attribute value) { results.push_back(value); }
144
145 /// Push a new Operation onto the result list.
146 void push_back(Operation *value) { results.push_back(value); }
147
148 /// Push a new Type onto the result list.
149 void push_back(Type value) { results.push_back(value); }
150
151 /// Push a new TypeRange onto the result list.
152 void push_back(TypeRange value) {
153 // The lifetime of a TypeRange can't be guaranteed, so we'll need to
154 // allocate a storage for it.
155 allocatedTypeRanges.emplace_back(value.begin(), value.end());
156 typeRanges.push_back(allocatedTypeRanges.back());
157 results.push_back(&typeRanges.back());
158 }
159 void push_back(ValueTypeRange<OperandRange> value) {
160 typeRanges.push_back(value);
161 results.push_back(&typeRanges.back());
162 }
163 void push_back(ValueTypeRange<ResultRange> value) {
164 typeRanges.push_back(value);
165 results.push_back(&typeRanges.back());
166 }
167
168 /// Push a new Value onto the result list.
169 void push_back(Value value) { results.push_back(value); }
170
171 /// Push a new ValueRange onto the result list.
172 void push_back(ValueRange value) {
173 // The lifetime of a ValueRange can't be guaranteed, so we'll need to
174 // allocate a storage for it.
175 allocatedValueRanges.emplace_back(value.begin(), value.end());
176 valueRanges.push_back(allocatedValueRanges.back());
177 results.push_back(&valueRanges.back());
178 }
179 void push_back(OperandRange value) {
180 valueRanges.push_back(value);
181 results.push_back(&valueRanges.back());
182 }
183 void push_back(ResultRange value) {
184 valueRanges.push_back(value);
185 results.push_back(&valueRanges.back());
186 }
187
188protected:
189 /// Create a new result list with the expected number of results.
190 PDLResultList(unsigned maxNumResults) {
191 // For now just reserve enough space for all of the results. We could do
192 // separate counts per range type, but it isn't really worth it unless there
193 // are a "large" number of results.
194 typeRanges.reserve(maxNumResults);
195 valueRanges.reserve(maxNumResults);
196 }
197
198 /// The PDL results held by this list.
199 SmallVector<PDLValue> results;
200 /// Memory used to store ranges held by the list.
201 SmallVector<TypeRange> typeRanges;
202 SmallVector<ValueRange> valueRanges;
203 /// Memory allocated to store ranges in the result list whose lifetime was
204 /// generated in the native function.
205 SmallVector<std::vector<Type>> allocatedTypeRanges;
206 SmallVector<std::vector<Value>> allocatedValueRanges;
207};
208
209//===----------------------------------------------------------------------===//
210// PDLPatternConfig
211
212/// An individual configuration for a pattern, which can be accessed by native
213/// functions via the PDLPatternConfigSet. This allows for injecting additional
214/// configuration into PDL patterns that is specific to certain compilation
215/// flows.
216class PDLPatternConfig {
217public:
218 virtual ~PDLPatternConfig() = default;
219
220 /// Hooks that are invoked at the beginning and end of a rewrite of a matched
221 /// pattern. These can be used to setup any specific state necessary for the
222 /// rewrite.
223 virtual void notifyRewriteBegin(PatternRewriter &rewriter) {}
224 virtual void notifyRewriteEnd(PatternRewriter &rewriter) {}
225
226 /// Return the TypeID that represents this configuration.
227 TypeID getTypeID() const { return id; }
228
229protected:
230 PDLPatternConfig(TypeID id) : id(id) {}
231
232private:
233 TypeID id;
234};
235
236/// This class provides a base class for users implementing a type of pattern
237/// configuration.
238template <typename T>
239class PDLPatternConfigBase : public PDLPatternConfig {
240public:
241 /// Support LLVM style casting.
242 static bool classof(const PDLPatternConfig *config) {
243 return config->getTypeID() == getConfigID();
244 }
245
246 /// Return the type id used for this configuration.
247 static TypeID getConfigID() { return TypeID::get<T>(); }
248
249protected:
250 PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {}
251};
252
253/// This class contains a set of configurations for a specific pattern.
254/// Configurations are uniqued by TypeID, meaning that only one configuration of
255/// each type is allowed.
256class PDLPatternConfigSet {
257public:
258 PDLPatternConfigSet() = default;
259
260 /// Construct a set with the given configurations.
261 template <typename... ConfigsT>
262 PDLPatternConfigSet(ConfigsT &&...configs) {
263 (addConfig(std::forward<ConfigsT>(configs)), ...);
264 }
265
266 /// Get the configuration defined by the given type. Asserts that the
267 /// configuration of the provided type exists.
268 template <typename T>
269 const T &get() const {
270 const T *config = tryGet<T>();
271 assert(config && "configuration not found");
272 return *config;
273 }
274
275 /// Get the configuration defined by the given type, returns nullptr if the
276 /// configuration does not exist.
277 template <typename T>
278 const T *tryGet() const {
279 for (const auto &configIt : configs)
280 if (const T *config = dyn_cast<T>(configIt.get()))
281 return config;
282 return nullptr;
283 }
284
285 /// Notify the configurations within this set at the beginning or end of a
286 /// rewrite of a matched pattern.
287 void notifyRewriteBegin(PatternRewriter &rewriter) {
288 for (const auto &config : configs)
289 config->notifyRewriteBegin(rewriter);
290 }
291 void notifyRewriteEnd(PatternRewriter &rewriter) {
292 for (const auto &config : configs)
293 config->notifyRewriteEnd(rewriter);
294 }
295
296protected:
297 /// Add a configuration to the set.
298 template <typename T>
299 void addConfig(T &&config) {
300 assert(!tryGet<std::decay_t<T>>() && "configuration already exists");
301 configs.emplace_back(
302 std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
303 }
304
305 /// The set of configurations for this pattern. This uses a vector instead of
306 /// a map with the expectation that the number of configurations per set is
307 /// small (<= 1).
308 SmallVector<std::unique_ptr<PDLPatternConfig>> configs;
309};
310
311//===----------------------------------------------------------------------===//
312// PDLPatternModule
313
314/// A generic PDL pattern constraint function. This function applies a
315/// constraint to a given set of opaque PDLValue entities. Returns success if
316/// the constraint successfully held, failure otherwise.
317using PDLConstraintFunction = std::function<LogicalResult(
318 PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
319
320/// A native PDL rewrite function. This function performs a rewrite on the
321/// given set of values. Any results from this rewrite that should be passed
322/// back to PDL should be added to the provided result list. This method is only
323/// invoked when the corresponding match was successful. Returns failure if an
324/// invariant of the rewrite was broken (certain rewriters may recover from
325/// partial pattern application).
326using PDLRewriteFunction = std::function<LogicalResult(
327 PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
328
329namespace detail {
330namespace pdl_function_builder {
331/// A utility variable that always resolves to false. This is useful for static
332/// asserts that are always false, but only should fire in certain templated
333/// constructs. For example, if a templated function should never be called, the
334/// function could be defined as:
335///
336/// template <typename T>
337/// void foo() {
338/// static_assert(always_false<T>, "This function should never be called");
339/// }
340///
341template <class... T>
342constexpr bool always_false = false;
343
344//===----------------------------------------------------------------------===//
345// PDL Function Builder: Type Processing
346//===----------------------------------------------------------------------===//
347
348/// This struct provides a convenient way to determine how to process a given
349/// type as either a PDL parameter, or a result value. This allows for
350/// supporting complex types in constraint and rewrite functions, without
351/// requiring the user to hand-write the necessary glue code themselves.
352/// Specializations of this class should implement the following methods to
353/// enable support as a PDL argument or result type:
354///
355/// static LogicalResult verifyAsArg(
356/// function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue,
357/// size_t argIdx);
358///
359/// * This method verifies that the given PDLValue is valid for use as a
360/// value of `T`.
361///
362/// static T processAsArg(PDLValue pdlValue);
363///
364/// * This method processes the given PDLValue as a value of `T`.
365///
366/// static void processAsResult(PatternRewriter &, PDLResultList &results,
367/// const T &value);
368///
369/// * This method processes the given value of `T` as the result of a
370/// function invocation. The method should package the value into an
371/// appropriate form and append it to the given result list.
372///
373/// If the type `T` is based on a higher order value, consider using
374/// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify
375/// the implementation.
376///
377template <typename T, typename Enable = void>
378struct ProcessPDLValue;
379
380/// This struct provides a simplified model for processing types that are based
381/// on another type, e.g. APInt is based on the handling for IntegerAttr. This
382/// allows for building the necessary processing functions on top of the base
383/// value instead of a PDLValue. Derived users should implement the following
384/// (which subsume the ProcessPDLValue variants):
385///
386/// static LogicalResult verifyAsArg(
387/// function_ref<LogicalResult(const Twine &)> errorFn,
388/// const BaseT &baseValue, size_t argIdx);
389///
390/// * This method verifies that the given PDLValue is valid for use as a
391/// value of `T`.
392///
393/// static T processAsArg(BaseT baseValue);
394///
395/// * This method processes the given base value as a value of `T`.
396///
397template <typename T, typename BaseT>
398struct ProcessPDLValueBasedOn {
399 static LogicalResult
400 verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
401 PDLValue pdlValue, size_t argIdx) {
402 // Verify the base class before continuing.
403 if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx)))
404 return failure();
405 return ProcessPDLValue<T>::verifyAsArg(
406 errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx);
407 }
408 static T processAsArg(PDLValue pdlValue) {
409 return ProcessPDLValue<T>::processAsArg(
410 ProcessPDLValue<BaseT>::processAsArg(pdlValue));
411 }
412
413 /// Explicitly add the expected parent API to ensure the parent class
414 /// implements the necessary API (and doesn't implicitly inherit it from
415 /// somewhere else).
416 static LogicalResult
417 verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, BaseT value,
418 size_t argIdx) {
419 return success();
420 }
421 static T processAsArg(BaseT baseValue);
422};
423
424/// This struct provides a simplified model for processing types that have
425/// "builtin" PDLValue support:
426/// * Attribute, Operation *, Type, TypeRange, ValueRange
427template <typename T>
428struct ProcessBuiltinPDLValue {
429 static LogicalResult
430 verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
431 PDLValue pdlValue, size_t argIdx) {
432 if (pdlValue)
433 return success();
434 return errorFn("expected a non-null value for argument " + Twine(argIdx) +
435 " of type: " + llvm::getTypeName<T>());
436 }
437
438 static T processAsArg(PDLValue pdlValue) { return pdlValue.cast<T>(); }
439 static void processAsResult(PatternRewriter &, PDLResultList &results,
440 T value) {
441 results.push_back(value);
442 }
443};
444
445/// This struct provides a simplified model for processing types that inherit
446/// from builtin PDLValue types. For example, derived attributes like
447/// IntegerAttr, derived types like IntegerType, derived operations like
448/// ModuleOp, Interfaces, etc.
449template <typename T, typename BaseT>
450struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn<T, BaseT> {
451 static LogicalResult
452 verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
453 BaseT baseValue, size_t argIdx) {
454 return TypeSwitch<BaseT, LogicalResult>(baseValue)
455 .Case([&](T) { return success(); })
456 .Default([&](BaseT) {
457 return errorFn("expected argument " + Twine(argIdx) +
458 " to be of type: " + llvm::getTypeName<T>());
459 });
460 }
461 using ProcessPDLValueBasedOn<T, BaseT>::verifyAsArg;
462
463 static T processAsArg(BaseT baseValue) {
464 return baseValue.template cast<T>();
465 }
466 using ProcessPDLValueBasedOn<T, BaseT>::processAsArg;
467
468 static void processAsResult(PatternRewriter &, PDLResultList &results,
469 T value) {
470 results.push_back(value);
471 }
472};
473
474//===----------------------------------------------------------------------===//
475// Attribute
476
477template <>
478struct ProcessPDLValue<Attribute> : public ProcessBuiltinPDLValue<Attribute> {};
479template <typename T>
480struct ProcessPDLValue<T,
481 std::enable_if_t<std::is_base_of<Attribute, T>::value>>
482 : public ProcessDerivedPDLValue<T, Attribute> {};
483
484/// Handling for various Attribute value types.
485template <>
486struct ProcessPDLValue<StringRef>
487 : public ProcessPDLValueBasedOn<StringRef, StringAttr> {
488 static StringRef processAsArg(StringAttr value) { return value.getValue(); }
489 using ProcessPDLValueBasedOn<StringRef, StringAttr>::processAsArg;
490
491 static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
492 StringRef value) {
493 results.push_back(rewriter.getStringAttr(value));
494 }
495};
496template <>
497struct ProcessPDLValue<std::string>
498 : public ProcessPDLValueBasedOn<std::string, StringAttr> {
499 template <typename T>
500 static std::string processAsArg(T value) {
501 static_assert(always_false<T>,
502 "`std::string` arguments require a string copy, use "
503 "`StringRef` for string-like arguments instead");
504 return {};
505 }
506 static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
507 StringRef value) {
508 results.push_back(rewriter.getStringAttr(value));
509 }
510};
511
512//===----------------------------------------------------------------------===//
513// Operation
514
515template <>
516struct ProcessPDLValue<Operation *>
517 : public ProcessBuiltinPDLValue<Operation *> {};
518template <typename T>
519struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>>
520 : public ProcessDerivedPDLValue<T, Operation *> {
521 static T processAsArg(Operation *value) { return cast<T>(value); }
522};
523
524//===----------------------------------------------------------------------===//
525// Type
526
527template <>
528struct ProcessPDLValue<Type> : public ProcessBuiltinPDLValue<Type> {};
529template <typename T>
530struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>>
531 : public ProcessDerivedPDLValue<T, Type> {};
532
533//===----------------------------------------------------------------------===//
534// TypeRange
535
536template <>
537struct ProcessPDLValue<TypeRange> : public ProcessBuiltinPDLValue<TypeRange> {};
538template <>
539struct ProcessPDLValue<ValueTypeRange<OperandRange>> {
540 static void processAsResult(PatternRewriter &, PDLResultList &results,
541 ValueTypeRange<OperandRange> types) {
542 results.push_back(types);
543 }
544};
545template <>
546struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
547 static void processAsResult(PatternRewriter &, PDLResultList &results,
548 ValueTypeRange<ResultRange> types) {
549 results.push_back(types);
550 }
551};
552template <unsigned N>
553struct ProcessPDLValue<SmallVector<Type, N>> {
554 static void processAsResult(PatternRewriter &, PDLResultList &results,
555 SmallVector<Type, N> values) {
556 results.push_back(TypeRange(values));
557 }
558};
559
560//===----------------------------------------------------------------------===//
561// Value
562
563template <>
564struct ProcessPDLValue<Value> : public ProcessBuiltinPDLValue<Value> {};
565
566//===----------------------------------------------------------------------===//
567// ValueRange
568
569template <>
570struct ProcessPDLValue<ValueRange> : public ProcessBuiltinPDLValue<ValueRange> {
571};
572template <>
573struct ProcessPDLValue<OperandRange> {
574 static void processAsResult(PatternRewriter &, PDLResultList &results,
575 OperandRange values) {
576 results.push_back(values);
577 }
578};
579template <>
580struct ProcessPDLValue<ResultRange> {
581 static void processAsResult(PatternRewriter &, PDLResultList &results,
582 ResultRange values) {
583 results.push_back(values);
584 }
585};
586template <unsigned N>
587struct ProcessPDLValue<SmallVector<Value, N>> {
588 static void processAsResult(PatternRewriter &, PDLResultList &results,
589 SmallVector<Value, N> values) {
590 results.push_back(ValueRange(values));
591 }
592};
593
594//===----------------------------------------------------------------------===//
595// PDL Function Builder: Argument Handling
596//===----------------------------------------------------------------------===//
597
598/// Validate the given PDLValues match the constraints defined by the argument
599/// types of the given function. In the case of failure, a match failure
600/// diagnostic is emitted.
601/// FIXME: This should be completely removed in favor of `assertArgs`, but PDL
602/// does not currently preserve Constraint application ordering.
603template <typename PDLFnT, std::size_t... I>
604LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
605 std::index_sequence<I...>) {
606 using FnTraitsT = llvm::function_traits<PDLFnT>;
607
608 auto errorFn = [&](const Twine &msg) {
609 return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg);
610 };
611 return success(
612 (succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
613 verifyAsArg(errorFn, values[I], I)) &&
614 ...));
615}
616
617/// Assert that the given PDLValues match the constraints defined by the
618/// arguments of the given function. In the case of failure, a fatal error
619/// is emitted.
620template <typename PDLFnT, std::size_t... I>
621void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
622 std::index_sequence<I...>) {
623 // We only want to do verification in debug builds, same as with `assert`.
624#ifndef NDEBUG
625 using FnTraitsT = llvm::function_traits<PDLFnT>;
626 auto errorFn = [&](const Twine &msg) -> LogicalResult {
627 llvm::report_fatal_error(msg);
628 };
629 (void)errorFn;
630 assert((succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
631 verifyAsArg(errorFn, values[I], I)) &&
632 ...));
633#endif
634 (void)values;
635}
636
637//===----------------------------------------------------------------------===//
638// PDL Function Builder: Results Handling
639//===----------------------------------------------------------------------===//
640
641/// Store a single result within the result list.
642template <typename T>
643static LogicalResult processResults(PatternRewriter &rewriter,
644 PDLResultList &results, T &&value) {
645 ProcessPDLValue<T>::processAsResult(rewriter, results,
646 std::forward<T>(value));
647 return success();
648}
649
650/// Store a std::pair<> as individual results within the result list.
651template <typename T1, typename T2>
652static LogicalResult processResults(PatternRewriter &rewriter,
653 PDLResultList &results,
654 std::pair<T1, T2> &&pair) {
655 if (failed(processResults(rewriter, results, std::move(pair.first))) ||
656 failed(processResults(rewriter, results, std::move(pair.second))))
657 return failure();
658 return success();
659}
660
661/// Store a std::tuple<> as individual results within the result list.
662template <typename... Ts>
663static LogicalResult processResults(PatternRewriter &rewriter,
664 PDLResultList &results,
665 std::tuple<Ts...> &&tuple) {
666 auto applyFn = [&](auto &&...args) {
667 return (succeeded(processResults(rewriter, results, std::move(args))) &&
668 ...);
669 };
670 return success(std::apply(applyFn, std::move(tuple)));
671}
672
673/// Handle LogicalResult propagation.
674inline LogicalResult processResults(PatternRewriter &rewriter,
675 PDLResultList &results,
676 LogicalResult &&result) {
677 return result;
678}
679template <typename T>
680static LogicalResult processResults(PatternRewriter &rewriter,
681 PDLResultList &results,
682 FailureOr<T> &&result) {
683 if (failed(result))
684 return failure();
685 return processResults(rewriter, results, std::move(*result));
686}
687
688//===----------------------------------------------------------------------===//
689// PDL Constraint Builder
690//===----------------------------------------------------------------------===//
691
692/// Process the arguments of a native constraint and invoke it.
693template <typename PDLFnT, std::size_t... I,
694 typename FnTraitsT = llvm::function_traits<PDLFnT>>
695typename FnTraitsT::result_t
696processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter,
697 ArrayRef<PDLValue> values,
698 std::index_sequence<I...>) {
699 return fn(
700 rewriter,
701 (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
702 values[I]))...);
703}
704
705/// Build a constraint function from the given function `ConstraintFnT`. This
706/// allows for enabling the user to define simpler, more direct constraint
707/// functions without needing to handle the low-level PDL goop.
708///
709/// If the constraint function is already in the correct form, we just forward
710/// it directly.
711template <typename ConstraintFnT>
712std::enable_if_t<
713 std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
714 PDLConstraintFunction>
715buildConstraintFn(ConstraintFnT &&constraintFn) {
716 return std::forward<ConstraintFnT>(constraintFn);
717}
718/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
719/// we desire.
720template <typename ConstraintFnT>
721std::enable_if_t<
722 !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
723 PDLConstraintFunction>
724buildConstraintFn(ConstraintFnT &&constraintFn) {
725 return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
726 PatternRewriter &rewriter, PDLResultList &,
727 ArrayRef<PDLValue> values) -> LogicalResult {
728 auto argIndices = std::make_index_sequence<
729 llvm::function_traits<ConstraintFnT>::num_args - 1>();
730 if (failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
731 return failure();
732 return processArgsAndInvokeConstraint(constraintFn, rewriter, values,
733 argIndices);
734 };
735}
736
737//===----------------------------------------------------------------------===//
738// PDL Rewrite Builder
739//===----------------------------------------------------------------------===//
740
741/// Process the arguments of a native rewrite and invoke it.
742/// This overload handles the case of no return values.
743template <typename PDLFnT, std::size_t... I,
744 typename FnTraitsT = llvm::function_traits<PDLFnT>>
745std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value,
746 LogicalResult>
747processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
748 PDLResultList &, ArrayRef<PDLValue> values,
749 std::index_sequence<I...>) {
750 fn(rewriter,
751 (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
752 values[I]))...);
753 return success();
754}
755/// This overload handles the case of return values, which need to be packaged
756/// into the result list.
757template <typename PDLFnT, std::size_t... I,
758 typename FnTraitsT = llvm::function_traits<PDLFnT>>
759std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value,
760 LogicalResult>
761processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
762 PDLResultList &results, ArrayRef<PDLValue> values,
763 std::index_sequence<I...>) {
764 return processResults(
765 rewriter, results,
766 fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
767 processAsArg(values[I]))...));
768 (void)values;
769}
770
771/// Build a rewrite function from the given function `RewriteFnT`. This
772/// allows for enabling the user to define simpler, more direct rewrite
773/// functions without needing to handle the low-level PDL goop.
774///
775/// If the rewrite function is already in the correct form, we just forward
776/// it directly.
777template <typename RewriteFnT>
778std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
779 PDLRewriteFunction>
780buildRewriteFn(RewriteFnT &&rewriteFn) {
781 return std::forward<RewriteFnT>(rewriteFn);
782}
783/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
784/// we desire.
785template <typename RewriteFnT>
786std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
787 PDLRewriteFunction>
788buildRewriteFn(RewriteFnT &&rewriteFn) {
789 return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
790 PatternRewriter &rewriter, PDLResultList &results,
791 ArrayRef<PDLValue> values) {
792 auto argIndices =
793 std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
794 1>();
795 assertArgs<RewriteFnT>(rewriter, values, argIndices);
796 return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
797 argIndices);
798 };
799}
800
801} // namespace pdl_function_builder
802} // namespace detail
803
804//===----------------------------------------------------------------------===//
805// PDLPatternModule
806
807/// This class contains all of the necessary data for a set of PDL patterns, or
808/// pattern rewrites specified in the form of the PDL dialect. This PDL module
809/// contained by this pattern may contain any number of `pdl.pattern`
810/// operations.
811class PDLPatternModule {
812public:
813 PDLPatternModule() = default;
814
815 /// Construct a PDL pattern with the given module and configurations.
816 PDLPatternModule(OwningOpRef<ModuleOp> module)
817 : pdlModule(std::move(module)) {}
818 template <typename... ConfigsT>
819 PDLPatternModule(OwningOpRef<ModuleOp> module, ConfigsT &&...patternConfigs)
820 : PDLPatternModule(std::move(module)) {
821 auto configSet = std::make_unique<PDLPatternConfigSet>(
822 std::forward<ConfigsT>(patternConfigs)...);
823 attachConfigToPatterns(*pdlModule, *configSet);
824 configs.emplace_back(std::move(configSet));
825 }
826
827 /// Merge the state in `other` into this pattern module.
828 void mergeIn(PDLPatternModule &&other);
829
830 /// Return the internal PDL module of this pattern.
831 ModuleOp getModule() { return pdlModule.get(); }
832
833 /// Return the MLIR context of this pattern.
834 MLIRContext *getContext() { return getModule()->getContext(); }
835
836 //===--------------------------------------------------------------------===//
837 // Function Registry
838
839 /// Register a constraint function with PDL. A constraint function may be
840 /// specified in one of two ways:
841 ///
842 /// * `LogicalResult (PatternRewriter &,
843 /// PDLResultList &,
844 /// ArrayRef<PDLValue>)`
845 ///
846 /// In this overload the arguments of the constraint function are passed via
847 /// the low-level PDLValue form, and the results are manually appended to
848 /// the given result list.
849 ///
850 /// * `LogicalResult (PatternRewriter &, ValueTs... values)`
851 ///
852 /// In this form the arguments of the constraint function are passed via the
853 /// expected high level C++ type. In this form, the framework will
854 /// automatically unwrap PDLValues and convert them to the expected ValueTs.
855 /// For example, if the constraint function accepts a `Operation *`, the
856 /// framework will automatically cast the input PDLValue. In the case of a
857 /// `StringRef`, the framework will automatically unwrap the argument as a
858 /// StringAttr and pass the underlying string value. To see the full list of
859 /// supported types, or to see how to add handling for custom types, view
860 /// the definition of `ProcessPDLValue` above.
861 void registerConstraintFunction(StringRef name,
862 PDLConstraintFunction constraintFn);
863 template <typename ConstraintFnT>
864 void registerConstraintFunction(StringRef name,
865 ConstraintFnT &&constraintFn) {
866 registerConstraintFunction(name,
867 detail::pdl_function_builder::buildConstraintFn(
868 std::forward<ConstraintFnT>(constraintFn)));
869 }
870
871 /// Register a rewrite function with PDL. A rewrite function may be specified
872 /// in one of two ways:
873 ///
874 /// * `void (PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)`
875 ///
876 /// In this overload the arguments of the constraint function are passed via
877 /// the low-level PDLValue form, and the results are manually appended to
878 /// the given result list.
879 ///
880 /// * `ResultT (PatternRewriter &, ValueTs... values)`
881 ///
882 /// In this form the arguments and result of the rewrite function are passed
883 /// via the expected high level C++ type. In this form, the framework will
884 /// automatically unwrap the PDLValues arguments and convert them to the
885 /// expected ValueTs. It will also automatically handle the processing and
886 /// packaging of the result value to the result list. For example, if the
887 /// rewrite function takes a `Operation *`, the framework will automatically
888 /// cast the input PDLValue. In the case of a `StringRef`, the framework
889 /// will automatically unwrap the argument as a StringAttr and pass the
890 /// underlying string value. In the reverse case, if the rewrite returns a
891 /// StringRef or std::string, it will automatically package this as a
892 /// StringAttr and append it to the result list. To see the full list of
893 /// supported types, or to see how to add handling for custom types, view
894 /// the definition of `ProcessPDLValue` above.
895 void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
896 template <typename RewriteFnT>
897 void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {
898 registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn(
899 std::forward<RewriteFnT>(rewriteFn)));
900 }
901
902 /// Return the set of the registered constraint functions.
903 const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
904 return constraintFunctions;
905 }
906 llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
907 return constraintFunctions;
908 }
909 /// Return the set of the registered rewrite functions.
910 const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
911 return rewriteFunctions;
912 }
913 llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
914 return rewriteFunctions;
915 }
916
917 /// Return the set of the registered pattern configs.
918 SmallVector<std::unique_ptr<PDLPatternConfigSet>> takeConfigs() {
919 return std::move(configs);
920 }
921 DenseMap<Operation *, PDLPatternConfigSet *> takeConfigMap() {
922 return std::move(configMap);
923 }
924
925 /// Clear out the patterns and functions within this module.
926 void clear() {
927 pdlModule = nullptr;
928 constraintFunctions.clear();
929 rewriteFunctions.clear();
930 }
931
932private:
933 /// Attach the given pattern config set to the patterns defined within the
934 /// given module.
935 void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet);
936
937 /// The module containing the `pdl.pattern` operations.
938 OwningOpRef<ModuleOp> pdlModule;
939
940 /// The set of configuration sets referenced by patterns within `pdlModule`.
941 SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
942 DenseMap<Operation *, PDLPatternConfigSet *> configMap;
943
944 /// The external functions referenced from within the PDL module.
945 llvm::StringMap<PDLConstraintFunction> constraintFunctions;
946 llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
947};
948} // namespace mlir
949
950#else
951
952namespace mlir {
953// Stubs for when PDL in pattern rewrites is not enabled.
954
955class PDLValue {
956public:
957 template <typename T>
958 T dyn_cast() const {
959 return nullptr;
960 }
961};
962class PDLResultList {};
963using PDLConstraintFunction = std::function<LogicalResult(
964 PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
965using PDLRewriteFunction = std::function<LogicalResult(
966 PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
967
968class PDLPatternModule {
969public:
970 PDLPatternModule() = default;
971
972 PDLPatternModule(OwningOpRef<ModuleOp> /*module*/) {}
973 MLIRContext *getContext() {
974 llvm_unreachable("Error: PDL for rewrites when PDL is not enabled");
975 }
976 void mergeIn(PDLPatternModule &&other) {}
977 void clear() {}
978 template <typename ConstraintFnT>
979 void registerConstraintFunction(StringRef name,
980 ConstraintFnT &&constraintFn) {}
981 void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn) {}
982 template <typename RewriteFnT>
983 void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {}
984 const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
985 return constraintFunctions;
986 }
987
988private:
989 llvm::StringMap<PDLConstraintFunction> constraintFunctions;
990};
991
992} // namespace mlir
993#endif
994
995#endif // MLIR_IR_PDLPATTERNMATCH_H
return success()
b getContext())
memberIdxs push_back(ArrayAttr::get(parser.getContext(), values))
values clear()
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
AttrTypeReplacer.
Kind
An enumeration of the kinds of predicates.
Definition Predicate.h:44
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
const FrozenRewritePatternSet GreedyRewriteConfig config