16 #include "llvm/Support/DebugLog.h"
19 #include "llvm/ADT/ScopeExit.h"
22 #define DEBUG_TYPE "pattern-application"
29 : frozenPatternList(frozenPatternList) {
31 mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();
32 bytecode->initializeMutableState(*mutableByteCodeState);
40 LDBG() <<
"Ignoring pattern '" << pattern.
getRootKind()
41 <<
"' because it is impossible to match or cannot lead "
42 "to legal IR (by cost model)";
50 return isolatedParent;
54 LDBG(2) <<
"// *** IR Dump After Pattern Application ***\n" << *op <<
"\n";
63 mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
74 patterns[it.first].push_back(pattern);
77 anyOpPatterns.clear();
80 if (pattern.getBenefit().isImpossibleToMatch())
83 anyOpPatterns.push_back(&pattern);
87 llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits;
89 return benefits[lhs] > benefits[rhs];
93 if (list.size() == 1) {
94 if (model(*list.front()).isImpossibleToMatch()) {
103 for (
const Pattern *pat : list)
104 benefits.try_emplace(pat, model(*pat));
108 llvm::stable_sort(list, cmp);
109 while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
114 for (
auto &it : patterns)
115 processPatternList(it.second);
116 processPatternList(anyOpPatterns);
122 for (
const auto &pattern : it.second)
127 for (
const Pattern &it : bytecode->getPatterns())
143 bytecode->
match(op, rewriter, pdlMatches, *mutableByteCodeState);
147 auto patternIt = patterns.find(op->
getName());
148 if (patternIt != patterns.end())
149 opPatterns = patternIt->second;
153 unsigned opIt = 0, opE = opPatterns.size();
154 unsigned anyIt = 0, anyE = anyOpPatterns.size();
155 unsigned pdlIt = 0, pdlE = pdlMatches.size();
156 LogicalResult result = failure();
159 const Pattern *bestPattern =
nullptr;
160 unsigned *bestPatternIt = &opIt;
164 bestPattern = opPatterns[opIt];
168 bestPattern->
getBenefit() < anyOpPatterns[anyIt]->getBenefit())) {
169 bestPatternIt = &anyIt;
170 bestPattern = anyOpPatterns[anyIt];
175 if (pdlIt < pdlE && (!bestPattern || bestPattern->
getBenefit() <
176 pdlMatches[pdlIt].benefit)) {
177 bestPatternIt = &pdlIt;
178 pdlMatch = &pdlMatches[pdlIt];
179 bestPattern = pdlMatch->
pattern;
190 if (canApply && !canApply(*bestPattern))
196 bool matched =
false;
207 bytecode->
rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
209 LDBG() <<
"Trying to match \"" << bestPattern->
getDebugName()
211 const auto *pattern =
216 auto loggingListener =
217 std::make_unique<RewriterBase::PatternLoggingListener>(
218 oldListener, pattern->getDebugName());
220 auto resetListenerCallback = llvm::make_scope_exit(
223 result = pattern->matchAndRewrite(op, rewriter);
224 LDBG() <<
" -> matchAndRewrite "
225 << (succeeded(result) ?
"successful" :
"failed");
229 if (succeeded(result) && onSuccess &&
failed(onSuccess(*bestPattern)))
231 if (succeeded(result)) {
239 onFailure(*bestPattern);
246 if (mutableByteCodeState)
247 mutableByteCodeState->cleanupAfterMatchAndRewrite();
static void logImpossibleToMatch(const Pattern &pattern)
Log a message for a pattern that is impossible to match.
static void logSucessfulPatternApplication(Operation *op)
static Operation * getDumpRootOp(Operation *op)
Log IR after pattern application.
This is the type of Action that is dispatched when a pattern is applied.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const detail::PDLByteCode * getPDLByteCode() const
Return the compiled PDL bytecode held by this list.
iterator_range< llvm::pointee_iterator< NativePatternListT::const_iterator > > getMatchAnyOpNativePatterns() const
Return the "match any" native patterns held by this list.
const OpSpecificNativePatternListT & getOpSpecificNativePatterns() const
Return the op specific native patterns held by this list.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
This class provides the API for ops that are known to be isolated from above.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
OperationName getName()
The name of an operation is the key identifier for it.
void walkAllPatterns(function_ref< void(const Pattern &)> walk)
Walk all of the patterns within the applicator.
void applyCostModel(CostModel model)
Apply a cost model to the patterns within this applicator.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
PatternApplicator(const FrozenRewritePatternSet &frozenPatternList)
bool isImpossibleToMatch() const
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
PatternBenefit getBenefit() const
Return the benefit (the inverse of "cost") of matching this pattern.
StringRef getDebugName() const
Return a readable name for this pattern.
RewritePattern is the common base class for all DAG to DAG replacements.
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
const PDLByteCodePattern * pattern