20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/MathExtras.h"
23 #include "llvm/Support/raw_ostream.h"
28 #define GEN_PASS_DEF_LOWERVECTORTOFROMELEMENTSTOSHUFFLETREE
29 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
34 #define DEBUG_TYPE "lower-vector-to-from-elements-to-shuffle-tree"
42 [[maybe_unused]] constexpr
unsigned kIndScale = 2;
45 using Interval = std::pair<unsigned, unsigned>;
126 class VectorShuffleTreeBuilder {
128 VectorShuffleTreeBuilder() =
delete;
129 VectorShuffleTreeBuilder(FromElementsOp fromElemOp,
135 LogicalResult computeShuffleTree();
143 FromElementsOp fromElemsOp;
154 void computeShuffleTreeIntervals();
155 void computeShuffleTreeVectorSizes();
161 VectorShuffleTreeBuilder::VectorShuffleTreeBuilder(
163 : fromElemsOp(fromElemOp), toElemsDefs(toElemDefs) {
164 assert(fromElemsOp &&
"from_elements op is required");
165 assert(!toElemsDefs.empty() &&
"At least one to_elements op is required");
172 template <
typename T>
174 if (values.size() % 2 != 0)
175 values.push_back(values.back());
215 void VectorShuffleTreeBuilder::computeShuffleTreeIntervals() {
221 toElemsToInputOrdinal.insert({toElemsOp, idx});
227 {kMaxUnsigned, kMaxUnsigned});
229 for (
const auto &[idx, element] :
231 auto toElemsOp = cast<ToElementsOp>(element.getDefiningOp());
232 unsigned inputIdx = toElemsToInputOrdinal[toElemsOp];
233 Interval ¤tInterval = firstLevelIntervals[inputIdx];
236 if (currentInterval.first == kMaxUnsigned)
237 currentInterval.first = idx;
240 currentInterval.second = idx;
243 duplicateLastIfOdd(toElemsDefs);
244 duplicateLastIfOdd(firstLevelIntervals);
245 intervalsPerLevel.push_back(std::move(firstLevelIntervals));
248 for (
unsigned level = 1; level < numLevels; ++level) {
249 bool isLastLevel = level == numLevels - 1;
250 const auto &prevLevelIntervals = intervalsPerLevel[level - 1];
253 {kMaxUnsigned, kMaxUnsigned});
255 size_t currentNumLevels = currentLevelIntervals.size();
256 for (
size_t inputIdx = 0; inputIdx < currentNumLevels; ++inputIdx) {
257 auto &interval = currentLevelIntervals[inputIdx];
258 const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2];
259 const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1];
264 interval.first = prevLhsInterval.first;
266 std::max(prevLhsInterval.second, prevRhsInterval.second);
273 duplicateLastIfOdd(currentLevelIntervals);
275 intervalsPerLevel.push_back(std::move(currentLevelIntervals));
295 void VectorShuffleTreeBuilder::computeShuffleTreeVectorSizes() {
301 vectorSizePerLevel.front() = 0;
302 vectorSizePerLevel.back() =
303 cast<VectorType>(fromElemsOp.getResult().getType()).getNumElements();
305 for (
unsigned level = 1; level < numLevels - 1; ++level) {
306 const auto ¤tLevelIntervals = intervalsPerLevel[level];
307 unsigned currentVectorSize = 1;
308 size_t numIntervals = currentLevelIntervals.size();
309 for (
size_t i = 0; i < numIntervals; ++i) {
310 const auto &interval = currentLevelIntervals[i];
311 unsigned intervalSize = interval.second - interval.first + 1;
312 currentVectorSize =
std::max(currentVectorSize, intervalSize);
314 assert(currentVectorSize > 0 &&
"vector size must be positive");
315 vectorSizePerLevel[level] = currentVectorSize;
319 void VectorShuffleTreeBuilder::dump() {
323 llvm::dbgs() <<
"VectorShuffleTreeBuilder Configuration:\n";
325 llvm::dbgs() << llvm::indent(indLv, kIndScale) <<
"* Inputs:\n";
327 for (
const auto &toElemsOp : toElemsDefs)
328 llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElemsOp <<
"\n";
329 llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElemsOp <<
"\n\n";
332 llvm::dbgs() << llvm::indent(indLv, kIndScale)
333 <<
"* Total levels: " << numLevels <<
"\n";
334 llvm::dbgs() << llvm::indent(indLv, kIndScale)
335 <<
"* Vector sizes per level: ";
336 llvm::interleaveComma(vectorSizePerLevel, llvm::dbgs());
337 llvm::dbgs() <<
"\n";
338 llvm::dbgs() << llvm::indent(indLv, kIndScale)
339 <<
"* Input intervals per level:\n";
341 for (
const auto &[level, intervals] :
llvm::enumerate(intervalsPerLevel)) {
342 llvm::dbgs() << llvm::indent(indLv, kIndScale) <<
"* Level " << level
344 llvm::interleaveComma(intervals, llvm::dbgs(),
345 [](
const Interval &interval) {
346 llvm::dbgs() <<
"[" << interval.first <<
","
347 << interval.second <<
"]";
349 llvm::dbgs() <<
"\n";
379 LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() {
383 numLevels = 1u +
std::max(1u, llvm::Log2_64_Ceil(toElemsDefs.size()));
384 vectorSizePerLevel.resize(numLevels, 0);
385 intervalsPerLevel.reserve(numLevels);
387 computeShuffleTreeIntervals();
388 computeShuffleTreeVectorSizes();
420 ToElementsOp toElementOp0,
const Interval &interval0,
421 ToElementsOp toElementOp1,
const Interval &interval1,
422 FromElementsOp fromElemsOp,
unsigned outputVectorSize) {
424 unsigned inputVectorSize =
425 toElementOp0.getSource().getType().getNumElements();
427 for (
const auto &[inputIdx, element] :
429 auto currentToElemOp = cast<ToElementsOp>(element.getDefiningOp());
431 if (currentToElemOp != toElementOp0 && currentToElemOp != toElementOp1)
436 unsigned permVal = cast<OpResult>(element).getResultNumber();
437 unsigned maskIdx = inputIdx;
442 if (currentToElemOp == toElementOp0) {
443 maskIdx -= interval0.first;
446 unsigned intervalOffset = interval1.first - interval0.first;
447 maskIdx += intervalOffset - interval1.first;
448 permVal += inputVectorSize;
451 mask[maskIdx] = permVal;
456 llvm::dbgs() << llvm::indent(indLv, kIndScale) <<
"* Permutation mask: [";
457 llvm::interleaveComma(mask, llvm::dbgs());
458 llvm::dbgs() <<
"]\n";
460 llvm::dbgs() << llvm::indent(indLv, kIndScale)
461 <<
"* Combining: " << toElementOp0 <<
" and " << toElementOp1
494 ShuffleOp lhsShuffleOp,
const Interval &lhsInterval, ShuffleOp rhsShuffleOp,
495 const Interval &rhsInterval,
unsigned outputVectorSize) {
498 unsigned inputVectorSize = lhsShuffleMask.size();
499 assert(inputVectorSize == rhsShuffleMask.size() &&
500 "Expected both shuffle masks to have the same size");
502 bool hasSameInput = lhsShuffleOp == rhsShuffleOp;
503 unsigned lhsRhsOffset = rhsInterval.first - lhsInterval.first;
508 for (
unsigned i = 0; i < inputVectorSize; ++i) {
509 if (lhsShuffleMask[i] != ShuffleOp::kPoisonIndex)
515 unsigned rhsIdx = i + lhsRhsOffset;
516 if (rhsShuffleMask[i] != ShuffleOp::kPoisonIndex) {
517 assert(rhsIdx < outputVectorSize &&
"RHS index out of bounds");
518 assert(mask[rhsIdx] == ShuffleOp::kPoisonIndex &&
"mask already set");
519 mask[rhsIdx] = i + inputVectorSize;
525 llvm::dbgs() << llvm::indent(indLv, kIndScale)
526 <<
"* Propagation shuffle mask computation:\n";
528 llvm::dbgs() << llvm::indent(indLv, kIndScale)
529 <<
"* LHS shuffle op: " << lhsShuffleOp <<
"\n";
530 llvm::dbgs() << llvm::indent(indLv, kIndScale)
531 <<
"* RHS shuffle op: " << rhsShuffleOp <<
"\n";
532 llvm::dbgs() << llvm::indent(indLv, kIndScale) <<
"* Result mask: [";
533 llvm::interleaveComma(mask, llvm::dbgs());
534 llvm::dbgs() <<
"]\n";
582 LLVM_DEBUG(llvm::dbgs() <<
"VectorShuffleTreeBuilder Code Generation:\n");
586 llvm::transform(toElemsDefs, std::back_inserter(levelInputs),
587 [](ToElementsOp toElemsOp) {
return toElemsOp.getSource(); });
594 Location loc = fromElemsOp.getLoc();
595 unsigned currentLevel = 0;
596 for (
const auto &[nextLevelVectorSize, intervals] :
597 llvm::zip_equal(
ArrayRef(vectorSizePerLevel).drop_front(),
598 ArrayRef(intervalsPerLevel).drop_back())) {
600 duplicateLastIfOdd(levelInputs);
602 LLVM_DEBUG(llvm::dbgs() << llvm::indent(1, kIndScale)
603 <<
"* Processing level " << currentLevel
604 <<
" (output vector size: " << nextLevelVectorSize
605 <<
", # inputs: " << levelInputs.size() <<
")\n");
609 for (
size_t i = 0, numLevelInputs = levelInputs.size(); i < numLevelInputs;
611 Value lhsVector = levelInputs[i];
612 Value rhsVector = levelInputs[i + 1];
613 const Interval &lhsInterval = intervals[i];
614 const Interval &rhsInterval = intervals[i + 1];
620 if (currentLevel == 0) {
621 shuffleMask = computePermutationShuffleMask(
622 toElemsDefs[i], lhsInterval, toElemsDefs[i + 1], rhsInterval,
623 fromElemsOp, nextLevelVectorSize);
625 auto lhsShuffleOp = cast<ShuffleOp>(lhsVector.
getDefiningOp());
626 auto rhsShuffleOp = cast<ShuffleOp>(rhsVector.
getDefiningOp());
627 shuffleMask = computePropagationShuffleMask(lhsShuffleOp, lhsInterval,
628 rhsShuffleOp, rhsInterval,
629 nextLevelVectorSize);
632 Value shuffleVal = vector::ShuffleOp::create(rewriter, loc, lhsVector,
633 rhsVector, shuffleMask);
634 levelOutputs.push_back(shuffleVal);
637 levelInputs = std::move(levelOutputs);
641 assert(levelInputs.size() == 1 &&
"Should have exactly one result");
642 return levelInputs.front();
650 getToElementsDefiningOps(FromElementsOp fromElemsOp,
653 for (
Value element : fromElemsOp.getElements()) {
654 auto toElemsOp = element.getDefiningOp<ToElementsOp>();
657 toElemsDefsSet.insert(toElemsOp);
660 toElemsDefs.assign(toElemsDefsSet.begin(), toElemsDefsSet.end());
667 struct ToFromElementsToShuffleTreeRewrite final
672 LogicalResult matchAndRewrite(vector::FromElementsOp fromElemsOp,
674 VectorType resultType = fromElemsOp.getType();
675 if (resultType.getRank() != 1)
678 "multi-dimensional output vectors are not supported yet");
679 if (resultType.isScalable())
682 "'vector.from_elements' does not support scalable vectors");
687 if (
failed(getToElementsDefiningOps(fromElemsOp, toElemsDefs)))
690 if (llvm::any_of(toElemsDefs, [](ToElementsOp toElemsOp) {
691 return toElemsOp.getSource().
getType().getRank() != 1;
694 fromElemsOp,
"multi-dimensional input vectors are not supported yet");
697 if (llvm::any_of(toElemsDefs, [](ToElementsOp toElemsOp) {
698 return !toElemsOp.getSource().
getType().hasRank();
701 "0-D vectors are not supported");
706 if (toElemsDefs.size() == 1) {
707 ToElementsOp toElemsOp0 = toElemsDefs.front();
708 if (llvm::equal(fromElemsOp.getElements(), toElemsOp0.getResults())) {
710 fromElemsOp,
"trivial forwarding case does not require shuffling");
714 VectorShuffleTreeBuilder shuffleTreeBuilder(fromElemsOp, toElemsDefs);
715 if (
failed(shuffleTreeBuilder.computeShuffleTree()))
717 "failed to compute shuffle tree");
719 Value finalShuffle = shuffleTreeBuilder.generateShuffleTree(rewriter);
720 rewriter.
replaceOp(fromElemsOp, finalShuffle);
725 struct LowerVectorToFromElementsToShuffleTreePass
726 :
public vector::impl::LowerVectorToFromElementsToShuffleTreeBase<
727 LowerVectorToFromElementsToShuffleTreePass> {
729 void runOnOperation()
override {
734 return signalPassFailure();
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
void populateVectorToFromElementsToShuffleTreePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns to rewrite sequences of vector.to_elements + vector.from_elements operations into a...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...