MLIR  20.0.0git
ByteCode.h
Go to the documentation of this file.
1 //===- ByteCode.h - Pattern byte-code interpreter ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares a byte-code and interpreter for pattern rewrites in MLIR.
10 // The byte-code is constructed from the PDL Interpreter dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_REWRITE_BYTECODE_H_
15 #define MLIR_REWRITE_BYTECODE_H_
16 
17 #include "mlir/IR/PatternMatch.h"
18 
19 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
20 
21 namespace mlir {
22 namespace pdl_interp {
23 class RecordMatchOp;
24 } // namespace pdl_interp
25 
26 namespace detail {
27 class PDLByteCode;
28 
29 /// Use generic bytecode types. ByteCodeField refers to the actual bytecode
30 /// entries. ByteCodeAddr refers to size of indices into the bytecode.
31 using ByteCodeField = uint16_t;
32 using ByteCodeAddr = uint32_t;
33 using OwningOpRange = llvm::OwningArrayRef<Operation *>;
34 
35 //===----------------------------------------------------------------------===//
36 // PDLByteCodePattern
37 //===----------------------------------------------------------------------===//
38 
39 /// All of the data pertaining to a specific pattern within the bytecode.
40 class PDLByteCodePattern : public Pattern {
41 public:
42  static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp,
43  PDLPatternConfigSet *configSet,
44  ByteCodeAddr rewriterAddr);
45 
46  /// Return the bytecode address of the rewriter for this pattern.
47  ByteCodeAddr getRewriterAddr() const { return rewriterAddr; }
48 
49  /// Return the configuration set for this pattern, or null if there is none.
50  PDLPatternConfigSet *getConfigSet() const { return configSet; }
51 
52 private:
53  template <typename... Args>
54  PDLByteCodePattern(ByteCodeAddr rewriterAddr, PDLPatternConfigSet *configSet,
55  Args &&...patternArgs)
56  : Pattern(std::forward<Args>(patternArgs)...), rewriterAddr(rewriterAddr),
57  configSet(configSet) {}
58 
59  /// The address of the rewriter for this pattern.
60  ByteCodeAddr rewriterAddr;
61 
62  /// The optional config set for this pattern.
63  PDLPatternConfigSet *configSet;
64 };
65 
66 //===----------------------------------------------------------------------===//
67 // PDLByteCodeMutableState
68 //===----------------------------------------------------------------------===//
69 
70 /// This class contains the mutable state of a bytecode instance. This allows
71 /// for a bytecode instance to be cached and reused across various different
72 /// threads/drivers.
73 class PDLByteCodeMutableState {
74 public:
75  /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
76  /// to the position of the pattern within the range returned by
77  /// `PDLByteCode::getPatterns`.
78  void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit);
79 
80  /// Cleanup any allocated state after a match/rewrite has been completed. This
81  /// method should be called irregardless of whether the match+rewrite was a
82  /// success or not.
84 
85 private:
86  /// Allow access to data fields.
87  friend class PDLByteCode;
88 
89  /// The mutable block of memory used during the matching and rewriting phases
90  /// of the bytecode.
91  std::vector<const void *> memory;
92 
93  /// A mutable block of memory used during the matching and rewriting phase of
94  /// the bytecode to store ranges of operations. These are always stored by
95  /// owning references, because at no point in the execution of the byte code
96  /// we get an indexed range (view) of operations.
97  std::vector<OwningOpRange> opRangeMemory;
98 
99  /// A mutable block of memory used during the matching and rewriting phase of
100  /// the bytecode to store ranges of types.
101  std::vector<TypeRange> typeRangeMemory;
102  /// A set of type ranges that have been allocated by the byte code interpreter
103  /// to provide a guaranteed lifetime.
104  std::vector<llvm::OwningArrayRef<Type>> allocatedTypeRangeMemory;
105 
106  /// A mutable block of memory used during the matching and rewriting phase of
107  /// the bytecode to store ranges of values.
108  std::vector<ValueRange> valueRangeMemory;
109  /// A set of value ranges that have been allocated by the byte code
110  /// interpreter to provide a guaranteed lifetime.
111  std::vector<llvm::OwningArrayRef<Value>> allocatedValueRangeMemory;
112 
113  /// The current index of ranges being iterated over for each level of nesting.
114  /// These are always maintained at 0 for the loops that are not active, so we
115  /// do not need to have a separate initialization phase for each loop.
116  std::vector<unsigned> loopIndex;
117 
118  /// The up-to-date benefits of the patterns held by the bytecode. The order
119  /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`.
120  std::vector<PatternBenefit> currentPatternBenefits;
121 };
122 
123 //===----------------------------------------------------------------------===//
124 // PDLByteCode
125 //===----------------------------------------------------------------------===//
126 
127 /// The bytecode class is also the interpreter. Contains the bytecode itself,
128 /// the static info, addresses of the rewriter functions, the interpreter
129 /// memory buffer, and the execution context.
130 class PDLByteCode {
131 public:
132  /// Each successful match returns a MatchResult, which contains information
133  /// necessary to execute the rewriter and indicates the originating pattern.
134  struct MatchResult {
135  MatchResult(Location loc, const PDLByteCodePattern &pattern,
136  PatternBenefit benefit)
137  : location(loc), pattern(&pattern), benefit(benefit) {}
138  MatchResult(const MatchResult &) = delete;
139  MatchResult &operator=(const MatchResult &) = delete;
140  MatchResult(MatchResult &&other) = default;
141  MatchResult &operator=(MatchResult &&) = default;
142 
143  /// The location of operations to be replaced.
144  Location location;
145  /// Memory values defined in the matcher that are passed to the rewriter.
146  SmallVector<const void *> values;
147  /// Memory used for the range input values.
148  SmallVector<TypeRange, 0> typeRangeValues;
149  SmallVector<ValueRange, 0> valueRangeValues;
150 
151  /// The originating pattern that was matched. This is always non-null, but
152  /// represented with a pointer to allow for assignment.
153  const PDLByteCodePattern *pattern;
154  /// The current benefit of the pattern that was matched.
155  PatternBenefit benefit;
156  };
157 
158  /// Create a ByteCode instance from the given module containing operations in
159  /// the PDL interpreter dialect.
160  PDLByteCode(ModuleOp module,
161  SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
162  const DenseMap<Operation *, PDLPatternConfigSet *> &configMap,
163  llvm::StringMap<PDLConstraintFunction> constraintFns,
164  llvm::StringMap<PDLRewriteFunction> rewriteFns);
165 
166  /// Return the patterns held by the bytecode.
167  ArrayRef<PDLByteCodePattern> getPatterns() const { return patterns; }
168 
169  /// Initialize the given state such that it can be used to execute the current
170  /// bytecode.
171  void initializeMutableState(PDLByteCodeMutableState &state) const;
172 
173  /// Run the pattern matcher on the given root operation, collecting the
174  /// matched patterns in `matches`.
175  void match(Operation *op, PatternRewriter &rewriter,
176  SmallVectorImpl<MatchResult> &matches,
177  PDLByteCodeMutableState &state) const;
178 
179  /// Run the rewriter of the given pattern that was previously matched in
180  /// `match`. Returns if a failure was encountered during the rewrite.
181  LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match,
182  PDLByteCodeMutableState &state) const;
183 
184 private:
185  /// Execute the given byte code starting at the provided instruction `inst`.
186  /// `matches` is an optional field provided when this function is executed in
187  /// a matching context.
188  void executeByteCode(const ByteCodeField *inst, PatternRewriter &rewriter,
189  PDLByteCodeMutableState &state,
190  SmallVectorImpl<MatchResult> *matches) const;
191 
192  /// The set of pattern configs referenced within the bytecode.
193  SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
194 
195  /// A vector containing pointers to uniqued data. The storage is intentionally
196  /// opaque such that we can store a wide range of data types. The types of
197  /// data stored here include:
198  /// * Attribute, OperationName, Type
199  std::vector<const void *> uniquedData;
200 
201  /// A vector containing the generated bytecode for the matcher.
202  SmallVector<ByteCodeField, 64> matcherByteCode;
203 
204  /// A vector containing the generated bytecode for all of the rewriters.
205  SmallVector<ByteCodeField, 64> rewriterByteCode;
206 
207  /// The set of patterns contained within the bytecode.
208  SmallVector<PDLByteCodePattern, 32> patterns;
209 
210  /// A set of user defined functions invoked via PDL.
211  std::vector<PDLConstraintFunction> constraintFunctions;
212  std::vector<PDLRewriteFunction> rewriteFunctions;
213 
214  /// The maximum memory index used by a value.
215  ByteCodeField maxValueMemoryIndex = 0;
216 
217  /// The maximum number of different types of ranges.
218  ByteCodeField maxOpRangeCount = 0;
219  ByteCodeField maxTypeRangeCount = 0;
220  ByteCodeField maxValueRangeCount = 0;
221 
222  /// The maximum number of nested loops.
223  ByteCodeField maxLoopLevel = 0;
224 };
225 
226 } // namespace detail
227 } // namespace mlir
228 
229 #else
230 
231 namespace mlir::detail {
232 
234 public:
236  void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit) {}
237 };
238 
239 class PDLByteCodePattern : public Pattern {};
240 
241 class PDLByteCode {
242 public:
243  struct MatchResult {
244  const PDLByteCodePattern *pattern = nullptr;
246  };
247 
249  void match(Operation *op, PatternRewriter &rewriter,
251  PDLByteCodeMutableState &state) const {}
252  LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match,
253  PDLByteCodeMutableState &state) const {
254  return failure();
255  }
257 };
258 
259 } // namespace mlir::detail
260 
261 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
262 
263 #endif // MLIR_REWRITE_BYTECODE_H_
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
Definition: ByteCode.h:236
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a full match/rewrite has been completed.
Definition: ByteCode.h:235
ArrayRef< PDLByteCodePattern > getPatterns() const
Definition: ByteCode.h:256
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Definition: ByteCode.h:249
void initializeMutableState(PDLByteCodeMutableState &state) const
Initialize the given state such that it can be used to execute the current bytecode.
Definition: ByteCode.h:248
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Definition: ByteCode.h:252
AttrTypeReplacer.
Include the generated interface declarations.
const PDLByteCodePattern * pattern
Definition: ByteCode.h:244