16 #include "llvm/Support/Debug.h"
18 #define DEBUG_TYPE "pattern-application"
25 : frozenPatternList(frozenPatternList) {
27 mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();
28 bytecode->initializeMutableState(*mutableByteCodeState);
36 llvm::dbgs() <<
"Ignoring pattern '" << pattern.
getRootKind()
37 <<
"' because it is impossible to match or cannot lead "
38 "to legal IR (by cost model)\n";
46 return isolatedParent;
50 llvm::dbgs() <<
"// *** IR Dump After Pattern Application ***\n";
52 llvm::dbgs() <<
"\n\n";
61 mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
72 patterns[it.first].push_back(pattern);
75 anyOpPatterns.clear();
78 if (pattern.getBenefit().isImpossibleToMatch())
81 anyOpPatterns.push_back(&pattern);
85 llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits;
87 return benefits[lhs] > benefits[rhs];
91 if (list.size() == 1) {
92 if (model(*list.front()).isImpossibleToMatch()) {
101 for (
const Pattern *pat : list)
102 benefits.try_emplace(pat, model(*pat));
106 std::stable_sort(list.begin(), list.end(), cmp);
107 while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
112 for (
auto &it : patterns)
113 processPatternList(it.second);
114 processPatternList(anyOpPatterns);
120 for (
const auto &pattern : it.second)
125 for (
const Pattern &it : bytecode->getPatterns())
141 bytecode->
match(op, rewriter, pdlMatches, *mutableByteCodeState);
145 auto patternIt = patterns.find(op->
getName());
146 if (patternIt != patterns.end())
147 opPatterns = patternIt->second;
151 unsigned opIt = 0, opE = opPatterns.size();
152 unsigned anyIt = 0, anyE = anyOpPatterns.size();
153 unsigned pdlIt = 0, pdlE = pdlMatches.size();
154 LogicalResult result = failure();
157 const Pattern *bestPattern =
nullptr;
158 unsigned *bestPatternIt = &opIt;
162 bestPattern = opPatterns[opIt];
166 bestPattern->
getBenefit() < anyOpPatterns[anyIt]->getBenefit())) {
167 bestPatternIt = &anyIt;
168 bestPattern = anyOpPatterns[anyIt];
173 if (pdlIt < pdlE && (!bestPattern || bestPattern->
getBenefit() <
174 pdlMatches[pdlIt].benefit)) {
175 bestPatternIt = &pdlIt;
176 pdlMatch = &pdlMatches[pdlIt];
177 bestPattern = pdlMatch->
pattern;
188 if (canApply && !canApply(*bestPattern))
194 bool matched =
false;
205 bytecode->
rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
207 LLVM_DEBUG(llvm::dbgs() <<
"Trying to match \""
210 const auto *pattern =
214 LLVM_DEBUG(llvm::dbgs()
216 << succeeded(result) <<
"\n");
220 if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern)))
222 if (succeeded(result)) {
230 onFailure(*bestPattern);
237 if (mutableByteCodeState)
238 mutableByteCodeState->cleanupAfterMatchAndRewrite();
static void logImpossibleToMatch(const Pattern &pattern)
Log a message for a pattern that is impossible to match.
static void logSucessfulPatternApplication(Operation *op)
static Operation * getDumpRootOp(Operation *op)
Log IR after pattern application.
This is the type of Action that is dispatched when a pattern is applied.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const detail::PDLByteCode * getPDLByteCode() const
Return the compiled PDL bytecode held by this list.
iterator_range< llvm::pointee_iterator< NativePatternListT::const_iterator > > getMatchAnyOpNativePatterns() const
Return the "match any" native patterns held by this list.
const OpSpecificNativePatternListT & getOpSpecificNativePatterns() const
Return the op specific native patterns held by this list.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class provides the API for ops that are known to be isolated from above.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
OperationName getName()
The name of an operation is the key identifier for it.
void walkAllPatterns(function_ref< void(const Pattern &)> walk)
Walk all of the patterns within the applicator.
void applyCostModel(CostModel model)
Apply a cost model to the patterns within this applicator.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
PatternApplicator(const FrozenRewritePatternSet &frozenPatternList)
bool isImpossibleToMatch() const
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
PatternBenefit getBenefit() const
Return the benefit (the inverse of "cost") of matching this pattern.
StringRef getDebugName() const
Return a readable name for this pattern.
RewritePattern is the common base class for all DAG to DAG replacements.
virtual LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
const PDLByteCodePattern * pattern