19#include "llvm/ADT/SmallVectorExtras.h"
21#define DEBUG_TYPE "vector-interleave-lowering"
50class UnrollInterleaveOp final :
public OpRewritePattern<vector::InterleaveOp> {
56 LogicalResult matchAndRewrite(vector::InterleaveOp op,
58 VectorType resultType = op.getResultVectorType();
63 auto loc = op.getLoc();
64 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
66 for (
auto position : *unrollIterator) {
68 ExtractOp::create(rewriter, loc, op.getLhs(), position);
70 ExtractOp::create(rewriter, loc, op.getRhs(), position);
72 InterleaveOp::create(rewriter, loc, extractLhs, extractRhs);
73 result = InsertOp::create(rewriter, loc, interleave,
result, position);
114class UnrollDeinterleaveOp final
121 LogicalResult matchAndRewrite(vector::DeinterleaveOp op,
123 VectorType resultType = op.getResultVectorType();
128 auto loc = op.getLoc();
129 Value emptyResult = arith::ConstantOp::create(
130 rewriter, loc, resultType, rewriter.
getZeroAttr(resultType));
131 Value evenResult = emptyResult;
132 Value oddResult = emptyResult;
134 for (
auto position : *unrollIterator) {
136 vector::ExtractOp::create(rewriter, loc, op.getSource(), position);
138 vector::DeinterleaveOp::create(rewriter, loc, extractSrc);
139 evenResult = vector::InsertOp::create(
140 rewriter, loc, deinterleave.getRes1(), evenResult, position);
141 oddResult = vector::InsertOp::create(
142 rewriter, loc, deinterleave.getRes2(), oddResult, position);
170 LogicalResult matchAndRewrite(vector::InterleaveOp op,
172 VectorType sourceType = op.getSourceVectorType();
173 if (sourceType.getRank() > 1 || sourceType.isScalable()) {
176 int64_t n = sourceType.getNumElements();
177 auto seq = llvm::seq<int64_t>(2 * n);
178 auto zip = llvm::map_to_vector(
179 seq, [n](
int64_t i) {
return (i % 2 ? n : 0) + i / 2; });
200struct DeinterleaveToShuffle final :
OpRewritePattern<vector::DeinterleaveOp> {
203 LogicalResult matchAndRewrite(vector::DeinterleaveOp op,
205 VectorType sourceType = op.getSourceVectorType();
206 if (sourceType.getRank() != 1 || sourceType.isScalable()) {
210 auto seq = llvm::seq<int64_t>(sourceType.getNumElements() / 2);
211 auto evenZip = llvm::map_to_vector(seq, [](
int64_t i) {
return i * 2; });
212 auto oddZip = llvm::map_to_vector(evenZip, [](
int64_t i) {
return i + 1; });
214 Value evenResult = vector::ShuffleOp::create(
215 rewriter, op.getLoc(), op.getOperand(), op.getOperand(), evenZip);
216 Value oddResult = vector::ShuffleOp::create(
217 rewriter, op.getLoc(), op.getOperand(), op.getOperand(), oddZip);
228 patterns.
add<UnrollInterleaveOp, UnrollDeinterleaveOp>(
234 patterns.
add<InterleaveToShuffle>(patterns.
getContext(), benefit);
239 patterns.
add<DeinterleaveToShuffle>(patterns.
getContext(), benefit);
TypedAttr getZeroAttr(Type type)
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
std::optional< StaticTileOffsetRange > createUnrollIterator(VectorType vType, int64_t targetRank=1)
Returns an iterator for all positions in the leading dimensions of vType up to the targetRank.
void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorDeinterleaveToShufflePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...