MLIR 23.0.0git
LoweringPatterns.h
Go to the documentation of this file.
1//===- LoweringPatterns.h - Vector rewrite patterns --------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
10#define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
11
14
15namespace mlir {
17
18namespace vector {
19
20//===----------------------------------------------------------------------===//
21// Lowering pattern populate functions
22//===----------------------------------------------------------------------===//
23
24/// Populate the pattern set with the following patterns:
25///
26/// [OuterProductOpLowering]
27/// Progressively lower a `vector.outerproduct` to linearized
28/// `vector.extract` + `vector.fma` + `vector.insert`.
29///
30/// [ContractionOpLowering]
31/// Progressive lowering of ContractionOp.
32/// One:
33/// %x = vector.contract with at least one free/batch dimension
34/// is replaced by:
35/// %a = vector.contract with one less free/batch dimension
36/// %b = vector.contract with one less free/batch dimension
37///
38/// [ContractionOpToMatmulOpLowering]
39/// Progressively lower a `vector.contract` with row-major matmul semantics to
40/// linearized `vector.shape_cast` + `vector.matmul` on the way to
41/// `llvm.matrix.multiply`.
42///
43/// [ContractionOpToDotLowering]
44/// Progressively lower a `vector.contract` with row-major matmul semantics to
45/// linearized `vector.extract` + `vector.reduction` + `vector.insert`.
46///
47/// [ContractionOpToOuterProductOpLowering]
48/// Progressively lower a `vector.contract` with row-major matmul semantics to
49/// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
51 RewritePatternSet &patterns,
52 VectorContractLowering vectorContractLoweringOption,
53 PatternBenefit benefit = 1, bool disableOuterProductLowering = false);
54
55/// Populate the pattern set with the following patterns:
56///
57/// [OuterProductOpLowering]
58/// Progressively lower a `vector.outerproduct` to linearized
59/// `vector.extract` + `vector.fma` + `vector.insert`.
60void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns,
61 PatternBenefit benefit = 1);
62
63/// Populate the pattern set with the following patterns:
64///
65/// [InnerOuterDimReductionConversion]
66/// Rewrites vector.multi_reduction such that all reduction dimensions are
67/// either innermost or outermost, by adding the proper vector.transpose
68/// operations.
70 RewritePatternSet &patterns, VectorMultiReductionLowering options,
71 PatternBenefit benefit = 1);
72
73/// Populate the pattern set with the following patterns:
74///
75/// [ReduceMultiDimReductionRank]
76/// Once in innermost or outermost reduction
77/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
78/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
79/// back.
81 RewritePatternSet &patterns, VectorMultiReductionLowering options,
82 PatternBenefit benefit = 1);
83
84/// Populate the pattern set with the following patterns:
85///
86/// [OneDimMultiReductionToReduction]
87/// Converts 1-D vector.multi_reduction to vector.reduction.
88///
89/// [TwoDimMultiReductionToElementWise]
90/// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
91/// dimension, unroll the outer dimension to obtain a sequence of 1-D vector
92/// ops. This also has an opportunity for tree-reduction (in the future).
93///
94/// [TwoDimMultiReductionToReduction]
95/// Once in 2-D vector.multi_reduction form, with an **innermost** reduction
96/// dimension, unroll the outer dimension to obtain a sequence of extract +
97/// vector.reduction + insert. This can further lower to horizontal reduction
98/// ops.
100 RewritePatternSet &patterns, VectorMultiReductionLowering options,
101 PatternBenefit benefit = 1);
102
103/// Populate the pattern set with the following patterns:
104///
105/// [TransferReadToVectorLoadLowering]
106/// Progressive lowering of BroadcastOp to ExtractOp + InsertOp + lower-D
107/// BroadcastOp until dim 1.
108void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns,
109 PatternBenefit benefit = 1);
110
111/// Populate the pattern set with the following patterns:
112///
113/// [CreateMaskOp]
114/// Progressive lowering of CreateMaskOp to lower-D CreateMaskOp until dim 1.
115///
116/// [ConstantMaskOp]
117/// Progressive lowering of ConstantMaskOp to lower-D ConstantMaskOp until
118/// dim 1.
119void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns,
120 PatternBenefit benefit = 1);
121
122/// Collects patterns that lower scalar vector transfer ops to memref loads and
123/// stores when beneficial. If `allowMultipleUses` is set to true, the patterns
124/// are applied to vector transfer reads with any number of uses. Otherwise,
125/// only vector transfer reads with a single use will be lowered.
126void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns,
127 PatternBenefit benefit,
128 bool allowMultipleUses);
129
130/// Populate the pattern set with the following patterns:
131///
132/// [ShapeCastOp2DDownCastRewritePattern]
133/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
134/// vectors progressively.
135///
136/// [ShapeCastOp2DUpCastRewritePattern]
137/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
138/// vectors progressively.
139///
140/// [ShapeCastOpRewritePattern]
141/// Reference lowering to fully unrolled sequences of single element ExtractOp +
142/// InsertOp. Note that applying this pattern can almost always be considered a
143/// performance bug.
144void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
145 PatternBenefit benefit = 1);
146
147/// Populate the pattern set with the following patterns:
148///
149/// [TransposeOpLowering]
150///
151/// [TransposeOp2DToShuffleLowering]
152///
154 RewritePatternSet &patterns,
155 VectorTransposeLowering vectorTransposeLowering,
156 PatternBenefit benefit = 1);
157
158/// Populate the pattern set with the following patterns:
159///
160/// [TransferReadToVectorLoadLowering]
161/// Progressive lowering of transfer_read.This pattern supports lowering of
162/// `vector.transfer_read` to a combination of `vector.load` and
163/// `vector.broadcast`
164///
165/// [TransferWriteToVectorStoreLowering]
166/// Progressive lowering of transfer_write. This pattern supports lowering of
167/// `vector.transfer_write` to `vector.store`
168///
169/// These patterns lower transfer ops to simpler ops like `vector.load`,
170/// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank
171/// of a most `maxTransferRank` are lowered. This is useful when combined with
172/// VectorToSCF, which reduces the rank of vector transfer ops.
174 RewritePatternSet &patterns,
175 std::optional<unsigned> maxTransferRank = std::nullopt,
176 PatternBenefit benefit = 1);
177
178/// Collect a set of transfer read/write lowering patterns that simplify the
179/// permutation map (e.g., converting it to a minor identity map) by inserting
180/// broadcasts and transposes. More specifically:
181///
182/// [TransferReadPermutationLowering]
183/// Lower transfer_read op with permutation into a transfer_read with a
184/// permutation map composed of leading zeros followed by a minor identity +
185/// vector.transpose op.
186/// Ex:
187/// vector.transfer_read ...
188/// permutation_map: (d0, d1, d2) -> (0, d1)
189/// into:
190/// %v = vector.transfer_read ...
191/// permutation_map: (d0, d1, d2) -> (d1, 0)
192/// vector.transpose %v, [1, 0]
193///
194/// vector.transfer_read ...
195/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
196/// into:
197/// %v = vector.transfer_read ...
198/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
199/// vector.transpose %v, [0, 1, 3, 2, 4]
200/// Note that an alternative is to transform it to linalg.transpose +
201/// vector.transfer_read to do the transpose in memory instead.
202///
203/// [TransferWritePermutationLowering]
204/// Lower transfer_write op with permutation into a transfer_write with a
205/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
206/// Ex:
207/// vector.transfer_write %v ...
208/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
209/// into:
210/// %tmp = vector.transpose %v, [2, 0, 1]
211/// vector.transfer_write %tmp ...
212/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
213///
214/// vector.transfer_write %v ...
215/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
216/// into:
217/// %tmp = vector.transpose %v, [1, 0]
218/// %v = vector.transfer_write %tmp ...
219/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
220///
221/// [TransferOpReduceRank]
222/// Lower transfer_read op with broadcast in the leading dimensions into
223/// transfer_read of lower rank + vector.broadcast.
224/// Ex: vector.transfer_read ...
225/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
226/// into:
227/// %v = vector.transfer_read ...
228/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
229/// vector.broadcast %v
231 RewritePatternSet &patterns, PatternBenefit benefit = 1);
232
233/// Populate the pattern set with the following patterns:
234///
235/// [ScanToArithOps]
236/// Convert vector.scan op into arith ops and vector.insert_strided_slice /
237/// vector.extract_strided_slice.
238void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
239 PatternBenefit benefit = 1);
240
241/// Populate the pattern set with the following patterns:
242///
243/// [StepToArithConstantOp]
244/// Convert vector.step op into arith ops if not using scalable vectors
245void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
246 PatternBenefit benefit = 1);
247
248/// Populate the pattern set with the following patterns:
249///
250/// [UnrollGather]
251/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
252/// outermost dimension.
253void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
254 PatternBenefit benefit = 1);
255
256/// Populate the pattern set with the following patterns:
257///
258/// [Gather1DToConditionalLoads]
259/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
260/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
261/// loads/extracts are made conditional using `scf.if` ops.
262void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns,
263 PatternBenefit benefit = 1);
264
265/// Populates instances of `MaskOpRewritePattern` to lower masked operations
266/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
267/// not its nested `MaskableOpInterface`.
269 RewritePatternSet &patterns);
270
271/// Populate the pattern set with the following patterns:
272///
273/// [VectorMaskedLoadOpConverter]
274/// Turns vector.maskedload to scf.if + memref.load
275///
276/// [VectorMaskedStoreOpConverter]
277/// Turns vector.maskedstore to scf.if + memref.store
278void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns,
279 PatternBenefit benefit = 1);
280
281/// Populate the pattern set with the following patterns:
282///
283/// [UnrollInterleaveOp]
284/// A one-shot unrolling of InterleaveOp to (one or more) ExtractOp +
285/// InterleaveOp (of `targetRank`) + InsertOp.
286void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
287 int64_t targetRank = 1,
288 PatternBenefit benefit = 1);
289
290void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns,
291 PatternBenefit benefit = 1);
292
293/// Populates the pattern set with the following patterns:
294///
295/// [UnrollBitCastOp]
296/// A one-shot unrolling of BitCastOp to (one or more) ExtractOp +
297/// BitCastOp (of `targetRank`) + InsertOp.
298void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
299 int64_t targetRank = 1,
300 PatternBenefit benefit = 1);
301
302void populateVectorShuffleLoweringPatterns(RewritePatternSet &patterns,
303 PatternBenefit benefit = 1);
304
305/// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where
306/// n > 1.
307void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
308
309/// Populate patterns to rewrite sequences of `vector.to_elements` +
310/// `vector.from_elements` operations into a tree of `vector.shuffle`
311/// operations.
313 RewritePatternSet &patterns, PatternBenefit benefit = 1);
314
315/// Populate the pattern set with the following patterns:
316///
317/// [ContractionOpToMatmulOpLowering]
318/// Lowers `vector.contract` to `llvm.intr.matrix.multiply`.
319///
320/// Given the high benefit, this will be prioriotised over other
321/// contract-lowering patterns. As such, the convert-vector-to-llvm pass will
322/// only run this registration conditionally.
323void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns,
324 PatternBenefit benefit = 100);
325
326/// Populate the pattern set with the following patterns:
327///
328/// [TransposeOpLowering]
329/// Lowers `vector.transpose` to `llvm.intr.matrix.flat_transpose`.
330///
331/// Given the high benefit, this will be prioriotised over other
332/// transpose-lowering patterns. As such, the convert-vector-to-llvm pass will
333/// only run this registration conditionally.
334void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns,
335 PatternBenefit benefit = 100);
336
337} // namespace vector
338} // namespace mlir
339
340#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
static llvm::ManagedStatic< PassManagerOptions > options
void populateVectorShuffleLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorToFromElementsToShuffleTreePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns to rewrite sequences of vector.to_elements + vector.from_elements operations into a...
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMultiReductionUnrollingPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populates the pattern set with the following patterns:
void populateVectorMultiReductionFlatteningPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
void populateVectorMaskLoweringPatternsForSideEffectingOps(RewritePatternSet &patterns)
Populates instances of MaskOpRewritePattern to lower masked operations with vector....
void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
void populateVectorMultiReductionReorderPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.