MLIR 22.0.0git
LoweringPatterns.h
Go to the documentation of this file.
1//===- LoweringPatterns.h - Vector rewrite patterns --------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
10#define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
11
14
15namespace mlir {
17
18namespace vector {
19
20//===----------------------------------------------------------------------===//
21// Lowering pattern populate functions
22//===----------------------------------------------------------------------===//
23
24/// Populate the pattern set with the following patterns:
25///
26/// [OuterProductOpLowering]
27/// Progressively lower a `vector.outerproduct` to linearized
28/// `vector.extract` + `vector.fma` + `vector.insert`.
29///
30/// [ContractionOpLowering]
31/// Progressive lowering of ContractionOp.
32/// One:
33/// %x = vector.contract with at least one free/batch dimension
34/// is replaced by:
35/// %a = vector.contract with one less free/batch dimension
36/// %b = vector.contract with one less free/batch dimension
37///
38/// [ContractionOpToMatmulOpLowering]
39/// Progressively lower a `vector.contract` with row-major matmul semantics to
40/// linearized `vector.shape_cast` + `vector.matmul` on the way to
41/// `llvm.matrix.multiply`.
42///
43/// [ContractionOpToDotLowering]
44/// Progressively lower a `vector.contract` with row-major matmul semantics to
45/// linearized `vector.extract` + `vector.reduce` + `vector.insert`.
46///
47/// [ContractionOpToOuterProductOpLowering]
48/// Progressively lower a `vector.contract` with row-major matmul semantics to
49/// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
51 RewritePatternSet &patterns,
52 VectorContractLowering vectorContractLoweringOption,
53 PatternBenefit benefit = 1, bool disableOuterProductLowering = false);
54
55/// Populate the pattern set with the following patterns:
56///
57/// [OuterProductOpLowering]
58/// Progressively lower a `vector.outerproduct` to linearized
59/// `vector.extract` + `vector.fma` + `vector.insert`.
61 PatternBenefit benefit = 1);
62
63/// Collect a set of patterns to convert vector.multi_reduction op into
64/// a sequence of vector.reduction ops. The patterns comprise:
65///
66/// [InnerOuterDimReductionConversion]
67/// Rewrites vector.multi_reduction such that all reduction dimensions are
68/// either innermost or outermost, by adding the proper vector.transpose
69/// operations.
70///
71/// [ReduceMultiDimReductionRank]
72/// Once in innermost or outermost reduction
73/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
74/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
75/// back.
76///
77/// [TwoDimMultiReductionToElementWise]
78/// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
79/// dimension, unroll the outer dimension to obtain a sequence of 1-D vector
80/// ops. This also has an opportunity for tree-reduction (in the future).
81///
82/// [TwoDimMultiReductionToReduction]
83/// Once in 2-D vector.multi_reduction form, with an **innermost** reduction
84/// dimension, unroll the outer dimension to obtain a sequence of extract +
85/// vector.reduction + insert. This can further lower to horizontal reduction
86/// ops.
87///
88/// [OneDimMultiReductionToTwoDim]
89/// For cases that reduce to 1-D vector<k> reduction (and are thus missing
90/// either a parallel or a reduction), we lift them back up to 2-D with a simple
91/// vector.shape_cast to vector<1xk> so that the other patterns can kick in,
92/// thus fully exiting out of the vector.multi_reduction abstraction.
94 RewritePatternSet &patterns, VectorMultiReductionLowering options,
95 PatternBenefit benefit = 1);
96
97/// Populate the pattern set with the following patterns:
98///
99/// [TransferReadToVectorLoadLowering]
100/// Progressive lowering of BroadcastOp to ExtractOp + InsertOp + lower-D
101/// BroadcastOp until dim 1.
103 PatternBenefit benefit = 1);
104
105/// Populate the pattern set with the following patterns:
106///
107/// [CreateMaskOp]
108/// Progressive lowering of CreateMaskOp to lower-D CreateMaskOp until dim 1.
109///
110/// [ConstantMaskOp]
111/// Progressive lowering of ConstantMaskOp to lower-D ConstantMaskOp until
112/// dim 1.
113void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns,
114 PatternBenefit benefit = 1);
115
116/// Collects patterns that lower scalar vector transfer ops to memref loads and
117/// stores when beneficial. If `allowMultipleUses` is set to true, the patterns
118/// are applied to vector transfer reads with any number of uses. Otherwise,
119/// only vector transfer reads with a single use will be lowered.
121 PatternBenefit benefit,
122 bool allowMultipleUses);
123
124/// Populate the pattern set with the following patterns:
125///
126/// [ShapeCastOp2DDownCastRewritePattern]
127/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
128/// vectors progressively.
129///
130/// [ShapeCastOp2DUpCastRewritePattern]
131/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
132/// vectors progressively.
133///
134/// [ShapeCastOpRewritePattern]
135/// Reference lowering to fully unrolled sequences of single element ExtractOp +
136/// InsertOp. Note that applying this pattern can almost always be considered a
137/// performance bug.
139 PatternBenefit benefit = 1);
140
141/// Populate the pattern set with the following patterns:
142///
143/// [TransposeOpLowering]
144///
145/// [TransposeOp2DToShuffleLowering]
146///
148 RewritePatternSet &patterns,
149 VectorTransposeLowering vectorTransposeLowering,
150 PatternBenefit benefit = 1);
151
152/// Populate the pattern set with the following patterns:
153///
154/// [TransferReadToVectorLoadLowering]
155/// Progressive lowering of transfer_read.This pattern supports lowering of
156/// `vector.transfer_read` to a combination of `vector.load` and
157/// `vector.broadcast`
158///
159/// [TransferWriteToVectorStoreLowering]
160/// Progressive lowering of transfer_write. This pattern supports lowering of
161/// `vector.transfer_write` to `vector.store`
162///
163/// These patterns lower transfer ops to simpler ops like `vector.load`,
164/// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank
165/// of a most `maxTransferRank` are lowered. This is useful when combined with
166/// VectorToSCF, which reduces the rank of vector transfer ops.
168 RewritePatternSet &patterns,
169 std::optional<unsigned> maxTransferRank = std::nullopt,
170 PatternBenefit benefit = 1);
171
172/// Collect a set of transfer read/write lowering patterns that simplify the
173/// permutation map (e.g., converting it to a minor identity map) by inserting
174/// broadcasts and transposes. More specifically:
175///
176/// [TransferReadPermutationLowering]
177/// Lower transfer_read op with permutation into a transfer_read with a
178/// permutation map composed of leading zeros followed by a minor identity +
179/// vector.transpose op.
180/// Ex:
181/// vector.transfer_read ...
182/// permutation_map: (d0, d1, d2) -> (0, d1)
183/// into:
184/// %v = vector.transfer_read ...
185/// permutation_map: (d0, d1, d2) -> (d1, 0)
186/// vector.transpose %v, [1, 0]
187///
188/// vector.transfer_read ...
189/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
190/// into:
191/// %v = vector.transfer_read ...
192/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
193/// vector.transpose %v, [0, 1, 3, 2, 4]
194/// Note that an alternative is to transform it to linalg.transpose +
195/// vector.transfer_read to do the transpose in memory instead.
196///
197/// [TransferWritePermutationLowering]
198/// Lower transfer_write op with permutation into a transfer_write with a
199/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
200/// Ex:
201/// vector.transfer_write %v ...
202/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
203/// into:
204/// %tmp = vector.transpose %v, [2, 0, 1]
205/// vector.transfer_write %tmp ...
206/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
207///
208/// vector.transfer_write %v ...
209/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
210/// into:
211/// %tmp = vector.transpose %v, [1, 0]
212/// %v = vector.transfer_write %tmp ...
213/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
214///
215/// [TransferOpReduceRank]
216/// Lower transfer_read op with broadcast in the leading dimensions into
217/// transfer_read of lower rank + vector.broadcast.
218/// Ex: vector.transfer_read ...
219/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
220/// into:
221/// %v = vector.transfer_read ...
222/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
223/// vector.broadcast %v
225 RewritePatternSet &patterns, PatternBenefit benefit = 1);
226
227/// Populate the pattern set with the following patterns:
228///
229/// [ScanToArithOps]
230/// Convert vector.scan op into arith ops and vector.insert_strided_slice /
231/// vector.extract_strided_slice.
232void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
233 PatternBenefit benefit = 1);
234
235/// Populate the pattern set with the following patterns:
236///
237/// [StepToArithConstantOp]
238/// Convert vector.step op into arith ops if not using scalable vectors
239void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
240 PatternBenefit benefit = 1);
241
242/// Populate the pattern set with the following patterns:
243///
244/// [UnrollGather]
245/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
246/// outermost dimension.
247void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
248 PatternBenefit benefit = 1);
249
250/// Populate the pattern set with the following patterns:
251///
252/// [Gather1DToConditionalLoads]
253/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
254/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
255/// loads/extracts are made conditional using `scf.if` ops.
257 PatternBenefit benefit = 1);
258
259/// Populates instances of `MaskOpRewritePattern` to lower masked operations
260/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
261/// not its nested `MaskableOpInterface`.
263 RewritePatternSet &patterns);
264
265/// Populate the pattern set with the following patterns:
266///
267/// [VectorMaskedLoadOpConverter]
268/// Turns vector.maskedload to scf.if + memref.load
269///
270/// [VectorMaskedStoreOpConverter]
271/// Turns vector.maskedstore to scf.if + memref.store
273 PatternBenefit benefit = 1);
274
275/// Populate the pattern set with the following patterns:
276///
277/// [UnrollInterleaveOp]
278/// A one-shot unrolling of InterleaveOp to (one or more) ExtractOp +
279/// InterleaveOp (of `targetRank`) + InsertOp.
281 int64_t targetRank = 1,
282 PatternBenefit benefit = 1);
283
285 PatternBenefit benefit = 1);
286
287/// Populates the pattern set with the following patterns:
288///
289/// [UnrollBitCastOp]
290/// A one-shot unrolling of BitCastOp to (one or more) ExtractOp +
291/// BitCastOp (of `targetRank`) + InsertOp.
292void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
293 int64_t targetRank = 1,
294 PatternBenefit benefit = 1);
295
296void populateVectorShuffleLoweringPatterns(RewritePatternSet &patterns,
297 PatternBenefit benefit = 1);
298
299/// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where
300/// n > 1.
301void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
302
303/// Populate patterns to rewrite sequences of `vector.to_elements` +
304/// `vector.from_elements` operations into a tree of `vector.shuffle`
305/// operations.
307 RewritePatternSet &patterns, PatternBenefit benefit = 1);
308
309/// Populate the pattern set with the following patterns:
310///
311/// [ContractionOpToMatmulOpLowering]
312/// Lowers `vector.contract` to `llvm.intr.matrix.multiply`.
313///
314/// Given the high benefit, this will be prioriotised over other
315/// contract-lowering patterns. As such, the convert-vector-to-llvm pass will
316/// only run this registration conditionally.
318 PatternBenefit benefit = 100);
319
320/// Populate the pattern set with the following patterns:
321///
322/// [TransposeOpLowering]
323/// Lowers `vector.transpose` to `llvm.intr.matrix.flat_transpose`.
324///
325/// Given the high benefit, this will be prioriotised over other
326/// transpose-lowering patterns. As such, the convert-vector-to-llvm pass will
327/// only run this registration conditionally.
329 PatternBenefit benefit = 100);
330
331} // namespace vector
332} // namespace mlir
333
334#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
static llvm::ManagedStatic< PassManagerOptions > options
void populateVectorShuffleLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorToFromElementsToShuffleTreePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns to rewrite sequences of vector.to_elements + vector.from_elements operations into a...
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populates the pattern set with the following patterns:
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
void populateVectorMaskLoweringPatternsForSideEffectingOps(RewritePatternSet &patterns)
Populates instances of MaskOpRewritePattern to lower masked operations with vector....
void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns