35 #define DEBUG_TYPE "lower-vector-transpose"
44 size_t numTransposedDims =
transpose.size();
45 for (
size_t transpDim : llvm::reverse(
transpose)) {
46 if (transpDim != numTransposedDims - 1)
56 return lowering == VectorTransposeLowering::Shuffle1D ||
57 lowering == VectorTransposeLowering::Shuffle16x16;
72 assert(numBits % 128 == 0 &&
"expected numBits is a multiple of 128");
73 int numElem = numBits / 32;
75 for (
int i = 0; i < numElem; i += 4)
76 for (int64_t v : vals)
89 int numElem = numBits / 32;
90 return b.
create<vector::ShuffleOp>(
103 int numElem = numBits / 32;
104 return b.
create<vector::ShuffleOp>(
118 int numElem = numBits / 32;
119 auto shuffle = b.
create<vector::ShuffleOp>(
133 int numElem = numBits / 32;
134 return b.
create<vector::ShuffleOp>(
158 assert(cast<VectorType>(v1.
getType()).getShape()[0] == 16 &&
159 "expected a vector with length=16");
161 auto appendToMask = [&](int64_t base, uint8_t control) {
165 base + 2, base + 3});
169 base + 6, base + 7});
173 base + 10, base + 11});
177 base + 14, base + 15});
180 llvm_unreachable(
"control > 3 : overflow");
183 uint8_t b01 = mask & 0x3;
184 uint8_t b23 = (mask >> 2) & 0x3;
185 uint8_t b45 = (mask >> 4) & 0x3;
186 uint8_t b67 = (mask >> 6) & 0x3;
187 appendToMask(0, b01);
188 appendToMask(0, b23);
189 appendToMask(16, b45);
190 appendToMask(16, b67);
191 return b.
create<vector::ShuffleOp>(v1, v2, shuffleMask);
199 for (int64_t
j = 0;
j < n; ++
j)
200 for (int64_t i = 0; i < m; ++i)
201 mask.push_back(i * n +
j);
202 return b.
create<vector::ShuffleOp>(source.
getLoc(), source, source, mask);
211 for (int64_t i = 0; i < m; ++i)
212 vs.push_back(b.
create<vector::ExtractOp>(source, i));
293 {m, n}, cast<VectorType>(source.
getType()).getElementType());
296 for (int64_t i = 0; i < m; ++i)
297 res = b.
create<vector::InsertOp>(vs[i], res, i);
318 vectorTransformOptions(vectorTransformOptions) {}
320 LogicalResult matchAndRewrite(vector::TransposeOp op,
322 auto loc = op.getLoc();
324 Value input = op.getVector();
325 VectorType inputType = op.getSourceVectorType();
326 VectorType resType = op.getResultVectorType();
328 if (inputType.isScalable())
330 op,
"This lowering does not support scalable vectors");
335 if (
isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
338 op,
"Options specifies lowering to shuffle");
341 if (vectorTransformOptions.vectorTransposeLowering ==
342 vector::VectorTransposeLowering::Flat &&
343 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
347 rewriter.
create<vector::ShapeCastOp>(loc, flattenedType, input);
350 Value trans = rewriter.
create<vector::FlatTransposeOp>(
351 loc, flattenedType, matrix,
rows, columns);
363 size_t numPrunedDims = transp.size() - prunedTransp.size();
364 auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
375 for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
377 auto extractIdxs =
delinearize(linearIdx, prunedInStrides);
381 rewriter.
create<vector::ExtractOp>(loc, input, extractIdxs);
383 rewriter.
create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
422 class Transpose2DWithUnitDimToShapeCast
427 Transpose2DWithUnitDimToShapeCast(
MLIRContext *context,
431 LogicalResult matchAndRewrite(vector::TransposeOp op,
433 Value input = op.getVector();
434 VectorType resType = op.getResultVectorType();
439 if (resType.getRank() == 2 &&
440 ((resType.getShape().front() == 1 &&
441 !resType.getScalableDims().front()) ||
442 (resType.getShape().back() == 1 &&
443 !resType.getScalableDims().back())) &&
460 class TransposeOp2DToShuffleLowering
465 TransposeOp2DToShuffleLowering(
469 vectorTransformOptions(vectorTransformOptions) {}
471 LogicalResult matchAndRewrite(vector::TransposeOp op,
473 if (!
isShuffleLike(vectorTransformOptions.vectorTransposeLowering))
475 op,
"not using vector shuffle based lowering");
477 if (op.getSourceVectorType().isScalable())
479 op,
"vector shuffle lowering not supported for scalable vectors");
482 if (failed(srcGtOneDims))
484 op,
"expected transposition on a 2D slice");
486 VectorType srcType = op.getSourceVectorType();
487 int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
488 int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
493 auto flattenedType =
VectorType::get({n * m}, srcType.getElementType());
494 auto reshInputType =
VectorType::get({m, n}, srcType.getElementType());
495 auto reshInput = rewriter.
create<vector::ShapeCastOp>(loc, flattenedType,
499 if (vectorTransformOptions.vectorTransposeLowering ==
500 VectorTransposeLowering::Shuffle16x16 &&
501 m == 16 && n == 16) {
503 rewriter.
create<vector::ShapeCastOp>(loc, reshInputType, reshInput);
511 op, op.getResultVectorType(), res);
525 patterns.
add<Transpose2DWithUnitDimToShapeCast>(patterns.
getContext(),
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask.
static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask.
static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m, int n)
Lowers the value to a sequence of vector.shuffle ops.
static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask.
static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask.
static void pruneNonTransposedDims(ArrayRef< int64_t > transpose, SmallVectorImpl< int64_t > &result)
Given a 'transpose' pattern, prune the rightmost dimensions that are not transposed.
static bool isShuffleLike(VectorTransposeLowering lowering)
Returns true if the lowering option is a vector shuffle based approach.
static SmallVector< int64_t > getUnpackShufflePermFor128Lane(ArrayRef< int64_t > vals, int numBits)
Returns a shuffle mask that builds on vals.
static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit elements) selected by mask...
static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n)
Lowers the value to a vector.shuffle op.
static llvm::ManagedStatic< PassManagerOptions > options
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
IntegerAttr getI32IntegerAttr(int32_t value)
TypedAttr getZeroAttr(Type type)
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.