MLIR  20.0.0git
VectorRewritePatterns.h
Go to the documentation of this file.
1 //===- VectorRewritePatterns.h - Vector rewrite patterns --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
11 
12 #include <optional>
13 #include <utility>
14 
17 #include "mlir/IR/PatternMatch.h"
18 
19 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
20 
21 namespace mlir {
22 class ConversionTarget;
23 class RewritePatternSet;
24 class TypeConverter;
25 
26 namespace arith {
27 class AndIOp;
28 class NarrowTypeEmulationConverter;
29 class TruncIOp;
30 } // namespace arith
31 
32 namespace vector {
33 struct VectorTransformsOptions;
34 
35 /// Options that control the vector unrolling.
37  using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
38  /// Callback function that indicates whether vector unrolling should be
39  /// attempted on the operation.
42  filterConstraint = std::move(constraint);
43  return *this;
44  }
45 
47  std::function<std::optional<SmallVector<int64_t>>(Operation *op)>;
48  /// Function that returns the shape of the vector to unroll to for a given
49  /// operation. The unrolling is aborted if the function returns
50  /// `std::nullopt`.
53  nativeShape = std::move(fn);
54  return *this;
55  }
56 
57  /// Set the native shape to use for unrolling.
59  SmallVector<int64_t> tsShape(shape.begin(), shape.end());
60  nativeShape = [=](Operation *) -> std::optional<SmallVector<int64_t>> {
61  return tsShape;
62  };
63  return *this;
64  }
65 
66  /// Function that returns the traversal order (in terms of "for loop order",
67  /// i.e. slowest varying dimension to fastest varying dimension) that should
68  /// be used when unrolling the given operation into units of the native vector
69  /// size.
71  std::function<std::optional<SmallVector<int64_t>>(Operation *op)>;
75  traversalOrderCallback = std::move(traversalOrderFn);
76  return *this;
77  }
78 };
79 
80 /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
81 /// semantics to a contraction with MMT semantics (matrix matrix multiplication
82 /// with the RHS transposed). This specific form is meant to have the vector
83 /// operands are organized such that the reduction dimension is contiguous.
84 /// Example:
85 /// ```
86 /// vector.contract {indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
87 /// affine_map<(m, n, k) -> (n, k)>,
88 /// affine_map<(m, n, k) -> (m, n)>],
89 /// iterator_types = ["parallel", "parallel", "reduction"],
90 /// kind = #vector.kind<add>} %a, %b, %c : ...
91 /// ```
92 ///
93 /// The `constraint` predicate is used to decide which `vector.contraction` ops
94 /// to filter out.
96  RewritePatternSet &patterns,
97  std::function<LogicalResult(vector::ContractionOp)> constraint =
98  [](vector::ContractionOp) { return success(); },
99  PatternBenefit = 1);
100 
101 /// Collect patterns to convert reduction op to vector.contract and fold
102 /// transpose/broadcast ops into the contract.
103 void populateVectorReductionToContractPatterns(RewritePatternSet &patterns,
104  PatternBenefit benefit = 1);
105 
106 /// Populate `patterns` with the following patterns.
107 ///
108 /// - VectorTransferFullPartialRewriter
109 ///
110 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
111 /// masking) fast path and a slow path.
112 ///
113 /// Example (a 2-D vector.transfer_read):
114 /// ```
115 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
116 /// ```
117 /// is transformed into:
118 /// ```
119 /// %1:3 = scf.if (%inBounds) {
120 /// // fast path, direct cast
121 /// memref.cast %A: memref<A...> to compatibleMemRefType
122 /// scf.yield %view : compatibleMemRefType, index, index
123 /// } else {
124 /// // slow path, not in-bounds vector.transfer or linalg.copy.
125 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
126 /// scf.yield %4 : compatibleMemRefType, index, index
127 // }
128 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
129 /// ```
130 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
131 ///
132 /// Preconditions:
133 /// 1. `xferOp.permutation_map()` must be a minor identity map
134 /// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
135 /// must be equal. This will be relaxed in the future but requires
136 /// rank-reducing subviews.
138  RewritePatternSet &patterns, const VectorTransformsOptions &options);
139 
140 /// Collect a set of patterns to reduce the rank of the operands of vector
141 /// transfer ops to operate on the largest contigious vector.
142 /// These patterns are useful when lowering to dialects with 1d vector type
143 /// such as llvm and it will result fewer memory reads.
145  RewritePatternSet &patterns, PatternBenefit benefit = 1);
146 
147 /// Patterns that remove redundant vector broadcasts.
148 void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
149  PatternBenefit benefit = 1);
150 
151 /// Patterns that fold chained vector reductions. These patterns assume that
152 /// elementwise operations (e.g., `arith.addf` with vector operands) are
153 /// cheaper than vector reduction.
154 /// Note that these patterns change the order of reduction which may not always
155 /// produce bit-identical results on some floating point inputs.
156 ///
157 /// Example:
158 /// ```
159 /// %a = vector.reduction <add> %x, %acc
160 /// %b = vector.reduction <add> %y, %a
161 /// ```
162 /// is transformed into:
163 /// ```
164 /// %a = arith.addf %x, %y
165 /// %b = vector.reduction <add> %a, %acc
166 /// ```
167 void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,
168  PatternBenefit benefit = 1);
169 
170 /// Patterns to break down vector reductions into a series of arith reductions
171 /// over vector elements. This is intended to be simplify code with reductions
172 /// over small vector types and avoid more specialized reduction lowering when
173 /// possible.
174 ///
175 /// Example:
176 /// ```
177 /// %a = vector.reduction <add> %x : vector<2xf32> into f32
178 /// ```
179 /// is transformed into:
180 /// ```
181 /// %y = vector.extract %x[0] : f32 from vector<2xf32>
182 /// %z = vector.extract %x[1] : f32 from vector<2xf32>
183 /// %a = arith.addf %y, %z : f32
184 /// ```
186  RewritePatternSet &patterns, unsigned maxNumElementsToExtract = 2,
187  PatternBenefit benefit = 1);
188 
189 /// Populate `patterns` with the following patterns.
190 ///
191 /// [DecomposeDifferentRankInsertStridedSlice]
192 /// ==========================================
193 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
194 /// have different ranks.
195 ///
196 /// When ranks are different, InsertStridedSlice needs to extract a properly
197 /// ranked vector from the destination vector into which to insert. This pattern
198 /// only takes care of this extraction part and forwards the rest to
199 /// [VectorInsertStridedSliceOpSameRankRewritePattern].
200 ///
201 /// For a k-D source and n-D destination vector (k < n), we emit:
202 /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
203 /// insert the k-D source.
204 /// 2. k-D -> (n-1)-D InsertStridedSlice op
205 /// 3. InsertOp that is the reverse of 1.
206 ///
207 /// [DecomposeNDExtractStridedSlice]
208 /// ================================
209 /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
210 /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
212  RewritePatternSet &patterns, PatternBenefit benefit = 1);
213 
214 /// Populate `patterns` with a pattern to breaks down 1-D extract_strided_slice
215 /// ops into a chain of Extract ops to extract each element from the source, and
216 /// then a chain of Insert ops to insert to the target vector.
217 ///
218 /// If `controlFn` is not nullptr, the pattern will only be invoked on ops that
219 /// `controlFn` returns true. Otherwise runs on ops.
221  RewritePatternSet &patterns,
222  std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
223  PatternBenefit benefit = 1);
224 
225 /// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops
226 /// based on the destination vector shape. Bitcasts from a lower bitwidth
227 /// element type to a higher bitwidth one are extracted from the lower bitwidth
228 /// based on the native destination vector shape and inserted based on the ratio
229 /// of the bitwidths.
230 ///
231 /// This acts as a last resort way to break down vector.bitcast ops to smaller
232 /// vector sizes. Because this pattern composes until it is bitcasting to a
233 /// single element of the higher bitwidth, the is an optional control function.
234 /// If `controlFn` is not nullptr, the pattern will only apply to ops where
235 /// `controlFn` returns true, otherwise applies to all bitcast ops.
237  RewritePatternSet &patterns,
238  std::function<bool(BitCastOp)> controlFn = nullptr,
239  PatternBenefit benefit = 1);
240 
241 /// Populate `patterns` with the following patterns.
242 ///
243 /// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns();
244 ///
245 /// [ConvertSameRankInsertStridedSliceIntoShuffle]
246 /// ==============================================
247 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
248 /// have the same rank. For each outermost index in the slice:
249 /// begin end stride
250 /// [offset : offset+size*stride : stride]
251 /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
252 /// 2. InsertStridedSlice (k-1)-D into (n-1)-D
253 /// 3. the destination subvector is inserted back in the proper place
254 /// 3. InsertOp that is the reverse of 1.
255 ///
256 /// [Convert1DExtractStridedSliceIntoShuffle]
257 /// =========================================
258 /// For such cases, we can lower it to a ShuffleOp.
260  RewritePatternSet &patterns, PatternBenefit benefit = 1);
261 
262 /// Collect a set of pattern to unroll vector operations to a smaller shapes.
263 /// `options` structure controls which operations are unrolled and the target
264 /// shape.
265 /// `op` is unrolled to the `targetShape` as follows, for each of its operands:
266 /// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
267 /// `numUnrolledInstances` are computed from the `targetShape`. For now it is
268 /// assumed the unrolling factors divide the vector sizes.
269 /// 2. ExtractStridedSlice are created to break-up the vector operands.
270 /// 3. the original op is cloned `numUnrolledInstances` times, once for each
271 /// result.
272 /// 4. InsertStridedSlice are inserted to re-assemble the slices into the
273 /// original vectore shape.
274 ///
275 /// Example:
276 ///
277 /// opA(operand0, operand1) // numUnrolledInstances = 3
278 ///
279 /// operand0 operand1
280 /// | |
281 /// fork fork
282 /// <----------gather all fork ops --------->
283 /// /|\ /|\
284 /// f00 f01 f02 f10 f11 f12
285 /// <---------- clone op 3 times --------->
286 /// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
287 /// \ | /
288 /// <-------------------- join ------------------------->
289 ///
290 /// Other local patterns then kick in iteratively (including DCE) and compose
291 /// to combine the ExtractStridedSlice/InsertStridedSlice.
294  PatternBenefit benefit = 1);
295 
296 /// Collect a set of vector.shape_cast folding patterns.
298  PatternBenefit benefit = 1);
299 
300 /// Collect a set of leading one dimension removal patterns.
301 ///
302 /// These patterns insert vector.shape_cast to remove leading one dimensions
303 /// to expose more canonical forms of read/write/insert/extract operations.
304 /// With them, there are more chances that we can cancel out extract-insert
305 /// pairs or forward write-read pairs.
307  PatternBenefit benefit = 1);
308 
309 /// Collect a set of one dimension removal patterns.
310 ///
311 /// These patterns insert rank-reducing memref.subview ops to remove one
312 /// dimensions. With them, there are more chances that we can avoid
313 /// potentially expensive vector.shape_cast operations.
315  PatternBenefit benefit = 1);
316 
317 /// Collect a set of patterns that use vector.shape_cast to help fold unit dims.
318 ///
319 /// These patterns use vector.shape_cast to remove unit dims from e.g.
320 /// arithmetic operations on Vectors. The newly inserted shape_casts will either
321 /// cancel each other out or will be folded away when combined with other
322 /// patterns.
324  PatternBenefit benefit = 1);
325 
326 /// Collect a set of patterns to flatten n-D vector transfers on contiguous
327 /// memref.
328 ///
329 /// These patterns insert memref.collapse_shape + vector.shape_cast patterns
330 /// to transform multiple small n-D transfers into a larger 1-D transfer where
331 /// the memref contiguity properties allow it.
332 ///
333 /// Flattening is only applied if the bitwidth of the trailing vector dimension
334 /// is smaller or equal to `targetVectorBitwidth`.
336  RewritePatternSet &patterns,
337  unsigned targetVectorBitwidth = std::numeric_limits<unsigned>::max(),
338  PatternBenefit benefit = 1);
339 
340 /// Collect a set of patterns that bubble up/down bitcast ops.
341 ///
342 /// These patterns move vector.bitcast ops to be before insert ops or after
343 /// extract ops where suitable. With them, bitcast will happen on smaller
344 /// vectors and there are more chances to share extract/insert ops.
346  PatternBenefit benefit = 1);
347 
348 /// These patterns materialize masks for various vector ops such as transfers.
350  bool force32BitVectorIndices,
351  PatternBenefit benefit = 1);
352 
353 /// Appends patterns for emulating vector operations over narrow types with ops
354 /// over wider types.
357  RewritePatternSet &patterns);
358 
359 /// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
360 /// vector operations comprising `shuffle` and `bitwise` ops.
361 /// Warning: these patterns currently only work for little endian targets.
362 FailureOr<Value> rewriteBitCastOfTruncI(RewriterBase &rewriter,
363  vector::BitCastOp bitCastOp,
364  arith::TruncIOp truncOp,
365  vector::BroadcastOp maybeBroadcastOp);
366 
367 /// Rewrite a vector `ext(bitcast)` to use a more efficient sequence of
368 /// vector operations comprising `shuffle` and `bitwise` ops.
369 /// Warning: these patterns currently only work for little endian targets.
370 FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
371  vector::BitCastOp bitCastOp,
372  vector::BroadcastOp maybeBroadcastOp);
373 
374 /// Appends patterns for rewriting vector operations over narrow types with
375 /// ops over wider types.
376 /// Warning: these patterns currently only work for little endian targets.
378  PatternBenefit benefit = 1);
379 
380 /// Appends patterns for emulating a sub-byte vector transpose.
382  RewritePatternSet &patterns, PatternBenefit benefit = 1);
383 
384 /// Populates patterns for ND vectors (N >= 2) linearization and sets up the
385 /// provided ConversionTarget with the appropriate legality configuration for
386 /// the ops to get converted properly.
388  TypeConverter &typeConverter, RewritePatternSet &patterns,
389  ConversionTarget &target, unsigned targetBitWidth);
390 
391 /// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
392 /// vector shuffle operations.
394  RewritePatternSet &patterns,
395  ConversionTarget &target,
396  unsigned targetBitWidth);
397 
398 } // namespace vector
399 } // namespace mlir
400 
401 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class describes a specific conversion target.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
Type conversion class.
Converts narrow integer or float types that are not supported by the target hardware to wider types.
void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that use vector.shape_cast to help fold unit dims.
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for rewriting vector operations over narrow types with ops over wider types.
void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns, unsigned targetVectorBitwidth=std::numeric_limits< unsigned >::max(), PatternBenefit benefit=1)
Collect a set of patterns to flatten n-D vector transfers on contiguous memref.
void populateBreakDownVectorBitCastOpPatterns(RewritePatternSet &patterns, std::function< bool(BitCastOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to break down 1-D vector.bitcast ops based on the destination vector...
void populateVectorExtractStridedSliceToExtractInsertChainPatterns(RewritePatternSet &patterns, std::function< bool(ExtractStridedSliceOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to breaks down 1-D extract_strided_slice ops into a chain of Extract...
void populateVectorLinearizeShuffleLikeOpsPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth)
Populates patterns for linearizing ND (N >= 2) vector operations to 1D vector shuffle operations.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of one dimension removal patterns.
void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that fold chained vector reductions.
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that bubble up/down bitcast ops.
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
FailureOr< Value > rewriteBitCastOfTruncI(RewriterBase &rewriter, vector::BitCastOp bitCastOp, arith::TruncIOp truncOp, vector::BroadcastOp maybeBroadcastOp)
Rewrite a vector bitcast(trunci) to use a more efficient sequence of vector operations comprising shu...
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit=1)
These patterns materialize masks for various vector ops such as transfers.
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant vector broadcasts.
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
void populateVectorNarrowTypeEmulationPatterns(arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating vector operations over narrow types with ops over wider types.
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
void populateBreakDownVectorReductionPatterns(RewritePatternSet &patterns, unsigned maxNumElementsToExtract=2, PatternBenefit benefit=1)
Patterns to break down vector reductions into a series of arith reductions over vector elements.
void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of leading one dimension removal patterns.
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorTransposeNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for emulating a sub-byte vector transpose.
FailureOr< Value > rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp, vector::BitCastOp bitCastOp, vector::BroadcastOp maybeBroadcastOp)
Rewrite a vector ext(bitcast) to use a more efficient sequence of vector operations comprising shuffl...
void populateVectorLinearizeTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth)
Populates patterns for ND vectors (N >= 2) linearization and sets up the provided ConversionTarget wi...
Include the generated interface declarations.
Options that control the vector unrolling.
FilterConstraintFnType filterConstraint
Callback function that indicates whether vector unrolling should be attempted on the operation.
UnrollVectorOptions & setFilterConstraint(FilterConstraintFnType constraint)
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
UnrollVectorOptions & setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn)
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> UnrollTraversalOrderFnType
Function that returns the traversal order (in terms of "for loop order", i.e.
std::function< LogicalResult(Operation *op)> FilterConstraintFnType
NativeShapeFnType nativeShape
Function that returns the shape of the vector to unroll to for a given operation.
UnrollVectorOptions & setNativeShape(ArrayRef< int64_t > shape)
Set the native shape to use for unrolling.
UnrollTraversalOrderFnType traversalOrderCallback
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> NativeShapeFnType