MLIR  20.0.0git
LowerVectorTranspose.cpp
Go to the documentation of this file.
1 //===- LowerVectorTranspose.cpp - Lower 'vector.transpose' operation ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.transpose' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
27 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
34 
35 #define DEBUG_TYPE "lower-vector-transpose"
36 
37 using namespace mlir;
38 using namespace mlir::vector;
39 
40 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not
41 /// transposed.
43  SmallVectorImpl<int64_t> &result) {
44  size_t numTransposedDims = transpose.size();
45  for (size_t transpDim : llvm::reverse(transpose)) {
46  if (transpDim != numTransposedDims - 1)
47  break;
48  numTransposedDims--;
49  }
50 
51  result.append(transpose.begin(), transpose.begin() + numTransposedDims);
52 }
53 
54 /// Returns true if the lowering option is a vector shuffle based approach.
55 static bool isShuffleLike(VectorTransposeLowering lowering) {
56  return lowering == VectorTransposeLowering::Shuffle1D ||
57  lowering == VectorTransposeLowering::Shuffle16x16;
58 }
59 
60 /// Returns a shuffle mask that builds on `vals`. `vals` is the offset base of
61 /// shuffle ops, i.e., the unpack pattern. The method iterates with `vals` to
62 /// create the mask for `numBits` bits vector. The `numBits` have to be a
63 /// multiple of 128. For example, if `vals` is {0, 1, 16, 17} and `numBits` is
64 /// 512, there should be 16 elements in the final result. It constructs the
65 /// below mask to get the unpack elements.
66 /// [0, 1, 16, 17,
67 /// 0+4, 1+4, 16+4, 17+4,
68 /// 0+8, 1+8, 16+8, 17+8,
69 /// 0+12, 1+12, 16+12, 17+12]
72  assert(numBits % 128 == 0 && "expected numBits is a multiple of 128");
73  int numElem = numBits / 32;
75  for (int i = 0; i < numElem; i += 4)
76  for (int64_t v : vals)
77  res.push_back(v + i);
78  return res;
79 }
80 
81 /// Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask. For
82 /// example, if it is targeting 512 bit vector, returns
83 /// vector.shuffle on v1, v2, [0, 1, 16, 17,
84 /// 0+4, 1+4, 16+4, 17+4,
85 /// 0+8, 1+8, 16+8, 17+8,
86 /// 0+12, 1+12, 16+12, 17+12].
88  int numBits) {
89  int numElem = numBits / 32;
90  return b.create<vector::ShuffleOp>(
91  v1, v2,
92  getUnpackShufflePermFor128Lane({0, 1, numElem, numElem + 1}, numBits));
93 }
94 
95 /// Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask. For
96 /// example, if it is targeting 512 bit vector, returns
97 /// vector.shuffle, v1, v2, [2, 3, 18, 19,
98 /// 2+4, 3+4, 18+4, 19+4,
99 /// 2+8, 3+8, 18+8, 19+8,
100 /// 2+12, 3+12, 18+12, 19+12].
102  int numBits) {
103  int numElem = numBits / 32;
104  return b.create<vector::ShuffleOp>(
105  v1, v2,
106  getUnpackShufflePermFor128Lane({2, 3, numElem + 2, numElem + 3},
107  numBits));
108 }
109 
110 /// Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask. For
111 /// example, if it is targeting 512 bit vector, returns
112 /// vector.shuffle, v1, v2, [0, 16, 1, 17,
113 /// 0+4, 16+4, 1+4, 17+4,
114 /// 0+8, 16+8, 1+8, 17+8,
115 /// 0+12, 16+12, 1+12, 17+12].
117  int numBits) {
118  int numElem = numBits / 32;
119  auto shuffle = b.create<vector::ShuffleOp>(
120  v1, v2,
121  getUnpackShufflePermFor128Lane({0, numElem, 1, numElem + 1}, numBits));
122  return shuffle;
123 }
124 
125 /// Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask. For
126 /// example, if it is targeting 512 bit vector, returns
127 /// vector.shuffle, v1, v2, [2, 18, 3, 19,
128 /// 2+4, 18+4, 3+4, 19+4,
129 /// 2+8, 18+8, 3+8, 19+8,
130 /// 2+12, 18+12, 3+12, 19+12].
132  int numBits) {
133  int numElem = numBits / 32;
134  return b.create<vector::ShuffleOp>(
135  v1, v2,
136  getUnpackShufflePermFor128Lane({2, numElem + 2, 3, numElem + 3},
137  numBits));
138 }
139 
140 /// Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit
141 /// elements) selected by `mask` from `v1` and `v2`. I.e.,
142 ///
143 /// DEFINE SELECT4(src, control) {
144 /// CASE(control[1:0]) OF
145 /// 0: tmp[127:0] := src[127:0]
146 /// 1: tmp[127:0] := src[255:128]
147 /// 2: tmp[127:0] := src[383:256]
148 /// 3: tmp[127:0] := src[511:384]
149 /// ESAC
150 /// RETURN tmp[127:0]
151 /// }
152 /// dst[127:0] := SELECT4(v1[511:0], mask[1:0])
153 /// dst[255:128] := SELECT4(v1[511:0], mask[3:2])
154 /// dst[383:256] := SELECT4(v2[511:0], mask[5:4])
155 /// dst[511:384] := SELECT4(v2[511:0], mask[7:6])
157  uint8_t mask) {
158  assert(cast<VectorType>(v1.getType()).getShape()[0] == 16 &&
159  "expected a vector with length=16");
160  SmallVector<int64_t> shuffleMask;
161  auto appendToMask = [&](int64_t base, uint8_t control) {
162  switch (control) {
163  case 0:
164  llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 0, base + 1,
165  base + 2, base + 3});
166  break;
167  case 1:
168  llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 4, base + 5,
169  base + 6, base + 7});
170  break;
171  case 2:
172  llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 8, base + 9,
173  base + 10, base + 11});
174  break;
175  case 3:
176  llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 12, base + 13,
177  base + 14, base + 15});
178  break;
179  default:
180  llvm_unreachable("control > 3 : overflow");
181  }
182  };
183  uint8_t b01 = mask & 0x3;
184  uint8_t b23 = (mask >> 2) & 0x3;
185  uint8_t b45 = (mask >> 4) & 0x3;
186  uint8_t b67 = (mask >> 6) & 0x3;
187  appendToMask(0, b01);
188  appendToMask(0, b23);
189  appendToMask(16, b45);
190  appendToMask(16, b67);
191  return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
192 }
193 
194 /// Lowers the value to a vector.shuffle op. The `source` is expected to be a
195 /// 1-D vector and have `m`x`n` elements.
196 static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n) {
198  mask.reserve(m * n);
199  for (int64_t j = 0; j < n; ++j)
200  for (int64_t i = 0; i < m; ++i)
201  mask.push_back(i * n + j);
202  return b.create<vector::ShuffleOp>(source.getLoc(), source, source, mask);
203 }
204 
205 /// Lowers the value to a sequence of vector.shuffle ops. The `source` is
206 /// expected to be a 16x16 vector.
207 static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
208  int n) {
209  ImplicitLocOpBuilder b(source.getLoc(), builder);
211  for (int64_t i = 0; i < m; ++i)
212  vs.push_back(b.create<vector::ExtractOp>(source, i));
213 
214  // Interleave 32-bit lanes using
215  // 8x _mm512_unpacklo_epi32
216  // 8x _mm512_unpackhi_epi32
217  Value t0 = createUnpackLoPs(b, vs[0x0], vs[0x1], 512);
218  Value t1 = createUnpackHiPs(b, vs[0x0], vs[0x1], 512);
219  Value t2 = createUnpackLoPs(b, vs[0x2], vs[0x3], 512);
220  Value t3 = createUnpackHiPs(b, vs[0x2], vs[0x3], 512);
221  Value t4 = createUnpackLoPs(b, vs[0x4], vs[0x5], 512);
222  Value t5 = createUnpackHiPs(b, vs[0x4], vs[0x5], 512);
223  Value t6 = createUnpackLoPs(b, vs[0x6], vs[0x7], 512);
224  Value t7 = createUnpackHiPs(b, vs[0x6], vs[0x7], 512);
225  Value t8 = createUnpackLoPs(b, vs[0x8], vs[0x9], 512);
226  Value t9 = createUnpackHiPs(b, vs[0x8], vs[0x9], 512);
227  Value ta = createUnpackLoPs(b, vs[0xa], vs[0xb], 512);
228  Value tb = createUnpackHiPs(b, vs[0xa], vs[0xb], 512);
229  Value tc = createUnpackLoPs(b, vs[0xc], vs[0xd], 512);
230  Value td = createUnpackHiPs(b, vs[0xc], vs[0xd], 512);
231  Value te = createUnpackLoPs(b, vs[0xe], vs[0xf], 512);
232  Value tf = createUnpackHiPs(b, vs[0xe], vs[0xf], 512);
233 
234  // Interleave 64-bit lanes using
235  // 8x _mm512_unpacklo_epi64
236  // 8x _mm512_unpackhi_epi64
237  Value r0 = createUnpackLoPd(b, t0, t2, 512);
238  Value r1 = createUnpackHiPd(b, t0, t2, 512);
239  Value r2 = createUnpackLoPd(b, t1, t3, 512);
240  Value r3 = createUnpackHiPd(b, t1, t3, 512);
241  Value r4 = createUnpackLoPd(b, t4, t6, 512);
242  Value r5 = createUnpackHiPd(b, t4, t6, 512);
243  Value r6 = createUnpackLoPd(b, t5, t7, 512);
244  Value r7 = createUnpackHiPd(b, t5, t7, 512);
245  Value r8 = createUnpackLoPd(b, t8, ta, 512);
246  Value r9 = createUnpackHiPd(b, t8, ta, 512);
247  Value ra = createUnpackLoPd(b, t9, tb, 512);
248  Value rb = createUnpackHiPd(b, t9, tb, 512);
249  Value rc = createUnpackLoPd(b, tc, te, 512);
250  Value rd = createUnpackHiPd(b, tc, te, 512);
251  Value re = createUnpackLoPd(b, td, tf, 512);
252  Value rf = createUnpackHiPd(b, td, tf, 512);
253 
254  // Permute 128-bit lanes using
255  // 16x _mm512_shuffle_i32x4
256  t0 = create4x128BitSuffle(b, r0, r4, 0x88);
257  t1 = create4x128BitSuffle(b, r1, r5, 0x88);
258  t2 = create4x128BitSuffle(b, r2, r6, 0x88);
259  t3 = create4x128BitSuffle(b, r3, r7, 0x88);
260  t4 = create4x128BitSuffle(b, r0, r4, 0xdd);
261  t5 = create4x128BitSuffle(b, r1, r5, 0xdd);
262  t6 = create4x128BitSuffle(b, r2, r6, 0xdd);
263  t7 = create4x128BitSuffle(b, r3, r7, 0xdd);
264  t8 = create4x128BitSuffle(b, r8, rc, 0x88);
265  t9 = create4x128BitSuffle(b, r9, rd, 0x88);
266  ta = create4x128BitSuffle(b, ra, re, 0x88);
267  tb = create4x128BitSuffle(b, rb, rf, 0x88);
268  tc = create4x128BitSuffle(b, r8, rc, 0xdd);
269  td = create4x128BitSuffle(b, r9, rd, 0xdd);
270  te = create4x128BitSuffle(b, ra, re, 0xdd);
271  tf = create4x128BitSuffle(b, rb, rf, 0xdd);
272 
273  // Permute 256-bit lanes using again
274  // 16x _mm512_shuffle_i32x4
275  vs[0x0] = create4x128BitSuffle(b, t0, t8, 0x88);
276  vs[0x1] = create4x128BitSuffle(b, t1, t9, 0x88);
277  vs[0x2] = create4x128BitSuffle(b, t2, ta, 0x88);
278  vs[0x3] = create4x128BitSuffle(b, t3, tb, 0x88);
279  vs[0x4] = create4x128BitSuffle(b, t4, tc, 0x88);
280  vs[0x5] = create4x128BitSuffle(b, t5, td, 0x88);
281  vs[0x6] = create4x128BitSuffle(b, t6, te, 0x88);
282  vs[0x7] = create4x128BitSuffle(b, t7, tf, 0x88);
283  vs[0x8] = create4x128BitSuffle(b, t0, t8, 0xdd);
284  vs[0x9] = create4x128BitSuffle(b, t1, t9, 0xdd);
285  vs[0xa] = create4x128BitSuffle(b, t2, ta, 0xdd);
286  vs[0xb] = create4x128BitSuffle(b, t3, tb, 0xdd);
287  vs[0xc] = create4x128BitSuffle(b, t4, tc, 0xdd);
288  vs[0xd] = create4x128BitSuffle(b, t5, td, 0xdd);
289  vs[0xe] = create4x128BitSuffle(b, t6, te, 0xdd);
290  vs[0xf] = create4x128BitSuffle(b, t7, tf, 0xdd);
291 
292  auto reshInputType = VectorType::get(
293  {m, n}, cast<VectorType>(source.getType()).getElementType());
294  Value res =
295  b.create<arith::ConstantOp>(reshInputType, b.getZeroAttr(reshInputType));
296  for (int64_t i = 0; i < m; ++i)
297  res = b.create<vector::InsertOp>(vs[i], res, i);
298  return res;
299 }
300 
301 namespace {
302 /// Progressive lowering of TransposeOp.
303 /// One:
304 /// %x = vector.transpose %y, [1, 0]
305 /// is replaced by:
306 /// %z = arith.constant dense<0.000000e+00>
307 /// %0 = vector.extract %y[0, 0]
308 /// %1 = vector.insert %0, %z [0, 0]
309 /// ..
310 /// %x = vector.insert .., .. [.., ..]
311 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
312 public:
314 
316  MLIRContext *context, PatternBenefit benefit = 1)
317  : OpRewritePattern<vector::TransposeOp>(context, benefit),
318  vectorTransformOptions(vectorTransformOptions) {}
319 
320  LogicalResult matchAndRewrite(vector::TransposeOp op,
321  PatternRewriter &rewriter) const override {
322  auto loc = op.getLoc();
323 
324  Value input = op.getVector();
325  VectorType inputType = op.getSourceVectorType();
326  VectorType resType = op.getResultVectorType();
327 
328  if (inputType.isScalable())
329  return rewriter.notifyMatchFailure(
330  op, "This lowering does not support scalable vectors");
331 
332  // Set up convenience transposition table.
333  ArrayRef<int64_t> transp = op.getPermutation();
334 
335  if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
336  succeeded(isTranspose2DSlice(op)))
337  return rewriter.notifyMatchFailure(
338  op, "Options specifies lowering to shuffle");
339 
340  // Handle a true 2-D matrix transpose differently when requested.
341  if (vectorTransformOptions.vectorTransposeLowering ==
342  vector::VectorTransposeLowering::Flat &&
343  resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
344  Type flattenedType =
345  VectorType::get(resType.getNumElements(), resType.getElementType());
346  auto matrix =
347  rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
348  auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
349  auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
350  Value trans = rewriter.create<vector::FlatTransposeOp>(
351  loc, flattenedType, matrix, rows, columns);
352  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
353  return success();
354  }
355 
356  // Generate unrolled extract/insert ops. We do not unroll the rightmost
357  // (i.e., highest-order) dimensions that are not transposed and leave them
358  // in vector form to improve performance. Therefore, we prune those
359  // dimensions from the shape/transpose data structures used to generate the
360  // extract/insert ops.
361  SmallVector<int64_t> prunedTransp;
362  pruneNonTransposedDims(transp, prunedTransp);
363  size_t numPrunedDims = transp.size() - prunedTransp.size();
364  auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
365  auto prunedInStrides = computeStrides(prunedInShape);
366 
367  // Generates the extract/insert operations for every scalar/vector element
368  // of the leftmost transposed dimensions. We traverse every transpose
369  // element using a linearized index that we delinearize to generate the
370  // appropriate indices for the extract/insert operations.
371  Value result = rewriter.create<arith::ConstantOp>(
372  loc, resType, rewriter.getZeroAttr(resType));
373  int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
374 
375  for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
376  ++linearIdx) {
377  auto extractIdxs = delinearize(linearIdx, prunedInStrides);
378  SmallVector<int64_t> insertIdxs(extractIdxs);
379  applyPermutationToVector(insertIdxs, prunedTransp);
380  Value extractOp =
381  rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
382  result =
383  rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
384  }
385 
386  rewriter.replaceOp(op, result);
387  return success();
388  }
389 
390 private:
391  /// Options to control the vector patterns.
392  vector::VectorTransformsOptions vectorTransformOptions;
393 };
394 
395 /// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
396 /// to 2D vectors with at least one unit dim. For example:
397 ///
398 /// Replace:
399 /// vector.transpose %0, [1, 0] : vector<4x1xi32>> to
400 /// vector<1x4xi32>
401 /// with:
402 /// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
403 ///
404 /// Source with leading unit dim (inverse) is also replaced. Unit dim must
405 /// be fixed. Non-unit dim can be scalable.
406 ///
407 /// TODO: This pattern was introduced specifically to help lower scalable
408 /// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
409 /// to cancel out) would be preferable:
410 ///
411 /// BEFORE:
412 /// %0 = some_op
413 /// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
414 /// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
415 /// AFTER:
416 /// %0 = some_op
417 /// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
418 ///
419 /// Given the context above, we may want to consider (re-)moving this pattern
420 /// at some later time. I am leaving it for now in case there are other users
421 /// that I am not aware of.
422 class Transpose2DWithUnitDimToShapeCast
423  : public OpRewritePattern<vector::TransposeOp> {
424 public:
426 
427  Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
428  PatternBenefit benefit = 1)
429  : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
430 
431  LogicalResult matchAndRewrite(vector::TransposeOp op,
432  PatternRewriter &rewriter) const override {
433  Value input = op.getVector();
434  VectorType resType = op.getResultVectorType();
435 
436  // Set up convenience transposition table.
437  ArrayRef<int64_t> transp = op.getPermutation();
438 
439  if (resType.getRank() == 2 &&
440  ((resType.getShape().front() == 1 &&
441  !resType.getScalableDims().front()) ||
442  (resType.getShape().back() == 1 &&
443  !resType.getScalableDims().back())) &&
444  transp == ArrayRef<int64_t>({1, 0})) {
445  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
446  return success();
447  }
448 
449  return failure();
450  }
451 };
452 
453 /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
454 /// If the strategy is Shuffle1D, it will be lowered to:
455 /// vector.shape_cast 2D -> 1D
456 /// vector.shuffle
457 /// vector.shape_cast 1D -> 2D
458 /// If the strategy is Shuffle16x16, it will be lowered to a sequence of shuffle
459 /// ops on 16xf32 vectors.
460 class TransposeOp2DToShuffleLowering
461  : public OpRewritePattern<vector::TransposeOp> {
462 public:
464 
465  TransposeOp2DToShuffleLowering(
466  vector::VectorTransformsOptions vectorTransformOptions,
467  MLIRContext *context, PatternBenefit benefit = 1)
468  : OpRewritePattern<vector::TransposeOp>(context, benefit),
469  vectorTransformOptions(vectorTransformOptions) {}
470 
471  LogicalResult matchAndRewrite(vector::TransposeOp op,
472  PatternRewriter &rewriter) const override {
473  if (!isShuffleLike(vectorTransformOptions.vectorTransposeLowering))
474  return rewriter.notifyMatchFailure(
475  op, "not using vector shuffle based lowering");
476 
477  if (op.getSourceVectorType().isScalable())
478  return rewriter.notifyMatchFailure(
479  op, "vector shuffle lowering not supported for scalable vectors");
480 
481  auto srcGtOneDims = isTranspose2DSlice(op);
482  if (failed(srcGtOneDims))
483  return rewriter.notifyMatchFailure(
484  op, "expected transposition on a 2D slice");
485 
486  VectorType srcType = op.getSourceVectorType();
487  int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
488  int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
489 
490  // Reshape the n-D input vector with only two dimensions greater than one
491  // to a 2-D vector.
492  Location loc = op.getLoc();
493  auto flattenedType = VectorType::get({n * m}, srcType.getElementType());
494  auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
495  auto reshInput = rewriter.create<vector::ShapeCastOp>(loc, flattenedType,
496  op.getVector());
497 
498  Value res;
499  if (vectorTransformOptions.vectorTransposeLowering ==
500  VectorTransposeLowering::Shuffle16x16 &&
501  m == 16 && n == 16) {
502  reshInput =
503  rewriter.create<vector::ShapeCastOp>(loc, reshInputType, reshInput);
504  res = transposeToShuffle16x16(rewriter, reshInput, m, n);
505  } else {
506  // Fallback to shuffle on 1D approach.
507  res = transposeToShuffle1D(rewriter, reshInput, m, n);
508  }
509 
510  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
511  op, op.getResultVectorType(), res);
512 
513  return success();
514  }
515 
516 private:
517  /// Options to control the vector patterns.
518  vector::VectorTransformsOptions vectorTransformOptions;
519 };
520 } // namespace
521 
524  PatternBenefit benefit) {
525  patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
526  benefit);
527  patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
528  options, patterns.getContext(), benefit);
529 }
static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask.
static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask.
static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m, int n)
Lowers the value to a sequence of vector.shuffle ops.
static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask.
static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask.
static void pruneNonTransposedDims(ArrayRef< int64_t > transpose, SmallVectorImpl< int64_t > &result)
Given a 'transpose' pattern, prune the rightmost dimensions that are not transposed.
static bool isShuffleLike(VectorTransposeLowering lowering)
Returns true if the lowering option is a vector shuffle based approach.
static SmallVector< int64_t > getUnpackShufflePermFor128Lane(ArrayRef< int64_t > vals, int numBits)
Returns a shuffle mask that builds on vals.
static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit elements) selected by mask...
static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n)
Lowers the value to a vector.shuffle op.
static llvm::ManagedStatic< PassManagerOptions > options
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1545
int64_t rows
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:220
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:335
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:210
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
Definition: VectorUtils.cpp:84
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:21
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
Structure to control the behavior of vector transform patterns.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.