MLIR  19.0.0git
VectorTransformOps.cpp
Go to the documentation of this file.
1 //===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
24 
25 using namespace mlir;
26 using namespace mlir::vector;
27 using namespace mlir::transform;
28 
29 //===----------------------------------------------------------------------===//
30 // Apply...ConversionPatternsOp
31 //===----------------------------------------------------------------------===//
32 
33 void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns(
34  TypeConverter &typeConverter, RewritePatternSet &patterns) {
36  static_cast<LLVMTypeConverter &>(typeConverter), patterns,
37  getReassociateFpReductions(), getForce_32bitVectorIndices());
38 }
39 
41 transform::ApplyVectorToLLVMConversionPatternsOp::verifyTypeConverter(
42  transform::TypeConverterBuilderOpInterface builder) {
43  if (builder.getTypeConverterType() != "LLVMTypeConverter")
44  return emitOpError("expected LLVMTypeConverter");
45  return success();
46 }
47 
48 //===----------------------------------------------------------------------===//
49 // Apply...PatternsOp
50 //===----------------------------------------------------------------------===//
51 
52 void transform::ApplyCastAwayVectorLeadingOneDimPatternsOp::populatePatterns(
53  RewritePatternSet &patterns) {
55 }
56 
57 void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
58  RewritePatternSet &patterns) {
60 }
61 
62 void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns(
63  RewritePatternSet &patterns) {
65 }
66 
67 void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
68  RewritePatternSet &patterns) {
70 }
71 
72 void transform::ApplyLowerCreateMaskPatternsOp::populatePatterns(
73  RewritePatternSet &patterns) {
75 }
76 
77 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
78  RewritePatternSet &patterns) {
80 }
81 
82 void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
83  RewritePatternSet &patterns) {
85 }
86 
87 void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
88  RewritePatternSet &patterns) {
90 }
91 
92 void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
93  RewritePatternSet &patterns) {
95 }
96 
97 void transform::ApplyLowerContractionPatternsOp::populatePatterns(
98  RewritePatternSet &patterns) {
99  vector::VectorTransformsOptions vectorTransformOptions;
100  vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy());
101  populateVectorContractLoweringPatterns(patterns, vectorTransformOptions,
102  /*benefit=*/1,
103  /*disableOuterProductLowering=*/true);
104 }
105 
106 void transform::ApplyLowerMasksPatternsOp::populatePatterns(
107  RewritePatternSet &patterns) {
109 }
110 
111 void transform::ApplyLowerMaskedTransfersPatternsOp::populatePatterns(
112  RewritePatternSet &patterns) {
114 }
115 
116 void transform::ApplyMaterializeMasksPatternsOp::populatePatterns(
117  RewritePatternSet &patterns) {
119  /*force32BitVectorIndices=*/false);
120 }
121 
122 void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns(
123  RewritePatternSet &patterns) {
124  vector::VectorTransformsOptions vectorTransformOptions;
125  vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
127  patterns, vectorTransformOptions.vectorMultiReductionLowering);
128 }
129 
130 void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
131  RewritePatternSet &patterns) {
133 }
134 
135 void transform::ApplyLowerGatherPatternsOp::populatePatterns(
136  RewritePatternSet &patterns) {
138 }
139 
140 void transform::ApplyLowerScanPatternsOp::populatePatterns(
141  RewritePatternSet &patterns) {
143 }
144 
145 void transform::ApplyLowerShapeCastPatternsOp::populatePatterns(
146  RewritePatternSet &patterns) {
148 }
149 
150 void transform::ApplyLowerTransferPatternsOp::populatePatterns(
151  RewritePatternSet &patterns) {
153  getMaxTransferRank());
154 }
155 
156 void transform::ApplyLowerTransposePatternsOp::populatePatterns(
157  RewritePatternSet &patterns) {
159  patterns, vector::VectorTransformsOptions().setVectorTransposeLowering(
160  getLoweringStrategy()));
161  if (getAvx2LoweringStrategy()) {
162  auto avx2LoweringOptions =
165  .lower4x8xf32(true)
166  .lower8x8xf32(true));
168  patterns, avx2LoweringOptions, /*benefit=*/10);
169  }
170 }
171 
172 void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
173  RewritePatternSet &patterns) {
175 }
176 
177 void transform::ApplyInterleaveToShufflePatternsOp::populatePatterns(
178  RewritePatternSet &patterns) {
180 }
181 
182 void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
183  RewritePatternSet &patterns) {
186 }
187 
188 void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
189  RewritePatternSet &patterns) {
190  vector::VectorTransformsOptions vectorTransformOptions;
191  vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy());
192  populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions);
193 }
194 
195 void transform::ApplyTransferToScfPatternsOp::populatePatterns(
196  RewritePatternSet &patterns) {
197  VectorTransferToSCFOptions vectorTransferToSCFOptions =
199  .enableFullUnroll(getFullUnroll())
200  .setTargetRank(getMaxTransferRank());
201  populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
202 }
203 
204 //===----------------------------------------------------------------------===//
205 // Transform op registration
206 //===----------------------------------------------------------------------===//
207 
208 namespace {
209 /// Registers new ops and declares PDL as dependent dialect since the additional
210 /// ops are using PDL types for operands and results.
211 class VectorTransformDialectExtension
213  VectorTransformDialectExtension> {
214 public:
215  VectorTransformDialectExtension() {
216  declareGeneratedDialect<vector::VectorDialect>();
217  declareGeneratedDialect<LLVM::LLVMDialect>();
218  registerTransformOps<
219 #define GET_OP_LIST
220 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
221  >();
222  }
223 };
224 } // namespace
225 
226 #define GET_OP_CLASSES
227 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
228 
230  DialectRegistry &registry) {
231  registry.addExtensions<VectorTransformDialectExtension>();
232 }
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
Type conversion class.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for rewriting vector operations over narrow types with ops over wider types.
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold elementwise op on vectors to the vector dialect.
void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of one dimension removal patterns.
void registerTransformDialectExtension(DialectRegistry &registry)
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populates the pattern set with the following patterns:
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit=1)
These patterns materialize masks for various vector ops such as transfers.
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
void populateVectorMaskLoweringPatternsForSideEffectingOps(RewritePatternSet &patterns)
Populates instances of MaskOpRewritePattern to lower masked operations with vector....
void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of leading one dimension removal patterns.
void populateVectorTransposeNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for emulating a sub-byte vector transpose.
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateSpecializedTransposeLoweringPatterns(RewritePatternSet &patterns, LoweringOptions options=LoweringOptions(), int benefit=10)
Insert specialized transpose lowering patterns.
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToSCFConversionPatterns(RewritePatternSet &patterns, const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Collect a set of patterns to convert from the Vector dialect to SCF + func.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
When lowering an N-d vector transfer op to an (N-1)-d vector transfer op, a temporary buffer is creat...
Definition: VectorToSCF.h:52
VectorTransferToSCFOptions & enableFullUnroll(bool u=true)
Definition: VectorToSCF.h:68
VectorTransferToSCFOptions & setTargetRank(unsigned r)
Definition: VectorToSCF.h:55
Structure to control the behavior of vector transform patterns.
VectorTransformsOptions & setVectorMultiReductionLowering(VectorMultiReductionLowering opt)
VectorMultiReductionLowering vectorMultiReductionLowering
Option to control the lowering of vector.multi_reduction.
VectorTransformsOptions & setVectorTransferSplit(VectorTransferSplit opt)
VectorTransformsOptions & setVectorTransformsOptions(VectorContractLowering opt)
Options for controlling specialized AVX2 lowerings.
Definition: Transforms.h:159
LoweringOptions & setTransposeOptions(TransposeLoweringOptions options)
Definition: Transforms.h:162
Structure to control the behavior of specialized AVX2 transpose lowering.
Definition: Transforms.h:145