20#include "llvm/Support/Format.h"
21#include "llvm/Support/FormatVariadic.h"
33 LLVM::AsmDialectAttr::get(
b.getContext(), LLVM::AsmDialect::AD_Intel);
34 const auto *asmTp =
"vblendps $0, $1, $2, {0}";
38 auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, 2)).str();
39 auto asmOp = LLVM::InlineAsmOp::create(
42 false, LLVM::TailCallKind::None,
45 return asmOp.getResult(0);
50 return vector::ShuffleOp::create(
b, v1, v2,
56 return vector::ShuffleOp::create(
57 b, v1, v2,
ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
66 uint8_t b01, b23, b45, b67;
69 b01, b23, b45 + 8, b67 + 8, b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
70 return vector::ShuffleOp::create(
b, v1, v2, shuffleMask);
82 auto appendToMask = [&](uint8_t control) {
85 else if (control == 1)
87 else if (control == 2)
89 else if (control == 3)
92 llvm_unreachable(
"control > 3 : overflow");
98 return vector::ShuffleOp::create(
b, v1, v2, shuffleMask);
106 for (
int i = 0; i < 8; ++i) {
107 bool isSet = mask & (1 << i);
108 shuffleMask.push_back(!isSet ? i : i + 8);
110 return vector::ShuffleOp::create(
b, v1, v2, shuffleMask);
117 auto vt = VectorType::get({8}, Float32Type::get(ib.
getContext()));
118 assert(vs.size() == 4 &&
"expects 4 vectors");
120 [&](
Type t) {
return t == vt; }) &&
121 "expects all types to be vector<8xf32>");
141 auto vt = VectorType::get({8}, Float32Type::get(ib.
getContext()));
143 assert(vs.size() == 8 &&
"expects 8 vectors");
145 [&](
Type t) {
return t == vt; }) &&
146 "expects all types to be vector<8xf32>");
164 mm256BlendPsAsm(ib, t0, sh0,
MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
166 mm256BlendPsAsm(ib, t2, sh0,
MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
168 mm256BlendPsAsm(ib, t1, sh2,
MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
170 mm256BlendPsAsm(ib, t3, sh2,
MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
172 mm256BlendPsAsm(ib, t4, sh4,
MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
174 mm256BlendPsAsm(ib, t6, sh4,
MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
176 mm256BlendPsAsm(ib, t5, sh6,
MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
178 mm256BlendPsAsm(ib, t7, sh6,
MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
214 loweringOptions(loweringOptions) {}
218 auto loc = op.getLoc();
222 VectorType srcType = op.getSourceVectorType();
223 if (!srcType.getElementType().isF32())
227 if (failed(srcGtOneDims))
229 op,
"expected transposition on a 2D slice");
233 int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
234 int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
236 auto applyRewrite = [&]() {
243 VectorType::get({n * m}, op.getSourceVectorType().
getElementType());
244 auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
246 vector::ShapeCastOp::create(ib, flattenedType, op.getVector());
247 reshInput = vector::ShapeCastOp::create(ib, reshInputType, reshInput);
251 for (
int64_t i = 0; i < m; ++i)
252 vs.push_back(vector::ExtractOp::create(ib, reshInput, i));
262 Value res = arith::ConstantOp::create(ib, reshInputType,
264 for (
int64_t i = 0; i < m; ++i)
265 res = vector::InsertOp::create(ib, vs[i], res, i);
270 res = vector::ShapeCastOp::create(ib, flattenedType, res);
271 res = vector::ShapeCastOp::create(ib, op.getResultVectorType(), res);
276 if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)
277 return applyRewrite();
278 if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)
279 return applyRewrite();
static Type getElementType(Type type)
Determine the element type of type.
static llvm::ManagedStatic< PassManagerOptions > options
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context, int benefit)
LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
MLIRContext is the top-level object for a collection of MLIR operations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
Value mm256BlendPsAsm(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Methods in the inline_asm namespace emit calls to LLVM::InlineAsmOp.
Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2)
Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
Value mm256BlendPs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
If bit i of mask is zero, take f32@i from v1 else take it from v2.
Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
a a b b a a b b Take an 8 bit mask, 2 bit for each position of a[0, 3) and b[0, 4): 0:127 | 128:255 b...
Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2)
Methods in the intrin namespace emulate clang's impl. of X86 intrinsics.
void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef< Value > vs)
8x8xf32-specific AVX2 transpose lowering.
void populateSpecializedTransposeLoweringPatterns(RewritePatternSet &patterns, LoweringOptions options=LoweringOptions(), int benefit=10)
Insert specialized transpose lowering patterns.
void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef< Value > vs)
Generic lowerings may either use intrin or inline_asm depending on needs.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
static void extractPermute(uint8_t mask, uint8_t &b03, uint8_t &b47)
b03 captures the lower 4 bits, b47 captures the higher 4 bits.
static void extractShuffle(uint8_t mask, uint8_t &b01, uint8_t &b23, uint8_t &b45, uint8_t &b67)
b01 captures the lower 2 bits, b67 captures the higher 2 bits.
static uint8_t shuffle()
b01 captures the lower 2 bits, b67 captures the higher 2 bits.
static uint8_t blend()
b0 captures the lowest bit, b7 captures the highest bit.
static uint8_t permute()
b03 captures the lower 4 bits, b47 captures the higher 4 bits.
Options for controlling specialized AVX2 lowerings.