MLIR  22.0.0git
VectorTransformOps.cpp
Go to the documentation of this file.
1 //===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
22 
23 using namespace mlir;
24 using namespace mlir::vector;
25 using namespace mlir::transform;
26 
27 //===----------------------------------------------------------------------===//
28 // Apply...ConversionPatternsOp
29 //===----------------------------------------------------------------------===//
30 
31 void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns(
32  TypeConverter &typeConverter, RewritePatternSet &patterns) {
34  static_cast<LLVMTypeConverter &>(typeConverter), patterns,
35  getReassociateFpReductions(), getForce_32bitVectorIndices(),
36  getUseVectorAlignment());
37 }
38 
39 LogicalResult
40 transform::ApplyVectorToLLVMConversionPatternsOp::verifyTypeConverter(
41  transform::TypeConverterBuilderOpInterface builder) {
42  if (builder.getTypeConverterType() != "LLVMTypeConverter")
43  return emitOpError("expected LLVMTypeConverter");
44  return success();
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // Apply...PatternsOp
49 //===----------------------------------------------------------------------===//
50 
51 void transform::ApplyCastAwayVectorLeadingOneDimPatternsOp::populatePatterns(
53  vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
54 }
55 
56 void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
59 }
60 
61 void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns(
64 }
65 
66 void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
69 }
70 
71 void transform::ApplyLowerCreateMaskPatternsOp::populatePatterns(
74 }
75 
76 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
78  vector::populateVectorTransferDropUnitDimsPatterns(patterns);
79 }
80 
81 void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
84 }
85 
86 void transform::ApplyDropUnitDimWithShapeCastPatternsOp::populatePatterns(
88  vector::populateDropUnitDimWithShapeCastPatterns(patterns);
89 }
90 
91 void transform::ApplyDropInnerMostUnitDimsFromXferOpsPatternsOp::
92  populatePatterns(RewritePatternSet &patterns) {
94 }
95 
96 void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
99 }
100 
101 void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
104 }
105 
106 void transform::ApplyLowerContractionPatternsOp::populatePatterns(
108  populateVectorContractLoweringPatterns(patterns, getLoweringStrategy(),
109  /*benefit=*/1,
110  /*disableOuterProductLowering=*/true);
111 }
112 
113 void transform::ApplyLowerMasksPatternsOp::populatePatterns(
116 }
117 
118 void transform::ApplyLowerMaskedTransfersPatternsOp::populatePatterns(
121 }
122 
123 void transform::ApplyMaterializeMasksPatternsOp::populatePatterns(
125  populateVectorMaskMaterializationPatterns(patterns,
126  /*force32BitVectorIndices=*/false);
127 }
128 
129 void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns(
131  vector::VectorTransformsOptions vectorTransformOptions;
132  vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
134  patterns, vectorTransformOptions.vectorMultiReductionLowering);
135 }
136 
137 void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
140 }
141 
142 void transform::ApplyLowerGatherPatternsOp::populatePatterns(
145 }
146 
147 void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
149  vector::populateVectorFromElementsUnrollPatterns(patterns);
150 }
151 
152 void transform::ApplyUnrollToElementsPatternsOp::populatePatterns(
154  vector::populateVectorToElementsUnrollPatterns(patterns);
155 }
156 
157 void transform::ApplyLowerScanPatternsOp::populatePatterns(
160 }
161 
162 void transform::ApplyLowerShapeCastPatternsOp::populatePatterns(
165 }
166 
167 void transform::ApplyLowerTransferPatternsOp::populatePatterns(
170  getMaxTransferRank());
171 }
172 
173 void transform::ApplyLowerTransposePatternsOp::populatePatterns(
176  getLoweringStrategy());
177  if (getAvx2LoweringStrategy()) {
178  auto avx2LoweringOptions =
181  .lower4x8xf32(true)
182  .lower8x8xf32(true));
184  patterns, avx2LoweringOptions, /*benefit=*/10);
185  }
186 }
187 
188 void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
191 }
192 
193 void transform::ApplyInterleaveToShufflePatternsOp::populatePatterns(
196 }
197 
198 void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
200  populateVectorNarrowTypeRewritePatterns(patterns);
201  populateVectorTransposeNarrowTypeRewritePatterns(patterns);
202 }
203 
204 void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
206  vector::VectorTransformsOptions vectorTransformOptions;
207  vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy());
208  populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions);
209 }
210 
211 void transform::ApplyTransferToScfPatternsOp::populatePatterns(
213  VectorTransferToSCFOptions vectorTransferToSCFOptions =
215  .enableFullUnroll(getFullUnroll())
216  .setTargetRank(getMaxTransferRank());
217  populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
218 }
219 
220 void transform::ApplySinkVectorPatternsOp::populatePatterns(
223 }
224 
225 void transform::ApplySinkVectorMemPatternsOp::populatePatterns(
227  vector::populateSinkVectorMemOpsPatterns(patterns);
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // Transform op registration
232 //===----------------------------------------------------------------------===//
233 
234 namespace {
235 /// Registers new ops and declares PDL as dependent dialect since the additional
236 /// ops are using PDL types for operands and results.
237 class VectorTransformDialectExtension
239  VectorTransformDialectExtension> {
240 public:
241  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorTransformDialectExtension)
242 
243  VectorTransformDialectExtension() {
244  declareGeneratedDialect<vector::VectorDialect>();
245  declareGeneratedDialect<LLVM::LLVMDialect>();
246  registerTransformOps<
247 #define GET_OP_LIST
248 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
249  >();
250  }
251 };
252 } // namespace
253 
254 #define GET_OP_CLASSES
255 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
256 
258  DialectRegistry &registry) {
259  registry.addExtensions<VectorTransformDialectExtension>();
260 }
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
Type conversion class.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold elementwise op on vectors to the vector dialect.
void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateDropInnerMostUnitDimsXferOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to collapse the most inner unit dims in xfer Ops.
void registerTransformDialectExtension(DialectRegistry &registry)
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populates the pattern set with the following patterns:
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
void populateVectorMaskLoweringPatternsForSideEffectingOps(RewritePatternSet &patterns)
Populates instances of MaskOpRewritePattern to lower masked operations with vector....
void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateSpecializedTransposeLoweringPatterns(RewritePatternSet &patterns, LoweringOptions options=LoweringOptions(), int benefit=10)
Insert specialized transpose lowering patterns.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateVectorToSCFConversionPatterns(RewritePatternSet &patterns, const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Collect a set of patterns to convert from the Vector dialect to SCF + func.
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
When lowering an N-d vector transfer op to an (N-1)-d vector transfer op, a temporary buffer is creat...
Definition: VectorToSCF.h:52
VectorTransferToSCFOptions & enableFullUnroll(bool u=true)
Definition: VectorToSCF.h:68
VectorTransferToSCFOptions & setTargetRank(unsigned r)
Definition: VectorToSCF.h:55
Structure to control the behavior of vector transform patterns.
VectorTransformsOptions & setVectorMultiReductionLowering(VectorMultiReductionLowering opt)
VectorMultiReductionLowering vectorMultiReductionLowering
Option to control the lowering of vector.multi_reduction.
VectorTransformsOptions & setVectorTransferSplit(VectorTransferSplit opt)
Options for controlling specialized AVX2 lowerings.
Definition: Transforms.h:159
LoweringOptions & setTransposeOptions(TransposeLoweringOptions options)
Definition: Transforms.h:162
Structure to control the behavior of specialized AVX2 transpose lowering.
Definition: Transforms.h:145