MLIR  19.0.0git
PatternApplicator.h
Go to the documentation of this file.
1 //===- PatternApplicator.h - PatternApplicator ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an applicator that applies pattern rewrites based upon a
10 // user defined cost model.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_REWRITE_PATTERNAPPLICATOR_H
15 #define MLIR_REWRITE_PATTERNAPPLICATOR_H
16 
18 
19 #include "mlir/IR/Action.h"
20 
21 namespace mlir {
22 class PatternRewriter;
23 
24 namespace detail {
25 class PDLByteCodeMutableState;
26 } // namespace detail
27 
28 /// This is the type of Action that is dispatched when a pattern is applied.
29 /// It captures the pattern to apply on top of the usual context.
30 class ApplyPatternAction : public tracing::ActionImpl<ApplyPatternAction> {
31 public:
34  : Base(irUnits), pattern(pattern) {}
35  static constexpr StringLiteral tag = "apply-pattern";
36  static constexpr StringLiteral desc =
37  "Encapsulate the application of rewrite patterns";
38 
39  void print(raw_ostream &os) const override {
40  os << "`" << tag << " pattern: " << pattern.getDebugName();
41  }
42 
43 private:
44  const Pattern &pattern;
45 };
46 
47 /// This class manages the application of a group of rewrite patterns, with a
48 /// user-provided cost model.
50 public:
51  /// The cost model dynamically assigns a PatternBenefit to a particular
52  /// pattern. Users can query contained patterns and pass analysis results to
53  /// applyCostModel. Patterns to be discarded should have a benefit of
54  /// `impossibleToMatch`.
56 
57  explicit PatternApplicator(const FrozenRewritePatternSet &frozenPatternList);
59 
60  /// Attempt to match and rewrite the given op with any pattern, allowing a
61  /// predicate to decide if a pattern can be applied or not, and hooks for if
62  /// the pattern match was a success or failure.
63  ///
64  /// canApply: called before each match and rewrite attempt; return false to
65  /// skip pattern.
66  /// onFailure: called when a pattern fails to match to perform cleanup.
67  /// onSuccess: called when a pattern match succeeds; return failure() to
68  /// invalidate the match and try another pattern.
71  function_ref<bool(const Pattern &)> canApply = {},
72  function_ref<void(const Pattern &)> onFailure = {},
73  function_ref<LogicalResult(const Pattern &)> onSuccess = {});
74 
75  /// Apply a cost model to the patterns within this applicator.
76  void applyCostModel(CostModel model);
77 
78  /// Apply the default cost model that solely uses the pattern's static
79  /// benefit.
81  applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); });
82  }
83 
84  /// Walk all of the patterns within the applicator.
85  void walkAllPatterns(function_ref<void(const Pattern &)> walk);
86 
87 private:
88  /// The list that owns the patterns used within this applicator.
89  const FrozenRewritePatternSet &frozenPatternList;
90  /// The set of patterns to match for each operation, stable sorted by benefit.
92  /// The set of patterns that may match against any operation type, stable
93  /// sorted by benefit.
95  /// The mutable state used during execution of the PDL bytecode.
96  std::unique_ptr<detail::PDLByteCodeMutableState> mutableByteCodeState;
97 };
98 
99 } // namespace mlir
100 
101 #endif // MLIR_REWRITE_PATTERNAPPLICATOR_H
This is the type of Action that is dispatched when a pattern is applied.
ApplyPatternAction(ArrayRef< IRUnit > irUnits, const Pattern &pattern)
static constexpr StringLiteral desc
void print(raw_ostream &os) const override
static constexpr StringLiteral tag
This class represents a frozen set of patterns that can be processed by a pattern applicator.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
void walkAllPatterns(function_ref< void(const Pattern &)> walk)
Walk all of the patterns within the applicator.
void applyCostModel(CostModel model)
Apply a cost model to the patterns within this applicator.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
PatternApplicator(const FrozenRewritePatternSet &frozenPatternList)
function_ref< PatternBenefit(const Pattern &)> CostModel
The cost model dynamically assigns a PatternBenefit to a particular pattern.
void applyDefaultCostModel()
Apply the default cost model that solely uses the pattern's static benefit.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
PatternBenefit getBenefit() const
Return the benefit (the inverse of "cost") of matching this pattern.
Definition: PatternMatch.h:123
StringRef getDebugName() const
Return a readable name for this pattern.
Definition: PatternMatch.h:140
CRTP Implementation of an action.
Definition: Action.h:77
ArrayRef< IRUnit > irUnits
Set of IR units (operations, regions, blocks, values) that are associated with this action.
Definition: Action.h:67
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:137
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:147
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26