9#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
10#define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
19#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
22class ConversionTarget;
47 std::function<std::optional<SmallVector<int64_t>>(
Operation *op)>;
71 std::function<std::optional<SmallVector<int64_t>>(
Operation *op)>;
97 std::function<LogicalResult(vector::ContractionOp)> constraint =
98 [](vector::ContractionOp) {
return success(); },
104 PatternBenefit benefit = 1);
147 PatternBenefit benefit = 1);
164 PatternBenefit benefit = 1);
177void populateSinkVectorMemOpsPatterns(RewritePatternSet &
patterns,
178 PatternBenefit benefit = 1);
196void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &
patterns,
197 PatternBenefit benefit = 1);
214void populateBreakDownVectorReductionPatterns(
215 RewritePatternSet &
patterns,
unsigned maxNumElementsToExtract = 2,
216 PatternBenefit benefit = 1);
240void populateVectorInsertExtractStridedSliceDecompositionPatterns(
241 RewritePatternSet &
patterns, PatternBenefit benefit = 1);
249void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
251 std::function<
bool(ExtractStridedSliceOp)> controlFn =
nullptr,
252 PatternBenefit benefit = 1);
265void populateBreakDownVectorBitCastOpPatterns(
267 std::function<
bool(BitCastOp)> controlFn =
nullptr,
268 PatternBenefit benefit = 1);
288void populateVectorInsertExtractStridedSliceTransforms(
289 RewritePatternSet &
patterns, PatternBenefit benefit = 1);
321void populateVectorUnrollPatterns(RewritePatternSet &
patterns,
322 const UnrollVectorOptions &
options,
323 PatternBenefit benefit = 1);
327void populateVectorToElementsUnrollPatterns(RewritePatternSet &
patterns,
328 PatternBenefit benefit = 1);
332void populateVectorFromElementsUnrollPatterns(RewritePatternSet &
patterns,
333 PatternBenefit benefit = 1);
341void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &
patterns,
342 PatternBenefit benefit = 1);
349void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &
patterns,
350 PatternBenefit benefit = 1);
358void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &
patterns,
359 PatternBenefit benefit = 1);
370void populateFlattenVectorTransferPatterns(
372 unsigned targetVectorBitwidth = std::numeric_limits<unsigned>::max(),
373 PatternBenefit benefit = 1);
380void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &
patterns,
381 PatternBenefit benefit = 1);
384void populateVectorMaskMaterializationPatterns(RewritePatternSet &
patterns,
385 bool force32BitVectorIndices,
386 PatternBenefit benefit = 1);
392void populateVectorNarrowTypeEmulationPatterns(
393 const arith::NarrowTypeEmulationConverter &typeConverter,
394 RewritePatternSet &
patterns,
bool disableAtomicRMW =
false);
402void populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
403 arith::NarrowTypeEmulationConverter &typeConverter,
409FailureOr<Value> rewriteBitCastOfTruncI(RewriterBase &rewriter,
410 vector::BitCastOp bitCastOp,
411 arith::TruncIOp truncOp,
412 vector::BroadcastOp maybeBroadcastOp);
417FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
418 vector::BitCastOp bitCastOp,
419 vector::BroadcastOp maybeBroadcastOp);
424void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &
patterns,
425 PatternBenefit benefit = 1);
428void populateVectorTransposeNarrowTypeRewritePatterns(
429 RewritePatternSet &
patterns, PatternBenefit benefit = 1);
454void populateVectorLinearizeBasePatterns(
const TypeConverter &,
460void populateVectorLinearizeShuffleLikeOpsPatterns(
const TypeConverter &,
static llvm::ManagedStatic< PassManagerOptions > options
Operation is the basic unit of execution within MLIR.
Converts narrow integer or float types that are not supported by the target hardware to wider types.
void populateDropInnerMostUnitDimsXferOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to collapse the most inner unit dims in xfer Ops.
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
Options that control the vector unrolling.
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> UnrollTraversalOrderFnType
Function that returns the traversal order (in terms of "for loop order", i.e.
FilterConstraintFnType filterConstraint
Callback function that indicates whether vector unrolling should be attempted on the operation.
std::function< LogicalResult(Operation *op)> FilterConstraintFnType
NativeShapeFnType nativeShape
Function that returns the shape of the vector to unroll to for a given operation.
UnrollVectorOptions & setFilterConstraint(FilterConstraintFnType constraint)
UnrollVectorOptions & setNativeShape(ArrayRef< int64_t > shape)
Set the native shape to use for unrolling.
UnrollTraversalOrderFnType traversalOrderCallback
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
UnrollVectorOptions & setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn)
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> NativeShapeFnType