MLIR  19.0.0git
LowerVectorTranspose.cpp
Go to the documentation of this file.
1 //===- LowerVectorTranspose.cpp - Lower 'vector.transpose' operation ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.transpose' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
27 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
35 
36 #define DEBUG_TYPE "lower-vector-transpose"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not
42 /// transposed.
44  SmallVectorImpl<int64_t> &result) {
45  size_t numTransposedDims = transpose.size();
46  for (size_t transpDim : llvm::reverse(transpose)) {
47  if (transpDim != numTransposedDims - 1)
48  break;
49  numTransposedDims--;
50  }
51 
52  result.append(transpose.begin(), transpose.begin() + numTransposedDims);
53 }
54 
55 /// Returns true if the lowering option is a vector shuffle based approach.
56 static bool isShuffleLike(VectorTransposeLowering lowering) {
57  return lowering == VectorTransposeLowering::Shuffle1D ||
58  lowering == VectorTransposeLowering::Shuffle16x16;
59 }
60 
61 /// Returns a shuffle mask that builds on `vals`. `vals` is the offset base of
62 /// shuffle ops, i.e., the unpack pattern. The method iterates with `vals` to
63 /// create the mask for `numBits` bits vector. The `numBits` have to be a
64 /// multiple of 128. For example, if `vals` is {0, 1, 16, 17} and `numBits` is
65 /// 512, there should be 16 elements in the final result. It constructs the
66 /// below mask to get the unpack elements.
67 /// [0, 1, 16, 17,
68 /// 0+4, 1+4, 16+4, 17+4,
69 /// 0+8, 1+8, 16+8, 17+8,
70 /// 0+12, 1+12, 16+12, 17+12]
73  assert(numBits % 128 == 0 && "expected numBits is a multiple of 128");
74  int numElem = numBits / 32;
76  for (int i = 0; i < numElem; i += 4)
77  for (int64_t v : vals)
78  res.push_back(v + i);
79  return res;
80 }
81 
82 /// Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask. For
83 /// example, if it is targeting 512 bit vector, returns
84 /// vector.shuffle on v1, v2, [0, 1, 16, 17,
85 /// 0+4, 1+4, 16+4, 17+4,
86 /// 0+8, 1+8, 16+8, 17+8,
87 /// 0+12, 1+12, 16+12, 17+12].
89  int numBits) {
90  int numElem = numBits / 32;
91  return b.create<vector::ShuffleOp>(
92  v1, v2,
93  getUnpackShufflePermFor128Lane({0, 1, numElem, numElem + 1}, numBits));
94 }
95 
96 /// Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask. For
97 /// example, if it is targeting 512 bit vector, returns
98 /// vector.shuffle, v1, v2, [2, 3, 18, 19,
99 /// 2+4, 3+4, 18+4, 19+4,
100 /// 2+8, 3+8, 18+8, 19+8,
101 /// 2+12, 3+12, 18+12, 19+12].
103  int numBits) {
104  int numElem = numBits / 32;
105  return b.create<vector::ShuffleOp>(
106  v1, v2,
107  getUnpackShufflePermFor128Lane({2, 3, numElem + 2, numElem + 3},
108  numBits));
109 }
110 
111 /// Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask. For
112 /// example, if it is targeting 512 bit vector, returns
113 /// vector.shuffle, v1, v2, [0, 16, 1, 17,
114 /// 0+4, 16+4, 1+4, 17+4,
115 /// 0+8, 16+8, 1+8, 17+8,
116 /// 0+12, 16+12, 1+12, 17+12].
118  int numBits) {
119  int numElem = numBits / 32;
120  auto shuffle = b.create<vector::ShuffleOp>(
121  v1, v2,
122  getUnpackShufflePermFor128Lane({0, numElem, 1, numElem + 1}, numBits));
123  return shuffle;
124 }
125 
126 /// Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask. For
127 /// example, if it is targeting 512 bit vector, returns
128 /// vector.shuffle, v1, v2, [2, 18, 3, 19,
129 /// 2+4, 18+4, 3+4, 19+4,
130 /// 2+8, 18+8, 3+8, 19+8,
131 /// 2+12, 18+12, 3+12, 19+12].
133  int numBits) {
134  int numElem = numBits / 32;
135  return b.create<vector::ShuffleOp>(
136  v1, v2,
137  getUnpackShufflePermFor128Lane({2, numElem + 2, 3, numElem + 3},
138  numBits));
139 }
140 
141 /// Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit
142 /// elements) selected by `mask` from `v1` and `v2`. I.e.,
143 ///
144 /// DEFINE SELECT4(src, control) {
145 /// CASE(control[1:0]) OF
146 /// 0: tmp[127:0] := src[127:0]
147 /// 1: tmp[127:0] := src[255:128]
148 /// 2: tmp[127:0] := src[383:256]
149 /// 3: tmp[127:0] := src[511:384]
150 /// ESAC
151 /// RETURN tmp[127:0]
152 /// }
153 /// dst[127:0] := SELECT4(v1[511:0], mask[1:0])
154 /// dst[255:128] := SELECT4(v1[511:0], mask[3:2])
155 /// dst[383:256] := SELECT4(v2[511:0], mask[5:4])
156 /// dst[511:384] := SELECT4(v2[511:0], mask[7:6])
158  uint8_t mask) {
159  assert(cast<VectorType>(v1.getType()).getShape()[0] == 16 &&
160  "expected a vector with length=16");
161  SmallVector<int64_t> shuffleMask;
162  auto appendToMask = [&](int64_t base, uint8_t control) {
163  switch (control) {
164  case 0:
165  llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 0, base + 1,
166  base + 2, base + 3});
167  break;
168  case 1:
169  llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 4, base + 5,
170  base + 6, base + 7});
171  break;
172  case 2:
173  llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 8, base + 9,
174  base + 10, base + 11});
175  break;
176  case 3:
177  llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 12, base + 13,
178  base + 14, base + 15});
179  break;
180  default:
181  llvm_unreachable("control > 3 : overflow");
182  }
183  };
184  uint8_t b01 = mask & 0x3;
185  uint8_t b23 = (mask >> 2) & 0x3;
186  uint8_t b45 = (mask >> 4) & 0x3;
187  uint8_t b67 = (mask >> 6) & 0x3;
188  appendToMask(0, b01);
189  appendToMask(0, b23);
190  appendToMask(16, b45);
191  appendToMask(16, b67);
192  return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
193 }
194 
195 /// Lowers the value to a vector.shuffle op. The `source` is expected to be a
196 /// 1-D vector and have `m`x`n` elements.
197 static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n) {
199  mask.reserve(m * n);
200  for (int64_t j = 0; j < n; ++j)
201  for (int64_t i = 0; i < m; ++i)
202  mask.push_back(i * n + j);
203  return b.create<vector::ShuffleOp>(source.getLoc(), source, source, mask);
204 }
205 
206 /// Lowers the value to a sequence of vector.shuffle ops. The `source` is
207 /// expected to be a 16x16 vector.
208 static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
209  int n) {
210  ImplicitLocOpBuilder b(source.getLoc(), builder);
212  for (int64_t i = 0; i < m; ++i)
213  vs.push_back(b.create<vector::ExtractOp>(source, i));
214 
215  // Interleave 32-bit lanes using
216  // 8x _mm512_unpacklo_epi32
217  // 8x _mm512_unpackhi_epi32
218  Value t0 = createUnpackLoPs(b, vs[0x0], vs[0x1], 512);
219  Value t1 = createUnpackHiPs(b, vs[0x0], vs[0x1], 512);
220  Value t2 = createUnpackLoPs(b, vs[0x2], vs[0x3], 512);
221  Value t3 = createUnpackHiPs(b, vs[0x2], vs[0x3], 512);
222  Value t4 = createUnpackLoPs(b, vs[0x4], vs[0x5], 512);
223  Value t5 = createUnpackHiPs(b, vs[0x4], vs[0x5], 512);
224  Value t6 = createUnpackLoPs(b, vs[0x6], vs[0x7], 512);
225  Value t7 = createUnpackHiPs(b, vs[0x6], vs[0x7], 512);
226  Value t8 = createUnpackLoPs(b, vs[0x8], vs[0x9], 512);
227  Value t9 = createUnpackHiPs(b, vs[0x8], vs[0x9], 512);
228  Value ta = createUnpackLoPs(b, vs[0xa], vs[0xb], 512);
229  Value tb = createUnpackHiPs(b, vs[0xa], vs[0xb], 512);
230  Value tc = createUnpackLoPs(b, vs[0xc], vs[0xd], 512);
231  Value td = createUnpackHiPs(b, vs[0xc], vs[0xd], 512);
232  Value te = createUnpackLoPs(b, vs[0xe], vs[0xf], 512);
233  Value tf = createUnpackHiPs(b, vs[0xe], vs[0xf], 512);
234 
235  // Interleave 64-bit lanes using
236  // 8x _mm512_unpacklo_epi64
237  // 8x _mm512_unpackhi_epi64
238  Value r0 = createUnpackLoPd(b, t0, t2, 512);
239  Value r1 = createUnpackHiPd(b, t0, t2, 512);
240  Value r2 = createUnpackLoPd(b, t1, t3, 512);
241  Value r3 = createUnpackHiPd(b, t1, t3, 512);
242  Value r4 = createUnpackLoPd(b, t4, t6, 512);
243  Value r5 = createUnpackHiPd(b, t4, t6, 512);
244  Value r6 = createUnpackLoPd(b, t5, t7, 512);
245  Value r7 = createUnpackHiPd(b, t5, t7, 512);
246  Value r8 = createUnpackLoPd(b, t8, ta, 512);
247  Value r9 = createUnpackHiPd(b, t8, ta, 512);
248  Value ra = createUnpackLoPd(b, t9, tb, 512);
249  Value rb = createUnpackHiPd(b, t9, tb, 512);
250  Value rc = createUnpackLoPd(b, tc, te, 512);
251  Value rd = createUnpackHiPd(b, tc, te, 512);
252  Value re = createUnpackLoPd(b, td, tf, 512);
253  Value rf = createUnpackHiPd(b, td, tf, 512);
254 
255  // Permute 128-bit lanes using
256  // 16x _mm512_shuffle_i32x4
257  t0 = create4x128BitSuffle(b, r0, r4, 0x88);
258  t1 = create4x128BitSuffle(b, r1, r5, 0x88);
259  t2 = create4x128BitSuffle(b, r2, r6, 0x88);
260  t3 = create4x128BitSuffle(b, r3, r7, 0x88);
261  t4 = create4x128BitSuffle(b, r0, r4, 0xdd);
262  t5 = create4x128BitSuffle(b, r1, r5, 0xdd);
263  t6 = create4x128BitSuffle(b, r2, r6, 0xdd);
264  t7 = create4x128BitSuffle(b, r3, r7, 0xdd);
265  t8 = create4x128BitSuffle(b, r8, rc, 0x88);
266  t9 = create4x128BitSuffle(b, r9, rd, 0x88);
267  ta = create4x128BitSuffle(b, ra, re, 0x88);
268  tb = create4x128BitSuffle(b, rb, rf, 0x88);
269  tc = create4x128BitSuffle(b, r8, rc, 0xdd);
270  td = create4x128BitSuffle(b, r9, rd, 0xdd);
271  te = create4x128BitSuffle(b, ra, re, 0xdd);
272  tf = create4x128BitSuffle(b, rb, rf, 0xdd);
273 
274  // Permute 256-bit lanes using again
275  // 16x _mm512_shuffle_i32x4
276  vs[0x0] = create4x128BitSuffle(b, t0, t8, 0x88);
277  vs[0x1] = create4x128BitSuffle(b, t1, t9, 0x88);
278  vs[0x2] = create4x128BitSuffle(b, t2, ta, 0x88);
279  vs[0x3] = create4x128BitSuffle(b, t3, tb, 0x88);
280  vs[0x4] = create4x128BitSuffle(b, t4, tc, 0x88);
281  vs[0x5] = create4x128BitSuffle(b, t5, td, 0x88);
282  vs[0x6] = create4x128BitSuffle(b, t6, te, 0x88);
283  vs[0x7] = create4x128BitSuffle(b, t7, tf, 0x88);
284  vs[0x8] = create4x128BitSuffle(b, t0, t8, 0xdd);
285  vs[0x9] = create4x128BitSuffle(b, t1, t9, 0xdd);
286  vs[0xa] = create4x128BitSuffle(b, t2, ta, 0xdd);
287  vs[0xb] = create4x128BitSuffle(b, t3, tb, 0xdd);
288  vs[0xc] = create4x128BitSuffle(b, t4, tc, 0xdd);
289  vs[0xd] = create4x128BitSuffle(b, t5, td, 0xdd);
290  vs[0xe] = create4x128BitSuffle(b, t6, te, 0xdd);
291  vs[0xf] = create4x128BitSuffle(b, t7, tf, 0xdd);
292 
293  auto reshInputType = VectorType::get(
294  {m, n}, cast<VectorType>(source.getType()).getElementType());
295  Value res =
296  b.create<arith::ConstantOp>(reshInputType, b.getZeroAttr(reshInputType));
297  for (int64_t i = 0; i < m; ++i)
298  res = b.create<vector::InsertOp>(vs[i], res, i);
299  return res;
300 }
301 
302 namespace {
303 /// Progressive lowering of TransposeOp.
304 /// One:
305 /// %x = vector.transpose %y, [1, 0]
306 /// is replaced by:
307 /// %z = arith.constant dense<0.000000e+00>
308 /// %0 = vector.extract %y[0, 0]
309 /// %1 = vector.insert %0, %z [0, 0]
310 /// ..
311 /// %x = vector.insert .., .. [.., ..]
312 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
313 public:
315 
317  MLIRContext *context, PatternBenefit benefit = 1)
318  : OpRewritePattern<vector::TransposeOp>(context, benefit),
319  vectorTransformOptions(vectorTransformOptions) {}
320 
321  LogicalResult matchAndRewrite(vector::TransposeOp op,
322  PatternRewriter &rewriter) const override {
323  auto loc = op.getLoc();
324 
325  Value input = op.getVector();
326  VectorType inputType = op.getSourceVectorType();
327  VectorType resType = op.getResultVectorType();
328 
329  // Set up convenience transposition table.
330  ArrayRef<int64_t> transp = op.getPermutation();
331 
332  if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
334  return rewriter.notifyMatchFailure(
335  op, "Options specifies lowering to shuffle");
336 
337  // Replace:
338  // vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
339  // vector<1xnxelty>
340  // with:
341  // vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
342  //
343  // Source with leading unit dim (inverse) is also replaced. Unit dim must
344  // be fixed. Non-unit can be scalable.
345  if (resType.getRank() == 2 &&
346  ((resType.getShape().front() == 1 &&
347  !resType.getScalableDims().front()) ||
348  (resType.getShape().back() == 1 &&
349  !resType.getScalableDims().back())) &&
350  transp == ArrayRef<int64_t>({1, 0})) {
351  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
352  return success();
353  }
354 
355  if (inputType.isScalable())
356  return failure();
357 
358  // Handle a true 2-D matrix transpose differently when requested.
359  if (vectorTransformOptions.vectorTransposeLowering ==
360  vector::VectorTransposeLowering::Flat &&
361  resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
362  Type flattenedType =
363  VectorType::get(resType.getNumElements(), resType.getElementType());
364  auto matrix =
365  rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
366  auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
367  auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
368  Value trans = rewriter.create<vector::FlatTransposeOp>(
369  loc, flattenedType, matrix, rows, columns);
370  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
371  return success();
372  }
373 
374  // Generate unrolled extract/insert ops. We do not unroll the rightmost
375  // (i.e., highest-order) dimensions that are not transposed and leave them
376  // in vector form to improve performance. Therefore, we prune those
377  // dimensions from the shape/transpose data structures used to generate the
378  // extract/insert ops.
379  SmallVector<int64_t> prunedTransp;
380  pruneNonTransposedDims(transp, prunedTransp);
381  size_t numPrunedDims = transp.size() - prunedTransp.size();
382  auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
383  auto prunedInStrides = computeStrides(prunedInShape);
384 
385  // Generates the extract/insert operations for every scalar/vector element
386  // of the leftmost transposed dimensions. We traverse every transpose
387  // element using a linearized index that we delinearize to generate the
388  // appropriate indices for the extract/insert operations.
389  Value result = rewriter.create<arith::ConstantOp>(
390  loc, resType, rewriter.getZeroAttr(resType));
391  int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
392 
393  for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
394  ++linearIdx) {
395  auto extractIdxs = delinearize(linearIdx, prunedInStrides);
396  SmallVector<int64_t> insertIdxs(extractIdxs);
397  applyPermutationToVector(insertIdxs, prunedTransp);
398  Value extractOp =
399  rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
400  result =
401  rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
402  }
403 
404  rewriter.replaceOp(op, result);
405  return success();
406  }
407 
408 private:
409  /// Options to control the vector patterns.
410  vector::VectorTransformsOptions vectorTransformOptions;
411 };
412 
413 /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
414 /// If the strategy is Shuffle1D, it will be lowered to:
415 /// vector.shape_cast 2D -> 1D
416 /// vector.shuffle
417 /// vector.shape_cast 1D -> 2D
418 /// If the strategy is Shuffle16x16, it will be lowered to a sequence of shuffle
419 /// ops on 16xf32 vectors.
420 class TransposeOp2DToShuffleLowering
421  : public OpRewritePattern<vector::TransposeOp> {
422 public:
424 
425  TransposeOp2DToShuffleLowering(
426  vector::VectorTransformsOptions vectorTransformOptions,
427  MLIRContext *context, PatternBenefit benefit = 1)
428  : OpRewritePattern<vector::TransposeOp>(context, benefit),
429  vectorTransformOptions(vectorTransformOptions) {}
430 
431  LogicalResult matchAndRewrite(vector::TransposeOp op,
432  PatternRewriter &rewriter) const override {
433  if (!isShuffleLike(vectorTransformOptions.vectorTransposeLowering))
434  return rewriter.notifyMatchFailure(
435  op, "not using vector shuffle based lowering");
436 
437  if (op.getSourceVectorType().isScalable())
438  return rewriter.notifyMatchFailure(
439  op, "vector shuffle lowering not supported for scalable vectors");
440 
441  auto srcGtOneDims = isTranspose2DSlice(op);
442  if (failed(srcGtOneDims))
443  return rewriter.notifyMatchFailure(
444  op, "expected transposition on a 2D slice");
445 
446  VectorType srcType = op.getSourceVectorType();
447  int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
448  int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
449 
450  // Reshape the n-D input vector with only two dimensions greater than one
451  // to a 2-D vector.
452  Location loc = op.getLoc();
453  auto flattenedType = VectorType::get({n * m}, srcType.getElementType());
454  auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
455  auto reshInput = rewriter.create<vector::ShapeCastOp>(loc, flattenedType,
456  op.getVector());
457 
458  Value res;
459  if (vectorTransformOptions.vectorTransposeLowering ==
460  VectorTransposeLowering::Shuffle16x16 &&
461  m == 16 && n == 16) {
462  reshInput =
463  rewriter.create<vector::ShapeCastOp>(loc, reshInputType, reshInput);
464  res = transposeToShuffle16x16(rewriter, reshInput, m, n);
465  } else {
466  // Fallback to shuffle on 1D approach.
467  res = transposeToShuffle1D(rewriter, reshInput, m, n);
468  }
469 
470  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
471  op, op.getResultVectorType(), res);
472 
473  return success();
474  }
475 
476 private:
477  /// Options to control the vector patterns.
478  vector::VectorTransformsOptions vectorTransformOptions;
479 };
480 } // namespace
481 
484  PatternBenefit benefit) {
485  patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
486  options, patterns.getContext(), benefit);
487 }
static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask.
static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask.
static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m, int n)
Lowers the value to a sequence of vector.shuffle ops.
static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask.
static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2, int numBits)
Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask.
static void pruneNonTransposedDims(ArrayRef< int64_t > transpose, SmallVectorImpl< int64_t > &result)
Given a 'transpose' pattern, prune the rightmost dimensions that are not transposed.
static bool isShuffleLike(VectorTransposeLowering lowering)
Returns true if the lowering option is a vector shuffle based approach.
static SmallVector< int64_t > getUnpackShufflePermFor128Lane(ArrayRef< int64_t > vals, int numBits)
Returns a shuffle mask that builds on vals.
static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit elements) selected by mask...
static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n)
Lowers the value to a vector.shuffle op.
static llvm::ManagedStatic< PassManagerOptions > options
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1541
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
Definition: VectorUtils.cpp:80
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:21
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
Structure to control the behavior of vector transform patterns.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.