MLIR  21.0.0git
LoweringPatterns.h
Go to the documentation of this file.
1 //===- LoweringPatterns.h - Vector rewrite patterns --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
11 
14 
15 namespace mlir {
16 class RewritePatternSet;
17 
18 namespace vector {
19 
20 //===----------------------------------------------------------------------===//
21 // Lowering pattern populate functions
22 //===----------------------------------------------------------------------===//
23 
24 /// Populate the pattern set with the following patterns:
25 ///
26 /// [OuterProductOpLowering]
27 /// Progressively lower a `vector.outerproduct` to linearized
28 /// `vector.extract` + `vector.fma` + `vector.insert`.
29 ///
30 /// [ContractionOpLowering]
31 /// Progressive lowering of ContractionOp.
32 /// One:
33 /// %x = vector.contract with at least one free/batch dimension
34 /// is replaced by:
35 /// %a = vector.contract with one less free/batch dimension
36 /// %b = vector.contract with one less free/batch dimension
37 ///
38 /// [ContractionOpToMatmulOpLowering]
39 /// Progressively lower a `vector.contract` with row-major matmul semantics to
40 /// linearized `vector.shape_cast` + `vector.matmul` on the way to
41 /// `llvm.matrix.multiply`.
42 ///
43 /// [ContractionOpToDotLowering]
44 /// Progressively lower a `vector.contract` with row-major matmul semantics to
45 /// linearized `vector.extract` + `vector.reduce` + `vector.insert`.
46 ///
47 /// [ContractionOpToOuterProductOpLowering]
48 /// Progressively lower a `vector.contract` with row-major matmul semantics to
49 /// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
51  RewritePatternSet &patterns,
52  VectorContractLowering vectorContractLoweringOption,
53  PatternBenefit benefit = 1, bool disableOuterProductLowering = false);
54 
55 /// Populate the pattern set with the following patterns:
56 ///
57 /// [OuterProductOpLowering]
58 /// Progressively lower a `vector.outerproduct` to linearized
59 /// `vector.extract` + `vector.fma` + `vector.insert`.
61  PatternBenefit benefit = 1);
62 
63 /// Collect a set of patterns to convert vector.multi_reduction op into
64 /// a sequence of vector.reduction ops. The patterns comprise:
65 ///
66 /// [InnerOuterDimReductionConversion]
67 /// Rewrites vector.multi_reduction such that all reduction dimensions are
68 /// either innermost or outermost, by adding the proper vector.transpose
69 /// operations.
70 ///
71 /// [ReduceMultiDimReductionRank]
72 /// Once in innermost or outermost reduction
73 /// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
74 /// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
75 /// back.
76 ///
77 /// [TwoDimMultiReductionToElementWise]
78 /// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
79 /// dimension, unroll the outer dimension to obtain a sequence of 1-D vector
80 /// ops. This also has an opportunity for tree-reduction (in the future).
81 ///
82 /// [TwoDimMultiReductionToReduction]
83 /// Once in 2-D vector.multi_reduction form, with an **innermost** reduction
84 /// dimension, unroll the outer dimension to obtain a sequence of extract +
85 /// vector.reduction + insert. This can further lower to horizontal reduction
86 /// ops.
87 ///
88 /// [OneDimMultiReductionToTwoDim]
89 /// For cases that reduce to 1-D vector<k> reduction (and are thus missing
90 /// either a parallel or a reduction), we lift them back up to 2-D with a simple
91 /// vector.shape_cast to vector<1xk> so that the other patterns can kick in,
92 /// thus fully exiting out of the vector.multi_reduction abstraction.
94  RewritePatternSet &patterns, VectorMultiReductionLowering options,
95  PatternBenefit benefit = 1);
96 
97 /// Populate the pattern set with the following patterns:
98 ///
99 /// [TransferReadToVectorLoadLowering]
100 /// Progressive lowering of BroadcastOp to ExtractOp + InsertOp + lower-D
101 /// BroadcastOp until dim 1.
102 void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns,
103  PatternBenefit benefit = 1);
104 
105 /// Populate the pattern set with the following patterns:
106 ///
107 /// [CreateMaskOp]
108 /// Progressive lowering of CreateMaskOp to lower-D CreateMaskOp until dim 1.
109 ///
110 /// [ConstantMaskOp]
111 /// Progressive lowering of ConstantMaskOp to lower-D ConstantMaskOp until
112 /// dim 1.
113 void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns,
114  PatternBenefit benefit = 1);
115 
116 /// Collects patterns that lower scalar vector transfer ops to memref loads and
117 /// stores when beneficial. If `allowMultipleUses` is set to true, the patterns
118 /// are applied to vector transfer reads with any number of uses. Otherwise,
119 /// only vector transfer reads with a single use will be lowered.
121  PatternBenefit benefit,
122  bool allowMultipleUses);
123 
124 /// Populate the pattern set with the following patterns:
125 ///
126 /// [ShapeCastOp2DDownCastRewritePattern]
127 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
128 /// vectors progressively.
129 ///
130 /// [ShapeCastOp2DUpCastRewritePattern]
131 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
132 /// vectors progressively.
133 ///
134 /// [ShapeCastOpRewritePattern]
135 /// Reference lowering to fully unrolled sequences of single element ExtractOp +
136 /// InsertOp. Note that applying this pattern can almost always be considered a
137 /// performance bug.
138 void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
139  PatternBenefit benefit = 1);
140 
141 /// Populate the pattern set with the following patterns:
142 ///
143 /// [TransposeOpLowering]
144 ///
145 /// [TransposeOp2DToShuffleLowering]
146 ///
148  RewritePatternSet &patterns,
149  VectorTransposeLowering vectorTransposeLowering,
150  PatternBenefit benefit = 1);
151 
152 /// Populate the pattern set with the following patterns:
153 ///
154 /// [TransferReadToVectorLoadLowering]
155 /// Progressive lowering of transfer_read.This pattern supports lowering of
156 /// `vector.transfer_read` to a combination of `vector.load` and
157 /// `vector.broadcast`
158 ///
159 /// [TransferWriteToVectorStoreLowering]
160 /// Progressive lowering of transfer_write. This pattern supports lowering of
161 /// `vector.transfer_write` to `vector.store`
162 ///
163 /// These patterns lower transfer ops to simpler ops like `vector.load`,
164 /// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank
165 /// of a most `maxTransferRank` are lowered. This is useful when combined with
166 /// VectorToSCF, which reduces the rank of vector transfer ops.
168  RewritePatternSet &patterns,
169  std::optional<unsigned> maxTransferRank = std::nullopt,
170  PatternBenefit benefit = 1);
171 
172 /// Collect a set of transfer read/write lowering patterns that simplify the
173 /// permutation map (e.g., converting it to a minor identity map) by inserting
174 /// broadcasts and transposes. More specifically:
175 ///
176 /// [TransferReadPermutationLowering]
177 /// Lower transfer_read op with permutation into a transfer_read with a
178 /// permutation map composed of leading zeros followed by a minor identity +
179 /// vector.transpose op.
180 /// Ex:
181 /// vector.transfer_read ...
182 /// permutation_map: (d0, d1, d2) -> (0, d1)
183 /// into:
184 /// %v = vector.transfer_read ...
185 /// permutation_map: (d0, d1, d2) -> (d1, 0)
186 /// vector.transpose %v, [1, 0]
187 ///
188 /// vector.transfer_read ...
189 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
190 /// into:
191 /// %v = vector.transfer_read ...
192 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
193 /// vector.transpose %v, [0, 1, 3, 2, 4]
194 /// Note that an alternative is to transform it to linalg.transpose +
195 /// vector.transfer_read to do the transpose in memory instead.
196 ///
197 /// [TransferWritePermutationLowering]
198 /// Lower transfer_write op with permutation into a transfer_write with a
199 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
200 /// Ex:
201 /// vector.transfer_write %v ...
202 /// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
203 /// into:
204 /// %tmp = vector.transpose %v, [2, 0, 1]
205 /// vector.transfer_write %tmp ...
206 /// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
207 ///
208 /// vector.transfer_write %v ...
209 /// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
210 /// into:
211 /// %tmp = vector.transpose %v, [1, 0]
212 /// %v = vector.transfer_write %tmp ...
213 /// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
214 ///
215 /// [TransferOpReduceRank]
216 /// Lower transfer_read op with broadcast in the leading dimensions into
217 /// transfer_read of lower rank + vector.broadcast.
218 /// Ex: vector.transfer_read ...
219 /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
220 /// into:
221 /// %v = vector.transfer_read ...
222 /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
223 /// vector.broadcast %v
225  RewritePatternSet &patterns, PatternBenefit benefit = 1);
226 
227 /// Populate the pattern set with the following patterns:
228 ///
229 /// [ScanToArithOps]
230 /// Convert vector.scan op into arith ops and vector.insert_strided_slice /
231 /// vector.extract_strided_slice.
232 void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
233  PatternBenefit benefit = 1);
234 
235 /// Populate the pattern set with the following patterns:
236 ///
237 /// [StepToArithConstantOp]
238 /// Convert vector.step op into arith ops if not using scalable vectors
239 void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
240  PatternBenefit benefit = 1);
241 
242 /// Populate the pattern set with the following patterns:
243 ///
244 /// [FlattenGather]
245 /// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
246 /// outermost dimension.
247 ///
248 /// [Gather1DToConditionalLoads]
249 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
250 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
251 /// loads/extracts are made conditional using `scf.if` ops.
252 void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
253  PatternBenefit benefit = 1);
254 
255 /// Populates instances of `MaskOpRewritePattern` to lower masked operations
256 /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
257 /// not its nested `MaskableOpInterface`.
259  RewritePatternSet &patterns);
260 
261 /// Populate the pattern set with the following patterns:
262 ///
263 /// [VectorMaskedLoadOpConverter]
264 /// Turns vector.maskedload to scf.if + memref.load
265 ///
266 /// [VectorMaskedStoreOpConverter]
267 /// Turns vector.maskedstore to scf.if + memref.store
269  PatternBenefit benefit = 1);
270 
271 /// Populate the pattern set with the following patterns:
272 ///
273 /// [UnrollInterleaveOp]
274 /// A one-shot unrolling of InterleaveOp to (one or more) ExtractOp +
275 /// InterleaveOp (of `targetRank`) + InsertOp.
276 void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
277  int64_t targetRank = 1,
278  PatternBenefit benefit = 1);
279 
281  PatternBenefit benefit = 1);
282 
283 /// Populates the pattern set with the following patterns:
284 ///
285 /// [UnrollBitCastOp]
286 /// A one-shot unrolling of BitCastOp to (one or more) ExtractOp +
287 /// BitCastOp (of `targetRank`) + InsertOp.
288 void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
289  int64_t targetRank = 1,
290  PatternBenefit benefit = 1);
291 
292 /// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where
293 /// n > 1.
294 void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
295 
296 } // namespace vector
297 } // namespace mlir
298 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
static llvm::ManagedStatic< PassManagerOptions > options
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populates the pattern set with the following patterns:
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
void populateVectorMaskLoweringPatternsForSideEffectingOps(RewritePatternSet &patterns)
Populates instances of MaskOpRewritePattern to lower masked operations with vector....
void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns