36 #define DEBUG_TYPE "lower-vector-transpose"
45 size_t numTransposedDims =
transpose.size();
46 for (
size_t transpDim : llvm::reverse(
transpose)) {
47 if (transpDim != numTransposedDims - 1)
57 return lowering == VectorTransposeLowering::Shuffle1D ||
58 lowering == VectorTransposeLowering::Shuffle16x16;
73 assert(numBits % 128 == 0 &&
"expected numBits is a multiple of 128");
74 int numElem = numBits / 32;
76 for (
int i = 0; i < numElem; i += 4)
77 for (int64_t v : vals)
90 int numElem = numBits / 32;
91 return b.
create<vector::ShuffleOp>(
104 int numElem = numBits / 32;
105 return b.
create<vector::ShuffleOp>(
119 int numElem = numBits / 32;
120 auto shuffle = b.
create<vector::ShuffleOp>(
134 int numElem = numBits / 32;
135 return b.
create<vector::ShuffleOp>(
159 assert(cast<VectorType>(v1.
getType()).getShape()[0] == 16 &&
160 "expected a vector with length=16");
162 auto appendToMask = [&](int64_t base, uint8_t control) {
166 base + 2, base + 3});
170 base + 6, base + 7});
174 base + 10, base + 11});
178 base + 14, base + 15});
181 llvm_unreachable(
"control > 3 : overflow");
184 uint8_t b01 = mask & 0x3;
185 uint8_t b23 = (mask >> 2) & 0x3;
186 uint8_t b45 = (mask >> 4) & 0x3;
187 uint8_t b67 = (mask >> 6) & 0x3;
188 appendToMask(0, b01);
189 appendToMask(0, b23);
190 appendToMask(16, b45);
191 appendToMask(16, b67);
192 return b.
create<vector::ShuffleOp>(v1, v2, shuffleMask);
200 for (int64_t
j = 0;
j < n; ++
j)
201 for (int64_t i = 0; i < m; ++i)
202 mask.push_back(i * n +
j);
203 return b.
create<vector::ShuffleOp>(source.
getLoc(), source, source, mask);
212 for (int64_t i = 0; i < m; ++i)
213 vs.push_back(b.
create<vector::ExtractOp>(source, i));
294 {m, n}, cast<VectorType>(source.
getType()).getElementType());
297 for (int64_t i = 0; i < m; ++i)
298 res = b.
create<vector::InsertOp>(vs[i], res, i);
319 vectorTransformOptions(vectorTransformOptions) {}
325 Value input = op.getVector();
326 VectorType inputType = op.getSourceVectorType();
327 VectorType resType = op.getResultVectorType();
332 if (
isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
335 op,
"Options specifies lowering to shuffle");
345 if (resType.getRank() == 2 &&
346 ((resType.getShape().front() == 1 &&
347 !resType.getScalableDims().front()) ||
348 (resType.getShape().back() == 1 &&
349 !resType.getScalableDims().back())) &&
355 if (inputType.isScalable())
359 if (vectorTransformOptions.vectorTransposeLowering ==
360 vector::VectorTransposeLowering::Flat &&
361 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
365 rewriter.
create<vector::ShapeCastOp>(loc, flattenedType, input);
368 Value trans = rewriter.
create<vector::FlatTransposeOp>(
369 loc, flattenedType, matrix, rows, columns);
381 size_t numPrunedDims = transp.size() - prunedTransp.size();
382 auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
393 for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
395 auto extractIdxs =
delinearize(linearIdx, prunedInStrides);
399 rewriter.
create<vector::ExtractOp>(loc, input, extractIdxs);
401 rewriter.
create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
420 class TransposeOp2DToShuffleLowering
425 TransposeOp2DToShuffleLowering(
429 vectorTransformOptions(vectorTransformOptions) {}
433 if (!
isShuffleLike(vectorTransformOptions.vectorTransposeLowering))
435 op,
"not using vector shuffle based lowering");
437 if (op.getSourceVectorType().isScalable())
439 op,
"vector shuffle lowering not supported for scalable vectors");
444 op,
"expected transposition on a 2D slice");
446 VectorType srcType = op.getSourceVectorType();
447 int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
448 int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
453 auto flattenedType =
VectorType::get({n * m}, srcType.getElementType());
454 auto reshInputType =
VectorType::get({m, n}, srcType.getElementType());
455 auto reshInput = rewriter.
create<vector::ShapeCastOp>(loc, flattenedType,
459 if (vectorTransformOptions.vectorTransposeLowering ==
460 VectorTransposeLowering::Shuffle16x16 &&
461 m == 16 && n == 16) {
463 rewriter.
create<vector::ShapeCastOp>(loc, reshInputType, reshInput);
471 op, op.getResultVectorType(), res);
static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask.
static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask.
static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m, int n)
Lowers the value to a sequence of vector.shuffle ops.
static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask.
static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask.
static void pruneNonTransposedDims(ArrayRef< int64_t > transpose, SmallVectorImpl< int64_t > &result)
Given a 'transpose' pattern, prune the rightmost dimensions that are not transposed.
static bool isShuffleLike(VectorTransposeLowering lowering)
Returns true if the lowering option is a vector shuffle based approach.
static SmallVector< int64_t > getUnpackShufflePermFor128Lane(ArrayRef< int64_t > vals, int numBits)
Returns a shuffle mask that builds on vals.
static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit elements) selected by mask...
static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n)
Lowers the value to a vector.shuffle op.
static llvm::ManagedStatic< PassManagerOptions > options
static int64_t getNumElements(ShapedType type)
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
IntegerAttr getI32IntegerAttr(int32_t value)
TypedAttr getZeroAttr(Type type)
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
static void transpose(llvm::ArrayRef< int64_t > trans, std::vector< int64_t > &shape)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.