195 using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
200 if (
auto step = dyn_cast<vector::StepOp>(op))
201 return visitStep(step, results);
202 if (
auto cst = dyn_cast<arith::ConstantOp>(op))
203 return visitConstant(cst, results);
204 if (
auto bcast = dyn_cast<vector::BroadcastOp>(op))
205 return visitBroadcast(bcast, operands, results);
206 if (
auto sc = dyn_cast<vector::ShapeCastOp>(op))
207 return visitShapeCast(sc, operands, results);
208 if (
auto tp = dyn_cast<vector::TransposeOp>(op))
209 return visitTranspose(tp, operands, results);
210 if (
auto add = dyn_cast<arith::AddIOp>(op))
211 return visitAddSub<
false>(
add, operands, results);
212 if (
auto sub = dyn_cast<arith::SubIOp>(op))
213 return visitAddSub<
true>(sub, operands, results);
214 if (
auto mul = dyn_cast<arith::MulIOp>(op))
215 return visitMul(
mul, operands, results);
216 if (
auto div = dyn_cast<arith::DivUIOp>(op))
217 return visitDivRem<
false,
false>(
div, operands,
219 if (
auto div = dyn_cast<arith::DivSIOp>(op))
220 return visitDivRem<
true,
false>(
div, operands,
222 if (
auto rem = dyn_cast<arith::RemUIOp>(op))
223 return visitDivRem<
false,
true>(
rem, operands,
225 if (
auto rem = dyn_cast<arith::RemSIOp>(op))
226 return visitDivRem<
true,
true>(
rem, operands,
228 if (
auto andi = dyn_cast<arith::AndIOp>(op))
229 return visitAndI(andi, operands, results);
230 if (
auto shl = dyn_cast<arith::ShLIOp>(op))
231 return visitShift<
true>(shl, operands, results);
232 if (
auto shr = dyn_cast<arith::ShRUIOp>(op))
233 return visitShift<
false>(shr, operands, results);
234 if (
auto sel = dyn_cast<arith::SelectOp>(op))
235 return visitSelect(sel, operands, results);
236 if (
auto cast = dyn_cast<arith::IndexCastOp>(op))
237 return visitPassThrough(cast, operands, results);
238 if (
auto cast = dyn_cast<arith::IndexCastUIOp>(op))
239 return visitPassThrough(cast, operands, results);
240 setAllPessimistic(op, results);
251 for (
auto [r, lat] : llvm::zip(op->
getResults(), results)) {
258 LogicalResult visitStep(vector::StepOp op,
260 auto vt = cast<VectorType>(op.getType());
261 int64_t n = vt.getNumElements();
276 LogicalResult visitConstant(arith::ConstantOp op,
277 ArrayRef<AxisInfoLattice *> results) {
278 auto vt = dyn_cast<VectorType>(op.getType());
282 if (
auto intAttr = dyn_cast<IntegerAttr>(op.getValue())) {
283 int64_t c = intAttr.getValue().getSExtValue();
293 setAllPessimistic(op, results);
296 auto dense = dyn_cast<DenseIntElementsAttr>(op.getValue());
298 setAllPessimistic(op, results);
301 auto shape = vt.getShape();
304 if (dense.isSplat()) {
305 int64_t c = dense.getSplatValue<APInt>().getSExtValue();
316 unsigned r = shape.size();
317 int64_t inner = shape.back();
318 int64_t outer = vt.getNumElements() / inner;
319 if (inner < 2 || outer < 1) {
328 auto values = llvm::to_vector(dense.getValues<APInt>());
329 int64_t innerCont = inner;
330 int64_t innerConst = inner;
331 int64_t innerStride = values[1].getSExtValue() - values[0].getSExtValue();
332 int64_t base = values[0].getSExtValue();
334 for (int64_t o = 0; o < outer; ++o) {
335 int64_t origin = values[o * inner].getSExtValue();
337 for (int64_t i = 1; i < inner; ++i) {
338 int64_t cur = values[o * inner + i].getSExtValue();
339 int64_t prev = values[o * inner + i - 1].getSExtValue();
340 int64_t diff = cur - prev;
341 if (diff != innerStride)
342 innerStride = std::numeric_limits<int64_t>::min();
344 innerCont = std::min<int64_t>(innerCont, i);
346 innerConst = std::min<int64_t>(innerConst, i);
350 if (innerStride == 1)
351 v.contiguity[r - 1] = innerCont;
352 else if (innerStride == 0)
353 v.constancy[r - 1] = innerConst;
355 v.divisibility[r - 1] = baseDiv;
356 if (innerStride != std::numeric_limits<int64_t>::min())
357 v.innerStride = innerStride;
365 LogicalResult visitBroadcast(vector::BroadcastOp op,
366 ArrayRef<const AxisInfoLattice *> operands,
367 ArrayRef<AxisInfoLattice *> results) {
368 auto resTy = dyn_cast<VectorType>(op.getType());
370 setAllPessimistic(op, results);
373 unsigned rRank = resTy.getRank();
374 AxisInfo src = operands[0]->getValue();
376 auto resShape = resTy.getShape();
377 auto srcVt = dyn_cast<VectorType>(op.getSource().getType());
378 unsigned sRank = srcVt ? srcVt.getRank() : 0;
382 for (
unsigned d = 0; d < rRank; ++d) {
383 int64_t resExt = resShape[d];
386 int sIdx =
static_cast<int>(d) -
static_cast<int>(rRank - sRank);
388 v.constancy[d] = resExt;
390 v.divisibility[d] = src.isInitialized() ? src.divisibility.front() : 1;
393 int64_t srcExt = srcVt.getShape()[sIdx];
394 if (srcExt == 1 && resExt > 1) {
395 v.constancy[d] = resExt;
397 v.divisibility[d] = src.isInitialized() ? src.divisibility[sIdx] : 1;
398 }
else if (src.isInitialized()) {
399 v.contiguity[d] = src.contiguity[sIdx];
400 v.constancy[d] = src.constancy[sIdx];
401 v.divisibility[d] = src.divisibility[sIdx];
404 if (src.knownConstant)
405 v.knownConstant = src.knownConstant;
409 auto resShapeArr = resTy.getShape();
410 int64_t innerExt = resShapeArr.back();
412 static_cast<int>(rRank - 1) -
static_cast<int>(rRank - sRank);
416 int64_t srcInner = srcVt.getShape()[sIdxInner];
417 if (srcInner == 1 && innerExt > 1)
419 else if (srcInner == innerExt && src.innerStride)
420 v.innerStride = src.innerStride;
433 LogicalResult visitShapeCast(vector::ShapeCastOp op,
434 ArrayRef<const AxisInfoLattice *> operands,
435 ArrayRef<AxisInfoLattice *> results) {
436 auto srcTy = dyn_cast<VectorType>(op.getSource().getType());
437 auto dstTy = dyn_cast<VectorType>(op.getType());
438 if (!srcTy || !dstTy) {
439 setAllPessimistic(op, results);
442 AxisInfo src = operands[0]->getValue();
443 if (!src.isInitialized()) {
444 setAllPessimistic(op, results);
449 auto stripLeading = [](ArrayRef<int64_t> s) {
451 while (i < s.size() && s[i] == 1)
453 return s.drop_front(i);
455 auto sCore = stripLeading(srcTy.getShape());
456 auto dCore = stripLeading(dstTy.getShape());
457 unsigned dRank = dstTy.getRank();
459 if (sCore == dCore) {
460 unsigned sLead = srcTy.getRank() - sCore.size();
461 unsigned dLead = dRank - dCore.size();
462 for (
unsigned d = dLead; d < dRank; ++d) {
463 unsigned sIdx = sLead + (d - dLead);
464 v.contiguity[d] = src.contiguity[sIdx];
465 v.constancy[d] = src.constancy[sIdx];
466 v.divisibility[d] = src.divisibility[sIdx];
474 int64_t innerExt = dstTy.getShape().back();
475 int64_t srcContig = std::numeric_limits<int64_t>::max();
476 int64_t srcConst = std::numeric_limits<int64_t>::max();
477 int64_t srcDiv = src.divisibility[src.getRank() - 1];
478 for (
unsigned d = 0; d < src.getRank(); ++d) {
479 srcContig = std::min(srcContig, src.contiguity[d]);
480 srcConst = std::min(srcConst, src.constancy[d]);
482 v.contiguity[dRank - 1] = std::min<int64_t>(srcContig, innerExt);
483 v.constancy[dRank - 1] = std::min<int64_t>(srcConst, innerExt);
484 v.divisibility[dRank - 1] = srcDiv;
486 if (src.knownConstant)
487 v.knownConstant = src.knownConstant;
491 v.innerStride = src.innerStride;
499 LogicalResult visitTranspose(vector::TransposeOp op,
500 ArrayRef<const AxisInfoLattice *> operands,
501 ArrayRef<AxisInfoLattice *> results) {
502 auto resTy = dyn_cast<VectorType>(op.getType());
504 setAllPessimistic(op, results);
507 AxisInfo src = operands[0]->getValue();
508 if (!src.isInitialized()) {
509 setAllPessimistic(op, results);
512 ArrayRef<int64_t> perm = op.getPermutation();
513 unsigned r = resTy.getRank();
515 for (
unsigned d = 0; d < r; ++d) {
516 unsigned s =
static_cast<unsigned>(perm[d]);
517 v.contiguity[d] = src.contiguity[s];
518 v.constancy[d] = src.constancy[s];
519 v.divisibility[d] = src.divisibility[s];
521 if (src.knownConstant)
522 v.knownConstant = src.knownConstant;
525 if (src.innerStride && perm.back() == src.getRank() - 1)
526 v.innerStride = src.innerStride;
531 template <
bool IsSub,
typename OpTy>
532 LogicalResult visitAddSub(OpTy op, ArrayRef<const AxisInfoLattice *> operands,
533 ArrayRef<AxisInfoLattice *> results) {
534 auto vt = dyn_cast<VectorType>(op.getType());
536 setAllPessimistic(op, results);
539 AxisInfo
lhs = operands[0]->getValue();
540 AxisInfo
rhs = operands[1]->getValue();
541 if (!
lhs.isInitialized() || !
rhs.isInitialized()) {
542 setAllPessimistic(op, results);
545 unsigned r = vt.getRank();
547 for (
unsigned d = 0; d < r; ++d) {
548 int64_t lhsCont =
lhs.contiguity[d];
549 int64_t rhsCont =
rhs.contiguity[d];
550 int64_t lhsConst =
lhs.constancy[d];
551 int64_t rhsConst =
rhs.constancy[d];
554 int64_t cont = IsSub ? std::min(lhsCont, rhsConst)
555 : std::
max(std::
min(lhsCont, rhsConst),
556 std::
min(rhsCont, lhsConst));
557 v.contiguity[d] = std::max<int64_t>(1, cont);
558 v.constancy[d] = std::min(lhsConst, rhsConst);
559 v.divisibility[d] = std::gcd(
lhs.divisibility[d],
rhs.divisibility[d]);
564 auto isUniform = [&](
const AxisInfo &a) {
565 unsigned inner = vt.getRank() - 1;
566 return a.constancy[inner] >= vt.getShape()[inner];
568 if (
lhs.innerStride && isUniform(
rhs)) {
569 v.innerStride = *
lhs.innerStride;
570 }
else if (
rhs.innerStride && isUniform(
lhs)) {
571 v.innerStride = IsSub ? -*
rhs.innerStride : *
rhs.innerStride;
577 LogicalResult visitMul(arith::MulIOp op,
578 ArrayRef<const AxisInfoLattice *> operands,
579 ArrayRef<AxisInfoLattice *> results) {
580 auto vt = dyn_cast<VectorType>(op.getType());
582 setAllPessimistic(op, results);
585 AxisInfo
lhs = operands[0]->getValue();
586 AxisInfo
rhs = operands[1]->getValue();
587 if (!
lhs.isInitialized() || !
rhs.isInitialized()) {
588 setAllPessimistic(op, results);
591 unsigned r = vt.getRank();
592 auto shape = vt.getShape();
594 auto unitConstant = [](
const AxisInfo &a,
unsigned d, int64_t extent) {
595 return a.knownConstant && *a.knownConstant == 1 &&
596 a.constancy[d] >= extent;
598 for (
unsigned d = 0; d < r; ++d) {
599 v.constancy[d] = std::min({shape[d],
lhs.constancy[d],
rhs.constancy[d]});
600 v.divisibility[d] = std::min<int64_t>(
603 if (unitConstant(
lhs, d, shape[d]))
604 v.contiguity[d] = std::min(
rhs.contiguity[d], shape[d]);
605 else if (unitConstant(
rhs, d, shape[d]))
606 v.contiguity[d] = std::min(
lhs.contiguity[d], shape[d]);
611 unsigned inner = vt.getRank() - 1;
612 auto isUniformInner = [&](
const AxisInfo &a) {
613 return a.constancy[inner] >= shape[inner];
615 if (
lhs.innerStride && isUniformInner(
rhs) &&
rhs.knownConstant) {
616 v.innerStride = *
lhs.innerStride * *
rhs.knownConstant;
617 }
else if (
rhs.innerStride && isUniformInner(
lhs) &&
lhs.knownConstant) {
618 v.innerStride = *
rhs.innerStride * *
lhs.knownConstant;
636 template <
bool IsSigned,
bool IsRem,
typename OpTy>
637 LogicalResult visitDivRem(OpTy op, ArrayRef<const AxisInfoLattice *> operands,
638 ArrayRef<AxisInfoLattice *> results) {
639 auto vt = dyn_cast<VectorType>(op.getType());
641 setAllPessimistic(op, results);
644 AxisInfo
lhs = operands[0]->getValue();
645 AxisInfo
rhs = operands[1]->getValue();
646 if (!
lhs.isInitialized() || !
rhs.isInitialized()) {
647 setAllPessimistic(op, results);
650 unsigned r = vt.getRank();
651 unsigned inner = r - 1;
652 auto shape = vt.getShape();
655 bool rhsUniform =
rhs.constancy[inner] >= shape[inner] &&
rhs.knownConstant;
656 if (!rhsUniform || *
rhs.knownConstant <= 0) {
660 int64_t c = *
rhs.knownConstant;
662 if (!
lhs.innerStride) {
666 int64_t s = *
lhs.innerStride;
667 int64_t baseDivLhs =
lhs.divisibility[inner];
676 v.constancy[inner] = shape[inner];
681 if (baseDivLhs % c != 0) {
685 int64_t newStride = s / c;
686 v.innerStride = newStride;
688 v.contiguity[inner] = shape[inner];
689 else if (newStride == 0)
690 v.constancy[inner] = shape[inner];
691 v.divisibility[inner] = baseDivLhs / c;
707 LogicalResult visitAndI(arith::AndIOp op,
708 ArrayRef<const AxisInfoLattice *> operands,
709 ArrayRef<AxisInfoLattice *> results) {
710 auto vt = dyn_cast<VectorType>(op.getType());
712 setAllPessimistic(op, results);
715 AxisInfo
lhs = operands[0]->getValue();
716 AxisInfo
rhs = operands[1]->getValue();
717 if (!
lhs.isInitialized() || !
rhs.isInitialized()) {
718 setAllPessimistic(op, results);
721 unsigned r = vt.getRank();
722 unsigned inner = r - 1;
723 auto shape = vt.getShape();
727 auto getUniformMask = [&](
const AxisInfo &a) -> std::optional<int64_t> {
728 if (a.constancy[inner] >= shape[inner] && a.knownConstant)
729 return a.knownConstant;
732 std::optional<int64_t> mLhs = getUniformMask(
lhs);
733 std::optional<int64_t> mRhs = getUniformMask(
rhs);
734 if (!mLhs && !mRhs) {
738 const AxisInfo &x = mLhs ?
rhs :
lhs;
739 int64_t m = mLhs ? *mLhs : *mRhs;
744 v.constancy[inner] = shape[inner];
751 if (m > 0 && llvm::isPowerOf2_64(
static_cast<uint64_t
>(m + 1))) {
753 if (x.innerStride && *x.innerStride % P == 0) {
755 v.constancy[inner] = shape[inner];
756 v.divisibility[inner] =
771 template <
bool IsLeft,
typename OpTy>
772 LogicalResult visitShift(OpTy op, ArrayRef<const AxisInfoLattice *> operands,
773 ArrayRef<AxisInfoLattice *> results) {
774 auto vt = dyn_cast<VectorType>(op.getType());
776 setAllPessimistic(op, results);
779 AxisInfo
lhs = operands[0]->getValue();
780 AxisInfo
rhs = operands[1]->getValue();
781 if (!
lhs.isInitialized() || !
rhs.isInitialized()) {
782 setAllPessimistic(op, results);
785 unsigned r = vt.getRank();
786 unsigned inner = r - 1;
787 auto shape = vt.getShape();
790 if (
rhs.constancy[inner] < shape[inner] || !
rhs.knownConstant) {
794 int64_t k = *
rhs.knownConstant;
795 if (k < 0 || k >= 63) {
799 int64_t factor = 1LL << k;
803 if (
lhs.innerStride) {
804 v.innerStride = *
lhs.innerStride * factor;
805 if (*v.innerStride == 1)
806 v.contiguity[inner] = shape[inner];
807 else if (*v.innerStride == 0)
808 v.constancy[inner] = shape[inner];
810 v.divisibility[inner] =
814 if (
lhs.innerStride && *
lhs.innerStride % factor == 0 &&
815 lhs.divisibility[inner] % factor == 0) {
816 int64_t newStride = *
lhs.innerStride / factor;
817 v.innerStride = newStride;
819 v.contiguity[inner] = shape[inner];
820 else if (newStride == 0)
821 v.constancy[inner] = shape[inner];
822 v.divisibility[inner] =
lhs.divisibility[inner] / factor;
831 LogicalResult visitSelect(arith::SelectOp op,
832 ArrayRef<const AxisInfoLattice *> operands,
833 ArrayRef<AxisInfoLattice *> results) {
834 auto vt = dyn_cast<VectorType>(op.getType());
836 setAllPessimistic(op, results);
840 AxisInfo t = operands[1]->getValue();
841 AxisInfo f = operands[2]->getValue();
842 if (!t.isInitialized() || !f.isInitialized()) {
843 setAllPessimistic(op, results);
851 template <
typename OpTy>
852 LogicalResult visitPassThrough(OpTy op,
853 ArrayRef<const AxisInfoLattice *> operands,
854 ArrayRef<AxisInfoLattice *> results) {
855 if (!isa<VectorType>(op.getType())) {
856 setAllPessimistic(op, results);