9 #ifndef MLIR_IR_PDLPATTERNMATCH_H
10 #define MLIR_IR_PDLPATTERNMATCH_H
12 #include "mlir/Config/mlir-config.h"
14 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
31 enum class Kind { Attribute, Operation,
Type, TypeRange,
Value, ValueRange };
34 PDLValue(
const PDLValue &other) =
default;
35 PDLValue(std::nullptr_t =
nullptr) {}
36 PDLValue(Attribute value)
37 : value(value.getAsOpaquePointer()), kind(
Kind::Attribute) {}
38 PDLValue(Operation *value) : value(value), kind(
Kind::Operation) {}
39 PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(
Kind::
Type) {}
40 PDLValue(TypeRange *value) : value(value), kind(
Kind::TypeRange) {}
42 : value(value.getAsOpaquePointer()), kind(
Kind::
Value) {}
43 PDLValue(ValueRange *value) : value(value), kind(
Kind::ValueRange) {}
48 assert(value &&
"isa<> used on a null value");
49 return kind == getKindOf<T>();
55 typename ResultT = std::conditional_t<
56 std::is_convertible<T, bool>::value, T, std::optional<T>>>
57 ResultT dyn_cast()
const {
58 return isa<T>() ? castImpl<T>() : ResultT();
65 assert(isa<T>() &&
"expected value to be of type `T`");
70 const void *getAsOpaquePointer()
const {
return value; }
73 explicit operator bool()
const {
return value; }
76 Kind getKind()
const {
return kind; }
79 void print(raw_ostream &os)
const;
82 static void print(raw_ostream &os,
Kind kind);
86 template <
typename...>
88 template <
typename T,
typename... R>
89 struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
90 template <
typename T,
typename F,
typename... R>
91 struct index_of_t<T, F, R...>
92 : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
96 static Kind getKindOf() {
97 return static_cast<Kind>(index_of_t<T, Attribute, Operation *,
Type,
98 TypeRange,
Value, ValueRange>::value);
103 template <
typename T>
104 std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T>
106 return T::getFromOpaquePointer(value);
108 template <
typename T>
109 std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T>
111 return *
reinterpret_cast<T *
>(
const_cast<void *
>(value));
113 template <
typename T>
114 std::enable_if_t<std::is_pointer<T>::value, T> castImpl()
const {
115 return reinterpret_cast<T
>(
const_cast<void *
>(value));
119 const void *value{
nullptr};
121 Kind kind{Kind::Attribute};
124 inline raw_ostream &
operator<<(raw_ostream &os, PDLValue value) {
140 class PDLResultList {
143 void push_back(Attribute value) { results.push_back(value); }
146 void push_back(Operation *value) { results.push_back(value); }
149 void push_back(Type value) { results.push_back(value); }
152 void push_back(TypeRange value) {
155 llvm::OwningArrayRef<Type> storage(value.size());
157 allocatedTypeRanges.emplace_back(std::move(storage));
158 typeRanges.push_back(allocatedTypeRanges.back());
159 results.push_back(&typeRanges.back());
161 void push_back(ValueTypeRange<OperandRange> value) {
162 typeRanges.push_back(value);
163 results.push_back(&typeRanges.back());
165 void push_back(ValueTypeRange<ResultRange> value) {
166 typeRanges.push_back(value);
167 results.push_back(&typeRanges.back());
171 void push_back(Value value) { results.push_back(value); }
174 void push_back(ValueRange value) {
177 llvm::OwningArrayRef<Value> storage(value.size());
179 allocatedValueRanges.emplace_back(std::move(storage));
180 valueRanges.push_back(allocatedValueRanges.back());
181 results.push_back(&valueRanges.back());
183 void push_back(OperandRange value) {
184 valueRanges.push_back(value);
185 results.push_back(&valueRanges.back());
187 void push_back(ResultRange value) {
188 valueRanges.push_back(value);
189 results.push_back(&valueRanges.back());
194 PDLResultList(
unsigned maxNumResults) {
198 typeRanges.reserve(maxNumResults);
199 valueRanges.reserve(maxNumResults);
203 SmallVector<PDLValue> results;
205 SmallVector<TypeRange> typeRanges;
206 SmallVector<ValueRange> valueRanges;
209 SmallVector<llvm::OwningArrayRef<Type>> allocatedTypeRanges;
210 SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges;
220 class PDLPatternConfig {
222 virtual ~PDLPatternConfig() =
default;
227 virtual void notifyRewriteBegin(PatternRewriter &rewriter) {}
228 virtual void notifyRewriteEnd(PatternRewriter &rewriter) {}
231 TypeID getTypeID()
const {
return id; }
234 PDLPatternConfig(TypeID
id) : id(id) {}
242 template <
typename T>
246 static bool classof(
const PDLPatternConfig *
config) {
247 return config->getTypeID() == getConfigID();
251 static TypeID getConfigID() {
return TypeID::get<T>(); }
260 class PDLPatternConfigSet {
262 PDLPatternConfigSet() =
default;
265 template <
typename... ConfigsT>
266 PDLPatternConfigSet(ConfigsT &&...configs) {
267 (addConfig(std::forward<ConfigsT>(configs)), ...);
272 template <
typename T>
273 const T &
get()
const {
274 const T *
config = tryGet<T>();
275 assert(
config &&
"configuration not found");
281 template <
typename T>
282 const T *tryGet()
const {
283 for (
const auto &configIt : configs)
284 if (
const T *
config = dyn_cast<T>(configIt.get()))
291 void notifyRewriteBegin(PatternRewriter &rewriter) {
292 for (
const auto &
config : configs)
293 config->notifyRewriteBegin(rewriter);
295 void notifyRewriteEnd(PatternRewriter &rewriter) {
296 for (
const auto &
config : configs)
297 config->notifyRewriteEnd(rewriter);
302 template <
typename T>
303 void addConfig(T &&
config) {
304 assert(!tryGet<std::decay_t<T>>() &&
"configuration already exists");
305 configs.emplace_back(
306 std::make_unique<std::decay_t<T>>(std::forward<T>(
config)));
312 SmallVector<std::unique_ptr<PDLPatternConfig>> configs;
321 using PDLConstraintFunction = std::function<LogicalResult(
322 PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
330 using PDLRewriteFunction = std::function<LogicalResult(
331 PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
334 namespace pdl_function_builder {
345 template <
class... T>
346 constexpr
bool always_false =
false;
381 template <
typename T,
typename Enable =
void>
382 struct ProcessPDLValue;
401 template <
typename T,
typename BaseT>
402 struct ProcessPDLValueBasedOn {
404 verifyAsArg(
function_ref<LogicalResult(
const Twine &)> errorFn,
405 PDLValue pdlValue,
size_t argIdx) {
407 if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx)))
409 return ProcessPDLValue<T>::verifyAsArg(
410 errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx);
412 static T processAsArg(PDLValue pdlValue) {
413 return ProcessPDLValue<T>::processAsArg(
414 ProcessPDLValue<BaseT>::processAsArg(pdlValue));
425 static T processAsArg(
BaseT baseValue);
431 template <
typename T>
432 struct ProcessBuiltinPDLValue {
434 verifyAsArg(
function_ref<LogicalResult(
const Twine &)> errorFn,
435 PDLValue pdlValue,
size_t argIdx) {
438 return errorFn(
"expected a non-null value for argument " + Twine(argIdx) +
439 " of type: " + llvm::getTypeName<T>());
442 static T processAsArg(PDLValue pdlValue) {
return pdlValue.cast<T>(); }
443 static void processAsResult(PatternRewriter &, PDLResultList &results,
445 results.push_back(value);
453 template <
typename T,
typename BaseT>
454 struct ProcessDerivedPDLValue :
public ProcessPDLValueBasedOn<T, BaseT> {
456 verifyAsArg(
function_ref<LogicalResult(
const Twine &)> errorFn,
457 BaseT baseValue,
size_t argIdx) {
458 return TypeSwitch<BaseT, LogicalResult>(baseValue)
459 .Case([&](T) {
return success(); })
460 .Default([&](
BaseT) {
461 return errorFn(
"expected argument " + Twine(argIdx) +
462 " to be of type: " + llvm::getTypeName<T>());
465 using ProcessPDLValueBasedOn<T, BaseT>::verifyAsArg;
467 static T processAsArg(
BaseT baseValue) {
468 return baseValue.template cast<T>();
470 using ProcessPDLValueBasedOn<T, BaseT>::processAsArg;
472 static void processAsResult(PatternRewriter &, PDLResultList &results,
474 results.push_back(value);
482 struct ProcessPDLValue<Attribute> :
public ProcessBuiltinPDLValue<Attribute> {};
483 template <
typename T>
484 struct ProcessPDLValue<T,
485 std::enable_if_t<std::is_base_of<Attribute, T>::value>>
486 :
public ProcessDerivedPDLValue<T, Attribute> {};
490 struct ProcessPDLValue<StringRef>
491 :
public ProcessPDLValueBasedOn<StringRef, StringAttr> {
492 static StringRef processAsArg(StringAttr value) {
return value.getValue(); }
493 using ProcessPDLValueBasedOn<StringRef, StringAttr>::processAsArg;
495 static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
497 results.push_back(rewriter.getStringAttr(value));
501 struct ProcessPDLValue<std::string>
502 :
public ProcessPDLValueBasedOn<std::string, StringAttr> {
503 template <
typename T>
504 static std::string processAsArg(T value) {
505 static_assert(always_false<T>,
506 "`std::string` arguments require a string copy, use "
507 "`StringRef` for string-like arguments instead");
510 static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
512 results.push_back(rewriter.getStringAttr(value));
520 struct ProcessPDLValue<Operation *>
521 :
public ProcessBuiltinPDLValue<Operation *> {};
522 template <
typename T>
523 struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>>
524 :
public ProcessDerivedPDLValue<T, Operation *> {
525 static T processAsArg(Operation *value) {
return cast<T>(value); }
532 struct ProcessPDLValue<
Type> :
public ProcessBuiltinPDLValue<Type> {};
533 template <
typename T>
534 struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>>
535 :
public ProcessDerivedPDLValue<T, Type> {};
541 struct ProcessPDLValue<TypeRange> :
public ProcessBuiltinPDLValue<TypeRange> {};
543 struct ProcessPDLValue<ValueTypeRange<OperandRange>> {
544 static void processAsResult(PatternRewriter &, PDLResultList &results,
545 ValueTypeRange<OperandRange> types) {
546 results.push_back(types);
550 struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
551 static void processAsResult(PatternRewriter &, PDLResultList &results,
552 ValueTypeRange<ResultRange> types) {
553 results.push_back(types);
556 template <
unsigned N>
557 struct ProcessPDLValue<SmallVector<
Type, N>> {
558 static void processAsResult(PatternRewriter &, PDLResultList &results,
559 SmallVector<Type, N> values) {
560 results.push_back(TypeRange(values));
568 struct ProcessPDLValue<
Value> :
public ProcessBuiltinPDLValue<Value> {};
574 struct ProcessPDLValue<ValueRange> :
public ProcessBuiltinPDLValue<ValueRange> {
577 struct ProcessPDLValue<OperandRange> {
578 static void processAsResult(PatternRewriter &, PDLResultList &results,
579 OperandRange values) {
580 results.push_back(values);
584 struct ProcessPDLValue<ResultRange> {
585 static void processAsResult(PatternRewriter &, PDLResultList &results,
586 ResultRange values) {
587 results.push_back(values);
590 template <
unsigned N>
591 struct ProcessPDLValue<SmallVector<
Value, N>> {
592 static void processAsResult(PatternRewriter &, PDLResultList &results,
593 SmallVector<Value, N> values) {
594 results.push_back(ValueRange(values));
607 template <
typename PDLFnT, std::size_t... I>
608 LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
609 std::index_sequence<I...>) {
610 using FnTraitsT = llvm::function_traits<PDLFnT>;
612 auto errorFn = [&](
const Twine &msg) {
613 return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg);
616 (succeeded(ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::
617 verifyAsArg(errorFn, values[I], I)) &&
624 template <
typename PDLFnT, std::size_t... I>
625 void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
626 std::index_sequence<I...>) {
629 using FnTraitsT = llvm::function_traits<PDLFnT>;
630 auto errorFn = [&](
const Twine &msg) -> LogicalResult {
631 llvm::report_fatal_error(msg);
634 assert((succeeded(ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::
635 verifyAsArg(errorFn, values[I], I)) &&
646 template <
typename T>
647 static LogicalResult processResults(PatternRewriter &rewriter,
648 PDLResultList &results, T &&value) {
649 ProcessPDLValue<T>::processAsResult(rewriter, results,
650 std::forward<T>(value));
655 template <
typename T1,
typename T2>
656 static LogicalResult processResults(PatternRewriter &rewriter,
657 PDLResultList &results,
658 std::pair<T1, T2> &&pair) {
659 if (failed(processResults(rewriter, results, std::move(pair.first))) ||
660 failed(processResults(rewriter, results, std::move(pair.second))))
666 template <
typename... Ts>
667 static LogicalResult processResults(PatternRewriter &rewriter,
668 PDLResultList &results,
669 std::tuple<Ts...> &&tuple) {
670 auto applyFn = [&](
auto &&...args) {
671 return (succeeded(processResults(rewriter, results, std::move(args))) &&
674 return success(std::apply(applyFn, std::move(tuple)));
678 inline LogicalResult processResults(PatternRewriter &rewriter,
679 PDLResultList &results,
680 LogicalResult &&result) {
683 template <
typename T>
684 static LogicalResult processResults(PatternRewriter &rewriter,
685 PDLResultList &results,
686 FailureOr<T> &&result) {
689 return processResults(rewriter, results, std::move(*result));
697 template <
typename PDLFnT, std::size_t... I,
698 typename FnTraitsT = llvm::function_traits<PDLFnT>>
699 typename FnTraitsT::result_t
700 processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter,
701 ArrayRef<PDLValue> values,
702 std::index_sequence<I...>) {
705 (ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
715 template <
typename Constra
intFnT>
717 std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
718 PDLConstraintFunction>
719 buildConstraintFn(ConstraintFnT &&constraintFn) {
720 return std::forward<ConstraintFnT>(constraintFn);
724 template <
typename Constra
intFnT>
726 !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
727 PDLConstraintFunction>
728 buildConstraintFn(ConstraintFnT &&constraintFn) {
729 return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
730 PatternRewriter &rewriter, PDLResultList &,
731 ArrayRef<PDLValue> values) -> LogicalResult {
732 auto argIndices = std::make_index_sequence<
733 llvm::function_traits<ConstraintFnT>::num_args - 1>();
734 if (failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
736 return processArgsAndInvokeConstraint(constraintFn, rewriter, values,
747 template <
typename PDLFnT, std::size_t... I,
748 typename FnTraitsT = llvm::function_traits<PDLFnT>>
749 std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value,
751 processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
752 PDLResultList &, ArrayRef<PDLValue> values,
753 std::index_sequence<I...>) {
755 (ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
761 template <
typename PDLFnT, std::size_t... I,
762 typename FnTraitsT = llvm::function_traits<PDLFnT>>
763 std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value,
765 processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
766 PDLResultList &results, ArrayRef<PDLValue> values,
767 std::index_sequence<I...>) {
768 return processResults(
770 fn(rewriter, (ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::
771 processAsArg(values[I]))...));
781 template <
typename RewriteFnT>
782 std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
784 buildRewriteFn(RewriteFnT &&rewriteFn) {
785 return std::forward<RewriteFnT>(rewriteFn);
789 template <
typename RewriteFnT>
790 std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
792 buildRewriteFn(RewriteFnT &&rewriteFn) {
793 return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
794 PatternRewriter &rewriter, PDLResultList &results,
795 ArrayRef<PDLValue> values) {
797 std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
799 assertArgs<RewriteFnT>(rewriter, values, argIndices);
800 return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
815 class PDLPatternModule {
817 PDLPatternModule() =
default;
820 PDLPatternModule(OwningOpRef<ModuleOp> module)
821 : pdlModule(std::move(module)) {}
822 template <
typename... ConfigsT>
823 PDLPatternModule(OwningOpRef<ModuleOp> module, ConfigsT &&...patternConfigs)
824 : PDLPatternModule(std::move(module)) {
825 auto configSet = std::make_unique<PDLPatternConfigSet>(
826 std::forward<ConfigsT>(patternConfigs)...);
827 attachConfigToPatterns(*pdlModule, *configSet);
828 configs.emplace_back(std::move(configSet));
832 void mergeIn(PDLPatternModule &&other);
835 ModuleOp getModule() {
return pdlModule.get(); }
838 MLIRContext *
getContext() {
return getModule()->getContext(); }
865 void registerConstraintFunction(StringRef name,
866 PDLConstraintFunction constraintFn);
867 template <
typename Constra
intFnT>
868 void registerConstraintFunction(StringRef name,
869 ConstraintFnT &&constraintFn) {
870 registerConstraintFunction(name,
871 detail::pdl_function_builder::buildConstraintFn(
872 std::forward<ConstraintFnT>(constraintFn)));
899 void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
900 template <
typename RewriteFnT>
901 void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {
902 registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn(
903 std::forward<RewriteFnT>(rewriteFn)));
907 const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions()
const {
908 return constraintFunctions;
910 llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
911 return constraintFunctions;
914 const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions()
const {
915 return rewriteFunctions;
917 llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
918 return rewriteFunctions;
922 SmallVector<std::unique_ptr<PDLPatternConfigSet>> takeConfigs() {
923 return std::move(configs);
925 DenseMap<Operation *, PDLPatternConfigSet *> takeConfigMap() {
926 return std::move(configMap);
932 constraintFunctions.clear();
933 rewriteFunctions.clear();
939 void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet);
942 OwningOpRef<ModuleOp> pdlModule;
945 SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
946 DenseMap<Operation *, PDLPatternConfigSet *> configMap;
949 llvm::StringMap<PDLConstraintFunction> constraintFunctions;
950 llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
961 template <
typename T>
966 class PDLResultList {};
967 using PDLConstraintFunction = std::function<LogicalResult(
968 PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
969 using PDLRewriteFunction = std::function<LogicalResult(
970 PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
972 class PDLPatternModule {
974 PDLPatternModule() =
default;
976 PDLPatternModule(OwningOpRef<ModuleOp> ) {}
978 llvm_unreachable(
"Error: PDL for rewrites when PDL is not enabled");
980 void mergeIn(PDLPatternModule &&other) {}
982 template <
typename Constra
intFnT>
983 void registerConstraintFunction(StringRef name,
984 ConstraintFnT &&constraintFn) {}
985 void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn) {}
986 template <
typename RewriteFnT>
987 void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {}
988 const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions()
const {
989 return constraintFunctions;
993 llvm::StringMap<PDLConstraintFunction> constraintFunctions;
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Type
An inlay hint that for a type annotation.
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
const FrozenRewritePatternSet GreedyRewriteConfig config
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)