26#define DEBUG_TYPE "lower-vector-transpose"
35 size_t numTransposedDims = transpose.size();
36 for (
size_t transpDim : llvm::reverse(transpose)) {
37 if (transpDim != numTransposedDims - 1)
42 result.append(transpose.begin(), transpose.begin() + numTransposedDims);
47 return lowering == VectorTransposeLowering::Shuffle1D ||
48 lowering == VectorTransposeLowering::Shuffle16x16;
63 assert(numBits % 128 == 0 &&
"expected numBits is a multiple of 128");
64 int numElem = numBits / 32;
66 for (
int i = 0; i < numElem; i += 4)
80 int numElem = numBits / 32;
81 return vector::ShuffleOp::create(
94 int numElem = numBits / 32;
95 return vector::ShuffleOp::create(
109 int numElem = numBits / 32;
110 auto shuffle = vector::ShuffleOp::create(
124 int numElem = numBits / 32;
125 return vector::ShuffleOp::create(
149 assert(cast<VectorType>(v1.
getType()).getShape()[0] == 16 &&
150 "expected a vector with length=16");
152 auto appendToMask = [&](
int64_t base, uint8_t control) {
156 base + 2, base + 3});
160 base + 6, base + 7});
164 base + 10, base + 11});
168 base + 14, base + 15});
171 llvm_unreachable(
"control > 3 : overflow");
174 uint8_t b01 = mask & 0x3;
175 uint8_t b23 = (mask >> 2) & 0x3;
176 uint8_t b45 = (mask >> 4) & 0x3;
177 uint8_t b67 = (mask >> 6) & 0x3;
178 appendToMask(0, b01);
179 appendToMask(0, b23);
180 appendToMask(16, b45);
181 appendToMask(16, b67);
182 return vector::ShuffleOp::create(
b, v1, v2, shuffleMask);
191 for (
int64_t i = 0; i < m; ++i)
192 mask.push_back(i * n +
j);
193 return vector::ShuffleOp::create(
b, source.
getLoc(), source, source, mask);
202 for (
int64_t i = 0; i < m; ++i)
203 vs.push_back(
b.createOrFold<vector::ExtractOp>(source, i));
283 auto reshInputType = VectorType::get(
284 {m, n}, cast<VectorType>(source.
getType()).getElementType());
285 Value res = ub::PoisonOp::create(
b, reshInputType);
286 for (
int64_t i = 0; i < m; ++i)
287 res = vector::InsertOp::create(
b, vs[i], res, i);
306 MLIRContext *context, PatternBenefit benefit = 1)
308 vectorTransposeLowering(vectorTransposeLowering) {}
311 PatternRewriter &rewriter)
const override {
312 auto loc = op.getLoc();
314 Value input = op.getVector();
315 VectorType inputType = op.getSourceVectorType();
316 VectorType resType = op.getResultVectorType();
318 if (inputType.isScalable())
320 op,
"This lowering does not support scalable vectors");
323 ArrayRef<int64_t> transp = op.getPermutation();
328 op,
"Options specifies lowering to shuffle");
335 SmallVector<int64_t> prunedTransp;
337 size_t numPrunedDims = transp.size() - prunedTransp.size();
338 auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
345 Value
result = ub::PoisonOp::create(rewriter, loc, resType);
346 int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
348 for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
350 auto extractIdxs =
delinearize(linearIdx, prunedInStrides);
351 SmallVector<int64_t> insertIdxs(extractIdxs);
354 rewriter.
createOrFold<vector::ExtractOp>(loc, input, extractIdxs);
365 vector::VectorTransposeLowering vectorTransposeLowering;
395class Transpose2DWithUnitDimToShapeCast
400 Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
401 PatternBenefit benefit = 1)
402 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
404 LogicalResult matchAndRewrite(vector::TransposeOp op,
405 PatternRewriter &rewriter)
const override {
406 Value input = op.getVector();
407 VectorType resType = op.getResultVectorType();
410 ArrayRef<int64_t> transp = op.getPermutation();
412 if (resType.getRank() == 2 &&
413 ((resType.getShape().front() == 1 &&
414 !resType.getScalableDims().front()) ||
415 (resType.getShape().back() == 1 &&
416 !resType.getScalableDims().back())) &&
417 transp == ArrayRef<int64_t>({1, 0})) {
433class TransposeOp2DToShuffleLowering
438 TransposeOp2DToShuffleLowering(
439 vector::VectorTransposeLowering vectorTransposeLowering,
440 MLIRContext *context, PatternBenefit benefit = 1)
441 : OpRewritePattern<vector::TransposeOp>(context, benefit),
442 vectorTransposeLowering(vectorTransposeLowering) {}
444 LogicalResult matchAndRewrite(vector::TransposeOp op,
445 PatternRewriter &rewriter)
const override {
448 op,
"not using vector shuffle based lowering");
450 if (op.getSourceVectorType().isScalable())
452 op,
"vector shuffle lowering not supported for scalable vectors");
457 op,
"expected transposition on a 2D slice");
459 VectorType srcType = op.getSourceVectorType();
460 int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
461 int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
465 Location loc = op.getLoc();
466 auto flattenedType = VectorType::get({n * m}, srcType.getElementType());
467 auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
468 auto reshInput = vector::ShapeCastOp::create(rewriter, loc, flattenedType,
472 if (vectorTransposeLowering == VectorTransposeLowering::Shuffle16x16 &&
473 m == 16 && n == 16) {
475 vector::ShapeCastOp::create(rewriter, loc, reshInputType, reshInput);
483 op, op.getResultVectorType(), res);
490 vector::VectorTransposeLowering vectorTransposeLowering;
496 VectorTransposeLowering vectorTransposeLowering,
PatternBenefit benefit) {
500 vectorTransposeLowering,
patterns.getContext(), benefit);
static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask.
static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask.
static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m, int n)
Lowers the value to a sequence of vector.shuffle ops.
static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask.
static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask.
static SmallVector< int64_t > getUnpackShufflePermFor128Lane(ArrayRef< int64_t > vals, int numBits)
Returns a shuffle mask that builds on vals.
static void pruneNonTransposedDims(ArrayRef< int64_t > transpose, SmallVectorImpl< int64_t > &result)
Given a 'transpose' pattern, prune the rightmost dimensions that are not transposed.
static bool isShuffleLike(VectorTransposeLowering lowering)
Returns true if the lowering option is a vector shuffle based approach.
static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit elements) selected by mask...
static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n)
Lowers the value to a vector.shuffle op.
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context, int benefit)
LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
const FrozenRewritePatternSet & patterns
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.