26 #define DEBUG_TYPE "lower-vector-transpose"
35 size_t numTransposedDims = transpose.size();
36 for (
size_t transpDim : llvm::reverse(transpose)) {
37 if (transpDim != numTransposedDims - 1)
42 result.append(transpose.begin(), transpose.begin() + numTransposedDims);
47 return lowering == VectorTransposeLowering::Shuffle1D ||
48 lowering == VectorTransposeLowering::Shuffle16x16;
63 assert(numBits % 128 == 0 &&
"expected numBits is a multiple of 128");
64 int numElem = numBits / 32;
66 for (
int i = 0; i < numElem; i += 4)
67 for (int64_t v : vals)
80 int numElem = numBits / 32;
81 return vector::ShuffleOp::create(
94 int numElem = numBits / 32;
95 return vector::ShuffleOp::create(
109 int numElem = numBits / 32;
110 auto shuffle = vector::ShuffleOp::create(
124 int numElem = numBits / 32;
125 return vector::ShuffleOp::create(
149 assert(cast<VectorType>(v1.
getType()).getShape()[0] == 16 &&
150 "expected a vector with length=16");
152 auto appendToMask = [&](int64_t base, uint8_t control) {
156 base + 2, base + 3});
160 base + 6, base + 7});
164 base + 10, base + 11});
168 base + 14, base + 15});
171 llvm_unreachable(
"control > 3 : overflow");
174 uint8_t b01 = mask & 0x3;
175 uint8_t b23 = (mask >> 2) & 0x3;
176 uint8_t b45 = (mask >> 4) & 0x3;
177 uint8_t b67 = (mask >> 6) & 0x3;
178 appendToMask(0, b01);
179 appendToMask(0, b23);
180 appendToMask(16, b45);
181 appendToMask(16, b67);
182 return vector::ShuffleOp::create(b, v1, v2, shuffleMask);
190 for (int64_t
j = 0;
j < n; ++
j)
191 for (int64_t i = 0; i < m; ++i)
192 mask.push_back(i * n +
j);
193 return vector::ShuffleOp::create(b, source.
getLoc(), source, source, mask);
202 for (int64_t i = 0; i < m; ++i)
203 vs.push_back(b.
createOrFold<vector::ExtractOp>(source, i));
284 {m, n}, cast<VectorType>(source.
getType()).getElementType());
285 Value res = ub::PoisonOp::create(b, reshInputType);
286 for (int64_t i = 0; i < m; ++i)
287 res = vector::InsertOp::create(b, vs[i], res, i);
308 vectorTransposeLowering(vectorTransposeLowering) {}
310 LogicalResult matchAndRewrite(vector::TransposeOp op,
312 auto loc = op.getLoc();
314 Value input = op.getVector();
315 VectorType inputType = op.getSourceVectorType();
316 VectorType resType = op.getResultVectorType();
318 if (inputType.isScalable())
320 op,
"This lowering does not support scalable vectors");
328 op,
"Options specifies lowering to shuffle");
337 size_t numPrunedDims = transp.size() - prunedTransp.size();
338 auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
345 Value result = ub::PoisonOp::create(rewriter, loc, resType);
348 for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
350 auto extractIdxs =
delinearize(linearIdx, prunedInStrides);
354 rewriter.
createOrFold<vector::ExtractOp>(loc, input, extractIdxs);
355 result = rewriter.
createOrFold<vector::InsertOp>(loc, extractOp, result,
365 vector::VectorTransposeLowering vectorTransposeLowering;
395 class Transpose2DWithUnitDimToShapeCast
400 Transpose2DWithUnitDimToShapeCast(
MLIRContext *context,
404 LogicalResult matchAndRewrite(vector::TransposeOp op,
406 Value input = op.getVector();
407 VectorType resType = op.getResultVectorType();
412 if (resType.getRank() == 2 &&
413 ((resType.getShape().front() == 1 &&
414 !resType.getScalableDims().front()) ||
415 (resType.getShape().back() == 1 &&
416 !resType.getScalableDims().back())) &&
433 class TransposeOp2DToShuffleLowering
438 TransposeOp2DToShuffleLowering(
439 vector::VectorTransposeLowering vectorTransposeLowering,
442 vectorTransposeLowering(vectorTransposeLowering) {}
444 LogicalResult matchAndRewrite(vector::TransposeOp op,
448 op,
"not using vector shuffle based lowering");
450 if (op.getSourceVectorType().isScalable())
452 op,
"vector shuffle lowering not supported for scalable vectors");
457 op,
"expected transposition on a 2D slice");
459 VectorType srcType = op.getSourceVectorType();
460 int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
461 int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
466 auto flattenedType =
VectorType::get({n * m}, srcType.getElementType());
467 auto reshInputType =
VectorType::get({m, n}, srcType.getElementType());
468 auto reshInput = vector::ShapeCastOp::create(rewriter, loc, flattenedType,
472 if (vectorTransposeLowering == VectorTransposeLowering::Shuffle16x16 &&
473 m == 16 && n == 16) {
475 vector::ShapeCastOp::create(rewriter, loc, reshInputType, reshInput);
483 op, op.getResultVectorType(), res);
490 vector::VectorTransposeLowering vectorTransposeLowering;
496 VectorTransposeLowering vectorTransposeLowering,
PatternBenefit benefit) {
500 vectorTransposeLowering,
patterns.getContext(), benefit);
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask.
static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask.
static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m, int n)
Lowers the value to a sequence of vector.shuffle ops.
static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask.
static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask.
static void pruneNonTransposedDims(ArrayRef< int64_t > transpose, SmallVectorImpl< int64_t > &result)
Given a 'transpose' pattern, prune the rightmost dimensions that are not transposed.
static bool isShuffleLike(VectorTransposeLowering lowering)
Returns true if the lowering option is a vector shuffle based approach.
static SmallVector< int64_t > getUnpackShufflePermFor128Lane(ArrayRef< int64_t > vals, int numBits)
Returns a shuffle mask that builds on vals.
static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit elements) selected by mask...
static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n)
Lowers the value to a vector.shuffle op.
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
void createOrFold(llvm::SmallVectorImpl< Value > &results, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.