MLIR  19.0.0git
VectorRewritePatterns.h
Go to the documentation of this file.
1 //===- VectorRewritePatterns.h - Vector rewrite patterns --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
11 
12 #include <optional>
13 #include <utility>
14 
17 #include "mlir/IR/PatternMatch.h"
19 
20 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
21 
22 namespace mlir {
23 class ConversionTarget;
24 class RewritePatternSet;
25 class TypeConverter;
26 
27 namespace arith {
28 class AndIOp;
29 class NarrowTypeEmulationConverter;
30 class TruncIOp;
31 } // namespace arith
32 
33 namespace vector {
34 struct VectorTransformsOptions;
35 
36 /// Options that control the vector unrolling.
38  using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
39  /// Callback function that indicates whether vector unrolling should be
40  /// attempted on the operation.
43  filterConstraint = std::move(constraint);
44  return *this;
45  }
46 
48  std::function<std::optional<SmallVector<int64_t>>(Operation *op)>;
49  /// Function that returns the shape of the vector to unroll to for a given
50  /// operation. The unrolling is aborted if the function returns
51  /// `std::nullopt`.
54  nativeShape = std::move(fn);
55  return *this;
56  }
57 
58  /// Set the native shape to use for unrolling.
60  SmallVector<int64_t> tsShape(shape.begin(), shape.end());
61  nativeShape = [=](Operation *) -> std::optional<SmallVector<int64_t>> {
62  return tsShape;
63  };
64  return *this;
65  }
66 
67  /// Function that returns the traversal order (in terms of "for loop order",
68  /// i.e. slowest varying dimension to fastest varying dimension) that should
69  /// be used when unrolling the given operation into units of the native vector
70  /// size.
72  std::function<std::optional<SmallVector<int64_t>>(Operation *op)>;
76  traversalOrderCallback = std::move(traversalOrderFn);
77  return *this;
78  }
79 };
80 
81 /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
82 /// semantics to a contraction with MMT semantics (matrix matrix multiplication
83 /// with the RHS transposed). This specific form is meant to have the vector
84 /// operands are organized such that the reduction dimension is contiguous.
85 /// Example:
86 /// ```
87 /// vector.contract {indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
88 /// affine_map<(m, n, k) -> (n, k)>,
89 /// affine_map<(m, n, k) -> (m, n)>],
90 /// iterator_types = ["parallel", "parallel", "reduction"],
91 /// kind = #vector.kind<add>} %a, %b, %c : ...
92 /// ```
93 ///
94 /// The `constraint` predicate is used to decide which `vector.contraction` ops
95 /// to filter out.
97  RewritePatternSet &patterns,
98  std::function<LogicalResult(vector::ContractionOp)> constraint =
99  [](vector::ContractionOp) { return success(); },
100  PatternBenefit = 1);
101 
102 /// Collect patterns to convert reduction op to vector.contract and fold
103 /// transpose/broadcast ops into the contract.
104 void populateVectorReductionToContractPatterns(RewritePatternSet &patterns,
105  PatternBenefit benefit = 1);
106 
107 /// Populate `patterns` with the following patterns.
108 ///
109 /// - VectorTransferFullPartialRewriter
110 ///
111 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
112 /// masking) fast path and a slow path.
113 ///
114 /// Example (a 2-D vector.transfer_read):
115 /// ```
116 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
117 /// ```
118 /// is transformed into:
119 /// ```
120 /// %1:3 = scf.if (%inBounds) {
121 /// // fast path, direct cast
122 /// memref.cast %A: memref<A...> to compatibleMemRefType
123 /// scf.yield %view : compatibleMemRefType, index, index
124 /// } else {
125 /// // slow path, not in-bounds vector.transfer or linalg.copy.
126 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
127 /// scf.yield %4 : compatibleMemRefType, index, index
128 // }
129 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
130 /// ```
131 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
132 ///
133 /// Preconditions:
134 /// 1. `xferOp.permutation_map()` must be a minor identity map
135 /// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
136 /// must be equal. This will be relaxed in the future but requires
137 /// rank-reducing subviews.
139  RewritePatternSet &patterns, const VectorTransformsOptions &options);
140 
141 /// Collect a set of patterns to reduce the rank of the operands of vector
142 /// transfer ops to operate on the largest contigious vector.
143 /// These patterns are useful when lowering to dialects with 1d vector type
144 /// such as llvm and it will result fewer memory reads.
146  RewritePatternSet &patterns, PatternBenefit benefit = 1);
147 
148 /// Patterns that remove redundant vector broadcasts.
149 void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
150  PatternBenefit benefit = 1);
151 
152 /// Patterns that fold chained vector reductions. These patterns assume that
153 /// elementwise operations (e.g., `arith.addf` with vector operands) are
154 /// cheaper than vector reduction.
155 /// Note that these patterns change the order of reduction which may not always
156 /// produce bit-identical results on some floating point inputs.
157 ///
158 /// Example:
159 /// ```
160 /// %a = vector.reduction <add> %x, %acc
161 /// %b = vector.reduction <add> %y, %a
162 /// ```
163 /// is transformed into:
164 /// ```
165 /// %a = arith.addf %x, %y
166 /// %b = vector.reduction <add> %a, %acc
167 /// ```
168 void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,
169  PatternBenefit benefit = 1);
170 
171 /// Patterns to break down vector reductions into a series of arith reductions
172 /// over vector elements. This is intended to be simplify code with reductions
173 /// over small vector types and avoid more specialized reduction lowering when
174 /// possible.
175 ///
176 /// Example:
177 /// ```
178 /// %a = vector.reduction <add> %x : vector<2xf32> into f32
179 /// ```
180 /// is transformed into:
181 /// ```
182 /// %y = vector.extract %x[0] : f32 from vector<2xf32>
183 /// %z = vector.extract %x[1] : f32 from vector<2xf32>
184 /// %a = arith.addf %y, %z : f32
185 /// ```
187  RewritePatternSet &patterns, unsigned maxNumElementsToExtract = 2,
188  PatternBenefit benefit = 1);
189 
190 /// Populate `patterns` with the following patterns.
191 ///
192 /// [DecomposeDifferentRankInsertStridedSlice]
193 /// ==========================================
194 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
195 /// have different ranks.
196 ///
197 /// When ranks are different, InsertStridedSlice needs to extract a properly
198 /// ranked vector from the destination vector into which to insert. This pattern
199 /// only takes care of this extraction part and forwards the rest to
200 /// [VectorInsertStridedSliceOpSameRankRewritePattern].
201 ///
202 /// For a k-D source and n-D destination vector (k < n), we emit:
203 /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
204 /// insert the k-D source.
205 /// 2. k-D -> (n-1)-D InsertStridedSlice op
206 /// 3. InsertOp that is the reverse of 1.
207 ///
208 /// [DecomposeNDExtractStridedSlice]
209 /// ================================
210 /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
211 /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
213  RewritePatternSet &patterns, PatternBenefit benefit = 1);
214 
215 /// Populate `patterns` with a pattern to breaks down 1-D extract_strided_slice
216 /// ops into a chain of Extract ops to extract each element from the source, and
217 /// then a chain of Insert ops to insert to the target vector.
218 ///
219 /// If `controlFn` is not nullptr, the pattern will only be invoked on ops that
220 /// `controlFn` returns true. Otherwise runs on ops.
222  RewritePatternSet &patterns,
223  std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
224  PatternBenefit benefit = 1);
225 
226 /// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops
227 /// based on the destination vector shape. Bitcasts from a lower bitwidth
228 /// element type to a higher bitwidth one are extracted from the lower bitwidth
229 /// based on the native destination vector shape and inserted based on the ratio
230 /// of the bitwidths.
231 ///
232 /// This acts as a last resort way to break down vector.bitcast ops to smaller
233 /// vector sizes. Because this pattern composes until it is bitcasting to a
234 /// single element of the higher bitwidth, the is an optional control function.
235 /// If `controlFn` is not nullptr, the pattern will only apply to ops where
236 /// `controlFn` returns true, otherwise applies to all bitcast ops.
238  RewritePatternSet &patterns,
239  std::function<bool(BitCastOp)> controlFn = nullptr,
240  PatternBenefit benefit = 1);
241 
242 /// Populate `patterns` with the following patterns.
243 ///
244 /// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns();
245 ///
246 /// [ConvertSameRankInsertStridedSliceIntoShuffle]
247 /// ==============================================
248 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
249 /// have the same rank. For each outermost index in the slice:
250 /// begin end stride
251 /// [offset : offset+size*stride : stride]
252 /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
253 /// 2. InsertStridedSlice (k-1)-D into (n-1)-D
254 /// 3. the destination subvector is inserted back in the proper place
255 /// 3. InsertOp that is the reverse of 1.
256 ///
257 /// [Convert1DExtractStridedSliceIntoShuffle]
258 /// =========================================
259 /// For such cases, we can lower it to a ShuffleOp.
261  RewritePatternSet &patterns, PatternBenefit benefit = 1);
262 
263 /// Collect a set of pattern to unroll vector operations to a smaller shapes.
264 /// `options` structure controls which operations are unrolled and the target
265 /// shape.
266 /// `op` is unrolled to the `targetShape` as follows, for each of its operands:
267 /// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
268 /// `numUnrolledInstances` are computed from the `targetShape`. For now it is
269 /// assumed the unrolling factors divide the vector sizes.
270 /// 2. ExtractStridedSlice are created to break-up the vector operands.
271 /// 3. the original op is cloned `numUnrolledInstances` times, once for each
272 /// result.
273 /// 4. InsertStridedSlice are inserted to re-assemble the slices into the
274 /// original vectore shape.
275 ///
276 /// Example:
277 ///
278 /// opA(operand0, operand1) // numUnrolledInstances = 3
279 ///
280 /// operand0 operand1
281 /// | |
282 /// fork fork
283 /// <----------gather all fork ops --------->
284 /// /|\ /|\
285 /// f00 f01 f02 f10 f11 f12
286 /// <---------- clone op 3 times --------->
287 /// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
288 /// \ | /
289 /// <-------------------- join ------------------------->
290 ///
291 /// Other local patterns then kick in iteratively (including DCE) and compose
292 /// to combine the ExtractStridedSlice/InsertStridedSlice.
295  PatternBenefit benefit = 1);
296 
297 /// Collect a set of vector.shape_cast folding patterns.
299  PatternBenefit benefit = 1);
300 
301 /// Collect a set of leading one dimension removal patterns.
302 ///
303 /// These patterns insert vector.shape_cast to remove leading one dimensions
304 /// to expose more canonical forms of read/write/insert/extract operations.
305 /// With them, there are more chances that we can cancel out extract-insert
306 /// pairs or forward write-read pairs.
308  PatternBenefit benefit = 1);
309 
310 /// Collect a set of one dimension removal patterns.
311 ///
312 /// These patterns insert rank-reducing memref.subview ops to remove one
313 /// dimensions. With them, there are more chances that we can avoid
314 /// potentially expensive vector.shape_cast operations.
316  PatternBenefit benefit = 1);
317 
318 /// Collect a set of patterns that use vector.shape_cast to help fold unit dims.
319 ///
320 /// These patterns use vector.shape_cast to remove unit dims from e.g.
321 /// arithmetic operations on Vectors. The newly inserted shape_casts will either
322 /// cancel each other out or will be folded away when combined with other
323 /// patterns.
325  PatternBenefit benefit = 1);
326 
327 /// Collect a set of patterns to flatten n-D vector transfers on contiguous
328 /// memref.
329 ///
330 /// These patterns insert memref.collapse_shape + vector.shape_cast patterns
331 /// to transform multiple small n-D transfers into a larger 1-D transfer where
332 /// the memref contiguity properties allow it.
333 ///
334 /// Flattening is only applied if the bitwidth of the trailing vector dimension
335 /// is smaller or equal to `targetVectorBitwidth`.
337  RewritePatternSet &patterns,
338  unsigned targetVectorBitwidth = std::numeric_limits<unsigned>::max(),
339  PatternBenefit benefit = 1);
340 
341 /// Collect a set of patterns that bubble up/down bitcast ops.
342 ///
343 /// These patterns move vector.bitcast ops to be before insert ops or after
344 /// extract ops where suitable. With them, bitcast will happen on smaller
345 /// vectors and there are more chances to share extract/insert ops.
347  PatternBenefit benefit = 1);
348 
349 /// These patterns materialize masks for various vector ops such as transfers.
351  bool force32BitVectorIndices,
352  PatternBenefit benefit = 1);
353 
354 /// Appends patterns for emulating vector operations over narrow types with ops
355 /// over wider types.
358  RewritePatternSet &patterns);
359 
360 /// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
361 /// vector operations comprising `shuffle` and `bitwise` ops.
362 /// Warning: these patterns currently only work for little endian targets.
364  vector::BitCastOp bitCastOp,
365  arith::TruncIOp truncOp,
366  vector::BroadcastOp maybeBroadcastOp);
367 
368 /// Rewrite a vector `ext(bitcast)` to use a more efficient sequence of
369 /// vector operations comprising `shuffle` and `bitwise` ops.
370 /// Warning: these patterns currently only work for little endian targets.
372  vector::BitCastOp bitCastOp,
373  vector::BroadcastOp maybeBroadcastOp);
374 
375 /// Appends patterns for rewriting vector operations over narrow types with
376 /// ops over wider types.
377 /// Warning: these patterns currently only work for little endian targets.
379  PatternBenefit benefit = 1);
380 
381 /// Appends patterns for emulating a sub-byte vector transpose.
383  RewritePatternSet &patterns, PatternBenefit benefit = 1);
384 
385 /// Populates patterns for ND vectors (N >= 2) linearization and sets up the
386 /// provided ConversionTarget with the appropriate legality configuration for
387 /// the ops to get converted properly.
389  TypeConverter &typeConverter, RewritePatternSet &patterns,
390  ConversionTarget &target);
391 
392 } // namespace vector
393 } // namespace mlir
394 
395 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class describes a specific conversion target.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
Type conversion class.
Converts narrow integer or float types that are not supported by the target hardware to wider types.
void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that use vector.shape_cast to help fold unit dims.
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for rewriting vector operations over narrow types with ops over wider types.
void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns, unsigned targetVectorBitwidth=std::numeric_limits< unsigned >::max(), PatternBenefit benefit=1)
Collect a set of patterns to flatten n-D vector transfers on contiguous memref.
void populateBreakDownVectorBitCastOpPatterns(RewritePatternSet &patterns, std::function< bool(BitCastOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to break down 1-D vector.bitcast ops based on the destination vector...
void populateVectorExtractStridedSliceToExtractInsertChainPatterns(RewritePatternSet &patterns, std::function< bool(ExtractStridedSliceOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to breaks down 1-D extract_strided_slice ops into a chain of Extract...
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of one dimension removal patterns.
void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that fold chained vector reductions.
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that bubble up/down bitcast ops.
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
FailureOr< Value > rewriteBitCastOfTruncI(RewriterBase &rewriter, vector::BitCastOp bitCastOp, arith::TruncIOp truncOp, vector::BroadcastOp maybeBroadcastOp)
Rewrite a vector bitcast(trunci) to use a more efficient sequence of vector operations comprising shu...
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit=1)
These patterns materialize masks for various vector ops such as transfers.
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant vector broadcasts.
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
void populateVectorNarrowTypeEmulationPatterns(arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating vector operations over narrow types with ops over wider types.
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
void populateBreakDownVectorReductionPatterns(RewritePatternSet &patterns, unsigned maxNumElementsToExtract=2, PatternBenefit benefit=1)
Patterns to break down vector reductions into a series of arith reductions over vector elements.
void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of leading one dimension removal patterns.
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorTransposeNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for emulating a sub-byte vector transpose.
void populateVectorLinearizeTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for ND vectors (N >= 2) linearization and sets up the provided ConversionTarget wi...
FailureOr< Value > rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp, vector::BitCastOp bitCastOp, vector::BroadcastOp maybeBroadcastOp)
Rewrite a vector ext(bitcast) to use a more efficient sequence of vector operations comprising shuffl...
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Options that control the vector unrolling.
FilterConstraintFnType filterConstraint
Callback function that indicates whether vector unrolling should be attempted on the operation.
UnrollVectorOptions & setFilterConstraint(FilterConstraintFnType constraint)
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
UnrollVectorOptions & setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn)
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> UnrollTraversalOrderFnType
Function that returns the traversal order (in terms of "for loop order", i.e.
std::function< LogicalResult(Operation *op)> FilterConstraintFnType
NativeShapeFnType nativeShape
Function that returns the shape of the vector to unroll to for a given operation.
UnrollVectorOptions & setNativeShape(ArrayRef< int64_t > shape)
Set the native shape to use for unrolling.
UnrollTraversalOrderFnType traversalOrderCallback
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> NativeShapeFnType