20#include "llvm/ADT/DenseMap.h"
21#include "llvm/Support/Debug.h"
22#include "llvm/Support/MathExtras.h"
23#include "llvm/Support/raw_ostream.h"
28#define GEN_PASS_DEF_LOWERVECTORTOFROMELEMENTSTOSHUFFLETREE
29#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
34#define DEBUG_TYPE "lower-vector-to-from-elements-to-shuffle-tree"
42[[maybe_unused]]
constexpr unsigned kIndScale = 2;
45using Interval = std::pair<unsigned, unsigned>;
47constexpr unsigned kMaxUnsigned = std::numeric_limits<unsigned>::max();
126class VectorShuffleTreeBuilder {
128 VectorShuffleTreeBuilder() =
delete;
129 VectorShuffleTreeBuilder(FromElementsOp fromElemOp,
135 LogicalResult computeShuffleTree();
143 FromElementsOp fromElemsOp;
154 void computeShuffleTreeIntervals();
155 void computeShuffleTreeVectorSizes();
161VectorShuffleTreeBuilder::VectorShuffleTreeBuilder(
163 : fromElemsOp(fromElemOp), toElemsDefs(toElemDefs) {
164 assert(fromElemsOp &&
"from_elements op is required");
165 assert(!toElemsDefs.empty() &&
"At least one to_elements op is required");
174 if (values.size() % 2 != 0)
175 values.push_back(values.back());
215void VectorShuffleTreeBuilder::computeShuffleTreeIntervals() {
220 for (
const auto &[idx, toElemsOp] : llvm::enumerate(toElemsDefs))
221 toElemsToInputOrdinal.insert({toElemsOp, idx});
227 {kMaxUnsigned, kMaxUnsigned});
229 for (
const auto &[idx, element] :
230 llvm::enumerate(fromElemsOp.getElements())) {
231 auto toElemsOp = cast<ToElementsOp>(element.getDefiningOp());
232 unsigned inputIdx = toElemsToInputOrdinal[toElemsOp];
233 Interval ¤tInterval = firstLevelIntervals[inputIdx];
236 if (currentInterval.first == kMaxUnsigned)
237 currentInterval.first = idx;
240 currentInterval.second = idx;
243 duplicateLastIfOdd(toElemsDefs);
244 duplicateLastIfOdd(firstLevelIntervals);
245 intervalsPerLevel.push_back(std::move(firstLevelIntervals));
248 for (
unsigned level = 1; level < numLevels; ++level) {
249 bool isLastLevel = level == numLevels - 1;
250 const auto &prevLevelIntervals = intervalsPerLevel[level - 1];
252 llvm::divideCeil(prevLevelIntervals.size(), 2),
253 {kMaxUnsigned, kMaxUnsigned});
255 size_t currentNumLevels = currentLevelIntervals.size();
256 for (
size_t inputIdx = 0; inputIdx < currentNumLevels; ++inputIdx) {
257 auto &interval = currentLevelIntervals[inputIdx];
258 const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2];
259 const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1];
264 interval.first = prevLhsInterval.first;
266 std::max(prevLhsInterval.second, prevRhsInterval.second);
273 duplicateLastIfOdd(currentLevelIntervals);
275 intervalsPerLevel.push_back(std::move(currentLevelIntervals));
295void VectorShuffleTreeBuilder::computeShuffleTreeVectorSizes() {
301 vectorSizePerLevel.front() = 0;
302 vectorSizePerLevel.back() =
303 cast<VectorType>(fromElemsOp.getResult().getType()).getNumElements();
305 for (
unsigned level = 1; level < numLevels - 1; ++level) {
306 const auto ¤tLevelIntervals = intervalsPerLevel[level];
307 unsigned currentVectorSize = 1;
308 size_t numIntervals = currentLevelIntervals.size();
309 for (
size_t i = 0; i < numIntervals; ++i) {
310 const auto &interval = currentLevelIntervals[i];
311 unsigned intervalSize = interval.second - interval.first + 1;
312 currentVectorSize = std::max(currentVectorSize, intervalSize);
314 assert(currentVectorSize > 0 &&
"vector size must be positive");
315 vectorSizePerLevel[level] = currentVectorSize;
319void VectorShuffleTreeBuilder::dump() {
323 llvm::dbgs() <<
"VectorShuffleTreeBuilder Configuration:\n";
325 llvm::dbgs() << llvm::indent(indLv, kIndScale) <<
"* Inputs:\n";
327 for (
const auto &toElemsOp : toElemsDefs)
328 llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElemsOp <<
"\n";
329 llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElemsOp <<
"\n\n";
332 llvm::dbgs() << llvm::indent(indLv, kIndScale)
333 <<
"* Total levels: " << numLevels <<
"\n";
334 llvm::dbgs() << llvm::indent(indLv, kIndScale)
335 <<
"* Vector sizes per level: ";
336 llvm::interleaveComma(vectorSizePerLevel, llvm::dbgs());
337 llvm::dbgs() <<
"\n";
338 llvm::dbgs() << llvm::indent(indLv, kIndScale)
339 <<
"* Input intervals per level:\n";
341 for (
const auto &[level, intervals] : llvm::enumerate(intervalsPerLevel)) {
342 llvm::dbgs() << llvm::indent(indLv, kIndScale) <<
"* Level " << level
344 llvm::interleaveComma(intervals, llvm::dbgs(),
345 [](
const Interval &interval) {
346 llvm::dbgs() <<
"[" << interval.first <<
","
347 << interval.second <<
"]";
349 llvm::dbgs() <<
"\n";
379LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() {
383 numLevels = 1u + std::max(1u, llvm::Log2_64_Ceil(toElemsDefs.size()));
384 vectorSizePerLevel.resize(numLevels, 0);
385 intervalsPerLevel.reserve(numLevels);
387 computeShuffleTreeIntervals();
388 computeShuffleTreeVectorSizes();
419static SmallVector<int64_t> computePermutationShuffleMask(
420 ToElementsOp toElementOp0,
const Interval &interval0,
421 ToElementsOp toElementOp1,
const Interval &interval1,
422 FromElementsOp fromElemsOp,
unsigned outputVectorSize) {
423 SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex);
424 unsigned inputVectorSize =
425 toElementOp0.getSource().getType().getNumElements();
427 for (
const auto &[inputIdx, element] :
428 llvm::enumerate(fromElemsOp.getElements())) {
429 auto currentToElemOp = cast<ToElementsOp>(element.getDefiningOp());
431 if (currentToElemOp != toElementOp0 && currentToElemOp != toElementOp1)
436 unsigned permVal = cast<OpResult>(element).getResultNumber();
437 unsigned maskIdx = inputIdx;
442 if (currentToElemOp == toElementOp0) {
443 maskIdx -= interval0.first;
446 unsigned intervalOffset = interval1.first - interval0.first;
447 maskIdx += intervalOffset - interval1.first;
448 permVal += inputVectorSize;
451 mask[maskIdx] = permVal;
456 llvm::dbgs() << llvm::indent(indLv, kIndScale) <<
"* Permutation mask: [";
457 llvm::interleaveComma(mask, llvm::dbgs());
458 llvm::dbgs() <<
"]\n";
460 llvm::dbgs() << llvm::indent(indLv, kIndScale)
461 <<
"* Combining: " << toElementOp0 <<
" and " << toElementOp1
493static SmallVector<int64_t> computePropagationShuffleMask(
494 ShuffleOp lhsShuffleOp,
const Interval &lhsInterval, ShuffleOp rhsShuffleOp,
495 const Interval &rhsInterval,
unsigned outputVectorSize) {
496 ArrayRef<int64_t> lhsShuffleMask = lhsShuffleOp.getMask();
497 ArrayRef<int64_t> rhsShuffleMask = rhsShuffleOp.getMask();
498 unsigned inputVectorSize = lhsShuffleMask.size();
499 assert(inputVectorSize == rhsShuffleMask.size() &&
500 "Expected both shuffle masks to have the same size");
502 bool hasSameInput = lhsShuffleOp == rhsShuffleOp;
503 unsigned lhsRhsOffset = rhsInterval.first - lhsInterval.first;
504 SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex);
508 for (
unsigned i = 0; i < inputVectorSize; ++i) {
509 if (lhsShuffleMask[i] != ShuffleOp::kPoisonIndex)
515 unsigned rhsIdx = i + lhsRhsOffset;
516 if (rhsShuffleMask[i] != ShuffleOp::kPoisonIndex) {
517 assert(rhsIdx < outputVectorSize &&
"RHS index out of bounds");
518 assert(mask[rhsIdx] == ShuffleOp::kPoisonIndex &&
"mask already set");
519 mask[rhsIdx] = i + inputVectorSize;
525 llvm::dbgs() << llvm::indent(indLv, kIndScale)
526 <<
"* Propagation shuffle mask computation:\n";
528 llvm::dbgs() << llvm::indent(indLv, kIndScale)
529 <<
"* LHS shuffle op: " << lhsShuffleOp <<
"\n";
530 llvm::dbgs() << llvm::indent(indLv, kIndScale)
531 <<
"* RHS shuffle op: " << rhsShuffleOp <<
"\n";
532 llvm::dbgs() << llvm::indent(indLv, kIndScale) <<
"* Result mask: [";
533 llvm::interleaveComma(mask, llvm::dbgs());
534 llvm::dbgs() <<
"]\n";
581Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) {
582 LLVM_DEBUG(llvm::dbgs() <<
"VectorShuffleTreeBuilder Code Generation:\n");
585 SmallVector<Value> levelInputs;
586 llvm::transform(toElemsDefs, std::back_inserter(levelInputs),
587 [](ToElementsOp toElemsOp) {
return toElemsOp.getSource(); });
594 Location loc = fromElemsOp.getLoc();
595 unsigned currentLevel = 0;
596 for (
const auto &[nextLevelVectorSize, intervals] :
597 llvm::zip_equal(ArrayRef(vectorSizePerLevel).drop_front(),
598 ArrayRef(intervalsPerLevel).drop_back())) {
600 duplicateLastIfOdd(levelInputs);
602 LLVM_DEBUG(llvm::dbgs() << llvm::indent(1, kIndScale)
603 <<
"* Processing level " << currentLevel
604 <<
" (output vector size: " << nextLevelVectorSize
605 <<
", # inputs: " << levelInputs.size() <<
")\n");
608 SmallVector<Value> levelOutputs;
609 for (
size_t i = 0, numLevelInputs = levelInputs.size(); i < numLevelInputs;
611 Value lhsVector = levelInputs[i];
612 Value rhsVector = levelInputs[i + 1];
613 const Interval &lhsInterval = intervals[i];
614 const Interval &rhsInterval = intervals[i + 1];
619 SmallVector<int64_t> shuffleMask;
620 if (currentLevel == 0) {
621 shuffleMask = computePermutationShuffleMask(
622 toElemsDefs[i], lhsInterval, toElemsDefs[i + 1], rhsInterval,
623 fromElemsOp, nextLevelVectorSize);
625 auto lhsShuffleOp = cast<ShuffleOp>(lhsVector.
getDefiningOp());
626 auto rhsShuffleOp = cast<ShuffleOp>(rhsVector.
getDefiningOp());
627 shuffleMask = computePropagationShuffleMask(lhsShuffleOp, lhsInterval,
628 rhsShuffleOp, rhsInterval,
629 nextLevelVectorSize);
632 Value shuffleVal = vector::ShuffleOp::create(rewriter, loc, lhsVector,
633 rhsVector, shuffleMask);
634 levelOutputs.push_back(shuffleVal);
637 levelInputs = std::move(levelOutputs);
641 assert(levelInputs.size() == 1 &&
"Should have exactly one result");
642 return levelInputs.front();
650getToElementsDefiningOps(FromElementsOp fromElemsOp,
651 SmallVectorImpl<ToElementsOp> &toElemsDefs) {
653 for (Value element : fromElemsOp.getElements()) {
654 auto toElemsOp = element.getDefiningOp<ToElementsOp>();
657 toElemsDefsSet.insert(toElemsOp);
660 toElemsDefs.assign(toElemsDefsSet.begin(), toElemsDefsSet.end());
667struct ToFromElementsToShuffleTreeRewrite final
668 : OpRewritePattern<vector::FromElementsOp> {
672 LogicalResult matchAndRewrite(vector::FromElementsOp fromElemsOp,
673 PatternRewriter &rewriter)
const override {
674 VectorType resultType = fromElemsOp.getType();
675 if (resultType.getRank() != 1)
678 "multi-dimensional output vectors are not supported yet");
679 if (resultType.isScalable())
682 "'vector.from_elements' does not support scalable vectors");
686 SmallVector<ToElementsOp> toElemsDefs;
687 if (
failed(getToElementsDefiningOps(fromElemsOp, toElemsDefs)))
690 if (llvm::any_of(toElemsDefs, [](ToElementsOp toElemsOp) {
691 return toElemsOp.getSource().getType().getRank() != 1;
694 fromElemsOp,
"multi-dimensional input vectors are not supported yet");
697 if (llvm::any_of(toElemsDefs, [](ToElementsOp toElemsOp) {
698 return !toElemsOp.getSource().getType().hasRank();
701 "0-D vectors are not supported");
706 if (toElemsDefs.size() == 1) {
707 ToElementsOp toElemsOp0 = toElemsDefs.front();
708 if (llvm::equal(fromElemsOp.getElements(), toElemsOp0.getResults())) {
710 fromElemsOp,
"trivial forwarding case does not require shuffling");
714 VectorShuffleTreeBuilder shuffleTreeBuilder(fromElemsOp, toElemsDefs);
715 if (
failed(shuffleTreeBuilder.computeShuffleTree()))
717 "failed to compute shuffle tree");
719 Value finalShuffle = shuffleTreeBuilder.generateShuffleTree(rewriter);
720 rewriter.
replaceOp(fromElemsOp, finalShuffle);
725struct LowerVectorToFromElementsToShuffleTreePass
726 :
public vector::impl::LowerVectorToFromElementsToShuffleTreeBase<
727 LowerVectorToFromElementsToShuffleTreePass> {
729 void runOnOperation()
override {
734 return signalPassFailure();
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void populateVectorToFromElementsToShuffleTreePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns to rewrite sequences of vector.to_elements + vector.from_elements operations into a...
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::SetVector< T, Vector, Set, N > SetVector
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap