MLIR  16.0.0git
VectorRewritePatterns.h
Go to the documentation of this file.
1 //===- VectorRewritePatterns.h - Vector rewrite patterns --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
11 
12 #include <utility>
13 
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/PatternMatch.h"
18 
19 namespace mlir {
20 class RewritePatternSet;
21 
22 namespace vector {
23 
24 //===----------------------------------------------------------------------===//
25 // Vector transformation options exposed as auxiliary structs.
26 //===----------------------------------------------------------------------===//
27 /// Enum to control the lowering of `vector.transpose` operations.
29  /// Lower transpose into element-wise extract and inserts.
30  EltWise = 0,
31  /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
32  /// intrinsics.
33  Flat = 1,
34  /// Lower 2-D transpose to `vector.shuffle`.
35  Shuffle = 2,
36 };
37 /// Enum to control the lowering of `vector.multi_reduction` operations.
39  /// Lower multi_reduction into outer-reduction and inner-parallel ops.
40  InnerParallel = 0,
41  /// Lower multi_reduction into outer-parallel and inner-reduction ops.
42  InnerReduction = 1,
43 };
44 /// Enum to control the lowering of `vector.contract` operations.
46  /// Progressively lower to finer grained `vector.contract` and dot-products.
47  Dot = 0,
48  /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
49  Matmul = 1,
50  /// Lower to `vector.outerproduct`.
51  OuterProduct = 2,
52  /// Lower contract with all reduction dimensions unrolled to 1 to a vector
53  /// elementwise operations.
54  ParallelArith = 3,
55 };
56 /// Enum to control the splitting of `vector.transfer` operations into
57 /// in-bounds and out-of-bounds variants.
58 enum class VectorTransferSplit {
59  /// Do not split vector transfer operations.
60  None = 0,
61  /// Split using in-bounds + out-of-bounds vector.transfer operations.
62  VectorTransfer = 1,
63  /// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy
64  /// operations.
65  LinalgCopy = 2,
66  /// Do not split vector transfer operation but instead mark it as "in-bounds".
67  ForceInBounds = 3
68 };
69 /// Structure to control the behavior of vector transform patterns.
71  /// Option to control the lowering of vector.contract.
75  vectorContractLowering = opt;
76  return *this;
77  }
78  /// Option to control the lowering of vector.multi_reduction.
79  VectorMultiReductionLowering vectorMultiReductionLowering =
83  vectorMultiReductionLowering = opt;
84  return *this;
85  }
86  /// Option to control the lowering of vector.transpose.
87  VectorTransposeLowering vectorTransposeLowering =
91  vectorTransposeLowering = opt;
92  return *this;
93  }
94  /// Option to control the splitting of vector transfers.
97  vectorTransferSplit = opt;
98  return *this;
99  }
100 };
101 
102 /// Options that control the vector unrolling.
104  using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
105  /// Callback function that indicates whether vector unrolling should be
106  /// attempted on the operation.
107  FilterConstraintFnType filterConstraint = nullptr;
109  filterConstraint = std::move(constraint);
110  return *this;
111  }
112 
113  using NativeShapeFnType =
114  std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>;
115  /// Function that returns the shape of the vector to unroll to for a given
116  /// operation. The unrolling is aborted if the function returns `llvm::None`.
117  NativeShapeFnType nativeShape = nullptr;
119  nativeShape = std::move(fn);
120  return *this;
121  }
122 
123  /// Set the native shape to use for unrolling.
125  SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end());
126  nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> {
127  return tsShape;
128  };
129  return *this;
130  }
131 
132  /// Function that returns the traversal order (in terms of "for loop order",
133  /// i.e. slowest varying dimension to fastest varying dimension) that shoudl
134  /// be used when unrolling the given operation into units of the native vector
135  /// size.
137  std::function<Optional<SmallVector<int64_t>>(Operation *op)>;
138  UnrollTraversalOrderFnType traversalOrderCallback = nullptr;
141  traversalOrderCallback = std::move(traversalOrderFn);
142  return *this;
143  }
144 };
145 
146 //===----------------------------------------------------------------------===//
147 // Vector transformation exposed as populate functions over rewrite patterns.
148 //===----------------------------------------------------------------------===//
149 
150 /// Insert TransposeLowering patterns into extraction/insertion.
152  RewritePatternSet &patterns,
154  PatternBenefit benefit = 1);
155 
156 /// Collect a set of patterns to convert vector.multi_reduction op into
157 /// a sequence of vector.reduction ops. The patterns comprise:
158 /// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such
159 /// that all reduction dimensions are either innermost or outermost, by adding
160 /// the proper vector.transpose operations.
161 /// - ReduceMultiDimReductionRank: once in innermost or outermost reduction
162 /// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
163 /// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
164 /// back.
165 /// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction
166 /// form, with an **outermost** reduction dimension, unroll the outer dimension
167 /// to obtain a sequence of 1-D vector ops. This also has an opportunity for
168 /// tree-reduction (in the future).
169 /// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form,
170 /// with an **innermost** reduction dimension, unroll the outer dimension to
171 /// obtain a sequence of extract + vector.reduction + insert. This can further
172 /// lower to horizontal reduction ops.
173 /// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector<k>
174 /// reduction (and are thus missing either a parallel or a reduction), we lift
175 /// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that
176 /// the other patterns can kick in, thus fully exiting out of the
177 /// vector.multi_reduction abstraction.
180  PatternBenefit benefit = 1);
181 
182 /// Collects patterns to progressively lower vector contraction ops on high-D
183 /// into low-D reduction and product ops.
185  RewritePatternSet &patterns,
187  PatternBenefit benefit = 1);
188 
189 /// Collect patterns to convert reduction op to vector.contract and fold
190 /// transpose/broadcast ops into the contract.
192  PatternBenefit benefit = 1);
193 
194 /// Collect patterns to convert scan op
196  PatternBenefit benefit = 1);
197 
198 //===----------------------------------------------------------------------===//
199 // Vector.transfer patterns.
200 //===----------------------------------------------------------------------===//
201 /// Collect a set of transfer read/write lowering patterns that simplify the
202 /// permutation map (e.g., converting it to a minor identity map) by inserting
203 /// broadcasts and transposes. More specifically:
204 ///
205 /// [TransferReadPermutationLowering]
206 /// Lower transfer_read op with permutation into a transfer_read with a
207 /// permutation map composed of leading zeros followed by a minor identity +
208 /// vector.transpose op.
209 /// Ex:
210 /// vector.transfer_read ...
211 /// permutation_map: (d0, d1, d2) -> (0, d1)
212 /// into:
213 /// %v = vector.transfer_read ...
214 /// permutation_map: (d0, d1, d2) -> (d1, 0)
215 /// vector.transpose %v, [1, 0]
216 ///
217 /// vector.transfer_read ...
218 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
219 /// into:
220 /// %v = vector.transfer_read ...
221 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
222 /// vector.transpose %v, [0, 1, 3, 2, 4]
223 /// Note that an alternative is to transform it to linalg.transpose +
224 /// vector.transfer_read to do the transpose in memory instead.
225 ///
226 /// [TransferWritePermutationLowering]
227 /// Lower transfer_write op with permutation into a transfer_write with a
228 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
229 /// Ex:
230 /// vector.transfer_write %v ...
231 /// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
232 /// into:
233 /// %tmp = vector.transpose %v, [2, 0, 1]
234 /// vector.transfer_write %tmp ...
235 /// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
236 ///
237 /// vector.transfer_write %v ...
238 /// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
239 /// into:
240 /// %tmp = vector.transpose %v, [1, 0]
241 /// %v = vector.transfer_write %tmp ...
242 /// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
243 ///
244 /// [TransferOpReduceRank]
245 /// Lower transfer_read op with broadcast in the leading dimensions into
246 /// transfer_read of lower rank + vector.broadcast.
247 /// Ex: vector.transfer_read ...
248 /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
249 /// into:
250 /// %v = vector.transfer_read ...
251 /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
252 /// vector.broadcast %v
254  RewritePatternSet &patterns, PatternBenefit benefit = 1);
255 
256 /// Collect a set of patterns to reduce the rank of the operands of vector
257 /// transfer ops to operate on the largest contigious vector.
258 /// These patterns are useful when lowering to dialects with 1d vector type
259 /// such as llvm and it will result fewer memory reads.
261  RewritePatternSet &patterns, PatternBenefit benefit = 1);
262 
263 /// Populate `patterns` with the following patterns.
264 ///
265 /// [DecomposeDifferentRankInsertStridedSlice]
266 /// ==========================================
267 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
268 /// have different ranks.
269 ///
270 /// When ranks are different, InsertStridedSlice needs to extract a properly
271 /// ranked vector from the destination vector into which to insert. This pattern
272 /// only takes care of this extraction part and forwards the rest to
273 /// [VectorInsertStridedSliceOpSameRankRewritePattern].
274 ///
275 /// For a k-D source and n-D destination vector (k < n), we emit:
276 /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
277 /// insert the k-D source.
278 /// 2. k-D -> (n-1)-D InsertStridedSlice op
279 /// 3. InsertOp that is the reverse of 1.
280 ///
281 /// [DecomposeNDExtractStridedSlice]
282 /// ================================
283 /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
284 /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
286  RewritePatternSet &patterns, PatternBenefit benefit = 1);
287 
288 /// Populate `patterns` with the following patterns.
289 ///
290 /// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns();
291 ///
292 /// [ConvertSameRankInsertStridedSliceIntoShuffle]
293 /// ==============================================
294 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
295 /// have the same rank. For each outermost index in the slice:
296 /// begin end stride
297 /// [offset : offset+size*stride : stride]
298 /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
299 /// 2. InsertStridedSlice (k-1)-D into (n-1)-D
300 /// 3. the destination subvector is inserted back in the proper place
301 /// 3. InsertOp that is the reverse of 1.
302 ///
303 /// [Convert1DExtractStridedSliceIntoShuffle]
304 /// =========================================
305 /// For such cases, we can lower it to a ShuffleOp.
307  RewritePatternSet &patterns, PatternBenefit benefit = 1);
308 
309 /// Collect a set of pattern to unroll vector operations to a smaller shapes.
310 /// `options` structure controls which operations are unrolled and the target
311 /// shape.
312 /// `op` is unrolled to the `targetShape` as follows, for each of its operands:
313 /// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
314 /// `numUnrolledInstances` are computed from the `targetShape`. For now it is
315 /// assumed the unrolling factors divide the vector sizes.
316 /// 2. ExtractStridedSlice are created to break-up the vector operands.
317 /// 3. the original op is cloned `numUnrolledInstances` times, once for each
318 /// result.
319 /// 4. InsertStridedSlice are inserted to re-assemble the slices into the
320 /// original vectore shape.
321 ///
322 /// Example:
323 ///
324 /// opA(operand0, operand1) // numUnrolledInstances = 3
325 ///
326 /// operand0 operand1
327 /// | |
328 /// fork fork
329 /// <----------gather all fork ops --------->
330 /// /|\ /|\
331 /// f00 f01 f02 f10 f11 f12
332 /// <---------- clone op 3 times --------->
333 /// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
334 /// \ | /
335 /// <-------------------- join ------------------------->
336 ///
337 /// Other local patterns then kick in iteratively (including DCE) and compose
338 /// to combine the ExtractStridedSlice/InsertStridedSlice.
340  const UnrollVectorOptions &options,
341  PatternBenefit benefit = 1);
342 
343 //===----------------------------------------------------------------------===//
344 // Finer-grained patterns exposed for more control over individual lowerings.
345 //===----------------------------------------------------------------------===//
346 /// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
347 /// may take an extra filter to perform selection at a finer granularity.
349  using FilterConstraintType =
350  std::function<LogicalResult(VectorTransferOpInterface op)>;
351 
353  MLIRContext *context,
355  FilterConstraintType filter =
356  [](VectorTransferOpInterface op) { return success(); },
357  PatternBenefit benefit = 1)
358  : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
359  filter(std::move(filter)) {}
360 
361  /// Performs the rewrite.
362  LogicalResult matchAndRewrite(Operation *op,
363  PatternRewriter &rewriter) const override;
364 
365 private:
367  FilterConstraintType filter;
368 };
369 
370 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
371 /// semantics to:
372 /// ```
373 /// %flattened_a = vector.shape_cast %a
374 /// %flattened_b = vector.shape_cast %b
375 /// %flattened_d = vector.matmul %flattened_a, %flattened_b
376 /// %d = vector.shape_cast %%flattened_d
377 /// %e = add %c, %d
378 /// ```
379 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
380 //
381 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
382 /// the vector.contract op is a row-major matrix multiply.
384  : public OpRewritePattern<vector::ContractionOp> {
385 public:
387 
388  using FilterConstraintType =
389  std::function<LogicalResult(vector::ContractionOp op)>;
390 
391  static LogicalResult defaultFilter(vector::ContractionOp op) {
392  return success();
393  }
394 
396  vector::VectorTransformsOptions vectorTransformOptions,
397  MLIRContext *context, PatternBenefit benefit = 1,
398  FilterConstraintType constraint = defaultFilter)
399  : OpRewritePattern<vector::ContractionOp>(context, benefit),
400  vectorTransformOptions(vectorTransformOptions),
401  filter(std::move(constraint)) {}
402 
403  LogicalResult matchAndRewrite(vector::ContractionOp op,
404  PatternRewriter &rewriter) const override;
405 
406 private:
407  /// Options to control the vector patterns.
408  vector::VectorTransformsOptions vectorTransformOptions;
409  FilterConstraintType filter;
410 };
411 
412 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
413 /// semantics to a reduction_size-unrolled sequence:
414 /// ```
415 /// %at = vector.transpose %a, [1, 0]
416 /// %bRow0 = vector.extract %b[0]
417 /// %atRow0 = vector.extract %at[0]
418 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c
419 /// ...
420 /// %bRowK = vector.extract %b[K]
421 /// %atRowK = vector.extract %at[K]
422 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
423 /// ```
424 ///
425 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
426 /// the vector.contract op is a row-major matrix multiply.
428  : public OpRewritePattern<vector::ContractionOp> {
429 public:
431 
432  using FilterConstraintType =
433  std::function<LogicalResult(vector::ContractionOp op)>;
434 
435  static LogicalResult defaultFilter(vector::ContractionOp op) {
436  return success();
437  }
438 
440  vector::VectorTransformsOptions vectorTransformOptions,
441  MLIRContext *context, PatternBenefit benefit = 1,
442  FilterConstraintType constraint = defaultFilter)
443  : OpRewritePattern<vector::ContractionOp>(context, benefit),
444  vectorTransformOptions(vectorTransformOptions),
445  filter(std::move(constraint)) {}
446 
447  LogicalResult matchAndRewrite(vector::ContractionOp op,
448  PatternRewriter &rewriter) const override;
449 
450 private:
451  /// Options to control the vector patterns.
452  vector::VectorTransformsOptions vectorTransformOptions;
453  FilterConstraintType filter;
454 };
455 
456 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
457 /// semantics to an output-size-unrolled sequence:
458 /// ```
459 /// %out = arith.constant ... : vector<MxNxelt_type>
460 /// %bt = vector.transpose %b, [1, 0]
461 /// %aRow0 = vector.extract %a[0]
462 /// %btRow0 = vector.extract %bt[0]
463 /// %c00 = vector.reduce %atRow0, %bRow0
464 /// %out00 = vector.insert %c00, %out[0, 0]
465 /// ...
466 /// %aRowLast = vector.extract %at[M-1]
467 /// %btRowLast = vector.extract %b[N-1]
468 /// %cLastLast = vector.reduce %atRowLast, %bRowLast
469 /// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
470 /// ```
471 ///
472 /// This only kicks in when VectorTransformsOptions is set to Dot and
473 /// the vector.contract op is a row-major matmul or matvec.
475  : public OpRewritePattern<vector::ContractionOp> {
476 public:
478 
479  using FilterConstraintType =
480  std::function<LogicalResult(vector::ContractionOp op)>;
481 
482  static LogicalResult defaultFilter(vector::ContractionOp op) {
483  return success();
484  }
485 
487  vector::VectorTransformsOptions vectorTransformOptions,
488  MLIRContext *context, PatternBenefit benefit = 1,
489  const FilterConstraintType &constraint = defaultFilter)
490  : OpRewritePattern<vector::ContractionOp>(context, benefit),
491  vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
492 
493  LogicalResult matchAndRewrite(vector::ContractionOp op,
494  PatternRewriter &rewriter) const override;
495 
496 private:
497  /// Options to control the vector patterns.
498  vector::VectorTransformsOptions vectorTransformOptions;
499  FilterConstraintType filter;
500 };
501 
502 /// Progressive lowering of ContractionOp.
503 ///
504 /// One:
505 /// %x = vector.contract with at least one free/batch dimension
506 /// is replaced by:
507 /// %a = vector.contract with one less free/batch dimension
508 /// %b = vector.contract with one less free/batch dimension
509 /// ..
510 /// %x = combine %a %b ..
511 /// until a pure contraction is reached (no free/batch dimensions),
512 /// which is replaced by a dot-product.
513 ///
514 /// This only kicks in when either VectorTransformsOptions is set
515 /// to Dot or when other contraction patterns fail.
516 class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
517 public:
519  using FilterConstraintType =
520  std::function<LogicalResult(vector::ContractionOp op)>;
521 
522  static LogicalResult defaultFilter(vector::ContractionOp op) {
523  return success();
524  }
525 
527  MLIRContext *context, PatternBenefit benefit = 1,
528  FilterConstraintType constraint = defaultFilter)
529  : OpRewritePattern<vector::ContractionOp>(context, benefit),
530  vectorTransformOptions(vectorTransformOptions),
531  filter(std::move(constraint)) {}
532 
533  LogicalResult matchAndRewrite(vector::ContractionOp op,
534  PatternRewriter &rewriter) const override;
535 
536 private:
537  /// Options to control the vector patterns.
538  vector::VectorTransformsOptions vectorTransformOptions;
539  FilterConstraintType filter;
540  // Lower one parallel dimension.
541  FailureOr<Value> lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
542  int64_t rhsIndex,
543  PatternRewriter &rewriter) const;
544  // Lower one reduction dimension.
545  FailureOr<Value> lowerReduction(vector::ContractionOp op,
546  PatternRewriter &rewriter) const;
547 };
548 
549 } // namespace vector
550 } // namespace mlir
551 
552 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
Include the generated interface declarations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Progressive lowering of a vector.contract a, b, c with row-major matmul semantics to a reduction_size...
ContractionOpToOuterProductOpLowering(vector::VectorTransformsOptions vectorTransformOptions, MLIRContext *context, PatternBenefit benefit=1, FilterConstraintType constraint=defaultFilter)
Lower transpose into element-wise extract and inserts.
Apply splitFullAndPartialTransfer selectively via a pattern.
VectorTransferFullPartialRewriter(MLIRContext *context, VectorTransformsOptions options=VectorTransformsOptions(), FilterConstraintType filter=[](VectorTransferOpInterface op) { return success();}, PatternBenefit benefit=1)
Progressive lowering of a vector.contract a, b, c with row-major matmul semantics to an output-size-u...
Progressively lower to finer grained vector.contract and dot-products.
VectorContractLowering
Enum to control the lowering of vector.contract operations.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
static LogicalResult defaultFilter(vector::ContractionOp op)
Options that control the vector unrolling.
UnrollVectorOptions & setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn)
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:244
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
VectorMultiReductionLowering
Enum to control the lowering of vector.multi_reduction operations.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static LogicalResult defaultFilter(vector::ContractionOp op)
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
Progressive lowering of ContractionOp.
Lower contract with all reduction dimensions unrolled to 1 to a vector elementwise operations...
UnrollVectorOptions & setNativeShape(ArrayRef< int64_t > shape)
Set the native shape to use for unrolling.
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:78
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options=VectorTransformsOptions(), PatternBenefit benefit=1)
Collects patterns to progressively lower vector contraction ops on high-D into low-D reduction and pr...
Split using in-bounds + out-of-bounds vector.transfer operations.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
Lower 2-D transpose to vector.flat_transpose, maps 1-1 to LLVM matrix intrinsics. ...
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions, MLIRContext *context, PatternBenefit benefit=1, FilterConstraintType constraint=defaultFilter)
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
VectorTransformsOptions & setVectorTransposeLowering(VectorTransposeLowering opt)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:360
This class acts as a special tag that makes the desire to match "any" operation type explicit...
Definition: PatternMatch.h:157
UnrollVectorOptions & setFilterConstraint(FilterConstraintFnType constraint)
Lower to vector.matrix_multiply, maps 1-1 to LLVM matrix intrinsics.
Lower multi_reduction into outer-parallel and inner-reduction ops.
ContractionOpToMatmulOpLowering(vector::VectorTransformsOptions vectorTransformOptions, MLIRContext *context, PatternBenefit benefit=1, FilterConstraintType constraint=defaultFilter)
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options=VectorTransformsOptions(), PatternBenefit benefit=1)
Insert TransposeLowering patterns into extraction/insertion.
std::function< LogicalResult(VectorTransferOpInterface op)> FilterConstraintType
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
std::function< Optional< SmallVector< int64_t, 4 > >(Operation *op)> NativeShapeFnType
static LogicalResult defaultFilter(vector::ContractionOp op)
Do not split vector transfer operation but instead mark it as "in-bounds".
VectorTransformsOptions & setVectorTransformsOptions(VectorContractLowering opt)
ContractionOpToDotLowering(vector::VectorTransformsOptions vectorTransformOptions, MLIRContext *context, PatternBenefit benefit=1, const FilterConstraintType &constraint=defaultFilter)
static llvm::ManagedStatic< PassManagerOptions > options
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
VectorTransformsOptions & setVectorMultiReductionLowering(VectorMultiReductionLowering opt)
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
VectorTransposeLowering
Enum to control the lowering of vector.transpose operations.
static LogicalResult defaultFilter(vector::ContractionOp op)
Do not split vector transfer operations.
Progressive lowering of a vector.contract a, b, c with row-major matmul semantics to: %flattened_a = ...
Split using an in-bounds vector.transfer + linalg.fill + linalg.copy operations.
Structure to control the behavior of vector transform patterns.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
Lower multi_reduction into outer-reduction and inner-parallel ops.
VectorTransformsOptions & setVectorTransferSplit(VectorTransferSplit opt)
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e...
VectorTransferSplit
Enum to control the splitting of vector.transfer operations into in-bounds and out-of-bounds variants...
std::function< LogicalResult(Operation *op)> FilterConstraintFnType
std::function< Optional< SmallVector< int64_t > >(Operation *op)> UnrollTraversalOrderFnType
Function that returns the traversal order (in terms of "for loop order", i.e.
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector...
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert scan op.
Lower 2-D transpose to vector.shuffle.