MLIR  19.0.0git
AVXTranspose.cpp
Go to the documentation of this file.
1 //===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements vector.transpose rewrites as AVX patterns for particular
10 // sizes of interest.
11 //
12 //===----------------------------------------------------------------------===//
13 
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/Support/Format.h"
23 #include "llvm/Support/FormatVariadic.h"
24 
25 using namespace mlir;
26 using namespace mlir::vector;
27 using namespace mlir::x86vector;
28 using namespace mlir::x86vector::avx2;
29 using namespace mlir::x86vector::avx2::inline_asm;
30 using namespace mlir::x86vector::avx2::intrin;
31 
33  ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
34  auto asmDialectAttr =
35  LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
36  const auto *asmTp = "vblendps $0, $1, $2, {0}";
37  const auto *asmCstr =
38  "=x,x,x"; // Careful: constraint parser is very brittle: no ws!
39  SmallVector<Value> asmVals{v1, v2};
40  auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str();
41  auto asmOp = b.create<LLVM::InlineAsmOp>(
42  v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr,
43  /*constraints=*/asmCstr, /*has_side_effects=*/false,
44  /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
45  /*operand_attrs=*/ArrayAttr());
46  return asmOp.getResult(0);
47 }
48 
50  Value v1, Value v2) {
51  return b.create<vector::ShuffleOp>(
52  v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
53 }
54 
56  Value v1, Value v2) {
57  return b.create<vector::ShuffleOp>(
58  v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
59 }
60 /// a a b b a a b b
61 /// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
62 /// 0:127 | 128:255
63 /// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
65  Value v1, Value v2,
66  uint8_t mask) {
67  uint8_t b01, b23, b45, b67;
68  MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
69  SmallVector<int64_t> shuffleMask{b01, b23, b45 + 8, b67 + 8,
70  b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
71  return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
72 }
73 
74 // imm[0:1] out of imm[0:3] is:
75 // 0 1 2 3
76 // a[0:127] or a[128:255] or b[0:127] or b[128:255] |
77 // a[0:127] or a[128:255] or b[0:127] or b[128:255]
78 // 0 1 2 3
79 // imm[0:1] out of imm[4:7].
81  ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
82  SmallVector<int64_t> shuffleMask;
83  auto appendToMask = [&](uint8_t control) {
84  if (control == 0)
85  llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3});
86  else if (control == 1)
87  llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7});
88  else if (control == 2)
89  llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11});
90  else if (control == 3)
91  llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15});
92  else
93  llvm_unreachable("control > 3 : overflow");
94  };
95  uint8_t b03, b47;
96  MaskHelper::extractPermute(mask, b03, b47);
97  appendToMask(b03);
98  appendToMask(b47);
99  return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
100 }
101 
102 /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
104  Value v1, Value v2,
105  uint8_t mask) {
106  SmallVector<int64_t, 8> shuffleMask;
107  for (int i = 0; i < 8; ++i) {
108  bool isSet = mask & (1 << i);
109  shuffleMask.push_back(!isSet ? i : i + 8);
110  }
111  return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
112 }
113 
114 /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
117 #ifndef NDEBUG
118  auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
119  assert(vs.size() == 4 && "expects 4 vectors");
120  assert(llvm::all_of(ValueRange{vs}.getTypes(),
121  [&](Type t) { return t == vt; }) &&
122  "expects all types to be vector<8xf32>");
123 #endif
124 
125  Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
126  Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
127  Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
128  Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
129  Value s0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 1, 0>());
130  Value s1 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<3, 2, 3, 2>());
131  Value s2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 1, 0>());
132  Value s3 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<3, 2, 3, 2>());
133  vs[0] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<2, 0>());
134  vs[1] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<2, 0>());
135  vs[2] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<3, 1>());
136  vs[3] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<3, 1>());
137 }
138 
139 /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
142  auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
143  (void)vt;
144  assert(vs.size() == 8 && "expects 8 vectors");
145  assert(llvm::all_of(ValueRange{vs}.getTypes(),
146  [&](Type t) { return t == vt; }) &&
147  "expects all types to be vector<8xf32>");
148 
149  Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
150  Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
151  Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
152  Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
153  Value t4 = mm256UnpackLoPs(ib, vs[4], vs[5]);
154  Value t5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
155  Value t6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
156  Value t7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
157 
159  Value sh0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 3, 2>());
160  Value sh2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 3, 2>());
161  Value sh4 = mm256ShufflePs(ib, t4, t6, MaskHelper::shuffle<1, 0, 3, 2>());
162  Value sh6 = mm256ShufflePs(ib, t5, t7, MaskHelper::shuffle<1, 0, 3, 2>());
163 
164  Value s0 =
165  mm256BlendPsAsm(ib, t0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
166  Value s1 =
167  mm256BlendPsAsm(ib, t2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
168  Value s2 =
169  mm256BlendPsAsm(ib, t1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
170  Value s3 =
171  mm256BlendPsAsm(ib, t3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
172  Value s4 =
173  mm256BlendPsAsm(ib, t4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
174  Value s5 =
175  mm256BlendPsAsm(ib, t6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
176  Value s6 =
177  mm256BlendPsAsm(ib, t5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
178  Value s7 =
179  mm256BlendPsAsm(ib, t7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
180 
181  vs[0] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<2, 0>());
182  vs[1] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<2, 0>());
183  vs[2] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<2, 0>());
184  vs[3] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<2, 0>());
185  vs[4] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<3, 1>());
186  vs[5] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<3, 1>());
187  vs[6] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<3, 1>());
188  vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>());
189 }
190 
191 /// Rewrite AVX2-specific vector.transpose, for the supported cases and
192 /// depending on the `TransposeLoweringOptions`. The lowering supports 2-D
193 /// transpose cases and n-D cases that have been decomposed into 2-D
194 /// transposition slices. For example, a 3-D transpose:
195 ///
196 /// %0 = vector.transpose %arg0, [2, 0, 1]
197 /// : vector<1024x2048x4096xf32> to vector<4096x1024x2048xf32>
198 ///
199 /// could be sliced into 2-D transposes by tiling two of its dimensions to one
200 /// of the vector lengths supported by the AVX2 patterns (e.g., 4x8):
201 ///
202 /// %0 = vector.transpose %arg0, [2, 0, 1]
203 /// : vector<1x4x8xf32> to vector<8x1x4xf32>
204 ///
205 /// This lowering will analyze the n-D vector.transpose and determine if it's a
206 /// supported 2-D transposition slice where any of the AVX2 patterns can be
207 /// applied.
208 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
209 public:
211 
213  int benefit)
214  : OpRewritePattern<vector::TransposeOp>(context, benefit),
215  loweringOptions(loweringOptions) {}
216 
217  LogicalResult matchAndRewrite(vector::TransposeOp op,
218  PatternRewriter &rewriter) const override {
219  auto loc = op.getLoc();
220 
221  // Check if the source vector type is supported. AVX2 patterns can only be
222  // applied to f32 vector types with two dimensions greater than one.
223  VectorType srcType = op.getSourceVectorType();
224  if (!srcType.getElementType().isF32())
225  return rewriter.notifyMatchFailure(op, "Unsupported vector element type");
226 
227  auto srcGtOneDims = mlir::vector::isTranspose2DSlice(op);
228  if (failed(srcGtOneDims))
229  return rewriter.notifyMatchFailure(
230  op, "expected transposition on a 2D slice");
231 
232  // Retrieve the sizes of the two dimensions greater than one to be
233  // transposed.
234  int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
235  int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
236 
237  auto applyRewrite = [&]() {
238  ImplicitLocOpBuilder ib(loc, rewriter);
240 
241  // Reshape the n-D input vector with only two dimensions greater than one
242  // to a 2-D vector.
243  auto flattenedType =
244  VectorType::get({n * m}, op.getSourceVectorType().getElementType());
245  auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
246  auto reshInput =
247  ib.create<vector::ShapeCastOp>(flattenedType, op.getVector());
248  reshInput = ib.create<vector::ShapeCastOp>(reshInputType, reshInput);
249 
250  // Extract 1-D vectors from the higher-order dimension of the input
251  // vector.
252  for (int64_t i = 0; i < m; ++i)
253  vs.push_back(ib.create<vector::ExtractOp>(reshInput, i));
254 
255  // Transpose set of 1-D vectors.
256  if (m == 4)
257  transpose4x8xf32(ib, vs);
258  if (m == 8)
259  transpose8x8xf32(ib, vs);
260 
261  // Insert transposed 1-D vectors into the higher-order dimension of the
262  // output vector.
263  Value res = ib.create<arith::ConstantOp>(reshInputType,
264  ib.getZeroAttr(reshInputType));
265  for (int64_t i = 0; i < m; ++i)
266  res = ib.create<vector::InsertOp>(vs[i], res, i);
267 
268  // The output vector still has the shape of the input vector (e.g., 4x8).
269  // We have to transpose their dimensions and retrieve its original rank
270  // (e.g., 1x8x1x4x1).
271  res = ib.create<vector::ShapeCastOp>(flattenedType, res);
272  res = ib.create<vector::ShapeCastOp>(op.getResultVectorType(), res);
273  rewriter.replaceOp(op, res);
274  return success();
275  };
276 
277  if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)
278  return applyRewrite();
279  if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)
280  return applyRewrite();
281  return failure();
282  }
283 
284 private:
285  LoweringOptions loweringOptions;
286 };
287 
289  RewritePatternSet &patterns, LoweringOptions options, int benefit) {
290  patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit);
291 }
static llvm::ManagedStatic< PassManagerOptions > options
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context, int benefit)
LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
MLIRContext * getContext() const
Definition: Builders.h:55
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
Definition: VectorUtils.cpp:80
Value mm256BlendPsAsm(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Methods in the inline_asm namespace emit calls to LLVM::InlineAsmOp.
Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2)
Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
Value mm256BlendPs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
If bit i of mask is zero, take f32@i from v1 else take it from v2.
Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
a a b b a a b b Take an 8 bit mask, 2 bit for each position of a[0, 3) and b[0, 4): 0:127 | 128:255 b...
Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2)
Methods in the intrin namespace emulate clang's impl. of X86 intrinsics.
Helpers extracted from:
Definition: Transforms.h:94
void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef< Value > vs)
8x8xf32-specific AVX2 transpose lowering.
void populateSpecializedTransposeLoweringPatterns(RewritePatternSet &patterns, LoweringOptions options=LoweringOptions(), int benefit=10)
Insert specialized transpose lowering patterns.
void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef< Value > vs)
Generic lowerings may either use intrin or inline_asm depending on needs.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Options for controlling specialized AVX2 lowerings.
Definition: Transforms.h:159