MLIR  22.0.0git
PatternApplicator.cpp
Go to the documentation of this file.
1 //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an applicator that applies pattern rewrites based upon a
10 // user defined cost model.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 #include "ByteCode.h"
16 #include "llvm/Support/DebugLog.h"
17 
18 #ifndef NDEBUG
19 #include "llvm/ADT/ScopeExit.h"
20 #endif
21 
22 #define DEBUG_TYPE "pattern-application"
23 
24 using namespace mlir;
25 using namespace mlir::detail;
26 
28  const FrozenRewritePatternSet &frozenPatternList)
29  : frozenPatternList(frozenPatternList) {
30  if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
31  mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();
32  bytecode->initializeMutableState(*mutableByteCodeState);
33  }
34 }
36 
37 #ifndef NDEBUG
38 /// Log a message for a pattern that is impossible to match.
39 static void logImpossibleToMatch(const Pattern &pattern) {
40  LDBG() << "Ignoring pattern '" << pattern.getRootKind()
41  << "' because it is impossible to match or cannot lead "
42  "to legal IR (by cost model)";
43 }
44 
45 /// Log IR after pattern application.
47  Operation *isolatedParent =
49  if (isolatedParent)
50  return isolatedParent;
51  return op;
52 }
54  LDBG(2) << "// *** IR Dump After Pattern Application ***\n" << *op << "\n";
55 }
56 #endif
57 
59  // Apply the cost model to the bytecode patterns first, and then the native
60  // patterns.
61  if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
62  for (const auto &it : llvm::enumerate(bytecode->getPatterns()))
63  mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
64  }
65 
66  // Copy over the patterns so that we can sort by benefit based on the cost
67  // model. Patterns that are already impossible to match are ignored.
68  patterns.clear();
69  for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) {
70  for (const RewritePattern *pattern : it.second) {
71  if (pattern->getBenefit().isImpossibleToMatch())
72  LLVM_DEBUG(logImpossibleToMatch(*pattern));
73  else
74  patterns[it.first].push_back(pattern);
75  }
76  }
77  anyOpPatterns.clear();
78  for (const RewritePattern &pattern :
79  frozenPatternList.getMatchAnyOpNativePatterns()) {
80  if (pattern.getBenefit().isImpossibleToMatch())
81  LLVM_DEBUG(logImpossibleToMatch(pattern));
82  else
83  anyOpPatterns.push_back(&pattern);
84  }
85 
86  // Sort the patterns using the provided cost model.
87  llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits;
88  auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) {
89  return benefits[lhs] > benefits[rhs];
90  };
91  auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) {
92  // Special case for one pattern in the list, which is the most common case.
93  if (list.size() == 1) {
94  if (model(*list.front()).isImpossibleToMatch()) {
95  LLVM_DEBUG(logImpossibleToMatch(*list.front()));
96  list.clear();
97  }
98  return;
99  }
100 
101  // Collect the dynamic benefits for the current pattern list.
102  benefits.clear();
103  for (const Pattern *pat : list)
104  benefits.try_emplace(pat, model(*pat));
105 
106  // Sort patterns with highest benefit first, and remove those that are
107  // impossible to match.
108  llvm::stable_sort(list, cmp);
109  while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
110  LLVM_DEBUG(logImpossibleToMatch(*list.back()));
111  list.pop_back();
112  }
113  };
114  for (auto &it : patterns)
115  processPatternList(it.second);
116  processPatternList(anyOpPatterns);
117 }
118 
120  function_ref<void(const Pattern &)> walk) {
121  for (const auto &it : frozenPatternList.getOpSpecificNativePatterns())
122  for (const auto &pattern : it.second)
123  walk(*pattern);
124  for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns())
125  walk(it);
126  if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
127  for (const Pattern &it : bytecode->getPatterns())
128  walk(it);
129  }
130 }
131 
133  Operation *op, PatternRewriter &rewriter,
134  function_ref<bool(const Pattern &)> canApply,
135  function_ref<void(const Pattern &)> onFailure,
136  function_ref<LogicalResult(const Pattern &)> onSuccess) {
137  // Before checking native patterns, first match against the bytecode. This
138  // won't automatically perform any rewrites so there is no need to worry about
139  // conflicts.
141  const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode();
142  if (bytecode)
143  bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState);
144 
145  // Check to see if there are patterns matching this specific operation type.
147  auto patternIt = patterns.find(op->getName());
148  if (patternIt != patterns.end())
149  opPatterns = patternIt->second;
150 
151  // Process the patterns for that match the specific operation type, and any
152  // operation type in an interleaved fashion.
153  unsigned opIt = 0, opE = opPatterns.size();
154  unsigned anyIt = 0, anyE = anyOpPatterns.size();
155  unsigned pdlIt = 0, pdlE = pdlMatches.size();
156  LogicalResult result = failure();
157  do {
158  // Find the next pattern with the highest benefit.
159  const Pattern *bestPattern = nullptr;
160  unsigned *bestPatternIt = &opIt;
161 
162  /// Operation specific patterns.
163  if (opIt < opE)
164  bestPattern = opPatterns[opIt];
165  /// Operation agnostic patterns.
166  if (anyIt < anyE &&
167  (!bestPattern ||
168  bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) {
169  bestPatternIt = &anyIt;
170  bestPattern = anyOpPatterns[anyIt];
171  }
172 
173  const PDLByteCode::MatchResult *pdlMatch = nullptr;
174  /// PDL patterns.
175  if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() <
176  pdlMatches[pdlIt].benefit)) {
177  bestPatternIt = &pdlIt;
178  pdlMatch = &pdlMatches[pdlIt];
179  bestPattern = pdlMatch->pattern;
180  }
181 
182  if (!bestPattern)
183  break;
184 
185  // Update the pattern iterator on failure so that this pattern isn't
186  // attempted again.
187  ++(*bestPatternIt);
188 
189  // Check that the pattern can be applied.
190  if (canApply && !canApply(*bestPattern))
191  continue;
192 
193  // Try to match and rewrite this pattern. The patterns are sorted by
194  // benefit, so if we match we can immediately rewrite. For PDL patterns, the
195  // match has already been performed, we just need to rewrite.
196  bool matched = false;
198  [&]() {
199  rewriter.setInsertionPoint(op);
200 #ifndef NDEBUG
201  // Operation `op` may be invalidated after applying the rewrite
202  // pattern.
203  Operation *dumpRootOp = getDumpRootOp(op);
204 #endif
205  if (pdlMatch) {
206  result =
207  bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
208  } else {
209  LDBG() << "Trying to match \"" << bestPattern->getDebugName()
210  << "\"";
211  const auto *pattern =
212  static_cast<const RewritePattern *>(bestPattern);
213 
214 #ifndef NDEBUG
215  OpBuilder::Listener *oldListener = rewriter.getListener();
216  auto loggingListener =
217  std::make_unique<RewriterBase::PatternLoggingListener>(
218  oldListener, pattern->getDebugName());
219  rewriter.setListener(loggingListener.get());
220  auto resetListenerCallback = llvm::make_scope_exit(
221  [&] { rewriter.setListener(oldListener); });
222 #endif
223  result = pattern->matchAndRewrite(op, rewriter);
224  LDBG() << " -> matchAndRewrite "
225  << (succeeded(result) ? "successful" : "failed");
226  }
227 
228  // Process the result of the pattern application.
229  if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern)))
230  result = failure();
231  if (succeeded(result)) {
232  LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
233  matched = true;
234  return;
235  }
236 
237  // Perform any necessary cleanups.
238  if (onFailure)
239  onFailure(*bestPattern);
240  },
241  {op}, *bestPattern);
242  if (matched)
243  break;
244  } while (true);
245 
246  if (mutableByteCodeState)
247  mutableByteCodeState->cleanupAfterMatchAndRewrite();
248  return result;
249 }
static void logImpossibleToMatch(const Pattern &pattern)
Log a message for a pattern that is impossible to match.
static void logSucessfulPatternApplication(Operation *op)
static Operation * getDumpRootOp(Operation *op)
Log IR after pattern application.
This is the type of Action that is dispatched when a pattern is applied.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const detail::PDLByteCode * getPDLByteCode() const
Return the compiled PDL bytecode held by this list.
iterator_range< llvm::pointee_iterator< NativePatternListT::const_iterator > > getMatchAnyOpNativePatterns() const
Return the "match any" native patterns held by this list.
const OpSpecificNativePatternListT & getOpSpecificNativePatterns() const
Return the op specific native patterns held by this list.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
Definition: MLIRContext.h:274
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:314
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
This class provides the API for ops that are known to be isolated from above.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
void walkAllPatterns(function_ref< void(const Pattern &)> walk)
Walk all of the patterns within the applicator.
void applyCostModel(CostModel model)
Apply a cost model to the patterns within this applicator.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
PatternApplicator(const FrozenRewritePatternSet &frozenPatternList)
bool isImpossibleToMatch() const
Definition: PatternMatch.h:44
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
PatternBenefit getBenefit() const
Return the benefit (the inverse of "cost") of matching this pattern.
Definition: PatternMatch.h:123
StringRef getDebugName() const
Return a readable name for this pattern.
Definition: PatternMatch.h:140
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Definition: ByteCode.h:249
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Definition: ByteCode.h:252
AttrTypeReplacer.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:102
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:283
const PDLByteCodePattern * pattern
Definition: ByteCode.h:244