MLIR  22.0.0git
VectorRewritePatterns.h
Go to the documentation of this file.
1 //===- VectorRewritePatterns.h - Vector rewrite patterns --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
11 
12 #include <optional>
13 #include <utility>
14 
17 #include "mlir/IR/PatternMatch.h"
18 
19 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
20 
21 namespace mlir {
22 class ConversionTarget;
23 class RewritePatternSet;
24 class TypeConverter;
25 
26 namespace arith {
27 class AndIOp;
28 class NarrowTypeEmulationConverter;
29 class TruncIOp;
30 } // namespace arith
31 
32 namespace vector {
33 struct VectorTransformsOptions;
34 
35 /// Options that control the vector unrolling.
37  using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
38  /// Callback function that indicates whether vector unrolling should be
39  /// attempted on the operation.
42  filterConstraint = std::move(constraint);
43  return *this;
44  }
45 
47  std::function<std::optional<SmallVector<int64_t>>(Operation *op)>;
48  /// Function that returns the shape of the vector to unroll to for a given
49  /// operation. The unrolling is aborted if the function returns
50  /// `std::nullopt`.
53  nativeShape = std::move(fn);
54  return *this;
55  }
56 
57  /// Set the native shape to use for unrolling.
59  SmallVector<int64_t> tsShape(shape);
60  nativeShape = [=](Operation *) -> std::optional<SmallVector<int64_t>> {
61  return tsShape;
62  };
63  return *this;
64  }
65 
66  /// Function that returns the traversal order (in terms of "for loop order",
67  /// i.e. slowest varying dimension to fastest varying dimension) that should
68  /// be used when unrolling the given operation into units of the native vector
69  /// size.
71  std::function<std::optional<SmallVector<int64_t>>(Operation *op)>;
75  traversalOrderCallback = std::move(traversalOrderFn);
76  return *this;
77  }
78 };
79 
80 /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
81 /// semantics to a contraction with MMT semantics (matrix matrix multiplication
82 /// with the RHS transposed). This specific form is meant to have the vector
83 /// operands are organized such that the reduction dimension is contiguous.
84 /// Example:
85 /// ```
86 /// vector.contract {indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
87 /// affine_map<(m, n, k) -> (n, k)>,
88 /// affine_map<(m, n, k) -> (m, n)>],
89 /// iterator_types = ["parallel", "parallel", "reduction"],
90 /// kind = #vector.kind<add>} %a, %b, %c : ...
91 /// ```
92 ///
93 /// The `constraint` predicate is used to decide which `vector.contraction` ops
94 /// to filter out.
97  std::function<LogicalResult(vector::ContractionOp)> constraint =
98  [](vector::ContractionOp) { return success(); },
99  PatternBenefit = 1);
100 
101 /// Collect patterns to convert reduction op to vector.contract and fold
102 /// transpose/broadcast ops into the contract.
104  PatternBenefit benefit = 1);
105 
106 /// Populate `patterns` with the following patterns.
107 ///
108 /// - VectorTransferFullPartialRewriter
109 ///
110 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
111 /// masking) fast path and a slow path.
112 ///
113 /// Example (a 2-D vector.transfer_read):
114 /// ```
115 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
116 /// ```
117 /// is transformed into:
118 /// ```
119 /// %1:3 = scf.if (%inBounds) {
120 /// // fast path, direct cast
121 /// memref.cast %A: memref<A...> to compatibleMemRefType
122 /// scf.yield %view : compatibleMemRefType, index, index
123 /// } else {
124 /// // slow path, not in-bounds vector.transfer or linalg.copy.
125 /// memref.cast %alloc: memref<B...> to compatibleMemRefType
126 /// scf.yield %4 : compatibleMemRefType, index, index
127 // }
128 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
129 /// ```
130 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
131 ///
132 /// Preconditions:
133 /// 1. `xferOp.permutation_map()` must be a minor identity map
134 /// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
135 /// must be equal. This will be relaxed in the future but requires
136 /// rank-reducing subviews.
138  RewritePatternSet &patterns, const VectorTransformsOptions &options);
139 
140 /// Collect a set of patterns to collapse the most inner unit dims in xfer Ops
141 ///
142 /// These patters reduce the rank of the operands of vector transfer ops to
143 /// operate on vectors without trailing unit dims. This helps reduce the rank of
144 /// the operands, which can be helpful when lowering to dialects that only
145 /// support 1D vector type such as LLVM.
147  PatternBenefit benefit = 1);
148 
149 /// Patterns that remove redundant Vector Ops by re-ordering them with
150 /// e.g. elementwise Ops:
151 /// ```
152 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
153 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
154 /// %r = arith.addf %at, %bt : vector<2x4xf32>
155 /// ```
156 /// gets converted to:
157 /// ```
158 /// %0 = arith.addf %a, %b : vector<4x2xf32>
159 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
160 /// ```
161 /// At the moment, these patterns are limited to vector.broadcast and
162 /// vector.transpose.
163 void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
164  PatternBenefit benefit = 1);
165 
166 /// Patterns that remove redundant Vector Ops by merging them with load/store
167 /// ops
168 /// ```
169 /// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
170 /// vector.extract %0[1] : f32 from vector<4xf32>
171 /// ```
172 /// Gets converted to:
173 /// ```
174 /// %c1 = arith.constant 1 : index
175 /// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
176 /// %1 = memref.load %arg0[%0] : memref<?xf32>
177 void populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
178  PatternBenefit benefit = 1);
179 
180 /// Patterns that fold chained vector reductions. These patterns assume that
181 /// elementwise operations (e.g., `arith.addf` with vector operands) are
182 /// cheaper than vector reduction.
183 /// Note that these patterns change the order of reduction which may not always
184 /// produce bit-identical results on some floating point inputs.
185 ///
186 /// Example:
187 /// ```
188 /// %a = vector.reduction <add> %x, %acc
189 /// %b = vector.reduction <add> %y, %a
190 /// ```
191 /// is transformed into:
192 /// ```
193 /// %a = arith.addf %x, %y
194 /// %b = vector.reduction <add> %a, %acc
195 /// ```
196 void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,
197  PatternBenefit benefit = 1);
198 
199 /// Patterns to break down vector reductions into a series of arith reductions
200 /// over vector elements. This is intended to be simplify code with reductions
201 /// over small vector types and avoid more specialized reduction lowering when
202 /// possible.
203 ///
204 /// Example:
205 /// ```
206 /// %a = vector.reduction <add> %x : vector<2xf32> into f32
207 /// ```
208 /// is transformed into:
209 /// ```
210 /// %y = vector.extract %x[0] : f32 from vector<2xf32>
211 /// %z = vector.extract %x[1] : f32 from vector<2xf32>
212 /// %a = arith.addf %y, %z : f32
213 /// ```
214 void populateBreakDownVectorReductionPatterns(
215  RewritePatternSet &patterns, unsigned maxNumElementsToExtract = 2,
216  PatternBenefit benefit = 1);
217 
218 /// Populate `patterns` with the following patterns.
219 ///
220 /// [DecomposeDifferentRankInsertStridedSlice]
221 /// ==========================================
222 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
223 /// have different ranks.
224 ///
225 /// When ranks are different, InsertStridedSlice needs to extract a properly
226 /// ranked vector from the destination vector into which to insert. This pattern
227 /// only takes care of this extraction part and forwards the rest to
228 /// [VectorInsertStridedSliceOpSameRankRewritePattern].
229 ///
230 /// For a k-D source and n-D destination vector (k < n), we emit:
231 /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
232 /// insert the k-D source.
233 /// 2. k-D -> (n-1)-D InsertStridedSlice op
234 /// 3. InsertOp that is the reverse of 1.
235 ///
236 /// [DecomposeNDExtractStridedSlice]
237 /// ================================
238 /// For such cases, we can rewrite it to ExtractOp + lower rank
239 /// ExtractStridedSliceOp + InsertOp for the n-D case.
240 void populateVectorInsertExtractStridedSliceDecompositionPatterns(
241  RewritePatternSet &patterns, PatternBenefit benefit = 1);
242 
243 /// Populate `patterns` with a pattern to breaks down 1-D extract_strided_slice
244 /// ops into a chain of Extract ops to extract each element from the source, and
245 /// then a chain of Insert ops to insert to the target vector.
246 ///
247 /// If `controlFn` is not nullptr, the pattern will only be invoked on ops that
248 /// `controlFn` returns true. Otherwise runs on ops.
249 void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
250  RewritePatternSet &patterns,
251  std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
252  PatternBenefit benefit = 1);
253 
254 /// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops
255 /// based on the destination vector shape. Bitcasts from a lower bitwidth
256 /// element type to a higher bitwidth one are extracted from the lower bitwidth
257 /// based on the native destination vector shape and inserted based on the ratio
258 /// of the bitwidths.
259 ///
260 /// This acts as a last resort way to break down vector.bitcast ops to smaller
261 /// vector sizes. Because this pattern composes until it is bitcasting to a
262 /// single element of the higher bitwidth, the is an optional control function.
263 /// If `controlFn` is not nullptr, the pattern will only apply to ops where
264 /// `controlFn` returns true, otherwise applies to all bitcast ops.
265 void populateBreakDownVectorBitCastOpPatterns(
266  RewritePatternSet &patterns,
267  std::function<bool(BitCastOp)> controlFn = nullptr,
268  PatternBenefit benefit = 1);
269 
270 /// Populate `patterns` with the following patterns.
271 ///
272 /// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns();
273 ///
274 /// [ConvertSameRankInsertStridedSliceIntoShuffle]
275 /// ==============================================
276 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
277 /// have the same rank. For each outermost index in the slice:
278 /// begin end stride
279 /// [offset : offset+size*stride : stride]
280 /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
281 /// 2. InsertStridedSlice (k-1)-D into (n-1)-D
282 /// 3. the destination subvector is inserted back in the proper place
283 /// 3. InsertOp that is the reverse of 1.
284 ///
285 /// [Convert1DExtractStridedSliceIntoShuffle]
286 /// =========================================
287 /// For such cases, we can lower it to a ShuffleOp.
288 void populateVectorInsertExtractStridedSliceTransforms(
289  RewritePatternSet &patterns, PatternBenefit benefit = 1);
290 
291 /// Collect a set of pattern to unroll vector operations to a smaller shapes.
292 /// `options` structure controls which operations are unrolled and the target
293 /// shape.
294 /// `op` is unrolled to the `targetShape` as follows, for each of its operands:
295 /// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
296 /// `numUnrolledInstances` are computed from the `targetShape`. For now it is
297 /// assumed the unrolling factors divide the vector sizes.
298 /// 2. ExtractStridedSlice are created to break-up the vector operands.
299 /// 3. the original op is cloned `numUnrolledInstances` times, once for each
300 /// result.
301 /// 4. InsertStridedSlice are inserted to re-assemble the slices into the
302 /// original vectore shape.
303 ///
304 /// Example:
305 ///
306 /// opA(operand0, operand1) // numUnrolledInstances = 3
307 ///
308 /// operand0 operand1
309 /// | |
310 /// fork fork
311 /// <----------gather all fork ops --------->
312 /// /|\ /|\
313 /// f00 f01 f02 f10 f11 f12
314 /// <---------- clone op 3 times --------->
315 /// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
316 /// \ | /
317 /// <-------------------- join ------------------------->
318 ///
319 /// Other local patterns then kick in iteratively (including DCE) and compose
320 /// to combine the ExtractStridedSlice/InsertStridedSlice.
321 void populateVectorUnrollPatterns(RewritePatternSet &patterns,
322  const UnrollVectorOptions &options,
323  PatternBenefit benefit = 1);
324 
325 /// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
326 /// outermost dimension of the operand.
327 void populateVectorToElementsUnrollPatterns(RewritePatternSet &patterns,
328  PatternBenefit benefit = 1);
329 
330 /// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
331 /// outermost dimension.
332 void populateVectorFromElementsUnrollPatterns(RewritePatternSet &patterns,
333  PatternBenefit benefit = 1);
334 
335 /// Collect a set of leading one dimension removal patterns.
336 ///
337 /// These patterns insert vector.shape_cast to remove leading one dimensions
338 /// to expose more canonical forms of read/write/insert/extract operations.
339 /// With them, there are more chances that we can cancel out extract-insert
340 /// pairs or forward write-read pairs.
341 void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns,
342  PatternBenefit benefit = 1);
343 
344 /// Collect a set of one dimension removal patterns.
345 ///
346 /// These patterns insert rank-reducing memref.subview ops to remove one
347 /// dimensions. With them, there are more chances that we can avoid
348 /// potentially expensive vector.shape_cast operations.
349 void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns,
350  PatternBenefit benefit = 1);
351 
352 /// Collect a set of patterns that use vector.shape_cast to help fold unit dims.
353 ///
354 /// These patterns use vector.shape_cast to remove unit dims from e.g.
355 /// arithmetic operations on Vectors. The newly inserted shape_casts will either
356 /// cancel each other out or will be folded away when combined with other
357 /// patterns.
358 void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns,
359  PatternBenefit benefit = 1);
360 
361 /// Collect a set of patterns to flatten n-D vector transfers on contiguous
362 /// memref.
363 ///
364 /// These patterns insert memref.collapse_shape + vector.shape_cast patterns
365 /// to transform multiple small n-D transfers into a larger 1-D transfer where
366 /// the memref contiguity properties allow it.
367 ///
368 /// Flattening is only applied if the bitwidth of the trailing vector dimension
369 /// is smaller or equal to `targetVectorBitwidth`.
370 void populateFlattenVectorTransferPatterns(
371  RewritePatternSet &patterns,
372  unsigned targetVectorBitwidth = std::numeric_limits<unsigned>::max(),
373  PatternBenefit benefit = 1);
374 
375 /// Collect a set of patterns that bubble up/down bitcast ops.
376 ///
377 /// These patterns move vector.bitcast ops to be before insert ops or after
378 /// extract ops where suitable. With them, bitcast will happen on smaller
379 /// vectors and there are more chances to share extract/insert ops.
380 void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns,
381  PatternBenefit benefit = 1);
382 
383 /// These patterns materialize masks for various vector ops such as transfers.
384 void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
385  bool force32BitVectorIndices,
386  PatternBenefit benefit = 1);
387 
388 /// Appends patterns for emulating vector operations over narrow types with ops
389 /// over wider types. The `disableAtomicRMW` indicates whether to use a normal
390 /// read-modify-write sequence instead of using `memref.generic_atomic_rmw` to
391 /// perform subbyte storing.
392 void populateVectorNarrowTypeEmulationPatterns(
393  const arith::NarrowTypeEmulationConverter &typeConverter,
394  RewritePatternSet &patterns, bool disableAtomicRMW = false);
395 
396 /// Populates patterns for both MeMref flattening and Vector narrow type
397 /// emulation.
398 ///
399 /// Patterns for narrow-type-emulation require "flattened" MemRef(s), so this
400 /// composite populate* method can be used for narrow-type-emulation for Ops
401 /// operating on MemRef(s) that are rank > 2.
402 void populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
403  arith::NarrowTypeEmulationConverter &typeConverter,
404  RewritePatternSet &patterns);
405 
406 /// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
407 /// vector operations comprising `shuffle` and `bitwise` ops.
408 /// Warning: these patterns currently only work for little endian targets.
409 FailureOr<Value> rewriteBitCastOfTruncI(RewriterBase &rewriter,
410  vector::BitCastOp bitCastOp,
411  arith::TruncIOp truncOp,
412  vector::BroadcastOp maybeBroadcastOp);
413 
414 /// Rewrite a vector `ext(bitcast)` to use a more efficient sequence of
415 /// vector operations comprising `shuffle` and `bitwise` ops.
416 /// Warning: these patterns currently only work for little endian targets.
417 FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
418  vector::BitCastOp bitCastOp,
419  vector::BroadcastOp maybeBroadcastOp);
420 
421 /// Appends patterns for rewriting vector operations over narrow types with
422 /// ops over wider types.
423 /// Warning: these patterns currently only work for little endian targets.
424 void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
425  PatternBenefit benefit = 1);
426 
427 /// Appends patterns for emulating a sub-byte vector transpose.
428 void populateVectorTransposeNarrowTypeRewritePatterns(
429  RewritePatternSet &patterns, PatternBenefit benefit = 1);
430 
431 /// Initialize `typeConverter` and `conversionTarget` for vector linearization.
432 ///
433 /// Definition: here 'linearization' means converting a single operation with
434 /// 1+ vector operand/result of rank>1, into a new single operation whose
435 /// vector operands and results are all of rank<=1.
436 ///
437 /// This function registers (1) which operations are legal, and hence should not
438 /// be linearized, (2) what the converted types are (rank-1 vectors) and how to
439 /// materialze the conversion (with shape_cast)
440 ///
441 /// Note: the set of legal operations can be extended by a user if for example
442 /// certain rank>1 vectors are considered valid, by adding additional
443 /// dynamically legal ops to `conversionTarget`.
444 ///
445 /// Further note: the choice to use a dialect conversion design for
446 /// linearization is to make it easy to reuse generic structural type
447 /// conversions for linearizing scf/cf/func operations
448 void populateForVectorLinearize(TypeConverter &typeConverter,
449  ConversionTarget &conversionTarget);
450 
451 /// Populates `patterns` for ND vector (N >= 2) linearization. This currently
452 /// contains patterns for converting ConstantLike, Vectorizable, and
453 /// vector::BitCast ops.
454 void populateVectorLinearizeBasePatterns(const TypeConverter &,
455  const ConversionTarget &,
456  RewritePatternSet &patterns);
457 
458 /// Populates `patterns` for linearizing ND (N >= 2) vector operations
459 /// to 1D vector shuffle operations.
460 void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
461  const ConversionTarget &,
462  RewritePatternSet &patterns);
463 
464 } // namespace vector
465 } // namespace mlir
466 
467 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void populateDropInnerMostUnitDimsXferOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to collapse the most inner unit dims in xfer Ops.
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
Options that control the vector unrolling.
FilterConstraintFnType filterConstraint
Callback function that indicates whether vector unrolling should be attempted on the operation.
UnrollVectorOptions & setFilterConstraint(FilterConstraintFnType constraint)
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
UnrollVectorOptions & setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn)
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> UnrollTraversalOrderFnType
Function that returns the traversal order (in terms of "for loop order", i.e.
std::function< LogicalResult(Operation *op)> FilterConstraintFnType
NativeShapeFnType nativeShape
Function that returns the shape of the vector to unroll to for a given operation.
UnrollVectorOptions & setNativeShape(ArrayRef< int64_t > shape)
Set the native shape to use for unrolling.
UnrollTraversalOrderFnType traversalOrderCallback
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> NativeShapeFnType