9#ifndef MLIR_IR_PDLPATTERNMATCH_H
10#define MLIR_IR_PDLPATTERNMATCH_H
12#include "mlir/Config/mlir-config.h"
14#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
17#include "llvm/ADT/TypeSwitch.h"
35 PDLValue(
const PDLValue &other) =
default;
36 PDLValue(std::nullptr_t =
nullptr) {}
37 PDLValue(Attribute value)
38 : value(value.getAsOpaquePointer()), kind(
Kind::Attribute) {}
39 PDLValue(Operation *value) : value(value), kind(
Kind::Operation) {}
40 PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(
Kind::Type) {}
43 : value(value.getAsOpaquePointer()), kind(
Kind::Value) {}
49 assert(value &&
"isa<> used on a null value");
50 return kind == getKindOf<T>();
56 typename ResultT = std::conditional_t<
57 std::is_constructible_v<bool, T>, T, std::optional<T>>>
58 ResultT dyn_cast()
const {
59 return isa<T>() ? castImpl<T>() : ResultT();
66 assert(isa<T>() &&
"expected value to be of type `T`");
71 const void *getAsOpaquePointer()
const {
return value; }
74 explicit operator bool()
const {
return value; }
77 Kind getKind()
const {
return kind; }
80 void print(raw_ostream &os)
const;
83 static void print(raw_ostream &os, Kind kind);
87 template <
typename...>
89 template <
typename T,
typename... R>
90 struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
91 template <
typename T,
typename F,
typename... R>
92 struct index_of_t<T, F, R...>
93 : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
97 static Kind getKindOf() {
98 return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
104 template <
typename T>
105 std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T>
107 return T::getFromOpaquePointer(value);
109 template <
typename T>
110 std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T>
112 return *
reinterpret_cast<T *
>(
const_cast<void *
>(value));
114 template <
typename T>
115 std::enable_if_t<std::is_pointer<T>::value, T> castImpl()
const {
116 return reinterpret_cast<T
>(
const_cast<void *
>(value));
120 const void *value{
nullptr};
122 Kind kind{Kind::Attribute};
131 PDLValue::print(os, kind);
144 void push_back(Attribute value) { results.push_back(value); }
147 void push_back(Operation *value) { results.push_back(value); }
150 void push_back(Type value) { results.push_back(value); }
156 allocatedTypeRanges.emplace_back(value.begin(), value.end());
157 typeRanges.push_back(allocatedTypeRanges.back());
158 results.push_back(&typeRanges.back());
160 void push_back(ValueTypeRange<OperandRange> value) {
161 typeRanges.push_back(value);
162 results.push_back(&typeRanges.back());
164 void push_back(ValueTypeRange<ResultRange> value) {
165 typeRanges.push_back(value);
166 results.push_back(&typeRanges.back());
170 void push_back(Value value) { results.push_back(value); }
176 allocatedValueRanges.emplace_back(value.begin(), value.end());
177 valueRanges.push_back(allocatedValueRanges.back());
178 results.push_back(&valueRanges.back());
181 valueRanges.push_back(value);
182 results.push_back(&valueRanges.back());
185 valueRanges.push_back(value);
186 results.push_back(&valueRanges.back());
191 PDLResultList(
unsigned maxNumResults) {
195 typeRanges.reserve(maxNumResults);
196 valueRanges.reserve(maxNumResults);
200 SmallVector<PDLValue> results;
202 SmallVector<TypeRange> typeRanges;
203 SmallVector<ValueRange> valueRanges;
206 SmallVector<std::vector<Type>> allocatedTypeRanges;
207 SmallVector<std::vector<Value>> allocatedValueRanges;
217class PDLPatternConfig {
219 virtual ~PDLPatternConfig() =
default;
224 virtual void notifyRewriteBegin(PatternRewriter &rewriter) {}
225 virtual void notifyRewriteEnd(PatternRewriter &rewriter) {}
228 TypeID getTypeID()
const {
return id; }
231 PDLPatternConfig(TypeID
id) : id(id) {}
240class PDLPatternConfigBase :
public PDLPatternConfig {
243 static bool classof(
const PDLPatternConfig *config) {
244 return config->getTypeID() == getConfigID();
248 static TypeID getConfigID() {
return TypeID::get<T>(); }
251 PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {}
257class PDLPatternConfigSet {
259 PDLPatternConfigSet() =
default;
262 template <
typename... ConfigsT>
263 PDLPatternConfigSet(ConfigsT &&...configs) {
264 (addConfig(std::forward<ConfigsT>(configs)), ...);
269 template <
typename T>
270 const T &
get()
const {
271 const T *config = tryGet<T>();
272 assert(config &&
"configuration not found");
278 template <
typename T>
279 const T *tryGet()
const {
280 for (
const auto &configIt : configs)
281 if (
const T *config = dyn_cast<T>(configIt.get()))
288 void notifyRewriteBegin(PatternRewriter &rewriter) {
289 for (
const auto &config : configs)
290 config->notifyRewriteBegin(rewriter);
292 void notifyRewriteEnd(PatternRewriter &rewriter) {
293 for (
const auto &config : configs)
294 config->notifyRewriteEnd(rewriter);
299 template <
typename T>
300 void addConfig(T &&config) {
301 assert(!tryGet<std::decay_t<T>>() &&
"configuration already exists");
302 configs.emplace_back(
303 std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
309 SmallVector<std::unique_ptr<PDLPatternConfig>> configs;
318using PDLConstraintFunction = std::function<LogicalResult(
327using PDLRewriteFunction = std::function<LogicalResult(
331namespace pdl_function_builder {
343constexpr bool always_false =
false;
378template <
typename T,
typename Enable =
void>
379struct ProcessPDLValue;
398template <
typename T,
typename BaseT>
399struct ProcessPDLValueBasedOn {
401 verifyAsArg(function_ref<LogicalResult(
const Twine &)> errorFn,
402 PDLValue pdlValue,
size_t argIdx) {
404 if (
failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx)))
406 return ProcessPDLValue<T>::verifyAsArg(
407 errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx);
409 static T processAsArg(PDLValue pdlValue) {
410 return ProcessPDLValue<T>::processAsArg(
411 ProcessPDLValue<BaseT>::processAsArg(pdlValue));
418 verifyAsArg(function_ref<LogicalResult(
const Twine &)> errorFn, BaseT value,
422 static T processAsArg(BaseT baseValue);
429struct ProcessBuiltinPDLValue {
431 verifyAsArg(function_ref<LogicalResult(
const Twine &)> errorFn,
432 PDLValue pdlValue,
size_t argIdx) {
435 return errorFn(
"expected a non-null value for argument " + Twine(argIdx) +
436 " of type: " + llvm::getTypeName<T>());
439 static T processAsArg(PDLValue pdlValue) {
return pdlValue.cast<T>(); }
440 static void processAsResult(PatternRewriter &, PDLResultList &results,
442 results.push_back(value);
450template <
typename T,
typename BaseT>
451struct ProcessDerivedPDLValue :
public ProcessPDLValueBasedOn<T, BaseT> {
453 verifyAsArg(function_ref<LogicalResult(
const Twine &)> errorFn,
454 BaseT baseValue,
size_t argIdx) {
455 return TypeSwitch<BaseT, LogicalResult>(baseValue)
456 .Case([&](T) {
return success(); })
457 .Default([&](BaseT) {
458 return errorFn(
"expected argument " + Twine(argIdx) +
459 " to be of type: " + llvm::getTypeName<T>());
462 using ProcessPDLValueBasedOn<T, BaseT>::verifyAsArg;
464 static T processAsArg(BaseT baseValue) {
465 return baseValue.template cast<T>();
467 using ProcessPDLValueBasedOn<T, BaseT>::processAsArg;
469 static void processAsResult(PatternRewriter &, PDLResultList &results,
471 results.push_back(value);
479struct ProcessPDLValue<Attribute> :
public ProcessBuiltinPDLValue<Attribute> {};
481struct ProcessPDLValue<T,
482 std::enable_if_t<std::is_base_of<Attribute, T>::value>>
483 :
public ProcessDerivedPDLValue<T, Attribute> {};
487struct ProcessPDLValue<StringRef>
488 :
public ProcessPDLValueBasedOn<StringRef, StringAttr> {
489 static StringRef processAsArg(StringAttr value) {
return value.getValue(); }
490 using ProcessPDLValueBasedOn<StringRef, StringAttr>::processAsArg;
492 static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
494 results.push_back(rewriter.getStringAttr(value));
498struct ProcessPDLValue<std::string>
499 :
public ProcessPDLValueBasedOn<std::string, StringAttr> {
500 template <
typename T>
501 static std::string processAsArg(T value) {
502 static_assert(always_false<T>,
503 "`std::string` arguments require a string copy, use "
504 "`StringRef` for string-like arguments instead");
507 static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
509 results.push_back(rewriter.getStringAttr(value));
517struct ProcessPDLValue<Operation *>
518 :
public ProcessBuiltinPDLValue<Operation *> {};
520struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>>
521 :
public ProcessDerivedPDLValue<T, Operation *> {
522 static T processAsArg(Operation *value) {
return cast<T>(value); }
523 using ProcessDerivedPDLValue<T, Operation *>::processAsArg;
530struct ProcessPDLValue<Type> :
public ProcessBuiltinPDLValue<Type> {};
532struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>>
533 :
public ProcessDerivedPDLValue<T, Type> {};
539struct ProcessPDLValue<
TypeRange> :
public ProcessBuiltinPDLValue<TypeRange> {};
541struct ProcessPDLValue<ValueTypeRange<OperandRange>> {
542 static void processAsResult(PatternRewriter &, PDLResultList &results,
543 ValueTypeRange<OperandRange> types) {
544 results.push_back(types);
548struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
549 static void processAsResult(PatternRewriter &, PDLResultList &results,
550 ValueTypeRange<ResultRange> types) {
551 results.push_back(types);
555struct ProcessPDLValue<SmallVector<Type, N>> {
556 static void processAsResult(PatternRewriter &, PDLResultList &results,
557 SmallVector<Type, N> values) {
566struct ProcessPDLValue<Value> :
public ProcessBuiltinPDLValue<Value> {};
572struct ProcessPDLValue<
ValueRange> :
public ProcessBuiltinPDLValue<ValueRange> {
575struct ProcessPDLValue<OperandRange> {
576 static void processAsResult(PatternRewriter &, PDLResultList &results,
577 OperandRange values) {
578 results.push_back(values);
582struct ProcessPDLValue<ResultRange> {
583 static void processAsResult(PatternRewriter &, PDLResultList &results,
584 ResultRange values) {
585 results.push_back(values);
589struct ProcessPDLValue<SmallVector<Value, N>> {
590 static void processAsResult(PatternRewriter &, PDLResultList &results,
591 SmallVector<Value, N> values) {
605template <
typename PDLFnT, std::size_t... I>
606LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
607 std::index_sequence<I...>) {
608 using FnTraitsT = llvm::function_traits<PDLFnT>;
610 auto errorFn = [&](
const Twine &msg) {
611 return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg);
614 (succeeded(ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::
615 verifyAsArg(errorFn, values[I], I)) &&
622template <
typename PDLFnT, std::size_t... I>
623void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
624 std::index_sequence<I...>) {
627 using FnTraitsT = llvm::function_traits<PDLFnT>;
628 auto errorFn = [&](
const Twine &msg) -> LogicalResult {
629 llvm::report_fatal_error(msg);
632 assert((succeeded(ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::
633 verifyAsArg(errorFn, values[I], I)) &&
645static LogicalResult processResults(PatternRewriter &rewriter,
646 PDLResultList &results, T &&value) {
647 ProcessPDLValue<T>::processAsResult(rewriter, results,
648 std::forward<T>(value));
653template <
typename T1,
typename T2>
654static LogicalResult processResults(PatternRewriter &rewriter,
655 PDLResultList &results,
656 std::pair<T1, T2> &&pair) {
657 if (
failed(processResults(rewriter, results, std::move(pair.first))) ||
658 failed(processResults(rewriter, results, std::move(pair.second))))
664template <
typename... Ts>
665static LogicalResult processResults(PatternRewriter &rewriter,
666 PDLResultList &results,
667 std::tuple<Ts...> &&tuple) {
668 auto applyFn = [&](
auto &&...args) {
669 return (succeeded(processResults(rewriter, results, std::move(args))) &&
672 return success(std::apply(applyFn, std::move(tuple)));
676inline LogicalResult processResults(PatternRewriter &rewriter,
677 PDLResultList &results,
682static LogicalResult processResults(PatternRewriter &rewriter,
683 PDLResultList &results,
687 return processResults(rewriter, results, std::move(*
result));
695template <
typename PDLFnT, std::size_t... I,
696 typename FnTraitsT = llvm::function_traits<PDLFnT>>
697typename FnTraitsT::result_t
698processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter,
699 ArrayRef<PDLValue> values,
700 std::index_sequence<I...>) {
703 (ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
713template <
typename Constra
intFnT>
715 std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
716 PDLConstraintFunction>
717buildConstraintFn(ConstraintFnT &&constraintFn) {
718 return std::forward<ConstraintFnT>(constraintFn);
722template <
typename Constra
intFnT>
724 !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
725 PDLConstraintFunction>
726buildConstraintFn(ConstraintFnT &&constraintFn) {
727 return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
728 PatternRewriter &rewriter, PDLResultList &,
729 ArrayRef<PDLValue> values) -> LogicalResult {
730 auto argIndices = std::make_index_sequence<
731 llvm::function_traits<ConstraintFnT>::num_args - 1>();
732 if (
failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
734 return processArgsAndInvokeConstraint(constraintFn, rewriter, values,
745template <
typename PDLFnT, std::size_t... I,
746 typename FnTraitsT = llvm::function_traits<PDLFnT>>
747std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value,
749processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
750 PDLResultList &, ArrayRef<PDLValue> values,
751 std::index_sequence<I...>) {
753 (ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
759template <
typename PDLFnT, std::size_t... I,
760 typename FnTraitsT = llvm::function_traits<PDLFnT>>
761std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value,
763processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
764 PDLResultList &results, ArrayRef<PDLValue> values,
765 std::index_sequence<I...>) {
766 return processResults(
768 fn(rewriter, (ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::
769 processAsArg(values[I]))...));
779template <
typename RewriteFnT>
780std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
782buildRewriteFn(RewriteFnT &&rewriteFn) {
783 return std::forward<RewriteFnT>(rewriteFn);
787template <
typename RewriteFnT>
788std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
790buildRewriteFn(RewriteFnT &&rewriteFn) {
791 return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
792 PatternRewriter &rewriter, PDLResultList &results,
793 ArrayRef<PDLValue> values) {
795 std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
797 assertArgs<RewriteFnT>(rewriter, values, argIndices);
798 return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
813class PDLPatternModule {
815 PDLPatternModule() =
default;
818 PDLPatternModule(OwningOpRef<ModuleOp> module)
819 : pdlModule(std::move(module)) {}
820 template <
typename... ConfigsT>
821 PDLPatternModule(OwningOpRef<ModuleOp> module, ConfigsT &&...patternConfigs)
822 : PDLPatternModule(std::move(module)) {
823 auto configSet = std::make_unique<PDLPatternConfigSet>(
824 std::forward<ConfigsT>(patternConfigs)...);
825 attachConfigToPatterns(*pdlModule, *configSet);
826 configs.emplace_back(std::move(configSet));
830 void mergeIn(PDLPatternModule &&other);
833 ModuleOp getModule() {
return pdlModule.get(); }
836 MLIRContext *
getContext() {
return getModule()->getContext(); }
863 void registerConstraintFunction(StringRef name,
864 PDLConstraintFunction constraintFn);
865 template <
typename Constra
intFnT>
866 void registerConstraintFunction(StringRef name,
867 ConstraintFnT &&constraintFn) {
868 registerConstraintFunction(name,
869 detail::pdl_function_builder::buildConstraintFn(
870 std::forward<ConstraintFnT>(constraintFn)));
897 void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
898 template <
typename RewriteFnT>
899 void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {
900 registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn(
901 std::forward<RewriteFnT>(rewriteFn)));
905 const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions()
const {
906 return constraintFunctions;
908 llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
909 return constraintFunctions;
912 const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions()
const {
913 return rewriteFunctions;
915 llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
916 return rewriteFunctions;
920 SmallVector<std::unique_ptr<PDLPatternConfigSet>> takeConfigs() {
921 return std::move(configs);
923 DenseMap<Operation *, PDLPatternConfigSet *> takeConfigMap() {
924 return std::move(configMap);
930 constraintFunctions.clear();
931 rewriteFunctions.clear();
937 void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet);
940 OwningOpRef<ModuleOp> pdlModule;
943 SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
944 DenseMap<Operation *, PDLPatternConfigSet *> configMap;
947 llvm::StringMap<PDLConstraintFunction> constraintFunctions;
948 llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
959 template <
typename T>
964class PDLResultList {};
965using PDLConstraintFunction = std::function<LogicalResult(
967using PDLRewriteFunction = std::function<LogicalResult(
970class PDLPatternModule {
972 PDLPatternModule() =
default;
974 PDLPatternModule(OwningOpRef<ModuleOp> ) {}
976 llvm_unreachable(
"Error: PDL for rewrites when PDL is not enabled");
978 void mergeIn(PDLPatternModule &&other) {}
980 template <
typename Constra
intFnT>
981 void registerConstraintFunction(StringRef name,
982 ConstraintFnT &&constraintFn) {}
983 void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn) {}
984 template <
typename RewriteFnT>
985 void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {}
986 const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions()
const {
987 return constraintFunctions;
991 llvm::StringMap<PDLConstraintFunction> constraintFunctions;
memberIdxs push_back(ArrayAttr::get(parser.getContext(), values))
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)