MLIR  17.0.0git
VectorRewritePatterns.h
Go to the documentation of this file.
1 //===- VectorRewritePatterns.h - Vector rewrite patterns --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
11 
12 #include <optional>
13 #include <utility>
14 
17 #include "mlir/IR/PatternMatch.h"
19 
20 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
21 
22 namespace mlir {
23 class RewritePatternSet;
24 
25 namespace vector {
26 struct VectorTransformsOptions;
27 
28 /// Options that control the vector unrolling.
30  using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
31  /// Callback function that indicates whether vector unrolling should be
32  /// attempted on the operation.
35  filterConstraint = std::move(constraint);
36  return *this;
37  }
38 
40  std::function<std::optional<SmallVector<int64_t>>(Operation *op)>;
41  /// Function that returns the shape of the vector to unroll to for a given
42  /// operation. The unrolling is aborted if the function returns
43  /// `std::nullopt`.
46  nativeShape = std::move(fn);
47  return *this;
48  }
49 
50  /// Set the native shape to use for unrolling.
52  SmallVector<int64_t> tsShape(shape.begin(), shape.end());
53  nativeShape = [=](Operation *) -> std::optional<SmallVector<int64_t>> {
54  return tsShape;
55  };
56  return *this;
57  }
58 
59  /// Function that returns the traversal order (in terms of "for loop order",
60  /// i.e. slowest varying dimension to fastest varying dimension) that should
61  /// be used when unrolling the given operation into units of the native vector
62  /// size.
64  std::function<std::optional<SmallVector<int64_t>>(Operation *op)>;
68  traversalOrderCallback = std::move(traversalOrderFn);
69  return *this;
70  }
71 };
72 
73 /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
74 /// semantics to a contraction with MMT semantics (matrix matrix multiplication
75 /// with the RHS transposed). This specific form is meant to have the vector
76 /// operands are organized such that the reduction dimension is contiguous.
77 /// Example:
78 /// ```
79 /// vector.contract {indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
80 /// affine_map<(m, n, k) -> (n, k)>,
81 /// affine_map<(m, n, k) -> (m, n)>],
82 /// iterator_types = ["parallel", "parallel", "reduction"],
83 /// kind = #vector.kind<add>} %a, %b, %c : ...
84 /// ```
85 ///
86 /// The `constraint` predicate is used to decide which `vector.contraction` ops
87 /// to filter out.
89  RewritePatternSet &patterns,
90  std::function<LogicalResult(vector::ContractionOp)> constraint =
91  [](vector::ContractionOp) { return success(); },
92  PatternBenefit = 1);
93 
94 /// Collect patterns to convert reduction op to vector.contract and fold
95 /// transpose/broadcast ops into the contract.
96 void populateVectorReductionToContractPatterns(RewritePatternSet &patterns,
97  PatternBenefit benefit = 1);
98 
99 /// Populate `patterns` with the following patterns.
100 ///
101 /// - VectorTransferFullPartialRewriter
102 ///
103 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
104 /// masking) fast path and a slow path.
105 ///
106 /// Example (a 2-D vector.transfer_read):
107 /// ```
108 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
109 /// ```
110 /// is transformed into:
111 /// ```
112 /// %1:3 = scf.if (%inBounds) {
113 /// // fast path, direct cast
114 /// memref.cast %A: memref<A...> to compatibleMemRefType
115 /// scf.yield %view : compatibleMemRefType, index, index
116 /// } else {
117 /// // slow path, not in-bounds vector.transfer or linalg.copy.
118 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
119 /// scf.yield %4 : compatibleMemRefType, index, index
120 // }
121 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
122 /// ```
123 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
124 ///
125 /// Preconditions:
126 /// 1. `xferOp.permutation_map()` must be a minor identity map
127 /// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
128 /// must be equal. This will be relaxed in the future but requires
129 /// rank-reducing subviews.
131  RewritePatternSet &patterns, const VectorTransformsOptions &options);
132 
133 /// Collect a set of patterns to reduce the rank of the operands of vector
134 /// transfer ops to operate on the largest contigious vector.
135 /// These patterns are useful when lowering to dialects with 1d vector type
136 /// such as llvm and it will result fewer memory reads.
138  RewritePatternSet &patterns, PatternBenefit benefit = 1);
139 
140 /// Populate `patterns` with the following patterns.
141 ///
142 /// [DecomposeDifferentRankInsertStridedSlice]
143 /// ==========================================
144 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
145 /// have different ranks.
146 ///
147 /// When ranks are different, InsertStridedSlice needs to extract a properly
148 /// ranked vector from the destination vector into which to insert. This pattern
149 /// only takes care of this extraction part and forwards the rest to
150 /// [VectorInsertStridedSliceOpSameRankRewritePattern].
151 ///
152 /// For a k-D source and n-D destination vector (k < n), we emit:
153 /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
154 /// insert the k-D source.
155 /// 2. k-D -> (n-1)-D InsertStridedSlice op
156 /// 3. InsertOp that is the reverse of 1.
157 ///
158 /// [DecomposeNDExtractStridedSlice]
159 /// ================================
160 /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
161 /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
163  RewritePatternSet &patterns, PatternBenefit benefit = 1);
164 
165 /// Populate `patterns` with a pattern to breaks down 1-D extract_strided_slice
166 /// ops into a chain of Extract ops to extract each element from the source, and
167 /// then a chain of Insert ops to insert to the target vector.
168 ///
169 /// If `controlFn` is not nullptr, the pattern will only be invoked on ops that
170 /// `controlFn` returns true. Otherwise runs on ops.
172  RewritePatternSet &patterns,
173  std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
174  PatternBenefit benefit = 1);
175 
176 /// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops
177 /// based on the destination vector shape. Bitcasts from a lower bitwidth
178 /// element type to a higher bitwidth one are extracted from the lower bitwidth
179 /// based on the native destination vector shape and inserted based on the ratio
180 /// of the bitwidths.
181 ///
182 /// This acts as a last resort way to break down vector.bitcast ops to smaller
183 /// vector sizes. Because this pattern composes until it is bitcasting to a
184 /// single element of the higher bitwidth, the is an optional control function.
185 /// If `controlFn` is not nullptr, the pattern will only apply to ops where
186 /// `controlFn` returns true, otherwise applies to all bitcast ops.
188  RewritePatternSet &patterns,
189  std::function<bool(BitCastOp)> controlFn = nullptr,
190  PatternBenefit benefit = 1);
191 
192 /// Populate `patterns` with the following patterns.
193 ///
194 /// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns();
195 ///
196 /// [ConvertSameRankInsertStridedSliceIntoShuffle]
197 /// ==============================================
198 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
199 /// have the same rank. For each outermost index in the slice:
200 /// begin end stride
201 /// [offset : offset+size*stride : stride]
202 /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
203 /// 2. InsertStridedSlice (k-1)-D into (n-1)-D
204 /// 3. the destination subvector is inserted back in the proper place
205 /// 3. InsertOp that is the reverse of 1.
206 ///
207 /// [Convert1DExtractStridedSliceIntoShuffle]
208 /// =========================================
209 /// For such cases, we can lower it to a ShuffleOp.
211  RewritePatternSet &patterns, PatternBenefit benefit = 1);
212 
213 /// Collect patterns to fold tensor.extract_slice -> vector.transfer_read and
214 /// vector.transfer_write -> tensor.insert_slice op chains into vector tranfer
215 /// read and write ops.
216 ///
217 /// If `controlFn` is not nullptr, the pattern will only apply to ops where
218 /// `controlFn` returns true, given the vector transfer read/write op as input.
220  RewritePatternSet &patterns,
221  std::function<bool(Operation *vectorOp)> controlFn = nullptr,
222  PatternBenefit benefit = 1);
223 
224 /// Collect a set of pattern to unroll vector operations to a smaller shapes.
225 /// `options` structure controls which operations are unrolled and the target
226 /// shape.
227 /// `op` is unrolled to the `targetShape` as follows, for each of its operands:
228 /// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
229 /// `numUnrolledInstances` are computed from the `targetShape`. For now it is
230 /// assumed the unrolling factors divide the vector sizes.
231 /// 2. ExtractStridedSlice are created to break-up the vector operands.
232 /// 3. the original op is cloned `numUnrolledInstances` times, once for each
233 /// result.
234 /// 4. InsertStridedSlice are inserted to re-assemble the slices into the
235 /// original vectore shape.
236 ///
237 /// Example:
238 ///
239 /// opA(operand0, operand1) // numUnrolledInstances = 3
240 ///
241 /// operand0 operand1
242 /// | |
243 /// fork fork
244 /// <----------gather all fork ops --------->
245 /// /|\ /|\
246 /// f00 f01 f02 f10 f11 f12
247 /// <---------- clone op 3 times --------->
248 /// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
249 /// \ | /
250 /// <-------------------- join ------------------------->
251 ///
252 /// Other local patterns then kick in iteratively (including DCE) and compose
253 /// to combine the ExtractStridedSlice/InsertStridedSlice.
256  PatternBenefit benefit = 1);
257 
258 /// Collect a set of vector.shape_cast folding patterns.
260  PatternBenefit benefit = 1);
261 
262 /// Collect a set of leading one dimension removal patterns.
263 ///
264 /// These patterns insert vector.shape_cast to remove leading one dimensions
265 /// to expose more canonical forms of read/write/insert/extract operations.
266 /// With them, there are more chances that we can cancel out extract-insert
267 /// pairs or forward write-read pairs.
269  PatternBenefit benefit = 1);
270 
271 /// Collect a set of one dimension removal patterns.
272 ///
273 /// These patterns insert rank-reducing memref.subview ops to remove one
274 /// dimensions. With them, there are more chances that we can avoid
275 /// potentially expensive vector.shape_cast operations.
277  PatternBenefit benefit = 1);
278 
279 /// Collect a set of patterns to flatten n-D vector transfers on contiguous
280 /// memref.
281 ///
282 /// These patterns insert memref.collapse_shape + vector.shape_cast patterns
283 /// to transform multiple small n-D transfers into a larger 1-D transfer where
284 /// the memref contiguity properties allow it.
286  PatternBenefit benefit = 1);
287 
288 /// Collect a set of patterns that bubble up/down bitcast ops.
289 ///
290 /// These patterns move vector.bitcast ops to be before insert ops or after
291 /// extract ops where suitable. With them, bitcast will happen on smaller
292 /// vectors and there are more chances to share extract/insert ops.
294  PatternBenefit benefit = 1);
295 
296 /// These patterns materialize masks for various vector ops such as transfers.
298  bool force32BitVectorIndices,
299  PatternBenefit benefit = 1);
300 
301 } // namespace vector
302 } // namespace mlir
303 
304 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
static llvm::ManagedStatic< PassManagerOptions > options
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
void populateBreakDownVectorBitCastOpPatterns(RewritePatternSet &patterns, std::function< bool(BitCastOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to break down 1-D vector.bitcast ops based on the destination vector...
void populateVectorExtractStridedSliceToExtractInsertChainPatterns(RewritePatternSet &patterns, std::function< bool(ExtractStridedSliceOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to breaks down 1-D extract_strided_slice ops into a chain of Extract...
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of one dimension removal patterns.
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that bubble up/down bitcast ops.
void populateVectorTransferTensorSliceTransforms(RewritePatternSet &patterns, std::function< bool(Operation *vectorOp)> controlFn=nullptr, PatternBenefit benefit=1)
Collect patterns to fold tensor.extract_slice -> vector.transfer_read and vector.transfer_write -> te...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit=1)
These patterns materialize masks for various vector ops such as transfers.
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to flatten n-D vector transfers on contiguous memref.
void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of leading one dimension removal patterns.
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Options that control the vector unrolling.
FilterConstraintFnType filterConstraint
Callback function that indicates whether vector unrolling should be attempted on the operation.
UnrollVectorOptions & setFilterConstraint(FilterConstraintFnType constraint)
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
UnrollVectorOptions & setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn)
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> UnrollTraversalOrderFnType
Function that returns the traversal order (in terms of "for loop order", i.e.
std::function< LogicalResult(Operation *op)> FilterConstraintFnType
NativeShapeFnType nativeShape
Function that returns the shape of the vector to unroll to for a given operation.
UnrollVectorOptions & setNativeShape(ArrayRef< int64_t > shape)
Set the native shape to use for unrolling.
UnrollTraversalOrderFnType traversalOrderCallback
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> NativeShapeFnType