16 #include "llvm/Support/Debug.h"
18 #define DEBUG_TYPE "pattern-application"
25 : frozenPatternList(frozenPatternList) {
27 mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();
28 bytecode->initializeMutableState(*mutableByteCodeState);
36 llvm::dbgs() <<
"Ignoring pattern '" << pattern.
getRootKind()
37 <<
"' because it is impossible to match or cannot lead "
38 "to legal IR (by cost model)\n";
46 llvm::dbgs() <<
"// *** IR Dump After Pattern Application ***\n";
48 llvm::dbgs() <<
"\n\n";
57 mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
68 patterns[it.first].push_back(pattern);
71 anyOpPatterns.clear();
74 if (pattern.getBenefit().isImpossibleToMatch())
77 anyOpPatterns.push_back(&pattern);
81 llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits;
83 return benefits[lhs] > benefits[rhs];
87 if (list.size() == 1) {
88 if (model(*list.front()).isImpossibleToMatch()) {
98 benefits.try_emplace(pat, model(*pat));
102 std::stable_sort(list.begin(), list.end(), cmp);
103 while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
108 for (
auto &it : patterns)
109 processPatternList(it.second);
110 processPatternList(anyOpPatterns);
116 for (
const auto &pattern : it.second)
121 for (
const Pattern &it : bytecode->getPatterns())
137 bytecode->
match(op, rewriter, pdlMatches, *mutableByteCodeState);
141 auto patternIt = patterns.find(op->
getName());
142 if (patternIt != patterns.end())
143 opPatterns = patternIt->second;
147 unsigned opIt = 0, opE = opPatterns.size();
148 unsigned anyIt = 0, anyE = anyOpPatterns.size();
149 unsigned pdlIt = 0, pdlE = pdlMatches.size();
153 const Pattern *bestPattern =
nullptr;
154 unsigned *bestPatternIt = &opIt;
159 bestPattern = opPatterns[opIt];
163 bestPattern->
getBenefit() < anyOpPatterns[anyIt]->getBenefit())) {
164 bestPatternIt = &anyIt;
165 bestPattern = anyOpPatterns[anyIt];
168 if (pdlIt < pdlE && (!bestPattern || bestPattern->
getBenefit() <
169 pdlMatches[pdlIt].benefit)) {
170 bestPatternIt = &pdlIt;
171 pdlMatch = &pdlMatches[pdlIt];
172 bestPattern = pdlMatch->
pattern;
182 if (canApply && !canApply(*bestPattern))
188 bool matched =
false;
199 bytecode->
rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
201 LLVM_DEBUG(llvm::dbgs() <<
"Trying to match \""
204 const auto *pattern =
206 result = pattern->matchAndRewrite(op, rewriter);
208 LLVM_DEBUG(llvm::dbgs()
224 onFailure(*bestPattern);
231 if (mutableByteCodeState)
232 mutableByteCodeState->cleanupAfterMatchAndRewrite();
static void logImpossibleToMatch(const Pattern &pattern)
Log a message for a pattern that is impossible to match.
static void logSucessfulPatternApplication(Operation *op)
static Operation * getDumpRootOp(Operation *op)
Log IR after pattern application.
This is the type of Action that is dispatched when a pattern is applied.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const detail::PDLByteCode * getPDLByteCode() const
Return the compiled PDL bytecode held by this list.
iterator_range< llvm::pointee_iterator< NativePatternListT::const_iterator > > getMatchAnyOpNativePatterns() const
Return the "match any" native patterns held by this list.
const OpSpecificNativePatternListT & getOpSpecificNativePatterns() const
Return the op specific native patterns held by this list.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class provides the API for ops that are known to be isolated from above.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
OperationName getName()
The name of an operation is the key identifier for it.
void walkAllPatterns(function_ref< void(const Pattern &)> walk)
Walk all of the patterns within the applicator.
void applyCostModel(CostModel model)
Apply a cost model to the patterns within this applicator.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
PatternApplicator(const FrozenRewritePatternSet &frozenPatternList)
bool isImpossibleToMatch() const
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
PatternBenefit getBenefit() const
Return the benefit (the inverse of "cost") of matching this pattern.
StringRef getDebugName() const
Return a readable name for this pattern.
RewritePattern is the common base class for all DAG to DAG replacements.
The bytecode class is also the interpreter.
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Run the pattern matcher on the given root operation, collecting the matched patterns in matches.
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Run the rewriter of the given pattern that was previously matched in match.
Detect if any of the given parameter types has a sub-element handler.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Each successful match returns a MatchResult, which contains information necessary to execute the rewr...
const PDLByteCodePattern * pattern
The originating pattern that was matched.