MLIR  20.0.0git
VectorRewritePatterns.h
Go to the documentation of this file.
1 //===- VectorRewritePatterns.h - Vector rewrite patterns --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
11 
12 #include <optional>
13 #include <utility>
14 
17 #include "mlir/IR/PatternMatch.h"
18 
19 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
20 
21 namespace mlir {
22 class ConversionTarget;
23 class RewritePatternSet;
24 class TypeConverter;
25 
26 namespace arith {
27 class AndIOp;
28 class NarrowTypeEmulationConverter;
29 class TruncIOp;
30 } // namespace arith
31 
32 namespace vector {
33 struct VectorTransformsOptions;
34 
35 /// Options that control the vector unrolling.
37  using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
38  /// Callback function that indicates whether vector unrolling should be
39  /// attempted on the operation.
42  filterConstraint = std::move(constraint);
43  return *this;
44  }
45 
47  std::function<std::optional<SmallVector<int64_t>>(Operation *op)>;
48  /// Function that returns the shape of the vector to unroll to for a given
49  /// operation. The unrolling is aborted if the function returns
50  /// `std::nullopt`.
53  nativeShape = std::move(fn);
54  return *this;
55  }
56 
57  /// Set the native shape to use for unrolling.
59  SmallVector<int64_t> tsShape(shape);
60  nativeShape = [=](Operation *) -> std::optional<SmallVector<int64_t>> {
61  return tsShape;
62  };
63  return *this;
64  }
65 
66  /// Function that returns the traversal order (in terms of "for loop order",
67  /// i.e. slowest varying dimension to fastest varying dimension) that should
68  /// be used when unrolling the given operation into units of the native vector
69  /// size.
71  std::function<std::optional<SmallVector<int64_t>>(Operation *op)>;
75  traversalOrderCallback = std::move(traversalOrderFn);
76  return *this;
77  }
78 };
79 
80 /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
81 /// semantics to a contraction with MMT semantics (matrix matrix multiplication
82 /// with the RHS transposed). This specific form is meant to have the vector
83 /// operands are organized such that the reduction dimension is contiguous.
84 /// Example:
85 /// ```
86 /// vector.contract {indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
87 /// affine_map<(m, n, k) -> (n, k)>,
88 /// affine_map<(m, n, k) -> (m, n)>],
89 /// iterator_types = ["parallel", "parallel", "reduction"],
90 /// kind = #vector.kind<add>} %a, %b, %c : ...
91 /// ```
92 ///
93 /// The `constraint` predicate is used to decide which `vector.contraction` ops
94 /// to filter out.
96  RewritePatternSet &patterns,
97  std::function<LogicalResult(vector::ContractionOp)> constraint =
98  [](vector::ContractionOp) { return success(); },
99  PatternBenefit = 1);
100 
101 /// Collect patterns to convert reduction op to vector.contract and fold
102 /// transpose/broadcast ops into the contract.
103 void populateVectorReductionToContractPatterns(RewritePatternSet &patterns,
104  PatternBenefit benefit = 1);
105 
106 /// Populate `patterns` with the following patterns.
107 ///
108 /// - VectorTransferFullPartialRewriter
109 ///
110 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
111 /// masking) fast path and a slow path.
112 ///
113 /// Example (a 2-D vector.transfer_read):
114 /// ```
115 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
116 /// ```
117 /// is transformed into:
118 /// ```
119 /// %1:3 = scf.if (%inBounds) {
120 /// // fast path, direct cast
121 /// memref.cast %A: memref<A...> to compatibleMemRefType
122 /// scf.yield %view : compatibleMemRefType, index, index
123 /// } else {
124 /// // slow path, not in-bounds vector.transfer or linalg.copy.
125 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
126 /// scf.yield %4 : compatibleMemRefType, index, index
127 // }
128 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
129 /// ```
130 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
131 ///
132 /// Preconditions:
133 /// 1. `xferOp.permutation_map()` must be a minor identity map
134 /// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
135 /// must be equal. This will be relaxed in the future but requires
136 /// rank-reducing subviews.
138  RewritePatternSet &patterns, const VectorTransformsOptions &options);
139 
140 /// Collect a set of patterns to reduce the rank of the operands of vector
141 /// transfer ops to operate on the largest contigious vector.
142 /// These patterns are useful when lowering to dialects with 1d vector type
143 /// such as llvm and it will result fewer memory reads.
145  RewritePatternSet &patterns, PatternBenefit benefit = 1);
146 
147 /// Patterns that remove redundant Vector Ops by re-ordering them with
148 /// e.g. elementwise Ops:
149 /// ```
150 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
151 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
152 /// %r = arith.addf %at, %bt : vector<2x4xf32>
153 /// ```
154 /// gets converted to:
155 /// ```
156 /// %0 = arith.addf %a, %b : vector<4x2xf32>
157 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
158 /// ```
159 /// At the moment, these patterns are limited to vector.broadcast and
160 /// vector.transpose.
161 void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
162  PatternBenefit benefit = 1);
163 
164 /// Patterns that fold chained vector reductions. These patterns assume that
165 /// elementwise operations (e.g., `arith.addf` with vector operands) are
166 /// cheaper than vector reduction.
167 /// Note that these patterns change the order of reduction which may not always
168 /// produce bit-identical results on some floating point inputs.
169 ///
170 /// Example:
171 /// ```
172 /// %a = vector.reduction <add> %x, %acc
173 /// %b = vector.reduction <add> %y, %a
174 /// ```
175 /// is transformed into:
176 /// ```
177 /// %a = arith.addf %x, %y
178 /// %b = vector.reduction <add> %a, %acc
179 /// ```
180 void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,
181  PatternBenefit benefit = 1);
182 
183 /// Patterns to break down vector reductions into a series of arith reductions
184 /// over vector elements. This is intended to be simplify code with reductions
185 /// over small vector types and avoid more specialized reduction lowering when
186 /// possible.
187 ///
188 /// Example:
189 /// ```
190 /// %a = vector.reduction <add> %x : vector<2xf32> into f32
191 /// ```
192 /// is transformed into:
193 /// ```
194 /// %y = vector.extract %x[0] : f32 from vector<2xf32>
195 /// %z = vector.extract %x[1] : f32 from vector<2xf32>
196 /// %a = arith.addf %y, %z : f32
197 /// ```
199  RewritePatternSet &patterns, unsigned maxNumElementsToExtract = 2,
200  PatternBenefit benefit = 1);
201 
202 /// Populate `patterns` with the following patterns.
203 ///
204 /// [DecomposeDifferentRankInsertStridedSlice]
205 /// ==========================================
206 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
207 /// have different ranks.
208 ///
209 /// When ranks are different, InsertStridedSlice needs to extract a properly
210 /// ranked vector from the destination vector into which to insert. This pattern
211 /// only takes care of this extraction part and forwards the rest to
212 /// [VectorInsertStridedSliceOpSameRankRewritePattern].
213 ///
214 /// For a k-D source and n-D destination vector (k < n), we emit:
215 /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
216 /// insert the k-D source.
217 /// 2. k-D -> (n-1)-D InsertStridedSlice op
218 /// 3. InsertOp that is the reverse of 1.
219 ///
220 /// [DecomposeNDExtractStridedSlice]
221 /// ================================
222 /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
223 /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
225  RewritePatternSet &patterns, PatternBenefit benefit = 1);
226 
227 /// Populate `patterns` with a pattern to breaks down 1-D extract_strided_slice
228 /// ops into a chain of Extract ops to extract each element from the source, and
229 /// then a chain of Insert ops to insert to the target vector.
230 ///
231 /// If `controlFn` is not nullptr, the pattern will only be invoked on ops that
232 /// `controlFn` returns true. Otherwise runs on ops.
234  RewritePatternSet &patterns,
235  std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
236  PatternBenefit benefit = 1);
237 
238 /// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops
239 /// based on the destination vector shape. Bitcasts from a lower bitwidth
240 /// element type to a higher bitwidth one are extracted from the lower bitwidth
241 /// based on the native destination vector shape and inserted based on the ratio
242 /// of the bitwidths.
243 ///
244 /// This acts as a last resort way to break down vector.bitcast ops to smaller
245 /// vector sizes. Because this pattern composes until it is bitcasting to a
246 /// single element of the higher bitwidth, the is an optional control function.
247 /// If `controlFn` is not nullptr, the pattern will only apply to ops where
248 /// `controlFn` returns true, otherwise applies to all bitcast ops.
250  RewritePatternSet &patterns,
251  std::function<bool(BitCastOp)> controlFn = nullptr,
252  PatternBenefit benefit = 1);
253 
254 /// Populate `patterns` with the following patterns.
255 ///
256 /// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns();
257 ///
258 /// [ConvertSameRankInsertStridedSliceIntoShuffle]
259 /// ==============================================
260 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
261 /// have the same rank. For each outermost index in the slice:
262 /// begin end stride
263 /// [offset : offset+size*stride : stride]
264 /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
265 /// 2. InsertStridedSlice (k-1)-D into (n-1)-D
266 /// 3. the destination subvector is inserted back in the proper place
267 /// 3. InsertOp that is the reverse of 1.
268 ///
269 /// [Convert1DExtractStridedSliceIntoShuffle]
270 /// =========================================
271 /// For such cases, we can lower it to a ShuffleOp.
273  RewritePatternSet &patterns, PatternBenefit benefit = 1);
274 
275 /// Collect a set of pattern to unroll vector operations to a smaller shapes.
276 /// `options` structure controls which operations are unrolled and the target
277 /// shape.
278 /// `op` is unrolled to the `targetShape` as follows, for each of its operands:
279 /// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
280 /// `numUnrolledInstances` are computed from the `targetShape`. For now it is
281 /// assumed the unrolling factors divide the vector sizes.
282 /// 2. ExtractStridedSlice are created to break-up the vector operands.
283 /// 3. the original op is cloned `numUnrolledInstances` times, once for each
284 /// result.
285 /// 4. InsertStridedSlice are inserted to re-assemble the slices into the
286 /// original vectore shape.
287 ///
288 /// Example:
289 ///
290 /// opA(operand0, operand1) // numUnrolledInstances = 3
291 ///
292 /// operand0 operand1
293 /// | |
294 /// fork fork
295 /// <----------gather all fork ops --------->
296 /// /|\ /|\
297 /// f00 f01 f02 f10 f11 f12
298 /// <---------- clone op 3 times --------->
299 /// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
300 /// \ | /
301 /// <-------------------- join ------------------------->
302 ///
303 /// Other local patterns then kick in iteratively (including DCE) and compose
304 /// to combine the ExtractStridedSlice/InsertStridedSlice.
307  PatternBenefit benefit = 1);
308 
309 /// Collect a set of vector.shape_cast folding patterns.
311  PatternBenefit benefit = 1);
312 
313 /// Collect a set of leading one dimension removal patterns.
314 ///
315 /// These patterns insert vector.shape_cast to remove leading one dimensions
316 /// to expose more canonical forms of read/write/insert/extract operations.
317 /// With them, there are more chances that we can cancel out extract-insert
318 /// pairs or forward write-read pairs.
320  PatternBenefit benefit = 1);
321 
322 /// Collect a set of one dimension removal patterns.
323 ///
324 /// These patterns insert rank-reducing memref.subview ops to remove one
325 /// dimensions. With them, there are more chances that we can avoid
326 /// potentially expensive vector.shape_cast operations.
328  PatternBenefit benefit = 1);
329 
330 /// Collect a set of patterns that use vector.shape_cast to help fold unit dims.
331 ///
332 /// These patterns use vector.shape_cast to remove unit dims from e.g.
333 /// arithmetic operations on Vectors. The newly inserted shape_casts will either
334 /// cancel each other out or will be folded away when combined with other
335 /// patterns.
337  PatternBenefit benefit = 1);
338 
339 /// Collect a set of patterns to flatten n-D vector transfers on contiguous
340 /// memref.
341 ///
342 /// These patterns insert memref.collapse_shape + vector.shape_cast patterns
343 /// to transform multiple small n-D transfers into a larger 1-D transfer where
344 /// the memref contiguity properties allow it.
345 ///
346 /// Flattening is only applied if the bitwidth of the trailing vector dimension
347 /// is smaller or equal to `targetVectorBitwidth`.
349  RewritePatternSet &patterns,
350  unsigned targetVectorBitwidth = std::numeric_limits<unsigned>::max(),
351  PatternBenefit benefit = 1);
352 
353 /// Collect a set of patterns that bubble up/down bitcast ops.
354 ///
355 /// These patterns move vector.bitcast ops to be before insert ops or after
356 /// extract ops where suitable. With them, bitcast will happen on smaller
357 /// vectors and there are more chances to share extract/insert ops.
359  PatternBenefit benefit = 1);
360 
361 /// These patterns materialize masks for various vector ops such as transfers.
363  bool force32BitVectorIndices,
364  PatternBenefit benefit = 1);
365 
366 /// Appends patterns for emulating vector operations over narrow types with ops
367 /// over wider types.
369  const arith::NarrowTypeEmulationConverter &typeConverter,
370  RewritePatternSet &patterns);
371 
372 /// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
373 /// vector operations comprising `shuffle` and `bitwise` ops.
374 /// Warning: these patterns currently only work for little endian targets.
375 FailureOr<Value> rewriteBitCastOfTruncI(RewriterBase &rewriter,
376  vector::BitCastOp bitCastOp,
377  arith::TruncIOp truncOp,
378  vector::BroadcastOp maybeBroadcastOp);
379 
380 /// Rewrite a vector `ext(bitcast)` to use a more efficient sequence of
381 /// vector operations comprising `shuffle` and `bitwise` ops.
382 /// Warning: these patterns currently only work for little endian targets.
383 FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
384  vector::BitCastOp bitCastOp,
385  vector::BroadcastOp maybeBroadcastOp);
386 
387 /// Appends patterns for rewriting vector operations over narrow types with
388 /// ops over wider types.
389 /// Warning: these patterns currently only work for little endian targets.
391  PatternBenefit benefit = 1);
392 
393 /// Appends patterns for emulating a sub-byte vector transpose.
395  RewritePatternSet &patterns, PatternBenefit benefit = 1);
396 
397 /// Populates patterns for ND vectors (N >= 2) linearization and sets up the
398 /// provided ConversionTarget with the appropriate legality configuration for
399 /// the ops to get converted properly.
401  TypeConverter &typeConverter, RewritePatternSet &patterns,
402  ConversionTarget &target, unsigned targetBitWidth);
403 
404 /// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
405 /// vector shuffle operations.
407  const TypeConverter &typeConverter, RewritePatternSet &patterns,
408  ConversionTarget &target, unsigned targetBitWidth);
409 
410 } // namespace vector
411 } // namespace mlir
412 
413 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class describes a specific conversion target.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
Type conversion class.
Converts narrow integer or float types that are not supported by the target hardware to wider types.
void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that use vector.shape_cast to help fold unit dims.
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for rewriting vector operations over narrow types with ops over wider types.
void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns, unsigned targetVectorBitwidth=std::numeric_limits< unsigned >::max(), PatternBenefit benefit=1)
Collect a set of patterns to flatten n-D vector transfers on contiguous memref.
void populateBreakDownVectorBitCastOpPatterns(RewritePatternSet &patterns, std::function< bool(BitCastOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to break down 1-D vector.bitcast ops based on the destination vector...
void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth)
Populates patterns for linearizing ND (N >= 2) vector operations to 1D vector shuffle operations.
void populateVectorExtractStridedSliceToExtractInsertChainPatterns(RewritePatternSet &patterns, std::function< bool(ExtractStridedSliceOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to breaks down 1-D extract_strided_slice ops into a chain of Extract...
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of one dimension removal patterns.
void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that fold chained vector reductions.
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that bubble up/down bitcast ops.
void populateVectorNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating vector operations over narrow types with ops over wider types.
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
FailureOr< Value > rewriteBitCastOfTruncI(RewriterBase &rewriter, vector::BitCastOp bitCastOp, arith::TruncIOp truncOp, vector::BroadcastOp maybeBroadcastOp)
Rewrite a vector bitcast(trunci) to use a more efficient sequence of vector operations comprising shu...
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit=1)
These patterns materialize masks for various vector ops such as transfers.
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
void populateBreakDownVectorReductionPatterns(RewritePatternSet &patterns, unsigned maxNumElementsToExtract=2, PatternBenefit benefit=1)
Patterns to break down vector reductions into a series of arith reductions over vector elements.
void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of leading one dimension removal patterns.
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorTransposeNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for emulating a sub-byte vector transpose.
FailureOr< Value > rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp, vector::BitCastOp bitCastOp, vector::BroadcastOp maybeBroadcastOp)
Rewrite a vector ext(bitcast) to use a more efficient sequence of vector operations comprising shuffl...
void populateVectorLinearizeTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth)
Populates patterns for ND vectors (N >= 2) linearization and sets up the provided ConversionTarget wi...
Include the generated interface declarations.
Options that control the vector unrolling.
FilterConstraintFnType filterConstraint
Callback function that indicates whether vector unrolling should be attempted on the operation.
UnrollVectorOptions & setFilterConstraint(FilterConstraintFnType constraint)
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
UnrollVectorOptions & setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn)
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> UnrollTraversalOrderFnType
Function that returns the traversal order (in terms of "for loop order", i.e.
std::function< LogicalResult(Operation *op)> FilterConstraintFnType
NativeShapeFnType nativeShape
Function that returns the shape of the vector to unroll to for a given operation.
UnrollVectorOptions & setNativeShape(ArrayRef< int64_t > shape)
Set the native shape to use for unrolling.
UnrollTraversalOrderFnType traversalOrderCallback
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> NativeShapeFnType