MLIR 22.0.0git
LowerVectorToFromElementsToShuffleTree.cpp
Go to the documentation of this file.
1//===- VectorShuffleTreeBuilder.cpp ----- Vector shuffle tree builder -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements pattern rewrites to lower sequences of
10// `vector.to_elements` and `vector.from_elements` operations into a tree of
11// `vector.shuffle` operations.
12//
13//===----------------------------------------------------------------------===//
14
20#include "llvm/ADT/DenseMap.h"
21#include "llvm/Support/Debug.h"
22#include "llvm/Support/MathExtras.h"
23#include "llvm/Support/raw_ostream.h"
24
25namespace mlir {
26namespace vector {
27
28#define GEN_PASS_DEF_LOWERVECTORTOFROMELEMENTSTOSHUFFLETREE
29#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
30
31} // namespace vector
32} // namespace mlir
33
34#define DEBUG_TYPE "lower-vector-to-from-elements-to-shuffle-tree"
35
36using namespace mlir;
37using namespace mlir::vector;
38
39namespace {
40
41// Indentation unit for debug output formatting.
42[[maybe_unused]] constexpr unsigned kIndScale = 2;
43
44/// Represents a closed interval of elements (e.g., [0, 7] = 8 elements).
45using Interval = std::pair<unsigned, unsigned>;
46// Sentinel value for uninitialized intervals.
47constexpr unsigned kMaxUnsigned = std::numeric_limits<unsigned>::max();
48
49/// The VectorShuffleTreeBuilder builds a balanced binary tree of
50/// `vector.shuffle` operations from one or more `vector.to_elements`
51/// operations feeding a single `vector.from_elements` operation.
52///
53/// The implementation generates hardware-agnostic `vector.shuffle` operations
54/// that minimize both the number of shuffle operations and the length of
55/// intermediate vectors (to the extent possible). The tree has the
56/// following properties:
57///
58/// 1. Vectors are shuffled in pairs by order of appearance in
59/// the `vector.from_elements` operand list.
60/// 2. Each vector at each level is used only once.
61/// 3. The number of levels in the tree is:
62/// 1 (input vectors) + ceil(max(1,log2(# `vector.to_elements` ops))).
63/// 4. Vectors at each level of the tree have the same vector length.
64/// 5. Vector positions that do not need to be shuffled are represented with
65/// poison in the shuffle mask.
66///
67/// Examples #1: Concatenation of 3x vector<4xf32> to vector<12xf32>:
68///
69/// %0:4 = vector.to_elements %a : vector<4xf32>
70/// %1:4 = vector.to_elements %b : vector<4xf32>
71/// %2:4 = vector.to_elements %c : vector<4xf32>
72/// %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1,
73/// %1#2, %1#3, %2#0, %2#1, %2#2, %2#3
74/// : vector<12xf32>
75/// =>
76///
77/// %shuffle0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7]
78/// : vector<4xf32>, vector<4xf32>
79/// %shuffle1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1]
80/// : vector<4xf32>, vector<4xf32>
81/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 2, 3, 4, 5,
82/// 6, 7, 8, 9, 10, 11]
83/// : vector<8xf32>, vector<8xf32>
84///
85/// Comments:
86/// * The shuffle tree has three levels:
87/// - Level 0 = (%a, %b, %c, %c)
88/// - Level 1 = (%shuffle0, %shuffle1)
89/// - Level 2 = (%result)
90/// * `%a` and `%b` are shuffled first because they appear first in the
91/// `vector.from_elements` operand list (`%0#0` and `%1#0`).
92/// * `%c` is shuffled with itself because the number of
93/// `vector.from_elements` operands is odd.
94/// * The vector length for level 1 and level 2 are 8 and 16, respectively.
95/// * `%shuffle1` uses poison values to match the vector length of its
96/// tree level (8).
97///
98///
99/// Example #2: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
100///
101/// %0:5 = vector.to_elements %a : vector<5xf32>
102/// %1:5 = vector.to_elements %b : vector<5xf32>
103/// %2:5 = vector.to_elements %c : vector<5xf32>
104/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
105/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
106/// =>
107///
108/// %shuffle0 = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6]
109/// : vector<5xf32>, vector<5xf32>
110/// %shuffle1 = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1]
111/// : vector<5xf32>, vector<5xf32>
112/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 8, 9, 4, 5, 6, 7, 14]
113/// : vector<8xf32>, vector<8xf32>
114///
115/// Comments:
116/// * `%c` and `%b` are shuffled first because they appear first in the
117/// `vector.from_elements` operand list (`%2#2` and `%1#1`).
118/// * `%a` is shuffled with itself because the number of
119/// `vector.from_elements` operands is odd.
120/// * The vector length for level 1 and level 2 are 8 and 9, respectively.
121/// * `%shuffle0` uses poison values to mark unused vector positions and
122/// match the vector length of its tree level (8).
123///
124/// TODO: Implement mask compression to reduce the number of intermediate poison
125/// values.
126class VectorShuffleTreeBuilder {
127public:
128 VectorShuffleTreeBuilder() = delete;
129 VectorShuffleTreeBuilder(FromElementsOp fromElemOp,
130 ArrayRef<ToElementsOp> toElemDefs);
131
132 /// Analyze the input `vector.to_elements` + `vector.from_elements` sequence
133 /// and compute the shuffle tree configuration. This method does not generate
134 /// any IR.
135 LogicalResult computeShuffleTree();
136
137 /// Materialize the shuffle tree configuration computed by
138 /// `computeShuffleTree` in the IR.
139 Value generateShuffleTree(PatternRewriter &rewriter);
140
141private:
142 // IR input information.
143 FromElementsOp fromElemsOp;
144 SmallVector<ToElementsOp> toElemsDefs;
145
146 // Shuffle tree configuration.
147 unsigned numLevels;
148 SmallVector<unsigned> vectorSizePerLevel;
149 /// Holds the range of positions each vector in the tree contributes to in the
150 /// final output vector.
151 SmallVector<SmallVector<Interval>> intervalsPerLevel;
152
153 // Utility methods to compute the shuffle tree configuration.
154 void computeShuffleTreeIntervals();
155 void computeShuffleTreeVectorSizes();
156
157 /// Dump the shuffle tree configuration.
158 void dump();
159};
160
161VectorShuffleTreeBuilder::VectorShuffleTreeBuilder(
162 FromElementsOp fromElemOp, ArrayRef<ToElementsOp> toElemDefs)
163 : fromElemsOp(fromElemOp), toElemsDefs(toElemDefs) {
164 assert(fromElemsOp && "from_elements op is required");
165 assert(!toElemsDefs.empty() && "At least one to_elements op is required");
166}
168/// Duplicate the last operation, value or interval if the total number of them
169/// is odd. This is useful to simplify the shuffle tree algorithm given that
170/// vectors are shuffled in pairs and duplication would lead to the last shuffle
171/// to have a single (duplicated) input vector.
172template <typename T>
173static void duplicateLastIfOdd(SmallVectorImpl<T> &values) {
174 if (values.size() % 2 != 0)
175 values.push_back(values.back());
176}
178// ===---------------------------------------------------------------------===//
179// Shuffle Tree Analysis Utilities.
180// ===---------------------------------------------------------------------===//
181
182/// Compute the intervals for all the vectors in the shuffle tree. The interval
183/// of a vector is the range of positions that the vector contributes to in the
184/// final output vector.
185///
186/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
187///
188/// %0:5 = vector.to_elements %a : vector<5xf32>
189/// %1:5 = vector.to_elements %b : vector<5xf32>
190/// %2:5 = vector.to_elements %c : vector<5xf32>
191/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
192/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
193///
194/// The shuffle tree has 3 levels. Level 0 has 4 vectors (%2, %1, %0, %0, the
195/// last one is duplicated to make the number of inputs even) so we compute the
196/// interval for each vector:
197///
198/// * intervalsPerLevel[0][0] = interval(%2) = [0,6]
199/// * intervalsPerLevel[0][1] = interval(%1) = [1,7]
200/// * intervalsPerLevel[0][2] = interval(%0) = [2,8]
201/// * intervalsPerLevel[0][3] = interval(%0) = [2,8]
202///
203/// Level 1 has 2 vectors, resulting from the shuffling of %2 + %1 and %0 + %0
204/// so we compute the intervals for each vector at level 1 as:
205/// * intervalsPerLevel[1][0] = intervalsPerLevel[0][0] U
206/// intervalsPerLevel[0][1] = [0,7]
207/// * intervalsPerLevel[1][1] = intervalsPerLevel[0][2] U
208/// intervalsPerLevel[0][3] = [2,8]
209///
210/// Level 2 is the last level and only contains the output vector so the
211/// interval should be the whole output vector:
212/// * intervalsPerLevel[2][0] = intervalsPerLevel[1][0] U
213/// intervalsPerLevel[1][1] = [0,8]
214///
215void VectorShuffleTreeBuilder::computeShuffleTreeIntervals() {
216 // Map `vector.to_elements` ops to their ordinal position in the
217 // `vector.from_elements` operand list. Make sure duplicated
218 // `vector.to_elements` ops are mapped to the its first occurrence.
219 DenseMap<ToElementsOp, unsigned> toElemsToInputOrdinal;
220 for (const auto &[idx, toElemsOp] : llvm::enumerate(toElemsDefs))
221 toElemsToInputOrdinal.insert({toElemsOp, idx});
222
223 // Compute intervals for each vector in the shuffle tree. The first
224 // level computation is special-cased to keep the implementation simpler.
225
226 SmallVector<Interval> firstLevelIntervals(toElemsDefs.size(),
227 {kMaxUnsigned, kMaxUnsigned});
228
229 for (const auto &[idx, element] :
230 llvm::enumerate(fromElemsOp.getElements())) {
231 auto toElemsOp = cast<ToElementsOp>(element.getDefiningOp());
232 unsigned inputIdx = toElemsToInputOrdinal[toElemsOp];
233 Interval &currentInterval = firstLevelIntervals[inputIdx];
234
235 // Set lower bound to the first occurrence of the `vector.to_elements`.
236 if (currentInterval.first == kMaxUnsigned)
237 currentInterval.first = idx;
238
239 // Set upper bound to the last occurrence of the `vector.to_elements`.
240 currentInterval.second = idx;
241 }
242
243 duplicateLastIfOdd(toElemsDefs);
244 duplicateLastIfOdd(firstLevelIntervals);
245 intervalsPerLevel.push_back(std::move(firstLevelIntervals));
246
247 // Compute intervals for the remaining levels.
248 for (unsigned level = 1; level < numLevels; ++level) {
249 bool isLastLevel = level == numLevels - 1;
250 const auto &prevLevelIntervals = intervalsPerLevel[level - 1];
251 SmallVector<Interval> currentLevelIntervals(
252 llvm::divideCeil(prevLevelIntervals.size(), 2),
253 {kMaxUnsigned, kMaxUnsigned});
254
255 size_t currentNumLevels = currentLevelIntervals.size();
256 for (size_t inputIdx = 0; inputIdx < currentNumLevels; ++inputIdx) {
257 auto &interval = currentLevelIntervals[inputIdx];
258 const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2];
259 const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1];
260
261 // The interval of a vector at the current level is the union of the
262 // intervals of the two vectors from the previous level being shuffled at
263 // this level.
264 interval.first = prevLhsInterval.first;
265 interval.second =
266 std::max(prevLhsInterval.second, prevRhsInterval.second);
267 }
268
269 // Duplicate the last interval if the number of intervals is odd, except for
270 // the last level as it only contains the output vector, which doesn't have
271 // to be shuffled.
272 if (!isLastLevel)
273 duplicateLastIfOdd(currentLevelIntervals);
274
275 intervalsPerLevel.push_back(std::move(currentLevelIntervals));
276 }
277}
278
279/// Compute the uniform vector size for each level of the shuffle tree, given
280/// the intervals of the vectors at each level. The vector size of a level is
281/// the size of the widest interval at that level.
282///
283/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
284///
285/// Intervals:
286/// * Level 0: [0,6], [1,7], [2,8], [2,8]
287/// * Level 1: [0,7], [2,8]
288/// * Level 2: [0,8]
289///
290/// Vector sizes:
291/// * Level 0: Arbitrary sizes from input vectors.
292/// * Level 1: max(size_of([0,7]) = 8, size_of([2,8]) = 7) = 8
293/// * Level 2: max(size_of([0,8]) = 9) = 9
294///
295void VectorShuffleTreeBuilder::computeShuffleTreeVectorSizes() {
296 // Compute vector size for each level. There are two direct cases:
297 // * First level: the vector size depends on the actual size of the input
298 // vectors and it's allowed to be non-uniform. We set it to 0.
299 // * Last level: the vector size is the output vector size so it doesn't
300 // have to be computed using intervals.
301 vectorSizePerLevel.front() = 0;
302 vectorSizePerLevel.back() =
303 cast<VectorType>(fromElemsOp.getResult().getType()).getNumElements();
304
305 for (unsigned level = 1; level < numLevels - 1; ++level) {
306 const auto &currentLevelIntervals = intervalsPerLevel[level];
307 unsigned currentVectorSize = 1;
308 size_t numIntervals = currentLevelIntervals.size();
309 for (size_t i = 0; i < numIntervals; ++i) {
310 const auto &interval = currentLevelIntervals[i];
311 unsigned intervalSize = interval.second - interval.first + 1;
312 currentVectorSize = std::max(currentVectorSize, intervalSize);
313 }
314 assert(currentVectorSize > 0 && "vector size must be positive");
315 vectorSizePerLevel[level] = currentVectorSize;
316 }
317}
318
319void VectorShuffleTreeBuilder::dump() {
320 LLVM_DEBUG({
321 unsigned indLv = 0;
322
323 llvm::dbgs() << "VectorShuffleTreeBuilder Configuration:\n";
324 ++indLv;
325 llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Inputs:\n";
326 ++indLv;
327 for (const auto &toElemsOp : toElemsDefs)
328 llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElemsOp << "\n";
329 llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElemsOp << "\n\n";
330 --indLv;
331
332 llvm::dbgs() << llvm::indent(indLv, kIndScale)
333 << "* Total levels: " << numLevels << "\n";
334 llvm::dbgs() << llvm::indent(indLv, kIndScale)
335 << "* Vector sizes per level: ";
336 llvm::interleaveComma(vectorSizePerLevel, llvm::dbgs());
337 llvm::dbgs() << "\n";
338 llvm::dbgs() << llvm::indent(indLv, kIndScale)
339 << "* Input intervals per level:\n";
340 ++indLv;
341 for (const auto &[level, intervals] : llvm::enumerate(intervalsPerLevel)) {
342 llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Level " << level
343 << ": ";
344 llvm::interleaveComma(intervals, llvm::dbgs(),
345 [](const Interval &interval) {
346 llvm::dbgs() << "[" << interval.first << ","
347 << interval.second << "]";
348 });
349 llvm::dbgs() << "\n";
350 }
351 });
352}
353
354/// Compute the shuffle tree configuration for the given `vector.to_elements` +
355/// `vector.from_elements` input sequence. This method builds a balanced binary
356/// shuffle tree that combines pairs of vectors at each level.
357///
358/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
359///
360/// %0:5 = vector.to_elements %a : vector<5xf32>
361/// %1:5 = vector.to_elements %b : vector<5xf32>
362/// %2:5 = vector.to_elements %c : vector<5xf32>
363/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
364/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
365///
366/// build a tree that looks like:
367///
368/// %2 %1 %0 %0
369/// \ / \ /
370/// %2_1 = vector.shuffle %0_0 = vector.shuffle
371/// \ /
372/// %2_1_0_0 =vector.shuffle
373///
374/// The actual representation of the shuffle tree configuration is based on
375/// intervals of each vector at each level of the shuffle tree (i.e., %2, %1,
376/// %0, %0, %2_1, %0_0 and %2_1_0_0) and the vector size for each level. For
377/// further details on intervals and vector size computation, please, take a
378/// look at the corresponding utility functions.
379LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() {
380 // Initialize shuffle tree information based on its size. For the number of
381 // levels, we add one to account for the input `vector.to_elements` as one
382 // tree level. We need the std::max(1) to account for a single element input.
383 numLevels = 1u + std::max(1u, llvm::Log2_64_Ceil(toElemsDefs.size()));
384 vectorSizePerLevel.resize(numLevels, 0);
385 intervalsPerLevel.reserve(numLevels);
386
387 computeShuffleTreeIntervals();
388 computeShuffleTreeVectorSizes();
389 dump();
390
391 return success();
392}
393
394// ===---------------------------------------------------------------------===//
395// Shuffle Tree Code Generation Utilities.
396// ===---------------------------------------------------------------------===//
397
398/// Compute the permutation mask for shuffling two input `vector.to_elements`
399/// ops. The permutation mask is the mapping of the vector elements to their
400/// final position in the output vector, relative to the intermediate output
401/// vector of the `vector.shuffle` operation combining the two inputs.
402///
403/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
404///
405/// %0:5 = vector.to_elements %a : vector<5xf32>
406/// %1:5 = vector.to_elements %b : vector<5xf32>
407/// %2:5 = vector.to_elements %c : vector<5xf32>
408/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
409/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
410///
411/// =>
412///
413/// // Level 1, vector length = 8
414/// %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6]
415/// %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1]
416///
417/// TODO: Implement mask compression to reduce the number of intermediate poison
418/// values.
419static SmallVector<int64_t> computePermutationShuffleMask(
420 ToElementsOp toElementOp0, const Interval &interval0,
421 ToElementsOp toElementOp1, const Interval &interval1,
422 FromElementsOp fromElemsOp, unsigned outputVectorSize) {
423 SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex);
424 unsigned inputVectorSize =
425 toElementOp0.getSource().getType().getNumElements();
426
427 for (const auto &[inputIdx, element] :
428 llvm::enumerate(fromElemsOp.getElements())) {
429 auto currentToElemOp = cast<ToElementsOp>(element.getDefiningOp());
430 // Match `vector.from_elements` operands to the two input ops.
431 if (currentToElemOp != toElementOp0 && currentToElemOp != toElementOp1)
432 continue;
433
434 // The permutation value for a particular operand is the ordinal position of
435 // the operand in the `vector.to_elements` list of results.
436 unsigned permVal = cast<OpResult>(element).getResultNumber();
437 unsigned maskIdx = inputIdx;
438
439 // The mask index is the ordinal position of the operand in
440 // `vector.from_elements` operand list. We make this position relative to
441 // the output interval resulting from combining the two input intervals.
442 if (currentToElemOp == toElementOp0) {
443 maskIdx -= interval0.first;
444 } else {
445 // currentToElemOp == toElementOp1
446 unsigned intervalOffset = interval1.first - interval0.first;
447 maskIdx += intervalOffset - interval1.first;
448 permVal += inputVectorSize;
449 }
450
451 mask[maskIdx] = permVal;
452 }
453
454 LLVM_DEBUG({
455 unsigned indLv = 1;
456 llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Permutation mask: [";
457 llvm::interleaveComma(mask, llvm::dbgs());
458 llvm::dbgs() << "]\n";
459 ++indLv;
460 llvm::dbgs() << llvm::indent(indLv, kIndScale)
461 << "* Combining: " << toElementOp0 << " and " << toElementOp1
462 << "\n";
463 });
464
465 return mask;
466}
467
468/// Compute the propagation shuffle mask for combining two intermediate shuffle
469/// operations of the tree. The propagation shuffle mask is the mapping of the
470/// intermediate vector elements, which have already been shuffled to their
471/// relative output position using the mask generated by
472/// `computePermutationShuffleMask`, to their next position in the tree.
473///
474/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
475///
476/// %0:5 = vector.to_elements %a : vector<5xf32>
477/// %1:5 = vector.to_elements %b : vector<5xf32>
478/// %2:5 = vector.to_elements %c : vector<5xf32>
479/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
480/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
481///
482/// // Level 1, vector length = 8
483/// %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6]
484/// %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1]
485///
486/// =>
487///
488/// // Level 2, vector length = 9
489/// PropagationShuffleMask(%2_1, %0_0) = [0, 1, 8, 9, 4, 5, 6, 7, 14]
490///
491/// TODO: Implement mask compression to reduce the number of intermediate poison
492/// values.
493static SmallVector<int64_t> computePropagationShuffleMask(
494 ShuffleOp lhsShuffleOp, const Interval &lhsInterval, ShuffleOp rhsShuffleOp,
495 const Interval &rhsInterval, unsigned outputVectorSize) {
496 ArrayRef<int64_t> lhsShuffleMask = lhsShuffleOp.getMask();
497 ArrayRef<int64_t> rhsShuffleMask = rhsShuffleOp.getMask();
498 unsigned inputVectorSize = lhsShuffleMask.size();
499 assert(inputVectorSize == rhsShuffleMask.size() &&
500 "Expected both shuffle masks to have the same size");
501
502 bool hasSameInput = lhsShuffleOp == rhsShuffleOp;
503 unsigned lhsRhsOffset = rhsInterval.first - lhsInterval.first;
504 SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex);
505
506 // Propagate any element from the input mask that is not poison. For the RHS
507 // vector, offset mask index by the distance between the intervals.
508 for (unsigned i = 0; i < inputVectorSize; ++i) {
509 if (lhsShuffleMask[i] != ShuffleOp::kPoisonIndex)
510 mask[i] = i;
511
512 if (hasSameInput)
513 continue;
514
515 unsigned rhsIdx = i + lhsRhsOffset;
516 if (rhsShuffleMask[i] != ShuffleOp::kPoisonIndex) {
517 assert(rhsIdx < outputVectorSize && "RHS index out of bounds");
518 assert(mask[rhsIdx] == ShuffleOp::kPoisonIndex && "mask already set");
519 mask[rhsIdx] = i + inputVectorSize;
520 }
521 }
522
523 LLVM_DEBUG({
524 unsigned indLv = 1;
525 llvm::dbgs() << llvm::indent(indLv, kIndScale)
526 << "* Propagation shuffle mask computation:\n";
527 ++indLv;
528 llvm::dbgs() << llvm::indent(indLv, kIndScale)
529 << "* LHS shuffle op: " << lhsShuffleOp << "\n";
530 llvm::dbgs() << llvm::indent(indLv, kIndScale)
531 << "* RHS shuffle op: " << rhsShuffleOp << "\n";
532 llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Result mask: [";
533 llvm::interleaveComma(mask, llvm::dbgs());
534 llvm::dbgs() << "]\n";
535 });
536
537 return mask;
538}
539
540/// Materialize the pre-computed shuffle tree configuration in the IR by
541/// generating the corresponding `vector.shuffle` ops.
542///
543/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
544///
545/// %0:5 = vector.to_elements %a : vector<5xf32>
546/// %1:5 = vector.to_elements %b : vector<5xf32>
547/// %2:5 = vector.to_elements %c : vector<5xf32>
548/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
549/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
550///
551/// with the pre-computed shuffle tree configuration:
552///
553/// * Vector sizes per level: 0, 8, 9
554/// * Input intervals per level:
555/// * Level 0: [0,6], [1,7], [2,8], [2,8]
556/// * Level 1: [0,7], [2,8]
557/// * Level 2: [0,8]
558///
559/// =>
560///
561/// %0 = vector.shuffle %arg2, %arg1 [2, 6, -1, -1, 7, 2, 0, 6]
562/// : vector<5xf32>, vector<5xf32>
563/// %1 = vector.shuffle %arg0, %arg0 [1, 1, -1, -1, -1, -1, 4, -1]
564/// : vector<5xf32>, vector<5xf32>
565/// %2 = vector.shuffle %0, %1 [0, 1, 8, 9, 4, 5, 6, 7, 14]
566/// : vector<8xf32>, vector<8xf32>
567///
568/// The code generation consists of combining pairs of vectors at each level of
569/// the tree, using the pre-computed tree intervals and vector sizes. The
570/// algorithm generates two kinds of shuffle masks:
571/// * Permutation masks: computed for the first level of the tree and permute
572/// the input vector elements to their relative position in the final
573/// output.
574/// * Propagation masks: computed for subsequent levels and propagate the
575/// elements to the next level without permutation.
576///
577/// For further details on the shuffle mask computation, please, take a look at
578/// the corresponding `computePermutationShuffleMask` and
579/// `computePropagationShuffleMask` functions.
580///
581Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) {
582 LLVM_DEBUG(llvm::dbgs() << "VectorShuffleTreeBuilder Code Generation:\n");
583
584 // Initialize work list with the `vector.to_elements` sources.
585 SmallVector<Value> levelInputs;
586 llvm::transform(toElemsDefs, std::back_inserter(levelInputs),
587 [](ToElementsOp toElemsOp) { return toElemsOp.getSource(); });
588
589 // Build shuffle tree by combining pairs of vectors (represented by their
590 // corresponding intervals) in one level and producing a new vector with the
591 // next level's vector length. Skip the interval from the last tree level
592 // (actual shuffle tree output) as it doesn't have to be combined with
593 // anything else.
594 Location loc = fromElemsOp.getLoc();
595 unsigned currentLevel = 0;
596 for (const auto &[nextLevelVectorSize, intervals] :
597 llvm::zip_equal(ArrayRef(vectorSizePerLevel).drop_front(),
598 ArrayRef(intervalsPerLevel).drop_back())) {
599
600 duplicateLastIfOdd(levelInputs);
601
602 LLVM_DEBUG(llvm::dbgs() << llvm::indent(1, kIndScale)
603 << "* Processing level " << currentLevel
604 << " (output vector size: " << nextLevelVectorSize
605 << ", # inputs: " << levelInputs.size() << ")\n");
606
607 // Process level input vectors in pairs.
608 SmallVector<Value> levelOutputs;
609 for (size_t i = 0, numLevelInputs = levelInputs.size(); i < numLevelInputs;
610 i += 2) {
611 Value lhsVector = levelInputs[i];
612 Value rhsVector = levelInputs[i + 1];
613 const Interval &lhsInterval = intervals[i];
614 const Interval &rhsInterval = intervals[i + 1];
615
616 // For the first level of the tree, permute the vector elements to their
617 // relative position in the final output. For subsequent levels, we
618 // propagate the elements to the next level without permutation.
619 SmallVector<int64_t> shuffleMask;
620 if (currentLevel == 0) {
621 shuffleMask = computePermutationShuffleMask(
622 toElemsDefs[i], lhsInterval, toElemsDefs[i + 1], rhsInterval,
623 fromElemsOp, nextLevelVectorSize);
624 } else {
625 auto lhsShuffleOp = cast<ShuffleOp>(lhsVector.getDefiningOp());
626 auto rhsShuffleOp = cast<ShuffleOp>(rhsVector.getDefiningOp());
627 shuffleMask = computePropagationShuffleMask(lhsShuffleOp, lhsInterval,
628 rhsShuffleOp, rhsInterval,
629 nextLevelVectorSize);
630 }
631
632 Value shuffleVal = vector::ShuffleOp::create(rewriter, loc, lhsVector,
633 rhsVector, shuffleMask);
634 levelOutputs.push_back(shuffleVal);
635 }
636
637 levelInputs = std::move(levelOutputs);
638 ++currentLevel;
639 }
640
641 assert(levelInputs.size() == 1 && "Should have exactly one result");
642 return levelInputs.front();
643}
644
645/// Gather and unique all the `vector.to_elements` operations that feed the
646/// `vector.from_elements` operation. The `vector.to_elements` operations are
647/// returned in order of appearance in the `vector.from_elements`'s operand
648/// list.
649static LogicalResult
650getToElementsDefiningOps(FromElementsOp fromElemsOp,
651 SmallVectorImpl<ToElementsOp> &toElemsDefs) {
652 SetVector<ToElementsOp> toElemsDefsSet;
653 for (Value element : fromElemsOp.getElements()) {
654 auto toElemsOp = element.getDefiningOp<ToElementsOp>();
655 if (!toElemsOp)
656 return failure();
657 toElemsDefsSet.insert(toElemsOp);
658 }
659
660 toElemsDefs.assign(toElemsDefsSet.begin(), toElemsDefsSet.end());
661 return success();
662}
663
664/// Pass to rewrite `vector.to_elements` + `vector.from_elements` sequences into
665/// a tree of `vector.shuffle` operations. Only 1-D input vectors are supported
666/// for now.
667struct ToFromElementsToShuffleTreeRewrite final
668 : OpRewritePattern<vector::FromElementsOp> {
669
670 using Base::Base;
671
672 LogicalResult matchAndRewrite(vector::FromElementsOp fromElemsOp,
673 PatternRewriter &rewriter) const override {
674 VectorType resultType = fromElemsOp.getType();
675 if (resultType.getRank() != 1)
676 return rewriter.notifyMatchFailure(
677 fromElemsOp,
678 "multi-dimensional output vectors are not supported yet");
679 if (resultType.isScalable())
680 return rewriter.notifyMatchFailure(
681 fromElemsOp,
682 "'vector.from_elements' does not support scalable vectors");
683
684 // Gather all the `vector.to_elements` operations that feed the
685 // `vector.from_elements` operation. Other op definitions are not supported.
686 SmallVector<ToElementsOp> toElemsDefs;
687 if (failed(getToElementsDefiningOps(fromElemsOp, toElemsDefs)))
688 return rewriter.notifyMatchFailure(fromElemsOp, "unsupported sources");
689
690 if (llvm::any_of(toElemsDefs, [](ToElementsOp toElemsOp) {
691 return toElemsOp.getSource().getType().getRank() != 1;
692 })) {
693 return rewriter.notifyMatchFailure(
694 fromElemsOp, "multi-dimensional input vectors are not supported yet");
695 }
696
697 if (llvm::any_of(toElemsDefs, [](ToElementsOp toElemsOp) {
698 return !toElemsOp.getSource().getType().hasRank();
699 })) {
700 return rewriter.notifyMatchFailure(fromElemsOp,
701 "0-D vectors are not supported");
702 }
703
704 // Avoid generating a shuffle tree for trivial `vector.to_elements` ->
705 // `vector.from_elements` forwarding cases that do not require shuffling.
706 if (toElemsDefs.size() == 1) {
707 ToElementsOp toElemsOp0 = toElemsDefs.front();
708 if (llvm::equal(fromElemsOp.getElements(), toElemsOp0.getResults())) {
709 return rewriter.notifyMatchFailure(
710 fromElemsOp, "trivial forwarding case does not require shuffling");
711 }
712 }
713
714 VectorShuffleTreeBuilder shuffleTreeBuilder(fromElemsOp, toElemsDefs);
715 if (failed(shuffleTreeBuilder.computeShuffleTree()))
716 return rewriter.notifyMatchFailure(fromElemsOp,
717 "failed to compute shuffle tree");
718
719 Value finalShuffle = shuffleTreeBuilder.generateShuffleTree(rewriter);
720 rewriter.replaceOp(fromElemsOp, finalShuffle);
721 return success();
722 }
723};
724
725struct LowerVectorToFromElementsToShuffleTreePass
726 : public vector::impl::LowerVectorToFromElementsToShuffleTreeBase<
727 LowerVectorToFromElementsToShuffleTreePass> {
728
729 void runOnOperation() override {
730 RewritePatternSet patterns(&getContext());
732
733 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
734 return signalPassFailure();
735 }
736};
737
738} // namespace
739
742 patterns.add<ToFromElementsToShuffleTreeRewrite>(patterns.getContext(),
743 benefit);
744}
return success()
b getContext())
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void populateVectorToFromElementsToShuffleTreePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns to rewrite sequences of vector.to_elements + vector.from_elements operations into a...
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126