MLIR  20.0.0git
LoweringPatterns.h
Go to the documentation of this file.
1 //===- LoweringPatterns.h - Vector rewrite patterns --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
11 
13 
14 namespace mlir {
15 class RewritePatternSet;
16 
17 namespace vector {
18 
19 //===----------------------------------------------------------------------===//
20 // Lowering pattern populate functions
21 //===----------------------------------------------------------------------===//
22 
23 /// Populate the pattern set with the following patterns:
24 ///
25 /// [OuterProductOpLowering]
26 /// Progressively lower a `vector.outerproduct` to linearized
27 /// `vector.extract` + `vector.fma` + `vector.insert`.
28 ///
29 /// [ContractionOpLowering]
30 /// Progressive lowering of ContractionOp.
31 /// One:
32 /// %x = vector.contract with at least one free/batch dimension
33 /// is replaced by:
34 /// %a = vector.contract with one less free/batch dimension
35 /// %b = vector.contract with one less free/batch dimension
36 ///
37 /// [ContractionOpToMatmulOpLowering]
38 /// Progressively lower a `vector.contract` with row-major matmul semantics to
39 /// linearized `vector.shape_cast` + `vector.matmul` on the way to
40 /// `llvm.matrix.multiply`.
41 ///
42 /// [ContractionOpToDotLowering]
43 /// Progressively lower a `vector.contract` with row-major matmul semantics to
44 /// linearized `vector.extract` + `vector.reduce` + `vector.insert`.
45 ///
46 /// [ContractionOpToOuterProductOpLowering]
47 /// Progressively lower a `vector.contract` with row-major matmul semantics to
48 /// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
50  RewritePatternSet &patterns, VectorTransformsOptions options,
51  PatternBenefit benefit = 1, bool disableOuterProductLowering = false);
52 
53 /// Populate the pattern set with the following patterns:
54 ///
55 /// [OuterProductOpLowering]
56 /// Progressively lower a `vector.outerproduct` to linearized
57 /// `vector.extract` + `vector.fma` + `vector.insert`.
59  PatternBenefit benefit = 1);
60 
61 /// Collect a set of patterns to convert vector.multi_reduction op into
62 /// a sequence of vector.reduction ops. The patterns comprise:
63 ///
64 /// [InnerOuterDimReductionConversion]
65 /// Rewrites vector.multi_reduction such that all reduction dimensions are
66 /// either innermost or outermost, by adding the proper vector.transpose
67 /// operations.
68 ///
69 /// [ReduceMultiDimReductionRank]
70 /// Once in innermost or outermost reduction
71 /// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
72 /// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
73 /// back.
74 ///
75 /// [TwoDimMultiReductionToElementWise]
76 /// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
77 /// dimension, unroll the outer dimension to obtain a sequence of 1-D vector
78 /// ops. This also has an opportunity for tree-reduction (in the future).
79 ///
80 /// [TwoDimMultiReductionToReduction]
81 /// Once in 2-D vector.multi_reduction form, with an **innermost** reduction
82 /// dimension, unroll the outer dimension to obtain a sequence of extract +
83 /// vector.reduction + insert. This can further lower to horizontal reduction
84 /// ops.
85 ///
86 /// [OneDimMultiReductionToTwoDim]
87 /// For cases that reduce to 1-D vector<k> reduction (and are thus missing
88 /// either a parallel or a reduction), we lift them back up to 2-D with a simple
89 /// vector.shape_cast to vector<1xk> so that the other patterns can kick in,
90 /// thus fully exiting out of the vector.multi_reduction abstraction.
92  RewritePatternSet &patterns, VectorMultiReductionLowering options,
93  PatternBenefit benefit = 1);
94 
95 /// Populate the pattern set with the following patterns:
96 ///
97 /// [TransferReadToVectorLoadLowering]
98 /// Progressive lowering of BroadcastOp to ExtractOp + InsertOp + lower-D
99 /// BroadcastOp until dim 1.
100 void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns,
101  PatternBenefit benefit = 1);
102 
103 /// Populate the pattern set with the following patterns:
104 ///
105 /// [CreateMaskOp]
106 /// Progressive lowering of CreateMaskOp to lower-D CreateMaskOp until dim 1.
107 ///
108 /// [ConstantMaskOp]
109 /// Progressive lowering of ConstantMaskOp to lower-D ConstantMaskOp until
110 /// dim 1.
111 void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns,
112  PatternBenefit benefit = 1);
113 
114 /// Collects patterns that lower scalar vector transfer ops to memref loads and
115 /// stores when beneficial. If `allowMultipleUses` is set to true, the patterns
116 /// are applied to vector transfer reads with any number of uses. Otherwise,
117 /// only vector transfer reads with a single use will be lowered.
119  PatternBenefit benefit,
120  bool allowMultipleUses);
121 
122 /// Populate the pattern set with the following patterns:
123 ///
124 /// [ShapeCastOp2DDownCastRewritePattern]
125 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
126 /// vectors progressively.
127 ///
128 /// [ShapeCastOp2DUpCastRewritePattern]
129 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
130 /// vectors progressively.
131 ///
132 /// [ShapeCastOpRewritePattern]
133 /// Reference lowering to fully unrolled sequences of single element ExtractOp +
134 /// InsertOp. Note that applying this pattern can almost always be considered a
135 /// performance bug.
136 void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
137  PatternBenefit benefit = 1);
138 
139 /// Populate the pattern set with the following patterns:
140 ///
141 /// [TransposeOpLowering]
142 ///
143 /// [TransposeOp2DToShuffleLowering]
144 ///
145 void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns,
146  VectorTransformsOptions options,
147  PatternBenefit benefit = 1);
148 
149 /// Populate the pattern set with the following patterns:
150 ///
151 /// [TransferReadToVectorLoadLowering]
152 /// Progressive lowering of transfer_read.This pattern supports lowering of
153 /// `vector.transfer_read` to a combination of `vector.load` and
154 /// `vector.broadcast`
155 ///
156 /// [TransferWriteToVectorStoreLowering]
157 /// Progressive lowering of transfer_write. This pattern supports lowering of
158 /// `vector.transfer_write` to `vector.store`
159 ///
160 /// [VectorLoadToMemrefLoadLowering]
161 /// Replace a 0-d vector.load with a memref.load + vector.broadcast.
162 ///
163 /// [VectorStoreToMemrefStoreLowering]
164 /// Replace a 0-d vector.store with a vector.extractelement + memref.store.
165 ///
166 /// These patterns lower transfer ops to simpler ops like `vector.load`,
167 /// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank
168 /// of a most `maxTransferRank` are lowered. This is useful when combined with
169 /// VectorToSCF, which reduces the rank of vector transfer ops.
171  RewritePatternSet &patterns,
172  std::optional<unsigned> maxTransferRank = std::nullopt,
173  PatternBenefit benefit = 1);
174 
175 /// Collect a set of transfer read/write lowering patterns that simplify the
176 /// permutation map (e.g., converting it to a minor identity map) by inserting
177 /// broadcasts and transposes. More specifically:
178 ///
179 /// [TransferReadPermutationLowering]
180 /// Lower transfer_read op with permutation into a transfer_read with a
181 /// permutation map composed of leading zeros followed by a minor identity +
182 /// vector.transpose op.
183 /// Ex:
184 /// vector.transfer_read ...
185 /// permutation_map: (d0, d1, d2) -> (0, d1)
186 /// into:
187 /// %v = vector.transfer_read ...
188 /// permutation_map: (d0, d1, d2) -> (d1, 0)
189 /// vector.transpose %v, [1, 0]
190 ///
191 /// vector.transfer_read ...
192 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
193 /// into:
194 /// %v = vector.transfer_read ...
195 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
196 /// vector.transpose %v, [0, 1, 3, 2, 4]
197 /// Note that an alternative is to transform it to linalg.transpose +
198 /// vector.transfer_read to do the transpose in memory instead.
199 ///
200 /// [TransferWritePermutationLowering]
201 /// Lower transfer_write op with permutation into a transfer_write with a
202 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
203 /// Ex:
204 /// vector.transfer_write %v ...
205 /// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
206 /// into:
207 /// %tmp = vector.transpose %v, [2, 0, 1]
208 /// vector.transfer_write %tmp ...
209 /// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
210 ///
211 /// vector.transfer_write %v ...
212 /// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
213 /// into:
214 /// %tmp = vector.transpose %v, [1, 0]
215 /// %v = vector.transfer_write %tmp ...
216 /// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
217 ///
218 /// [TransferOpReduceRank]
219 /// Lower transfer_read op with broadcast in the leading dimensions into
220 /// transfer_read of lower rank + vector.broadcast.
221 /// Ex: vector.transfer_read ...
222 /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
223 /// into:
224 /// %v = vector.transfer_read ...
225 /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
226 /// vector.broadcast %v
228  RewritePatternSet &patterns, PatternBenefit benefit = 1);
229 
230 /// Populate the pattern set with the following patterns:
231 ///
232 /// [ScanToArithOps]
233 /// Convert vector.scan op into arith ops and vector.insert_strided_slice /
234 /// vector.extract_strided_slice.
235 void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
236  PatternBenefit benefit = 1);
237 
238 /// Populate the pattern set with the following patterns:
239 ///
240 /// [StepToArithConstantOp]
241 /// Convert vector.step op into arith ops if not using scalable vectors
242 void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
243  PatternBenefit benefit = 1);
244 
245 /// Populate the pattern set with the following patterns:
246 ///
247 /// [FlattenGather]
248 /// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
249 /// outermost dimension.
250 ///
251 /// [Gather1DToConditionalLoads]
252 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
253 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
254 /// loads/extracts are made conditional using `scf.if` ops.
255 void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
256  PatternBenefit benefit = 1);
257 
258 /// Populates instances of `MaskOpRewritePattern` to lower masked operations
259 /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
260 /// not its nested `MaskableOpInterface`.
262  RewritePatternSet &patterns);
263 
264 /// Populate the pattern set with the following patterns:
265 ///
266 /// [VectorMaskedLoadOpConverter]
267 /// Turns vector.maskedload to scf.if + memref.load
268 ///
269 /// [VectorMaskedStoreOpConverter]
270 /// Turns vector.maskedstore to scf.if + memref.store
272  PatternBenefit benefit = 1);
273 
274 /// Populate the pattern set with the following patterns:
275 ///
276 /// [UnrollInterleaveOp]
277 /// A one-shot unrolling of InterleaveOp to (one or more) ExtractOp +
278 /// InterleaveOp (of `targetRank`) + InsertOp.
279 void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
280  int64_t targetRank = 1,
281  PatternBenefit benefit = 1);
282 
284  PatternBenefit benefit = 1);
285 
286 /// Populates the pattern set with the following patterns:
287 ///
288 /// [UnrollBitCastOp]
289 /// A one-shot unrolling of BitCastOp to (one or more) ExtractOp +
290 /// BitCastOp (of `targetRank`) + InsertOp.
291 void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
292  int64_t targetRank = 1,
293  PatternBenefit benefit = 1);
294 
295 /// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where
296 /// n > 1.
297 void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
298 
299 } // namespace vector
300 } // namespace mlir
301 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
static llvm::ManagedStatic< PassManagerOptions > options
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populates the pattern set with the following patterns:
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
void populateVectorMaskLoweringPatternsForSideEffectingOps(RewritePatternSet &patterns)
Populates instances of MaskOpRewritePattern to lower masked operations with vector....
void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns