MLIR  17.0.0git
VectorRewritePatterns.h
Go to the documentation of this file.
1 //===- VectorRewritePatterns.h - Vector rewrite patterns --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
11 
12 #include <utility>
13 #include <optional>
14 
16 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/PatternMatch.h"
20 
21 namespace mlir {
22 class RewritePatternSet;
23 
24 namespace vector {
25 
26 //===----------------------------------------------------------------------===//
27 // Vector transformation options exposed as auxiliary structs.
28 //===----------------------------------------------------------------------===//
29 /// Structure to control the behavior of vector transform patterns.
31  /// Option to control the lowering of vector.contract.
32  VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
34  setVectorTransformsOptions(VectorContractLowering opt) {
36  return *this;
37  }
38  /// Option to control the lowering of vector.multi_reduction.
39  VectorMultiReductionLowering vectorMultiReductionLowering =
40  VectorMultiReductionLowering::InnerParallel;
42  setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
44  return *this;
45  }
46  /// Option to control the lowering of vector.transpose.
47  VectorTransposeLowering vectorTransposeLowering =
48  VectorTransposeLowering::EltWise;
50  setVectorTransposeLowering(VectorTransposeLowering opt) {
52  return *this;
53  }
54  /// Option to control the splitting of vector transfers.
56  VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
57  vectorTransferSplit = opt;
58  return *this;
59  }
60 };
61 
62 /// Options that control the vector unrolling.
64  using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
65  /// Callback function that indicates whether vector unrolling should be
66  /// attempted on the operation.
69  filterConstraint = std::move(constraint);
70  return *this;
71  }
72 
74  std::function<std::optional<SmallVector<int64_t>>(Operation *op)>;
75  /// Function that returns the shape of the vector to unroll to for a given
76  /// operation. The unrolling is aborted if the function returns
77  /// `std::nullopt`.
80  nativeShape = std::move(fn);
81  return *this;
82  }
83 
84  /// Set the native shape to use for unrolling.
86  SmallVector<int64_t> tsShape(shape.begin(), shape.end());
87  nativeShape = [=](Operation *) -> std::optional<SmallVector<int64_t>> {
88  return tsShape;
89  };
90  return *this;
91  }
92 
93  /// Function that returns the traversal order (in terms of "for loop order",
94  /// i.e. slowest varying dimension to fastest varying dimension) that shoudl
95  /// be used when unrolling the given operation into units of the native vector
96  /// size.
98  std::function<std::optional<SmallVector<int64_t>>(Operation *op)>;
102  traversalOrderCallback = std::move(traversalOrderFn);
103  return *this;
104  }
105 };
106 
107 //===----------------------------------------------------------------------===//
108 // Vector transformation exposed as populate functions over rewrite patterns.
109 //===----------------------------------------------------------------------===//
110 
111 /// Insert TransposeLowering patterns into extraction/insertion.
113  RewritePatternSet &patterns,
114  VectorTransformsOptions options = VectorTransformsOptions(),
115  PatternBenefit benefit = 1);
116 
117 /// Collect a set of patterns to convert vector.multi_reduction op into
118 /// a sequence of vector.reduction ops. The patterns comprise:
119 /// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such
120 /// that all reduction dimensions are either innermost or outermost, by adding
121 /// the proper vector.transpose operations.
122 /// - ReduceMultiDimReductionRank: once in innermost or outermost reduction
123 /// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
124 /// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
125 /// back.
126 /// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction
127 /// form, with an **outermost** reduction dimension, unroll the outer dimension
128 /// to obtain a sequence of 1-D vector ops. This also has an opportunity for
129 /// tree-reduction (in the future).
130 /// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form,
131 /// with an **innermost** reduction dimension, unroll the outer dimension to
132 /// obtain a sequence of extract + vector.reduction + insert. This can further
133 /// lower to horizontal reduction ops.
134 /// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector<k>
135 /// reduction (and are thus missing either a parallel or a reduction), we lift
136 /// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that
137 /// the other patterns can kick in, thus fully exiting out of the
138 /// vector.multi_reduction abstraction.
140  RewritePatternSet &patterns, VectorMultiReductionLowering options,
141  PatternBenefit benefit = 1);
142 
143 /// Collects patterns to progressively lower vector contraction ops on high-D
144 /// into low-D reduction and product ops.
146  RewritePatternSet &patterns,
147  VectorTransformsOptions options = VectorTransformsOptions(),
148  PatternBenefit benefit = 1);
149 
150 /// Collect patterns to convert reduction op to vector.contract and fold
151 /// transpose/broadcast ops into the contract.
153  PatternBenefit benefit = 1);
154 
155 /// Collect patterns to convert scan op
157  PatternBenefit benefit = 1);
158 
159 //===----------------------------------------------------------------------===//
160 // Vector.transfer patterns.
161 //===----------------------------------------------------------------------===//
162 /// Collect a set of transfer read/write lowering patterns that simplify the
163 /// permutation map (e.g., converting it to a minor identity map) by inserting
164 /// broadcasts and transposes. More specifically:
165 ///
166 /// [TransferReadPermutationLowering]
167 /// Lower transfer_read op with permutation into a transfer_read with a
168 /// permutation map composed of leading zeros followed by a minor identity +
169 /// vector.transpose op.
170 /// Ex:
171 /// vector.transfer_read ...
172 /// permutation_map: (d0, d1, d2) -> (0, d1)
173 /// into:
174 /// %v = vector.transfer_read ...
175 /// permutation_map: (d0, d1, d2) -> (d1, 0)
176 /// vector.transpose %v, [1, 0]
177 ///
178 /// vector.transfer_read ...
179 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
180 /// into:
181 /// %v = vector.transfer_read ...
182 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
183 /// vector.transpose %v, [0, 1, 3, 2, 4]
184 /// Note that an alternative is to transform it to linalg.transpose +
185 /// vector.transfer_read to do the transpose in memory instead.
186 ///
187 /// [TransferWritePermutationLowering]
188 /// Lower transfer_write op with permutation into a transfer_write with a
189 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
190 /// Ex:
191 /// vector.transfer_write %v ...
192 /// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
193 /// into:
194 /// %tmp = vector.transpose %v, [2, 0, 1]
195 /// vector.transfer_write %tmp ...
196 /// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
197 ///
198 /// vector.transfer_write %v ...
199 /// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
200 /// into:
201 /// %tmp = vector.transpose %v, [1, 0]
202 /// %v = vector.transfer_write %tmp ...
203 /// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
204 ///
205 /// [TransferOpReduceRank]
206 /// Lower transfer_read op with broadcast in the leading dimensions into
207 /// transfer_read of lower rank + vector.broadcast.
208 /// Ex: vector.transfer_read ...
209 /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
210 /// into:
211 /// %v = vector.transfer_read ...
212 /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
213 /// vector.broadcast %v
215  RewritePatternSet &patterns, PatternBenefit benefit = 1);
216 
217 /// Collect a set of patterns to reduce the rank of the operands of vector
218 /// transfer ops to operate on the largest contigious vector.
219 /// These patterns are useful when lowering to dialects with 1d vector type
220 /// such as llvm and it will result fewer memory reads.
222  RewritePatternSet &patterns, PatternBenefit benefit = 1);
223 
224 /// Populate `patterns` with the following patterns.
225 ///
226 /// [DecomposeDifferentRankInsertStridedSlice]
227 /// ==========================================
228 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
229 /// have different ranks.
230 ///
231 /// When ranks are different, InsertStridedSlice needs to extract a properly
232 /// ranked vector from the destination vector into which to insert. This pattern
233 /// only takes care of this extraction part and forwards the rest to
234 /// [VectorInsertStridedSliceOpSameRankRewritePattern].
235 ///
236 /// For a k-D source and n-D destination vector (k < n), we emit:
237 /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
238 /// insert the k-D source.
239 /// 2. k-D -> (n-1)-D InsertStridedSlice op
240 /// 3. InsertOp that is the reverse of 1.
241 ///
242 /// [DecomposeNDExtractStridedSlice]
243 /// ================================
244 /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
245 /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
247  RewritePatternSet &patterns, PatternBenefit benefit = 1);
248 
249 /// Populate `patterns` with a pattern to breaks down 1-D extract_strided_slice
250 /// ops into a chain of Extract ops to extract each element from the source, and
251 /// then a chain of Insert ops to insert to the target vector.
252 ///
253 /// If `controlFn` is not nullptr, the pattern will only be invoked on ops that
254 /// `controlFn` returns true. Otherwise runs on ops.
256  RewritePatternSet &patterns,
257  std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
258  PatternBenefit benefit = 1);
259 
260 /// Populate `patterns` with the following patterns.
261 ///
262 /// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns();
263 ///
264 /// [ConvertSameRankInsertStridedSliceIntoShuffle]
265 /// ==============================================
266 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
267 /// have the same rank. For each outermost index in the slice:
268 /// begin end stride
269 /// [offset : offset+size*stride : stride]
270 /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
271 /// 2. InsertStridedSlice (k-1)-D into (n-1)-D
272 /// 3. the destination subvector is inserted back in the proper place
273 /// 3. InsertOp that is the reverse of 1.
274 ///
275 /// [Convert1DExtractStridedSliceIntoShuffle]
276 /// =========================================
277 /// For such cases, we can lower it to a ShuffleOp.
279  RewritePatternSet &patterns, PatternBenefit benefit = 1);
280 
281 /// Collect a set of pattern to unroll vector operations to a smaller shapes.
282 /// `options` structure controls which operations are unrolled and the target
283 /// shape.
284 /// `op` is unrolled to the `targetShape` as follows, for each of its operands:
285 /// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
286 /// `numUnrolledInstances` are computed from the `targetShape`. For now it is
287 /// assumed the unrolling factors divide the vector sizes.
288 /// 2. ExtractStridedSlice are created to break-up the vector operands.
289 /// 3. the original op is cloned `numUnrolledInstances` times, once for each
290 /// result.
291 /// 4. InsertStridedSlice are inserted to re-assemble the slices into the
292 /// original vectore shape.
293 ///
294 /// Example:
295 ///
296 /// opA(operand0, operand1) // numUnrolledInstances = 3
297 ///
298 /// operand0 operand1
299 /// | |
300 /// fork fork
301 /// <----------gather all fork ops --------->
302 /// /|\ /|\
303 /// f00 f01 f02 f10 f11 f12
304 /// <---------- clone op 3 times --------->
305 /// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
306 /// \ | /
307 /// <-------------------- join ------------------------->
308 ///
309 /// Other local patterns then kick in iteratively (including DCE) and compose
310 /// to combine the ExtractStridedSlice/InsertStridedSlice.
312  const UnrollVectorOptions &options,
313  PatternBenefit benefit = 1);
314 
315 //===----------------------------------------------------------------------===//
316 // Finer-grained patterns exposed for more control over individual lowerings.
317 //===----------------------------------------------------------------------===//
318 /// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
319 /// may take an extra filter to perform selection at a finer granularity.
322  std::function<LogicalResult(VectorTransferOpInterface op)>;
323 
325  MLIRContext *context,
327  FilterConstraintType filter =
328  [](VectorTransferOpInterface op) { return success(); },
329  PatternBenefit benefit = 1)
330  : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
331  filter(std::move(filter)) {}
332 
333  /// Performs the rewrite.
335  PatternRewriter &rewriter) const override;
336 
337 private:
338  VectorTransformsOptions options;
339  FilterConstraintType filter;
340 };
341 
342 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
343 /// semantics to:
344 /// ```
345 /// %flattened_a = vector.shape_cast %a
346 /// %flattened_b = vector.shape_cast %b
347 /// %flattened_d = vector.matmul %flattened_a, %flattened_b
348 /// %d = vector.shape_cast %%flattened_d
349 /// %e = add %c, %d
350 /// ```
351 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
352 //
353 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
354 /// the vector.contract op is a row-major matrix multiply.
356  : public OpRewritePattern<vector::ContractionOp> {
357 public:
359 
361  std::function<LogicalResult(vector::ContractionOp op)>;
362 
363  static LogicalResult defaultFilter(vector::ContractionOp op) {
364  return success();
365  }
366 
368  vector::VectorTransformsOptions vectorTransformOptions,
369  MLIRContext *context, PatternBenefit benefit = 1,
370  FilterConstraintType constraint = defaultFilter)
371  : OpRewritePattern<vector::ContractionOp>(context, benefit),
372  vectorTransformOptions(vectorTransformOptions),
373  filter(std::move(constraint)) {}
374 
375  LogicalResult matchAndRewrite(vector::ContractionOp op,
376  PatternRewriter &rewriter) const override;
377 
378 private:
379  /// Options to control the vector patterns.
380  vector::VectorTransformsOptions vectorTransformOptions;
381  FilterConstraintType filter;
382 };
383 
384 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
385 /// semantics to a reduction_size-unrolled sequence:
386 /// ```
387 /// %at = vector.transpose %a, [1, 0]
388 /// %bRow0 = vector.extract %b[0]
389 /// %atRow0 = vector.extract %at[0]
390 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c
391 /// ...
392 /// %bRowK = vector.extract %b[K]
393 /// %atRowK = vector.extract %at[K]
394 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
395 /// ```
396 ///
397 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
398 /// the vector.contract op is a row-major matrix multiply.
400  : public OpRewritePattern<vector::ContractionOp> {
401 public:
403 
405  std::function<LogicalResult(vector::ContractionOp op)>;
406 
407  static LogicalResult defaultFilter(vector::ContractionOp op) {
408  return success();
409  }
410 
412  vector::VectorTransformsOptions vectorTransformOptions,
413  MLIRContext *context, PatternBenefit benefit = 1,
414  FilterConstraintType constraint = defaultFilter)
415  : OpRewritePattern<vector::ContractionOp>(context, benefit),
416  vectorTransformOptions(vectorTransformOptions),
417  filter(std::move(constraint)) {}
418 
419  LogicalResult matchAndRewrite(vector::ContractionOp op,
420  PatternRewriter &rewriter) const override;
421 
422 private:
423  /// Options to control the vector patterns.
424  vector::VectorTransformsOptions vectorTransformOptions;
425  FilterConstraintType filter;
426 };
427 
428 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
429 /// semantics to an output-size-unrolled sequence:
430 /// ```
431 /// %out = arith.constant ... : vector<MxNxelt_type>
432 /// %bt = vector.transpose %b, [1, 0]
433 /// %aRow0 = vector.extract %a[0]
434 /// %btRow0 = vector.extract %bt[0]
435 /// %c00 = vector.reduce %atRow0, %bRow0
436 /// %out00 = vector.insert %c00, %out[0, 0]
437 /// ...
438 /// %aRowLast = vector.extract %at[M-1]
439 /// %btRowLast = vector.extract %b[N-1]
440 /// %cLastLast = vector.reduce %atRowLast, %bRowLast
441 /// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
442 /// ```
443 ///
444 /// This only kicks in when VectorTransformsOptions is set to Dot and
445 /// the vector.contract op is a row-major matmul or matvec.
447  : public OpRewritePattern<vector::ContractionOp> {
448 public:
450 
452  std::function<LogicalResult(vector::ContractionOp op)>;
453 
454  static LogicalResult defaultFilter(vector::ContractionOp op) {
455  return success();
456  }
457 
459  vector::VectorTransformsOptions vectorTransformOptions,
460  MLIRContext *context, PatternBenefit benefit = 1,
461  const FilterConstraintType &constraint = defaultFilter)
462  : OpRewritePattern<vector::ContractionOp>(context, benefit),
463  vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
464 
465  LogicalResult matchAndRewrite(vector::ContractionOp op,
466  PatternRewriter &rewriter) const override;
467 
468 private:
469  /// Options to control the vector patterns.
470  vector::VectorTransformsOptions vectorTransformOptions;
471  FilterConstraintType filter;
472 };
473 
474 /// Progressive lowering of ContractionOp.
475 ///
476 /// One:
477 /// %x = vector.contract with at least one free/batch dimension
478 /// is replaced by:
479 /// %a = vector.contract with one less free/batch dimension
480 /// %b = vector.contract with one less free/batch dimension
481 /// ..
482 /// %x = combine %a %b ..
483 /// until a pure contraction is reached (no free/batch dimensions),
484 /// which is replaced by a dot-product.
485 ///
486 /// This only kicks in when either VectorTransformsOptions is set
487 /// to Dot or when other contraction patterns fail.
488 class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
489 public:
492  std::function<LogicalResult(vector::ContractionOp op)>;
493 
494  static LogicalResult defaultFilter(vector::ContractionOp op) {
495  return success();
496  }
497 
499  MLIRContext *context, PatternBenefit benefit = 1,
500  FilterConstraintType constraint = defaultFilter)
501  : OpRewritePattern<vector::ContractionOp>(context, benefit),
502  vectorTransformOptions(vectorTransformOptions),
503  filter(std::move(constraint)) {}
504 
505  LogicalResult matchAndRewrite(vector::ContractionOp op,
506  PatternRewriter &rewriter) const override;
507 
508 private:
509  /// Options to control the vector patterns.
510  vector::VectorTransformsOptions vectorTransformOptions;
511  FilterConstraintType filter;
512  // Lower one parallel dimension.
513  FailureOr<Value> lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
514  int64_t rhsIndex,
515  PatternRewriter &rewriter) const;
516  // Lower one reduction dimension.
517  FailureOr<Value> lowerReduction(vector::ContractionOp op,
518  PatternRewriter &rewriter) const;
519 };
520 
521 } // namespace vector
522 } // namespace mlir
523 
524 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
@ None
static llvm::ManagedStatic< PassManagerOptions > options
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:621
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:245
Progressive lowering of ContractionOp.
static LogicalResult defaultFilter(vector::ContractionOp op)
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions, MLIRContext *context, PatternBenefit benefit=1, FilterConstraintType constraint=defaultFilter)
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
Progressive lowering of ContractionOp.
Progressive lowering of a vector.contract a, b, c with row-major matmul semantics to an output-size-u...
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
static LogicalResult defaultFilter(vector::ContractionOp op)
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
ContractionOpToDotLowering(vector::VectorTransformsOptions vectorTransformOptions, MLIRContext *context, PatternBenefit benefit=1, const FilterConstraintType &constraint=defaultFilter)
Progressive lowering of a vector.contract a, b, c with row-major matmul semantics to:
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
static LogicalResult defaultFilter(vector::ContractionOp op)
ContractionOpToMatmulOpLowering(vector::VectorTransformsOptions vectorTransformOptions, MLIRContext *context, PatternBenefit benefit=1, FilterConstraintType constraint=defaultFilter)
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
Progressively lower a vector.contract a, b, c with row-major matmul semantics to:
Progressive lowering of a vector.contract a, b, c with row-major matmul semantics to a reduction_size...
ContractionOpToOuterProductOpLowering(vector::VectorTransformsOptions vectorTransformOptions, MLIRContext *context, PatternBenefit benefit=1, FilterConstraintType constraint=defaultFilter)
static LogicalResult defaultFilter(vector::ContractionOp op)
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
Progressively lower a vector.contract a, b, c with row-major matmul semantics to a reduction_size-unr...
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options=VectorTransformsOptions(), PatternBenefit benefit=1)
Collects patterns to progressively lower vector contraction ops on high-D into low-D reduction and pr...
void populateVectorExtractStridedSliceToExtractInsertChainPatterns(RewritePatternSet &patterns, std::function< bool(ExtractStridedSliceOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to breaks down 1-D extract_strided_slice ops into a chain of Extract...
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options=VectorTransformsOptions(), PatternBenefit benefit=1)
Insert TransposeLowering patterns into extraction/insertion.
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert scan op.
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361
Options that control the vector unrolling.
FilterConstraintFnType filterConstraint
Callback function that indicates whether vector unrolling should be attempted on the operation.
UnrollVectorOptions & setFilterConstraint(FilterConstraintFnType constraint)
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
UnrollVectorOptions & setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn)
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> UnrollTraversalOrderFnType
Function that returns the traversal order (in terms of "for loop order", i.e.
std::function< LogicalResult(Operation *op)> FilterConstraintFnType
NativeShapeFnType nativeShape
Function that returns the shape of the vector to unroll to for a given operation.
UnrollVectorOptions & setNativeShape(ArrayRef< int64_t > shape)
Set the native shape to use for unrolling.
UnrollTraversalOrderFnType traversalOrderCallback
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> NativeShapeFnType
Apply splitFullAndPartialTransfer selectively via a pattern.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Performs the rewrite.
std::function< LogicalResult(VectorTransferOpInterface op)> FilterConstraintType
VectorTransferFullPartialRewriter(MLIRContext *context, VectorTransformsOptions options=VectorTransformsOptions(), FilterConstraintType filter=[](VectorTransferOpInterface op) { return success();}, PatternBenefit benefit=1)
Structure to control the behavior of vector transform patterns.
VectorTransferSplit vectorTransferSplit
Option to control the splitting of vector transfers.
VectorTransformsOptions & setVectorMultiReductionLowering(VectorMultiReductionLowering opt)
VectorMultiReductionLowering vectorMultiReductionLowering
Option to control the lowering of vector.multi_reduction.
VectorTransformsOptions & setVectorTransposeLowering(VectorTransposeLowering opt)
VectorTransformsOptions & setVectorTransferSplit(VectorTransferSplit opt)
VectorTransformsOptions & setVectorTransformsOptions(VectorContractLowering opt)
VectorContractLowering vectorContractLowering
Option to control the lowering of vector.contract.
VectorTransposeLowering vectorTransposeLowering
Option to control the lowering of vector.transpose.