MLIR  16.0.0git
AVXTranspose.cpp
Go to the documentation of this file.
1 //===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements vector.transpose rewrites as AVX patterns for particular
10 // sizes of interest.
11 //
12 //===----------------------------------------------------------------------===//
13 
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "llvm/Support/Format.h"
22 #include "llvm/Support/FormatVariadic.h"
23 
24 using namespace mlir;
25 using namespace mlir::vector;
26 using namespace mlir::x86vector;
27 using namespace mlir::x86vector::avx2;
28 using namespace mlir::x86vector::avx2::inline_asm;
29 using namespace mlir::x86vector::avx2::intrin;
30 
32  ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
33  auto asmDialectAttr =
34  LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
35  const auto *asmTp = "vblendps $0, $1, $2, {0}";
36  const auto *asmCstr =
37  "=x,x,x"; // Careful: constraint parser is very brittle: no ws!
38  SmallVector<Value> asmVals{v1, v2};
39  auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str();
40  auto asmOp = b.create<LLVM::InlineAsmOp>(
41  v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr,
42  /*constraints=*/asmCstr, /*has_side_effects=*/false,
43  /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
44  /*operand_attrs=*/ArrayAttr());
45  return asmOp.getResult(0);
46 }
47 
49  Value v1, Value v2) {
50  return b.create<vector::ShuffleOp>(
51  v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
52 }
53 
55  Value v1, Value v2) {
56  return b.create<vector::ShuffleOp>(
57  v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
58 }
59 /// a a b b a a b b
60 /// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
61 /// 0:127 | 128:255
62 /// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
64  Value v1, Value v2,
65  uint8_t mask) {
66  uint8_t b01, b23, b45, b67;
67  MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
68  SmallVector<int64_t> shuffleMask{b01, b23, b45 + 8, b67 + 8,
69  b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
70  return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
71 }
72 
73 // imm[0:1] out of imm[0:3] is:
74 // 0 1 2 3
75 // a[0:127] or a[128:255] or b[0:127] or b[128:255] |
76 // a[0:127] or a[128:255] or b[0:127] or b[128:255]
77 // 0 1 2 3
78 // imm[0:1] out of imm[4:7].
80  ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
81  SmallVector<int64_t> shuffleMask;
82  auto appendToMask = [&](uint8_t control) {
83  if (control == 0)
84  llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3});
85  else if (control == 1)
86  llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7});
87  else if (control == 2)
88  llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11});
89  else if (control == 3)
90  llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15});
91  else
92  llvm_unreachable("control > 3 : overflow");
93  };
94  uint8_t b03, b47;
95  MaskHelper::extractPermute(mask, b03, b47);
96  appendToMask(b03);
97  appendToMask(b47);
98  return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
99 }
100 
101 /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
103  Value v1, Value v2,
104  uint8_t mask) {
105  SmallVector<int64_t, 8> shuffleMask;
106  for (int i = 0; i < 8; ++i) {
107  bool isSet = mask & (1 << i);
108  shuffleMask.push_back(!isSet ? i : i + 8);
109  }
110  return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
111 }
112 
113 /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
116 #ifndef NDEBUG
117  auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
118  assert(vs.size() == 4 && "expects 4 vectors");
119  assert(llvm::all_of(ValueRange{vs}.getTypes(),
120  [&](Type t) { return t == vt; }) &&
121  "expects all types to be vector<8xf32>");
122 #endif
123 
124  Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
125  Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
126  Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
127  Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
128  Value s0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 1, 0>());
129  Value s1 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<3, 2, 3, 2>());
130  Value s2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 1, 0>());
131  Value s3 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<3, 2, 3, 2>());
132  vs[0] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<2, 0>());
133  vs[1] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<2, 0>());
134  vs[2] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<3, 1>());
135  vs[3] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<3, 1>());
136 }
137 
138 /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
141  auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
142  (void)vt;
143  assert(vs.size() == 8 && "expects 8 vectors");
144  assert(llvm::all_of(ValueRange{vs}.getTypes(),
145  [&](Type t) { return t == vt; }) &&
146  "expects all types to be vector<8xf32>");
147 
148  Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
149  Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
150  Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
151  Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
152  Value t4 = mm256UnpackLoPs(ib, vs[4], vs[5]);
153  Value t5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
154  Value t6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
155  Value t7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
156 
158  Value sh0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 3, 2>());
159  Value sh2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 3, 2>());
160  Value sh4 = mm256ShufflePs(ib, t4, t6, MaskHelper::shuffle<1, 0, 3, 2>());
161  Value sh6 = mm256ShufflePs(ib, t5, t7, MaskHelper::shuffle<1, 0, 3, 2>());
162 
163  Value s0 =
164  mm256BlendPsAsm(ib, t0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
165  Value s1 =
166  mm256BlendPsAsm(ib, t2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
167  Value s2 =
168  mm256BlendPsAsm(ib, t1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
169  Value s3 =
170  mm256BlendPsAsm(ib, t3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
171  Value s4 =
172  mm256BlendPsAsm(ib, t4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
173  Value s5 =
174  mm256BlendPsAsm(ib, t6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
175  Value s6 =
176  mm256BlendPsAsm(ib, t5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
177  Value s7 =
178  mm256BlendPsAsm(ib, t7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
179 
180  vs[0] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<2, 0>());
181  vs[1] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<2, 0>());
182  vs[2] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<2, 0>());
183  vs[3] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<2, 0>());
184  vs[4] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<3, 1>());
185  vs[5] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<3, 1>());
186  vs[6] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<3, 1>());
187  vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>());
188 }
189 
190 /// Given the n-D transpose pattern 'transp', return true if 'dim0' and 'dim1'
191 /// should be transposed with each other within the context of their 2D
192 /// transposition slice.
193 ///
194 /// Example 1: dim0 = 0, dim1 = 2, transp = [2, 1, 0]
195 /// Return true: dim0 and dim1 are transposed within the context of their 2D
196 /// transposition slice ([1, 0]).
197 ///
198 /// Example 2: dim0 = 0, dim1 = 1, transp = [2, 1, 0]
199 /// Return true: dim0 and dim1 are transposed within the context of their 2D
200 /// transposition slice ([1, 0]). Paradoxically, note how dim1 (1) is *not*
201 /// transposed within the full context of the transposition.
202 ///
203 /// Example 3: dim0 = 0, dim1 = 1, transp = [2, 0, 1]
204 /// Return false: dim0 and dim1 are *not* transposed within the context of
205 /// their 2D transposition slice ([0, 1]). Paradoxically, note how dim0 (0)
206 /// and dim1 (1) are transposed within the full context of the of the
207 /// transposition.
208 static bool areDimsTransposedIn2DSlice(int64_t dim0, int64_t dim1,
209  ArrayRef<int64_t> transp) {
210  // Perform a linear scan along the dimensions of the transposed pattern. If
211  // dim0 is found first, dim0 and dim1 are not transposed within the context of
212  // their 2D slice. Otherwise, 'dim1' is found first and they are transposed.
213  for (int64_t permDim : transp) {
214  if (permDim == dim0)
215  return false;
216  if (permDim == dim1)
217  return true;
218  }
219 
220  llvm_unreachable("Ill-formed transpose pattern");
221 }
222 
223 /// Rewrite AVX2-specific vector.transpose, for the supported cases and
224 /// depending on the `TransposeLoweringOptions`. The lowering supports 2-D
225 /// transpose cases and n-D cases that have been decomposed into 2-D
226 /// transposition slices. For example, a 3-D transpose:
227 ///
228 /// %0 = vector.transpose %arg0, [2, 0, 1]
229 /// : vector<1024x2048x4096xf32> to vector<4096x1024x2048xf32>
230 ///
231 /// could be sliced into 2-D transposes by tiling two of its dimensions to one
232 /// of the vector lengths supported by the AVX2 patterns (e.g., 4x8):
233 ///
234 /// %0 = vector.transpose %arg0, [2, 0, 1]
235 /// : vector<1x4x8xf32> to vector<8x1x4xf32>
236 ///
237 /// This lowering will analyze the n-D vector.transpose and determine if it's a
238 /// supported 2-D transposition slice where any of the AVX2 patterns can be
239 /// applied.
240 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
241 public:
243 
245  int benefit)
246  : OpRewritePattern<vector::TransposeOp>(context, benefit),
247  loweringOptions(loweringOptions) {}
248 
249  LogicalResult matchAndRewrite(vector::TransposeOp op,
250  PatternRewriter &rewriter) const override {
251  auto loc = op.getLoc();
252 
253  // Check if the source vector type is supported. AVX2 patterns can only be
254  // applied to f32 vector types with two dimensions greater than one.
255  VectorType srcType = op.getVectorType();
256  if (!srcType.getElementType().isF32())
257  return rewriter.notifyMatchFailure(op, "Unsupported vector element type");
258 
259  SmallVector<int64_t> srcGtOneDims;
260  for (auto &en : llvm::enumerate(srcType.getShape()))
261  if (en.value() > 1)
262  srcGtOneDims.push_back(en.index());
263 
264  if (srcGtOneDims.size() != 2)
265  return rewriter.notifyMatchFailure(op, "Unsupported vector type");
266 
268  for (auto attr : op.getTransp())
269  transp.push_back(attr.cast<IntegerAttr>().getInt());
270 
271  // Check whether the two source vector dimensions that are greater than one
272  // must be transposed with each other so that we can apply one of the 2-D
273  // AVX2 transpose pattens. Otherwise, these patterns are not applicable.
274  if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], transp))
275  return rewriter.notifyMatchFailure(
276  op, "Not applicable to this transpose permutation");
277 
278  // Retrieve the sizes of the two dimensions greater than one to be
279  // transposed.
280  auto srcShape = srcType.getShape();
281  int64_t m = srcShape[srcGtOneDims[0]], n = srcShape[srcGtOneDims[1]];
282 
283  auto applyRewrite = [&]() {
284  ImplicitLocOpBuilder ib(loc, rewriter);
286 
287  // Reshape the n-D input vector with only two dimensions greater than one
288  // to a 2-D vector.
289  auto flattenedType =
290  VectorType::get({n * m}, op.getVectorType().getElementType());
291  auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
292  auto reshInput =
293  ib.create<vector::ShapeCastOp>(flattenedType, op.getVector());
294  reshInput = ib.create<vector::ShapeCastOp>(reshInputType, reshInput);
295 
296  // Extract 1-D vectors from the higher-order dimension of the input
297  // vector.
298  for (int64_t i = 0; i < m; ++i)
299  vs.push_back(ib.create<vector::ExtractOp>(reshInput, i));
300 
301  // Transpose set of 1-D vectors.
302  if (m == 4)
303  transpose4x8xf32(ib, vs);
304  if (m == 8)
305  transpose8x8xf32(ib, vs);
306 
307  // Insert transposed 1-D vectors into the higher-order dimension of the
308  // output vector.
309  Value res = ib.create<arith::ConstantOp>(reshInputType,
310  ib.getZeroAttr(reshInputType));
311  for (int64_t i = 0; i < m; ++i)
312  res = ib.create<vector::InsertOp>(vs[i], res, i);
313 
314  // The output vector still has the shape of the input vector (e.g., 4x8).
315  // We have to transpose their dimensions and retrieve its original rank
316  // (e.g., 1x8x1x4x1).
317  res = ib.create<vector::ShapeCastOp>(flattenedType, res);
318  res = ib.create<vector::ShapeCastOp>(op.getResultType(), res);
319  rewriter.replaceOp(op, res);
320  return success();
321  };
322 
323  if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)
324  return applyRewrite();
325  if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)
326  return applyRewrite();
327  return failure();
328  }
329 
330 private:
331  LoweringOptions loweringOptions;
332 };
333 
335  RewritePatternSet &patterns, LoweringOptions options, int benefit) {
336  patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit);
337 }
Include the generated interface declarations.
Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2)
Methods in the intrin namespace emulate clang&#39;s impl. of X86 intrinsics.
MLIRContext * getContext() const
Definition: Builders.h:54
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Options for controlling specialized AVX2 lowerings.
Definition: Transforms.h:159
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
Helpers extracted from:
Definition: Transforms.h:94
LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:292
static bool areDimsTransposedIn2DSlice(int64_t dim0, int64_t dim1, ArrayRef< int64_t > transp)
Given the n-D transpose pattern &#39;transp&#39;, return true if &#39;dim0&#39; and &#39;dim1&#39; should be transposed with ...
void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef< Value > vs)
Generic lowerings may either use intrin or inline_asm depending on needs.
Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
a a b b a a b b Take an 8 bit mask, 2 bit for each position of a[0, 3) and b[0, 4): 0:127 | 128:255 b...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef< Value > vs)
8x8xf32-specific AVX2 transpose lowering.
void populateSpecializedTransposeLoweringPatterns(RewritePatternSet &patterns, LoweringOptions options=LoweringOptions(), int benefit=10)
Insert specialized transpose lowering patterns.
Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2)
Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
static llvm::ManagedStatic< PassManagerOptions > options
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
Type getType() const
Return the type of this value.
Definition: Value.h:118
ImplicitLocOpBuilder maintains a &#39;current location&#39;, allowing use of the create<> method without spec...
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context, int benefit)
Value mm256BlendPs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
If bit i of mask is zero, take f32 from v1 else take it from v2.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:512
Value mm256BlendPsAsm(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Methods in the inline_asm namespace emit calls to LLVM::InlineAsmOp.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
MLIRContext * getContext() const