22 #include "llvm/Support/Format.h"
23 #include "llvm/Support/FormatVariadic.h"
36 const auto *asmTp =
"vblendps $0, $1, $2, {0}";
40 auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, 2)).str();
41 auto asmOp = b.
create<LLVM::InlineAsmOp>(
44 false, asmDialectAttr,
46 return asmOp.getResult(0);
51 return b.
create<vector::ShuffleOp>(
57 return b.
create<vector::ShuffleOp>(
67 uint8_t b01, b23, b45, b67;
68 MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
70 b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
71 return b.
create<vector::ShuffleOp>(v1, v2, shuffleMask);
83 auto appendToMask = [&](uint8_t control) {
86 else if (control == 1)
88 else if (control == 2)
90 else if (control == 3)
93 llvm_unreachable(
"control > 3 : overflow");
96 MaskHelper::extractPermute(mask, b03, b47);
99 return b.
create<vector::ShuffleOp>(v1, v2, shuffleMask);
107 for (
int i = 0; i < 8; ++i) {
108 bool isSet = mask & (1 << i);
109 shuffleMask.push_back(!isSet ? i : i + 8);
111 return b.
create<vector::ShuffleOp>(v1, v2, shuffleMask);
119 assert(vs.size() == 4 &&
"expects 4 vectors");
120 assert(llvm::all_of(
ValueRange{vs}.getTypes(),
121 [&](
Type t) {
return t == vt; }) &&
122 "expects all types to be vector<8xf32>");
144 assert(vs.size() == 8 &&
"expects 8 vectors");
145 assert(llvm::all_of(
ValueRange{vs}.getTypes(),
146 [&](
Type t) {
return t == vt; }) &&
147 "expects all types to be vector<8xf32>");
165 mm256BlendPsAsm(ib, t0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
167 mm256BlendPsAsm(ib, t2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
169 mm256BlendPsAsm(ib, t1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
171 mm256BlendPsAsm(ib, t3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
173 mm256BlendPsAsm(ib, t4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
175 mm256BlendPsAsm(ib, t6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
177 mm256BlendPsAsm(ib, t5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
179 mm256BlendPsAsm(ib, t7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
215 loweringOptions(loweringOptions) {}
219 auto loc = op.getLoc();
223 VectorType srcType = op.getSourceVectorType();
224 if (!srcType.getElementType().isF32())
228 if (failed(srcGtOneDims))
230 op,
"expected transposition on a 2D slice");
234 int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
235 int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
237 auto applyRewrite = [&]() {
245 auto reshInputType =
VectorType::get({m, n}, srcType.getElementType());
247 ib.
create<vector::ShapeCastOp>(flattenedType, op.getVector());
248 reshInput = ib.
create<vector::ShapeCastOp>(reshInputType, reshInput);
252 for (int64_t i = 0; i < m; ++i)
253 vs.push_back(ib.
create<vector::ExtractOp>(reshInput, i));
263 Value res = ib.
create<arith::ConstantOp>(reshInputType,
265 for (int64_t i = 0; i < m; ++i)
266 res = ib.
create<vector::InsertOp>(vs[i], res, i);
271 res = ib.
create<vector::ShapeCastOp>(flattenedType, res);
272 res = ib.
create<vector::ShapeCastOp>(op.getResultVectorType(), res);
277 if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)
278 return applyRewrite();
279 if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)
280 return applyRewrite();
static llvm::ManagedStatic< PassManagerOptions > options
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context, int benefit)
LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
Value mm256BlendPsAsm(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Methods in the inline_asm namespace emit calls to LLVM::InlineAsmOp.
Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2)
Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
Value mm256BlendPs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
If bit i of mask is zero, take f32@i from v1 else take it from v2.
Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
a a b b a a b b Take an 8 bit mask, 2 bit for each position of a[0, 3) and b[0, 4): 0:127 | 128:255 b...
Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2)
Methods in the intrin namespace emulate clang's impl. of X86 intrinsics.
void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef< Value > vs)
8x8xf32-specific AVX2 transpose lowering.
void populateSpecializedTransposeLoweringPatterns(RewritePatternSet &patterns, LoweringOptions options=LoweringOptions(), int benefit=10)
Insert specialized transpose lowering patterns.
void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef< Value > vs)
Generic lowerings may either use intrin or inline_asm depending on needs.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options for controlling specialized AVX2 lowerings.