MLIR  14.0.0git
PatternApplicator.cpp
Go to the documentation of this file.
1 //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an applicator that applies pattern rewrites based upon a
10 // user defined cost model.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 #include "ByteCode.h"
16 #include "llvm/Support/Debug.h"
17 
18 #define DEBUG_TYPE "pattern-application"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 
24  const FrozenRewritePatternSet &frozenPatternList)
25  : frozenPatternList(frozenPatternList) {
26  if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
27  mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();
28  bytecode->initializeMutableState(*mutableByteCodeState);
29  }
30 }
32 
33 #ifndef NDEBUG
34 /// Log a message for a pattern that is impossible to match.
35 static void logImpossibleToMatch(const Pattern &pattern) {
36  llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind()
37  << "' because it is impossible to match or cannot lead "
38  "to legal IR (by cost model)\n";
39 }
40 
41 /// Log IR after pattern application.
44 }
46  llvm::dbgs() << "// *** IR Dump After Pattern Application ***\n";
47  op->dump();
48  llvm::dbgs() << "\n\n";
49 }
50 #endif
51 
53  // Apply the cost model to the bytecode patterns first, and then the native
54  // patterns.
55  if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
56  for (const auto &it : llvm::enumerate(bytecode->getPatterns()))
57  mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
58  }
59 
60  // Copy over the patterns so that we can sort by benefit based on the cost
61  // model. Patterns that are already impossible to match are ignored.
62  patterns.clear();
63  for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) {
64  for (const RewritePattern *pattern : it.second) {
65  if (pattern->getBenefit().isImpossibleToMatch())
66  LLVM_DEBUG(logImpossibleToMatch(*pattern));
67  else
68  patterns[it.first].push_back(pattern);
69  }
70  }
71  anyOpPatterns.clear();
72  for (const RewritePattern &pattern :
73  frozenPatternList.getMatchAnyOpNativePatterns()) {
74  if (pattern.getBenefit().isImpossibleToMatch())
75  LLVM_DEBUG(logImpossibleToMatch(pattern));
76  else
77  anyOpPatterns.push_back(&pattern);
78  }
79 
80  // Sort the patterns using the provided cost model.
81  llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits;
82  auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) {
83  return benefits[lhs] > benefits[rhs];
84  };
85  auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) {
86  // Special case for one pattern in the list, which is the most common case.
87  if (list.size() == 1) {
88  if (model(*list.front()).isImpossibleToMatch()) {
89  LLVM_DEBUG(logImpossibleToMatch(*list.front()));
90  list.clear();
91  }
92  return;
93  }
94 
95  // Collect the dynamic benefits for the current pattern list.
96  benefits.clear();
97  for (const Pattern *pat : list)
98  benefits.try_emplace(pat, model(*pat));
99 
100  // Sort patterns with highest benefit first, and remove those that are
101  // impossible to match.
102  std::stable_sort(list.begin(), list.end(), cmp);
103  while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
104  LLVM_DEBUG(logImpossibleToMatch(*list.back()));
105  list.pop_back();
106  }
107  };
108  for (auto &it : patterns)
109  processPatternList(it.second);
110  processPatternList(anyOpPatterns);
111 }
112 
114  function_ref<void(const Pattern &)> walk) {
115  for (const auto &it : frozenPatternList.getOpSpecificNativePatterns())
116  for (const auto &pattern : it.second)
117  walk(*pattern);
118  for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns())
119  walk(it);
120  if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
121  for (const Pattern &it : bytecode->getPatterns())
122  walk(it);
123  }
124 }
125 
127  Operation *op, PatternRewriter &rewriter,
128  function_ref<bool(const Pattern &)> canApply,
129  function_ref<void(const Pattern &)> onFailure,
130  function_ref<LogicalResult(const Pattern &)> onSuccess) {
131  // Before checking native patterns, first match against the bytecode. This
132  // won't automatically perform any rewrites so there is no need to worry about
133  // conflicts.
135  const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode();
136  if (bytecode)
137  bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState);
138 
139  // Check to see if there are patterns matching this specific operation type.
141  auto patternIt = patterns.find(op->getName());
142  if (patternIt != patterns.end())
143  opPatterns = patternIt->second;
144 
145  // Process the patterns for that match the specific operation type, and any
146  // operation type in an interleaved fashion.
147  unsigned opIt = 0, opE = opPatterns.size();
148  unsigned anyIt = 0, anyE = anyOpPatterns.size();
149  unsigned pdlIt = 0, pdlE = pdlMatches.size();
150  LogicalResult result = failure();
151  do {
152  // Find the next pattern with the highest benefit.
153  const Pattern *bestPattern = nullptr;
154  unsigned *bestPatternIt = &opIt;
155  const PDLByteCode::MatchResult *pdlMatch = nullptr;
156 
157  /// Operation specific patterns.
158  if (opIt < opE)
159  bestPattern = opPatterns[opIt];
160  /// Operation agnostic patterns.
161  if (anyIt < anyE &&
162  (!bestPattern ||
163  bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) {
164  bestPatternIt = &anyIt;
165  bestPattern = anyOpPatterns[anyIt];
166  }
167  /// PDL patterns.
168  if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() <
169  pdlMatches[pdlIt].benefit)) {
170  bestPatternIt = &pdlIt;
171  pdlMatch = &pdlMatches[pdlIt];
172  bestPattern = pdlMatch->pattern;
173  }
174  if (!bestPattern)
175  break;
176 
177  // Update the pattern iterator on failure so that this pattern isn't
178  // attempted again.
179  ++(*bestPatternIt);
180 
181  // Check that the pattern can be applied.
182  if (canApply && !canApply(*bestPattern))
183  continue;
184 
185  // Try to match and rewrite this pattern. The patterns are sorted by
186  // benefit, so if we match we can immediately rewrite. For PDL patterns, the
187  // match has already been performed, we just need to rewrite.
188  rewriter.setInsertionPoint(op);
189 #ifndef NDEBUG
190  // Operation `op` may be invalidated after applying the rewrite pattern.
191  Operation *dumpRootOp = getDumpRootOp(op);
192 #endif
193  if (pdlMatch) {
194  bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
195  result = success(!onSuccess || succeeded(onSuccess(*bestPattern)));
196  } else {
197  const auto *pattern = static_cast<const RewritePattern *>(bestPattern);
198 
199  LLVM_DEBUG(llvm::dbgs()
200  << "Trying to match \"" << pattern->getDebugName() << "\"\n");
201  result = pattern->matchAndRewrite(op, rewriter);
202  LLVM_DEBUG(llvm::dbgs() << "\"" << pattern->getDebugName() << "\" result "
203  << succeeded(result) << "\n");
204 
205  if (succeeded(result) && onSuccess && failed(onSuccess(*pattern)))
206  result = failure();
207  }
208  if (succeeded(result)) {
209  LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
210  break;
211  }
212 
213  // Perform any necessary cleanups.
214  if (onFailure)
215  onFailure(*bestPattern);
216  } while (true);
217 
218  if (mutableByteCodeState)
219  mutableByteCodeState->cleanupAfterMatchAndRewrite();
220  return result;
221 }
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
Include the generated interface declarations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
This class represents a frozen set of patterns that can be processed by a pattern applicator...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:329
Each successful match returns a MatchResult, which contains information necessary to execute the rewr...
Definition: ByteCode.h:124
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:130
static void logImpossibleToMatch(const Pattern &pattern)
Log a message for a pattern that is impossible to match.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
void applyCostModel(CostModel model)
Apply a cost model to the patterns within this applicator.
void walkAllPatterns(function_ref< void(const Pattern &)> walk)
Walk all of the patterns within the applicator.
static void logSucessfulPatternApplication(Operation *op)
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:71
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:244
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Run the rewriter of the given pattern that was previously matched in match.
Definition: ByteCode.cpp:2192
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
PatternApplicator(const FrozenRewritePatternSet &frozenPatternList)
const detail::PDLByteCode * getPDLByteCode() const
Return the compiled PDL bytecode held by this list.
The bytecode class is also the interpreter.
Definition: ByteCode.h:120
bool isImpossibleToMatch() const
Definition: PatternMatch.h:42
This class provides the API for ops that are known to be isolated from above.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation...
Definition: Visitors.cpp:24
const OpSpecificNativePatternListT & getOpSpecificNativePatterns() const
Return the op specific native patterns held by this list.
const PDLByteCodePattern * pattern
The originating pattern that was matched.
Definition: ByteCode.h:143
static Operation * getDumpRootOp(Operation *op)
Log IR after pattern application.
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Run the pattern matcher on the given root operation, collecting the matched patterns in matches...
Definition: ByteCode.cpp:2169
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:57
PatternBenefit getBenefit() const
Return the benefit (the inverse of "cost") of matching this pattern.
Definition: PatternMatch.h:121
iterator_range< llvm::pointee_iterator< NativePatternListT::const_iterator > > getMatchAnyOpNativePatterns() const
Return the "match any" native patterns held by this list.
Optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
Definition: PatternMatch.h:92