MLIR 23.0.0git
VectorTransformOps.cpp
Go to the documentation of this file.
1//===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
22
23using namespace mlir;
24using namespace mlir::vector;
25using namespace mlir::transform;
26
27//===----------------------------------------------------------------------===//
28// Apply...ConversionPatternsOp
29//===----------------------------------------------------------------------===//
30
31void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns(
32 TypeConverter &typeConverter, RewritePatternSet &patterns) {
34 static_cast<LLVMTypeConverter &>(typeConverter), patterns,
35 getReassociateFpReductions(), getForce_32bitVectorIndices(),
36 getUseVectorAlignment());
37}
38
39LogicalResult
40transform::ApplyVectorToLLVMConversionPatternsOp::verifyTypeConverter(
41 transform::TypeConverterBuilderOpInterface builder) {
42 if (builder.getTypeConverterType() != "LLVMTypeConverter")
43 return emitOpError("expected LLVMTypeConverter");
44 return success();
45}
46
47//===----------------------------------------------------------------------===//
48// Apply...PatternsOp
49//===----------------------------------------------------------------------===//
50
51void transform::ApplyCastAwayVectorLeadingOneDimPatternsOp::populatePatterns(
52 RewritePatternSet &patterns) {
53 vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
54}
55
56void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
57 RewritePatternSet &patterns) {
59}
60
61void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns(
62 RewritePatternSet &patterns) {
64}
65
66void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
67 RewritePatternSet &patterns) {
69}
70
71void transform::ApplyLowerCreateMaskPatternsOp::populatePatterns(
72 RewritePatternSet &patterns) {
74}
75
76void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
77 RewritePatternSet &patterns) {
78 vector::populateVectorTransferDropUnitDimsPatterns(patterns);
79}
80
81void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
82 RewritePatternSet &patterns) {
84}
85
86void transform::ApplyDropUnitDimWithShapeCastPatternsOp::populatePatterns(
87 RewritePatternSet &patterns) {
88 vector::populateDropUnitDimWithShapeCastPatterns(patterns);
89}
90
91void transform::ApplyDropInnerMostUnitDimsFromXferOpsPatternsOp::
92 populatePatterns(RewritePatternSet &patterns) {
94}
95
96void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
97 RewritePatternSet &patterns) {
99}
100
101void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
102 RewritePatternSet &patterns) {
104}
105
106void transform::ApplyLowerContractionPatternsOp::populatePatterns(
107 RewritePatternSet &patterns) {
108 populateVectorContractLoweringPatterns(patterns, getLoweringStrategy(),
109 /*benefit=*/1,
110 /*disableOuterProductLowering=*/true);
111}
112
113void transform::ApplyLowerMasksPatternsOp::populatePatterns(
114 RewritePatternSet &patterns) {
116}
117
118void transform::ApplyLowerMaskedTransfersPatternsOp::populatePatterns(
119 RewritePatternSet &patterns) {
121}
122
123void transform::ApplyMaterializeMasksPatternsOp::populatePatterns(
124 RewritePatternSet &patterns) {
125 populateVectorMaskMaterializationPatterns(patterns,
126 /*force32BitVectorIndices=*/false);
127}
128
129//===----------------------------------------------------------------------===//
130// Multi-reduction patterns
131//===----------------------------------------------------------------------===//
132void transform::ApplyReorderMultiReductionPatternsOp::populatePatterns(
133 RewritePatternSet &patterns) {
134 vector::VectorTransformsOptions vectorTransformOptions;
135 vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
137 patterns, vectorTransformOptions.vectorMultiReductionLowering);
138}
139
140void transform::ApplyMultiReductionFlatteningPatternsOp::populatePatterns(
141 RewritePatternSet &patterns) {
142 vector::VectorTransformsOptions vectorTransformOptions;
143 vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
145 patterns, vectorTransformOptions.vectorMultiReductionLowering);
146}
147
148void transform::ApplyMultiReductionUnrollingPatternsOp::populatePatterns(
149 RewritePatternSet &patterns) {
150 vector::VectorTransformsOptions vectorTransformOptions;
151 vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
153 patterns, vectorTransformOptions.vectorMultiReductionLowering);
154}
155
156void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
157 RewritePatternSet &patterns) {
159}
160
161void transform::ApplyLowerGatherPatternsOp::populatePatterns(
162 RewritePatternSet &patterns) {
164}
165
166void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
167 RewritePatternSet &patterns) {
168 vector::populateVectorFromElementsUnrollPatterns(patterns);
169}
170
171void transform::ApplyUnrollToElementsPatternsOp::populatePatterns(
172 RewritePatternSet &patterns) {
173 vector::populateVectorToElementsUnrollPatterns(patterns);
174}
175
176void transform::ApplyLowerScanPatternsOp::populatePatterns(
177 RewritePatternSet &patterns) {
179}
180
181void transform::ApplyLowerShapeCastPatternsOp::populatePatterns(
182 RewritePatternSet &patterns) {
184}
185
186void transform::ApplyLowerTransferPatternsOp::populatePatterns(
187 RewritePatternSet &patterns) {
189 getMaxTransferRank());
190}
191
192void transform::ApplyLowerTransposePatternsOp::populatePatterns(
193 RewritePatternSet &patterns) {
195 getLoweringStrategy());
196 if (getAvx2LoweringStrategy()) {
197 auto avx2LoweringOptions = x86::avx2::LoweringOptions().setTransposeOptions(
198 x86::avx2::TransposeLoweringOptions().lower4x8xf32(true).lower8x8xf32(
199 true));
201 patterns, avx2LoweringOptions, /*benefit=*/10);
202 }
203}
204
205void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
206 RewritePatternSet &patterns) {
208}
209
210void transform::ApplyInterleaveAndDeinterleaveToShufflePatternsOp::
211 populatePatterns(RewritePatternSet &patterns) {
214}
215
216void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
217 RewritePatternSet &patterns) {
218 populateVectorNarrowTypeRewritePatterns(patterns);
219 populateVectorTransposeNarrowTypeRewritePatterns(patterns);
220}
221
222void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
223 RewritePatternSet &patterns) {
224 vector::VectorTransformsOptions vectorTransformOptions;
225 vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy());
226 populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions);
227}
228
229void transform::ApplyTransferToScfPatternsOp::populatePatterns(
230 RewritePatternSet &patterns) {
231 VectorTransferToSCFOptions vectorTransferToSCFOptions =
233 .enableFullUnroll(getFullUnroll())
234 .setTargetRank(getMaxTransferRank());
235 populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
236}
237
238void transform::ApplySinkVectorPatternsOp::populatePatterns(
239 RewritePatternSet &patterns) {
241}
242
243void transform::ApplySinkVectorMemPatternsOp::populatePatterns(
244 RewritePatternSet &patterns) {
245 vector::populateSinkVectorMemOpsPatterns(patterns);
246}
247
248void transform::ApplyFlattenVectorTransferOpsPatternsOp::populatePatterns(
249 RewritePatternSet &patterns) {
250 vector::populateFlattenVectorTransferPatterns(patterns,
251 getTargetVectorBitwidth());
252}
253
254//===----------------------------------------------------------------------===//
255// Transform op registration
256//===----------------------------------------------------------------------===//
257
258namespace {
259/// Registers new ops and declares PDL as dependent dialect since the additional
260/// ops are using PDL types for operands and results.
261class VectorTransformDialectExtension
263 VectorTransformDialectExtension> {
264public:
265 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorTransformDialectExtension)
266
267 VectorTransformDialectExtension() {
268 declareGeneratedDialect<vector::VectorDialect>();
269 declareGeneratedDialect<LLVM::LLVMDialect>();
270 registerTransformOps<
271#define GET_OP_LIST
272#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
273 >();
274 }
275};
276} // namespace
277
278#define GET_OP_CLASSES
279#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
280
282 DialectRegistry &registry) {
283 registry.addExtensions<VectorTransformDialectExtension>();
284}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Conversion from types to the LLVM IR dialect.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold elementwise op on vectors to the vector dialect.
void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateDropInnerMostUnitDimsXferOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to collapse the most inner unit dims in xfer Ops.
void registerTransformDialectExtension(DialectRegistry &registry)
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorMultiReductionUnrollingPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populates the pattern set with the following patterns:
void populateVectorMultiReductionFlatteningPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
void populateVectorMaskLoweringPatternsForSideEffectingOps(RewritePatternSet &patterns)
Populates instances of MaskOpRewritePattern to lower masked operations with vector....
void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorDeinterleaveToShufflePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
void populateVectorMultiReductionReorderPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateSpecializedTransposeLoweringPatterns(RewritePatternSet &patterns, LoweringOptions options=LoweringOptions(), int benefit=10)
Insert specialized transpose lowering patterns.
Include the generated interface declarations.
void populateVectorToSCFConversionPatterns(RewritePatternSet &patterns, const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Collect a set of patterns to convert from the Vector dialect to SCF + func.
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
When lowering an N-d vector transfer op to an (N-1)-d vector transfer op, a temporary buffer is creat...
Definition VectorToSCF.h:52
VectorTransferToSCFOptions & setTargetRank(unsigned r)
Definition VectorToSCF.h:55
VectorTransferToSCFOptions & enableFullUnroll(bool u=true)
Definition VectorToSCF.h:68
Structure to control the behavior of vector transform patterns.
VectorMultiReductionLowering vectorMultiReductionLowering
Option to control the lowering of vector.multi_reduction.
VectorTransformsOptions & setVectorMultiReductionLowering(VectorMultiReductionLowering opt)
VectorTransformsOptions & setVectorTransferSplit(VectorTransferSplit opt)
Options for controlling specialized AVX2 lowerings.
Definition Transforms.h:190
LoweringOptions & setTransposeOptions(TransposeLoweringOptions options)
Definition Transforms.h:193
Structure to control the behavior of specialized AVX2 transpose lowering.
Definition Transforms.h:176