MLIR  22.0.0git
LowerVectorToFromElementsToShuffleTree.cpp
Go to the documentation of this file.
1 //===- VectorShuffleTreeBuilder.cpp ----- Vector shuffle tree builder -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements pattern rewrites to lower sequences of
10 // `vector.to_elements` and `vector.from_elements` operations into a tree of
11 // `vector.shuffle` operations.
12 //
13 //===----------------------------------------------------------------------===//
14 
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/MathExtras.h"
23 #include "llvm/Support/raw_ostream.h"
24 
25 namespace mlir {
26 namespace vector {
27 
28 #define GEN_PASS_DEF_LOWERVECTORTOFROMELEMENTSTOSHUFFLETREE
29 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
30 
31 } // namespace vector
32 } // namespace mlir
33 
34 #define DEBUG_TYPE "lower-vector-to-from-elements-to-shuffle-tree"
35 
36 using namespace mlir;
37 using namespace mlir::vector;
38 
39 namespace {
40 
41 // Indentation unit for debug output formatting.
42 [[maybe_unused]] constexpr unsigned kIndScale = 2;
43 
44 /// Represents a closed interval of elements (e.g., [0, 7] = 8 elements).
45 using Interval = std::pair<unsigned, unsigned>;
46 // Sentinel value for uninitialized intervals.
47 constexpr unsigned kMaxUnsigned = std::numeric_limits<unsigned>::max();
48 
49 /// The VectorShuffleTreeBuilder builds a balanced binary tree of
50 /// `vector.shuffle` operations from one or more `vector.to_elements`
51 /// operations feeding a single `vector.from_elements` operation.
52 ///
53 /// The implementation generates hardware-agnostic `vector.shuffle` operations
54 /// that minimize both the number of shuffle operations and the length of
55 /// intermediate vectors (to the extent possible). The tree has the
56 /// following properties:
57 ///
58 /// 1. Vectors are shuffled in pairs by order of appearance in
59 /// the `vector.from_elements` operand list.
60 /// 2. Each vector at each level is used only once.
61 /// 3. The number of levels in the tree is:
62 /// 1 (input vectors) + ceil(max(1,log2(# `vector.to_elements` ops))).
63 /// 4. Vectors at each level of the tree have the same vector length.
64 /// 5. Vector positions that do not need to be shuffled are represented with
65 /// poison in the shuffle mask.
66 ///
67 /// Examples #1: Concatenation of 3x vector<4xf32> to vector<12xf32>:
68 ///
69 /// %0:4 = vector.to_elements %a : vector<4xf32>
70 /// %1:4 = vector.to_elements %b : vector<4xf32>
71 /// %2:4 = vector.to_elements %c : vector<4xf32>
72 /// %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1,
73 /// %1#2, %1#3, %2#0, %2#1, %2#2, %2#3
74 /// : vector<12xf32>
75 /// =>
76 ///
77 /// %shuffle0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7]
78 /// : vector<4xf32>, vector<4xf32>
79 /// %shuffle1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1]
80 /// : vector<4xf32>, vector<4xf32>
81 /// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 2, 3, 4, 5,
82 /// 6, 7, 8, 9, 10, 11]
83 /// : vector<8xf32>, vector<8xf32>
84 ///
85 /// Comments:
86 /// * The shuffle tree has three levels:
87 /// - Level 0 = (%a, %b, %c, %c)
88 /// - Level 1 = (%shuffle0, %shuffle1)
89 /// - Level 2 = (%result)
90 /// * `%a` and `%b` are shuffled first because they appear first in the
91 /// `vector.from_elements` operand list (`%0#0` and `%1#0`).
92 /// * `%c` is shuffled with itself because the number of
93 /// `vector.from_elements` operands is odd.
94 /// * The vector length for level 1 and level 2 are 8 and 16, respectively.
95 /// * `%shuffle1` uses poison values to match the vector length of its
96 /// tree level (8).
97 ///
98 ///
99 /// Example #2: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
100 ///
101 /// %0:5 = vector.to_elements %a : vector<5xf32>
102 /// %1:5 = vector.to_elements %b : vector<5xf32>
103 /// %2:5 = vector.to_elements %c : vector<5xf32>
104 /// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
105 /// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
106 /// =>
107 ///
108 /// %shuffle0 = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6]
109 /// : vector<5xf32>, vector<5xf32>
110 /// %shuffle1 = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1]
111 /// : vector<5xf32>, vector<5xf32>
112 /// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 8, 9, 4, 5, 6, 7, 14]
113 /// : vector<8xf32>, vector<8xf32>
114 ///
115 /// Comments:
116 /// * `%c` and `%b` are shuffled first because they appear first in the
117 /// `vector.from_elements` operand list (`%2#2` and `%1#1`).
118 /// * `%a` is shuffled with itself because the number of
119 /// `vector.from_elements` operands is odd.
120 /// * The vector length for level 1 and level 2 are 8 and 9, respectively.
121 /// * `%shuffle0` uses poison values to mark unused vector positions and
122 /// match the vector length of its tree level (8).
123 ///
124 /// TODO: Implement mask compression to reduce the number of intermediate poison
125 /// values.
126 class VectorShuffleTreeBuilder {
127 public:
128  VectorShuffleTreeBuilder() = delete;
129  VectorShuffleTreeBuilder(FromElementsOp fromElemOp,
130  ArrayRef<ToElementsOp> toElemDefs);
131 
132  /// Analyze the input `vector.to_elements` + `vector.from_elements` sequence
133  /// and compute the shuffle tree configuration. This method does not generate
134  /// any IR.
135  LogicalResult computeShuffleTree();
136 
137  /// Materialize the shuffle tree configuration computed by
138  /// `computeShuffleTree` in the IR.
139  Value generateShuffleTree(PatternRewriter &rewriter);
140 
141 private:
142  // IR input information.
143  FromElementsOp fromElemsOp;
144  SmallVector<ToElementsOp> toElemsDefs;
145 
146  // Shuffle tree configuration.
147  unsigned numLevels;
148  SmallVector<unsigned> vectorSizePerLevel;
149  /// Holds the range of positions each vector in the tree contributes to in the
150  /// final output vector.
151  SmallVector<SmallVector<Interval>> intervalsPerLevel;
152 
153  // Utility methods to compute the shuffle tree configuration.
154  void computeShuffleTreeIntervals();
155  void computeShuffleTreeVectorSizes();
156 
157  /// Dump the shuffle tree configuration.
158  void dump();
159 };
160 
161 VectorShuffleTreeBuilder::VectorShuffleTreeBuilder(
162  FromElementsOp fromElemOp, ArrayRef<ToElementsOp> toElemDefs)
163  : fromElemsOp(fromElemOp), toElemsDefs(toElemDefs) {
164  assert(fromElemsOp && "from_elements op is required");
165  assert(!toElemsDefs.empty() && "At least one to_elements op is required");
166 }
167 
168 /// Duplicate the last operation, value or interval if the total number of them
169 /// is odd. This is useful to simplify the shuffle tree algorithm given that
170 /// vectors are shuffled in pairs and duplication would lead to the last shuffle
171 /// to have a single (duplicated) input vector.
172 template <typename T>
173 static void duplicateLastIfOdd(SmallVectorImpl<T> &values) {
174  if (values.size() % 2 != 0)
175  values.push_back(values.back());
176 }
177 
178 // ===---------------------------------------------------------------------===//
179 // Shuffle Tree Analysis Utilities.
180 // ===---------------------------------------------------------------------===//
181 
182 /// Compute the intervals for all the vectors in the shuffle tree. The interval
183 /// of a vector is the range of positions that the vector contributes to in the
184 /// final output vector.
185 ///
186 /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
187 ///
188 /// %0:5 = vector.to_elements %a : vector<5xf32>
189 /// %1:5 = vector.to_elements %b : vector<5xf32>
190 /// %2:5 = vector.to_elements %c : vector<5xf32>
191 /// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
192 /// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
193 ///
194 /// The shuffle tree has 3 levels. Level 0 has 4 vectors (%2, %1, %0, %0, the
195 /// last one is duplicated to make the number of inputs even) so we compute the
196 /// interval for each vector:
197 ///
198 /// * intervalsPerLevel[0][0] = interval(%2) = [0,6]
199 /// * intervalsPerLevel[0][1] = interval(%1) = [1,7]
200 /// * intervalsPerLevel[0][2] = interval(%0) = [2,8]
201 /// * intervalsPerLevel[0][3] = interval(%0) = [2,8]
202 ///
203 /// Level 1 has 2 vectors, resulting from the shuffling of %2 + %1 and %0 + %0
204 /// so we compute the intervals for each vector at level 1 as:
205 /// * intervalsPerLevel[1][0] = intervalsPerLevel[0][0] U
206 /// intervalsPerLevel[0][1] = [0,7]
207 /// * intervalsPerLevel[1][1] = intervalsPerLevel[0][2] U
208 /// intervalsPerLevel[0][3] = [2,8]
209 ///
210 /// Level 2 is the last level and only contains the output vector so the
211 /// interval should be the whole output vector:
212 /// * intervalsPerLevel[2][0] = intervalsPerLevel[1][0] U
213 /// intervalsPerLevel[1][1] = [0,8]
214 ///
215 void VectorShuffleTreeBuilder::computeShuffleTreeIntervals() {
216  // Map `vector.to_elements` ops to their ordinal position in the
217  // `vector.from_elements` operand list. Make sure duplicated
218  // `vector.to_elements` ops are mapped to the its first occurrence.
219  DenseMap<ToElementsOp, unsigned> toElemsToInputOrdinal;
220  for (const auto &[idx, toElemsOp] : llvm::enumerate(toElemsDefs))
221  toElemsToInputOrdinal.insert({toElemsOp, idx});
222 
223  // Compute intervals for each vector in the shuffle tree. The first
224  // level computation is special-cased to keep the implementation simpler.
225 
226  SmallVector<Interval> firstLevelIntervals(toElemsDefs.size(),
227  {kMaxUnsigned, kMaxUnsigned});
228 
229  for (const auto &[idx, element] :
230  llvm::enumerate(fromElemsOp.getElements())) {
231  auto toElemsOp = cast<ToElementsOp>(element.getDefiningOp());
232  unsigned inputIdx = toElemsToInputOrdinal[toElemsOp];
233  Interval &currentInterval = firstLevelIntervals[inputIdx];
234 
235  // Set lower bound to the first occurrence of the `vector.to_elements`.
236  if (currentInterval.first == kMaxUnsigned)
237  currentInterval.first = idx;
238 
239  // Set upper bound to the last occurrence of the `vector.to_elements`.
240  currentInterval.second = idx;
241  }
242 
243  duplicateLastIfOdd(toElemsDefs);
244  duplicateLastIfOdd(firstLevelIntervals);
245  intervalsPerLevel.push_back(std::move(firstLevelIntervals));
246 
247  // Compute intervals for the remaining levels.
248  for (unsigned level = 1; level < numLevels; ++level) {
249  bool isLastLevel = level == numLevels - 1;
250  const auto &prevLevelIntervals = intervalsPerLevel[level - 1];
251  SmallVector<Interval> currentLevelIntervals(
252  llvm::divideCeil(prevLevelIntervals.size(), 2),
253  {kMaxUnsigned, kMaxUnsigned});
254 
255  size_t currentNumLevels = currentLevelIntervals.size();
256  for (size_t inputIdx = 0; inputIdx < currentNumLevels; ++inputIdx) {
257  auto &interval = currentLevelIntervals[inputIdx];
258  const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2];
259  const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1];
260 
261  // The interval of a vector at the current level is the union of the
262  // intervals of the two vectors from the previous level being shuffled at
263  // this level.
264  interval.first = prevLhsInterval.first;
265  interval.second =
266  std::max(prevLhsInterval.second, prevRhsInterval.second);
267  }
268 
269  // Duplicate the last interval if the number of intervals is odd, except for
270  // the last level as it only contains the output vector, which doesn't have
271  // to be shuffled.
272  if (!isLastLevel)
273  duplicateLastIfOdd(currentLevelIntervals);
274 
275  intervalsPerLevel.push_back(std::move(currentLevelIntervals));
276  }
277 }
278 
279 /// Compute the uniform vector size for each level of the shuffle tree, given
280 /// the intervals of the vectors at each level. The vector size of a level is
281 /// the size of the widest interval at that level.
282 ///
283 /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
284 ///
285 /// Intervals:
286 /// * Level 0: [0,6], [1,7], [2,8], [2,8]
287 /// * Level 1: [0,7], [2,8]
288 /// * Level 2: [0,8]
289 ///
290 /// Vector sizes:
291 /// * Level 0: Arbitrary sizes from input vectors.
292 /// * Level 1: max(size_of([0,7]) = 8, size_of([2,8]) = 7) = 8
293 /// * Level 2: max(size_of([0,8]) = 9) = 9
294 ///
295 void VectorShuffleTreeBuilder::computeShuffleTreeVectorSizes() {
296  // Compute vector size for each level. There are two direct cases:
297  // * First level: the vector size depends on the actual size of the input
298  // vectors and it's allowed to be non-uniform. We set it to 0.
299  // * Last level: the vector size is the output vector size so it doesn't
300  // have to be computed using intervals.
301  vectorSizePerLevel.front() = 0;
302  vectorSizePerLevel.back() =
303  cast<VectorType>(fromElemsOp.getResult().getType()).getNumElements();
304 
305  for (unsigned level = 1; level < numLevels - 1; ++level) {
306  const auto &currentLevelIntervals = intervalsPerLevel[level];
307  unsigned currentVectorSize = 1;
308  size_t numIntervals = currentLevelIntervals.size();
309  for (size_t i = 0; i < numIntervals; ++i) {
310  const auto &interval = currentLevelIntervals[i];
311  unsigned intervalSize = interval.second - interval.first + 1;
312  currentVectorSize = std::max(currentVectorSize, intervalSize);
313  }
314  assert(currentVectorSize > 0 && "vector size must be positive");
315  vectorSizePerLevel[level] = currentVectorSize;
316  }
317 }
318 
319 void VectorShuffleTreeBuilder::dump() {
320  LLVM_DEBUG({
321  unsigned indLv = 0;
322 
323  llvm::dbgs() << "VectorShuffleTreeBuilder Configuration:\n";
324  ++indLv;
325  llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Inputs:\n";
326  ++indLv;
327  for (const auto &toElemsOp : toElemsDefs)
328  llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElemsOp << "\n";
329  llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElemsOp << "\n\n";
330  --indLv;
331 
332  llvm::dbgs() << llvm::indent(indLv, kIndScale)
333  << "* Total levels: " << numLevels << "\n";
334  llvm::dbgs() << llvm::indent(indLv, kIndScale)
335  << "* Vector sizes per level: ";
336  llvm::interleaveComma(vectorSizePerLevel, llvm::dbgs());
337  llvm::dbgs() << "\n";
338  llvm::dbgs() << llvm::indent(indLv, kIndScale)
339  << "* Input intervals per level:\n";
340  ++indLv;
341  for (const auto &[level, intervals] : llvm::enumerate(intervalsPerLevel)) {
342  llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Level " << level
343  << ": ";
344  llvm::interleaveComma(intervals, llvm::dbgs(),
345  [](const Interval &interval) {
346  llvm::dbgs() << "[" << interval.first << ","
347  << interval.second << "]";
348  });
349  llvm::dbgs() << "\n";
350  }
351  });
352 }
353 
354 /// Compute the shuffle tree configuration for the given `vector.to_elements` +
355 /// `vector.from_elements` input sequence. This method builds a balanced binary
356 /// shuffle tree that combines pairs of vectors at each level.
357 ///
358 /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
359 ///
360 /// %0:5 = vector.to_elements %a : vector<5xf32>
361 /// %1:5 = vector.to_elements %b : vector<5xf32>
362 /// %2:5 = vector.to_elements %c : vector<5xf32>
363 /// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
364 /// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
365 ///
366 /// build a tree that looks like:
367 ///
368 /// %2 %1 %0 %0
369 /// \ / \ /
370 /// %2_1 = vector.shuffle %0_0 = vector.shuffle
371 /// \ /
372 /// %2_1_0_0 =vector.shuffle
373 ///
374 /// The actual representation of the shuffle tree configuration is based on
375 /// intervals of each vector at each level of the shuffle tree (i.e., %2, %1,
376 /// %0, %0, %2_1, %0_0 and %2_1_0_0) and the vector size for each level. For
377 /// further details on intervals and vector size computation, please, take a
378 /// look at the corresponding utility functions.
379 LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() {
380  // Initialize shuffle tree information based on its size. For the number of
381  // levels, we add one to account for the input `vector.to_elements` as one
382  // tree level. We need the std::max(1) to account for a single element input.
383  numLevels = 1u + std::max(1u, llvm::Log2_64_Ceil(toElemsDefs.size()));
384  vectorSizePerLevel.resize(numLevels, 0);
385  intervalsPerLevel.reserve(numLevels);
386 
387  computeShuffleTreeIntervals();
388  computeShuffleTreeVectorSizes();
389  dump();
390 
391  return success();
392 }
393 
394 // ===---------------------------------------------------------------------===//
395 // Shuffle Tree Code Generation Utilities.
396 // ===---------------------------------------------------------------------===//
397 
398 /// Compute the permutation mask for shuffling two input `vector.to_elements`
399 /// ops. The permutation mask is the mapping of the vector elements to their
400 /// final position in the output vector, relative to the intermediate output
401 /// vector of the `vector.shuffle` operation combining the two inputs.
402 ///
403 /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
404 ///
405 /// %0:5 = vector.to_elements %a : vector<5xf32>
406 /// %1:5 = vector.to_elements %b : vector<5xf32>
407 /// %2:5 = vector.to_elements %c : vector<5xf32>
408 /// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
409 /// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
410 ///
411 /// =>
412 ///
413 /// // Level 1, vector length = 8
414 /// %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6]
415 /// %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1]
416 ///
417 /// TODO: Implement mask compression to reduce the number of intermediate poison
418 /// values.
419 static SmallVector<int64_t> computePermutationShuffleMask(
420  ToElementsOp toElementOp0, const Interval &interval0,
421  ToElementsOp toElementOp1, const Interval &interval1,
422  FromElementsOp fromElemsOp, unsigned outputVectorSize) {
423  SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex);
424  unsigned inputVectorSize =
425  toElementOp0.getSource().getType().getNumElements();
426 
427  for (const auto &[inputIdx, element] :
428  llvm::enumerate(fromElemsOp.getElements())) {
429  auto currentToElemOp = cast<ToElementsOp>(element.getDefiningOp());
430  // Match `vector.from_elements` operands to the two input ops.
431  if (currentToElemOp != toElementOp0 && currentToElemOp != toElementOp1)
432  continue;
433 
434  // The permutation value for a particular operand is the ordinal position of
435  // the operand in the `vector.to_elements` list of results.
436  unsigned permVal = cast<OpResult>(element).getResultNumber();
437  unsigned maskIdx = inputIdx;
438 
439  // The mask index is the ordinal position of the operand in
440  // `vector.from_elements` operand list. We make this position relative to
441  // the output interval resulting from combining the two input intervals.
442  if (currentToElemOp == toElementOp0) {
443  maskIdx -= interval0.first;
444  } else {
445  // currentToElemOp == toElementOp1
446  unsigned intervalOffset = interval1.first - interval0.first;
447  maskIdx += intervalOffset - interval1.first;
448  permVal += inputVectorSize;
449  }
450 
451  mask[maskIdx] = permVal;
452  }
453 
454  LLVM_DEBUG({
455  unsigned indLv = 1;
456  llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Permutation mask: [";
457  llvm::interleaveComma(mask, llvm::dbgs());
458  llvm::dbgs() << "]\n";
459  ++indLv;
460  llvm::dbgs() << llvm::indent(indLv, kIndScale)
461  << "* Combining: " << toElementOp0 << " and " << toElementOp1
462  << "\n";
463  });
464 
465  return mask;
466 }
467 
468 /// Compute the propagation shuffle mask for combining two intermediate shuffle
469 /// operations of the tree. The propagation shuffle mask is the mapping of the
470 /// intermediate vector elements, which have already been shuffled to their
471 /// relative output position using the mask generated by
472 /// `computePermutationShuffleMask`, to their next position in the tree.
473 ///
474 /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
475 ///
476 /// %0:5 = vector.to_elements %a : vector<5xf32>
477 /// %1:5 = vector.to_elements %b : vector<5xf32>
478 /// %2:5 = vector.to_elements %c : vector<5xf32>
479 /// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
480 /// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
481 ///
482 /// // Level 1, vector length = 8
483 /// %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6]
484 /// %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1]
485 ///
486 /// =>
487 ///
488 /// // Level 2, vector length = 9
489 /// PropagationShuffleMask(%2_1, %0_0) = [0, 1, 8, 9, 4, 5, 6, 7, 14]
490 ///
491 /// TODO: Implement mask compression to reduce the number of intermediate poison
492 /// values.
493 static SmallVector<int64_t> computePropagationShuffleMask(
494  ShuffleOp lhsShuffleOp, const Interval &lhsInterval, ShuffleOp rhsShuffleOp,
495  const Interval &rhsInterval, unsigned outputVectorSize) {
496  ArrayRef<int64_t> lhsShuffleMask = lhsShuffleOp.getMask();
497  ArrayRef<int64_t> rhsShuffleMask = rhsShuffleOp.getMask();
498  unsigned inputVectorSize = lhsShuffleMask.size();
499  assert(inputVectorSize == rhsShuffleMask.size() &&
500  "Expected both shuffle masks to have the same size");
501 
502  bool hasSameInput = lhsShuffleOp == rhsShuffleOp;
503  unsigned lhsRhsOffset = rhsInterval.first - lhsInterval.first;
504  SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex);
505 
506  // Propagate any element from the input mask that is not poison. For the RHS
507  // vector, offset mask index by the distance between the intervals.
508  for (unsigned i = 0; i < inputVectorSize; ++i) {
509  if (lhsShuffleMask[i] != ShuffleOp::kPoisonIndex)
510  mask[i] = i;
511 
512  if (hasSameInput)
513  continue;
514 
515  unsigned rhsIdx = i + lhsRhsOffset;
516  if (rhsShuffleMask[i] != ShuffleOp::kPoisonIndex) {
517  assert(rhsIdx < outputVectorSize && "RHS index out of bounds");
518  assert(mask[rhsIdx] == ShuffleOp::kPoisonIndex && "mask already set");
519  mask[rhsIdx] = i + inputVectorSize;
520  }
521  }
522 
523  LLVM_DEBUG({
524  unsigned indLv = 1;
525  llvm::dbgs() << llvm::indent(indLv, kIndScale)
526  << "* Propagation shuffle mask computation:\n";
527  ++indLv;
528  llvm::dbgs() << llvm::indent(indLv, kIndScale)
529  << "* LHS shuffle op: " << lhsShuffleOp << "\n";
530  llvm::dbgs() << llvm::indent(indLv, kIndScale)
531  << "* RHS shuffle op: " << rhsShuffleOp << "\n";
532  llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Result mask: [";
533  llvm::interleaveComma(mask, llvm::dbgs());
534  llvm::dbgs() << "]\n";
535  });
536 
537  return mask;
538 }
539 
540 /// Materialize the pre-computed shuffle tree configuration in the IR by
541 /// generating the corresponding `vector.shuffle` ops.
542 ///
543 /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
544 ///
545 /// %0:5 = vector.to_elements %a : vector<5xf32>
546 /// %1:5 = vector.to_elements %b : vector<5xf32>
547 /// %2:5 = vector.to_elements %c : vector<5xf32>
548 /// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
549 /// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
550 ///
551 /// with the pre-computed shuffle tree configuration:
552 ///
553 /// * Vector sizes per level: 0, 8, 9
554 /// * Input intervals per level:
555 /// * Level 0: [0,6], [1,7], [2,8], [2,8]
556 /// * Level 1: [0,7], [2,8]
557 /// * Level 2: [0,8]
558 ///
559 /// =>
560 ///
561 /// %0 = vector.shuffle %arg2, %arg1 [2, 6, -1, -1, 7, 2, 0, 6]
562 /// : vector<5xf32>, vector<5xf32>
563 /// %1 = vector.shuffle %arg0, %arg0 [1, 1, -1, -1, -1, -1, 4, -1]
564 /// : vector<5xf32>, vector<5xf32>
565 /// %2 = vector.shuffle %0, %1 [0, 1, 8, 9, 4, 5, 6, 7, 14]
566 /// : vector<8xf32>, vector<8xf32>
567 ///
568 /// The code generation consists of combining pairs of vectors at each level of
569 /// the tree, using the pre-computed tree intervals and vector sizes. The
570 /// algorithm generates two kinds of shuffle masks:
571 /// * Permutation masks: computed for the first level of the tree and permute
572 /// the input vector elements to their relative position in the final
573 /// output.
574 /// * Propagation masks: computed for subsequent levels and propagate the
575 /// elements to the next level without permutation.
576 ///
577 /// For further details on the shuffle mask computation, please, take a look at
578 /// the corresponding `computePermutationShuffleMask` and
579 /// `computePropagationShuffleMask` functions.
580 ///
581 Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) {
582  LLVM_DEBUG(llvm::dbgs() << "VectorShuffleTreeBuilder Code Generation:\n");
583 
584  // Initialize work list with the `vector.to_elements` sources.
585  SmallVector<Value> levelInputs;
586  llvm::transform(toElemsDefs, std::back_inserter(levelInputs),
587  [](ToElementsOp toElemsOp) { return toElemsOp.getSource(); });
588 
589  // Build shuffle tree by combining pairs of vectors (represented by their
590  // corresponding intervals) in one level and producing a new vector with the
591  // next level's vector length. Skip the interval from the last tree level
592  // (actual shuffle tree output) as it doesn't have to be combined with
593  // anything else.
594  Location loc = fromElemsOp.getLoc();
595  unsigned currentLevel = 0;
596  for (const auto &[nextLevelVectorSize, intervals] :
597  llvm::zip_equal(ArrayRef(vectorSizePerLevel).drop_front(),
598  ArrayRef(intervalsPerLevel).drop_back())) {
599 
600  duplicateLastIfOdd(levelInputs);
601 
602  LLVM_DEBUG(llvm::dbgs() << llvm::indent(1, kIndScale)
603  << "* Processing level " << currentLevel
604  << " (output vector size: " << nextLevelVectorSize
605  << ", # inputs: " << levelInputs.size() << ")\n");
606 
607  // Process level input vectors in pairs.
608  SmallVector<Value> levelOutputs;
609  for (size_t i = 0, numLevelInputs = levelInputs.size(); i < numLevelInputs;
610  i += 2) {
611  Value lhsVector = levelInputs[i];
612  Value rhsVector = levelInputs[i + 1];
613  const Interval &lhsInterval = intervals[i];
614  const Interval &rhsInterval = intervals[i + 1];
615 
616  // For the first level of the tree, permute the vector elements to their
617  // relative position in the final output. For subsequent levels, we
618  // propagate the elements to the next level without permutation.
619  SmallVector<int64_t> shuffleMask;
620  if (currentLevel == 0) {
621  shuffleMask = computePermutationShuffleMask(
622  toElemsDefs[i], lhsInterval, toElemsDefs[i + 1], rhsInterval,
623  fromElemsOp, nextLevelVectorSize);
624  } else {
625  auto lhsShuffleOp = cast<ShuffleOp>(lhsVector.getDefiningOp());
626  auto rhsShuffleOp = cast<ShuffleOp>(rhsVector.getDefiningOp());
627  shuffleMask = computePropagationShuffleMask(lhsShuffleOp, lhsInterval,
628  rhsShuffleOp, rhsInterval,
629  nextLevelVectorSize);
630  }
631 
632  Value shuffleVal = vector::ShuffleOp::create(rewriter, loc, lhsVector,
633  rhsVector, shuffleMask);
634  levelOutputs.push_back(shuffleVal);
635  }
636 
637  levelInputs = std::move(levelOutputs);
638  ++currentLevel;
639  }
640 
641  assert(levelInputs.size() == 1 && "Should have exactly one result");
642  return levelInputs.front();
643 }
644 
645 /// Gather and unique all the `vector.to_elements` operations that feed the
646 /// `vector.from_elements` operation. The `vector.to_elements` operations are
647 /// returned in order of appearance in the `vector.from_elements`'s operand
648 /// list.
649 static LogicalResult
650 getToElementsDefiningOps(FromElementsOp fromElemsOp,
651  SmallVectorImpl<ToElementsOp> &toElemsDefs) {
652  SetVector<ToElementsOp> toElemsDefsSet;
653  for (Value element : fromElemsOp.getElements()) {
654  auto toElemsOp = element.getDefiningOp<ToElementsOp>();
655  if (!toElemsOp)
656  return failure();
657  toElemsDefsSet.insert(toElemsOp);
658  }
659 
660  toElemsDefs.assign(toElemsDefsSet.begin(), toElemsDefsSet.end());
661  return success();
662 }
663 
664 /// Pass to rewrite `vector.to_elements` + `vector.from_elements` sequences into
665 /// a tree of `vector.shuffle` operations. Only 1-D input vectors are supported
666 /// for now.
667 struct ToFromElementsToShuffleTreeRewrite final
668  : OpRewritePattern<vector::FromElementsOp> {
669 
670  using Base::Base;
671 
672  LogicalResult matchAndRewrite(vector::FromElementsOp fromElemsOp,
673  PatternRewriter &rewriter) const override {
674  VectorType resultType = fromElemsOp.getType();
675  if (resultType.getRank() != 1)
676  return rewriter.notifyMatchFailure(
677  fromElemsOp,
678  "multi-dimensional output vectors are not supported yet");
679  if (resultType.isScalable())
680  return rewriter.notifyMatchFailure(
681  fromElemsOp,
682  "'vector.from_elements' does not support scalable vectors");
683 
684  // Gather all the `vector.to_elements` operations that feed the
685  // `vector.from_elements` operation. Other op definitions are not supported.
686  SmallVector<ToElementsOp> toElemsDefs;
687  if (failed(getToElementsDefiningOps(fromElemsOp, toElemsDefs)))
688  return rewriter.notifyMatchFailure(fromElemsOp, "unsupported sources");
689 
690  if (llvm::any_of(toElemsDefs, [](ToElementsOp toElemsOp) {
691  return toElemsOp.getSource().getType().getRank() != 1;
692  })) {
693  return rewriter.notifyMatchFailure(
694  fromElemsOp, "multi-dimensional input vectors are not supported yet");
695  }
696 
697  if (llvm::any_of(toElemsDefs, [](ToElementsOp toElemsOp) {
698  return !toElemsOp.getSource().getType().hasRank();
699  })) {
700  return rewriter.notifyMatchFailure(fromElemsOp,
701  "0-D vectors are not supported");
702  }
703 
704  // Avoid generating a shuffle tree for trivial `vector.to_elements` ->
705  // `vector.from_elements` forwarding cases that do not require shuffling.
706  if (toElemsDefs.size() == 1) {
707  ToElementsOp toElemsOp0 = toElemsDefs.front();
708  if (llvm::equal(fromElemsOp.getElements(), toElemsOp0.getResults())) {
709  return rewriter.notifyMatchFailure(
710  fromElemsOp, "trivial forwarding case does not require shuffling");
711  }
712  }
713 
714  VectorShuffleTreeBuilder shuffleTreeBuilder(fromElemsOp, toElemsDefs);
715  if (failed(shuffleTreeBuilder.computeShuffleTree()))
716  return rewriter.notifyMatchFailure(fromElemsOp,
717  "failed to compute shuffle tree");
718 
719  Value finalShuffle = shuffleTreeBuilder.generateShuffleTree(rewriter);
720  rewriter.replaceOp(fromElemsOp, finalShuffle);
721  return success();
722  }
723 };
724 
725 struct LowerVectorToFromElementsToShuffleTreePass
726  : public vector::impl::LowerVectorToFromElementsToShuffleTreeBase<
727  LowerVectorToFromElementsToShuffleTreePass> {
728 
729  void runOnOperation() override {
732 
733  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
734  return signalPassFailure();
735  }
736 };
737 
738 } // namespace
739 
742  patterns.add<ToFromElementsToShuffleTreeRewrite>(patterns.getContext(),
743  benefit);
744 }
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
void populateVectorToFromElementsToShuffleTreePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns to rewrite sequences of vector.to_elements + vector.from_elements operations into a...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314