MLIR  21.0.0git
AVXTranspose.cpp
Go to the documentation of this file.
1 //===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements vector.transpose rewrites as AVX patterns for particular
10 // sizes of interest.
11 //
12 //===----------------------------------------------------------------------===//
13 
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/Support/Format.h"
23 #include "llvm/Support/FormatVariadic.h"
24 
25 using namespace mlir;
26 using namespace mlir::vector;
27 using namespace mlir::x86vector;
28 using namespace mlir::x86vector::avx2;
29 using namespace mlir::x86vector::avx2::inline_asm;
30 using namespace mlir::x86vector::avx2::intrin;
31 
33  ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
34  auto asmDialectAttr =
35  LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
36  const auto *asmTp = "vblendps $0, $1, $2, {0}";
37  const auto *asmCstr =
38  "=x,x,x"; // Careful: constraint parser is very brittle: no ws!
39  SmallVector<Value> asmVals{v1, v2};
40  auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str();
41  auto asmOp = b.create<LLVM::InlineAsmOp>(
42  v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr,
43  /*constraints=*/asmCstr, /*has_side_effects=*/false,
44  /*is_align_stack=*/false, LLVM::TailCallKind::None,
45  /*asm_dialect=*/asmDialectAttr,
46  /*operand_attrs=*/ArrayAttr());
47  return asmOp.getResult(0);
48 }
49 
51  Value v1, Value v2) {
52  return b.create<vector::ShuffleOp>(
53  v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
54 }
55 
57  Value v1, Value v2) {
58  return b.create<vector::ShuffleOp>(
59  v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
60 }
61 /// a a b b a a b b
62 /// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
63 /// 0:127 | 128:255
64 /// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
66  Value v1, Value v2,
67  uint8_t mask) {
68  uint8_t b01, b23, b45, b67;
69  MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
70  SmallVector<int64_t> shuffleMask = {
71  b01, b23, b45 + 8, b67 + 8, b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
72  return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
73 }
74 
75 // imm[0:1] out of imm[0:3] is:
76 // 0 1 2 3
77 // a[0:127] or a[128:255] or b[0:127] or b[128:255] |
78 // a[0:127] or a[128:255] or b[0:127] or b[128:255]
79 // 0 1 2 3
80 // imm[0:1] out of imm[4:7].
82  ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
83  SmallVector<int64_t> shuffleMask;
84  auto appendToMask = [&](uint8_t control) {
85  if (control == 0)
86  llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3});
87  else if (control == 1)
88  llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7});
89  else if (control == 2)
90  llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11});
91  else if (control == 3)
92  llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15});
93  else
94  llvm_unreachable("control > 3 : overflow");
95  };
96  uint8_t b03, b47;
97  MaskHelper::extractPermute(mask, b03, b47);
98  appendToMask(b03);
99  appendToMask(b47);
100  return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
101 }
102 
103 /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
105  Value v1, Value v2,
106  uint8_t mask) {
107  SmallVector<int64_t, 8> shuffleMask;
108  for (int i = 0; i < 8; ++i) {
109  bool isSet = mask & (1 << i);
110  shuffleMask.push_back(!isSet ? i : i + 8);
111  }
112  return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
113 }
114 
115 /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
118 #ifndef NDEBUG
119  auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
120  assert(vs.size() == 4 && "expects 4 vectors");
121  assert(llvm::all_of(ValueRange{vs}.getTypes(),
122  [&](Type t) { return t == vt; }) &&
123  "expects all types to be vector<8xf32>");
124 #endif
125 
126  Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
127  Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
128  Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
129  Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
130  Value s0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 1, 0>());
131  Value s1 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<3, 2, 3, 2>());
132  Value s2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 1, 0>());
133  Value s3 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<3, 2, 3, 2>());
134  vs[0] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<2, 0>());
135  vs[1] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<2, 0>());
136  vs[2] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<3, 1>());
137  vs[3] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<3, 1>());
138 }
139 
140 /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
143  auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
144  (void)vt;
145  assert(vs.size() == 8 && "expects 8 vectors");
146  assert(llvm::all_of(ValueRange{vs}.getTypes(),
147  [&](Type t) { return t == vt; }) &&
148  "expects all types to be vector<8xf32>");
149 
150  Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
151  Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
152  Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
153  Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
154  Value t4 = mm256UnpackLoPs(ib, vs[4], vs[5]);
155  Value t5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
156  Value t6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
157  Value t7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
158 
160  Value sh0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 3, 2>());
161  Value sh2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 3, 2>());
162  Value sh4 = mm256ShufflePs(ib, t4, t6, MaskHelper::shuffle<1, 0, 3, 2>());
163  Value sh6 = mm256ShufflePs(ib, t5, t7, MaskHelper::shuffle<1, 0, 3, 2>());
164 
165  Value s0 =
166  mm256BlendPsAsm(ib, t0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
167  Value s1 =
168  mm256BlendPsAsm(ib, t2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
169  Value s2 =
170  mm256BlendPsAsm(ib, t1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
171  Value s3 =
172  mm256BlendPsAsm(ib, t3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
173  Value s4 =
174  mm256BlendPsAsm(ib, t4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
175  Value s5 =
176  mm256BlendPsAsm(ib, t6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
177  Value s6 =
178  mm256BlendPsAsm(ib, t5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
179  Value s7 =
180  mm256BlendPsAsm(ib, t7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
181 
182  vs[0] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<2, 0>());
183  vs[1] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<2, 0>());
184  vs[2] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<2, 0>());
185  vs[3] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<2, 0>());
186  vs[4] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<3, 1>());
187  vs[5] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<3, 1>());
188  vs[6] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<3, 1>());
189  vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>());
190 }
191 
192 /// Rewrite AVX2-specific vector.transpose, for the supported cases and
193 /// depending on the `TransposeLoweringOptions`. The lowering supports 2-D
194 /// transpose cases and n-D cases that have been decomposed into 2-D
195 /// transposition slices. For example, a 3-D transpose:
196 ///
197 /// %0 = vector.transpose %arg0, [2, 0, 1]
198 /// : vector<1024x2048x4096xf32> to vector<4096x1024x2048xf32>
199 ///
200 /// could be sliced into 2-D transposes by tiling two of its dimensions to one
201 /// of the vector lengths supported by the AVX2 patterns (e.g., 4x8):
202 ///
203 /// %0 = vector.transpose %arg0, [2, 0, 1]
204 /// : vector<1x4x8xf32> to vector<8x1x4xf32>
205 ///
206 /// This lowering will analyze the n-D vector.transpose and determine if it's a
207 /// supported 2-D transposition slice where any of the AVX2 patterns can be
208 /// applied.
209 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
210 public:
212 
214  int benefit)
215  : OpRewritePattern<vector::TransposeOp>(context, benefit),
216  loweringOptions(loweringOptions) {}
217 
218  LogicalResult matchAndRewrite(vector::TransposeOp op,
219  PatternRewriter &rewriter) const override {
220  auto loc = op.getLoc();
221 
222  // Check if the source vector type is supported. AVX2 patterns can only be
223  // applied to f32 vector types with two dimensions greater than one.
224  VectorType srcType = op.getSourceVectorType();
225  if (!srcType.getElementType().isF32())
226  return rewriter.notifyMatchFailure(op, "Unsupported vector element type");
227 
228  auto srcGtOneDims = mlir::vector::isTranspose2DSlice(op);
229  if (failed(srcGtOneDims))
230  return rewriter.notifyMatchFailure(
231  op, "expected transposition on a 2D slice");
232 
233  // Retrieve the sizes of the two dimensions greater than one to be
234  // transposed.
235  int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
236  int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
237 
238  auto applyRewrite = [&]() {
239  ImplicitLocOpBuilder ib(loc, rewriter);
241 
242  // Reshape the n-D input vector with only two dimensions greater than one
243  // to a 2-D vector.
244  auto flattenedType =
245  VectorType::get({n * m}, op.getSourceVectorType().getElementType());
246  auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
247  auto reshInput =
248  ib.create<vector::ShapeCastOp>(flattenedType, op.getVector());
249  reshInput = ib.create<vector::ShapeCastOp>(reshInputType, reshInput);
250 
251  // Extract 1-D vectors from the higher-order dimension of the input
252  // vector.
253  for (int64_t i = 0; i < m; ++i)
254  vs.push_back(ib.create<vector::ExtractOp>(reshInput, i));
255 
256  // Transpose set of 1-D vectors.
257  if (m == 4)
258  transpose4x8xf32(ib, vs);
259  if (m == 8)
260  transpose8x8xf32(ib, vs);
261 
262  // Insert transposed 1-D vectors into the higher-order dimension of the
263  // output vector.
264  Value res = ib.create<arith::ConstantOp>(reshInputType,
265  ib.getZeroAttr(reshInputType));
266  for (int64_t i = 0; i < m; ++i)
267  res = ib.create<vector::InsertOp>(vs[i], res, i);
268 
269  // The output vector still has the shape of the input vector (e.g., 4x8).
270  // We have to transpose their dimensions and retrieve its original rank
271  // (e.g., 1x8x1x4x1).
272  res = ib.create<vector::ShapeCastOp>(flattenedType, res);
273  res = ib.create<vector::ShapeCastOp>(op.getResultVectorType(), res);
274  rewriter.replaceOp(op, res);
275  return success();
276  };
277 
278  if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)
279  return applyRewrite();
280  if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)
281  return applyRewrite();
282  return failure();
283  }
284 
285 private:
286  LoweringOptions loweringOptions;
287 };
288 
291  patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit);
292 }
@ None
static llvm::ManagedStatic< PassManagerOptions > options
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context, int benefit)
LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:322
MLIRContext * getContext() const
Definition: Builders.h:55
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
Definition: VectorUtils.cpp:84
Value mm256BlendPsAsm(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Methods in the inline_asm namespace emit calls to LLVM::InlineAsmOp.
Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2)
Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
Value mm256BlendPs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
If bit i of mask is zero, take f32@i from v1 else take it from v2.
Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
a a b b a a b b Take an 8 bit mask, 2 bit for each position of a[0, 3) and b[0, 4): 0:127 | 128:255 b...
Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2)
Methods in the intrin namespace emulate clang's impl. of X86 intrinsics.
Helpers extracted from:
Definition: Transforms.h:94
void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef< Value > vs)
8x8xf32-specific AVX2 transpose lowering.
void populateSpecializedTransposeLoweringPatterns(RewritePatternSet &patterns, LoweringOptions options=LoweringOptions(), int benefit=10)
Insert specialized transpose lowering patterns.
void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef< Value > vs)
Generic lowerings may either use intrin or inline_asm depending on needs.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Options for controlling specialized AVX2 lowerings.
Definition: Transforms.h:159