35 #define DEBUG_TYPE "vector-contract-lowering"
58 int64_t idx = it.index();
61 results.push_back(it.value());
78 results.push_back(targetExpr);
86 int64_t index, int64_t pos,
93 return rewriter.
create<vector::ExtractOp>(loc, val, pos);
100 for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
101 Value ext = rewriter.
create<vector::ExtractOp>(loc, val, d);
103 result = rewriter.
create<vector::InsertOp>(loc, load, result, d);
111 VectorType type, int64_t index, int64_t pos,
118 return rewriter.
create<vector::InsertOp>(loc, val, result, pos);
122 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
123 Value ext = rewriter.
create<vector::ExtractOp>(loc, result, d);
124 Value ins = rewriter.
create<vector::ExtractOp>(loc, val, d);
126 result = rewriter.
create<vector::InsertOp>(loc, sto, result, d);
132 static std::optional<Value>
136 using vector::CombiningKind;
140 if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
141 kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
144 mul = rewriter.
create<arith::MulIOp>(loc, x, y);
148 kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
149 kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
150 kind == CombiningKind::XOR)
154 if (acc && isa<VectorType>(acc.
getType()) && kind == CombiningKind::ADD) {
155 Value fma = rewriter.
create<vector::FMAOp>(loc, x, y, acc);
162 mul = rewriter.
create<arith::MulFOp>(loc, x, y);
166 return std::optional<Value>(mul);
174 ArrayAttr iteratorTypes) {
178 dimsIdx.push_back(i);
198 return rewriter.
create<arith::AddIOp>(loc, x, y);
199 return rewriter.
create<arith::AddFOp>(loc, x, y);
207 return rewriter.
create<arith::MulIOp>(loc, x, y);
208 return rewriter.
create<arith::MulFOp>(loc, x, y);
226 class ContractionOpToMatmulOpLowering
229 using MaskableOpRewritePattern::MaskableOpRewritePattern;
231 using FilterConstraintType =
232 std::function<LogicalResult(vector::ContractionOp op)>;
234 static LogicalResult defaultFilter(vector::ContractionOp op) {
238 ContractionOpToMatmulOpLowering(
241 FilterConstraintType constraint = defaultFilter)
243 vectorTransformOptions(vectorTransformOptions),
244 filter(std::move(constraint)) {}
247 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
253 FilterConstraintType filter;
271 class ContractionOpToOuterProductOpLowering
274 using MaskableOpRewritePattern::MaskableOpRewritePattern;
276 using FilterConstraintType =
277 std::function<LogicalResult(vector::ContractionOp op)>;
279 static LogicalResult defaultFilter(vector::ContractionOp op) {
283 ContractionOpToOuterProductOpLowering(
286 FilterConstraintType constraint = defaultFilter)
288 vectorTransformOptions(vectorTransformOptions),
289 filter(std::move(constraint)) {}
292 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
298 FilterConstraintType filter;
319 class ContractionOpToDotLowering
322 using MaskableOpRewritePattern::MaskableOpRewritePattern;
324 using FilterConstraintType =
325 std::function<LogicalResult(vector::ContractionOp op)>;
327 static LogicalResult defaultFilter(vector::ContractionOp op) {
331 ContractionOpToDotLowering(
334 const FilterConstraintType &constraint = defaultFilter)
336 vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
339 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
345 FilterConstraintType filter;
362 class ContractionOpLowering
365 using MaskableOpRewritePattern::MaskableOpRewritePattern;
366 using FilterConstraintType =
367 std::function<LogicalResult(vector::ContractionOp op)>;
369 static LogicalResult defaultFilter(vector::ContractionOp op) {
375 FilterConstraintType constraint = defaultFilter)
377 vectorTransformOptions(vectorTransformOptions),
378 filter(std::move(constraint)) {}
381 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
387 FilterConstraintType filter;
390 vector::ContractionOp op, int64_t lhsIndex,
391 int64_t rhsIndex,
Value mask)
const;
394 vector::ContractionOp op,
Value mask)
const;
399 struct UnrolledOuterProductGenerator
401 UnrolledOuterProductGenerator(
RewriterBase &b, vector::ContractionOp op)
403 kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
404 res(op.getAcc()), lhsType(op.getLhsType()) {
405 auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
406 if (maskableOp.isMasked())
407 mask = maskableOp.getMaskingOp().getMask();
413 return rewriter.
create<vector::TransposeOp>(loc, v, perm);
418 auto vecType = dyn_cast<VectorType>(elementType);
420 elementType = vecType.getElementType();
421 if (elementType == dstElementType)
423 Type promotedType = dstElementType;
425 promotedType = vecType.clone(promotedType);
426 if (isa<FloatType>(dstElementType))
427 return rewriter.
create<arith::ExtFOp>(loc, promotedType, v);
428 return rewriter.
create<arith::ExtSIOp>(loc, promotedType, v);
432 VectorType lhsType,
int reductionSize,
433 std::optional<Value> maybeMask = std::nullopt) {
435 if (mask && !maybeMask.has_value())
438 Type resElementType = cast<VectorType>(res.
getType()).getElementType();
439 for (int64_t k = 0; k < reductionSize; ++k) {
440 Value extractA = rewriter.
create<vector::ExtractOp>(loc, lhs, k);
441 Value extractB = rewriter.
create<vector::ExtractOp>(loc, rhs, k);
442 extractA =
promote(extractA, resElementType);
443 extractB =
promote(extractB, resElementType);
445 if (maybeMask.has_value() && maybeMask.value())
447 rewriter.
create<vector::ExtractOp>(loc, maybeMask.value(), k);
450 loc, res.
getType(), extractA, extractB, res, kind);
459 std::optional<int64_t> getReductionSize(VectorType vecType,
460 int64_t reductionDim) {
462 if (vecType.getScalableDims()[reductionDim])
464 int64_t reductionSize = vecType.getDimSize(reductionDim);
465 assert(reductionSize > 0 &&
466 "Reduction dim must be a known static size to allow unrolling");
467 return reductionSize;
471 FailureOr<Value> matmat() {
472 if (!iters({Par(), Par(), Red()}))
479 if (layout({{m, k}, {k, n}, {m, n}})) {
480 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
484 Value tMask = t(mask, {2, 0, 1});
485 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
489 if (layout({{m, k}, {n, k}, {m, n}})) {
490 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
493 Value tMask = t(mask, {2, 0, 1});
494 return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
498 if (layout({{k, m}, {k, n}, {m, n}})) {
499 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
500 Value tMask = t(mask, {2, 0, 1});
501 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
505 if (layout({{k, m}, {n, k}, {m, n}})) {
506 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
508 Value tMask = t(mask, {2, 0, 1});
509 return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask);
514 if (layout({{m, k}, {k, n}, {n, m}})) {
515 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
517 Value tMask = t(mask, {2, 0, 1});
518 return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask);
522 if (layout({{m, k}, {n, k}, {n, m}})) {
523 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
526 Value tMask = t(mask, {2, 0, 1});
527 return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
530 if (layout({{k, m}, {k, n}, {n, m}})) {
531 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
532 Value tMask = t(mask, {2, 0, 1});
533 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
536 if (layout({{k, m}, {n, k}, {n, m}})) {
537 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
539 Value tMask = t(mask, {2, 0, 1});
540 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
551 FailureOr<Value> matvec() {
552 if (!iters({Par(), Red()}))
558 if (layout({{m, k}, {k}, {m}})) {
559 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
561 Value tMask = t(mask);
562 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
566 if (layout({{k, m}, {k}, {m}})) {
567 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
568 Value tMask = t(mask);
569 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
573 if (layout({{k}, {m, k}, {m}})) {
574 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
576 Value tMask = t(mask);
577 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
581 if (layout({{k}, {k, m}, {m}})) {
582 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
583 Value tMask = t(mask);
584 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
594 FailureOr<Value> tmatvec() {
595 if (!iters({Red(), Par()}))
601 if (layout({{m, k}, {k}, {m}}))
602 if (
auto reductionSize = getReductionSize(lhsType, 1))
603 return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
605 if (layout({{k, m}, {k}, {m}}))
606 if (
auto reductionSize = getReductionSize(lhsType, 0))
607 return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
609 if (layout({{k}, {m, k}, {m}}))
610 if (
auto reductionSize = getReductionSize(lhsType, 0))
611 return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
613 if (layout({{k}, {k, m}, {m}}))
614 if (
auto reductionSize = getReductionSize(lhsType, 0))
615 return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
620 vector::CombiningKind kind;
621 Value lhs, rhs, res, mask;
641 ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
642 vector::ContractionOp op, MaskingOpInterface maskOp,
644 if (vectorTransformOptions.vectorContractLowering !=
645 vector::VectorContractLowering::OuterProduct)
648 if (failed(filter(op)))
651 UnrolledOuterProductGenerator e(rewriter, op);
652 FailureOr<Value> matmatRes = e.matmat();
653 if (succeeded(matmatRes)) {
656 FailureOr<Value> matvecRes = e.matvec();
657 if (succeeded(matvecRes)) {
661 FailureOr<Value> tmatvecRes = e.tmatvec();
665 FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
666 vector::ContractionOp op, MaskingOpInterface maskOp,
672 if (failed(filter(op)))
675 if (vectorTransformOptions.vectorContractLowering !=
676 vector::VectorContractLowering::Dot)
679 auto iteratorTypes = op.getIteratorTypes().getValue();
680 static constexpr std::array<int64_t, 2> perm = {1, 0};
682 Value lhs = op.getLhs(), rhs = op.getRhs();
685 auto infer = [&](MapList m) {
701 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
702 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
703 }
else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
705 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
706 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
707 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
708 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
709 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
710 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
713 lhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
715 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
717 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
719 lhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
720 rhs = rewriter.
create<vector::TransposeOp>(loc, tmp, perm);
721 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
723 rhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
733 if (maps == infer({{m, n}, {n}, {m}})) {
735 }
else if (maps == infer({{n, m}, {n}, {m}})) {
736 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
737 }
else if (maps == infer({{n}, {m, n}, {m}})) {
739 }
else if (maps == infer({{n}, {n, m}, {m}})) {
741 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
749 VectorType dstType = cast<VectorType>(op.getResultType());
750 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
751 "Expected dst type of rank 1 or 2");
753 unsigned rank = dstType.getRank();
754 unsigned dstRows = dstType.getShape()[0];
755 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
758 Value res = rewriter.
create<arith::ConstantOp>(loc, dstType,
760 bool isInt = isa<IntegerType>(dstType.getElementType());
761 for (
unsigned r = 0; r < dstRows; ++r) {
762 Value a = rewriter.
create<vector::ExtractOp>(op.getLoc(), lhs, r);
763 for (
unsigned c = 0; c < dstColumns; ++c) {
766 : rewriter.
create<vector::ExtractOp>(op.getLoc(), rhs, c);
768 Value reduced = rewriter.
create<vector::ReductionOp>(
769 op.getLoc(), vector::CombiningKind::ADD, m);
773 res = rewriter.
create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
776 if (
auto acc = op.getAcc())
777 res =
createAdd(op.getLoc(), res, acc, isInt, rewriter);
783 struct ContractOpToElementwise
785 using MaskableOpRewritePattern::MaskableOpRewritePattern;
786 using FilterConstraintType =
787 std::function<LogicalResult(vector::ContractionOp op)>;
788 static LogicalResult defaultFilter(vector::ContractionOp op) {
791 ContractOpToElementwise(
794 const FilterConstraintType &constraint = defaultFilter)
796 vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
799 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
800 MaskingOpInterface maskOp,
806 if (failed(filter(contractOp)))
809 if (vectorTransformOptions.vectorContractLowering !=
810 vector::VectorContractLowering::ParallelArith)
815 AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
816 AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
822 for (int64_t dim : lhsReductionDims) {
823 if (lhsShape[dim] != 1)
826 for (int64_t dim : rhsReductionDims) {
827 if (rhsShape[dim] != 1)
830 AffineMap accMap = contractOp.getIndexingMapsArray()[2];
832 unsigned numLhsDimToBroadcast =
833 numParallelDims - (lhsMap.
getNumResults() - lhsReductionDims.size());
834 unsigned numRhsDimToBroadcast =
835 numParallelDims - (rhsMap.
getNumResults() - rhsReductionDims.size());
840 for (int64_t dim : lhsReductionDims)
841 lhsTranspose.push_back(numLhsDimToBroadcast + dim);
842 for (int64_t dim : rhsReductionDims)
843 rhsTranspose.push_back(numRhsDimToBroadcast + dim);
846 for (
unsigned i = 0; i < numParallelDims; i++) {
847 std::optional<unsigned> lhsDim =
850 lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
854 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
855 lhsTranspose.push_back(lhsDims.size() - 1);
857 std::optional<unsigned> rhsDim =
860 rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
864 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
865 rhsTranspose.push_back(rhsDims.size() - 1);
868 Value newLhs = contractOp.getLhs();
869 Value newRhs = contractOp.getRhs();
871 if (!lhsDims.empty()) {
872 lhsDims.append(lhsShape.begin(), lhsShape.end());
875 newLhs = rewriter.
create<vector::BroadcastOp>(loc, expandedType, newLhs);
877 if (!rhsDims.empty()) {
878 rhsDims.append(rhsShape.begin(), rhsShape.end());
881 newRhs = rewriter.
create<vector::BroadcastOp>(loc, expandedType, newRhs);
883 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
884 newLhs = rewriter.
create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
885 newRhs = rewriter.
create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
888 newLhs = rewriter.
create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
889 newRhs = rewriter.
create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
890 std::optional<Value> result =
892 contractOp.getKind(), rewriter, isInt);
902 FilterConstraintType filter;
922 FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
923 vector::ContractionOp op, MaskingOpInterface maskOp,
925 if (failed(filter(op)))
929 if (op.getLhsType().getElementType() !=
936 if (op.getKind() != vector::CombiningKind::ADD) {
938 op,
"contractions other than 'add' not supported");
944 ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
945 FailureOr<Value> newVal1 =
946 pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
947 if (!failed(newVal1))
950 ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
951 FailureOr<Value> newVal2 =
952 pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
953 if (!failed(newVal2))
956 ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
957 FailureOr<Value> newVal3 =
958 pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
959 if (!failed(newVal3))
962 ContractOpToElementwise pat4(vectorTransformOptions, ctx);
963 FailureOr<Value> newVal4 =
964 pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
965 if (!failed(newVal4))
972 mask = maskOp.getMask();
974 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
975 if (!batchDimMap.empty()) {
976 int64_t lhsIndex = batchDimMap[0].first;
977 int64_t rhsIndex = batchDimMap[0].second;
978 auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
985 std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
986 op.getContractingDimMap();
989 for (
auto &dimPair : contractingDimMap) {
990 lhsContractingDimSet.insert(dimPair.first);
991 rhsContractingDimSet.insert(dimPair.second);
995 VectorType lhsType = op.getLhsType();
996 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
997 if (lhsContractingDimSet.count(lhsIndex) == 0) {
998 auto newOp = lowerParallel(rewriter, op, lhsIndex, -1, mask);
1006 VectorType rhsType = op.getRhsType();
1007 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1008 if (rhsContractingDimSet.count(rhsIndex) == 0) {
1009 auto newOp = lowerParallel(rewriter, op, -1, rhsIndex, mask);
1017 if (!contractingDimMap.empty()) {
1018 auto newOp = lowerReduction(rewriter, op, mask);
1030 FailureOr<Value> ContractionOpLowering::lowerParallel(
PatternRewriter &rewriter,
1031 vector::ContractionOp op,
1035 VectorType lhsType = op.getLhsType();
1036 VectorType rhsType = op.getRhsType();
1037 VectorType resType = cast<VectorType>(op.getResultType());
1040 int64_t iterIndex = -1;
1041 int64_t dimSize = -1;
1042 if (lhsIndex >= 0) {
1043 iterIndex = iMap[0].getDimPosition(lhsIndex);
1044 if (rhsIndex >= 0 && iterIndex != iMap[1].
getDimPosition(rhsIndex))
1046 diag <<
"expected lhsIndex=" << lhsIndex <<
" and rhsIndex=" << rhsIndex
1047 <<
" to map to the same dimension";
1049 if (lhsType.getScalableDims()[lhsIndex])
1051 diag <<
"Unrolling scalable dimension (lhsIndex=" << lhsIndex
1052 <<
") is not supported yet";
1054 dimSize = lhsType.getDimSize(lhsIndex);
1055 }
else if (rhsIndex >= 0) {
1056 iterIndex = iMap[1].getDimPosition(rhsIndex);
1057 if (rhsType.getScalableDims()[rhsIndex])
1059 diag <<
"Unrolling scalable dimension (rhsIndex=" << rhsIndex
1060 <<
") is not supported yet";
1062 dimSize = rhsType.getDimSize(rhsIndex);
1066 diag <<
"expected either lhsIndex=" << lhsIndex
1067 <<
" or rhsIndex=" << rhsIndex <<
" to be nonnegative";
1076 int64_t resIndex =
getResultIndex(iMap[2], iterIndex).value_or(-1);
1077 if (resIndex == -1 && dimSize != 1)
1079 diag <<
"expected the dimension for iterIndex=" << iterIndex
1080 <<
" to either appear in the result map, or to be a unit dimension";
1084 std::array<AffineMap, 3> lowIndexingMaps = {
1085 adjustMap(iMap[0], iterIndex, rewriter),
1086 adjustMap(iMap[1], iterIndex, rewriter),
1087 adjustMap(iMap[2], iterIndex, rewriter)};
1093 Value result = rewriter.
create<arith::ConstantOp>(
1096 for (int64_t d = 0; d < dimSize; ++d) {
1097 auto lhs =
reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1098 auto rhs =
reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1099 auto acc =
reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1104 iterIndex, d, rewriter);
1107 loc, lhs, rhs, acc, lowAffine, lowIter);
1108 lowContract =
maskOperation(rewriter, lowContract, lowMask);
1110 resIndex, d, rewriter);
1116 FailureOr<Value> ContractionOpLowering::lowerReduction(
1119 VectorType lhsType = op.getLhsType();
1120 VectorType rhsType = op.getRhsType();
1121 Type resType = op.getResultType();
1122 if (isa<VectorType>(resType))
1124 "did not expect a VectorType result");
1125 bool isInt = isa<IntegerType>(resType);
1127 int64_t iterIndex = 0;
1129 std::optional<int64_t> lookupLhs =
getResultIndex(iMap[0], iterIndex);
1130 std::optional<int64_t> lookupRhs =
getResultIndex(iMap[1], iterIndex);
1131 if (!lookupLhs.has_value())
1133 diag <<
"expected iterIndex=" << iterIndex <<
"to map to a LHS dimension";
1135 if (!lookupRhs.has_value())
1137 diag <<
"expected iterIndex=" << iterIndex <<
"to map to a RHS dimension";
1139 int64_t lhsIndex = *lookupLhs;
1140 int64_t rhsIndex = *lookupRhs;
1141 int64_t dimSize = lhsType.getDimSize(lhsIndex);
1142 if (dimSize != rhsType.getDimSize(rhsIndex))
1144 diag <<
"expect LHS dimension " << lhsIndex
1145 <<
" to have the same size as RHS dimension " << rhsIndex;
1148 if (lhsType.getRank() == 1) {
1149 if (rhsType.getRank() != 1)
1151 op,
"When LHS has rank 1, expected also RHS to have rank 1");
1152 Value m =
createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1153 auto kind = vector::CombiningKind::ADD;
1155 Value acc = op.getAcc();
1157 acc ? rewriter.
create<vector::ReductionOp>(loc, kind, m, acc)
1158 : rewriter.
create<vector::ReductionOp>(loc, kind, m);
1162 std::array<AffineMap, 3> lowIndexingMaps = {
1163 adjustMap(iMap[0], iterIndex, rewriter),
1164 adjustMap(iMap[1], iterIndex, rewriter),
1165 adjustMap(iMap[2], iterIndex, rewriter)};
1173 Value result = op.getAcc();
1174 for (int64_t d = 0; d < dimSize; ++d) {
1175 auto lhs =
reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1176 auto rhs =
reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1180 iterIndex, d, rewriter);
1183 loc, lhs, rhs, result, lowAffine, lowIter);
1202 class OuterProductOpLowering :
public OpRewritePattern<vector::OuterProductOp> {
1206 LogicalResult matchAndRewrite(vector::OuterProductOp op,
1208 VectorType resType = op.getResultVectorType();
1209 if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1212 auto loc = op.getLoc();
1214 VectorType lhsType = op.getOperandVectorTypeLHS();
1215 VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1216 Type eltType = resType.getElementType();
1217 bool isInt = isa<IntegerType, IndexType>(eltType);
1218 Value acc = op.getAcc();
1219 vector::CombiningKind kind = op.getKind();
1223 auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1226 if (maskableOp.isMasked()) {
1228 rootOp = maskableOp.getMaskingOp();
1229 mask = maskableOp.getMaskingOp().getMask();
1236 Value b = rewriter.
create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
1238 loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
1239 if (!mult.has_value())
1245 Value result = rewriter.
create<arith::ConstantOp>(
1247 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1248 Value x = rewriter.
create<vector::ExtractOp>(loc, op.getLhs(), d);
1249 Value a = rewriter.
create<vector::BroadcastOp>(loc, rhsType, x);
1252 r = rewriter.
create<vector::ExtractOp>(loc, acc, d);
1255 extrMask = rewriter.
create<vector::ExtractOp>(loc, mask, d);
1258 loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
1261 result = rewriter.
create<vector::InsertOp>(loc, *m, result, d);
1288 FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
1289 vector::ContractionOp op, MaskingOpInterface maskOp,
1295 if (vectorTransformOptions.vectorContractLowering !=
1296 vector::VectorContractLowering::Matmul)
1298 if (failed(filter(op)))
1301 auto iteratorTypes = op.getIteratorTypes().getValue();
1307 Type opResType = op.getType();
1308 VectorType vecType = dyn_cast<VectorType>(opResType);
1309 if (vecType && vecType.isScalable()) {
1314 Type elementType = op.getLhsType().getElementType();
1318 Type dstElementType = vecType ? vecType.getElementType() : opResType;
1319 if (elementType != dstElementType)
1329 Value lhs = op.getLhs();
1330 auto lhsMap = op.getIndexingMapsArray()[0];
1337 Value rhs = op.getRhs();
1338 auto rhsMap = op.getIndexingMapsArray()[1];
1345 VectorType lhsType = cast<VectorType>(lhs.getType());
1346 VectorType rhsType = cast<VectorType>(rhs.getType());
1347 int64_t lhsRows = lhsType.getDimSize(0);
1348 int64_t lhsColumns = lhsType.getDimSize(1);
1349 int64_t rhsColumns = rhsType.getDimSize(1);
1351 Type flattenedLHSType =
1353 lhs = rew.
create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1355 Type flattenedRHSType =
1357 rhs = rew.
create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1359 Value mul = rew.
create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1361 mul = rew.
create<vector::ShapeCastOp>(
1368 auto accMap = op.getIndexingMapsArray()[2];
1372 llvm_unreachable(
"invalid contraction semantics");
1375 isa<IntegerType>(elementType)
1376 ?
static_cast<Value>(rew.
create<arith::AddIOp>(loc, op.getAcc(), mul))
1377 :
static_cast<Value>(
1378 rew.
create<arith::AddFOp>(loc, op.getAcc(), mul));
1387 if (!disableOuterProductLowering)
1389 patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
1390 ContractionOpToOuterProductOpLowering>(
static SmallVector< int64_t > getReductionIndex(AffineMap map, ArrayAttr iteratorTypes)
Return the positions of the reductions in the given map.
static Value createMul(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates a MulIOp if isInt is true otherwise create an MulFOp using operands x andy`.
static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y.
static std::optional< unsigned > getDimPosition(AffineMap map, unsigned dim)
Look for a given dimension in an affine map and return its position.
static std::optional< Value > createContractArithOp(Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, bool isInt, Value mask=Value())
Helper to create arithmetic operation associated with a kind of contraction.
static Value reshapeStore(Location loc, Value val, Value result, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter)
static Value reshapeLoad(Location loc, Value val, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< Attribute > adjustIter(ArrayAttr iteratorTypes, int64_t index)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.