MLIR  21.0.0git
LowerVectorTranspose.cpp
Go to the documentation of this file.
1 //===- LowerVectorTranspose.cpp - Lower 'vector.transpose' operation ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.transpose' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
22 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Location.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
27 
28 #define DEBUG_TYPE "lower-vector-transpose"
29 
30 using namespace mlir;
31 using namespace mlir::vector;
32 
33 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not
34 /// transposed.
36  SmallVectorImpl<int64_t> &result) {
37  size_t numTransposedDims = transpose.size();
38  for (size_t transpDim : llvm::reverse(transpose)) {
39  if (transpDim != numTransposedDims - 1)
40  break;
41  numTransposedDims--;
42  }
43 
44  result.append(transpose.begin(), transpose.begin() + numTransposedDims);
45 }
46 
47 /// Returns true if the lowering option is a vector shuffle based approach.
48 static bool isShuffleLike(VectorTransposeLowering lowering) {
49  return lowering == VectorTransposeLowering::Shuffle1D ||
50  lowering == VectorTransposeLowering::Shuffle16x16;
51 }
52 
53 /// Returns a shuffle mask that builds on `vals`. `vals` is the offset base of
54 /// shuffle ops, i.e., the unpack pattern. The method iterates with `vals` to
55 /// create the mask for `numBits` bits vector. The `numBits` have to be a
56 /// multiple of 128. For example, if `vals` is {0, 1, 16, 17} and `numBits` is
57 /// 512, there should be 16 elements in the final result. It constructs the
58 /// below mask to get the unpack elements.
59 /// [0, 1, 16, 17,
60 /// 0+4, 1+4, 16+4, 17+4,
61 /// 0+8, 1+8, 16+8, 17+8,
62 /// 0+12, 1+12, 16+12, 17+12]
65  assert(numBits % 128 == 0 && "expected numBits is a multiple of 128");
66  int numElem = numBits / 32;
68  for (int i = 0; i < numElem; i += 4)
69  for (int64_t v : vals)
70  res.push_back(v + i);
71  return res;
72 }
73 
74 /// Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask. For
75 /// example, if it is targeting 512 bit vector, returns
76 /// vector.shuffle on v1, v2, [0, 1, 16, 17,
77 /// 0+4, 1+4, 16+4, 17+4,
78 /// 0+8, 1+8, 16+8, 17+8,
79 /// 0+12, 1+12, 16+12, 17+12].
81  int numBits) {
82  int numElem = numBits / 32;
83  return b.create<vector::ShuffleOp>(
84  v1, v2,
85  getUnpackShufflePermFor128Lane({0, 1, numElem, numElem + 1}, numBits));
86 }
87 
88 /// Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask. For
89 /// example, if it is targeting 512 bit vector, returns
90 /// vector.shuffle, v1, v2, [2, 3, 18, 19,
91 /// 2+4, 3+4, 18+4, 19+4,
92 /// 2+8, 3+8, 18+8, 19+8,
93 /// 2+12, 3+12, 18+12, 19+12].
95  int numBits) {
96  int numElem = numBits / 32;
97  return b.create<vector::ShuffleOp>(
98  v1, v2,
99  getUnpackShufflePermFor128Lane({2, 3, numElem + 2, numElem + 3},
100  numBits));
101 }
102 
103 /// Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask. For
104 /// example, if it is targeting 512 bit vector, returns
105 /// vector.shuffle, v1, v2, [0, 16, 1, 17,
106 /// 0+4, 16+4, 1+4, 17+4,
107 /// 0+8, 16+8, 1+8, 17+8,
108 /// 0+12, 16+12, 1+12, 17+12].
110  int numBits) {
111  int numElem = numBits / 32;
112  auto shuffle = b.create<vector::ShuffleOp>(
113  v1, v2,
114  getUnpackShufflePermFor128Lane({0, numElem, 1, numElem + 1}, numBits));
115  return shuffle;
116 }
117 
118 /// Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask. For
119 /// example, if it is targeting 512 bit vector, returns
120 /// vector.shuffle, v1, v2, [2, 18, 3, 19,
121 /// 2+4, 18+4, 3+4, 19+4,
122 /// 2+8, 18+8, 3+8, 19+8,
123 /// 2+12, 18+12, 3+12, 19+12].
125  int numBits) {
126  int numElem = numBits / 32;
127  return b.create<vector::ShuffleOp>(
128  v1, v2,
129  getUnpackShufflePermFor128Lane({2, numElem + 2, 3, numElem + 3},
130  numBits));
131 }
132 
133 /// Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit
134 /// elements) selected by `mask` from `v1` and `v2`. I.e.,
135 ///
136 /// DEFINE SELECT4(src, control) {
137 /// CASE(control[1:0]) OF
138 /// 0: tmp[127:0] := src[127:0]
139 /// 1: tmp[127:0] := src[255:128]
140 /// 2: tmp[127:0] := src[383:256]
141 /// 3: tmp[127:0] := src[511:384]
142 /// ESAC
143 /// RETURN tmp[127:0]
144 /// }
145 /// dst[127:0] := SELECT4(v1[511:0], mask[1:0])
146 /// dst[255:128] := SELECT4(v1[511:0], mask[3:2])
147 /// dst[383:256] := SELECT4(v2[511:0], mask[5:4])
148 /// dst[511:384] := SELECT4(v2[511:0], mask[7:6])
150  uint8_t mask) {
151  assert(cast<VectorType>(v1.getType()).getShape()[0] == 16 &&
152  "expected a vector with length=16");
153  SmallVector<int64_t> shuffleMask;
154  auto appendToMask = [&](int64_t base, uint8_t control) {
155  switch (control) {
156  case 0:
157  llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 0, base + 1,
158  base + 2, base + 3});
159  break;
160  case 1:
161  llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 4, base + 5,
162  base + 6, base + 7});
163  break;
164  case 2:
165  llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 8, base + 9,
166  base + 10, base + 11});
167  break;
168  case 3:
169  llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 12, base + 13,
170  base + 14, base + 15});
171  break;
172  default:
173  llvm_unreachable("control > 3 : overflow");
174  }
175  };
176  uint8_t b01 = mask & 0x3;
177  uint8_t b23 = (mask >> 2) & 0x3;
178  uint8_t b45 = (mask >> 4) & 0x3;
179  uint8_t b67 = (mask >> 6) & 0x3;
180  appendToMask(0, b01);
181  appendToMask(0, b23);
182  appendToMask(16, b45);
183  appendToMask(16, b67);
184  return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
185 }
186 
187 /// Lowers the value to a vector.shuffle op. The `source` is expected to be a
188 /// 1-D vector and have `m`x`n` elements.
189 static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n) {
191  mask.reserve(m * n);
192  for (int64_t j = 0; j < n; ++j)
193  for (int64_t i = 0; i < m; ++i)
194  mask.push_back(i * n + j);
195  return b.create<vector::ShuffleOp>(source.getLoc(), source, source, mask);
196 }
197 
198 /// Lowers the value to a sequence of vector.shuffle ops. The `source` is
199 /// expected to be a 16x16 vector.
200 static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
201  int n) {
202  ImplicitLocOpBuilder b(source.getLoc(), builder);
204  for (int64_t i = 0; i < m; ++i)
205  vs.push_back(b.createOrFold<vector::ExtractOp>(source, i));
206 
207  // Interleave 32-bit lanes using
208  // 8x _mm512_unpacklo_epi32
209  // 8x _mm512_unpackhi_epi32
210  Value t0 = createUnpackLoPs(b, vs[0x0], vs[0x1], 512);
211  Value t1 = createUnpackHiPs(b, vs[0x0], vs[0x1], 512);
212  Value t2 = createUnpackLoPs(b, vs[0x2], vs[0x3], 512);
213  Value t3 = createUnpackHiPs(b, vs[0x2], vs[0x3], 512);
214  Value t4 = createUnpackLoPs(b, vs[0x4], vs[0x5], 512);
215  Value t5 = createUnpackHiPs(b, vs[0x4], vs[0x5], 512);
216  Value t6 = createUnpackLoPs(b, vs[0x6], vs[0x7], 512);
217  Value t7 = createUnpackHiPs(b, vs[0x6], vs[0x7], 512);
218  Value t8 = createUnpackLoPs(b, vs[0x8], vs[0x9], 512);
219  Value t9 = createUnpackHiPs(b, vs[0x8], vs[0x9], 512);
220  Value ta = createUnpackLoPs(b, vs[0xa], vs[0xb], 512);
221  Value tb = createUnpackHiPs(b, vs[0xa], vs[0xb], 512);
222  Value tc = createUnpackLoPs(b, vs[0xc], vs[0xd], 512);
223  Value td = createUnpackHiPs(b, vs[0xc], vs[0xd], 512);
224  Value te = createUnpackLoPs(b, vs[0xe], vs[0xf], 512);
225  Value tf = createUnpackHiPs(b, vs[0xe], vs[0xf], 512);
226 
227  // Interleave 64-bit lanes using
228  // 8x _mm512_unpacklo_epi64
229  // 8x _mm512_unpackhi_epi64
230  Value r0 = createUnpackLoPd(b, t0, t2, 512);
231  Value r1 = createUnpackHiPd(b, t0, t2, 512);
232  Value r2 = createUnpackLoPd(b, t1, t3, 512);
233  Value r3 = createUnpackHiPd(b, t1, t3, 512);
234  Value r4 = createUnpackLoPd(b, t4, t6, 512);
235  Value r5 = createUnpackHiPd(b, t4, t6, 512);
236  Value r6 = createUnpackLoPd(b, t5, t7, 512);
237  Value r7 = createUnpackHiPd(b, t5, t7, 512);
238  Value r8 = createUnpackLoPd(b, t8, ta, 512);
239  Value r9 = createUnpackHiPd(b, t8, ta, 512);
240  Value ra = createUnpackLoPd(b, t9, tb, 512);
241  Value rb = createUnpackHiPd(b, t9, tb, 512);
242  Value rc = createUnpackLoPd(b, tc, te, 512);
243  Value rd = createUnpackHiPd(b, tc, te, 512);
244  Value re = createUnpackLoPd(b, td, tf, 512);
245  Value rf = createUnpackHiPd(b, td, tf, 512);
246 
247  // Permute 128-bit lanes using
248  // 16x _mm512_shuffle_i32x4
249  t0 = create4x128BitSuffle(b, r0, r4, 0x88);
250  t1 = create4x128BitSuffle(b, r1, r5, 0x88);
251  t2 = create4x128BitSuffle(b, r2, r6, 0x88);
252  t3 = create4x128BitSuffle(b, r3, r7, 0x88);
253  t4 = create4x128BitSuffle(b, r0, r4, 0xdd);
254  t5 = create4x128BitSuffle(b, r1, r5, 0xdd);
255  t6 = create4x128BitSuffle(b, r2, r6, 0xdd);
256  t7 = create4x128BitSuffle(b, r3, r7, 0xdd);
257  t8 = create4x128BitSuffle(b, r8, rc, 0x88);
258  t9 = create4x128BitSuffle(b, r9, rd, 0x88);
259  ta = create4x128BitSuffle(b, ra, re, 0x88);
260  tb = create4x128BitSuffle(b, rb, rf, 0x88);
261  tc = create4x128BitSuffle(b, r8, rc, 0xdd);
262  td = create4x128BitSuffle(b, r9, rd, 0xdd);
263  te = create4x128BitSuffle(b, ra, re, 0xdd);
264  tf = create4x128BitSuffle(b, rb, rf, 0xdd);
265 
266  // Permute 256-bit lanes using again
267  // 16x _mm512_shuffle_i32x4
268  vs[0x0] = create4x128BitSuffle(b, t0, t8, 0x88);
269  vs[0x1] = create4x128BitSuffle(b, t1, t9, 0x88);
270  vs[0x2] = create4x128BitSuffle(b, t2, ta, 0x88);
271  vs[0x3] = create4x128BitSuffle(b, t3, tb, 0x88);
272  vs[0x4] = create4x128BitSuffle(b, t4, tc, 0x88);
273  vs[0x5] = create4x128BitSuffle(b, t5, td, 0x88);
274  vs[0x6] = create4x128BitSuffle(b, t6, te, 0x88);
275  vs[0x7] = create4x128BitSuffle(b, t7, tf, 0x88);
276  vs[0x8] = create4x128BitSuffle(b, t0, t8, 0xdd);
277  vs[0x9] = create4x128BitSuffle(b, t1, t9, 0xdd);
278  vs[0xa] = create4x128BitSuffle(b, t2, ta, 0xdd);
279  vs[0xb] = create4x128BitSuffle(b, t3, tb, 0xdd);
280  vs[0xc] = create4x128BitSuffle(b, t4, tc, 0xdd);
281  vs[0xd] = create4x128BitSuffle(b, t5, td, 0xdd);
282  vs[0xe] = create4x128BitSuffle(b, t6, te, 0xdd);
283  vs[0xf] = create4x128BitSuffle(b, t7, tf, 0xdd);
284 
285  auto reshInputType = VectorType::get(
286  {m, n}, cast<VectorType>(source.getType()).getElementType());
287  Value res = b.create<ub::PoisonOp>(reshInputType);
288  for (int64_t i = 0; i < m; ++i)
289  res = b.create<vector::InsertOp>(vs[i], res, i);
290  return res;
291 }
292 
293 namespace {
294 /// Progressive lowering of TransposeOp.
295 /// One:
296 /// %x = vector.transpose %y, [1, 0]
297 /// is replaced by:
298 /// %z = arith.constant dense<0.000000e+00>
299 /// %0 = vector.extract %y[0, 0]
300 /// %1 = vector.insert %0, %z [0, 0]
301 /// ..
302 /// %x = vector.insert .., .. [.., ..]
303 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
304 public:
306 
307  TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering,
308  MLIRContext *context, PatternBenefit benefit = 1)
309  : OpRewritePattern<vector::TransposeOp>(context, benefit),
310  vectorTransposeLowering(vectorTransposeLowering) {}
311 
312  LogicalResult matchAndRewrite(vector::TransposeOp op,
313  PatternRewriter &rewriter) const override {
314  auto loc = op.getLoc();
315 
316  Value input = op.getVector();
317  VectorType inputType = op.getSourceVectorType();
318  VectorType resType = op.getResultVectorType();
319 
320  if (inputType.isScalable())
321  return rewriter.notifyMatchFailure(
322  op, "This lowering does not support scalable vectors");
323 
324  // Set up convenience transposition table.
325  ArrayRef<int64_t> transp = op.getPermutation();
326 
327  if (isShuffleLike(vectorTransposeLowering) &&
328  succeeded(isTranspose2DSlice(op)))
329  return rewriter.notifyMatchFailure(
330  op, "Options specifies lowering to shuffle");
331 
332  // Handle a true 2-D matrix transpose differently when requested.
333  if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat &&
334  resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
335  Type flattenedType =
336  VectorType::get(resType.getNumElements(), resType.getElementType());
337  auto matrix =
338  rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
339  auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
340  auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
341  Value trans = rewriter.create<vector::FlatTransposeOp>(
342  loc, flattenedType, matrix, rows, columns);
343  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
344  return success();
345  }
346 
347  // Generate unrolled extract/insert ops. We do not unroll the rightmost
348  // (i.e., highest-order) dimensions that are not transposed and leave them
349  // in vector form to improve performance. Therefore, we prune those
350  // dimensions from the shape/transpose data structures used to generate the
351  // extract/insert ops.
352  SmallVector<int64_t> prunedTransp;
353  pruneNonTransposedDims(transp, prunedTransp);
354  size_t numPrunedDims = transp.size() - prunedTransp.size();
355  auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
356  auto prunedInStrides = computeStrides(prunedInShape);
357 
358  // Generates the extract/insert operations for every scalar/vector element
359  // of the leftmost transposed dimensions. We traverse every transpose
360  // element using a linearized index that we delinearize to generate the
361  // appropriate indices for the extract/insert operations.
362  Value result = rewriter.create<ub::PoisonOp>(loc, resType);
363  int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
364 
365  for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
366  ++linearIdx) {
367  auto extractIdxs = delinearize(linearIdx, prunedInStrides);
368  SmallVector<int64_t> insertIdxs(extractIdxs);
369  applyPermutationToVector(insertIdxs, prunedTransp);
370  Value extractOp =
371  rewriter.createOrFold<vector::ExtractOp>(loc, input, extractIdxs);
372  result = rewriter.createOrFold<vector::InsertOp>(loc, extractOp, result,
373  insertIdxs);
374  }
375 
376  rewriter.replaceOp(op, result);
377  return success();
378  }
379 
380 private:
381  /// Options to control the vector patterns.
382  vector::VectorTransposeLowering vectorTransposeLowering;
383 };
384 
385 /// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
386 /// to 2D vectors with at least one unit dim. For example:
387 ///
388 /// Replace:
389 /// vector.transpose %0, [1, 0] : vector<4x1xi32>> to
390 /// vector<1x4xi32>
391 /// with:
392 /// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
393 ///
394 /// Source with leading unit dim (inverse) is also replaced. Unit dim must
395 /// be fixed. Non-unit dim can be scalable.
396 ///
397 /// TODO: This pattern was introduced specifically to help lower scalable
398 /// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
399 /// to cancel out) would be preferable:
400 ///
401 /// BEFORE:
402 /// %0 = some_op
403 /// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
404 /// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
405 /// AFTER:
406 /// %0 = some_op
407 /// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
408 ///
409 /// Given the context above, we may want to consider (re-)moving this pattern
410 /// at some later time. I am leaving it for now in case there are other users
411 /// that I am not aware of.
412 class Transpose2DWithUnitDimToShapeCast
413  : public OpRewritePattern<vector::TransposeOp> {
414 public:
416 
417  Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
418  PatternBenefit benefit = 1)
419  : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
420 
421  LogicalResult matchAndRewrite(vector::TransposeOp op,
422  PatternRewriter &rewriter) const override {
423  Value input = op.getVector();
424  VectorType resType = op.getResultVectorType();
425 
426  // Set up convenience transposition table.
427  ArrayRef<int64_t> transp = op.getPermutation();
428 
429  if (resType.getRank() == 2 &&
430  ((resType.getShape().front() == 1 &&
431  !resType.getScalableDims().front()) ||
432  (resType.getShape().back() == 1 &&
433  !resType.getScalableDims().back())) &&
434  transp == ArrayRef<int64_t>({1, 0})) {
435  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
436  return success();
437  }
438 
439  return failure();
440  }
441 };
442 
443 /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
444 /// If the strategy is Shuffle1D, it will be lowered to:
445 /// vector.shape_cast 2D -> 1D
446 /// vector.shuffle
447 /// vector.shape_cast 1D -> 2D
448 /// If the strategy is Shuffle16x16, it will be lowered to a sequence of shuffle
449 /// ops on 16xf32 vectors.
450 class TransposeOp2DToShuffleLowering
451  : public OpRewritePattern<vector::TransposeOp> {
452 public:
454 
455  TransposeOp2DToShuffleLowering(
456  vector::VectorTransposeLowering vectorTransposeLowering,
457  MLIRContext *context, PatternBenefit benefit = 1)
458  : OpRewritePattern<vector::TransposeOp>(context, benefit),
459  vectorTransposeLowering(vectorTransposeLowering) {}
460 
461  LogicalResult matchAndRewrite(vector::TransposeOp op,
462  PatternRewriter &rewriter) const override {
463  if (!isShuffleLike(vectorTransposeLowering))
464  return rewriter.notifyMatchFailure(
465  op, "not using vector shuffle based lowering");
466 
467  if (op.getSourceVectorType().isScalable())
468  return rewriter.notifyMatchFailure(
469  op, "vector shuffle lowering not supported for scalable vectors");
470 
471  auto srcGtOneDims = isTranspose2DSlice(op);
472  if (failed(srcGtOneDims))
473  return rewriter.notifyMatchFailure(
474  op, "expected transposition on a 2D slice");
475 
476  VectorType srcType = op.getSourceVectorType();
477  int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
478  int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
479 
480  // Reshape the n-D input vector with only two dimensions greater than one
481  // to a 2-D vector.
482  Location loc = op.getLoc();
483  auto flattenedType = VectorType::get({n * m}, srcType.getElementType());
484  auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
485  auto reshInput = rewriter.create<vector::ShapeCastOp>(loc, flattenedType,
486  op.getVector());
487 
488  Value res;
489  if (vectorTransposeLowering == VectorTransposeLowering::Shuffle16x16 &&
490  m == 16 && n == 16) {
491  reshInput =
492  rewriter.create<vector::ShapeCastOp>(loc, reshInputType, reshInput);
493  res = transposeToShuffle16x16(rewriter, reshInput, m, n);
494  } else {
495  // Fallback to shuffle on 1D approach.
496  res = transposeToShuffle1D(rewriter, reshInput, m, n);
497  }
498 
499  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
500  op, op.getResultVectorType(), res);
501 
502  return success();
503  }
504 
505 private:
506  /// Options to control the vector patterns.
507  vector::VectorTransposeLowering vectorTransposeLowering;
508 };
509 } // namespace
510 
513  VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
514  patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
515  benefit);
516  patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
517  vectorTransposeLowering, patterns.getContext(), benefit);
518 }
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask.
static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask.
static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m, int n)
Lowers the value to a sequence of vector.shuffle ops.
static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask.
static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask.
static void pruneNonTransposedDims(ArrayRef< int64_t > transpose, SmallVectorImpl< int64_t > &result)
Given a 'transpose' pattern, prune the rightmost dimensions that are not transposed.
static bool isShuffleLike(VectorTransposeLowering lowering)
Returns true if the lowering option is a vector shuffle based approach.
static SmallVector< int64_t > getUnpackShufflePermFor128Lane(ArrayRef< int64_t > vals, int numBits)
Returns a shuffle mask that builds on vals.
static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit elements) selected by mask...
static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n)
Lowers the value to a vector.shuffle op.
int64_t rows
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
void createOrFold(llvm::SmallVectorImpl< Value > &results, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
Definition: VectorUtils.cpp:84
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:22
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.