MLIR  22.0.0git
LowerVectorTranspose.cpp
Go to the documentation of this file.
1 //===- LowerVectorTranspose.cpp - Lower 'vector.transpose' operation ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.transpose' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Location.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
25 
26 #define DEBUG_TYPE "lower-vector-transpose"
27 
28 using namespace mlir;
29 using namespace mlir::vector;
30 
31 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not
32 /// transposed.
34  SmallVectorImpl<int64_t> &result) {
35  size_t numTransposedDims = transpose.size();
36  for (size_t transpDim : llvm::reverse(transpose)) {
37  if (transpDim != numTransposedDims - 1)
38  break;
39  numTransposedDims--;
40  }
41 
42  result.append(transpose.begin(), transpose.begin() + numTransposedDims);
43 }
44 
45 /// Returns true if the lowering option is a vector shuffle based approach.
46 static bool isShuffleLike(VectorTransposeLowering lowering) {
47  return lowering == VectorTransposeLowering::Shuffle1D ||
48  lowering == VectorTransposeLowering::Shuffle16x16;
49 }
50 
51 /// Returns a shuffle mask that builds on `vals`. `vals` is the offset base of
52 /// shuffle ops, i.e., the unpack pattern. The method iterates with `vals` to
53 /// create the mask for `numBits` bits vector. The `numBits` have to be a
54 /// multiple of 128. For example, if `vals` is {0, 1, 16, 17} and `numBits` is
55 /// 512, there should be 16 elements in the final result. It constructs the
56 /// below mask to get the unpack elements.
57 /// [0, 1, 16, 17,
58 /// 0+4, 1+4, 16+4, 17+4,
59 /// 0+8, 1+8, 16+8, 17+8,
60 /// 0+12, 1+12, 16+12, 17+12]
63  assert(numBits % 128 == 0 && "expected numBits is a multiple of 128");
64  int numElem = numBits / 32;
66  for (int i = 0; i < numElem; i += 4)
67  for (int64_t v : vals)
68  res.push_back(v + i);
69  return res;
70 }
71 
72 /// Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask. For
73 /// example, if it is targeting 512 bit vector, returns
74 /// vector.shuffle on v1, v2, [0, 1, 16, 17,
75 /// 0+4, 1+4, 16+4, 17+4,
76 /// 0+8, 1+8, 16+8, 17+8,
77 /// 0+12, 1+12, 16+12, 17+12].
79  int numBits) {
80  int numElem = numBits / 32;
81  return vector::ShuffleOp::create(
82  b, v1, v2,
83  getUnpackShufflePermFor128Lane({0, 1, numElem, numElem + 1}, numBits));
84 }
85 
86 /// Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask. For
87 /// example, if it is targeting 512 bit vector, returns
88 /// vector.shuffle, v1, v2, [2, 3, 18, 19,
89 /// 2+4, 3+4, 18+4, 19+4,
90 /// 2+8, 3+8, 18+8, 19+8,
91 /// 2+12, 3+12, 18+12, 19+12].
93  int numBits) {
94  int numElem = numBits / 32;
95  return vector::ShuffleOp::create(
96  b, v1, v2,
97  getUnpackShufflePermFor128Lane({2, 3, numElem + 2, numElem + 3},
98  numBits));
99 }
100 
101 /// Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask. For
102 /// example, if it is targeting 512 bit vector, returns
103 /// vector.shuffle, v1, v2, [0, 16, 1, 17,
104 /// 0+4, 16+4, 1+4, 17+4,
105 /// 0+8, 16+8, 1+8, 17+8,
106 /// 0+12, 16+12, 1+12, 17+12].
108  int numBits) {
109  int numElem = numBits / 32;
110  auto shuffle = vector::ShuffleOp::create(
111  b, v1, v2,
112  getUnpackShufflePermFor128Lane({0, numElem, 1, numElem + 1}, numBits));
113  return shuffle;
114 }
115 
116 /// Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask. For
117 /// example, if it is targeting 512 bit vector, returns
118 /// vector.shuffle, v1, v2, [2, 18, 3, 19,
119 /// 2+4, 18+4, 3+4, 19+4,
120 /// 2+8, 18+8, 3+8, 19+8,
121 /// 2+12, 18+12, 3+12, 19+12].
123  int numBits) {
124  int numElem = numBits / 32;
125  return vector::ShuffleOp::create(
126  b, v1, v2,
127  getUnpackShufflePermFor128Lane({2, numElem + 2, 3, numElem + 3},
128  numBits));
129 }
130 
131 /// Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit
132 /// elements) selected by `mask` from `v1` and `v2`. I.e.,
133 ///
134 /// DEFINE SELECT4(src, control) {
135 /// CASE(control[1:0]) OF
136 /// 0: tmp[127:0] := src[127:0]
137 /// 1: tmp[127:0] := src[255:128]
138 /// 2: tmp[127:0] := src[383:256]
139 /// 3: tmp[127:0] := src[511:384]
140 /// ESAC
141 /// RETURN tmp[127:0]
142 /// }
143 /// dst[127:0] := SELECT4(v1[511:0], mask[1:0])
144 /// dst[255:128] := SELECT4(v1[511:0], mask[3:2])
145 /// dst[383:256] := SELECT4(v2[511:0], mask[5:4])
146 /// dst[511:384] := SELECT4(v2[511:0], mask[7:6])
148  uint8_t mask) {
149  assert(cast<VectorType>(v1.getType()).getShape()[0] == 16 &&
150  "expected a vector with length=16");
151  SmallVector<int64_t> shuffleMask;
152  auto appendToMask = [&](int64_t base, uint8_t control) {
153  switch (control) {
154  case 0:
155  llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 0, base + 1,
156  base + 2, base + 3});
157  break;
158  case 1:
159  llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 4, base + 5,
160  base + 6, base + 7});
161  break;
162  case 2:
163  llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 8, base + 9,
164  base + 10, base + 11});
165  break;
166  case 3:
167  llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 12, base + 13,
168  base + 14, base + 15});
169  break;
170  default:
171  llvm_unreachable("control > 3 : overflow");
172  }
173  };
174  uint8_t b01 = mask & 0x3;
175  uint8_t b23 = (mask >> 2) & 0x3;
176  uint8_t b45 = (mask >> 4) & 0x3;
177  uint8_t b67 = (mask >> 6) & 0x3;
178  appendToMask(0, b01);
179  appendToMask(0, b23);
180  appendToMask(16, b45);
181  appendToMask(16, b67);
182  return vector::ShuffleOp::create(b, v1, v2, shuffleMask);
183 }
184 
185 /// Lowers the value to a vector.shuffle op. The `source` is expected to be a
186 /// 1-D vector and have `m`x`n` elements.
187 static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n) {
189  mask.reserve(m * n);
190  for (int64_t j = 0; j < n; ++j)
191  for (int64_t i = 0; i < m; ++i)
192  mask.push_back(i * n + j);
193  return vector::ShuffleOp::create(b, source.getLoc(), source, source, mask);
194 }
195 
196 /// Lowers the value to a sequence of vector.shuffle ops. The `source` is
197 /// expected to be a 16x16 vector.
198 static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
199  int n) {
200  ImplicitLocOpBuilder b(source.getLoc(), builder);
202  for (int64_t i = 0; i < m; ++i)
203  vs.push_back(b.createOrFold<vector::ExtractOp>(source, i));
204 
205  // Interleave 32-bit lanes using
206  // 8x _mm512_unpacklo_epi32
207  // 8x _mm512_unpackhi_epi32
208  Value t0 = createUnpackLoPs(b, vs[0x0], vs[0x1], 512);
209  Value t1 = createUnpackHiPs(b, vs[0x0], vs[0x1], 512);
210  Value t2 = createUnpackLoPs(b, vs[0x2], vs[0x3], 512);
211  Value t3 = createUnpackHiPs(b, vs[0x2], vs[0x3], 512);
212  Value t4 = createUnpackLoPs(b, vs[0x4], vs[0x5], 512);
213  Value t5 = createUnpackHiPs(b, vs[0x4], vs[0x5], 512);
214  Value t6 = createUnpackLoPs(b, vs[0x6], vs[0x7], 512);
215  Value t7 = createUnpackHiPs(b, vs[0x6], vs[0x7], 512);
216  Value t8 = createUnpackLoPs(b, vs[0x8], vs[0x9], 512);
217  Value t9 = createUnpackHiPs(b, vs[0x8], vs[0x9], 512);
218  Value ta = createUnpackLoPs(b, vs[0xa], vs[0xb], 512);
219  Value tb = createUnpackHiPs(b, vs[0xa], vs[0xb], 512);
220  Value tc = createUnpackLoPs(b, vs[0xc], vs[0xd], 512);
221  Value td = createUnpackHiPs(b, vs[0xc], vs[0xd], 512);
222  Value te = createUnpackLoPs(b, vs[0xe], vs[0xf], 512);
223  Value tf = createUnpackHiPs(b, vs[0xe], vs[0xf], 512);
224 
225  // Interleave 64-bit lanes using
226  // 8x _mm512_unpacklo_epi64
227  // 8x _mm512_unpackhi_epi64
228  Value r0 = createUnpackLoPd(b, t0, t2, 512);
229  Value r1 = createUnpackHiPd(b, t0, t2, 512);
230  Value r2 = createUnpackLoPd(b, t1, t3, 512);
231  Value r3 = createUnpackHiPd(b, t1, t3, 512);
232  Value r4 = createUnpackLoPd(b, t4, t6, 512);
233  Value r5 = createUnpackHiPd(b, t4, t6, 512);
234  Value r6 = createUnpackLoPd(b, t5, t7, 512);
235  Value r7 = createUnpackHiPd(b, t5, t7, 512);
236  Value r8 = createUnpackLoPd(b, t8, ta, 512);
237  Value r9 = createUnpackHiPd(b, t8, ta, 512);
238  Value ra = createUnpackLoPd(b, t9, tb, 512);
239  Value rb = createUnpackHiPd(b, t9, tb, 512);
240  Value rc = createUnpackLoPd(b, tc, te, 512);
241  Value rd = createUnpackHiPd(b, tc, te, 512);
242  Value re = createUnpackLoPd(b, td, tf, 512);
243  Value rf = createUnpackHiPd(b, td, tf, 512);
244 
245  // Permute 128-bit lanes using
246  // 16x _mm512_shuffle_i32x4
247  t0 = create4x128BitSuffle(b, r0, r4, 0x88);
248  t1 = create4x128BitSuffle(b, r1, r5, 0x88);
249  t2 = create4x128BitSuffle(b, r2, r6, 0x88);
250  t3 = create4x128BitSuffle(b, r3, r7, 0x88);
251  t4 = create4x128BitSuffle(b, r0, r4, 0xdd);
252  t5 = create4x128BitSuffle(b, r1, r5, 0xdd);
253  t6 = create4x128BitSuffle(b, r2, r6, 0xdd);
254  t7 = create4x128BitSuffle(b, r3, r7, 0xdd);
255  t8 = create4x128BitSuffle(b, r8, rc, 0x88);
256  t9 = create4x128BitSuffle(b, r9, rd, 0x88);
257  ta = create4x128BitSuffle(b, ra, re, 0x88);
258  tb = create4x128BitSuffle(b, rb, rf, 0x88);
259  tc = create4x128BitSuffle(b, r8, rc, 0xdd);
260  td = create4x128BitSuffle(b, r9, rd, 0xdd);
261  te = create4x128BitSuffle(b, ra, re, 0xdd);
262  tf = create4x128BitSuffle(b, rb, rf, 0xdd);
263 
264  // Permute 256-bit lanes using again
265  // 16x _mm512_shuffle_i32x4
266  vs[0x0] = create4x128BitSuffle(b, t0, t8, 0x88);
267  vs[0x1] = create4x128BitSuffle(b, t1, t9, 0x88);
268  vs[0x2] = create4x128BitSuffle(b, t2, ta, 0x88);
269  vs[0x3] = create4x128BitSuffle(b, t3, tb, 0x88);
270  vs[0x4] = create4x128BitSuffle(b, t4, tc, 0x88);
271  vs[0x5] = create4x128BitSuffle(b, t5, td, 0x88);
272  vs[0x6] = create4x128BitSuffle(b, t6, te, 0x88);
273  vs[0x7] = create4x128BitSuffle(b, t7, tf, 0x88);
274  vs[0x8] = create4x128BitSuffle(b, t0, t8, 0xdd);
275  vs[0x9] = create4x128BitSuffle(b, t1, t9, 0xdd);
276  vs[0xa] = create4x128BitSuffle(b, t2, ta, 0xdd);
277  vs[0xb] = create4x128BitSuffle(b, t3, tb, 0xdd);
278  vs[0xc] = create4x128BitSuffle(b, t4, tc, 0xdd);
279  vs[0xd] = create4x128BitSuffle(b, t5, td, 0xdd);
280  vs[0xe] = create4x128BitSuffle(b, t6, te, 0xdd);
281  vs[0xf] = create4x128BitSuffle(b, t7, tf, 0xdd);
282 
283  auto reshInputType = VectorType::get(
284  {m, n}, cast<VectorType>(source.getType()).getElementType());
285  Value res = ub::PoisonOp::create(b, reshInputType);
286  for (int64_t i = 0; i < m; ++i)
287  res = vector::InsertOp::create(b, vs[i], res, i);
288  return res;
289 }
290 
291 namespace {
292 /// Progressive lowering of TransposeOp.
293 /// One:
294 /// %x = vector.transpose %y, [1, 0]
295 /// is replaced by:
296 /// %z = arith.constant dense<0.000000e+00>
297 /// %0 = vector.extract %y[0, 0]
298 /// %1 = vector.insert %0, %z [0, 0]
299 /// ..
300 /// %x = vector.insert .., .. [.., ..]
301 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
302 public:
304 
305  TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering,
306  MLIRContext *context, PatternBenefit benefit = 1)
307  : OpRewritePattern<vector::TransposeOp>(context, benefit),
308  vectorTransposeLowering(vectorTransposeLowering) {}
309 
310  LogicalResult matchAndRewrite(vector::TransposeOp op,
311  PatternRewriter &rewriter) const override {
312  auto loc = op.getLoc();
313 
314  Value input = op.getVector();
315  VectorType inputType = op.getSourceVectorType();
316  VectorType resType = op.getResultVectorType();
317 
318  if (inputType.isScalable())
319  return rewriter.notifyMatchFailure(
320  op, "This lowering does not support scalable vectors");
321 
322  // Set up convenience transposition table.
323  ArrayRef<int64_t> transp = op.getPermutation();
324 
325  if (isShuffleLike(vectorTransposeLowering) &&
326  succeeded(isTranspose2DSlice(op)))
327  return rewriter.notifyMatchFailure(
328  op, "Options specifies lowering to shuffle");
329 
330  // Generate unrolled extract/insert ops. We do not unroll the rightmost
331  // (i.e., highest-order) dimensions that are not transposed and leave them
332  // in vector form to improve performance. Therefore, we prune those
333  // dimensions from the shape/transpose data structures used to generate the
334  // extract/insert ops.
335  SmallVector<int64_t> prunedTransp;
336  pruneNonTransposedDims(transp, prunedTransp);
337  size_t numPrunedDims = transp.size() - prunedTransp.size();
338  auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
339  auto prunedInStrides = computeStrides(prunedInShape);
340 
341  // Generates the extract/insert operations for every scalar/vector element
342  // of the leftmost transposed dimensions. We traverse every transpose
343  // element using a linearized index that we delinearize to generate the
344  // appropriate indices for the extract/insert operations.
345  Value result = ub::PoisonOp::create(rewriter, loc, resType);
346  int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
347 
348  for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
349  ++linearIdx) {
350  auto extractIdxs = delinearize(linearIdx, prunedInStrides);
351  SmallVector<int64_t> insertIdxs(extractIdxs);
352  applyPermutationToVector(insertIdxs, prunedTransp);
353  Value extractOp =
354  rewriter.createOrFold<vector::ExtractOp>(loc, input, extractIdxs);
355  result = rewriter.createOrFold<vector::InsertOp>(loc, extractOp, result,
356  insertIdxs);
357  }
358 
359  rewriter.replaceOp(op, result);
360  return success();
361  }
362 
363 private:
364  /// Options to control the vector patterns.
365  vector::VectorTransposeLowering vectorTransposeLowering;
366 };
367 
368 /// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
369 /// to 2D vectors with at least one unit dim. For example:
370 ///
371 /// Replace:
372 /// vector.transpose %0, [1, 0] : vector<4x1xi32>> to
373 /// vector<1x4xi32>
374 /// with:
375 /// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
376 ///
377 /// Source with leading unit dim (inverse) is also replaced. Unit dim must
378 /// be fixed. Non-unit dim can be scalable.
379 ///
380 /// TODO: This pattern was introduced specifically to help lower scalable
381 /// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
382 /// to cancel out) would be preferable:
383 ///
384 /// BEFORE:
385 /// %0 = some_op
386 /// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
387 /// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
388 /// AFTER:
389 /// %0 = some_op
390 /// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
391 ///
392 /// Given the context above, we may want to consider (re-)moving this pattern
393 /// at some later time. I am leaving it for now in case there are other users
394 /// that I am not aware of.
395 class Transpose2DWithUnitDimToShapeCast
396  : public OpRewritePattern<vector::TransposeOp> {
397 public:
399 
400  Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
401  PatternBenefit benefit = 1)
402  : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
403 
404  LogicalResult matchAndRewrite(vector::TransposeOp op,
405  PatternRewriter &rewriter) const override {
406  Value input = op.getVector();
407  VectorType resType = op.getResultVectorType();
408 
409  // Set up convenience transposition table.
410  ArrayRef<int64_t> transp = op.getPermutation();
411 
412  if (resType.getRank() == 2 &&
413  ((resType.getShape().front() == 1 &&
414  !resType.getScalableDims().front()) ||
415  (resType.getShape().back() == 1 &&
416  !resType.getScalableDims().back())) &&
417  transp == ArrayRef<int64_t>({1, 0})) {
418  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
419  return success();
420  }
421 
422  return failure();
423  }
424 };
425 
426 /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
427 /// If the strategy is Shuffle1D, it will be lowered to:
428 /// vector.shape_cast 2D -> 1D
429 /// vector.shuffle
430 /// vector.shape_cast 1D -> 2D
431 /// If the strategy is Shuffle16x16, it will be lowered to a sequence of shuffle
432 /// ops on 16xf32 vectors.
433 class TransposeOp2DToShuffleLowering
434  : public OpRewritePattern<vector::TransposeOp> {
435 public:
437 
438  TransposeOp2DToShuffleLowering(
439  vector::VectorTransposeLowering vectorTransposeLowering,
440  MLIRContext *context, PatternBenefit benefit = 1)
441  : OpRewritePattern<vector::TransposeOp>(context, benefit),
442  vectorTransposeLowering(vectorTransposeLowering) {}
443 
444  LogicalResult matchAndRewrite(vector::TransposeOp op,
445  PatternRewriter &rewriter) const override {
446  if (!isShuffleLike(vectorTransposeLowering))
447  return rewriter.notifyMatchFailure(
448  op, "not using vector shuffle based lowering");
449 
450  if (op.getSourceVectorType().isScalable())
451  return rewriter.notifyMatchFailure(
452  op, "vector shuffle lowering not supported for scalable vectors");
453 
454  auto srcGtOneDims = isTranspose2DSlice(op);
455  if (failed(srcGtOneDims))
456  return rewriter.notifyMatchFailure(
457  op, "expected transposition on a 2D slice");
458 
459  VectorType srcType = op.getSourceVectorType();
460  int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
461  int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
462 
463  // Reshape the n-D input vector with only two dimensions greater than one
464  // to a 2-D vector.
465  Location loc = op.getLoc();
466  auto flattenedType = VectorType::get({n * m}, srcType.getElementType());
467  auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
468  auto reshInput = vector::ShapeCastOp::create(rewriter, loc, flattenedType,
469  op.getVector());
470 
471  Value res;
472  if (vectorTransposeLowering == VectorTransposeLowering::Shuffle16x16 &&
473  m == 16 && n == 16) {
474  reshInput =
475  vector::ShapeCastOp::create(rewriter, loc, reshInputType, reshInput);
476  res = transposeToShuffle16x16(rewriter, reshInput, m, n);
477  } else {
478  // Fallback to shuffle on 1D approach.
479  res = transposeToShuffle1D(rewriter, reshInput, m, n);
480  }
481 
482  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
483  op, op.getResultVectorType(), res);
484 
485  return success();
486  }
487 
488 private:
489  /// Options to control the vector patterns.
490  vector::VectorTransposeLowering vectorTransposeLowering;
491 };
492 } // namespace
493 
496  VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
497  patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
498  benefit);
499  patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
500  vectorTransposeLowering, patterns.getContext(), benefit);
501 }
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask.
static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask.
static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m, int n)
Lowers the value to a sequence of vector.shuffle ops.
static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask.
static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask.
static void pruneNonTransposedDims(ArrayRef< int64_t > transpose, SmallVectorImpl< int64_t > &result)
Given a 'transpose' pattern, prune the rightmost dimensions that are not transposed.
static bool isShuffleLike(VectorTransposeLowering lowering)
Returns true if the lowering option is a vector shuffle based approach.
static SmallVector< int64_t > getUnpackShufflePermFor128Lane(ArrayRef< int64_t > vals, int numBits)
Returns a shuffle mask that builds on vals.
static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit elements) selected by mask...
static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n)
Lowers the value to a vector.shuffle op.
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:621
void createOrFold(llvm::SmallVectorImpl< Value > &results, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:672
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:205
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
Definition: VectorUtils.cpp:82
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.