MLIR 22.0.0git
VectorRewritePatterns.h
Go to the documentation of this file.
1//===- VectorRewritePatterns.h - Vector rewrite patterns --------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
10#define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
11
12#include <optional>
13#include <utility>
14
18
19#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
20
21namespace mlir {
22class ConversionTarget;
24class TypeConverter;
25
26namespace arith {
27class AndIOp;
29class TruncIOp;
30} // namespace arith
31
32namespace vector {
34
35/// Options that control the vector unrolling.
37 using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
38 /// Callback function that indicates whether vector unrolling should be
39 /// attempted on the operation.
42 filterConstraint = std::move(constraint);
43 return *this;
44 }
45
47 std::function<std::optional<SmallVector<int64_t>>(Operation *op)>;
48 /// Function that returns the shape of the vector to unroll to for a given
49 /// operation. The unrolling is aborted if the function returns
50 /// `std::nullopt`.
53 nativeShape = std::move(fn);
54 return *this;
55 }
56
57 /// Set the native shape to use for unrolling.
60 nativeShape = [=](Operation *) -> std::optional<SmallVector<int64_t>> {
61 return tsShape;
62 };
63 return *this;
64 }
65
66 /// Function that returns the traversal order (in terms of "for loop order",
67 /// i.e. slowest varying dimension to fastest varying dimension) that should
68 /// be used when unrolling the given operation into units of the native vector
69 /// size.
71 std::function<std::optional<SmallVector<int64_t>>(Operation *op)>;
75 traversalOrderCallback = std::move(traversalOrderFn);
76 return *this;
77 }
78};
79
80/// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
81/// semantics to a contraction with MMT semantics (matrix matrix multiplication
82/// with the RHS transposed). This specific form is meant to have the vector
83/// operands are organized such that the reduction dimension is contiguous.
84/// Example:
85/// ```
86/// vector.contract {indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
87/// affine_map<(m, n, k) -> (n, k)>,
88/// affine_map<(m, n, k) -> (m, n)>],
89/// iterator_types = ["parallel", "parallel", "reduction"],
90/// kind = #vector.kind<add>} %a, %b, %c : ...
91/// ```
92///
93/// The `constraint` predicate is used to decide which `vector.contraction` ops
94/// to filter out.
97 std::function<LogicalResult(vector::ContractionOp)> constraint =
98 [](vector::ContractionOp) { return success(); },
99 PatternBenefit = 1);
100
101/// Collect patterns to convert reduction op to vector.contract and fold
102/// transpose/broadcast ops into the contract.
104 PatternBenefit benefit = 1);
105
106/// Populate `patterns` with the following patterns.
107///
108/// - VectorTransferFullPartialRewriter
109///
110/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
111/// masking) fast path and a slow path.
112///
113/// Example (a 2-D vector.transfer_read):
114/// ```
115/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
116/// ```
117/// is transformed into:
118/// ```
119/// %1:3 = scf.if (%inBounds) {
120/// // fast path, direct cast
121/// memref.cast %A: memref<A...> to compatibleMemRefType
122/// scf.yield %view : compatibleMemRefType, index, index
123/// } else {
124/// // slow path, not in-bounds vector.transfer or linalg.copy.
125/// memref.cast %alloc: memref<B...> to compatibleMemRefType
126/// scf.yield %4 : compatibleMemRefType, index, index
127// }
128/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
129/// ```
130/// where `alloc` is a top of the function alloca'ed buffer of one vector.
131///
132/// Preconditions:
133/// 1. `xferOp.permutation_map()` must be a minor identity map
134/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
135/// must be equal. This will be relaxed in the future but requires
136/// rank-reducing subviews.
138 RewritePatternSet &patterns, const VectorTransformsOptions &options);
139
140/// Collect a set of patterns to collapse the most inner unit dims in xfer Ops
141///
142/// These patters reduce the rank of the operands of vector transfer ops to
143/// operate on vectors without trailing unit dims. This helps reduce the rank of
144/// the operands, which can be helpful when lowering to dialects that only
145/// support 1D vector type such as LLVM.
147 PatternBenefit benefit = 1);
148
149/// Patterns that remove redundant Vector Ops by re-ordering them with
150/// e.g. elementwise Ops:
151/// ```
152/// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
153/// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
154/// %r = arith.addf %at, %bt : vector<2x4xf32>
155/// ```
156/// gets converted to:
157/// ```
158/// %0 = arith.addf %a, %b : vector<4x2xf32>
159/// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
160/// ```
161/// At the moment, these patterns are limited to vector.broadcast and
162/// vector.transpose.
163void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
164 PatternBenefit benefit = 1);
165
166/// Patterns that remove redundant Vector Ops by merging them with load/store
167/// ops
168/// ```
169/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
170/// vector.extract %0[1] : f32 from vector<4xf32>
171/// ```
172/// Gets converted to:
173/// ```
174/// %c1 = arith.constant 1 : index
175/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
176/// %1 = memref.load %arg0[%0] : memref<?xf32>
177void populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
178 PatternBenefit benefit = 1);
179
180/// Patterns that fold chained vector reductions. These patterns assume that
181/// elementwise operations (e.g., `arith.addf` with vector operands) are
182/// cheaper than vector reduction.
183/// Note that these patterns change the order of reduction which may not always
184/// produce bit-identical results on some floating point inputs.
185///
186/// Example:
187/// ```
188/// %a = vector.reduction <add> %x, %acc
189/// %b = vector.reduction <add> %y, %a
190/// ```
191/// is transformed into:
192/// ```
193/// %a = arith.addf %x, %y
194/// %b = vector.reduction <add> %a, %acc
195/// ```
196void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,
197 PatternBenefit benefit = 1);
198
199/// Patterns to break down vector reductions into a series of arith reductions
200/// over vector elements. This is intended to be simplify code with reductions
201/// over small vector types and avoid more specialized reduction lowering when
202/// possible.
203///
204/// Example:
205/// ```
206/// %a = vector.reduction <add> %x : vector<2xf32> into f32
207/// ```
208/// is transformed into:
209/// ```
210/// %y = vector.extract %x[0] : f32 from vector<2xf32>
211/// %z = vector.extract %x[1] : f32 from vector<2xf32>
212/// %a = arith.addf %y, %z : f32
213/// ```
214void populateBreakDownVectorReductionPatterns(
215 RewritePatternSet &patterns, unsigned maxNumElementsToExtract = 2,
216 PatternBenefit benefit = 1);
217
218/// Populate `patterns` with the following patterns.
219///
220/// [DecomposeDifferentRankInsertStridedSlice]
221/// ==========================================
222/// RewritePattern for InsertStridedSliceOp where source and destination vectors
223/// have different ranks.
224///
225/// When ranks are different, InsertStridedSlice needs to extract a properly
226/// ranked vector from the destination vector into which to insert. This pattern
227/// only takes care of this extraction part and forwards the rest to
228/// [VectorInsertStridedSliceOpSameRankRewritePattern].
229///
230/// For a k-D source and n-D destination vector (k < n), we emit:
231/// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
232/// insert the k-D source.
233/// 2. k-D -> (n-1)-D InsertStridedSlice op
234/// 3. InsertOp that is the reverse of 1.
235///
236/// [DecomposeNDExtractStridedSlice]
237/// ================================
238/// For such cases, we can rewrite it to ExtractOp + lower rank
239/// ExtractStridedSliceOp + InsertOp for the n-D case.
240void populateVectorInsertExtractStridedSliceDecompositionPatterns(
241 RewritePatternSet &patterns, PatternBenefit benefit = 1);
242
243/// Populate `patterns` with a pattern to breaks down 1-D extract_strided_slice
244/// ops into a chain of Extract ops to extract each element from the source, and
245/// then a chain of Insert ops to insert to the target vector.
246///
247/// If `controlFn` is not nullptr, the pattern will only be invoked on ops that
248/// `controlFn` returns true. Otherwise runs on ops.
249void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
250 RewritePatternSet &patterns,
251 std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
252 PatternBenefit benefit = 1);
253
254/// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops
255/// based on the destination vector shape. Bitcasts from a lower bitwidth
256/// element type to a higher bitwidth one are extracted from the lower bitwidth
257/// based on the native destination vector shape and inserted based on the ratio
258/// of the bitwidths.
259///
260/// This acts as a last resort way to break down vector.bitcast ops to smaller
261/// vector sizes. Because this pattern composes until it is bitcasting to a
262/// single element of the higher bitwidth, the is an optional control function.
263/// If `controlFn` is not nullptr, the pattern will only apply to ops where
264/// `controlFn` returns true, otherwise applies to all bitcast ops.
265void populateBreakDownVectorBitCastOpPatterns(
266 RewritePatternSet &patterns,
267 std::function<bool(BitCastOp)> controlFn = nullptr,
268 PatternBenefit benefit = 1);
269
270/// Populate `patterns` with the following patterns.
271///
272/// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns();
273///
274/// [ConvertSameRankInsertStridedSliceIntoShuffle]
275/// ==============================================
276/// RewritePattern for InsertStridedSliceOp where source and destination vectors
277/// have the same rank. For each outermost index in the slice:
278/// begin end stride
279/// [offset : offset+size*stride : stride]
280/// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
281/// 2. InsertStridedSlice (k-1)-D into (n-1)-D
282/// 3. the destination subvector is inserted back in the proper place
283/// 3. InsertOp that is the reverse of 1.
284///
285/// [Convert1DExtractStridedSliceIntoShuffle]
286/// =========================================
287/// For such cases, we can lower it to a ShuffleOp.
288void populateVectorInsertExtractStridedSliceTransforms(
289 RewritePatternSet &patterns, PatternBenefit benefit = 1);
290
291/// Collect a set of pattern to unroll vector operations to a smaller shapes.
292/// `options` structure controls which operations are unrolled and the target
293/// shape.
294/// `op` is unrolled to the `targetShape` as follows, for each of its operands:
295/// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
296/// `numUnrolledInstances` are computed from the `targetShape`. For now it is
297/// assumed the unrolling factors divide the vector sizes.
298/// 2. ExtractStridedSlice are created to break-up the vector operands.
299/// 3. the original op is cloned `numUnrolledInstances` times, once for each
300/// result.
301/// 4. InsertStridedSlice are inserted to re-assemble the slices into the
302/// original vectore shape.
303///
304/// Example:
305///
306/// opA(operand0, operand1) // numUnrolledInstances = 3
307///
308/// operand0 operand1
309/// | |
310/// fork fork
311/// <----------gather all fork ops --------->
312/// /|\ /|\
313/// f00 f01 f02 f10 f11 f12
314/// <---------- clone op 3 times --------->
315/// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
316/// \ | /
317/// <-------------------- join ------------------------->
318///
319/// Other local patterns then kick in iteratively (including DCE) and compose
320/// to combine the ExtractStridedSlice/InsertStridedSlice.
321void populateVectorUnrollPatterns(RewritePatternSet &patterns,
322 const UnrollVectorOptions &options,
323 PatternBenefit benefit = 1);
324
325/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
326/// outermost dimension of the operand.
327void populateVectorToElementsUnrollPatterns(RewritePatternSet &patterns,
328 PatternBenefit benefit = 1);
329
330/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
331/// outermost dimension.
332void populateVectorFromElementsUnrollPatterns(RewritePatternSet &patterns,
333 PatternBenefit benefit = 1);
334
335/// Collect a set of leading one dimension removal patterns.
336///
337/// These patterns insert vector.shape_cast to remove leading one dimensions
338/// to expose more canonical forms of read/write/insert/extract operations.
339/// With them, there are more chances that we can cancel out extract-insert
340/// pairs or forward write-read pairs.
341void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns,
342 PatternBenefit benefit = 1);
343
344/// Collect a set of one dimension removal patterns.
345///
346/// These patterns insert rank-reducing memref.subview ops to remove one
347/// dimensions. With them, there are more chances that we can avoid
348/// potentially expensive vector.shape_cast operations.
349void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns,
350 PatternBenefit benefit = 1);
351
352/// Collect a set of patterns that use vector.shape_cast to help fold unit dims.
353///
354/// These patterns use vector.shape_cast to remove unit dims from e.g.
355/// arithmetic operations on Vectors. The newly inserted shape_casts will either
356/// cancel each other out or will be folded away when combined with other
357/// patterns.
358void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns,
359 PatternBenefit benefit = 1);
360
361/// Collect a set of patterns to flatten n-D vector transfers on contiguous
362/// memref.
363///
364/// These patterns insert memref.collapse_shape + vector.shape_cast patterns
365/// to transform multiple small n-D transfers into a larger 1-D transfer where
366/// the memref contiguity properties allow it.
367///
368/// Flattening is only applied if the bitwidth of the trailing vector dimension
369/// is smaller or equal to `targetVectorBitwidth`.
370void populateFlattenVectorTransferPatterns(
371 RewritePatternSet &patterns,
372 unsigned targetVectorBitwidth = std::numeric_limits<unsigned>::max(),
373 PatternBenefit benefit = 1);
374
375/// Collect a set of patterns that bubble up/down bitcast ops.
376///
377/// These patterns move vector.bitcast ops to be before insert ops or after
378/// extract ops where suitable. With them, bitcast will happen on smaller
379/// vectors and there are more chances to share extract/insert ops.
380void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns,
381 PatternBenefit benefit = 1);
382
383/// These patterns materialize masks for various vector ops such as transfers.
384void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
385 bool force32BitVectorIndices,
386 PatternBenefit benefit = 1);
387
388/// Appends patterns for emulating vector operations over narrow types with ops
389/// over wider types. The `disableAtomicRMW` indicates whether to use a normal
390/// read-modify-write sequence instead of using `memref.generic_atomic_rmw` to
391/// perform subbyte storing.
392void populateVectorNarrowTypeEmulationPatterns(
393 const arith::NarrowTypeEmulationConverter &typeConverter,
394 RewritePatternSet &patterns, bool disableAtomicRMW = false);
395
396/// Populates patterns for both MeMref flattening and Vector narrow type
397/// emulation.
398///
399/// Patterns for narrow-type-emulation require "flattened" MemRef(s), so this
400/// composite populate* method can be used for narrow-type-emulation for Ops
401/// operating on MemRef(s) that are rank > 2.
402void populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
403 arith::NarrowTypeEmulationConverter &typeConverter,
404 RewritePatternSet &patterns);
405
406/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
407/// vector operations comprising `shuffle` and `bitwise` ops.
408/// Warning: these patterns currently only work for little endian targets.
409FailureOr<Value> rewriteBitCastOfTruncI(RewriterBase &rewriter,
410 vector::BitCastOp bitCastOp,
411 arith::TruncIOp truncOp,
412 vector::BroadcastOp maybeBroadcastOp);
413
414/// Rewrite a vector `ext(bitcast)` to use a more efficient sequence of
415/// vector operations comprising `shuffle` and `bitwise` ops.
416/// Warning: these patterns currently only work for little endian targets.
417FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
418 vector::BitCastOp bitCastOp,
419 vector::BroadcastOp maybeBroadcastOp);
420
421/// Appends patterns for rewriting vector operations over narrow types with
422/// ops over wider types.
423/// Warning: these patterns currently only work for little endian targets.
424void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
425 PatternBenefit benefit = 1);
426
427/// Appends patterns for emulating a sub-byte vector transpose.
428void populateVectorTransposeNarrowTypeRewritePatterns(
429 RewritePatternSet &patterns, PatternBenefit benefit = 1);
430
431/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
432///
433/// Definition: here 'linearization' means converting a single operation with
434/// 1+ vector operand/result of rank>1, into a new single operation whose
435/// vector operands and results are all of rank<=1.
436///
437/// This function registers (1) which operations are legal, and hence should not
438/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
439/// materialze the conversion (with shape_cast)
440///
441/// Note: the set of legal operations can be extended by a user if for example
442/// certain rank>1 vectors are considered valid, by adding additional
443/// dynamically legal ops to `conversionTarget`.
444///
445/// Further note: the choice to use a dialect conversion design for
446/// linearization is to make it easy to reuse generic structural type
447/// conversions for linearizing scf/cf/func operations
448void populateForVectorLinearize(TypeConverter &typeConverter,
449 ConversionTarget &conversionTarget);
450
451/// Populates `patterns` for ND vector (N >= 2) linearization. This currently
452/// contains patterns for converting ConstantLike, Vectorizable, and
453/// vector::BitCast ops.
454void populateVectorLinearizeBasePatterns(const TypeConverter &,
455 const ConversionTarget &,
456 RewritePatternSet &patterns);
457
458/// Populates `patterns` for linearizing ND (N >= 2) vector operations
459/// to 1D vector shuffle operations.
460void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
461 const ConversionTarget &,
462 RewritePatternSet &patterns);
463
464} // namespace vector
465} // namespace mlir
466
467#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
return success()
static llvm::ManagedStatic< PassManagerOptions > options
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Converts narrow integer or float types that are not supported by the target hardware to wider types.
void populateDropInnerMostUnitDimsXferOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to collapse the most inner unit dims in xfer Ops.
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
Options that control the vector unrolling.
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> UnrollTraversalOrderFnType
Function that returns the traversal order (in terms of "for loop order", i.e.
FilterConstraintFnType filterConstraint
Callback function that indicates whether vector unrolling should be attempted on the operation.
std::function< LogicalResult(Operation *op)> FilterConstraintFnType
NativeShapeFnType nativeShape
Function that returns the shape of the vector to unroll to for a given operation.
UnrollVectorOptions & setFilterConstraint(FilterConstraintFnType constraint)
UnrollVectorOptions & setNativeShape(ArrayRef< int64_t > shape)
Set the native shape to use for unrolling.
UnrollTraversalOrderFnType traversalOrderCallback
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
UnrollVectorOptions & setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn)
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> NativeShapeFnType
Structure to control the behavior of vector transform patterns.