36 #define DEBUG_TYPE "vector-contract-lowering"
59 int64_t idx = it.index();
62 results.push_back(it.value());
79 results.push_back(targetExpr);
87 int64_t index, int64_t pos,
94 return rewriter.
create<vector::ExtractOp>(loc, val, pos);
101 for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
102 Value ext = rewriter.
create<vector::ExtractOp>(loc, val, d);
104 result = rewriter.
create<vector::InsertOp>(loc, load, result, d);
112 VectorType type, int64_t index, int64_t pos,
119 return rewriter.
create<vector::InsertOp>(loc, val, result, pos);
123 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
124 Value ext = rewriter.
create<vector::ExtractOp>(loc, result, d);
125 Value ins = rewriter.
create<vector::ExtractOp>(loc, val, d);
127 result = rewriter.
create<vector::InsertOp>(loc, sto, result, d);
133 static std::optional<Value>
137 using vector::CombiningKind;
141 if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
142 kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
145 mul = rewriter.
create<arith::MulIOp>(loc, x, y);
149 kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
150 kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
151 kind == CombiningKind::XOR)
155 if (acc && isa<VectorType>(acc.
getType()) && kind == CombiningKind::ADD) {
156 Value fma = rewriter.
create<vector::FMAOp>(loc, x, y, acc);
163 mul = rewriter.
create<arith::MulFOp>(loc, x, y);
167 return std::optional<Value>(mul);
175 ArrayAttr iteratorTypes) {
179 dimsIdx.push_back(i);
199 return rewriter.
create<arith::AddIOp>(loc, x, y);
200 return rewriter.
create<arith::AddFOp>(loc, x, y);
208 return rewriter.
create<arith::MulIOp>(loc, x, y);
209 return rewriter.
create<arith::MulFOp>(loc, x, y);
227 class ContractionOpToMatmulOpLowering
230 using MaskableOpRewritePattern::MaskableOpRewritePattern;
232 using FilterConstraintType =
235 static LogicalResult defaultFilter(vector::ContractionOp op) {
239 ContractionOpToMatmulOpLowering(
242 FilterConstraintType constraint = defaultFilter)
244 vectorTransformOptions(vectorTransformOptions),
245 filter(std::move(constraint)) {}
248 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
254 FilterConstraintType filter;
272 class ContractionOpToOuterProductOpLowering
275 using MaskableOpRewritePattern::MaskableOpRewritePattern;
277 using FilterConstraintType =
280 static LogicalResult defaultFilter(vector::ContractionOp op) {
284 ContractionOpToOuterProductOpLowering(
287 FilterConstraintType constraint = defaultFilter)
289 vectorTransformOptions(vectorTransformOptions),
290 filter(std::move(constraint)) {}
293 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
299 FilterConstraintType filter;
320 class ContractionOpToDotLowering
323 using MaskableOpRewritePattern::MaskableOpRewritePattern;
325 using FilterConstraintType =
328 static LogicalResult defaultFilter(vector::ContractionOp op) {
332 ContractionOpToDotLowering(
335 const FilterConstraintType &constraint = defaultFilter)
337 vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
340 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
346 FilterConstraintType filter;
363 class ContractionOpLowering
366 using MaskableOpRewritePattern::MaskableOpRewritePattern;
367 using FilterConstraintType =
370 static LogicalResult defaultFilter(vector::ContractionOp op) {
376 FilterConstraintType constraint = defaultFilter)
378 vectorTransformOptions(vectorTransformOptions),
379 filter(std::move(constraint)) {}
382 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
388 FilterConstraintType filter;
391 vector::ContractionOp op, int64_t lhsIndex,
392 int64_t rhsIndex,
Value mask)
const;
395 vector::ContractionOp op,
Value mask)
const;
400 struct UnrolledOuterProductGenerator
402 UnrolledOuterProductGenerator(
RewriterBase &b, vector::ContractionOp op)
404 kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
405 res(op.getAcc()), lhsType(op.getLhsType()) {
406 auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
407 if (maskableOp.isMasked())
408 mask = maskableOp.getMaskingOp().getMask();
414 return rewriter.
create<vector::TransposeOp>(loc, v, perm);
419 auto vecType = dyn_cast<VectorType>(elementType);
421 elementType = vecType.getElementType();
422 if (elementType == dstElementType)
424 Type promotedType = dstElementType;
426 promotedType = vecType.clone(promotedType);
427 if (isa<FloatType>(dstElementType))
428 return rewriter.
create<arith::ExtFOp>(loc, promotedType, v);
429 return rewriter.
create<arith::ExtSIOp>(loc, promotedType, v);
433 VectorType lhsType,
int reductionSize,
434 std::optional<Value> maybeMask = std::nullopt) {
436 if (mask && !maybeMask.has_value())
439 Type resElementType = cast<VectorType>(res.
getType()).getElementType();
440 for (int64_t k = 0; k < reductionSize; ++k) {
441 Value extractA = rewriter.
create<vector::ExtractOp>(loc, lhs, k);
442 Value extractB = rewriter.
create<vector::ExtractOp>(loc, rhs, k);
443 extractA =
promote(extractA, resElementType);
444 extractB =
promote(extractB, resElementType);
446 if (maybeMask.has_value() && maybeMask.value())
448 rewriter.
create<vector::ExtractOp>(loc, maybeMask.value(), k);
451 loc, res.
getType(), extractA, extractB, res, kind);
460 std::optional<int64_t> getReductionSize(VectorType vecType,
461 int64_t reductionDim) {
463 if (vecType.getScalableDims()[reductionDim])
465 int64_t reductionSize = vecType.getDimSize(reductionDim);
466 assert(reductionSize > 0 &&
467 "Reduction dim must be a known static size to allow unrolling");
468 return reductionSize;
473 if (!iters({Par(), Par(), Red()}))
480 if (layout({{m, k}, {k, n}, {m, n}})) {
481 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
485 Value tMask = t(mask, {2, 0, 1});
486 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
490 if (layout({{m, k}, {n, k}, {m, n}})) {
491 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
494 Value tMask = t(mask, {2, 0, 1});
495 return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
499 if (layout({{k, m}, {k, n}, {m, n}})) {
500 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
501 Value tMask = t(mask, {2, 0, 1});
502 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
506 if (layout({{k, m}, {n, k}, {m, n}})) {
507 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
509 Value tMask = t(mask, {2, 0, 1});
510 return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask);
515 if (layout({{m, k}, {k, n}, {n, m}})) {
516 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
518 Value tMask = t(mask, {2, 0, 1});
519 return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask);
523 if (layout({{m, k}, {n, k}, {n, m}})) {
524 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
527 Value tMask = t(mask, {2, 0, 1});
528 return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
531 if (layout({{k, m}, {k, n}, {n, m}})) {
532 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
533 Value tMask = t(mask, {2, 0, 1});
534 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
537 if (layout({{k, m}, {n, k}, {n, m}})) {
538 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
540 Value tMask = t(mask, {2, 0, 1});
541 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
553 if (!iters({Par(), Red()}))
559 if (layout({{m, k}, {k}, {m}})) {
560 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
562 Value tMask = t(mask);
563 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
567 if (layout({{k, m}, {k}, {m}})) {
568 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
569 Value tMask = t(mask);
570 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
574 if (layout({{k}, {m, k}, {m}})) {
575 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
577 Value tMask = t(mask);
578 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
582 if (layout({{k}, {k, m}, {m}})) {
583 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
584 Value tMask = t(mask);
585 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
596 if (!iters({Red(), Par()}))
602 if (layout({{m, k}, {k}, {m}}))
603 if (
auto reductionSize = getReductionSize(lhsType, 1))
604 return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
606 if (layout({{k, m}, {k}, {m}}))
607 if (
auto reductionSize = getReductionSize(lhsType, 0))
608 return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
610 if (layout({{k}, {m, k}, {m}}))
611 if (
auto reductionSize = getReductionSize(lhsType, 0))
612 return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
614 if (layout({{k}, {k, m}, {m}}))
615 if (
auto reductionSize = getReductionSize(lhsType, 0))
616 return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
621 vector::CombiningKind kind;
622 Value lhs, rhs, res, mask;
642 ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
643 vector::ContractionOp op, MaskingOpInterface maskOp,
645 if (vectorTransformOptions.vectorContractLowering !=
646 vector::VectorContractLowering::OuterProduct)
652 UnrolledOuterProductGenerator e(rewriter, op);
667 vector::ContractionOp op, MaskingOpInterface maskOp,
676 if (vectorTransformOptions.vectorContractLowering !=
677 vector::VectorContractLowering::Dot)
680 auto iteratorTypes = op.getIteratorTypes().getValue();
681 static constexpr std::array<int64_t, 2> perm = {1, 0};
683 Value lhs = op.getLhs(), rhs = op.getRhs();
686 auto infer = [&](MapList m) {
702 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
703 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
704 }
else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
706 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
707 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
708 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
709 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
710 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
711 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
714 lhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
716 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
718 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
720 lhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
721 rhs = rewriter.
create<vector::TransposeOp>(loc, tmp, perm);
722 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
724 rhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
734 if (maps == infer({{m, n}, {n}, {m}})) {
736 }
else if (maps == infer({{n, m}, {n}, {m}})) {
737 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
738 }
else if (maps == infer({{n}, {m, n}, {m}})) {
740 }
else if (maps == infer({{n}, {n, m}, {m}})) {
742 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
750 VectorType dstType = cast<VectorType>(op.getResultType());
751 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
752 "Expected dst type of rank 1 or 2");
754 unsigned rank = dstType.getRank();
755 unsigned dstRows = dstType.getShape()[0];
756 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
759 Value res = rewriter.
create<arith::ConstantOp>(loc, dstType,
761 bool isInt = isa<IntegerType>(dstType.getElementType());
762 for (
unsigned r = 0; r < dstRows; ++r) {
764 for (
unsigned c = 0; c < dstColumns; ++c) {
767 : rewriter.
create<vector::ExtractOp>(op.
getLoc(), rhs, c);
769 Value reduced = rewriter.
create<vector::ReductionOp>(
770 op.
getLoc(), vector::CombiningKind::ADD, m);
774 res = rewriter.
create<vector::InsertOp>(op.
getLoc(), reduced, res, pos);
777 if (
auto acc = op.getAcc())
784 struct ContractOpToElementwise
786 using MaskableOpRewritePattern::MaskableOpRewritePattern;
787 using FilterConstraintType =
789 static LogicalResult defaultFilter(vector::ContractionOp op) {
792 ContractOpToElementwise(
795 const FilterConstraintType &constraint = defaultFilter)
797 vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
800 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
801 MaskingOpInterface maskOp,
807 if (
failed(filter(contractOp)))
810 if (vectorTransformOptions.vectorContractLowering !=
811 vector::VectorContractLowering::ParallelArith)
816 AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
817 AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
823 for (int64_t dim : lhsReductionDims) {
824 if (lhsShape[dim] != 1)
827 for (int64_t dim : rhsReductionDims) {
828 if (rhsShape[dim] != 1)
831 AffineMap accMap = contractOp.getIndexingMapsArray()[2];
833 unsigned numLhsDimToBroadcast =
834 numParallelDims - (lhsMap.
getNumResults() - lhsReductionDims.size());
835 unsigned numRhsDimToBroadcast =
836 numParallelDims - (rhsMap.
getNumResults() - rhsReductionDims.size());
841 for (int64_t dim : lhsReductionDims)
842 lhsTranspose.push_back(numLhsDimToBroadcast + dim);
843 for (int64_t dim : rhsReductionDims)
844 rhsTranspose.push_back(numRhsDimToBroadcast + dim);
847 for (
unsigned i = 0; i < numParallelDims; i++) {
848 std::optional<unsigned> lhsDim =
851 lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
855 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
856 lhsTranspose.push_back(lhsDims.size() - 1);
858 std::optional<unsigned> rhsDim =
861 rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
865 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
866 rhsTranspose.push_back(rhsDims.size() - 1);
869 Value newLhs = contractOp.getLhs();
870 Value newRhs = contractOp.getRhs();
872 if (!lhsDims.empty()) {
873 lhsDims.append(lhsShape.begin(), lhsShape.end());
876 newLhs = rewriter.
create<vector::BroadcastOp>(loc, expandedType, newLhs);
878 if (!rhsDims.empty()) {
879 rhsDims.append(rhsShape.begin(), rhsShape.end());
882 newRhs = rewriter.
create<vector::BroadcastOp>(loc, expandedType, newRhs);
884 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
885 newLhs = rewriter.
create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
886 newRhs = rewriter.
create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
889 newLhs = rewriter.
create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
890 newRhs = rewriter.
create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
891 std::optional<Value> result =
893 contractOp.getKind(), rewriter, isInt);
903 FilterConstraintType filter;
924 vector::ContractionOp op, MaskingOpInterface maskOp,
930 if (op.getLhsType().getElementType() !=
937 if (op.getKind() != vector::CombiningKind::ADD) {
939 op,
"contractions other than 'add' not supported");
945 ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
947 pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
951 ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
953 pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
957 ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
959 pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
963 ContractOpToElementwise pat4(vectorTransformOptions, ctx);
965 pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
973 mask = maskOp.getMask();
975 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
976 if (!batchDimMap.empty()) {
977 int64_t lhsIndex = batchDimMap[0].first;
978 int64_t rhsIndex = batchDimMap[0].second;
979 auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
986 std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
987 op.getContractingDimMap();
990 for (
auto &dimPair : contractingDimMap) {
991 lhsContractingDimSet.insert(dimPair.first);
992 rhsContractingDimSet.insert(dimPair.second);
996 VectorType lhsType = op.getLhsType();
997 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
998 if (lhsContractingDimSet.count(lhsIndex) == 0) {
999 auto newOp = lowerParallel(rewriter, op, lhsIndex, -1, mask);
1007 VectorType rhsType = op.getRhsType();
1008 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1009 if (rhsContractingDimSet.count(rhsIndex) == 0) {
1010 auto newOp = lowerParallel(rewriter, op, -1, rhsIndex, mask);
1018 if (!contractingDimMap.empty()) {
1019 auto newOp = lowerReduction(rewriter, op, mask);
1032 vector::ContractionOp op,
1036 VectorType lhsType = op.getLhsType();
1037 VectorType rhsType = op.getRhsType();
1038 VectorType resType = cast<VectorType>(op.getResultType());
1041 int64_t iterIndex = -1;
1042 int64_t dimSize = -1;
1043 if (lhsIndex >= 0) {
1044 iterIndex = iMap[0].getDimPosition(lhsIndex);
1045 if (rhsIndex >= 0 && iterIndex != iMap[1].
getDimPosition(rhsIndex))
1047 diag <<
"expected lhsIndex=" << lhsIndex <<
" and rhsIndex=" << rhsIndex
1048 <<
" to map to the same dimension";
1050 if (lhsType.getScalableDims()[lhsIndex])
1052 diag <<
"Unrolling scalable dimension (lhsIndex=" << lhsIndex
1053 <<
") is not supported yet";
1055 dimSize = lhsType.getDimSize(lhsIndex);
1056 }
else if (rhsIndex >= 0) {
1057 iterIndex = iMap[1].getDimPosition(rhsIndex);
1058 if (rhsType.getScalableDims()[rhsIndex])
1060 diag <<
"Unrolling scalable dimension (rhsIndex=" << rhsIndex
1061 <<
") is not supported yet";
1063 dimSize = rhsType.getDimSize(rhsIndex);
1067 diag <<
"expected either lhsIndex=" << lhsIndex
1068 <<
" or rhsIndex=" << rhsIndex <<
" to be nonnegative";
1077 int64_t resIndex =
getResultIndex(iMap[2], iterIndex).value_or(-1);
1078 if (resIndex == -1 && dimSize != 1)
1080 diag <<
"expected the dimension for iterIndex=" << iterIndex
1081 <<
" to either appear in the result map, or to be a unit dimension";
1085 std::array<AffineMap, 3> lowIndexingMaps = {
1086 adjustMap(iMap[0], iterIndex, rewriter),
1087 adjustMap(iMap[1], iterIndex, rewriter),
1088 adjustMap(iMap[2], iterIndex, rewriter)};
1094 Value result = rewriter.
create<arith::ConstantOp>(
1097 for (int64_t d = 0; d < dimSize; ++d) {
1098 auto lhs =
reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1099 auto rhs =
reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1100 auto acc =
reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1105 iterIndex, d, rewriter);
1108 loc, lhs, rhs, acc, lowAffine, lowIter);
1109 lowContract =
maskOperation(rewriter, lowContract, lowMask);
1111 resIndex, d, rewriter);
1120 VectorType lhsType = op.getLhsType();
1121 VectorType rhsType = op.getRhsType();
1122 Type resType = op.getResultType();
1123 if (isa<VectorType>(resType))
1125 "did not expect a VectorType result");
1126 bool isInt = isa<IntegerType>(resType);
1128 int64_t iterIndex = 0;
1130 std::optional<int64_t> lookupLhs =
getResultIndex(iMap[0], iterIndex);
1131 std::optional<int64_t> lookupRhs =
getResultIndex(iMap[1], iterIndex);
1132 if (!lookupLhs.has_value())
1134 diag <<
"expected iterIndex=" << iterIndex <<
"to map to a LHS dimension";
1136 if (!lookupRhs.has_value())
1138 diag <<
"expected iterIndex=" << iterIndex <<
"to map to a RHS dimension";
1140 int64_t lhsIndex = *lookupLhs;
1141 int64_t rhsIndex = *lookupRhs;
1142 int64_t dimSize = lhsType.getDimSize(lhsIndex);
1143 if (dimSize != rhsType.getDimSize(rhsIndex))
1145 diag <<
"expect LHS dimension " << lhsIndex
1146 <<
" to have the same size as RHS dimension " << rhsIndex;
1149 if (lhsType.getRank() == 1) {
1150 if (rhsType.getRank() != 1)
1152 op,
"When LHS has rank 1, expected also RHS to have rank 1");
1153 Value m =
createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1154 auto kind = vector::CombiningKind::ADD;
1156 Value acc = op.getAcc();
1158 acc ? rewriter.
create<vector::ReductionOp>(loc, kind, m, acc)
1159 : rewriter.
create<vector::ReductionOp>(loc, kind, m);
1163 std::array<AffineMap, 3> lowIndexingMaps = {
1164 adjustMap(iMap[0], iterIndex, rewriter),
1165 adjustMap(iMap[1], iterIndex, rewriter),
1166 adjustMap(iMap[2], iterIndex, rewriter)};
1174 Value result = op.getAcc();
1175 for (int64_t d = 0; d < dimSize; ++d) {
1176 auto lhs =
reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1177 auto rhs =
reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1181 iterIndex, d, rewriter);
1184 loc, lhs, rhs, result, lowAffine, lowIter);
1203 class OuterProductOpLowering :
public OpRewritePattern<vector::OuterProductOp> {
1209 VectorType resType = op.getResultVectorType();
1210 if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1215 VectorType lhsType = op.getOperandVectorTypeLHS();
1216 VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1217 Type eltType = resType.getElementType();
1218 bool isInt = isa<IntegerType, IndexType>(eltType);
1219 Value acc = op.getAcc();
1220 vector::CombiningKind kind = op.getKind();
1224 auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1227 if (maskableOp.isMasked()) {
1229 rootOp = maskableOp.getMaskingOp();
1230 mask = maskableOp.getMaskingOp().getMask();
1237 Value b = rewriter.
create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
1239 loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
1240 if (!mult.has_value())
1246 Value result = rewriter.
create<arith::ConstantOp>(
1248 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1249 Value x = rewriter.
create<vector::ExtractOp>(loc, op.getLhs(), d);
1250 Value a = rewriter.
create<vector::BroadcastOp>(loc, rhsType, x);
1253 r = rewriter.
create<vector::ExtractOp>(loc, acc, d);
1256 extrMask = rewriter.
create<vector::ExtractOp>(loc, mask, d);
1259 loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
1262 result = rewriter.
create<vector::InsertOp>(loc, *m, result, d);
1287 FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
1288 vector::ContractionOp op, MaskingOpInterface maskOp,
1294 if (vectorTransformOptions.vectorContractLowering !=
1295 vector::VectorContractLowering::Matmul)
1300 auto iteratorTypes = op.getIteratorTypes().getValue();
1306 Type elementType = op.getLhsType().getElementType();
1310 Type dstElementType = op.getType();
1311 if (
auto vecType = dyn_cast<VectorType>(dstElementType))
1312 dstElementType = vecType.getElementType();
1313 if (elementType != dstElementType)
1323 Value lhs = op.getLhs();
1324 auto lhsMap = op.getIndexingMapsArray()[0];
1331 Value rhs = op.getRhs();
1332 auto rhsMap = op.getIndexingMapsArray()[1];
1339 VectorType lhsType = cast<VectorType>(lhs.getType());
1340 VectorType rhsType = cast<VectorType>(rhs.getType());
1341 int64_t lhsRows = lhsType.getDimSize(0);
1342 int64_t lhsColumns = lhsType.getDimSize(1);
1343 int64_t rhsColumns = rhsType.getDimSize(1);
1345 Type flattenedLHSType =
1347 lhs = rew.
create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1349 Type flattenedRHSType =
1351 rhs = rew.
create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1353 Value mul = rew.
create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1355 mul = rew.
create<vector::ShapeCastOp>(
1362 auto accMap = op.getIndexingMapsArray()[2];
1366 llvm_unreachable(
"invalid contraction semantics");
1369 isa<IntegerType>(elementType)
1370 ?
static_cast<Value>(rew.
create<arith::AddIOp>(loc, op.getAcc(), mul))
1371 :
static_cast<Value>(
1372 rew.
create<arith::AddFOp>(loc, op.getAcc(), mul));
1381 if (!disableOuterProductLowering)
1382 patterns.
add<OuterProductOpLowering>(patterns.
getContext(), benefit);
1383 patterns.
add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
1384 ContractionOpToOuterProductOpLowering>(
1390 patterns.
add<OuterProductOpLowering>(patterns.
getContext(), benefit);
static SmallVector< int64_t > getReductionIndex(AffineMap map, ArrayAttr iteratorTypes)
Return the positions of the reductions in the given map.
static Value createMul(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates a MulIOp if isInt is true otherwise create an MulFOp using operands x andy`.
static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y.
static std::optional< unsigned > getDimPosition(AffineMap map, unsigned dim)
Look for a given dimension in an affine map and return its position.
static std::optional< Value > createContractArithOp(Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, bool isInt, Value mask=Value())
Helper to create arithmetic operation associated with a kind of contraction.
static Value reshapeStore(Location loc, Value val, Value result, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter)
static Value reshapeLoad(Location loc, Value val, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< Attribute > adjustIter(ArrayAttr iteratorTypes, int64_t index)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.