MLIR 23.0.0git
LoweringPatterns.h
Go to the documentation of this file.
1//===- LoweringPatterns.h - Vector rewrite patterns --------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
10#define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
11
14
15namespace mlir {
17
18namespace vector {
19
20//===----------------------------------------------------------------------===//
21// Lowering pattern populate functions
22//===----------------------------------------------------------------------===//
23
24/// Populate the pattern set with the following patterns:
25///
26/// [OuterProductOpLowering]
27/// Progressively lower a `vector.outerproduct` to linearized
28/// `vector.extract` + `vector.fma` + `vector.insert`.
29///
30/// [ContractionOpLowering]
31/// Progressive lowering of ContractionOp.
32/// One:
33/// %x = vector.contract with at least one free/batch dimension
34/// is replaced by:
35/// %a = vector.contract with one less free/batch dimension
36/// %b = vector.contract with one less free/batch dimension
37///
38/// [ContractionOpToMatmulOpLowering]
39/// Progressively lower a `vector.contract` with row-major matmul semantics to
40/// linearized `vector.shape_cast` + `vector.matmul` on the way to
41/// `llvm.matrix.multiply`.
42///
43/// [ContractionOpToDotLowering]
44/// Progressively lower a `vector.contract` with row-major matmul semantics to
45/// linearized `vector.extract` + `vector.reduction` + `vector.insert`.
46///
47/// [ContractionOpToOuterProductOpLowering]
48/// Progressively lower a `vector.contract` with row-major matmul semantics to
49/// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
51 RewritePatternSet &patterns,
52 VectorContractLowering vectorContractLoweringOption,
53 PatternBenefit benefit = 1, bool disableOuterProductLowering = false);
54
55/// Populate the pattern set with the following patterns:
56///
57/// [OuterProductOpLowering]
58/// Progressively lower a `vector.outerproduct` to linearized
59/// `vector.extract` + `vector.fma` + `vector.insert`.
61 PatternBenefit benefit = 1);
62
63/// Populate the pattern set with the following patterns:
64///
65/// [InnerOuterDimReductionConversion]
66/// Rewrites vector.multi_reduction such that all reduction dimensions are
67/// either innermost or outermost, by adding the proper vector.transpose
68/// operations.
69///
70/// [OneDimMultiReductionToTwoDim]
71/// For cases that reduce to 1-D vector<k> reduction (and are thus missing
72/// either a parallel or a reduction), we lift them back up to 2-D with a simple
73/// vector.shape_cast to vector<1xk> so that the other patterns can kick in,
74/// thus fully exiting out of the vector.multi_reduction abstraction.
76 RewritePatternSet &patterns, VectorMultiReductionLowering options,
77 PatternBenefit benefit = 1);
78
79/// Populate the pattern set with the following patterns:
80///
81/// [ReduceMultiDimReductionRank]
82/// Once in innermost or outermost reduction
83/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
84/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
85/// back.
87 RewritePatternSet &patterns, VectorMultiReductionLowering options,
88 PatternBenefit benefit = 1);
89
90/// Populate the pattern set with the following patterns:
91///
92/// [TwoDimMultiReductionToElementWise]
93/// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
94/// dimension, unroll the outer dimension to obtain a sequence of 1-D vector
95/// ops. This also has an opportunity for tree-reduction (in the future).
96///
97/// [TwoDimMultiReductionToReduction]
98/// Once in 2-D vector.multi_reduction form, with an **innermost** reduction
99/// dimension, unroll the outer dimension to obtain a sequence of extract +
100/// vector.reduction + insert. This can further lower to horizontal reduction
101/// ops.
103 RewritePatternSet &patterns, VectorMultiReductionLowering options,
104 PatternBenefit benefit = 1);
105
106/// Collect a set of patterns to convert vector.multi_reduction op into
107/// a sequence of vector.reduction ops. These patterns are the ones
108/// populated by:
109///
110/// * populateVectorMultiReductionReorderAndExpandPatterns
111/// * populateVectorMultiReductionFlatteningPatterns
112/// * populateVectorMultiReductionUnrollingPatterns
113///
114/// This is just a convenience wrapper that we use in testing and is effectively
115/// deprecated.
116/// TODO: Delete.
118 RewritePatternSet &patterns, VectorMultiReductionLowering options,
119 PatternBenefit benefit = 1);
120
121/// Populate the pattern set with the following patterns:
122///
123/// [TransferReadToVectorLoadLowering]
124/// Progressive lowering of BroadcastOp to ExtractOp + InsertOp + lower-D
125/// BroadcastOp until dim 1.
127 PatternBenefit benefit = 1);
128
129/// Populate the pattern set with the following patterns:
130///
131/// [CreateMaskOp]
132/// Progressive lowering of CreateMaskOp to lower-D CreateMaskOp until dim 1.
133///
134/// [ConstantMaskOp]
135/// Progressive lowering of ConstantMaskOp to lower-D ConstantMaskOp until
136/// dim 1.
137void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns,
138 PatternBenefit benefit = 1);
139
140/// Collects patterns that lower scalar vector transfer ops to memref loads and
141/// stores when beneficial. If `allowMultipleUses` is set to true, the patterns
142/// are applied to vector transfer reads with any number of uses. Otherwise,
143/// only vector transfer reads with a single use will be lowered.
145 PatternBenefit benefit,
146 bool allowMultipleUses);
147
148/// Populate the pattern set with the following patterns:
149///
150/// [ShapeCastOp2DDownCastRewritePattern]
151/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
152/// vectors progressively.
153///
154/// [ShapeCastOp2DUpCastRewritePattern]
155/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
156/// vectors progressively.
157///
158/// [ShapeCastOpRewritePattern]
159/// Reference lowering to fully unrolled sequences of single element ExtractOp +
160/// InsertOp. Note that applying this pattern can almost always be considered a
161/// performance bug.
163 PatternBenefit benefit = 1);
164
165/// Populate the pattern set with the following patterns:
166///
167/// [TransposeOpLowering]
168///
169/// [TransposeOp2DToShuffleLowering]
170///
172 RewritePatternSet &patterns,
173 VectorTransposeLowering vectorTransposeLowering,
174 PatternBenefit benefit = 1);
175
176/// Populate the pattern set with the following patterns:
177///
178/// [TransferReadToVectorLoadLowering]
179/// Progressive lowering of transfer_read.This pattern supports lowering of
180/// `vector.transfer_read` to a combination of `vector.load` and
181/// `vector.broadcast`
182///
183/// [TransferWriteToVectorStoreLowering]
184/// Progressive lowering of transfer_write. This pattern supports lowering of
185/// `vector.transfer_write` to `vector.store`
186///
187/// These patterns lower transfer ops to simpler ops like `vector.load`,
188/// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank
189/// of a most `maxTransferRank` are lowered. This is useful when combined with
190/// VectorToSCF, which reduces the rank of vector transfer ops.
192 RewritePatternSet &patterns,
193 std::optional<unsigned> maxTransferRank = std::nullopt,
194 PatternBenefit benefit = 1);
195
196/// Collect a set of transfer read/write lowering patterns that simplify the
197/// permutation map (e.g., converting it to a minor identity map) by inserting
198/// broadcasts and transposes. More specifically:
199///
200/// [TransferReadPermutationLowering]
201/// Lower transfer_read op with permutation into a transfer_read with a
202/// permutation map composed of leading zeros followed by a minor identity +
203/// vector.transpose op.
204/// Ex:
205/// vector.transfer_read ...
206/// permutation_map: (d0, d1, d2) -> (0, d1)
207/// into:
208/// %v = vector.transfer_read ...
209/// permutation_map: (d0, d1, d2) -> (d1, 0)
210/// vector.transpose %v, [1, 0]
211///
212/// vector.transfer_read ...
213/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
214/// into:
215/// %v = vector.transfer_read ...
216/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
217/// vector.transpose %v, [0, 1, 3, 2, 4]
218/// Note that an alternative is to transform it to linalg.transpose +
219/// vector.transfer_read to do the transpose in memory instead.
220///
221/// [TransferWritePermutationLowering]
222/// Lower transfer_write op with permutation into a transfer_write with a
223/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
224/// Ex:
225/// vector.transfer_write %v ...
226/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
227/// into:
228/// %tmp = vector.transpose %v, [2, 0, 1]
229/// vector.transfer_write %tmp ...
230/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
231///
232/// vector.transfer_write %v ...
233/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
234/// into:
235/// %tmp = vector.transpose %v, [1, 0]
236/// %v = vector.transfer_write %tmp ...
237/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
238///
239/// [TransferOpReduceRank]
240/// Lower transfer_read op with broadcast in the leading dimensions into
241/// transfer_read of lower rank + vector.broadcast.
242/// Ex: vector.transfer_read ...
243/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
244/// into:
245/// %v = vector.transfer_read ...
246/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
247/// vector.broadcast %v
249 RewritePatternSet &patterns, PatternBenefit benefit = 1);
250
251/// Populate the pattern set with the following patterns:
252///
253/// [ScanToArithOps]
254/// Convert vector.scan op into arith ops and vector.insert_strided_slice /
255/// vector.extract_strided_slice.
256void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
257 PatternBenefit benefit = 1);
258
259/// Populate the pattern set with the following patterns:
260///
261/// [StepToArithConstantOp]
262/// Convert vector.step op into arith ops if not using scalable vectors
263void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
264 PatternBenefit benefit = 1);
265
266/// Populate the pattern set with the following patterns:
267///
268/// [UnrollGather]
269/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
270/// outermost dimension.
271void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
272 PatternBenefit benefit = 1);
273
274/// Populate the pattern set with the following patterns:
275///
276/// [Gather1DToConditionalLoads]
277/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
278/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
279/// loads/extracts are made conditional using `scf.if` ops.
281 PatternBenefit benefit = 1);
282
283/// Populates instances of `MaskOpRewritePattern` to lower masked operations
284/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
285/// not its nested `MaskableOpInterface`.
287 RewritePatternSet &patterns);
288
289/// Populate the pattern set with the following patterns:
290///
291/// [VectorMaskedLoadOpConverter]
292/// Turns vector.maskedload to scf.if + memref.load
293///
294/// [VectorMaskedStoreOpConverter]
295/// Turns vector.maskedstore to scf.if + memref.store
297 PatternBenefit benefit = 1);
298
299/// Populate the pattern set with the following patterns:
300///
301/// [UnrollInterleaveOp]
302/// A one-shot unrolling of InterleaveOp to (one or more) ExtractOp +
303/// InterleaveOp (of `targetRank`) + InsertOp.
305 int64_t targetRank = 1,
306 PatternBenefit benefit = 1);
307
309 PatternBenefit benefit = 1);
310
311/// Populates the pattern set with the following patterns:
312///
313/// [UnrollBitCastOp]
314/// A one-shot unrolling of BitCastOp to (one or more) ExtractOp +
315/// BitCastOp (of `targetRank`) + InsertOp.
316void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
317 int64_t targetRank = 1,
318 PatternBenefit benefit = 1);
319
320void populateVectorShuffleLoweringPatterns(RewritePatternSet &patterns,
321 PatternBenefit benefit = 1);
322
323/// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where
324/// n > 1.
325void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
326
327/// Populate patterns to rewrite sequences of `vector.to_elements` +
328/// `vector.from_elements` operations into a tree of `vector.shuffle`
329/// operations.
331 RewritePatternSet &patterns, PatternBenefit benefit = 1);
332
333/// Populate the pattern set with the following patterns:
334///
335/// [ContractionOpToMatmulOpLowering]
336/// Lowers `vector.contract` to `llvm.intr.matrix.multiply`.
337///
338/// Given the high benefit, this will be prioriotised over other
339/// contract-lowering patterns. As such, the convert-vector-to-llvm pass will
340/// only run this registration conditionally.
342 PatternBenefit benefit = 100);
343
344/// Populate the pattern set with the following patterns:
345///
346/// [TransposeOpLowering]
347/// Lowers `vector.transpose` to `llvm.intr.matrix.flat_transpose`.
348///
349/// Given the high benefit, this will be prioriotised over other
350/// transpose-lowering patterns. As such, the convert-vector-to-llvm pass will
351/// only run this registration conditionally.
353 PatternBenefit benefit = 100);
354
355} // namespace vector
356} // namespace mlir
357
358#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
static llvm::ManagedStatic< PassManagerOptions > options
void populateVectorMultiReductionReorderAndExpandPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorShuffleLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorToFromElementsToShuffleTreePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns to rewrite sequences of vector.to_elements + vector.from_elements operations into a...
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMultiReductionUnrollingPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populates the pattern set with the following patterns:
void populateVectorMultiReductionFlatteningPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit, bool allowMultipleUses)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
void populateVectorMaskLoweringPatternsForSideEffectingOps(RewritePatternSet &patterns)
Populates instances of MaskOpRewritePattern to lower masked operations with vector....
void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns