28 #define DEBUG_TYPE "lower-vector-transpose"
37 size_t numTransposedDims =
transpose.size();
38 for (
size_t transpDim : llvm::reverse(
transpose)) {
39 if (transpDim != numTransposedDims - 1)
49 return lowering == VectorTransposeLowering::Shuffle1D ||
50 lowering == VectorTransposeLowering::Shuffle16x16;
65 assert(numBits % 128 == 0 &&
"expected numBits is a multiple of 128");
66 int numElem = numBits / 32;
68 for (
int i = 0; i < numElem; i += 4)
69 for (int64_t v : vals)
82 int numElem = numBits / 32;
83 return b.
create<vector::ShuffleOp>(
96 int numElem = numBits / 32;
97 return b.
create<vector::ShuffleOp>(
111 int numElem = numBits / 32;
112 auto shuffle = b.
create<vector::ShuffleOp>(
126 int numElem = numBits / 32;
127 return b.
create<vector::ShuffleOp>(
151 assert(cast<VectorType>(v1.
getType()).getShape()[0] == 16 &&
152 "expected a vector with length=16");
154 auto appendToMask = [&](int64_t base, uint8_t control) {
158 base + 2, base + 3});
162 base + 6, base + 7});
166 base + 10, base + 11});
170 base + 14, base + 15});
173 llvm_unreachable(
"control > 3 : overflow");
176 uint8_t b01 = mask & 0x3;
177 uint8_t b23 = (mask >> 2) & 0x3;
178 uint8_t b45 = (mask >> 4) & 0x3;
179 uint8_t b67 = (mask >> 6) & 0x3;
180 appendToMask(0, b01);
181 appendToMask(0, b23);
182 appendToMask(16, b45);
183 appendToMask(16, b67);
184 return b.
create<vector::ShuffleOp>(v1, v2, shuffleMask);
192 for (int64_t
j = 0;
j < n; ++
j)
193 for (int64_t i = 0; i < m; ++i)
194 mask.push_back(i * n +
j);
195 return b.
create<vector::ShuffleOp>(source.
getLoc(), source, source, mask);
204 for (int64_t i = 0; i < m; ++i)
205 vs.push_back(b.
createOrFold<vector::ExtractOp>(source, i));
286 {m, n}, cast<VectorType>(source.
getType()).getElementType());
288 for (int64_t i = 0; i < m; ++i)
289 res = b.
create<vector::InsertOp>(vs[i], res, i);
310 vectorTransposeLowering(vectorTransposeLowering) {}
312 LogicalResult matchAndRewrite(vector::TransposeOp op,
314 auto loc = op.getLoc();
316 Value input = op.getVector();
317 VectorType inputType = op.getSourceVectorType();
318 VectorType resType = op.getResultVectorType();
320 if (inputType.isScalable())
322 op,
"This lowering does not support scalable vectors");
330 op,
"Options specifies lowering to shuffle");
333 if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat &&
334 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
338 rewriter.
create<vector::ShapeCastOp>(loc, flattenedType, input);
341 Value trans = rewriter.
create<vector::FlatTransposeOp>(
342 loc, flattenedType, matrix,
rows, columns);
354 size_t numPrunedDims = transp.size() - prunedTransp.size();
355 auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
362 Value result = rewriter.
create<ub::PoisonOp>(loc, resType);
365 for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
367 auto extractIdxs =
delinearize(linearIdx, prunedInStrides);
371 rewriter.
createOrFold<vector::ExtractOp>(loc, input, extractIdxs);
372 result = rewriter.
createOrFold<vector::InsertOp>(loc, extractOp, result,
382 vector::VectorTransposeLowering vectorTransposeLowering;
412 class Transpose2DWithUnitDimToShapeCast
417 Transpose2DWithUnitDimToShapeCast(
MLIRContext *context,
421 LogicalResult matchAndRewrite(vector::TransposeOp op,
423 Value input = op.getVector();
424 VectorType resType = op.getResultVectorType();
429 if (resType.getRank() == 2 &&
430 ((resType.getShape().front() == 1 &&
431 !resType.getScalableDims().front()) ||
432 (resType.getShape().back() == 1 &&
433 !resType.getScalableDims().back())) &&
450 class TransposeOp2DToShuffleLowering
455 TransposeOp2DToShuffleLowering(
456 vector::VectorTransposeLowering vectorTransposeLowering,
459 vectorTransposeLowering(vectorTransposeLowering) {}
461 LogicalResult matchAndRewrite(vector::TransposeOp op,
465 op,
"not using vector shuffle based lowering");
467 if (op.getSourceVectorType().isScalable())
469 op,
"vector shuffle lowering not supported for scalable vectors");
472 if (failed(srcGtOneDims))
474 op,
"expected transposition on a 2D slice");
476 VectorType srcType = op.getSourceVectorType();
477 int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
478 int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
483 auto flattenedType =
VectorType::get({n * m}, srcType.getElementType());
484 auto reshInputType =
VectorType::get({m, n}, srcType.getElementType());
485 auto reshInput = rewriter.
create<vector::ShapeCastOp>(loc, flattenedType,
489 if (vectorTransposeLowering == VectorTransposeLowering::Shuffle16x16 &&
490 m == 16 && n == 16) {
492 rewriter.
create<vector::ShapeCastOp>(loc, reshInputType, reshInput);
500 op, op.getResultVectorType(), res);
507 vector::VectorTransposeLowering vectorTransposeLowering;
513 VectorTransposeLowering vectorTransposeLowering,
PatternBenefit benefit) {
517 vectorTransposeLowering,
patterns.getContext(), benefit);
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask.
static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask.
static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m, int n)
Lowers the value to a sequence of vector.shuffle ops.
static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask.
static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask.
static void pruneNonTransposedDims(ArrayRef< int64_t > transpose, SmallVectorImpl< int64_t > &result)
Given a 'transpose' pattern, prune the rightmost dimensions that are not transposed.
static bool isShuffleLike(VectorTransposeLowering lowering)
Returns true if the lowering option is a vector shuffle based approach.
static SmallVector< int64_t > getUnpackShufflePermFor128Lane(ArrayRef< int64_t > vals, int numBits)
Returns a shuffle mask that builds on vals.
static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit elements) selected by mask...
static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n)
Lowers the value to a vector.shuffle op.
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
IntegerAttr getI32IntegerAttr(int32_t value)
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
void createOrFold(llvm::SmallVectorImpl< Value > &results, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.