MLIR 22.0.0git
ByteCode.h
Go to the documentation of this file.
1//===- ByteCode.h - Pattern byte-code interpreter ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file declares a byte-code and interpreter for pattern rewrites in MLIR.
10// The byte-code is constructed from the PDL Interpreter dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_REWRITE_BYTECODE_H_
15#define MLIR_REWRITE_BYTECODE_H_
16
18
19#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
20
21namespace mlir {
22namespace pdl_interp {
23class RecordMatchOp;
24} // namespace pdl_interp
25
26namespace detail {
27class PDLByteCode;
28
29/// Use generic bytecode types. ByteCodeField refers to the actual bytecode
30/// entries. ByteCodeAddr refers to size of indices into the bytecode.
31using ByteCodeField = uint16_t;
32using ByteCodeAddr = uint32_t;
33
34//===----------------------------------------------------------------------===//
35// PDLByteCodePattern
36//===----------------------------------------------------------------------===//
37
38/// All of the data pertaining to a specific pattern within the bytecode.
39class PDLByteCodePattern : public Pattern {
40public:
41 static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp,
42 PDLPatternConfigSet *configSet,
43 ByteCodeAddr rewriterAddr);
44
45 /// Return the bytecode address of the rewriter for this pattern.
46 ByteCodeAddr getRewriterAddr() const { return rewriterAddr; }
47
48 /// Return the configuration set for this pattern, or null if there is none.
49 PDLPatternConfigSet *getConfigSet() const { return configSet; }
50
51private:
52 template <typename... Args>
53 PDLByteCodePattern(ByteCodeAddr rewriterAddr, PDLPatternConfigSet *configSet,
54 Args &&...patternArgs)
55 : Pattern(std::forward<Args>(patternArgs)...), rewriterAddr(rewriterAddr),
56 configSet(configSet) {}
57
58 /// The address of the rewriter for this pattern.
59 ByteCodeAddr rewriterAddr;
60
61 /// The optional config set for this pattern.
62 PDLPatternConfigSet *configSet;
63};
64
65//===----------------------------------------------------------------------===//
66// PDLByteCodeMutableState
67//===----------------------------------------------------------------------===//
68
69/// This class contains the mutable state of a bytecode instance. This allows
70/// for a bytecode instance to be cached and reused across various different
71/// threads/drivers.
73public:
74 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
75 /// to the position of the pattern within the range returned by
76 /// `PDLByteCode::getPatterns`.
77 void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit);
78
79 /// Cleanup any allocated state after a match/rewrite has been completed. This
80 /// method should be called irregardless of whether the match+rewrite was a
81 /// success or not.
83
84private:
85 /// Allow access to data fields.
86 friend class PDLByteCode;
87
88 /// The mutable block of memory used during the matching and rewriting phases
89 /// of the bytecode.
90 std::vector<const void *> memory;
91
92 /// A mutable block of memory used during the matching and rewriting phase of
93 /// the bytecode to store ranges of operations. These are always stored by
94 /// owning references, because at no point in the execution of the byte code
95 /// we get an indexed range (view) of operations.
96 std::vector<std::vector<Operation *>> opRangeMemory;
97
98 /// A mutable block of memory used during the matching and rewriting phase of
99 /// the bytecode to store ranges of types.
100 std::vector<TypeRange> typeRangeMemory;
101 /// A set of type ranges that have been allocated by the byte code interpreter
102 /// to provide a guaranteed lifetime.
103 std::vector<std::vector<Type>> allocatedTypeRangeMemory;
104
105 /// A mutable block of memory used during the matching and rewriting phase of
106 /// the bytecode to store ranges of values.
107 std::vector<ValueRange> valueRangeMemory;
108 /// A set of value ranges that have been allocated by the byte code
109 /// interpreter to provide a guaranteed lifetime.
110 std::vector<std::vector<Value>> allocatedValueRangeMemory;
111
112 /// The current index of ranges being iterated over for each level of nesting.
113 /// These are always maintained at 0 for the loops that are not active, so we
114 /// do not need to have a separate initialization phase for each loop.
115 std::vector<unsigned> loopIndex;
116
117 /// The up-to-date benefits of the patterns held by the bytecode. The order
118 /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`.
119 std::vector<PatternBenefit> currentPatternBenefits;
120};
121
122//===----------------------------------------------------------------------===//
123// PDLByteCode
124//===----------------------------------------------------------------------===//
125
126/// The bytecode class is also the interpreter. Contains the bytecode itself,
127/// the static info, addresses of the rewriter functions, the interpreter
128/// memory buffer, and the execution context.
129class PDLByteCode {
130public:
131 /// Each successful match returns a MatchResult, which contains information
132 /// necessary to execute the rewriter and indicates the originating pattern.
133 struct MatchResult {
134 MatchResult(Location loc, const PDLByteCodePattern &pattern,
135 PatternBenefit benefit)
136 : location(loc), pattern(&pattern), benefit(benefit) {}
137 MatchResult(const MatchResult &) = delete;
138 MatchResult &operator=(const MatchResult &) = delete;
139 MatchResult(MatchResult &&other) = default;
140 MatchResult &operator=(MatchResult &&) = default;
141
142 /// The location of operations to be replaced.
143 Location location;
144 /// Memory values defined in the matcher that are passed to the rewriter.
145 SmallVector<const void *> values;
146 /// Memory used for the range input values.
147 SmallVector<TypeRange, 0> typeRangeValues;
148 SmallVector<ValueRange, 0> valueRangeValues;
149
150 /// The originating pattern that was matched. This is always non-null, but
151 /// represented with a pointer to allow for assignment.
152 const PDLByteCodePattern *pattern;
153 /// The current benefit of the pattern that was matched.
154 PatternBenefit benefit;
155 };
156
157 /// Create a ByteCode instance from the given module containing operations in
158 /// the PDL interpreter dialect.
159 PDLByteCode(ModuleOp module,
160 SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
162 llvm::StringMap<PDLConstraintFunction> constraintFns,
163 llvm::StringMap<PDLRewriteFunction> rewriteFns);
164
165 /// Return the patterns held by the bytecode.
166 ArrayRef<PDLByteCodePattern> getPatterns() const { return patterns; }
167
168 /// Initialize the given state such that it can be used to execute the current
169 /// bytecode.
170 void initializeMutableState(PDLByteCodeMutableState &state) const;
171
172 /// Run the pattern matcher on the given root operation, collecting the
173 /// matched patterns in `matches`.
174 void match(Operation *op, PatternRewriter &rewriter,
175 SmallVectorImpl<MatchResult> &matches,
176 PDLByteCodeMutableState &state) const;
177
178 /// Run the rewriter of the given pattern that was previously matched in
179 /// `match`. Returns if a failure was encountered during the rewrite.
180 LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match,
181 PDLByteCodeMutableState &state) const;
182
183private:
184 /// Execute the given byte code starting at the provided instruction `inst`.
185 /// `matches` is an optional field provided when this function is executed in
186 /// a matching context.
187 void executeByteCode(const ByteCodeField *inst, PatternRewriter &rewriter,
188 PDLByteCodeMutableState &state,
189 SmallVectorImpl<MatchResult> *matches) const;
190
191 /// The set of pattern configs referenced within the bytecode.
192 SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
193
194 /// A vector containing pointers to uniqued data. The storage is intentionally
195 /// opaque such that we can store a wide range of data types. The types of
196 /// data stored here include:
197 /// * Attribute, OperationName, Type
198 std::vector<const void *> uniquedData;
199
200 /// A vector containing the generated bytecode for the matcher.
201 SmallVector<ByteCodeField, 64> matcherByteCode;
202
203 /// A vector containing the generated bytecode for all of the rewriters.
204 SmallVector<ByteCodeField, 64> rewriterByteCode;
205
206 /// The set of patterns contained within the bytecode.
207 SmallVector<PDLByteCodePattern, 32> patterns;
208
209 /// A set of user defined functions invoked via PDL.
210 std::vector<PDLConstraintFunction> constraintFunctions;
211 std::vector<PDLRewriteFunction> rewriteFunctions;
212
213 /// The maximum memory index used by a value.
214 ByteCodeField maxValueMemoryIndex = 0;
215
216 /// The maximum number of different types of ranges.
217 ByteCodeField maxOpRangeCount = 0;
218 ByteCodeField maxTypeRangeCount = 0;
219 ByteCodeField maxValueRangeCount = 0;
220
221 /// The maximum number of nested loops.
222 ByteCodeField maxLoopLevel = 0;
223};
224
225} // namespace detail
226} // namespace mlir
227
228#else
229
230namespace mlir::detail {
231
233public:
235 void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit) {}
236};
237
238class PDLByteCodePattern : public Pattern {};
239
241public:
246
248 void match(Operation *op, PatternRewriter &rewriter,
250 PDLByteCodeMutableState &state) const {}
251 LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match,
252 PDLByteCodeMutableState &state) const {
253 return failure();
254 }
256};
257
258} // namespace mlir::detail
259
260#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
261
262#endif // MLIR_REWRITE_BYTECODE_H_
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name.
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
Definition ByteCode.h:235
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a full match/rewrite has been completed.
Definition ByteCode.h:234
ArrayRef< PDLByteCodePattern > getPatterns() const
Definition ByteCode.h:255
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Definition ByteCode.h:248
void initializeMutableState(PDLByteCodeMutableState &state) const
Initialize the given state such that it can be used to execute the current bytecode.
Definition ByteCode.h:247
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Definition ByteCode.h:251
AttrTypeReplacer.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
const PDLByteCodePattern * pattern
Definition ByteCode.h:243