9 #ifndef MLIR_IR_PATTERNMATCH_H
10 #define MLIR_IR_PATTERNMATCH_H
14 #include "llvm/ADT/FunctionExtras.h"
15 #include "llvm/Support/TypeName.h"
20 class PatternRewriter;
34 enum { ImpossibleToMatchSentinel = 65535 };
50 return representation == rhs.representation;
54 return representation < rhs.representation;
61 unsigned short representation{ImpossibleToMatchSentinel};
94 if (rootKind == RootKind::OperationName)
103 if (rootKind == RootKind::InterfaceID)
112 if (rootKind == RootKind::TraitID)
129 return contextAndHasBoundedRecursion.getInt();
134 return contextAndHasBoundedRecursion.getPointer();
150 debugLabels.append(labels.begin(), labels.end());
187 Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
188 PatternBenefit benefit, MLIRContext *context,
189 ArrayRef<StringRef> generatedNames = {});
196 Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit,
197 MLIRContext *context, ArrayRef<StringRef> generatedNames = {});
202 contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg);
206 Pattern(
const void *rootValue, RootKind rootKind,
211 const void *rootValue;
219 llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion;
274 template <
typename T,
typename... Args>
275 static std::unique_ptr<T>
create(Args &&...args) {
276 std::unique_ptr<T> pattern =
277 std::make_unique<T>(std::forward<Args>(args)...);
278 initializePattern<T>(*pattern);
281 if (pattern->getDebugName().empty())
282 pattern->setDebugName(llvm::getTypeName<T>());
292 template <
typename T,
typename... Args>
293 using has_initialize = decltype(std::declval<T>().initialize());
294 template <
typename T>
295 using detect_has_initialize = llvm::is_detected<has_initialize, T>;
298 template <
typename T>
299 static std::enable_if_t<detect_has_initialize<T>::value>
300 initializePattern(T &pattern) {
301 pattern.initialize();
305 template <
typename T>
306 static std::enable_if_t<!detect_has_initialize<T>::value>
307 initializePattern(T &) {}
310 virtual void anchor();
317 template <
typename SourceOp>
319 using RewritePattern::RewritePattern;
323 rewrite(cast<SourceOp>(op), rewriter);
326 return match(cast<SourceOp>(op));
336 llvm_unreachable(
"must override rewrite or matchAndRewrite");
339 llvm_unreachable(
"must override match or matchAndRewrite");
355 template <
typename SourceOp>
364 SourceOp::getOperationName(), benefit, context, generatedNames) {}
370 template <
typename SourceOp>
374 : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
382 template <
template <
typename>
class TraitType>
456 if (
auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
457 rewriteListener->notifyOperationModified(op);
460 if (
auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
461 rewriteListener->notifyOperationReplaced(op, newOp);
465 if (
auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
466 rewriteListener->notifyOperationReplaced(op, replacement);
469 if (
auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
470 rewriteListener->notifyOperationRemoved(op);
475 if (
auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
476 return rewriteListener->notifyMatchFailure(loc, reasonCallback);
512 llvm::unique_function<
bool(
OpOperand &)
const> functor);
514 llvm::unique_function<
bool(
OpOperand &)
const> functor) {
524 bool *allUsesReplaced =
nullptr);
538 template <
typename OpTy,
typename... Args>
540 auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
605 template <
typename CallableT>
618 template <
typename OperandType,
typename ValueT>
620 for (OperandType &operand : llvm::make_early_inc_range(from->
getUses())) {
626 assert(from.size() == to.size() &&
"incorrect number of replacements");
627 for (
auto it : llvm::zip(from, to))
638 assert(from.size() == to.size() &&
"incorrect number of replacements");
639 for (
auto it : llvm::zip(from, to))
649 return user != exceptedUser;
658 template <
typename CallbackT>
659 std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value,
LogicalResult>
662 if (
auto *rewriteListener = dyn_cast_if_present<Listener>(
listener))
663 return rewriteListener->notifyMatchFailure(
670 template <
typename CallbackT>
671 std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value,
LogicalResult>
673 if (
auto *rewriteListener = dyn_cast_if_present<Listener>(
listener))
674 return rewriteListener->notifyMatchFailure(
678 template <
typename ArgT>
683 template <
typename ArgT>
765 template <
typename T>
767 assert(value &&
"isa<> used on a null value");
768 return kind == getKindOf<T>();
773 template <
typename T,
774 typename ResultT = std::conditional_t<
775 std::is_convertible<T, bool>::value, T, std::optional<T>>>
777 return isa<T>() ? castImpl<T>() : ResultT();
782 template <
typename T>
784 assert(isa<T>() &&
"expected value to be of type `T`");
785 return castImpl<T>();
792 explicit operator bool()
const {
return value; }
798 void print(raw_ostream &os)
const;
801 static void print(raw_ostream &os,
Kind kind);
805 template <
typename...>
807 template <
typename T,
typename... R>
808 struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
809 template <
typename T,
typename F,
typename... R>
810 struct index_of_t<T, F, R...>
811 : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
814 template <
typename T>
815 static Kind getKindOf() {
816 return static_cast<Kind>(index_of_t<T, Attribute, Operation *,
Type,
817 TypeRange,
Value, ValueRange>::value);
822 template <
typename T>
823 std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T>
825 return T::getFromOpaquePointer(value);
827 template <
typename T>
828 std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T>
830 return *
reinterpret_cast<T *
>(
const_cast<void *
>(value));
832 template <
typename T>
833 std::enable_if_t<std::is_pointer<T>::value, T> castImpl()
const {
834 return reinterpret_cast<T
>(
const_cast<void *
>(value));
838 const void *value{
nullptr};
874 llvm::OwningArrayRef<Type> storage(value.size());
896 llvm::OwningArrayRef<Value> storage(value.size());
961 template <
typename T>
984 template <
typename... ConfigsT>
991 template <
typename T>
993 const T *config = tryGet<T>();
994 assert(config &&
"configuration not found");
1000 template <
typename T>
1002 for (
const auto &configIt :
configs)
1003 if (
const T *config = dyn_cast<T>(configIt.get()))
1011 for (
const auto &config :
configs)
1012 config->notifyRewriteBegin(rewriter);
1015 for (
const auto &config :
configs)
1016 config->notifyRewriteEnd(rewriter);
1021 template <
typename T>
1023 assert(!
tryGet<std::decay_t<T>>() &&
"configuration already exists");
1025 std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
1052 namespace pdl_function_builder {
1063 template <
class... T>
1099 template <
typename T,
typename Enable =
void>
1119 template <
typename T,
typename BaseT>
1123 PDLValue pdlValue,
size_t argIdx) {
1149 template <
typename T>
1153 PDLValue pdlValue,
size_t argIdx) {
1156 return errorFn(
"expected a non-null value for argument " + Twine(argIdx) +
1157 " of type: " + llvm::getTypeName<T>());
1171 template <
typename T,
typename BaseT>
1175 BaseT baseValue,
size_t argIdx) {
1177 .Case([&](T) {
return success(); })
1178 .Default([&](
BaseT) {
1179 return errorFn(
"expected argument " + Twine(argIdx) +
1180 " to be of type: " + llvm::getTypeName<T>());
1186 return baseValue.template cast<T>();
1201 template <
typename T>
1203 std::enable_if_t<std::is_base_of<Attribute, T>::value>>
1210 static StringRef
processAsArg(StringAttr value) {
return value.getValue(); }
1221 template <
typename T>
1223 static_assert(always_false<T>,
1224 "`std::string` arguments require a string copy, use "
1225 "`StringRef` for string-like arguments instead");
1240 template <
typename T>
1251 template <
typename T>
1274 template <
unsigned N>
1308 template <
unsigned N>
1325 template <
typename PDLFnT, std::size_t... I>
1327 std::index_sequence<I...>) {
1328 using FnTraitsT = llvm::function_traits<PDLFnT>;
1330 auto errorFn = [&](
const Twine &msg) {
1335 verifyAsArg(errorFn, values[I], I)) &&
1342 template <
typename PDLFnT, std::size_t... I>
1344 std::index_sequence<I...>) {
1346 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1347 using FnTraitsT = llvm::function_traits<PDLFnT>;
1349 llvm::report_fatal_error(msg);
1353 verifyAsArg(errorFn, values[I], I)) &&
1364 template <
typename T>
1368 std::forward<T>(value));
1373 template <
typename T1,
typename T2>
1376 std::pair<T1, T2> &&pair) {
1384 template <
typename... Ts>
1387 std::tuple<Ts...> &&tuple) {
1388 auto applyFn = [&](
auto &&...args) {
1392 return success(std::apply(applyFn, std::move(tuple)));
1401 template <
typename T>
1415 template <
typename PDLFnT, std::size_t... I,
1416 typename FnTraitsT = llvm::function_traits<PDLFnT>>
1417 typename FnTraitsT::result_t
1420 std::index_sequence<I...>) {
1423 (
ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
1433 template <
typename Constra
intFnT>
1435 std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
1438 return std::forward<ConstraintFnT>(constraintFn);
1442 template <
typename Constra
intFnT>
1444 !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
1447 return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
1450 auto argIndices = std::make_index_sequence<
1451 llvm::function_traits<ConstraintFnT>::num_args - 1>();
1452 if (
failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
1465 template <
typename PDLFnT, std::size_t... I,
1466 typename FnTraitsT = llvm::function_traits<PDLFnT>>
1467 std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value,
1471 std::index_sequence<I...>) {
1473 (
ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
1479 template <
typename PDLFnT, std::size_t... I,
1480 typename FnTraitsT = llvm::function_traits<PDLFnT>>
1481 std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value,
1485 std::index_sequence<I...>) {
1488 fn(rewriter, (
ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::
1489 processAsArg(values[I]))...));
1499 template <
typename RewriteFnT>
1500 std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
1503 return std::forward<RewriteFnT>(rewriteFn);
1507 template <
typename RewriteFnT>
1508 std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
1511 return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
1515 std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
1517 assertArgs<RewriteFnT>(rewriter, values, argIndices);
1539 : pdlModule(std::move(module)) {}
1540 template <
typename... ConfigsT>
1543 auto configSet = std::make_unique<PDLPatternConfigSet>(
1544 std::forward<ConfigsT>(patternConfigs)...);
1545 attachConfigToPatterns(*pdlModule, *configSet);
1546 configs.emplace_back(std::move(configSet));
1579 template <
typename Constra
intFnT>
1581 ConstraintFnT &&constraintFn) {
1584 std::forward<ConstraintFnT>(constraintFn)));
1612 template <
typename RewriteFnT>
1615 std::forward<RewriteFnT>(rewriteFn)));
1620 return constraintFunctions;
1623 return constraintFunctions;
1627 return rewriteFunctions;
1630 return rewriteFunctions;
1635 return std::move(configs);
1638 return std::move(configMap);
1643 pdlModule =
nullptr;
1644 constraintFunctions.clear();
1645 rewriteFunctions.clear();
1661 llvm::StringMap<PDLConstraintFunction> constraintFunctions;
1662 llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
1670 using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
1677 std::unique_ptr<RewritePattern> pattern)
1678 : context(context) {
1679 nativePatterns.emplace_back(std::move(pattern));
1682 : context(pattern.getModule()->
getContext()),
1683 pdlPatterns(std::move(pattern)) {}
1695 nativePatterns.clear();
1696 pdlPatterns.
clear();
1706 template <
typename... Ts,
typename ConstructorArg,
1707 typename... ConstructorArgs,
1708 typename = std::enable_if_t<
sizeof...(Ts) != 0>>
1712 (addImpl<Ts>(std::nullopt,
1713 std::forward<ConstructorArg>(arg),
1714 std::forward<ConstructorArgs>(args)...),
1722 template <
typename... Ts,
typename ConstructorArg,
1723 typename... ConstructorArgs,
1724 typename = std::enable_if_t<
sizeof...(Ts) != 0>>
1726 ConstructorArg &&arg,
1727 ConstructorArgs &&...args) {
1730 (addImpl<Ts>(debugLabels, arg, args...), ...);
1736 template <
typename... Ts>
1738 (addImpl<Ts>(), ...);
1745 nativePatterns.emplace_back(std::move(pattern));
1752 pdlPatterns.
mergeIn(std::move(pattern));
1757 template <
typename OpType>
1768 LogicalResult matchAndRewrite(OpType op,
1769 PatternRewriter &rewriter)
const override {
1770 return implFn(op, rewriter);
1774 LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
1776 add(std::make_unique<FnPattern>(std::move(implFn),
getContext(), benefit,
1790 template <
typename... Ts,
typename ConstructorArg,
1791 typename... ConstructorArgs,
1792 typename = std::enable_if_t<
sizeof...(Ts) != 0>>
1796 (addImpl<Ts>(std::nullopt, arg, args...), ...);
1802 template <
typename... Ts>
1804 (addImpl<Ts>(), ...);
1811 nativePatterns.emplace_back(std::move(pattern));
1818 pdlPatterns.
mergeIn(std::move(pattern));
1823 template <
typename OpType>
1830 this->setDebugName(llvm::getTypeName<FnPattern>());
1835 return implFn(op, rewriter);
1841 add(std::make_unique<FnPattern>(std::move(implFn),
getContext()));
1848 template <
typename T,
typename... Args>
1849 std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
1851 std::unique_ptr<T> pattern =
1852 RewritePattern::create<T>(std::forward<Args>(args)...);
1853 pattern->addDebugLabels(debugLabels);
1854 nativePatterns.emplace_back(std::move(pattern));
1856 template <
typename T,
typename... Args>
1857 std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
1858 addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
1861 pdlPatterns.
mergeIn(T(std::forward<Args>(args)...));
1864 MLIRContext *
const context;
1865 NativePatternListT nativePatterns;
1866 PDLPatternModule pdlPatterns;
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static std::string diag(const llvm::Value &value)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType::iterator iterator
StringAttr getStringAttr(const Twine &bytes)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
This class represents a single IR object that contains a use list.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
IRRewriter(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
IRRewriter(const OpBuilder &builder)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Listener * listener
The optional listener for events of this builder.
This class represents an operand of an operation.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
This class implements the operand iterators for the Operation class.
static OperationName getFromOpaquePointer(const void *pointer)
Operation is the basic unit of execution within MLIR.
result_range getResults()
OpTy get() const
Allow accessing the internal op.
This class provides a base class for users implementing a type of pattern configuration.
static TypeID getConfigID()
Return the type id used for this configuration.
static bool classof(const PDLPatternConfig *config)
Support LLVM style casting.
This class contains a set of configurations for a specific pattern.
const T & get() const
Get the configuration defined by the given type.
PDLPatternConfigSet(ConfigsT &&...configs)
Construct a set with the given configurations.
const T * tryGet() const
Get the configuration defined by the given type, returns nullptr if the configuration does not exist.
SmallVector< std::unique_ptr< PDLPatternConfig > > configs
The set of configurations for this pattern.
void addConfig(T &&config)
Add a configuration to the set.
PDLPatternConfigSet()=default
void notifyRewriteBegin(PatternRewriter &rewriter)
Notify the configurations within this set at the beginning or end of a rewrite of a matched pattern.
void notifyRewriteEnd(PatternRewriter &rewriter)
An individual configuration for a pattern, which can be accessed by native functions via the PDLPatte...
virtual ~PDLPatternConfig()=default
PDLPatternConfig(TypeID id)
virtual void notifyRewriteEnd(PatternRewriter &rewriter)
virtual void notifyRewriteBegin(PatternRewriter &rewriter)
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
TypeID getTypeID() const
Return the TypeID that represents this configuration.
This class contains all of the necessary data for a set of PDL patterns, or pattern rewrites specifie...
void clear()
Clear out the patterns and functions within this module.
const llvm::StringMap< PDLRewriteFunction > & getRewriteFunctions() const
Return the set of the registered rewrite functions.
llvm::StringMap< PDLConstraintFunction > takeConstraintFunctions()
PDLPatternModule(OwningOpRef< ModuleOp > module, ConfigsT &&...patternConfigs)
PDLPatternModule()=default
void registerConstraintFunction(StringRef name, ConstraintFnT &&constraintFn)
PDLPatternModule(OwningOpRef< ModuleOp > module)
Construct a PDL pattern with the given module and configurations.
void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn)
ModuleOp getModule()
Return the internal PDL module of this pattern.
const llvm::StringMap< PDLConstraintFunction > & getConstraintFunctions() const
Return the set of the registered constraint functions.
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn)
Register a rewrite function with PDL.
SmallVector< std::unique_ptr< PDLPatternConfigSet > > takeConfigs()
Return the set of the registered pattern configs.
void mergeIn(PDLPatternModule &&other)
Merge the state in other into this pattern module.
void registerConstraintFunction(StringRef name, PDLConstraintFunction constraintFn)
Register a constraint function with PDL.
llvm::StringMap< PDLRewriteFunction > takeRewriteFunctions()
DenseMap< Operation *, PDLPatternConfigSet * > takeConfigMap()
The class represents a list of PDL results, returned by a native rewrite method.
void push_back(ValueTypeRange< OperandRange > value)
void push_back(ResultRange value)
void push_back(ValueRange value)
Push a new ValueRange onto the result list.
PDLResultList(unsigned maxNumResults)
Create a new result list with the expected number of results.
SmallVector< llvm::OwningArrayRef< Type > > allocatedTypeRanges
Memory allocated to store ranges in the result list whose lifetime was generated in the native functi...
void push_back(ValueTypeRange< ResultRange > value)
SmallVector< llvm::OwningArrayRef< Value > > allocatedValueRanges
void push_back(Attribute value)
Push a new Attribute value onto the result list.
SmallVector< TypeRange > typeRanges
Memory used to store ranges held by the list.
SmallVector< PDLValue > results
The PDL results held by this list.
void push_back(Type value)
Push a new Type onto the result list.
void push_back(Operation *value)
Push a new Operation onto the result list.
void push_back(OperandRange value)
void push_back(Value value)
Push a new Value onto the result list.
SmallVector< ValueRange > valueRanges
void push_back(TypeRange value)
Push a new TypeRange onto the result list.
Storage type of byte-code interpreter values.
PDLValue(std::nullptr_t=nullptr)
const void * getAsOpaquePointer() const
Get an opaque pointer to the value.
PDLValue(Attribute value)
PDLValue(Operation *value)
Kind getKind() const
Return the kind of this value.
ResultT dyn_cast() const
Attempt to dynamically cast this value to type T, returns null if this value is not an instance of T.
bool isa() const
Returns true if the type of the held value is T.
void print(raw_ostream &os) const
Print this value to the provided output stream.
PDLValue(TypeRange *value)
Kind
The underlying kind of a PDL value.
T cast() const
Cast this value to type T, asserts if this value is not an instance of T.
PDLValue(ValueRange *value)
PDLValue(const PDLValue &other)=default
Construct a new PDL value.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
PatternBenefit & operator=(const PatternBenefit &)=default
bool operator<(const PatternBenefit &rhs) const
bool operator==(const PatternBenefit &rhs) const
static PatternBenefit impossibleToMatch()
bool operator>=(const PatternBenefit &rhs) const
bool operator<=(const PatternBenefit &rhs) const
PatternBenefit(const PatternBenefit &)=default
bool isImpossibleToMatch() const
bool operator!=(const PatternBenefit &rhs) const
bool operator>(const PatternBenefit &rhs) const
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
std::optional< TypeID > getRootInterfaceID() const
Return the interface ID used to match the root operation of this pattern.
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
ArrayRef< StringRef > getDebugLabels() const
Return the set of debug labels attached to this pattern.
void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg=true)
Set the flag detailing if this pattern has bounded rewrite recursion or not.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
PatternBenefit getBenefit() const
Return the benefit (the inverse of "cost") of matching this pattern.
std::optional< TypeID > getRootTraitID() const
Return the trait ID used to match the root operation of this pattern.
void setDebugName(StringRef name)
Set the human readable debug name used for this pattern.
void addDebugLabels(StringRef label)
void addDebugLabels(ArrayRef< StringRef > labels)
Add the provided debug labels to this pattern.
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType::iterator iterator
This class implements the result iterators for the Operation class.
RewritePatternSet & add(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
RewritePatternSet & add(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter), PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
NativePatternListT & getNativePatterns()
Return the native patterns held in this list.
RewritePatternSet & add(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
RewritePatternSet(PDLPatternModule &&pattern)
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
void clear()
Clear out all of the held patterns in this list.
RewritePatternSet(MLIRContext *context)
RewritePatternSet & add()
Add an instance of each of the pattern types 'Ts'.
RewritePatternSet(MLIRContext *context, std::unique_ptr< RewritePattern > pattern)
Construct a RewritePatternSet populated with the given pattern.
RewritePatternSet & insert(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
RewritePatternSet & insert(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & insert()
Add an instance of each of the pattern types 'Ts'.
RewritePatternSet & addWithLabel(ArrayRef< StringRef > debugLabels, ConstructorArg &&arg, ConstructorArgs &&...args)
An overload of the above add method that allows for attaching a set of debug labels to the attached p...
RewritePatternSet & insert(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter))
RewritePattern is the common base class for all DAG to DAG replacements.
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
static std::unique_ptr< T > create(Args &&...args)
This method provides a convenient interface for creating and initializing derived rewrite patterns of...
virtual ~RewritePattern()=default
virtual LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor)
This method replaces the uses of the results of op with the values in newValues when the provided fun...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Operation *op, CallbackT &&reasonCallback)
void replaceOpWithIf(Operation *op, ValueRange newValues, llvm::unique_function< bool(OpOperand &) const > functor)
virtual void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg)
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
RewriterBase(const OpBuilder &otherBuilder)
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
void replaceAllUsesWith(ValueRange from, ValueRange to)
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
This method replaces the uses of the results of op with the values in newValues when a use is nested ...
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void cancelRootUpdate(Operation *op)
This method cancels a pending root update.
LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg)
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
void replaceAllUsesWith(IRObjectWithUseList< OperandType > *from, ValueT &&to)
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
void replaceUsesWithIf(ValueRange from, ValueRange to, function_ref< bool(OpOperand &)> functor)
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class provides an efficient unique identifier for a specific C++ type.
static TypeID getFromOpaquePointer(const void *pointer)
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class implements iteration on the types of a given range of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
detail::ValueImpl * getImpl() const
Operation * getOwner() const
Return the owner of this operand.
FnTraitsT::result_t processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter, ArrayRef< PDLValue > values, std::index_sequence< I... >)
Process the arguments of a native constraint and invoke it.
void assertArgs(PatternRewriter &rewriter, ArrayRef< PDLValue > values, std::index_sequence< I... >)
Assert that the given PDLValues match the constraints defined by the arguments of the given function.
LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef< PDLValue > values, std::index_sequence< I... >)
Validate the given PDLValues match the constraints defined by the argument types of the given functio...
std::enable_if_t< std::is_same< typename FnTraitsT::result_t, void >::value, LogicalResult > processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, PDLResultList &, ArrayRef< PDLValue > values, std::index_sequence< I... >)
Process the arguments of a native rewrite and invoke it.
static LogicalResult processResults(PatternRewriter &rewriter, PDLResultList &results, T &&value)
Store a single result within the result list.
std::enable_if_t< std::is_convertible< ConstraintFnT, PDLConstraintFunction >::value, PDLConstraintFunction > buildConstraintFn(ConstraintFnT &&constraintFn)
Build a constraint function from the given function ConstraintFnT.
constexpr bool always_false
A utility variable that always resolves to false.
std::enable_if_t< std::is_convertible< RewriteFnT, PDLRewriteFunction >::value, PDLRewriteFunction > buildRewriteFn(RewriteFnT &&rewriteFn)
Build a rewrite function from the given function RewriteFnT.
@ Type
An inlay hint that for a type annotation.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::function< LogicalResult(PatternRewriter &, ArrayRef< PDLValue >)> PDLConstraintFunction
A generic PDL pattern constraint function.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
std::function< LogicalResult(PatternRewriter &, PDLResultList &, ArrayRef< PDLValue >)> PDLRewriteFunction
A native PDL rewrite function.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
This class represents an efficient way to signal success or failure.
Base class for listeners.
Kind
The kind of listener.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
virtual void notifyBlockCreated(Block *block)
Notification handler for when a block is created using the builder.
virtual void notifyOperationInserted(Operation *op)
Notification handler for when an operation is inserted into the builder.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This class acts as a special tag that makes the desire to match "any" operation type explicit.
This class acts as a special tag that makes the desire to match any operation that implements a given...
This class acts as a special tag that makes the desire to match any operation that implements a given...
A listener that forwards all notifications to another listener.
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
void notifyOperationInserted(Operation *op) override
Notification handler for when an operation is inserted into the builder.
void notifyOperationReplaced(Operation *op, Operation *newOp) override
Notify the listener that the specified operation is about to be replaced with another operation.
void notifyOperationRemoved(Operation *op) override
Notify the listener that the specified operation is about to be erased.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
void notifyBlockCreated(Block *block) override
Notification handler for when a block is created using the builder.
ForwardingListener(OpBuilder::Listener *listener)
void notifyOperationReplaced(Operation *op, ValueRange replacement) override
Notify the listener that the specified operation is about to be replaced with the a range of values,...
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that the specified operation is about to be replaced with another operation.
virtual LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
virtual void notifyOperationRemoved(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, ValueRange replacement)
Notify the listener that the specified operation is about to be replaced with the a range of values,...
static bool classof(const OpBuilder::Listener *base)
OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that allows for matching and rewri...
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const
Rewrite and Match methods that operate on the SourceOp type.
void rewrite(Operation *op, PatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
virtual LogicalResult match(SourceOp op) const
LogicalResult match(Operation *op) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
virtual LogicalResult matchAndRewrite(SourceOp op, PatternRewriter &rewriter) const
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
This struct provides a simplified model for processing types that have "builtin" PDLValue support:
static void processAsResult(PatternRewriter &, PDLResultList &results, T value)
static LogicalResult verifyAsArg(function_ref< LogicalResult(const Twine &)> errorFn, PDLValue pdlValue, size_t argIdx)
static T processAsArg(PDLValue pdlValue)
This struct provides a simplified model for processing types that inherit from builtin PDLValue types...
static T processAsArg(BaseT baseValue)
static void processAsResult(PatternRewriter &, PDLResultList &results, T value)
static LogicalResult verifyAsArg(function_ref< LogicalResult(const Twine &)> errorFn, BaseT baseValue, size_t argIdx)
This struct provides a simplified model for processing types that are based on another type,...
static T processAsArg(PDLValue pdlValue)
static LogicalResult verifyAsArg(function_ref< LogicalResult(const Twine &)> errorFn, BaseT value, size_t argIdx)
Explicitly add the expected parent API to ensure the parent class implements the necessary API (and d...
static LogicalResult verifyAsArg(function_ref< LogicalResult(const Twine &)> errorFn, PDLValue pdlValue, size_t argIdx)
static T processAsArg(BaseT baseValue)
static void processAsResult(PatternRewriter &, PDLResultList &results, OperandRange values)
static void processAsResult(PatternRewriter &, PDLResultList &results, ResultRange values)
static void processAsResult(PatternRewriter &, PDLResultList &results, SmallVector< Type, N > values)
static void processAsResult(PatternRewriter &, PDLResultList &results, SmallVector< Value, N > values)
static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, StringRef value)
static StringRef processAsArg(StringAttr value)
static T processAsArg(Operation *value)
static void processAsResult(PatternRewriter &, PDLResultList &results, ValueTypeRange< OperandRange > types)
static void processAsResult(PatternRewriter &, PDLResultList &results, ValueTypeRange< ResultRange > types)
static std::string processAsArg(T value)
static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, StringRef value)
This struct provides a convenient way to determine how to process a given type as either a PDL parame...