9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
16 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
22 class RewritePatternSet;
40 VectorMultiReductionLowering::InnerParallel;
48 VectorTransposeLowering::EltWise;
74 std::function<std::optional<SmallVector<int64_t>>(
Operation *op)>;
98 std::function<std::optional<SmallVector<int64_t>>(
Operation *op)>;
114 VectorTransformsOptions
options = VectorTransformsOptions(),
147 VectorTransformsOptions
options = VectorTransformsOptions(),
257 std::function<
bool(ExtractStridedSliceOp)> controlFn =
nullptr,
312 const UnrollVectorOptions &
options,
328 [](VectorTransferOpInterface op) {
return success(); },
330 :
RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
331 filter(std::move(filter)) {}
338 VectorTransformsOptions options;
372 vectorTransformOptions(vectorTransformOptions),
373 filter(std::move(constraint)) {}
416 vectorTransformOptions(vectorTransformOptions),
417 filter(std::move(constraint)) {}
463 vectorTransformOptions(vectorTransformOptions), filter(
defaultFilter) {}
502 vectorTransformOptions(vectorTransformOptions),
503 filter(std::move(constraint)) {}
static llvm::ManagedStatic< PassManagerOptions > options
This class provides support for representing a failure result, or a valid value of type T.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
Progressive lowering of ContractionOp.
static LogicalResult defaultFilter(vector::ContractionOp op)
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions, MLIRContext *context, PatternBenefit benefit=1, FilterConstraintType constraint=defaultFilter)
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
Progressive lowering of ContractionOp.
Progressive lowering of a vector.contract a, b, c with row-major matmul semantics to an output-size-u...
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
static LogicalResult defaultFilter(vector::ContractionOp op)
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
ContractionOpToDotLowering(vector::VectorTransformsOptions vectorTransformOptions, MLIRContext *context, PatternBenefit benefit=1, const FilterConstraintType &constraint=defaultFilter)
Progressive lowering of a vector.contract a, b, c with row-major matmul semantics to:
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
static LogicalResult defaultFilter(vector::ContractionOp op)
ContractionOpToMatmulOpLowering(vector::VectorTransformsOptions vectorTransformOptions, MLIRContext *context, PatternBenefit benefit=1, FilterConstraintType constraint=defaultFilter)
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
Progressively lower a vector.contract a, b, c with row-major matmul semantics to:
Progressive lowering of a vector.contract a, b, c with row-major matmul semantics to a reduction_size...
ContractionOpToOuterProductOpLowering(vector::VectorTransformsOptions vectorTransformOptions, MLIRContext *context, PatternBenefit benefit=1, FilterConstraintType constraint=defaultFilter)
static LogicalResult defaultFilter(vector::ContractionOp op)
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
Progressively lower a vector.contract a, b, c with row-major matmul semantics to a reduction_size-unr...
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options=VectorTransformsOptions(), PatternBenefit benefit=1)
Collects patterns to progressively lower vector contraction ops on high-D into low-D reduction and pr...
void populateVectorExtractStridedSliceToExtractInsertChainPatterns(RewritePatternSet &patterns, std::function< bool(ExtractStridedSliceOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to breaks down 1-D extract_strided_slice ops into a chain of Extract...
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options=VectorTransformsOptions(), PatternBenefit benefit=1)
Insert TransposeLowering patterns into extraction/insertion.
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert scan op.
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Options that control the vector unrolling.
FilterConstraintFnType filterConstraint
Callback function that indicates whether vector unrolling should be attempted on the operation.
UnrollVectorOptions & setFilterConstraint(FilterConstraintFnType constraint)
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
UnrollVectorOptions & setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn)
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> UnrollTraversalOrderFnType
Function that returns the traversal order (in terms of "for loop order", i.e.
std::function< LogicalResult(Operation *op)> FilterConstraintFnType
NativeShapeFnType nativeShape
Function that returns the shape of the vector to unroll to for a given operation.
UnrollVectorOptions & setNativeShape(ArrayRef< int64_t > shape)
Set the native shape to use for unrolling.
UnrollTraversalOrderFnType traversalOrderCallback
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> NativeShapeFnType
Apply splitFullAndPartialTransfer selectively via a pattern.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Performs the rewrite.
std::function< LogicalResult(VectorTransferOpInterface op)> FilterConstraintType
VectorTransferFullPartialRewriter(MLIRContext *context, VectorTransformsOptions options=VectorTransformsOptions(), FilterConstraintType filter=[](VectorTransferOpInterface op) { return success();}, PatternBenefit benefit=1)