MLIR  22.0.0git
AVXTranspose.cpp
Go to the documentation of this file.
1 //===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements vector.transpose rewrites as AVX patterns for particular
10 // sizes of interest.
11 //
12 //===----------------------------------------------------------------------===//
13 
19 #include "mlir/IR/PatternMatch.h"
20 #include "llvm/Support/Format.h"
21 #include "llvm/Support/FormatVariadic.h"
22 
23 using namespace mlir;
24 using namespace mlir::vector;
25 using namespace mlir::x86vector;
26 using namespace mlir::x86vector::avx2;
27 using namespace mlir::x86vector::avx2::inline_asm;
28 using namespace mlir::x86vector::avx2::intrin;
29 
31  ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
32  auto asmDialectAttr =
33  LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
34  const auto *asmTp = "vblendps $0, $1, $2, {0}";
35  const auto *asmCstr =
36  "=x,x,x"; // Careful: constraint parser is very brittle: no ws!
37  SmallVector<Value> asmVals{v1, v2};
38  auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str();
39  auto asmOp = LLVM::InlineAsmOp::create(
40  b, v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr,
41  /*constraints=*/asmCstr, /*has_side_effects=*/false,
42  /*is_align_stack=*/false, LLVM::TailCallKind::None,
43  /*asm_dialect=*/asmDialectAttr,
44  /*operand_attrs=*/ArrayAttr());
45  return asmOp.getResult(0);
46 }
47 
49  Value v1, Value v2) {
50  return vector::ShuffleOp::create(b, v1, v2,
51  ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
52 }
53 
55  Value v1, Value v2) {
56  return vector::ShuffleOp::create(
57  b, v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
58 }
59 /// a a b b a a b b
60 /// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
61 /// 0:127 | 128:255
62 /// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
64  Value v1, Value v2,
65  uint8_t mask) {
66  uint8_t b01, b23, b45, b67;
67  MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
68  SmallVector<int64_t> shuffleMask = {
69  b01, b23, b45 + 8, b67 + 8, b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
70  return vector::ShuffleOp::create(b, v1, v2, shuffleMask);
71 }
72 
73 // imm[0:1] out of imm[0:3] is:
74 // 0 1 2 3
75 // a[0:127] or a[128:255] or b[0:127] or b[128:255] |
76 // a[0:127] or a[128:255] or b[0:127] or b[128:255]
77 // 0 1 2 3
78 // imm[0:1] out of imm[4:7].
80  ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
81  SmallVector<int64_t> shuffleMask;
82  auto appendToMask = [&](uint8_t control) {
83  if (control == 0)
84  llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3});
85  else if (control == 1)
86  llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7});
87  else if (control == 2)
88  llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11});
89  else if (control == 3)
90  llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15});
91  else
92  llvm_unreachable("control > 3 : overflow");
93  };
94  uint8_t b03, b47;
95  MaskHelper::extractPermute(mask, b03, b47);
96  appendToMask(b03);
97  appendToMask(b47);
98  return vector::ShuffleOp::create(b, v1, v2, shuffleMask);
99 }
100 
101 /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
103  Value v1, Value v2,
104  uint8_t mask) {
105  SmallVector<int64_t, 8> shuffleMask;
106  for (int i = 0; i < 8; ++i) {
107  bool isSet = mask & (1 << i);
108  shuffleMask.push_back(!isSet ? i : i + 8);
109  }
110  return vector::ShuffleOp::create(b, v1, v2, shuffleMask);
111 }
112 
113 /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
116 #ifndef NDEBUG
117  auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
118  assert(vs.size() == 4 && "expects 4 vectors");
119  assert(llvm::all_of(ValueRange{vs}.getTypes(),
120  [&](Type t) { return t == vt; }) &&
121  "expects all types to be vector<8xf32>");
122 #endif
123 
124  Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
125  Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
126  Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
127  Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
128  Value s0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 1, 0>());
129  Value s1 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<3, 2, 3, 2>());
130  Value s2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 1, 0>());
131  Value s3 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<3, 2, 3, 2>());
132  vs[0] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<2, 0>());
133  vs[1] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<2, 0>());
134  vs[2] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<3, 1>());
135  vs[3] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<3, 1>());
136 }
137 
138 /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
141  auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
142  (void)vt;
143  assert(vs.size() == 8 && "expects 8 vectors");
144  assert(llvm::all_of(ValueRange{vs}.getTypes(),
145  [&](Type t) { return t == vt; }) &&
146  "expects all types to be vector<8xf32>");
147 
148  Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
149  Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
150  Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
151  Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
152  Value t4 = mm256UnpackLoPs(ib, vs[4], vs[5]);
153  Value t5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
154  Value t6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
155  Value t7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
156 
158  Value sh0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 3, 2>());
159  Value sh2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 3, 2>());
160  Value sh4 = mm256ShufflePs(ib, t4, t6, MaskHelper::shuffle<1, 0, 3, 2>());
161  Value sh6 = mm256ShufflePs(ib, t5, t7, MaskHelper::shuffle<1, 0, 3, 2>());
162 
163  Value s0 =
164  mm256BlendPsAsm(ib, t0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
165  Value s1 =
166  mm256BlendPsAsm(ib, t2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
167  Value s2 =
168  mm256BlendPsAsm(ib, t1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
169  Value s3 =
170  mm256BlendPsAsm(ib, t3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
171  Value s4 =
172  mm256BlendPsAsm(ib, t4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
173  Value s5 =
174  mm256BlendPsAsm(ib, t6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
175  Value s6 =
176  mm256BlendPsAsm(ib, t5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
177  Value s7 =
178  mm256BlendPsAsm(ib, t7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
179 
180  vs[0] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<2, 0>());
181  vs[1] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<2, 0>());
182  vs[2] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<2, 0>());
183  vs[3] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<2, 0>());
184  vs[4] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<3, 1>());
185  vs[5] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<3, 1>());
186  vs[6] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<3, 1>());
187  vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>());
188 }
189 
190 /// Rewrite AVX2-specific vector.transpose, for the supported cases and
191 /// depending on the `TransposeLoweringOptions`. The lowering supports 2-D
192 /// transpose cases and n-D cases that have been decomposed into 2-D
193 /// transposition slices. For example, a 3-D transpose:
194 ///
195 /// %0 = vector.transpose %arg0, [2, 0, 1]
196 /// : vector<1024x2048x4096xf32> to vector<4096x1024x2048xf32>
197 ///
198 /// could be sliced into 2-D transposes by tiling two of its dimensions to one
199 /// of the vector lengths supported by the AVX2 patterns (e.g., 4x8):
200 ///
201 /// %0 = vector.transpose %arg0, [2, 0, 1]
202 /// : vector<1x4x8xf32> to vector<8x1x4xf32>
203 ///
204 /// This lowering will analyze the n-D vector.transpose and determine if it's a
205 /// supported 2-D transposition slice where any of the AVX2 patterns can be
206 /// applied.
207 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
208 public:
210 
212  int benefit)
213  : OpRewritePattern<vector::TransposeOp>(context, benefit),
214  loweringOptions(loweringOptions) {}
215 
216  LogicalResult matchAndRewrite(vector::TransposeOp op,
217  PatternRewriter &rewriter) const override {
218  auto loc = op.getLoc();
219 
220  // Check if the source vector type is supported. AVX2 patterns can only be
221  // applied to f32 vector types with two dimensions greater than one.
222  VectorType srcType = op.getSourceVectorType();
223  if (!srcType.getElementType().isF32())
224  return rewriter.notifyMatchFailure(op, "Unsupported vector element type");
225 
226  auto srcGtOneDims = mlir::vector::isTranspose2DSlice(op);
227  if (failed(srcGtOneDims))
228  return rewriter.notifyMatchFailure(
229  op, "expected transposition on a 2D slice");
230 
231  // Retrieve the sizes of the two dimensions greater than one to be
232  // transposed.
233  int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
234  int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
235 
236  auto applyRewrite = [&]() {
237  ImplicitLocOpBuilder ib(loc, rewriter);
239 
240  // Reshape the n-D input vector with only two dimensions greater than one
241  // to a 2-D vector.
242  auto flattenedType =
243  VectorType::get({n * m}, op.getSourceVectorType().getElementType());
244  auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
245  auto reshInput =
246  vector::ShapeCastOp::create(ib, flattenedType, op.getVector());
247  reshInput = vector::ShapeCastOp::create(ib, reshInputType, reshInput);
248 
249  // Extract 1-D vectors from the higher-order dimension of the input
250  // vector.
251  for (int64_t i = 0; i < m; ++i)
252  vs.push_back(vector::ExtractOp::create(ib, reshInput, i));
253 
254  // Transpose set of 1-D vectors.
255  if (m == 4)
256  transpose4x8xf32(ib, vs);
257  if (m == 8)
258  transpose8x8xf32(ib, vs);
259 
260  // Insert transposed 1-D vectors into the higher-order dimension of the
261  // output vector.
262  Value res = arith::ConstantOp::create(ib, reshInputType,
263  ib.getZeroAttr(reshInputType));
264  for (int64_t i = 0; i < m; ++i)
265  res = vector::InsertOp::create(ib, vs[i], res, i);
266 
267  // The output vector still has the shape of the input vector (e.g., 4x8).
268  // We have to transpose their dimensions and retrieve its original rank
269  // (e.g., 1x8x1x4x1).
270  res = vector::ShapeCastOp::create(ib, flattenedType, res);
271  res = vector::ShapeCastOp::create(ib, op.getResultVectorType(), res);
272  rewriter.replaceOp(op, res);
273  return success();
274  };
275 
276  if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)
277  return applyRewrite();
278  if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)
279  return applyRewrite();
280  return failure();
281  }
282 
283 private:
284  LoweringOptions loweringOptions;
285 };
286 
289  patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit);
290 }
static Type getElementType(Type type)
Determine the element type of type.
@ None
static llvm::ManagedStatic< PassManagerOptions > options
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context, int benefit)
LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:324
MLIRContext * getContext() const
Definition: Builders.h:56
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:629
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
Definition: VectorUtils.cpp:82
Value mm256BlendPsAsm(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Methods in the inline_asm namespace emit calls to LLVM::InlineAsmOp.
Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2)
Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
Value mm256BlendPs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
If bit i of mask is zero, take f32@i from v1 else take it from v2.
Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
a a b b a a b b Take an 8 bit mask, 2 bit for each position of a[0, 3) and b[0, 4): 0:127 | 128:255 b...
Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2)
Methods in the intrin namespace emulate clang's impl. of X86 intrinsics.
Helpers extracted from:
Definition: Transforms.h:94
void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef< Value > vs)
8x8xf32-specific AVX2 transpose lowering.
void populateSpecializedTransposeLoweringPatterns(RewritePatternSet &patterns, LoweringOptions options=LoweringOptions(), int benefit=10)
Insert specialized transpose lowering patterns.
void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef< Value > vs)
Generic lowerings may either use intrin or inline_asm depending on needs.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Options for controlling specialized AVX2 lowerings.
Definition: Transforms.h:159