9#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
10#define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
19#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
22class ConversionTarget;
47 std::function<std::optional<SmallVector<int64_t>>(
Operation *op)>;
71 std::function<std::optional<SmallVector<int64_t>>(
Operation *op)>;
97 std::function<LogicalResult(vector::ContractionOp)> constraint =
98 [](vector::ContractionOp) {
return success(); },
104 PatternBenefit benefit = 1);
138 RewritePatternSet &patterns,
const VectorTransformsOptions &
options);
147 PatternBenefit benefit = 1);
164 PatternBenefit benefit = 1);
177void populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
178 PatternBenefit benefit = 1);
196void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,
197 PatternBenefit benefit = 1);
214void populateBreakDownVectorReductionPatterns(
215 RewritePatternSet &patterns,
unsigned maxNumElementsToExtract = 2,
216 PatternBenefit benefit = 1);
240void populateVectorInsertExtractStridedSliceDecompositionPatterns(
241 RewritePatternSet &patterns, PatternBenefit benefit = 1);
249void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
250 RewritePatternSet &patterns,
251 std::function<
bool(ExtractStridedSliceOp)> controlFn =
nullptr,
252 PatternBenefit benefit = 1);
265void populateBreakDownVectorBitCastOpPatterns(
266 RewritePatternSet &patterns,
267 std::function<
bool(BitCastOp)> controlFn =
nullptr,
268 PatternBenefit benefit = 1);
288void populateVectorInsertExtractStridedSliceTransforms(
289 RewritePatternSet &patterns, PatternBenefit benefit = 1);
321void populateVectorUnrollPatterns(RewritePatternSet &patterns,
322 const UnrollVectorOptions &
options,
323 PatternBenefit benefit = 1);
326SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
328 AffineMap permutationMap, Location loc,
333void populateVectorToElementsUnrollPatterns(RewritePatternSet &patterns,
334 PatternBenefit benefit = 1);
338void populateVectorFromElementsUnrollPatterns(RewritePatternSet &patterns,
339 PatternBenefit benefit = 1);
347void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns,
348 PatternBenefit benefit = 1);
355void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns,
356 PatternBenefit benefit = 1);
364void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns,
365 PatternBenefit benefit = 1);
376void populateFlattenVectorTransferPatterns(
377 RewritePatternSet &patterns,
378 unsigned targetVectorBitwidth = std::numeric_limits<unsigned>::max(),
379 PatternBenefit benefit = 1);
386void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns,
387 PatternBenefit benefit = 1);
390void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
391 bool force32BitVectorIndices,
392 PatternBenefit benefit = 1);
402void populateVectorNarrowTypeEmulationPatterns(
403 const arith::NarrowTypeEmulationConverter &typeConverter,
404 RewritePatternSet &patterns,
bool disableAtomicRMW =
false,
405 bool assumeAligned =
false);
413void populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
414 arith::NarrowTypeEmulationConverter &typeConverter,
415 RewritePatternSet &patterns);
420FailureOr<Value> rewriteBitCastOfTruncI(RewriterBase &rewriter,
421 vector::BitCastOp bitCastOp,
422 arith::TruncIOp truncOp,
423 vector::BroadcastOp maybeBroadcastOp);
428FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
429 vector::BitCastOp bitCastOp,
430 vector::BroadcastOp maybeBroadcastOp);
435void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
436 PatternBenefit benefit = 1);
439void populateVectorTransposeNarrowTypeRewritePatterns(
440 RewritePatternSet &patterns, PatternBenefit benefit = 1);
465void populateVectorLinearizeBasePatterns(
const TypeConverter &,
467 RewritePatternSet &patterns);
471void populateVectorLinearizeShuffleLikeOpsPatterns(
const TypeConverter &,
473 RewritePatternSet &patterns);
static llvm::ManagedStatic< PassManagerOptions > options
Operation is the basic unit of execution within MLIR.
Converts narrow integer or float types that are not supported by the target hardware to wider types.
void populateDropInnerMostUnitDimsXferOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to collapse the most inner unit dims in xfer Ops.
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contract a, b, c with row-major matmul semantics to a contraction with M...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
Include the generated interface declarations.
Options that control the vector unrolling.
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> UnrollTraversalOrderFnType
Function that returns the traversal order (in terms of "for loop order", i.e.
FilterConstraintFnType filterConstraint
Callback function that indicates whether vector unrolling should be attempted on the operation.
std::function< LogicalResult(Operation *op)> FilterConstraintFnType
NativeShapeFnType nativeShape
Function that returns the shape of the vector to unroll to for a given operation.
UnrollVectorOptions & setFilterConstraint(FilterConstraintFnType constraint)
UnrollVectorOptions & setNativeShape(ArrayRef< int64_t > shape)
Set the native shape to use for unrolling.
UnrollTraversalOrderFnType traversalOrderCallback
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
UnrollVectorOptions & setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn)
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> NativeShapeFnType