9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
19 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
22 class ConversionTarget;
23 class RewritePatternSet;
28 class NarrowTypeEmulationConverter;
33 struct VectorTransformsOptions;
47 std::function<std::optional<SmallVector<int64_t>>(
Operation *op)>;
71 std::function<std::optional<SmallVector<int64_t>>(
Operation *op)>;
97 std::function<LogicalResult(vector::ContractionOp)> constraint =
98 [](vector::ContractionOp) {
return success(); },
104 PatternBenefit benefit = 1);
145 RewritePatternSet &
patterns, PatternBenefit benefit = 1);
162 PatternBenefit benefit = 1);
175 void populateSinkVectorMemOpsPatterns(RewritePatternSet &
patterns,
176 PatternBenefit benefit = 1);
194 void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &
patterns,
195 PatternBenefit benefit = 1);
212 void populateBreakDownVectorReductionPatterns(
213 RewritePatternSet &
patterns,
unsigned maxNumElementsToExtract = 2,
214 PatternBenefit benefit = 1);
238 void populateVectorInsertExtractStridedSliceDecompositionPatterns(
239 RewritePatternSet &
patterns, PatternBenefit benefit = 1);
247 void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
249 std::function<
bool(ExtractStridedSliceOp)> controlFn =
nullptr,
250 PatternBenefit benefit = 1);
263 void populateBreakDownVectorBitCastOpPatterns(
265 std::function<
bool(BitCastOp)> controlFn =
nullptr,
266 PatternBenefit benefit = 1);
286 void populateVectorInsertExtractStridedSliceTransforms(
287 RewritePatternSet &
patterns, PatternBenefit benefit = 1);
319 void populateVectorUnrollPatterns(RewritePatternSet &
patterns,
320 const UnrollVectorOptions &
options,
321 PatternBenefit benefit = 1);
329 void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &
patterns,
330 PatternBenefit benefit = 1);
337 void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &
patterns,
338 PatternBenefit benefit = 1);
346 void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &
patterns,
347 PatternBenefit benefit = 1);
358 void populateFlattenVectorTransferPatterns(
361 PatternBenefit benefit = 1);
368 void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &
patterns,
369 PatternBenefit benefit = 1);
372 void populateVectorMaskMaterializationPatterns(RewritePatternSet &
patterns,
373 bool force32BitVectorIndices,
374 PatternBenefit benefit = 1);
380 void populateVectorNarrowTypeEmulationPatterns(
381 const arith::NarrowTypeEmulationConverter &typeConverter,
382 RewritePatternSet &
patterns,
bool disableAtomicRMW =
false);
387 FailureOr<Value> rewriteBitCastOfTruncI(RewriterBase &rewriter,
388 vector::BitCastOp bitCastOp,
389 arith::TruncIOp truncOp,
390 vector::BroadcastOp maybeBroadcastOp);
395 FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
396 vector::BitCastOp bitCastOp,
397 vector::BroadcastOp maybeBroadcastOp);
402 void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &
patterns,
403 PatternBenefit benefit = 1);
406 void populateVectorTransposeNarrowTypeRewritePatterns(
407 RewritePatternSet &
patterns, PatternBenefit benefit = 1);
412 void populateVectorLinearizeTypeConversionsAndLegality(
413 TypeConverter &typeConverter, RewritePatternSet &
patterns,
414 ConversionTarget &target,
unsigned targetBitWidth);
418 void populateVectorLinearizeShuffleLikeOpsPatterns(
419 const TypeConverter &typeConverter, RewritePatternSet &
patterns,
420 ConversionTarget &target,
unsigned targetBitWidth);
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Operation is the basic unit of execution within MLIR.
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
Options that control the vector unrolling.
FilterConstraintFnType filterConstraint
Callback function that indicates whether vector unrolling should be attempted on the operation.
UnrollVectorOptions & setFilterConstraint(FilterConstraintFnType constraint)
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
UnrollVectorOptions & setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn)
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> UnrollTraversalOrderFnType
Function that returns the traversal order (in terms of "for loop order", i.e.
std::function< LogicalResult(Operation *op)> FilterConstraintFnType
NativeShapeFnType nativeShape
Function that returns the shape of the vector to unroll to for a given operation.
UnrollVectorOptions & setNativeShape(ArrayRef< int64_t > shape)
Set the native shape to use for unrolling.
UnrollTraversalOrderFnType traversalOrderCallback
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> NativeShapeFnType