35 #define DEBUG_TYPE "vector-contract-lowering"
58 int64_t idx = it.index();
61 results.push_back(it.value());
78 results.push_back(targetExpr);
86 int64_t index, int64_t pos,
93 return rewriter.
create<vector::ExtractOp>(loc, val, pos);
100 for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
101 Value ext = rewriter.
create<vector::ExtractOp>(loc, val, d);
103 result = rewriter.
create<vector::InsertOp>(loc, load, result, d);
111 VectorType type, int64_t index, int64_t pos,
118 return rewriter.
create<vector::InsertOp>(loc, val, result, pos);
122 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
123 Value ext = rewriter.
create<vector::ExtractOp>(loc, result, d);
124 Value ins = rewriter.
create<vector::ExtractOp>(loc, val, d);
126 result = rewriter.
create<vector::InsertOp>(loc, sto, result, d);
132 static std::optional<Value>
136 using vector::CombiningKind;
140 if (
kind == CombiningKind::MINNUMF ||
kind == CombiningKind::MAXNUMF ||
141 kind == CombiningKind::MINIMUMF ||
kind == CombiningKind::MAXIMUMF)
144 mul = rewriter.
create<arith::MulIOp>(loc, x, y);
148 kind == CombiningKind::MINSI ||
kind == CombiningKind::MAXUI ||
149 kind == CombiningKind::MAXSI ||
kind == CombiningKind::OR ||
150 kind == CombiningKind::XOR)
154 if (acc && isa<VectorType>(acc.
getType()) &&
kind == CombiningKind::ADD) {
155 Value fma = rewriter.
create<vector::FMAOp>(loc, x, y, acc);
162 mul = rewriter.
create<arith::MulFOp>(loc, x, y);
166 return std::optional<Value>(mul);
174 ArrayAttr iteratorTypes) {
178 dimsIdx.push_back(i);
198 return rewriter.
create<arith::AddIOp>(loc, x, y);
199 return rewriter.
create<arith::AddFOp>(loc, x, y);
207 return rewriter.
create<arith::MulIOp>(loc, x, y);
208 return rewriter.
create<arith::MulFOp>(loc, x, y);
226 class ContractionOpToMatmulOpLowering
229 using MaskableOpRewritePattern::MaskableOpRewritePattern;
231 using FilterConstraintType =
232 std::function<LogicalResult(vector::ContractionOp op)>;
234 static LogicalResult defaultFilter(vector::ContractionOp op) {
238 ContractionOpToMatmulOpLowering(
239 vector::VectorContractLowering vectorContractLowering,
241 FilterConstraintType constraint = defaultFilter)
243 vectorContractLowering(vectorContractLowering),
244 filter(std::move(constraint)) {}
247 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
252 vector::VectorContractLowering vectorContractLowering;
253 FilterConstraintType filter;
271 class ContractionOpToOuterProductOpLowering
274 using MaskableOpRewritePattern::MaskableOpRewritePattern;
276 using FilterConstraintType =
277 std::function<LogicalResult(vector::ContractionOp op)>;
279 static LogicalResult defaultFilter(vector::ContractionOp op) {
283 ContractionOpToOuterProductOpLowering(
284 vector::VectorContractLowering vectorContractLowering,
286 FilterConstraintType constraint = defaultFilter)
288 vectorContractLowering(vectorContractLowering),
289 filter(std::move(constraint)) {}
292 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
297 vector::VectorContractLowering vectorContractLowering;
298 FilterConstraintType filter;
319 class ContractionOpToDotLowering
322 using MaskableOpRewritePattern::MaskableOpRewritePattern;
324 using FilterConstraintType =
325 std::function<LogicalResult(vector::ContractionOp op)>;
327 static LogicalResult defaultFilter(vector::ContractionOp op) {
331 ContractionOpToDotLowering(
332 vector::VectorContractLowering vectorContractLowering,
334 const FilterConstraintType &constraint = defaultFilter)
336 vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
339 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
344 vector::VectorContractLowering vectorContractLowering;
345 FilterConstraintType filter;
362 class ContractionOpLowering
365 using MaskableOpRewritePattern::MaskableOpRewritePattern;
366 using FilterConstraintType =
367 std::function<LogicalResult(vector::ContractionOp op)>;
369 static LogicalResult defaultFilter(vector::ContractionOp op) {
373 ContractionOpLowering(
374 vector::VectorContractLowering vectorContractLoweringOption,
376 FilterConstraintType constraint = defaultFilter)
378 vectorContractLoweringOption(vectorContractLoweringOption),
379 filter(std::move(constraint)) {}
382 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
387 vector::VectorContractLowering vectorContractLoweringOption;
388 FilterConstraintType filter;
391 vector::ContractionOp op, int64_t lhsIndex,
392 int64_t rhsIndex,
Value mask)
const;
395 vector::ContractionOp op,
Value mask)
const;
400 struct UnrolledOuterProductGenerator
402 UnrolledOuterProductGenerator(
RewriterBase &b, vector::ContractionOp op)
404 kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
405 res(op.getAcc()), lhsType(op.getLhsType()) {
406 auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
407 if (maskableOp.isMasked())
408 mask = maskableOp.getMaskingOp().getMask();
414 return rewriter.
create<vector::TransposeOp>(loc, v, perm);
419 auto vecType = dyn_cast<VectorType>(elementType);
421 elementType = vecType.getElementType();
422 if (elementType == dstElementType)
424 Type promotedType = dstElementType;
426 promotedType = vecType.clone(promotedType);
427 if (isa<FloatType>(dstElementType))
428 return rewriter.
create<arith::ExtFOp>(loc, promotedType, v);
429 return rewriter.
create<arith::ExtSIOp>(loc, promotedType, v);
433 VectorType lhsType,
int reductionSize,
434 std::optional<Value> maybeMask = std::nullopt) {
436 if (mask && !maybeMask.has_value())
439 Type resElementType = cast<VectorType>(res.
getType()).getElementType();
440 for (int64_t k = 0; k < reductionSize; ++k) {
441 Value extractA = rewriter.
create<vector::ExtractOp>(loc, lhs, k);
442 Value extractB = rewriter.
create<vector::ExtractOp>(loc, rhs, k);
443 extractA =
promote(extractA, resElementType);
444 extractB =
promote(extractB, resElementType);
446 if (maybeMask.has_value() && maybeMask.value())
448 rewriter.
create<vector::ExtractOp>(loc, maybeMask.value(), k);
451 loc, res.
getType(), extractA, extractB, res,
kind);
460 std::optional<int64_t> getReductionSize(VectorType vecType,
461 int64_t reductionDim) {
463 if (vecType.getScalableDims()[reductionDim])
465 int64_t reductionSize = vecType.getDimSize(reductionDim);
466 assert(reductionSize > 0 &&
467 "Reduction dim must be a known static size to allow unrolling");
468 return reductionSize;
472 FailureOr<Value> matmat() {
473 if (!iters({Par(), Par(), Red()}))
480 if (layout({{m, k}, {k, n}, {m, n}})) {
481 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
485 Value tMask = t(mask, {2, 0, 1});
486 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
490 if (layout({{m, k}, {n, k}, {m, n}})) {
491 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
494 Value tMask = t(mask, {2, 0, 1});
495 return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
499 if (layout({{k, m}, {k, n}, {m, n}})) {
500 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
501 Value tMask = t(mask, {2, 0, 1});
502 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
506 if (layout({{k, m}, {n, k}, {m, n}})) {
507 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
509 Value tMask = t(mask, {2, 0, 1});
510 return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask);
515 if (layout({{m, k}, {k, n}, {n, m}})) {
516 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
518 Value tMask = t(mask, {2, 0, 1});
519 return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask);
523 if (layout({{m, k}, {n, k}, {n, m}})) {
524 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
527 Value tMask = t(mask, {2, 0, 1});
528 return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
531 if (layout({{k, m}, {k, n}, {n, m}})) {
532 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
533 Value tMask = t(mask, {2, 0, 1});
534 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
537 if (layout({{k, m}, {n, k}, {n, m}})) {
538 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
540 Value tMask = t(mask, {2, 0, 1});
541 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
552 FailureOr<Value> matvec() {
553 if (!iters({Par(), Red()}))
559 if (layout({{m, k}, {k}, {m}})) {
560 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
562 Value tMask = t(mask);
563 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
567 if (layout({{k, m}, {k}, {m}})) {
568 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
569 Value tMask = t(mask);
570 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
574 if (layout({{k}, {m, k}, {m}})) {
575 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
577 Value tMask = t(mask);
578 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
582 if (layout({{k}, {k, m}, {m}})) {
583 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
584 Value tMask = t(mask);
585 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
595 FailureOr<Value> tmatvec() {
596 if (!iters({Red(), Par()}))
602 if (layout({{m, k}, {k}, {m}}))
603 if (
auto reductionSize = getReductionSize(lhsType, 1))
604 return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
606 if (layout({{k, m}, {k}, {m}}))
607 if (
auto reductionSize = getReductionSize(lhsType, 0))
608 return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
610 if (layout({{k}, {m, k}, {m}}))
611 if (
auto reductionSize = getReductionSize(lhsType, 0))
612 return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
614 if (layout({{k}, {k, m}, {m}}))
615 if (
auto reductionSize = getReductionSize(lhsType, 0))
616 return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
621 vector::CombiningKind
kind;
622 Value lhs, rhs, res, mask;
642 ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
643 vector::ContractionOp op, MaskingOpInterface maskOp,
645 if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
648 if (failed(filter(op)))
651 UnrolledOuterProductGenerator e(rewriter, op);
652 FailureOr<Value> matmatRes = e.matmat();
653 if (succeeded(matmatRes)) {
656 FailureOr<Value> matvecRes = e.matvec();
657 if (succeeded(matvecRes)) {
661 FailureOr<Value> tmatvecRes = e.tmatvec();
665 FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
666 vector::ContractionOp op, MaskingOpInterface maskOp,
672 if (failed(filter(op)))
675 if (vectorContractLowering != vector::VectorContractLowering::Dot)
678 auto iteratorTypes = op.getIteratorTypes().getValue();
679 static constexpr std::array<int64_t, 2> perm = {1, 0};
681 Value lhs = op.getLhs(), rhs = op.getRhs();
684 auto infer = [&](MapList m) {
700 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
701 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
702 }
else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
704 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
705 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
706 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
707 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
708 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
709 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
712 lhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
714 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
716 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
718 lhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
719 rhs = rewriter.
create<vector::TransposeOp>(loc, tmp, perm);
720 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
722 rhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
732 if (maps == infer({{m, n}, {n}, {m}})) {
734 }
else if (maps == infer({{n, m}, {n}, {m}})) {
735 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
736 }
else if (maps == infer({{n}, {m, n}, {m}})) {
738 }
else if (maps == infer({{n}, {n, m}, {m}})) {
740 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
748 VectorType dstType = cast<VectorType>(op.getResultType());
749 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
750 "Expected dst type of rank 1 or 2");
752 unsigned rank = dstType.getRank();
753 unsigned dstRows = dstType.getShape()[0];
754 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
757 Value res = rewriter.
create<arith::ConstantOp>(loc, dstType,
759 bool isInt = isa<IntegerType>(dstType.getElementType());
761 extractedCols.reserve(dstColumns);
762 for (
unsigned r = 0; r < dstRows; ++r) {
763 Value rowLhs = rewriter.
create<vector::ExtractOp>(op.getLoc(), lhs, r);
764 for (
unsigned c = 0; c < dstColumns; ++c) {
770 : rewriter.
create<vector::ExtractOp>(op.getLoc(), rhs, c);
771 extractedCols.push_back(colRhs);
773 Value extractedColRhs = extractedCols[c];
775 createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter);
777 op.getLoc(), vector::CombiningKind::ADD,
product);
781 res = rewriter.
create<vector::InsertOp>(op.getLoc(), sum, res, pos);
784 if (
auto acc = op.getAcc())
785 res =
createAdd(op.getLoc(), res, acc, isInt, rewriter);
791 struct ContractOpToElementwise
793 using MaskableOpRewritePattern::MaskableOpRewritePattern;
794 using FilterConstraintType =
795 std::function<LogicalResult(vector::ContractionOp op)>;
796 static LogicalResult defaultFilter(vector::ContractionOp op) {
799 ContractOpToElementwise(
800 vector::VectorContractLowering vectorContractLowering,
802 const FilterConstraintType &constraint = defaultFilter)
804 vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
807 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
808 MaskingOpInterface maskOp,
814 if (failed(filter(contractOp)))
817 if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
822 AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
823 AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
829 for (int64_t dim : lhsReductionDims) {
830 if (lhsShape[dim] != 1)
833 for (int64_t dim : rhsReductionDims) {
834 if (rhsShape[dim] != 1)
837 AffineMap accMap = contractOp.getIndexingMapsArray()[2];
839 unsigned numLhsDimToBroadcast =
840 numParallelDims - (lhsMap.
getNumResults() - lhsReductionDims.size());
841 unsigned numRhsDimToBroadcast =
842 numParallelDims - (rhsMap.
getNumResults() - rhsReductionDims.size());
847 for (int64_t dim : lhsReductionDims)
848 lhsTranspose.push_back(numLhsDimToBroadcast + dim);
849 for (int64_t dim : rhsReductionDims)
850 rhsTranspose.push_back(numRhsDimToBroadcast + dim);
853 for (
unsigned i = 0; i < numParallelDims; i++) {
854 std::optional<unsigned> lhsDim =
857 lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
861 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
862 lhsTranspose.push_back(lhsDims.size() - 1);
864 std::optional<unsigned> rhsDim =
867 rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
871 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
872 rhsTranspose.push_back(rhsDims.size() - 1);
875 Value newLhs = contractOp.getLhs();
876 Value newRhs = contractOp.getRhs();
878 if (!lhsDims.empty()) {
879 lhsDims.append(lhsShape.begin(), lhsShape.end());
882 newLhs = rewriter.
create<vector::BroadcastOp>(loc, expandedType, newLhs);
884 if (!rhsDims.empty()) {
885 rhsDims.append(rhsShape.begin(), rhsShape.end());
888 newRhs = rewriter.
create<vector::BroadcastOp>(loc, expandedType, newRhs);
890 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
891 newLhs = rewriter.
create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
892 newRhs = rewriter.
create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
895 newLhs = rewriter.
create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
896 newRhs = rewriter.
create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
897 std::optional<Value> result =
899 contractOp.getKind(), rewriter, isInt);
908 vector::VectorContractLowering vectorContractLowering;
909 FilterConstraintType filter;
929 FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
930 vector::ContractionOp op, MaskingOpInterface maskOp,
932 if (failed(filter(op)))
936 if (op.getLhsType().getElementType() !=
943 if (op.getKind() != vector::CombiningKind::ADD) {
945 op,
"contractions other than 'add' not supported");
951 ContractionOpToMatmulOpLowering pat1(vectorContractLoweringOption, ctx);
952 FailureOr<Value> newVal1 =
953 pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
954 if (!failed(newVal1))
957 ContractionOpToOuterProductOpLowering pat2(vectorContractLoweringOption, ctx);
958 FailureOr<Value> newVal2 =
959 pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
960 if (!failed(newVal2))
963 ContractionOpToDotLowering pat3(vectorContractLoweringOption, ctx);
964 FailureOr<Value> newVal3 =
965 pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
966 if (!failed(newVal3))
969 ContractOpToElementwise pat4(vectorContractLoweringOption, ctx);
970 FailureOr<Value> newVal4 =
971 pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
972 if (!failed(newVal4))
979 mask = maskOp.getMask();
981 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
982 if (!batchDimMap.empty()) {
983 int64_t lhsIndex = batchDimMap[0].first;
984 int64_t rhsIndex = batchDimMap[0].second;
985 auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
992 std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
993 op.getContractingDimMap();
996 for (
auto &dimPair : contractingDimMap) {
997 lhsContractingDimSet.insert(dimPair.first);
998 rhsContractingDimSet.insert(dimPair.second);
1002 VectorType lhsType = op.getLhsType();
1003 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1004 if (lhsContractingDimSet.count(lhsIndex) == 0) {
1005 auto newOp = lowerParallel(rewriter, op, lhsIndex, -1, mask);
1013 VectorType rhsType = op.getRhsType();
1014 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1015 if (rhsContractingDimSet.count(rhsIndex) == 0) {
1016 auto newOp = lowerParallel(rewriter, op, -1, rhsIndex, mask);
1024 if (!contractingDimMap.empty()) {
1025 auto newOp = lowerReduction(rewriter, op, mask);
1037 FailureOr<Value> ContractionOpLowering::lowerParallel(
PatternRewriter &rewriter,
1038 vector::ContractionOp op,
1042 VectorType lhsType = op.getLhsType();
1043 VectorType rhsType = op.getRhsType();
1044 VectorType resType = cast<VectorType>(op.getResultType());
1047 int64_t iterIndex = -1;
1048 int64_t dimSize = -1;
1049 if (lhsIndex >= 0) {
1050 iterIndex = iMap[0].getDimPosition(lhsIndex);
1051 if (rhsIndex >= 0 && iterIndex != iMap[1].
getDimPosition(rhsIndex))
1053 diag <<
"expected lhsIndex=" << lhsIndex <<
" and rhsIndex=" << rhsIndex
1054 <<
" to map to the same dimension";
1056 if (lhsType.getScalableDims()[lhsIndex])
1058 diag <<
"Unrolling scalable dimension (lhsIndex=" << lhsIndex
1059 <<
") is not supported yet";
1061 dimSize = lhsType.getDimSize(lhsIndex);
1062 }
else if (rhsIndex >= 0) {
1063 iterIndex = iMap[1].getDimPosition(rhsIndex);
1064 if (rhsType.getScalableDims()[rhsIndex])
1066 diag <<
"Unrolling scalable dimension (rhsIndex=" << rhsIndex
1067 <<
") is not supported yet";
1069 dimSize = rhsType.getDimSize(rhsIndex);
1073 diag <<
"expected either lhsIndex=" << lhsIndex
1074 <<
" or rhsIndex=" << rhsIndex <<
" to be nonnegative";
1083 int64_t resIndex =
getResultIndex(iMap[2], iterIndex).value_or(-1);
1084 if (resIndex == -1 && dimSize != 1)
1086 diag <<
"expected the dimension for iterIndex=" << iterIndex
1087 <<
" to either appear in the result map, or to be a unit dimension";
1091 std::array<AffineMap, 3> lowIndexingMaps = {
1092 adjustMap(iMap[0], iterIndex, rewriter),
1093 adjustMap(iMap[1], iterIndex, rewriter),
1094 adjustMap(iMap[2], iterIndex, rewriter)};
1100 Value result = rewriter.
create<arith::ConstantOp>(
1103 for (int64_t d = 0; d < dimSize; ++d) {
1104 auto lhs =
reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1105 auto rhs =
reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1106 auto acc =
reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1111 iterIndex, d, rewriter);
1114 loc, lhs, rhs, acc, lowAffine, lowIter);
1115 lowContract =
maskOperation(rewriter, lowContract, lowMask);
1117 resIndex, d, rewriter);
1123 FailureOr<Value> ContractionOpLowering::lowerReduction(
1126 VectorType lhsType = op.getLhsType();
1127 VectorType rhsType = op.getRhsType();
1128 Type resType = op.getResultType();
1129 if (isa<VectorType>(resType))
1131 "did not expect a VectorType result");
1132 bool isInt = isa<IntegerType>(resType);
1134 int64_t iterIndex = 0;
1136 std::optional<int64_t> lookupLhs =
getResultIndex(iMap[0], iterIndex);
1137 std::optional<int64_t> lookupRhs =
getResultIndex(iMap[1], iterIndex);
1138 if (!lookupLhs.has_value())
1140 diag <<
"expected iterIndex=" << iterIndex <<
"to map to a LHS dimension";
1142 if (!lookupRhs.has_value())
1144 diag <<
"expected iterIndex=" << iterIndex <<
"to map to a RHS dimension";
1146 int64_t lhsIndex = *lookupLhs;
1147 int64_t rhsIndex = *lookupRhs;
1148 int64_t dimSize = lhsType.getDimSize(lhsIndex);
1149 if (dimSize != rhsType.getDimSize(rhsIndex))
1151 diag <<
"expect LHS dimension " << lhsIndex
1152 <<
" to have the same size as RHS dimension " << rhsIndex;
1155 if (lhsType.getRank() == 1) {
1156 if (rhsType.getRank() != 1)
1158 op,
"When LHS has rank 1, expected also RHS to have rank 1");
1159 Value m =
createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1160 auto kind = vector::CombiningKind::ADD;
1162 Value acc = op.getAcc();
1164 acc ? rewriter.
create<vector::ReductionOp>(loc,
kind, m, acc)
1165 : rewriter.
create<vector::ReductionOp>(loc,
kind, m);
1169 std::array<AffineMap, 3> lowIndexingMaps = {
1170 adjustMap(iMap[0], iterIndex, rewriter),
1171 adjustMap(iMap[1], iterIndex, rewriter),
1172 adjustMap(iMap[2], iterIndex, rewriter)};
1180 Value result = op.getAcc();
1181 for (int64_t d = 0; d < dimSize; ++d) {
1182 auto lhs =
reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1183 auto rhs =
reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1187 iterIndex, d, rewriter);
1190 loc, lhs, rhs, result, lowAffine, lowIter);
1209 class OuterProductOpLowering :
public OpRewritePattern<vector::OuterProductOp> {
1213 LogicalResult matchAndRewrite(vector::OuterProductOp op,
1215 VectorType resType = op.getResultVectorType();
1216 if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1219 auto loc = op.getLoc();
1221 VectorType lhsType = op.getOperandVectorTypeLHS();
1222 VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1223 Type eltType = resType.getElementType();
1224 bool isInt = isa<IntegerType, IndexType>(eltType);
1225 Value acc = op.getAcc();
1226 vector::CombiningKind
kind = op.getKind();
1230 auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1233 if (maskableOp.isMasked()) {
1235 rootOp = maskableOp.getMaskingOp();
1236 mask = maskableOp.getMaskingOp().getMask();
1243 Value b = rewriter.
create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
1245 loc, op.getLhs(), b, acc,
kind, rewriter, isInt, mask);
1246 if (!mult.has_value())
1252 Value result = rewriter.
create<arith::ConstantOp>(
1254 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1255 Value x = rewriter.
create<vector::ExtractOp>(loc, op.getLhs(), d);
1256 Value a = rewriter.
create<vector::BroadcastOp>(loc, rhsType, x);
1259 r = rewriter.
create<vector::ExtractOp>(loc, acc, d);
1262 extrMask = rewriter.
create<vector::ExtractOp>(loc, mask, d);
1265 loc, a, op.getRhs(), r,
kind, rewriter, isInt, extrMask);
1268 result = rewriter.
create<vector::InsertOp>(loc, *m, result, d);
1295 FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
1296 vector::ContractionOp op, MaskingOpInterface maskOp,
1302 if (vectorContractLowering != vector::VectorContractLowering::Matmul)
1304 if (failed(filter(op)))
1307 auto iteratorTypes = op.getIteratorTypes().getValue();
1313 Type opResType = op.getType();
1314 VectorType vecType = dyn_cast<VectorType>(opResType);
1315 if (vecType && vecType.isScalable()) {
1320 Type elementType = op.getLhsType().getElementType();
1324 Type dstElementType = vecType ? vecType.getElementType() : opResType;
1325 if (elementType != dstElementType)
1335 Value lhs = op.getLhs();
1336 auto lhsMap = op.getIndexingMapsArray()[0];
1343 Value rhs = op.getRhs();
1344 auto rhsMap = op.getIndexingMapsArray()[1];
1351 VectorType lhsType = cast<VectorType>(lhs.getType());
1352 VectorType rhsType = cast<VectorType>(rhs.getType());
1353 int64_t lhsRows = lhsType.getDimSize(0);
1354 int64_t lhsColumns = lhsType.getDimSize(1);
1355 int64_t rhsColumns = rhsType.getDimSize(1);
1357 Type flattenedLHSType =
1359 lhs = rew.
create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1361 Type flattenedRHSType =
1363 rhs = rew.
create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1365 Value mul = rew.
create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1367 mul = rew.
create<vector::ShapeCastOp>(
1374 auto accMap = op.getIndexingMapsArray()[2];
1378 llvm_unreachable(
"invalid contraction semantics");
1381 isa<IntegerType>(elementType)
1382 ?
static_cast<Value>(rew.
create<arith::AddIOp>(loc, op.getAcc(), mul))
1383 :
static_cast<Value>(
1384 rew.
create<arith::AddFOp>(loc, op.getAcc(), mul));
1392 VectorContractLowering vectorContractLoweringOption,
PatternBenefit benefit,
1393 bool disableOuterProductLowering) {
1394 if (!disableOuterProductLowering)
1396 patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
1397 ContractionOpToOuterProductOpLowering>(
1398 vectorContractLoweringOption,
patterns.getContext(), benefit);
static int64_t product(ArrayRef< int64_t > vals)
union mlir::linalg::@1183::ArityGroupAndKind::Kind kind
static SmallVector< int64_t > getReductionIndex(AffineMap map, ArrayAttr iteratorTypes)
Return the positions of the reductions in the given map.
static Value createMul(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates a MulIOp if isInt is true otherwise create an MulFOp using operands x andy`.
static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y.
static std::optional< unsigned > getDimPosition(AffineMap map, unsigned dim)
Look for a given dimension in an affine map and return its position.
static std::optional< Value > createContractArithOp(Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, bool isInt, Value mask=Value())
Helper to create arithmetic operation associated with a kind of contraction.
static Value reshapeStore(Location loc, Value val, Value result, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter)
static Value reshapeLoad(Location loc, Value val, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< Attribute > adjustIter(ArrayAttr iteratorTypes, int64_t index)
static std::string diag(const llvm::Value &value)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.