MLIR 22.0.0git
AVXTranspose.cpp
Go to the documentation of this file.
1//===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements vector.transpose rewrites as AVX patterns for particular
10// sizes of interest.
11//
12//===----------------------------------------------------------------------===//
13
20#include "llvm/Support/Format.h"
21#include "llvm/Support/FormatVariadic.h"
22
23using namespace mlir;
24using namespace mlir::vector;
25using namespace mlir::x86vector;
26using namespace mlir::x86vector::avx2;
28using namespace mlir::x86vector::avx2::intrin;
29
31 ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
32 auto asmDialectAttr =
33 LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
34 const auto *asmTp = "vblendps $0, $1, $2, {0}";
35 const auto *asmCstr =
36 "=x,x,x"; // Careful: constraint parser is very brittle: no ws!
37 SmallVector<Value> asmVals{v1, v2};
38 auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str();
39 auto asmOp = LLVM::InlineAsmOp::create(
40 b, v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr,
41 /*constraints=*/asmCstr, /*has_side_effects=*/false,
42 /*is_align_stack=*/false, LLVM::TailCallKind::None,
43 /*asm_dialect=*/asmDialectAttr,
44 /*operand_attrs=*/ArrayAttr());
45 return asmOp.getResult(0);
46}
47
49 Value v1, Value v2) {
50 return vector::ShuffleOp::create(b, v1, v2,
51 ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
52}
53
55 Value v1, Value v2) {
56 return vector::ShuffleOp::create(
57 b, v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
58}
59/// a a b b a a b b
60/// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
61/// 0:127 | 128:255
62/// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
64 Value v1, Value v2,
65 uint8_t mask) {
66 uint8_t b01, b23, b45, b67;
67 MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
68 SmallVector<int64_t> shuffleMask = {
69 b01, b23, b45 + 8, b67 + 8, b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
70 return vector::ShuffleOp::create(b, v1, v2, shuffleMask);
71}
72
73// imm[0:1] out of imm[0:3] is:
74// 0 1 2 3
75// a[0:127] or a[128:255] or b[0:127] or b[128:255] |
76// a[0:127] or a[128:255] or b[0:127] or b[128:255]
77// 0 1 2 3
78// imm[0:1] out of imm[4:7].
80 ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
81 SmallVector<int64_t> shuffleMask;
82 auto appendToMask = [&](uint8_t control) {
83 if (control == 0)
84 llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3});
85 else if (control == 1)
86 llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7});
87 else if (control == 2)
88 llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11});
89 else if (control == 3)
90 llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15});
91 else
92 llvm_unreachable("control > 3 : overflow");
93 };
94 uint8_t b03, b47;
95 MaskHelper::extractPermute(mask, b03, b47);
96 appendToMask(b03);
97 appendToMask(b47);
98 return vector::ShuffleOp::create(b, v1, v2, shuffleMask);
99}
100
101/// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
103 Value v1, Value v2,
104 uint8_t mask) {
105 SmallVector<int64_t, 8> shuffleMask;
106 for (int i = 0; i < 8; ++i) {
107 bool isSet = mask & (1 << i);
108 shuffleMask.push_back(!isSet ? i : i + 8);
109 }
110 return vector::ShuffleOp::create(b, v1, v2, shuffleMask);
111}
112
113/// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
116#ifndef NDEBUG
117 auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
118 assert(vs.size() == 4 && "expects 4 vectors");
119 assert(llvm::all_of(ValueRange{vs}.getTypes(),
120 [&](Type t) { return t == vt; }) &&
121 "expects all types to be vector<8xf32>");
122#endif
123
124 Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
125 Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
126 Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
127 Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
132 vs[0] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<2, 0>());
133 vs[1] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<2, 0>());
134 vs[2] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<3, 1>());
135 vs[3] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<3, 1>());
136}
137
138/// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
141 auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
142 (void)vt;
143 assert(vs.size() == 8 && "expects 8 vectors");
144 assert(llvm::all_of(ValueRange{vs}.getTypes(),
145 [&](Type t) { return t == vt; }) &&
146 "expects all types to be vector<8xf32>");
147
148 Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
149 Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
150 Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
151 Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
152 Value t4 = mm256UnpackLoPs(ib, vs[4], vs[5]);
153 Value t5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
154 Value t6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
155 Value t7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
156
162
163 Value s0 =
165 Value s1 =
167 Value s2 =
169 Value s3 =
171 Value s4 =
173 Value s5 =
175 Value s6 =
177 Value s7 =
179
180 vs[0] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<2, 0>());
181 vs[1] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<2, 0>());
182 vs[2] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<2, 0>());
183 vs[3] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<2, 0>());
184 vs[4] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<3, 1>());
185 vs[5] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<3, 1>());
186 vs[6] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<3, 1>());
187 vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>());
188}
189
190/// Rewrite AVX2-specific vector.transpose, for the supported cases and
191/// depending on the `TransposeLoweringOptions`. The lowering supports 2-D
192/// transpose cases and n-D cases that have been decomposed into 2-D
193/// transposition slices. For example, a 3-D transpose:
194///
195/// %0 = vector.transpose %arg0, [2, 0, 1]
196/// : vector<1024x2048x4096xf32> to vector<4096x1024x2048xf32>
197///
198/// could be sliced into 2-D transposes by tiling two of its dimensions to one
199/// of the vector lengths supported by the AVX2 patterns (e.g., 4x8):
200///
201/// %0 = vector.transpose %arg0, [2, 0, 1]
202/// : vector<1x4x8xf32> to vector<8x1x4xf32>
203///
204/// This lowering will analyze the n-D vector.transpose and determine if it's a
205/// supported 2-D transposition slice where any of the AVX2 patterns can be
206/// applied.
207class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
208public:
209 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
210
212 int benefit)
213 : OpRewritePattern<vector::TransposeOp>(context, benefit),
214 loweringOptions(loweringOptions) {}
215
216 LogicalResult matchAndRewrite(vector::TransposeOp op,
217 PatternRewriter &rewriter) const override {
218 auto loc = op.getLoc();
219
220 // Check if the source vector type is supported. AVX2 patterns can only be
221 // applied to f32 vector types with two dimensions greater than one.
222 VectorType srcType = op.getSourceVectorType();
223 if (!srcType.getElementType().isF32())
224 return rewriter.notifyMatchFailure(op, "Unsupported vector element type");
225
226 auto srcGtOneDims = mlir::vector::isTranspose2DSlice(op);
227 if (failed(srcGtOneDims))
228 return rewriter.notifyMatchFailure(
229 op, "expected transposition on a 2D slice");
230
231 // Retrieve the sizes of the two dimensions greater than one to be
232 // transposed.
233 int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
234 int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
235
236 auto applyRewrite = [&]() {
237 ImplicitLocOpBuilder ib(loc, rewriter);
239
240 // Reshape the n-D input vector with only two dimensions greater than one
241 // to a 2-D vector.
242 auto flattenedType =
243 VectorType::get({n * m}, op.getSourceVectorType().getElementType());
244 auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
245 auto reshInput =
246 vector::ShapeCastOp::create(ib, flattenedType, op.getVector());
247 reshInput = vector::ShapeCastOp::create(ib, reshInputType, reshInput);
248
249 // Extract 1-D vectors from the higher-order dimension of the input
250 // vector.
251 for (int64_t i = 0; i < m; ++i)
252 vs.push_back(vector::ExtractOp::create(ib, reshInput, i));
253
254 // Transpose set of 1-D vectors.
255 if (m == 4)
256 transpose4x8xf32(ib, vs);
257 if (m == 8)
258 transpose8x8xf32(ib, vs);
259
260 // Insert transposed 1-D vectors into the higher-order dimension of the
261 // output vector.
262 Value res = arith::ConstantOp::create(ib, reshInputType,
263 ib.getZeroAttr(reshInputType));
264 for (int64_t i = 0; i < m; ++i)
265 res = vector::InsertOp::create(ib, vs[i], res, i);
266
267 // The output vector still has the shape of the input vector (e.g., 4x8).
268 // We have to transpose their dimensions and retrieve its original rank
269 // (e.g., 1x8x1x4x1).
270 res = vector::ShapeCastOp::create(ib, flattenedType, res);
271 res = vector::ShapeCastOp::create(ib, op.getResultVectorType(), res);
272 rewriter.replaceOp(op, res);
273 return success();
274 };
275
276 if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)
277 return applyRewrite();
278 if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)
279 return applyRewrite();
280 return failure();
281 }
282
283private:
284 LoweringOptions loweringOptions;
285};
286
return success()
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
static llvm::ManagedStatic< PassManagerOptions > options
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context, int benefit)
LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
MLIRContext * getContext() const
Definition Builders.h:56
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
Value mm256BlendPsAsm(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Methods in the inline_asm namespace emit calls to LLVM::InlineAsmOp.
Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2)
Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
Value mm256BlendPs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
If bit i of mask is zero, take f32@i from v1 else take it from v2.
Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
a a b b a a b b Take an 8 bit mask, 2 bit for each position of a[0, 3) and b[0, 4): 0:127 | 128:255 b...
Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2)
Methods in the intrin namespace emulate clang's impl. of X86 intrinsics.
Helpers extracted from:
Definition Transforms.h:106
void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef< Value > vs)
8x8xf32-specific AVX2 transpose lowering.
void populateSpecializedTransposeLoweringPatterns(RewritePatternSet &patterns, LoweringOptions options=LoweringOptions(), int benefit=10)
Insert specialized transpose lowering patterns.
void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef< Value > vs)
Generic lowerings may either use intrin or inline_asm depending on needs.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
static void extractPermute(uint8_t mask, uint8_t &b03, uint8_t &b47)
b03 captures the lower 4 bits, b47 captures the higher 4 bits.
Definition Transforms.h:76
static void extractShuffle(uint8_t mask, uint8_t &b01, uint8_t &b23, uint8_t &b45, uint8_t &b67)
b01 captures the lower 2 bits, b67 captures the higher 2 bits.
Definition Transforms.h:60
static uint8_t shuffle()
b01 captures the lower 2 bits, b67 captures the higher 2 bits.
Definition Transforms.h:52
static uint8_t blend()
b0 captures the lowest bit, b7 captures the highest bit.
Definition Transforms.h:29
static uint8_t permute()
b03 captures the lower 4 bits, b47 captures the higher 4 bits.
Definition Transforms.h:70
Options for controlling specialized AVX2 lowerings.
Definition Transforms.h:171