9 #ifndef MLIR_IR_PATTERNMATCH_H
10 #define MLIR_IR_PATTERNMATCH_H
14 #include "llvm/ADT/FunctionExtras.h"
15 #include "llvm/Support/TypeName.h"
21 class PatternRewriter;
35 enum { ImpossibleToMatchSentinel = 65535 };
51 return representation == rhs.representation;
55 return representation < rhs.representation;
62 unsigned short representation{ImpossibleToMatchSentinel};
95 if (rootKind == RootKind::OperationName)
104 if (rootKind == RootKind::InterfaceID)
113 if (rootKind == RootKind::TraitID)
130 return contextAndHasBoundedRecursion.getInt();
135 return contextAndHasBoundedRecursion.getPointer();
151 debugLabels.append(labels.begin(), labels.end());
188 Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
189 PatternBenefit benefit, MLIRContext *context,
190 ArrayRef<StringRef> generatedNames = {});
197 Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit,
198 MLIRContext *context, ArrayRef<StringRef> generatedNames = {});
203 contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg);
207 Pattern(
const void *rootValue, RootKind rootKind,
212 const void *rootValue;
220 llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion;
266 if (succeeded(
match(op))) {
275 template <
typename T,
typename... Args>
276 static std::unique_ptr<T>
create(Args &&...args) {
277 std::unique_ptr<T> pattern =
278 std::make_unique<T>(std::forward<Args>(args)...);
279 initializePattern<T>(*pattern);
282 if (pattern->getDebugName().empty())
283 pattern->setDebugName(llvm::getTypeName<T>());
293 template <
typename T,
typename... Args>
294 using has_initialize = decltype(std::declval<T>().initialize());
295 template <
typename T>
296 using detect_has_initialize = llvm::is_detected<has_initialize, T>;
299 template <
typename T>
300 static std::enable_if_t<detect_has_initialize<T>::value>
301 initializePattern(T &pattern) {
302 pattern.initialize();
306 template <
typename T>
307 static std::enable_if_t<!detect_has_initialize<T>::value>
308 initializePattern(T &) {}
311 virtual void anchor();
318 template <
typename SourceOp>
320 using RewritePattern::RewritePattern;
324 rewrite(cast<SourceOp>(op), rewriter);
327 return match(cast<SourceOp>(op));
337 llvm_unreachable(
"must override rewrite or matchAndRewrite");
339 virtual LogicalResult
match(SourceOp op)
const {
340 llvm_unreachable(
"must override match or matchAndRewrite");
344 if (succeeded(
match(op))) {
356 template <
typename SourceOp>
365 SourceOp::getOperationName(), benefit, context, generatedNames) {}
371 template <
typename SourceOp>
375 : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
383 template <
template <
typename>
class TraitType>
447 LogicalResult status) {}
474 if (
auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
475 rewriteListener->notifyBlockErased(block);
478 if (
auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
479 rewriteListener->notifyOperationModified(op);
482 if (
auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
483 rewriteListener->notifyOperationReplaced(op, newOp);
487 if (
auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
488 rewriteListener->notifyOperationReplaced(op, replacement);
491 if (
auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
492 rewriteListener->notifyOperationErased(op);
495 if (
auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
496 rewriteListener->notifyPatternBegin(pattern, op);
499 LogicalResult status)
override {
500 if (
auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
501 rewriteListener->notifyPatternEnd(pattern, status);
506 if (
auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
507 rewriteListener->notifyMatchFailure(loc, reasonCallback);
535 template <
typename OpTy,
typename... Args>
537 auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
629 template <
typename CallableT>
651 assert(from.size() == to.size() &&
"incorrect number of replacements");
652 for (
auto it : llvm::zip(from, to))
672 bool *allUsesReplaced =
nullptr);
675 bool *allUsesReplaced =
nullptr);
681 bool *allUsesReplaced =
nullptr) {
690 Block *block,
bool *allUsesReplaced =
nullptr) {
705 return user != exceptedUser;
716 template <
typename CallbackT>
717 std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
719 if (
auto *rewriteListener = dyn_cast_if_present<Listener>(
listener))
720 rewriteListener->notifyMatchFailure(
724 template <
typename CallbackT>
725 std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
727 if (
auto *rewriteListener = dyn_cast_if_present<Listener>(
listener))
728 rewriteListener->notifyMatchFailure(
732 template <
typename ArgT>
737 template <
typename ArgT>
809 using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
816 std::unique_ptr<RewritePattern> pattern)
818 nativePatterns.emplace_back(std::move(pattern));
821 : context(pattern.
getContext()), pdlPatterns(std::move(pattern)) {}
833 nativePatterns.clear();
844 template <
typename... Ts,
typename ConstructorArg,
845 typename... ConstructorArgs,
846 typename = std::enable_if_t<
sizeof...(Ts) != 0>>
850 (addImpl<Ts>(std::nullopt,
851 std::forward<ConstructorArg>(arg),
852 std::forward<ConstructorArgs>(args)...),
860 template <
typename... Ts,
typename ConstructorArg,
861 typename... ConstructorArgs,
862 typename = std::enable_if_t<
sizeof...(Ts) != 0>>
864 ConstructorArg &&arg,
865 ConstructorArgs &&...args) {
868 (addImpl<Ts>(debugLabels, arg, args...), ...);
874 template <
typename... Ts>
876 (addImpl<Ts>(), ...);
883 nativePatterns.emplace_back(std::move(pattern));
890 pdlPatterns.mergeIn(std::move(pattern));
895 template <
typename OpType>
906 LogicalResult matchAndRewrite(OpType op,
907 PatternRewriter &rewriter)
const override {
908 return implFn(op, rewriter);
912 LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
914 add(std::make_unique<FnPattern>(std::move(implFn),
getContext(), benefit,
928 template <
typename... Ts,
typename ConstructorArg,
929 typename... ConstructorArgs,
930 typename = std::enable_if_t<
sizeof...(Ts) != 0>>
934 (addImpl<Ts>(std::nullopt, arg, args...), ...);
940 template <
typename... Ts>
942 (addImpl<Ts>(), ...);
949 nativePatterns.emplace_back(std::move(pattern));
956 pdlPatterns.mergeIn(std::move(pattern));
961 template <
typename OpType>
968 this->setDebugName(llvm::getTypeName<FnPattern>());
971 LogicalResult matchAndRewrite(OpType op,
973 return implFn(op, rewriter);
979 add(std::make_unique<FnPattern>(std::move(implFn),
getContext()));
986 template <
typename T,
typename... Args>
987 std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
989 std::unique_ptr<T> pattern =
990 RewritePattern::create<T>(std::forward<Args>(args)...);
991 pattern->addDebugLabels(debugLabels);
992 nativePatterns.emplace_back(std::move(pattern));
995 template <
typename T,
typename... Args>
996 std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
997 addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
1000 pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
1003 MLIRContext *
const context;
1004 NativePatternListT nativePatterns;
1008 PDLPatternModule pdlPatterns;
static std::string diag(const llvm::Value &value)
A block operand represents an operand that holds a reference to a Block, e.g.
Block represents an ordered list of Operations.
OpListType::iterator iterator
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
IRRewriter(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
IRRewriter(const OpBuilder &builder)
IRRewriter(Operation *op, OpBuilder::Listener *listener=nullptr)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a saved insertion point.
This class helps build Operations.
Listener * listener
The optional listener for events of this builder.
This class represents an operand of an operation.
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
static OperationName getFromOpaquePointer(const void *pointer)
Operation is the basic unit of execution within MLIR.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
PatternBenefit & operator=(const PatternBenefit &)=default
bool operator<(const PatternBenefit &rhs) const
bool operator==(const PatternBenefit &rhs) const
static PatternBenefit impossibleToMatch()
bool operator>=(const PatternBenefit &rhs) const
bool operator<=(const PatternBenefit &rhs) const
PatternBenefit(const PatternBenefit &)=default
bool isImpossibleToMatch() const
bool operator!=(const PatternBenefit &rhs) const
bool operator>(const PatternBenefit &rhs) const
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
PatternRewriter(MLIRContext *ctx)
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
std::optional< TypeID > getRootInterfaceID() const
Return the interface ID used to match the root operation of this pattern.
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
ArrayRef< StringRef > getDebugLabels() const
Return the set of debug labels attached to this pattern.
void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg=true)
Set the flag detailing if this pattern has bounded rewrite recursion or not.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
PatternBenefit getBenefit() const
Return the benefit (the inverse of "cost") of matching this pattern.
std::optional< TypeID > getRootTraitID() const
Return the trait ID used to match the root operation of this pattern.
void setDebugName(StringRef name)
Set the human readable debug name used for this pattern.
void addDebugLabels(StringRef label)
void addDebugLabels(ArrayRef< StringRef > labels)
Add the provided debug labels to this pattern.
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType::iterator iterator
RewritePatternSet & add(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
RewritePatternSet & add(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter), PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
NativePatternListT & getNativePatterns()
Return the native patterns held in this list.
RewritePatternSet & add(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
RewritePatternSet(PDLPatternModule &&pattern)
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
void clear()
Clear out all of the held patterns in this list.
RewritePatternSet(MLIRContext *context)
RewritePatternSet & add()
Add an instance of each of the pattern types 'Ts'.
RewritePatternSet(MLIRContext *context, std::unique_ptr< RewritePattern > pattern)
Construct a RewritePatternSet populated with the given pattern.
RewritePatternSet & insert(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
RewritePatternSet & insert(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & insert()
Add an instance of each of the pattern types 'Ts'.
RewritePatternSet & addWithLabel(ArrayRef< StringRef > debugLabels, ConstructorArg &&arg, ConstructorArgs &&...args)
An overload of the above add method that allows for attaching a set of debug labels to the attached p...
RewritePatternSet & insert(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter))
RewritePattern is the common base class for all DAG to DAG replacements.
virtual LogicalResult match(Operation *op) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
static std::unique_ptr< T > create(Args &&...args)
This method provides a convenient interface for creating and initializing derived rewrite patterns of...
virtual ~RewritePattern()=default
virtual LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const
Rewrite the IR rooted at the specified operation with the result of this pattern, generating any new ...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Operation *op, CallbackT &&reasonCallback)
void replaceOpUsesWithIf(Operation *from, ValueRange to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg)
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
RewriterBase(const OpBuilder &otherBuilder)
void replaceAllUsesWith(ValueRange from, ValueRange to)
void moveBlockBefore(Block *block, Block *anotherBlock)
Unlink this block and insert it right before existingBlock.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
RewriterBase(Operation *op, OpBuilder::Listener *listener=nullptr)
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
void replaceAllUsesWith(Block *from, Block *to)
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg)
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
void replaceOpUsesWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced=nullptr)
Find uses of from within block and replace them with to.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an efficient unique identifier for a specific C++ type.
static TypeID getFromOpaquePointer(const void *pointer)
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getOwner() const
Return the owner of this operand.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Base class for listeners.
Kind
The kind of listener.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This class acts as a special tag that makes the desire to match "any" operation type explicit.
This class acts as a special tag that makes the desire to match any operation that implements a given...
This class acts as a special tag that makes the desire to match any operation that implements a given...
A listener that forwards all notifications to another listener.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match, and provide a callback to populate a diagnostic...
void notifyOperationInserted(Operation *op, InsertPoint previous) override
Notify the listener that the specified operation was inserted.
void notifyPatternEnd(const Pattern &pattern, LogicalResult status) override
Notify the listener that a pattern application finished with the specified status.
void notifyOperationModified(Operation *op) override
Notify the listener that the specified operation was modified in-place.
void notifyPatternBegin(const Pattern &pattern, Operation *op) override
Notify the listener that the specified pattern is about to be applied at the specified root operation...
void notifyOperationReplaced(Operation *op, Operation *newOp) override
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
ForwardingListener(OpBuilder::Listener *listener)
void notifyOperationReplaced(Operation *op, ValueRange replacement) override
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notify the listener that the specified block was inserted.
virtual void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notify the listener that the pattern failed to match, and provide a callback to populate a diagnostic...
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
virtual void notifyPatternEnd(const Pattern &pattern, LogicalResult status)
Notify the listener that a pattern application finished with the specified status.
virtual void notifyPatternBegin(const Pattern &pattern, Operation *op)
Notify the listener that the specified pattern is about to be applied at the specified root operation...
virtual void notifyOperationReplaced(Operation *op, ValueRange replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
static bool classof(const OpBuilder::Listener *base)
OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that allows for matching and rewri...
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const
Rewrite and Match methods that operate on the SourceOp type.
void rewrite(Operation *op, PatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
virtual LogicalResult match(SourceOp op) const
LogicalResult match(Operation *op) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
virtual LogicalResult matchAndRewrite(SourceOp op, PatternRewriter &rewriter) const
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...