9#ifndef MLIR_IR_PDLPATTERNMATCH_H
10#define MLIR_IR_PDLPATTERNMATCH_H
12#include "mlir/Config/mlir-config.h"
14#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
34 PDLValue(
const PDLValue &other) =
default;
35 PDLValue(std::nullptr_t =
nullptr) {}
36 PDLValue(Attribute value)
37 : value(value.getAsOpaquePointer()), kind(
Kind::Attribute) {}
38 PDLValue(Operation *value) : value(value), kind(
Kind::Operation) {}
39 PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(
Kind::Type) {}
42 : value(value.getAsOpaquePointer()), kind(
Kind::Value) {}
48 assert(value &&
"isa<> used on a null value");
49 return kind == getKindOf<T>();
55 typename ResultT = std::conditional_t<
56 std::is_constructible_v<bool, T>, T, std::optional<T>>>
57 ResultT dyn_cast()
const {
58 return isa<T>() ? castImpl<T>() : ResultT();
65 assert(isa<T>() &&
"expected value to be of type `T`");
70 const void *getAsOpaquePointer()
const {
return value; }
73 explicit operator bool()
const {
return value; }
76 Kind getKind()
const {
return kind; }
79 void print(raw_ostream &os)
const;
82 static void print(raw_ostream &os, Kind kind);
86 template <
typename...>
88 template <
typename T,
typename... R>
89 struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
90 template <
typename T,
typename F,
typename... R>
91 struct index_of_t<T, F, R...>
92 : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
96 static Kind getKindOf() {
97 return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
103 template <
typename T>
104 std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T>
106 return T::getFromOpaquePointer(value);
108 template <
typename T>
109 std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T>
111 return *
reinterpret_cast<T *
>(
const_cast<void *
>(value));
113 template <
typename T>
114 std::enable_if_t<std::is_pointer<T>::value, T> castImpl()
const {
115 return reinterpret_cast<T
>(
const_cast<void *
>(value));
119 const void *value{
nullptr};
121 Kind kind{Kind::Attribute};
130 PDLValue::print(os, kind);
143 void push_back(Attribute value) { results.push_back(value); }
146 void push_back(Operation *value) { results.push_back(value); }
149 void push_back(Type value) { results.push_back(value); }
155 allocatedTypeRanges.emplace_back(value.begin(), value.end());
156 typeRanges.push_back(allocatedTypeRanges.back());
157 results.push_back(&typeRanges.back());
159 void push_back(ValueTypeRange<OperandRange> value) {
160 typeRanges.push_back(value);
161 results.push_back(&typeRanges.back());
163 void push_back(ValueTypeRange<ResultRange> value) {
164 typeRanges.push_back(value);
165 results.push_back(&typeRanges.back());
169 void push_back(Value value) { results.push_back(value); }
175 allocatedValueRanges.emplace_back(value.begin(), value.end());
176 valueRanges.push_back(allocatedValueRanges.back());
177 results.push_back(&valueRanges.back());
180 valueRanges.push_back(value);
181 results.push_back(&valueRanges.back());
184 valueRanges.push_back(value);
185 results.push_back(&valueRanges.back());
190 PDLResultList(
unsigned maxNumResults) {
194 typeRanges.reserve(maxNumResults);
195 valueRanges.reserve(maxNumResults);
199 SmallVector<PDLValue> results;
201 SmallVector<TypeRange> typeRanges;
202 SmallVector<ValueRange> valueRanges;
205 SmallVector<std::vector<Type>> allocatedTypeRanges;
206 SmallVector<std::vector<Value>> allocatedValueRanges;
216class PDLPatternConfig {
218 virtual ~PDLPatternConfig() =
default;
223 virtual void notifyRewriteBegin(PatternRewriter &rewriter) {}
224 virtual void notifyRewriteEnd(PatternRewriter &rewriter) {}
227 TypeID getTypeID()
const {
return id; }
230 PDLPatternConfig(TypeID
id) : id(id) {}
239class PDLPatternConfigBase :
public PDLPatternConfig {
242 static bool classof(
const PDLPatternConfig *config) {
243 return config->getTypeID() == getConfigID();
247 static TypeID getConfigID() {
return TypeID::get<T>(); }
250 PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {}
256class PDLPatternConfigSet {
258 PDLPatternConfigSet() =
default;
261 template <
typename... ConfigsT>
262 PDLPatternConfigSet(ConfigsT &&...configs) {
263 (addConfig(std::forward<ConfigsT>(configs)), ...);
268 template <
typename T>
269 const T &
get()
const {
270 const T *
config = tryGet<T>();
271 assert(config &&
"configuration not found");
277 template <
typename T>
278 const T *tryGet()
const {
279 for (
const auto &configIt : configs)
280 if (
const T *config = dyn_cast<T>(configIt.get()))
287 void notifyRewriteBegin(PatternRewriter &rewriter) {
288 for (
const auto &config : configs)
289 config->notifyRewriteBegin(rewriter);
291 void notifyRewriteEnd(PatternRewriter &rewriter) {
292 for (
const auto &config : configs)
293 config->notifyRewriteEnd(rewriter);
298 template <
typename T>
299 void addConfig(T &&config) {
300 assert(!tryGet<std::decay_t<T>>() &&
"configuration already exists");
301 configs.emplace_back(
302 std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
308 SmallVector<std::unique_ptr<PDLPatternConfig>> configs;
317using PDLConstraintFunction = std::function<LogicalResult(
326using PDLRewriteFunction = std::function<LogicalResult(
330namespace pdl_function_builder {
342constexpr bool always_false =
false;
377template <
typename T,
typename Enable =
void>
378struct ProcessPDLValue;
397template <
typename T,
typename BaseT>
398struct ProcessPDLValueBasedOn {
400 verifyAsArg(function_ref<LogicalResult(
const Twine &)> errorFn,
401 PDLValue pdlValue,
size_t argIdx) {
403 if (
failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx)))
405 return ProcessPDLValue<T>::verifyAsArg(
406 errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx);
408 static T processAsArg(PDLValue pdlValue) {
409 return ProcessPDLValue<T>::processAsArg(
410 ProcessPDLValue<BaseT>::processAsArg(pdlValue));
417 verifyAsArg(function_ref<LogicalResult(
const Twine &)> errorFn, BaseT value,
421 static T processAsArg(BaseT baseValue);
428struct ProcessBuiltinPDLValue {
430 verifyAsArg(function_ref<LogicalResult(
const Twine &)> errorFn,
431 PDLValue pdlValue,
size_t argIdx) {
434 return errorFn(
"expected a non-null value for argument " + Twine(argIdx) +
435 " of type: " + llvm::getTypeName<T>());
438 static T processAsArg(PDLValue pdlValue) {
return pdlValue.cast<T>(); }
439 static void processAsResult(PatternRewriter &, PDLResultList &results,
441 results.push_back(value);
449template <
typename T,
typename BaseT>
450struct ProcessDerivedPDLValue :
public ProcessPDLValueBasedOn<T, BaseT> {
452 verifyAsArg(function_ref<LogicalResult(
const Twine &)> errorFn,
453 BaseT baseValue,
size_t argIdx) {
454 return TypeSwitch<BaseT, LogicalResult>(baseValue)
455 .Case([&](T) {
return success(); })
456 .Default([&](BaseT) {
457 return errorFn(
"expected argument " + Twine(argIdx) +
458 " to be of type: " + llvm::getTypeName<T>());
461 using ProcessPDLValueBasedOn<T, BaseT>::verifyAsArg;
463 static T processAsArg(BaseT baseValue) {
464 return baseValue.template cast<T>();
466 using ProcessPDLValueBasedOn<T, BaseT>::processAsArg;
468 static void processAsResult(PatternRewriter &, PDLResultList &results,
470 results.push_back(value);
478struct ProcessPDLValue<Attribute> :
public ProcessBuiltinPDLValue<Attribute> {};
480struct ProcessPDLValue<T,
481 std::enable_if_t<std::is_base_of<Attribute, T>::value>>
482 :
public ProcessDerivedPDLValue<T, Attribute> {};
486struct ProcessPDLValue<StringRef>
487 :
public ProcessPDLValueBasedOn<StringRef, StringAttr> {
488 static StringRef processAsArg(StringAttr value) {
return value.getValue(); }
489 using ProcessPDLValueBasedOn<StringRef, StringAttr>::processAsArg;
491 static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
493 results.push_back(rewriter.getStringAttr(value));
497struct ProcessPDLValue<std::string>
498 :
public ProcessPDLValueBasedOn<std::string, StringAttr> {
499 template <
typename T>
500 static std::string processAsArg(T value) {
501 static_assert(always_false<T>,
502 "`std::string` arguments require a string copy, use "
503 "`StringRef` for string-like arguments instead");
506 static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
508 results.push_back(rewriter.getStringAttr(value));
516struct ProcessPDLValue<Operation *>
517 :
public ProcessBuiltinPDLValue<Operation *> {};
519struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>>
520 :
public ProcessDerivedPDLValue<T, Operation *> {
521 static T processAsArg(Operation *value) {
return cast<T>(value); }
528struct ProcessPDLValue<Type> :
public ProcessBuiltinPDLValue<Type> {};
530struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>>
531 :
public ProcessDerivedPDLValue<T, Type> {};
537struct ProcessPDLValue<
TypeRange> :
public ProcessBuiltinPDLValue<TypeRange> {};
539struct ProcessPDLValue<ValueTypeRange<OperandRange>> {
540 static void processAsResult(PatternRewriter &, PDLResultList &results,
541 ValueTypeRange<OperandRange> types) {
542 results.push_back(types);
546struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
547 static void processAsResult(PatternRewriter &, PDLResultList &results,
548 ValueTypeRange<ResultRange> types) {
549 results.push_back(types);
553struct ProcessPDLValue<SmallVector<Type, N>> {
554 static void processAsResult(PatternRewriter &, PDLResultList &results,
555 SmallVector<Type, N> values) {
564struct ProcessPDLValue<Value> :
public ProcessBuiltinPDLValue<Value> {};
570struct ProcessPDLValue<
ValueRange> :
public ProcessBuiltinPDLValue<ValueRange> {
573struct ProcessPDLValue<OperandRange> {
574 static void processAsResult(PatternRewriter &, PDLResultList &results,
575 OperandRange values) {
576 results.push_back(values);
580struct ProcessPDLValue<ResultRange> {
581 static void processAsResult(PatternRewriter &, PDLResultList &results,
582 ResultRange values) {
583 results.push_back(values);
587struct ProcessPDLValue<SmallVector<Value, N>> {
588 static void processAsResult(PatternRewriter &, PDLResultList &results,
589 SmallVector<Value, N> values) {
603template <
typename PDLFnT, std::size_t... I>
604LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
605 std::index_sequence<I...>) {
606 using FnTraitsT = llvm::function_traits<PDLFnT>;
608 auto errorFn = [&](
const Twine &msg) {
609 return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg);
612 (succeeded(ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::
613 verifyAsArg(errorFn, values[I], I)) &&
620template <
typename PDLFnT, std::size_t... I>
621void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
622 std::index_sequence<I...>) {
625 using FnTraitsT = llvm::function_traits<PDLFnT>;
626 auto errorFn = [&](
const Twine &msg) -> LogicalResult {
627 llvm::report_fatal_error(msg);
630 assert((succeeded(ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::
631 verifyAsArg(errorFn, values[I], I)) &&
643static LogicalResult processResults(PatternRewriter &rewriter,
644 PDLResultList &results, T &&value) {
645 ProcessPDLValue<T>::processAsResult(rewriter, results,
646 std::forward<T>(value));
651template <
typename T1,
typename T2>
652static LogicalResult processResults(PatternRewriter &rewriter,
653 PDLResultList &results,
654 std::pair<T1, T2> &&pair) {
655 if (
failed(processResults(rewriter, results, std::move(pair.first))) ||
656 failed(processResults(rewriter, results, std::move(pair.second))))
662template <
typename... Ts>
663static LogicalResult processResults(PatternRewriter &rewriter,
664 PDLResultList &results,
665 std::tuple<Ts...> &&tuple) {
666 auto applyFn = [&](
auto &&...args) {
667 return (succeeded(processResults(rewriter, results, std::move(args))) &&
670 return success(std::apply(applyFn, std::move(tuple)));
674inline LogicalResult processResults(PatternRewriter &rewriter,
675 PDLResultList &results,
680static LogicalResult processResults(PatternRewriter &rewriter,
681 PDLResultList &results,
685 return processResults(rewriter, results, std::move(*
result));
693template <
typename PDLFnT, std::size_t... I,
694 typename FnTraitsT = llvm::function_traits<PDLFnT>>
695typename FnTraitsT::result_t
696processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter,
697 ArrayRef<PDLValue> values,
698 std::index_sequence<I...>) {
701 (ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
711template <
typename Constra
intFnT>
713 std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
714 PDLConstraintFunction>
715buildConstraintFn(ConstraintFnT &&constraintFn) {
716 return std::forward<ConstraintFnT>(constraintFn);
720template <
typename Constra
intFnT>
722 !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
723 PDLConstraintFunction>
724buildConstraintFn(ConstraintFnT &&constraintFn) {
725 return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
726 PatternRewriter &rewriter, PDLResultList &,
727 ArrayRef<PDLValue> values) -> LogicalResult {
728 auto argIndices = std::make_index_sequence<
729 llvm::function_traits<ConstraintFnT>::num_args - 1>();
730 if (
failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
732 return processArgsAndInvokeConstraint(constraintFn, rewriter, values,
743template <
typename PDLFnT, std::size_t... I,
744 typename FnTraitsT = llvm::function_traits<PDLFnT>>
745std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value,
747processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
748 PDLResultList &, ArrayRef<PDLValue> values,
749 std::index_sequence<I...>) {
751 (ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
757template <
typename PDLFnT, std::size_t... I,
758 typename FnTraitsT = llvm::function_traits<PDLFnT>>
759std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value,
761processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
762 PDLResultList &results, ArrayRef<PDLValue> values,
763 std::index_sequence<I...>) {
764 return processResults(
766 fn(rewriter, (ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::
767 processAsArg(values[I]))...));
777template <
typename RewriteFnT>
778std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
780buildRewriteFn(RewriteFnT &&rewriteFn) {
781 return std::forward<RewriteFnT>(rewriteFn);
785template <
typename RewriteFnT>
786std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
788buildRewriteFn(RewriteFnT &&rewriteFn) {
789 return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
790 PatternRewriter &rewriter, PDLResultList &results,
791 ArrayRef<PDLValue> values) {
793 std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
795 assertArgs<RewriteFnT>(rewriter, values, argIndices);
796 return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
811class PDLPatternModule {
813 PDLPatternModule() =
default;
816 PDLPatternModule(OwningOpRef<ModuleOp> module)
817 : pdlModule(std::move(module)) {}
818 template <
typename... ConfigsT>
819 PDLPatternModule(OwningOpRef<ModuleOp> module, ConfigsT &&...patternConfigs)
820 : PDLPatternModule(std::move(module)) {
821 auto configSet = std::make_unique<PDLPatternConfigSet>(
822 std::forward<ConfigsT>(patternConfigs)...);
823 attachConfigToPatterns(*pdlModule, *configSet);
824 configs.emplace_back(std::move(configSet));
828 void mergeIn(PDLPatternModule &&other);
831 ModuleOp getModule() {
return pdlModule.get(); }
834 MLIRContext *
getContext() {
return getModule()->getContext(); }
861 void registerConstraintFunction(StringRef name,
862 PDLConstraintFunction constraintFn);
863 template <
typename Constra
intFnT>
864 void registerConstraintFunction(StringRef name,
865 ConstraintFnT &&constraintFn) {
866 registerConstraintFunction(name,
867 detail::pdl_function_builder::buildConstraintFn(
868 std::forward<ConstraintFnT>(constraintFn)));
895 void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
896 template <
typename RewriteFnT>
897 void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {
898 registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn(
899 std::forward<RewriteFnT>(rewriteFn)));
903 const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions()
const {
904 return constraintFunctions;
906 llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
907 return constraintFunctions;
910 const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions()
const {
911 return rewriteFunctions;
913 llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
914 return rewriteFunctions;
918 SmallVector<std::unique_ptr<PDLPatternConfigSet>> takeConfigs() {
919 return std::move(configs);
921 DenseMap<Operation *, PDLPatternConfigSet *> takeConfigMap() {
922 return std::move(configMap);
928 constraintFunctions.clear();
929 rewriteFunctions.clear();
935 void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet);
938 OwningOpRef<ModuleOp> pdlModule;
941 SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
942 DenseMap<Operation *, PDLPatternConfigSet *> configMap;
945 llvm::StringMap<PDLConstraintFunction> constraintFunctions;
946 llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
957 template <
typename T>
962class PDLResultList {};
963using PDLConstraintFunction = std::function<LogicalResult(
965using PDLRewriteFunction = std::function<LogicalResult(
968class PDLPatternModule {
970 PDLPatternModule() =
default;
972 PDLPatternModule(OwningOpRef<ModuleOp> ) {}
974 llvm_unreachable(
"Error: PDL for rewrites when PDL is not enabled");
976 void mergeIn(PDLPatternModule &&other) {}
978 template <
typename Constra
intFnT>
979 void registerConstraintFunction(StringRef name,
980 ConstraintFnT &&constraintFn) {}
981 void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn) {}
982 template <
typename RewriteFnT>
983 void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {}
984 const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions()
const {
985 return constraintFunctions;
989 llvm::StringMap<PDLConstraintFunction> constraintFunctions;
memberIdxs push_back(ArrayAttr::get(parser.getContext(), values))
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
const FrozenRewritePatternSet GreedyRewriteConfig config