MLIR 22.0.0git
PatternApplicator.cpp
Go to the documentation of this file.
1//===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements an applicator that applies pattern rewrites based upon a
10// user defined cost model.
11//
12//===----------------------------------------------------------------------===//
13
15#include "ByteCode.h"
16#include "llvm/Support/DebugLog.h"
17
18#ifndef NDEBUG
19#include "llvm/ADT/ScopeExit.h"
20#endif
21
22#define DEBUG_TYPE "pattern-application"
23
24using namespace mlir;
25using namespace mlir::detail;
26
28 const FrozenRewritePatternSet &frozenPatternList)
29 : frozenPatternList(frozenPatternList) {
30 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
31 mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();
32 bytecode->initializeMutableState(*mutableByteCodeState);
33 }
34}
36
37#ifndef NDEBUG
38/// Log a message for a pattern that is impossible to match.
39static void logImpossibleToMatch(const Pattern &pattern) {
40 LDBG() << "Ignoring pattern '" << pattern.getRootKind()
41 << "' because it is impossible to match or cannot lead "
42 "to legal IR (by cost model)";
43}
44
45/// Log IR after pattern application.
47 Operation *isolatedParent =
49 if (isolatedParent)
50 return isolatedParent;
51 return op;
52}
54 LDBG(2) << "// *** IR Dump After Pattern Application ***\n" << *op << "\n";
55}
56#endif
57
59 // Apply the cost model to the bytecode patterns first, and then the native
60 // patterns.
61 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
62 for (const auto &it : llvm::enumerate(bytecode->getPatterns()))
63 mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
64 }
65
66 // Copy over the patterns so that we can sort by benefit based on the cost
67 // model. Patterns that are already impossible to match are ignored.
68 patterns.clear();
69 for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) {
70 for (const RewritePattern *pattern : it.second) {
71 if (pattern->getBenefit().isImpossibleToMatch())
72 LLVM_DEBUG(logImpossibleToMatch(*pattern));
73 else
74 patterns[it.first].push_back(pattern);
75 }
76 }
77 anyOpPatterns.clear();
78 for (const RewritePattern &pattern :
79 frozenPatternList.getMatchAnyOpNativePatterns()) {
80 if (pattern.getBenefit().isImpossibleToMatch())
81 LLVM_DEBUG(logImpossibleToMatch(pattern));
82 else
83 anyOpPatterns.push_back(&pattern);
84 }
85
86 // Sort the patterns using the provided cost model.
87 llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits;
88 auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) {
89 return benefits[lhs] > benefits[rhs];
90 };
91 auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) {
92 // Special case for one pattern in the list, which is the most common case.
93 if (list.size() == 1) {
94 if (model(*list.front()).isImpossibleToMatch()) {
95 LLVM_DEBUG(logImpossibleToMatch(*list.front()));
96 list.clear();
97 }
98 return;
99 }
100
101 // Collect the dynamic benefits for the current pattern list.
102 benefits.clear();
103 for (const Pattern *pat : list)
104 benefits.try_emplace(pat, model(*pat));
105
106 // Sort patterns with highest benefit first, and remove those that are
107 // impossible to match.
108 llvm::stable_sort(list, cmp);
109 while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
110 LLVM_DEBUG(logImpossibleToMatch(*list.back()));
111 list.pop_back();
112 }
113 };
114 for (auto &it : patterns)
115 processPatternList(it.second);
116 processPatternList(anyOpPatterns);
117}
118
120 function_ref<void(const Pattern &)> walk) {
121 for (const auto &it : frozenPatternList.getOpSpecificNativePatterns())
122 for (const auto &pattern : it.second)
123 walk(*pattern);
124 for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns())
125 walk(it);
126 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
127 for (const Pattern &it : bytecode->getPatterns())
128 walk(it);
129 }
130}
131
133 Operation *op, PatternRewriter &rewriter,
134 function_ref<bool(const Pattern &)> canApply,
135 function_ref<void(const Pattern &)> onFailure,
136 function_ref<LogicalResult(const Pattern &)> onSuccess) {
137 // Before checking native patterns, first match against the bytecode. This
138 // won't automatically perform any rewrites so there is no need to worry about
139 // conflicts.
141 const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode();
142 if (bytecode)
143 bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState);
144
145 // Check to see if there are patterns matching this specific operation type.
147 auto patternIt = patterns.find(op->getName());
148 if (patternIt != patterns.end())
149 opPatterns = patternIt->second;
150
151 // Process the patterns for that match the specific operation type, and any
152 // operation type in an interleaved fashion.
153 unsigned opIt = 0, opE = opPatterns.size();
154 unsigned anyIt = 0, anyE = anyOpPatterns.size();
155 unsigned pdlIt = 0, pdlE = pdlMatches.size();
156 LogicalResult result = failure();
157 do {
158 // Find the next pattern with the highest benefit.
159 const Pattern *bestPattern = nullptr;
160 unsigned *bestPatternIt = &opIt;
161
162 /// Operation specific patterns.
163 if (opIt < opE)
164 bestPattern = opPatterns[opIt];
165 /// Operation agnostic patterns.
166 if (anyIt < anyE &&
167 (!bestPattern ||
168 bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) {
169 bestPatternIt = &anyIt;
170 bestPattern = anyOpPatterns[anyIt];
171 }
172
173 const PDLByteCode::MatchResult *pdlMatch = nullptr;
174 /// PDL patterns.
175 if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() <
176 pdlMatches[pdlIt].benefit)) {
177 bestPatternIt = &pdlIt;
178 pdlMatch = &pdlMatches[pdlIt];
179 bestPattern = pdlMatch->pattern;
180 }
181
182 if (!bestPattern)
183 break;
184
185 // Update the pattern iterator on failure so that this pattern isn't
186 // attempted again.
187 ++(*bestPatternIt);
188
189 // Check that the pattern can be applied.
190 if (canApply && !canApply(*bestPattern))
191 continue;
192
193 // Try to match and rewrite this pattern. The patterns are sorted by
194 // benefit, so if we match we can immediately rewrite. For PDL patterns, the
195 // match has already been performed, we just need to rewrite.
196 bool matched = false;
198 [&]() {
199 rewriter.setInsertionPoint(op);
200#ifndef NDEBUG
201 // Operation `op` may be invalidated after applying the rewrite
202 // pattern.
203 Operation *dumpRootOp = getDumpRootOp(op);
204#endif
205 if (pdlMatch) {
206 result =
207 bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
208 } else {
209 LDBG() << "Trying to match \"" << bestPattern->getDebugName()
210 << "\"";
211 const auto *pattern =
212 static_cast<const RewritePattern *>(bestPattern);
213
214#ifndef NDEBUG
215 OpBuilder::Listener *oldListener = rewriter.getListener();
216 auto loggingListener =
217 std::make_unique<RewriterBase::PatternLoggingListener>(
218 oldListener, pattern->getDebugName());
219 rewriter.setListener(loggingListener.get());
220 auto resetListenerCallback = llvm::make_scope_exit(
221 [&] { rewriter.setListener(oldListener); });
222#endif
223 result = pattern->matchAndRewrite(op, rewriter);
224 LDBG() << " -> matchAndRewrite "
225 << (succeeded(result) ? "successful" : "failed");
226 }
227
228 // Process the result of the pattern application.
229 if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern)))
230 result = failure();
231 if (succeeded(result)) {
232 LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
233 matched = true;
234 return;
235 }
236
237 // Perform any necessary cleanups.
238 if (onFailure)
239 onFailure(*bestPattern);
240 },
241 {op}, *bestPattern);
242 if (matched)
243 break;
244 } while (true);
245
246 if (mutableByteCodeState)
247 mutableByteCodeState->cleanupAfterMatchAndRewrite();
248 return result;
249}
lhs
static Operation * getDumpRootOp(Operation *op)
Log IR after pattern application.
static void logImpossibleToMatch(const Pattern &pattern)
Log a message for a pattern that is impossible to match.
static void logSucessfulPatternApplication(Operation *op)
This is the type of Action that is dispatched when a pattern is applied.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition Builders.h:316
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:320
This class provides the API for ops that are known to be isolated from above.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
void walkAllPatterns(function_ref< void(const Pattern &)> walk)
Walk all of the patterns within the applicator.
void applyCostModel(CostModel model)
Apply a cost model to the patterns within this applicator.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
PatternApplicator(const FrozenRewritePatternSet &frozenPatternList)
function_ref< PatternBenefit(const Pattern &)> CostModel
The cost model dynamically assigns a PatternBenefit to a particular pattern.
bool isImpossibleToMatch() const
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
PatternBenefit getBenefit() const
Return the benefit (the inverse of "cost") of matching this pattern.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
StringRef getDebugName() const
Return a readable name for this pattern.
RewritePattern is the common base class for all DAG to DAG replacements.
AttrTypeReplacer.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition Builders.h:285
const PDLByteCodePattern * pattern
Definition ByteCode.h:243