20 #define DEBUG_TYPE "vector-interleave-lowering"
49 class UnrollInterleaveOp final :
public OpRewritePattern<vector::InterleaveOp> {
51 UnrollInterleaveOp(int64_t targetRank,
MLIRContext *context,
55 LogicalResult matchAndRewrite(vector::InterleaveOp op,
57 VectorType resultType = op.getResultVectorType();
62 auto loc = op.getLoc();
65 for (
auto position : *unrollIterator) {
66 Value extractLhs = rewriter.
create<ExtractOp>(loc, op.getLhs(), position);
67 Value extractRhs = rewriter.
create<ExtractOp>(loc, op.getRhs(), position);
69 rewriter.
create<InterleaveOp>(loc, extractLhs, extractRhs);
70 result = rewriter.
create<InsertOp>(loc, interleave, result, position);
78 int64_t targetRank = 1;
111 class UnrollDeinterleaveOp final
114 UnrollDeinterleaveOp(int64_t targetRank,
MLIRContext *context,
118 LogicalResult matchAndRewrite(vector::DeinterleaveOp op,
120 VectorType resultType = op.getResultVectorType();
125 auto loc = op.getLoc();
126 Value emptyResult = rewriter.
create<arith::ConstantOp>(
127 loc, resultType, rewriter.
getZeroAttr(resultType));
128 Value evenResult = emptyResult;
129 Value oddResult = emptyResult;
131 for (
auto position : *unrollIterator) {
133 rewriter.
create<vector::ExtractOp>(loc, op.getSource(), position);
135 rewriter.
create<vector::DeinterleaveOp>(loc, extractSrc);
136 evenResult = rewriter.
create<vector::InsertOp>(
137 loc, deinterleave.getRes1(), evenResult, position);
138 oddResult = rewriter.
create<vector::InsertOp>(loc, deinterleave.getRes2(),
139 oddResult, position);
146 int64_t targetRank = 1;
166 LogicalResult matchAndRewrite(vector::InterleaveOp op,
168 VectorType sourceType = op.getSourceVectorType();
169 if (sourceType.getRank() != 1 || sourceType.isScalable()) {
172 int64_t n = sourceType.getNumElements();
173 auto seq = llvm::seq<int64_t>(2 * n);
174 auto zip = llvm::to_vector(llvm::map_range(
175 seq, [n](int64_t i) {
return (i % 2 ? n : 0) + i / 2; }));
185 patterns.
add<UnrollInterleaveOp, UnrollDeinterleaveOp>(
191 patterns.
add<InterleaveToShuffle>(patterns.
getContext(), benefit);
TypedAttr getZeroAttr(Type type)
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
std::optional< StaticTileOffsetRange > createUnrollIterator(VectorType vType, int64_t targetRank=1)
Returns an iterator for all positions in the leading dimensions of vType up to the targetRank.
void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...