9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
19 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
22 class ConversionTarget;
23 class RewritePatternSet;
28 class NarrowTypeEmulationConverter;
33 struct VectorTransformsOptions;
47 std::function<std::optional<SmallVector<int64_t>>(
Operation *op)>;
71 std::function<std::optional<SmallVector<int64_t>>(
Operation *op)>;
97 std::function<LogicalResult(vector::ContractionOp)> constraint =
98 [](vector::ContractionOp) {
return success(); },
104 PatternBenefit benefit = 1);
147 PatternBenefit benefit = 1);
164 PatternBenefit benefit = 1);
177 void populateSinkVectorMemOpsPatterns(RewritePatternSet &
patterns,
178 PatternBenefit benefit = 1);
196 void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &
patterns,
197 PatternBenefit benefit = 1);
214 void populateBreakDownVectorReductionPatterns(
215 RewritePatternSet &
patterns,
unsigned maxNumElementsToExtract = 2,
216 PatternBenefit benefit = 1);
240 void populateVectorInsertExtractStridedSliceDecompositionPatterns(
241 RewritePatternSet &
patterns, PatternBenefit benefit = 1);
249 void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
251 std::function<
bool(ExtractStridedSliceOp)> controlFn =
nullptr,
252 PatternBenefit benefit = 1);
265 void populateBreakDownVectorBitCastOpPatterns(
267 std::function<
bool(BitCastOp)> controlFn =
nullptr,
268 PatternBenefit benefit = 1);
288 void populateVectorInsertExtractStridedSliceTransforms(
289 RewritePatternSet &
patterns, PatternBenefit benefit = 1);
321 void populateVectorUnrollPatterns(RewritePatternSet &
patterns,
322 const UnrollVectorOptions &
options,
323 PatternBenefit benefit = 1);
327 void populateVectorToElementsUnrollPatterns(RewritePatternSet &
patterns,
328 PatternBenefit benefit = 1);
332 void populateVectorFromElementsUnrollPatterns(RewritePatternSet &
patterns,
333 PatternBenefit benefit = 1);
341 void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &
patterns,
342 PatternBenefit benefit = 1);
349 void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &
patterns,
350 PatternBenefit benefit = 1);
358 void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &
patterns,
359 PatternBenefit benefit = 1);
370 void populateFlattenVectorTransferPatterns(
373 PatternBenefit benefit = 1);
380 void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &
patterns,
381 PatternBenefit benefit = 1);
384 void populateVectorMaskMaterializationPatterns(RewritePatternSet &
patterns,
385 bool force32BitVectorIndices,
386 PatternBenefit benefit = 1);
392 void populateVectorNarrowTypeEmulationPatterns(
393 const arith::NarrowTypeEmulationConverter &typeConverter,
394 RewritePatternSet &
patterns,
bool disableAtomicRMW =
false);
402 void populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
403 arith::NarrowTypeEmulationConverter &typeConverter,
409 FailureOr<Value> rewriteBitCastOfTruncI(RewriterBase &rewriter,
410 vector::BitCastOp bitCastOp,
411 arith::TruncIOp truncOp,
412 vector::BroadcastOp maybeBroadcastOp);
417 FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
418 vector::BitCastOp bitCastOp,
419 vector::BroadcastOp maybeBroadcastOp);
424 void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &
patterns,
425 PatternBenefit benefit = 1);
428 void populateVectorTransposeNarrowTypeRewritePatterns(
429 RewritePatternSet &
patterns, PatternBenefit benefit = 1);
448 void populateForVectorLinearize(TypeConverter &typeConverter,
449 ConversionTarget &conversionTarget);
454 void populateVectorLinearizeBasePatterns(
const TypeConverter &,
455 const ConversionTarget &,
460 void populateVectorLinearizeShuffleLikeOpsPatterns(
const TypeConverter &,
461 const ConversionTarget &,
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Operation is the basic unit of execution within MLIR.
void populateDropInnerMostUnitDimsXferOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to collapse the most inner unit dims in xfer Ops.
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
Options that control the vector unrolling.
FilterConstraintFnType filterConstraint
Callback function that indicates whether vector unrolling should be attempted on the operation.
UnrollVectorOptions & setFilterConstraint(FilterConstraintFnType constraint)
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
UnrollVectorOptions & setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn)
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> UnrollTraversalOrderFnType
Function that returns the traversal order (in terms of "for loop order", i.e.
std::function< LogicalResult(Operation *op)> FilterConstraintFnType
NativeShapeFnType nativeShape
Function that returns the shape of the vector to unroll to for a given operation.
UnrollVectorOptions & setNativeShape(ArrayRef< int64_t > shape)
Set the native shape to use for unrolling.
UnrollTraversalOrderFnType traversalOrderCallback
std::function< std::optional< SmallVector< int64_t > >(Operation *op)> NativeShapeFnType