9 #ifndef MLIR_IR_PATTERNMATCH_H
10 #define MLIR_IR_PATTERNMATCH_H
14 #include "llvm/ADT/FunctionExtras.h"
15 #include "llvm/Support/TypeName.h"
20 class PatternRewriter;
34 enum { ImpossibleToMatchSentinel = 65535 };
50 return representation == rhs.representation;
54 return representation < rhs.representation;
61 unsigned short representation{ImpossibleToMatchSentinel};
94 if (rootKind == RootKind::OperationName)
103 if (rootKind == RootKind::InterfaceID)
112 if (rootKind == RootKind::TraitID)
129 return contextAndHasBoundedRecursion.getInt();
134 return contextAndHasBoundedRecursion.getPointer();
150 debugLabels.append(labels.begin(), labels.end());
187 Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
188 PatternBenefit benefit, MLIRContext *context,
189 ArrayRef<StringRef> generatedNames = {});
196 Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit,
197 MLIRContext *context, ArrayRef<StringRef> generatedNames = {});
202 contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg);
206 Pattern(
const void *rootValue, RootKind rootKind,
211 const void *rootValue;
219 llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion;
274 template <
typename T,
typename... Args>
275 static std::unique_ptr<T>
create(Args &&...args) {
276 std::unique_ptr<T> pattern =
277 std::make_unique<T>(std::forward<Args>(args)...);
278 initializePattern<T>(*pattern);
281 if (pattern->getDebugName().empty())
282 pattern->setDebugName(llvm::getTypeName<T>());
292 template <
typename T,
typename... Args>
293 using has_initialize = decltype(std::declval<T>().initialize());
294 template <
typename T>
295 using detect_has_initialize = llvm::is_detected<has_initialize, T>;
298 template <
typename T>
299 static std::enable_if_t<detect_has_initialize<T>::value>
300 initializePattern(T &pattern) {
301 pattern.initialize();
305 template <
typename T>
306 static std::enable_if_t<!detect_has_initialize<T>::value>
307 initializePattern(T &) {}
310 virtual void anchor();
317 template <
typename SourceOp>
319 using RewritePattern::RewritePattern;
323 rewrite(cast<SourceOp>(op), rewriter);
326 return match(cast<SourceOp>(op));
336 llvm_unreachable(
"must override rewrite or matchAndRewrite");
339 llvm_unreachable(
"must override match or matchAndRewrite");
355 template <
typename SourceOp>
364 SourceOp::getOperationName(), benefit, context, generatedNames) {}
370 template <
typename SourceOp>
374 : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
382 template <
template <
typename>
class TraitType>
460 llvm::unique_function<
bool(
OpOperand &)
const> functor);
462 llvm::unique_function<
bool(
OpOperand &)
const> functor) {
472 bool *allUsesReplaced =
nullptr);
481 template <
typename OpTy,
typename... Args>
483 auto newOp = create<OpTy>(op->
getLoc(), std::forward<Args>(args)...);
484 replaceOpWithResultsOfAnotherOp(op, newOp.getOperation());
548 template <
typename CallableT>
561 template <
typename OperandType,
typename ValueT>
563 for (OperandType &operand : llvm::make_early_inc_range(from->
getUses())) {
569 assert(from.size() == to.size() &&
"incorrect number of replacements");
570 for (
auto it : llvm::zip(from, to))
586 return user != exceptedUser;
595 template <
typename CallbackT>
596 std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value,
LogicalResult>
599 if (
auto *rewriteListener = dyn_cast_if_present<Listener>(
listener))
600 return rewriteListener->notifyMatchFailure(
607 template <
typename CallbackT>
608 std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value,
LogicalResult>
610 if (
auto *rewriteListener = dyn_cast_if_present<Listener>(
listener))
611 return rewriteListener->notifyMatchFailure(
615 template <
typename ArgT>
620 template <
typename ArgT>
706 template <
typename T>
708 assert(value &&
"isa<> used on a null value");
709 return kind == getKindOf<T>();
714 template <
typename T,
715 typename ResultT = std::conditional_t<
716 std::is_convertible<T, bool>::value, T, std::optional<T>>>
718 return isa<T>() ? castImpl<T>() : ResultT();
723 template <
typename T>
725 assert(isa<T>() &&
"expected value to be of type `T`");
726 return castImpl<T>();
733 explicit operator bool()
const {
return value; }
739 void print(raw_ostream &os)
const;
742 static void print(raw_ostream &os,
Kind kind);
746 template <
typename...>
748 template <
typename T,
typename... R>
749 struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
750 template <
typename T,
typename F,
typename... R>
751 struct index_of_t<T, F, R...>
752 : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
755 template <
typename T>
756 static Kind getKindOf() {
757 return static_cast<Kind>(index_of_t<T, Attribute, Operation *,
Type,
758 TypeRange,
Value, ValueRange>::value);
763 template <
typename T>
764 std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T>
766 return T::getFromOpaquePointer(value);
768 template <
typename T>
769 std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T>
771 return *
reinterpret_cast<T *
>(
const_cast<void *
>(value));
773 template <
typename T>
774 std::enable_if_t<std::is_pointer<T>::value, T> castImpl()
const {
775 return reinterpret_cast<T
>(
const_cast<void *
>(value));
779 const void *value{
nullptr};
815 llvm::OwningArrayRef<Type> storage(value.size());
837 llvm::OwningArrayRef<Value> storage(value.size());
902 template <
typename T>
925 template <
typename... ConfigsT>
932 template <
typename T>
934 const T *config = tryGet<T>();
935 assert(config &&
"configuration not found");
941 template <
typename T>
943 for (
const auto &configIt :
configs)
944 if (
const T *config = dyn_cast<T>(configIt.get()))
952 for (
const auto &config :
configs)
953 config->notifyRewriteBegin(rewriter);
956 for (
const auto &config :
configs)
957 config->notifyRewriteEnd(rewriter);
962 template <
typename T>
964 assert(!
tryGet<std::decay_t<T>>() &&
"configuration already exists");
966 std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
993 namespace pdl_function_builder {
1004 template <
class... T>
1040 template <
typename T,
typename Enable =
void>
1060 template <
typename T,
typename BaseT>
1064 PDLValue pdlValue,
size_t argIdx) {
1090 template <
typename T>
1094 PDLValue pdlValue,
size_t argIdx) {
1097 return errorFn(
"expected a non-null value for argument " + Twine(argIdx) +
1098 " of type: " + llvm::getTypeName<T>());
1112 template <
typename T,
typename BaseT>
1116 BaseT baseValue,
size_t argIdx) {
1118 .Case([&](T) {
return success(); })
1119 .Default([&](
BaseT) {
1120 return errorFn(
"expected argument " + Twine(argIdx) +
1121 " to be of type: " + llvm::getTypeName<T>());
1127 return baseValue.template cast<T>();
1142 template <
typename T>
1144 std::enable_if_t<std::is_base_of<Attribute, T>::value>>
1151 static StringRef
processAsArg(StringAttr value) {
return value.getValue(); }
1162 template <
typename T>
1164 static_assert(always_false<T>,
1165 "`std::string` arguments require a string copy, use "
1166 "`StringRef` for string-like arguments instead");
1181 template <
typename T>
1192 template <
typename T>
1215 template <
unsigned N>
1249 template <
unsigned N>
1266 template <
typename PDLFnT, std::size_t... I>
1268 std::index_sequence<I...>) {
1269 using FnTraitsT = llvm::function_traits<PDLFnT>;
1271 auto errorFn = [&](
const Twine &msg) {
1276 verifyAsArg(errorFn, values[I], I)) &&
1283 template <
typename PDLFnT, std::size_t... I>
1285 std::index_sequence<I...>) {
1287 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1288 using FnTraitsT = llvm::function_traits<PDLFnT>;
1290 llvm::report_fatal_error(msg);
1294 verifyAsArg(errorFn, values[I], I)) &&
1305 template <
typename T>
1309 std::forward<T>(value));
1314 template <
typename T1,
typename T2>
1317 std::pair<T1, T2> &&pair) {
1325 template <
typename... Ts>
1328 std::tuple<Ts...> &&tuple) {
1329 auto applyFn = [&](
auto &&...args) {
1333 return success(std::apply(applyFn, std::move(tuple)));
1342 template <
typename T>
1356 template <
typename PDLFnT, std::size_t... I,
1357 typename FnTraitsT = llvm::function_traits<PDLFnT>>
1358 typename FnTraitsT::result_t
1361 std::index_sequence<I...>) {
1364 (
ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
1374 template <
typename Constra
intFnT>
1376 std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
1379 return std::forward<ConstraintFnT>(constraintFn);
1383 template <
typename Constra
intFnT>
1385 !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
1388 return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
1391 auto argIndices = std::make_index_sequence<
1392 llvm::function_traits<ConstraintFnT>::num_args - 1>();
1393 if (
failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
1406 template <
typename PDLFnT, std::size_t... I,
1407 typename FnTraitsT = llvm::function_traits<PDLFnT>>
1408 std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value,
1412 std::index_sequence<I...>) {
1414 (
ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
1420 template <
typename PDLFnT, std::size_t... I,
1421 typename FnTraitsT = llvm::function_traits<PDLFnT>>
1422 std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value,
1426 std::index_sequence<I...>) {
1429 fn(rewriter, (
ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::
1430 processAsArg(values[I]))...));
1440 template <
typename RewriteFnT>
1441 std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
1444 return std::forward<RewriteFnT>(rewriteFn);
1448 template <
typename RewriteFnT>
1449 std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
1452 return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
1456 std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
1458 assertArgs<RewriteFnT>(rewriter, values, argIndices);
1480 : pdlModule(std::move(module)) {}
1481 template <
typename... ConfigsT>
1484 auto configSet = std::make_unique<PDLPatternConfigSet>(
1485 std::forward<ConfigsT>(patternConfigs)...);
1486 attachConfigToPatterns(*pdlModule, *configSet);
1487 configs.emplace_back(std::move(configSet));
1520 template <
typename Constra
intFnT>
1522 ConstraintFnT &&constraintFn) {
1525 std::forward<ConstraintFnT>(constraintFn)));
1553 template <
typename RewriteFnT>
1556 std::forward<RewriteFnT>(rewriteFn)));
1561 return constraintFunctions;
1564 return constraintFunctions;
1568 return rewriteFunctions;
1571 return rewriteFunctions;
1576 return std::move(configs);
1579 return std::move(configMap);
1584 pdlModule =
nullptr;
1585 constraintFunctions.clear();
1586 rewriteFunctions.clear();
1602 llvm::StringMap<PDLConstraintFunction> constraintFunctions;
1603 llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
1611 using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
1618 std::unique_ptr<RewritePattern> pattern)
1619 : context(context) {
1620 nativePatterns.emplace_back(std::move(pattern));
1623 : context(pattern.getModule()->
getContext()),
1624 pdlPatterns(std::move(pattern)) {}
1636 nativePatterns.clear();
1637 pdlPatterns.
clear();
1647 template <
typename... Ts,
typename ConstructorArg,
1648 typename... ConstructorArgs,
1649 typename = std::enable_if_t<
sizeof...(Ts) != 0>>
1653 (addImpl<Ts>(std::nullopt,
1654 std::forward<ConstructorArg>(arg),
1655 std::forward<ConstructorArgs>(args)...),
1663 template <
typename... Ts,
typename ConstructorArg,
1664 typename... ConstructorArgs,
1665 typename = std::enable_if_t<
sizeof...(Ts) != 0>>
1667 ConstructorArg &&arg,
1668 ConstructorArgs &&...args) {
1671 (addImpl<Ts>(debugLabels, arg, args...), ...);
1677 template <
typename... Ts>
1679 (addImpl<Ts>(), ...);
1686 nativePatterns.emplace_back(std::move(pattern));
1693 pdlPatterns.
mergeIn(std::move(pattern));
1698 template <
typename OpType>
1709 LogicalResult matchAndRewrite(OpType op,
1710 PatternRewriter &rewriter)
const override {
1711 return implFn(op, rewriter);
1715 LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
1717 add(std::make_unique<FnPattern>(std::move(implFn),
getContext(), benefit,
1731 template <
typename... Ts,
typename ConstructorArg,
1732 typename... ConstructorArgs,
1733 typename = std::enable_if_t<
sizeof...(Ts) != 0>>
1737 (addImpl<Ts>(std::nullopt, arg, args...), ...);
1743 template <
typename... Ts>
1745 (addImpl<Ts>(), ...);
1752 nativePatterns.emplace_back(std::move(pattern));
1759 pdlPatterns.
mergeIn(std::move(pattern));
1764 template <
typename OpType>
1771 this->setDebugName(llvm::getTypeName<FnPattern>());
1776 return implFn(op, rewriter);
1782 add(std::make_unique<FnPattern>(std::move(implFn),
getContext()));
1789 template <
typename T,
typename... Args>
1790 std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
1792 std::unique_ptr<T> pattern =
1793 RewritePattern::create<T>(std::forward<Args>(args)...);
1794 pattern->addDebugLabels(debugLabels);
1795 nativePatterns.emplace_back(std::move(pattern));
1797 template <
typename T,
typename... Args>
1798 std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
1799 addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
1802 pdlPatterns.
mergeIn(T(std::forward<Args>(args)...));
1805 MLIRContext *
const context;
1806 NativePatternListT nativePatterns;
1807 PDLPatternModule pdlPatterns;
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static std::string diag(const llvm::Value &value)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType::iterator iterator
StringAttr getStringAttr(const Twine &bytes)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
This class represents a single IR object that contains a use list.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
IRRewriter(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
IRRewriter(const OpBuilder &builder)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Listener * listener
The optional listener for events of this builder.
This class represents an operand of an operation.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
This class implements the operand iterators for the Operation class.
static OperationName getFromOpaquePointer(const void *pointer)
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
OpTy get() const
Allow accessing the internal op.
This class provides a base class for users implementing a type of pattern configuration.
static TypeID getConfigID()
Return the type id used for this configuration.
static bool classof(const PDLPatternConfig *config)
Support LLVM style casting.
This class contains a set of configurations for a specific pattern.
const T & get() const
Get the configuration defined by the given type.
PDLPatternConfigSet(ConfigsT &&...configs)
Construct a set with the given configurations.
const T * tryGet() const
Get the configuration defined by the given type, returns nullptr if the configuration does not exist.
SmallVector< std::unique_ptr< PDLPatternConfig > > configs
The set of configurations for this pattern.
void addConfig(T &&config)
Add a configuration to the set.
PDLPatternConfigSet()=default
void notifyRewriteBegin(PatternRewriter &rewriter)
Notify the configurations within this set at the beginning or end of a rewrite of a matched pattern.
void notifyRewriteEnd(PatternRewriter &rewriter)
An individual configuration for a pattern, which can be accessed by native functions via the PDLPatte...
virtual ~PDLPatternConfig()=default
PDLPatternConfig(TypeID id)
virtual void notifyRewriteEnd(PatternRewriter &rewriter)
virtual void notifyRewriteBegin(PatternRewriter &rewriter)
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
TypeID getTypeID() const
Return the TypeID that represents this configuration.
This class contains all of the necessary data for a set of PDL patterns, or pattern rewrites specifie...
void clear()
Clear out the patterns and functions within this module.
const llvm::StringMap< PDLRewriteFunction > & getRewriteFunctions() const
Return the set of the registered rewrite functions.
llvm::StringMap< PDLConstraintFunction > takeConstraintFunctions()
PDLPatternModule(OwningOpRef< ModuleOp > module, ConfigsT &&...patternConfigs)
PDLPatternModule()=default
void registerConstraintFunction(StringRef name, ConstraintFnT &&constraintFn)
PDLPatternModule(OwningOpRef< ModuleOp > module)
Construct a PDL pattern with the given module and configurations.
void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn)
ModuleOp getModule()
Return the internal PDL module of this pattern.
const llvm::StringMap< PDLConstraintFunction > & getConstraintFunctions() const
Return the set of the registered constraint functions.
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn)
Register a rewrite function with PDL.
SmallVector< std::unique_ptr< PDLPatternConfigSet > > takeConfigs()
Return the set of the registered pattern configs.
void mergeIn(PDLPatternModule &&other)
Merge the state in other into this pattern module.
void registerConstraintFunction(StringRef name, PDLConstraintFunction constraintFn)
Register a constraint function with PDL.
llvm::StringMap< PDLRewriteFunction > takeRewriteFunctions()
DenseMap< Operation *, PDLPatternConfigSet * > takeConfigMap()
The class represents a list of PDL results, returned by a native rewrite method.
void push_back(ValueTypeRange< OperandRange > value)
void push_back(ResultRange value)
void push_back(ValueRange value)
Push a new ValueRange onto the result list.
PDLResultList(unsigned maxNumResults)
Create a new result list with the expected number of results.
SmallVector< llvm::OwningArrayRef< Type > > allocatedTypeRanges
Memory allocated to store ranges in the result list whose lifetime was generated in the native functi...
void push_back(ValueTypeRange< ResultRange > value)
SmallVector< llvm::OwningArrayRef< Value > > allocatedValueRanges
void push_back(Attribute value)
Push a new Attribute value onto the result list.
SmallVector< TypeRange > typeRanges
Memory used to store ranges held by the list.
SmallVector< PDLValue > results
The PDL results held by this list.
void push_back(Type value)
Push a new Type onto the result list.
void push_back(Operation *value)
Push a new Operation onto the result list.
void push_back(OperandRange value)
void push_back(Value value)
Push a new Value onto the result list.
SmallVector< ValueRange > valueRanges
void push_back(TypeRange value)
Push a new TypeRange onto the result list.
Storage type of byte-code interpreter values.
PDLValue(std::nullptr_t=nullptr)
const void * getAsOpaquePointer() const
Get an opaque pointer to the value.
PDLValue(Attribute value)
PDLValue(Operation *value)
Kind getKind() const
Return the kind of this value.
ResultT dyn_cast() const
Attempt to dynamically cast this value to type T, returns null if this value is not an instance of T.
bool isa() const
Returns true if the type of the held value is T.
void print(raw_ostream &os) const
Print this value to the provided output stream.
PDLValue(TypeRange *value)
Kind
The underlying kind of a PDL value.
T cast() const
Cast this value to type T, asserts if this value is not an instance of T.
PDLValue(ValueRange *value)
PDLValue(const PDLValue &other)=default
Construct a new PDL value.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
PatternBenefit & operator=(const PatternBenefit &)=default
bool operator<(const PatternBenefit &rhs) const
bool operator==(const PatternBenefit &rhs) const
static PatternBenefit impossibleToMatch()
bool operator>=(const PatternBenefit &rhs) const
bool operator<=(const PatternBenefit &rhs) const
PatternBenefit(const PatternBenefit &)=default
bool isImpossibleToMatch() const
bool operator!=(const PatternBenefit &rhs) const
bool operator>(const PatternBenefit &rhs) const
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
std::optional< TypeID > getRootInterfaceID() const
Return the interface ID used to match the root operation of this pattern.
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
ArrayRef< StringRef > getDebugLabels() const
Return the set of debug labels attached to this pattern.
void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg=true)
Set the flag detailing if this pattern has bounded rewrite recursion or not.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
PatternBenefit getBenefit() const
Return the benefit (the inverse of "cost") of matching this pattern.
std::optional< TypeID > getRootTraitID() const
Return the trait ID used to match the root operation of this pattern.
void setDebugName(StringRef name)
Set the human readable debug name used for this pattern.
void addDebugLabels(StringRef label)
void addDebugLabels(ArrayRef< StringRef > labels)
Add the provided debug labels to this pattern.
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType::iterator iterator
This class implements the result iterators for the Operation class.
RewritePatternSet & add(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
RewritePatternSet & add(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter), PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
NativePatternListT & getNativePatterns()
Return the native patterns held in this list.
RewritePatternSet & add(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
RewritePatternSet(PDLPatternModule &&pattern)
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
void clear()
Clear out all of the held patterns in this list.
RewritePatternSet(MLIRContext *context)
RewritePatternSet & add()
Add an instance of each of the pattern types 'Ts'.
RewritePatternSet(MLIRContext *context, std::unique_ptr< RewritePattern > pattern)
Construct a RewritePatternSet populated with the given pattern.
RewritePatternSet & insert(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
RewritePatternSet & insert(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & insert()
Add an instance of each of the pattern types 'Ts'.
RewritePatternSet & addWithLabel(ArrayRef< StringRef > debugLabels, ConstructorArg &&arg, ConstructorArgs &&...args)
An overload of the above add method that allows for attaching a set of debug labels to the attached p...
RewritePatternSet & insert(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter))
RewritePattern is the common base class for all DAG to DAG replacements.
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
static std::unique_ptr< T > create(Args &&...args)
This method provides a convenient interface for creating and initializing derived rewrite patterns of...
virtual ~RewritePattern()=default
virtual LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor)
This method replaces the uses of the results of op with the values in newValues when the provided fun...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Operation *op, CallbackT &&reasonCallback)
void replaceOpWithIf(Operation *op, ValueRange newValues, llvm::unique_function< bool(OpOperand &) const > functor)
virtual void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg)
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
RewriterBase(const OpBuilder &otherBuilder)
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
void replaceAllUsesWith(ValueRange from, ValueRange to)
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
This method replaces the uses of the results of op with the values in newValues when a use is nested ...
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void cancelRootUpdate(Operation *op)
This method cancels a pending root update.
LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg)
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
void replaceAllUsesWith(IRObjectWithUseList< OperandType > *from, ValueT &&to)
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class provides an efficient unique identifier for a specific C++ type.
static TypeID getFromOpaquePointer(const void *pointer)
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class implements iteration on the types of a given range of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
detail::ValueImpl * getImpl() const
Operation * getOwner() const
Return the owner of this operand.
FnTraitsT::result_t processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter, ArrayRef< PDLValue > values, std::index_sequence< I... >)
Process the arguments of a native constraint and invoke it.
void assertArgs(PatternRewriter &rewriter, ArrayRef< PDLValue > values, std::index_sequence< I... >)
Assert that the given PDLValues match the constraints defined by the arguments of the given function.
LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef< PDLValue > values, std::index_sequence< I... >)
Validate the given PDLValues match the constraints defined by the argument types of the given functio...
std::enable_if_t< std::is_same< typename FnTraitsT::result_t, void >::value, LogicalResult > processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, PDLResultList &, ArrayRef< PDLValue > values, std::index_sequence< I... >)
Process the arguments of a native rewrite and invoke it.
static LogicalResult processResults(PatternRewriter &rewriter, PDLResultList &results, T &&value)
Store a single result within the result list.
std::enable_if_t< std::is_convertible< ConstraintFnT, PDLConstraintFunction >::value, PDLConstraintFunction > buildConstraintFn(ConstraintFnT &&constraintFn)
Build a constraint function from the given function ConstraintFnT.
constexpr bool always_false
A utility variable that always resolves to false.
std::enable_if_t< std::is_convertible< RewriteFnT, PDLRewriteFunction >::value, PDLRewriteFunction > buildRewriteFn(RewriteFnT &&rewriteFn)
Build a rewrite function from the given function RewriteFnT.
@ Type
An inlay hint that for a type annotation.
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
std::function< LogicalResult(PatternRewriter &, ArrayRef< PDLValue >)> PDLConstraintFunction
A generic PDL pattern constraint function.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
std::function< LogicalResult(PatternRewriter &, PDLResultList &, ArrayRef< PDLValue >)> PDLRewriteFunction
A native PDL rewrite function.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
This class represents an efficient way to signal success or failure.
Base class for listeners.
Kind
The kind of listener.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This class acts as a special tag that makes the desire to match "any" operation type explicit.
This class acts as a special tag that makes the desire to match any operation that implements a given...
This class acts as a special tag that makes the desire to match any operation that implements a given...
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
virtual void notifyOperationRemoved(Operation *op)
This is called on an operation that a rewrite is removing, right before the operation is deleted.
virtual void notifyOperationReplaced(Operation *op, ValueRange replacement)
Notify the listener that the specified operation is about to be replaced with the set of values poten...
static bool classof(const OpBuilder::Listener *base)
OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that allows for matching and rewri...
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const
Rewrite and Match methods that operate on the SourceOp type.
void rewrite(Operation *op, PatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
virtual LogicalResult match(SourceOp op) const
LogicalResult match(Operation *op) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
virtual LogicalResult matchAndRewrite(SourceOp op, PatternRewriter &rewriter) const
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
This struct provides a simplified model for processing types that have "builtin" PDLValue support:
static void processAsResult(PatternRewriter &, PDLResultList &results, T value)
static LogicalResult verifyAsArg(function_ref< LogicalResult(const Twine &)> errorFn, PDLValue pdlValue, size_t argIdx)
static T processAsArg(PDLValue pdlValue)
This struct provides a simplified model for processing types that inherit from builtin PDLValue types...
static T processAsArg(BaseT baseValue)
static void processAsResult(PatternRewriter &, PDLResultList &results, T value)
static LogicalResult verifyAsArg(function_ref< LogicalResult(const Twine &)> errorFn, BaseT baseValue, size_t argIdx)
This struct provides a simplified model for processing types that are based on another type,...
static T processAsArg(PDLValue pdlValue)
static LogicalResult verifyAsArg(function_ref< LogicalResult(const Twine &)> errorFn, BaseT value, size_t argIdx)
Explicitly add the expected parent API to ensure the parent class implements the necessary API (and d...
static LogicalResult verifyAsArg(function_ref< LogicalResult(const Twine &)> errorFn, PDLValue pdlValue, size_t argIdx)
static T processAsArg(BaseT baseValue)
static void processAsResult(PatternRewriter &, PDLResultList &results, OperandRange values)
static void processAsResult(PatternRewriter &, PDLResultList &results, ResultRange values)
static void processAsResult(PatternRewriter &, PDLResultList &results, SmallVector< Type, N > values)
static void processAsResult(PatternRewriter &, PDLResultList &results, SmallVector< Value, N > values)
static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, StringRef value)
static StringRef processAsArg(StringAttr value)
static T processAsArg(Operation *value)
static void processAsResult(PatternRewriter &, PDLResultList &results, ValueTypeRange< OperandRange > types)
static void processAsResult(PatternRewriter &, PDLResultList &results, ValueTypeRange< ResultRange > types)
static std::string processAsArg(T value)
static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, StringRef value)
This struct provides a convenient way to determine how to process a given type as either a PDL parame...