36 #define DEBUG_TYPE "lower-vector-transpose"
45 size_t numTransposedDims = transpose.size();
46 for (
size_t transpDim : llvm::reverse(transpose)) {
47 if (transpDim != numTransposedDims - 1)
52 result.append(transpose.begin(), transpose.begin() + numTransposedDims);
57 return lowering == VectorTransposeLowering::Shuffle1D ||
58 lowering == VectorTransposeLowering::Shuffle16x16;
73 assert(numBits % 128 == 0 &&
"expected numBits is a multiple of 128");
74 int numElem = numBits / 32;
76 for (
int i = 0; i < numElem; i += 4)
77 for (int64_t v : vals)
90 int numElem = numBits / 32;
91 return b.
create<vector::ShuffleOp>(
104 int numElem = numBits / 32;
105 return b.
create<vector::ShuffleOp>(
119 int numElem = numBits / 32;
120 auto shuffle = b.
create<vector::ShuffleOp>(
134 int numElem = numBits / 32;
135 return b.
create<vector::ShuffleOp>(
159 assert(cast<VectorType>(v1.
getType()).getShape()[0] == 16 &&
160 "expected a vector with length=16");
162 auto appendToMask = [&](int64_t base, uint8_t control) {
166 base + 2, base + 3});
170 base + 6, base + 7});
174 base + 10, base + 11});
178 base + 14, base + 15});
181 llvm_unreachable(
"control > 3 : overflow");
184 uint8_t b01 = mask & 0x3;
185 uint8_t b23 = (mask >> 2) & 0x3;
186 uint8_t b45 = (mask >> 4) & 0x3;
187 uint8_t b67 = (mask >> 6) & 0x3;
188 appendToMask(0, b01);
189 appendToMask(0, b23);
190 appendToMask(16, b45);
191 appendToMask(16, b67);
192 return b.
create<vector::ShuffleOp>(v1, v2, shuffleMask);
200 for (int64_t
j = 0;
j < n; ++
j)
201 for (int64_t i = 0; i < m; ++i)
202 mask.push_back(i * n +
j);
203 return b.
create<vector::ShuffleOp>(source.
getLoc(), source, source, mask);
212 for (int64_t i = 0; i < m; ++i)
213 vs.push_back(b.
create<vector::ExtractOp>(source, i));
294 {m, n}, cast<VectorType>(source.
getType()).getElementType());
297 for (int64_t i = 0; i < m; ++i)
298 res = b.
create<vector::InsertOp>(vs[i], res, i);
319 vectorTransformOptions(vectorTransformOptions) {}
325 Value input = op.getVector();
326 VectorType inputType = op.getSourceVectorType();
327 VectorType resType = op.getResultVectorType();
332 if (
isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
335 op,
"Options specifies lowering to shuffle");
337 if (vectorTransformOptions.useShapeCast) {
346 if (resType.getRank() == 2 &&
347 ((resType.getShape().front() == 1 &&
348 !resType.getScalableDims().front()) ||
349 (resType.getShape().back() == 1 &&
350 !resType.getScalableDims().back())) &&
357 if (inputType.isScalable())
361 if (vectorTransformOptions.vectorTransposeLowering ==
362 vector::VectorTransposeLowering::Flat &&
363 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
367 rewriter.
create<vector::ShapeCastOp>(loc, flattenedType, input);
370 Value trans = rewriter.
create<vector::FlatTransposeOp>(
371 loc, flattenedType, matrix, rows, columns);
383 size_t numPrunedDims = transp.size() - prunedTransp.size();
384 auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
395 for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
397 auto extractIdxs =
delinearize(linearIdx, prunedInStrides);
401 rewriter.
create<vector::ExtractOp>(loc, input, extractIdxs);
403 rewriter.
create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
422 class TransposeOp2DToShuffleLowering
427 TransposeOp2DToShuffleLowering(
431 vectorTransformOptions(vectorTransformOptions) {}
435 if (!
isShuffleLike(vectorTransformOptions.vectorTransposeLowering))
437 op,
"not using vector shuffle based lowering");
442 op,
"expected transposition on a 2D slice");
444 VectorType srcType = op.getSourceVectorType();
445 int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
446 int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
451 auto flattenedType =
VectorType::get({n * m}, srcType.getElementType());
452 auto reshInputType =
VectorType::get({m, n}, srcType.getElementType());
453 auto reshInput = rewriter.
create<vector::ShapeCastOp>(loc, flattenedType,
457 if (vectorTransformOptions.vectorTransposeLowering ==
458 VectorTransposeLowering::Shuffle16x16 &&
459 m == 16 && n == 16) {
461 rewriter.
create<vector::ShapeCastOp>(loc, reshInputType, reshInput);
469 op, op.getResultVectorType(), res);
static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask.
static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask.
static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m, int n)
Lowers the value to a sequence of vector.shuffle ops.
static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask.
static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask.
static void pruneNonTransposedDims(ArrayRef< int64_t > transpose, SmallVectorImpl< int64_t > &result)
Given a 'transpose' pattern, prune the rightmost dimensions that are not transposed.
static bool isShuffleLike(VectorTransposeLowering lowering)
Returns true if the lowering option is a vector shuffle based approach.
static SmallVector< int64_t > getUnpackShufflePermFor128Lane(ArrayRef< int64_t > vals, int numBits)
Returns a shuffle mask that builds on vals.
static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit elements) selected by mask...
static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n)
Lowers the value to a vector.shuffle op.
static llvm::ManagedStatic< PassManagerOptions > options
static int64_t getNumElements(ShapedType type)
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
IntegerAttr getI32IntegerAttr(int32_t value)
TypedAttr getZeroAttr(Type type)
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.