MLIR  19.0.0git
PatternApplicator.cpp
Go to the documentation of this file.
1 //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an applicator that applies pattern rewrites based upon a
10 // user defined cost model.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 #include "ByteCode.h"
16 #include "llvm/Support/Debug.h"
17 
18 #define DEBUG_TYPE "pattern-application"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 
24  const FrozenRewritePatternSet &frozenPatternList)
25  : frozenPatternList(frozenPatternList) {
26  if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
27  mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();
28  bytecode->initializeMutableState(*mutableByteCodeState);
29  }
30 }
32 
33 #ifndef NDEBUG
34 /// Log a message for a pattern that is impossible to match.
35 static void logImpossibleToMatch(const Pattern &pattern) {
36  llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind()
37  << "' because it is impossible to match or cannot lead "
38  "to legal IR (by cost model)\n";
39 }
40 
41 /// Log IR after pattern application.
43  Operation *isolatedParent =
45  if (isolatedParent)
46  return isolatedParent;
47  return op;
48 }
50  llvm::dbgs() << "// *** IR Dump After Pattern Application ***\n";
51  op->dump();
52  llvm::dbgs() << "\n\n";
53 }
54 #endif
55 
57  // Apply the cost model to the bytecode patterns first, and then the native
58  // patterns.
59  if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
60  for (const auto &it : llvm::enumerate(bytecode->getPatterns()))
61  mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
62  }
63 
64  // Copy over the patterns so that we can sort by benefit based on the cost
65  // model. Patterns that are already impossible to match are ignored.
66  patterns.clear();
67  for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) {
68  for (const RewritePattern *pattern : it.second) {
69  if (pattern->getBenefit().isImpossibleToMatch())
70  LLVM_DEBUG(logImpossibleToMatch(*pattern));
71  else
72  patterns[it.first].push_back(pattern);
73  }
74  }
75  anyOpPatterns.clear();
76  for (const RewritePattern &pattern :
77  frozenPatternList.getMatchAnyOpNativePatterns()) {
78  if (pattern.getBenefit().isImpossibleToMatch())
79  LLVM_DEBUG(logImpossibleToMatch(pattern));
80  else
81  anyOpPatterns.push_back(&pattern);
82  }
83 
84  // Sort the patterns using the provided cost model.
85  llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits;
86  auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) {
87  return benefits[lhs] > benefits[rhs];
88  };
89  auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) {
90  // Special case for one pattern in the list, which is the most common case.
91  if (list.size() == 1) {
92  if (model(*list.front()).isImpossibleToMatch()) {
93  LLVM_DEBUG(logImpossibleToMatch(*list.front()));
94  list.clear();
95  }
96  return;
97  }
98 
99  // Collect the dynamic benefits for the current pattern list.
100  benefits.clear();
101  for (const Pattern *pat : list)
102  benefits.try_emplace(pat, model(*pat));
103 
104  // Sort patterns with highest benefit first, and remove those that are
105  // impossible to match.
106  std::stable_sort(list.begin(), list.end(), cmp);
107  while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
108  LLVM_DEBUG(logImpossibleToMatch(*list.back()));
109  list.pop_back();
110  }
111  };
112  for (auto &it : patterns)
113  processPatternList(it.second);
114  processPatternList(anyOpPatterns);
115 }
116 
118  function_ref<void(const Pattern &)> walk) {
119  for (const auto &it : frozenPatternList.getOpSpecificNativePatterns())
120  for (const auto &pattern : it.second)
121  walk(*pattern);
122  for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns())
123  walk(it);
124  if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
125  for (const Pattern &it : bytecode->getPatterns())
126  walk(it);
127  }
128 }
129 
131  Operation *op, PatternRewriter &rewriter,
132  function_ref<bool(const Pattern &)> canApply,
133  function_ref<void(const Pattern &)> onFailure,
134  function_ref<LogicalResult(const Pattern &)> onSuccess) {
135  // Before checking native patterns, first match against the bytecode. This
136  // won't automatically perform any rewrites so there is no need to worry about
137  // conflicts.
139  const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode();
140  if (bytecode)
141  bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState);
142 
143  // Check to see if there are patterns matching this specific operation type.
145  auto patternIt = patterns.find(op->getName());
146  if (patternIt != patterns.end())
147  opPatterns = patternIt->second;
148 
149  // Process the patterns for that match the specific operation type, and any
150  // operation type in an interleaved fashion.
151  unsigned opIt = 0, opE = opPatterns.size();
152  unsigned anyIt = 0, anyE = anyOpPatterns.size();
153  unsigned pdlIt = 0, pdlE = pdlMatches.size();
154  LogicalResult result = failure();
155  do {
156  // Find the next pattern with the highest benefit.
157  const Pattern *bestPattern = nullptr;
158  unsigned *bestPatternIt = &opIt;
159 
160  /// Operation specific patterns.
161  if (opIt < opE)
162  bestPattern = opPatterns[opIt];
163  /// Operation agnostic patterns.
164  if (anyIt < anyE &&
165  (!bestPattern ||
166  bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) {
167  bestPatternIt = &anyIt;
168  bestPattern = anyOpPatterns[anyIt];
169  }
170 
171  const PDLByteCode::MatchResult *pdlMatch = nullptr;
172  /// PDL patterns.
173  if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() <
174  pdlMatches[pdlIt].benefit)) {
175  bestPatternIt = &pdlIt;
176  pdlMatch = &pdlMatches[pdlIt];
177  bestPattern = pdlMatch->pattern;
178  }
179 
180  if (!bestPattern)
181  break;
182 
183  // Update the pattern iterator on failure so that this pattern isn't
184  // attempted again.
185  ++(*bestPatternIt);
186 
187  // Check that the pattern can be applied.
188  if (canApply && !canApply(*bestPattern))
189  continue;
190 
191  // Try to match and rewrite this pattern. The patterns are sorted by
192  // benefit, so if we match we can immediately rewrite. For PDL patterns, the
193  // match has already been performed, we just need to rewrite.
194  bool matched = false;
196  [&]() {
197  rewriter.setInsertionPoint(op);
198 #ifndef NDEBUG
199  // Operation `op` may be invalidated after applying the rewrite
200  // pattern.
201  Operation *dumpRootOp = getDumpRootOp(op);
202 #endif
203  if (pdlMatch) {
204  result =
205  bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
206  } else {
207  LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
208  << bestPattern->getDebugName() << "\"\n");
209 
210  const auto *pattern =
211  static_cast<const RewritePattern *>(bestPattern);
212  result = pattern->matchAndRewrite(op, rewriter);
213 
214  LLVM_DEBUG(llvm::dbgs()
215  << "\"" << bestPattern->getDebugName() << "\" result "
216  << succeeded(result) << "\n");
217  }
218 
219  // Process the result of the pattern application.
220  if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern)))
221  result = failure();
222  if (succeeded(result)) {
223  LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
224  matched = true;
225  return;
226  }
227 
228  // Perform any necessary cleanups.
229  if (onFailure)
230  onFailure(*bestPattern);
231  },
232  {op}, *bestPattern);
233  if (matched)
234  break;
235  } while (true);
236 
237  if (mutableByteCodeState)
238  mutableByteCodeState->cleanupAfterMatchAndRewrite();
239  return result;
240 }
static void logImpossibleToMatch(const Pattern &pattern)
Log a message for a pattern that is impossible to match.
static void logSucessfulPatternApplication(Operation *op)
static Operation * getDumpRootOp(Operation *op)
Log IR after pattern application.
This is the type of Action that is dispatched when a pattern is applied.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const detail::PDLByteCode * getPDLByteCode() const
Return the compiled PDL bytecode held by this list.
iterator_range< llvm::pointee_iterator< NativePatternListT::const_iterator > > getMatchAnyOpNativePatterns() const
Return the "match any" native patterns held by this list.
const OpSpecificNativePatternListT & getOpSpecificNativePatterns() const
Return the op specific native patterns held by this list.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
Definition: MLIRContext.h:259
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
This class provides the API for ops that are known to be isolated from above.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
void walkAllPatterns(function_ref< void(const Pattern &)> walk)
Walk all of the patterns within the applicator.
void applyCostModel(CostModel model)
Apply a cost model to the patterns within this applicator.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
PatternApplicator(const FrozenRewritePatternSet &frozenPatternList)
bool isImpossibleToMatch() const
Definition: PatternMatch.h:44
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:94
PatternBenefit getBenefit() const
Return the benefit (the inverse of "cost") of matching this pattern.
Definition: PatternMatch.h:123
StringRef getDebugName() const
Return a readable name for this pattern.
Definition: PatternMatch.h:140
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Definition: ByteCode.h:249
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Definition: ByteCode.h:252
Detect if any of the given parameter types has a sub-element handler.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:137
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
const PDLByteCodePattern * pattern
Definition: ByteCode.h:244