26 #define DEBUG_TYPE "vector-contract-lowering"
49 int64_t idx = it.index();
52 results.push_back(it.value());
69 results.push_back(targetExpr);
77 int64_t index, int64_t pos,
84 return vector::ExtractOp::create(rewriter, loc, val, pos);
89 Value result = arith::ConstantOp::create(rewriter, loc, resType,
91 for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
92 Value ext = vector::ExtractOp::create(rewriter, loc, val, d);
94 result = vector::InsertOp::create(rewriter, loc, load, result, d);
102 VectorType type, int64_t index, int64_t pos,
109 return vector::InsertOp::create(rewriter, loc, val, result, pos);
113 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
114 Value ext = vector::ExtractOp::create(rewriter, loc, result, d);
115 Value ins = vector::ExtractOp::create(rewriter, loc, val, d);
117 result = vector::InsertOp::create(rewriter, loc, sto, result, d);
123 static std::optional<Value>
127 using vector::CombiningKind;
131 if (
kind == CombiningKind::MINNUMF ||
kind == CombiningKind::MAXNUMF ||
132 kind == CombiningKind::MINIMUMF ||
kind == CombiningKind::MAXIMUMF)
135 mul = arith::MulIOp::create(rewriter, loc, x, y);
139 kind == CombiningKind::MINSI ||
kind == CombiningKind::MAXUI ||
140 kind == CombiningKind::MAXSI ||
kind == CombiningKind::OR ||
141 kind == CombiningKind::XOR)
145 if (acc && isa<VectorType>(acc.
getType()) &&
kind == CombiningKind::ADD) {
146 Value fma = vector::FMAOp::create(rewriter, loc, x, y, acc);
153 mul = arith::MulFOp::create(rewriter, loc, x, y);
157 return std::optional<Value>(mul);
165 ArrayAttr iteratorTypes) {
169 dimsIdx.push_back(i);
189 return arith::AddIOp::create(rewriter, loc, x, y);
190 return arith::AddFOp::create(rewriter, loc, x, y);
198 return arith::MulIOp::create(rewriter, loc, x, y);
199 return arith::MulFOp::create(rewriter, loc, x, y);
219 class ContractionOpToOuterProductOpLowering
222 using MaskableOpRewritePattern::MaskableOpRewritePattern;
224 using FilterConstraintType =
225 std::function<LogicalResult(vector::ContractionOp op)>;
227 static LogicalResult defaultFilter(vector::ContractionOp op) {
231 ContractionOpToOuterProductOpLowering(
232 vector::VectorContractLowering vectorContractLowering,
234 FilterConstraintType constraint = defaultFilter)
236 vectorContractLowering(vectorContractLowering),
237 filter(std::move(constraint)) {}
240 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
245 vector::VectorContractLowering vectorContractLowering;
246 FilterConstraintType filter;
267 class ContractionOpToDotLowering
270 using MaskableOpRewritePattern::MaskableOpRewritePattern;
272 using FilterConstraintType =
273 std::function<LogicalResult(vector::ContractionOp op)>;
275 static LogicalResult defaultFilter(vector::ContractionOp op) {
279 ContractionOpToDotLowering(
280 vector::VectorContractLowering vectorContractLowering,
282 const FilterConstraintType &constraint = defaultFilter)
284 vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
287 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
292 vector::VectorContractLowering vectorContractLowering;
293 FilterConstraintType filter;
310 class ContractionOpLowering
313 using MaskableOpRewritePattern::MaskableOpRewritePattern;
314 using FilterConstraintType =
315 std::function<LogicalResult(vector::ContractionOp op)>;
317 static LogicalResult defaultFilter(vector::ContractionOp op) {
321 ContractionOpLowering(
322 vector::VectorContractLowering vectorContractLoweringOption,
324 FilterConstraintType constraint = defaultFilter)
326 vectorContractLoweringOption(vectorContractLoweringOption),
327 filter(std::move(constraint)) {}
330 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
335 vector::VectorContractLowering vectorContractLoweringOption;
336 FilterConstraintType filter;
339 vector::ContractionOp op, int64_t lhsIndex,
340 int64_t rhsIndex,
Value mask)
const;
343 vector::ContractionOp op,
Value mask)
const;
348 struct UnrolledOuterProductGenerator
350 UnrolledOuterProductGenerator(
RewriterBase &b, vector::ContractionOp op)
352 kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
353 res(op.getAcc()), lhsType(op.getLhsType()) {
354 auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
355 if (maskableOp.isMasked())
356 mask = maskableOp.getMaskingOp().getMask();
362 return vector::TransposeOp::create(rewriter, loc, v, perm);
367 auto vecType = dyn_cast<VectorType>(elementType);
369 elementType = vecType.getElementType();
370 if (elementType == dstElementType)
372 Type promotedType = dstElementType;
374 promotedType = vecType.clone(promotedType);
375 if (isa<FloatType>(dstElementType))
376 return arith::ExtFOp::create(rewriter, loc, promotedType, v);
377 return arith::ExtSIOp::create(rewriter, loc, promotedType, v);
381 VectorType lhsType,
int reductionSize,
382 std::optional<Value> maybeMask = std::nullopt) {
384 if (mask && !maybeMask.has_value())
387 Type resElementType = cast<VectorType>(res.
getType()).getElementType();
388 for (int64_t k = 0; k < reductionSize; ++k) {
389 Value extractA = vector::ExtractOp::create(rewriter, loc, lhs, k);
390 Value extractB = vector::ExtractOp::create(rewriter, loc, rhs, k);
391 extractA =
promote(extractA, resElementType);
392 extractB =
promote(extractB, resElementType);
394 if (maybeMask.has_value() && maybeMask.value())
396 vector::ExtractOp::create(rewriter, loc, maybeMask.value(), k);
398 Operation *outerProdOp = vector::OuterProductOp::create(
399 rewriter, loc, res.
getType(), extractA, extractB, res,
kind);
408 std::optional<int64_t> getReductionSize(VectorType vecType,
409 int64_t reductionDim) {
411 if (vecType.getScalableDims()[reductionDim])
413 int64_t reductionSize = vecType.getDimSize(reductionDim);
414 assert(reductionSize > 0 &&
415 "Reduction dim must be a known static size to allow unrolling");
416 return reductionSize;
420 FailureOr<Value> matmat() {
421 if (!iters({Par(), Par(), Red()}))
428 if (layout({{m, k}, {k, n}, {m, n}})) {
429 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
433 Value tMask = t(mask, {2, 0, 1});
434 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
438 if (layout({{m, k}, {n, k}, {m, n}})) {
439 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
442 Value tMask = t(mask, {2, 0, 1});
443 return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
447 if (layout({{k, m}, {k, n}, {m, n}})) {
448 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
449 Value tMask = t(mask, {2, 0, 1});
450 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
454 if (layout({{k, m}, {n, k}, {m, n}})) {
455 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
457 Value tMask = t(mask, {2, 0, 1});
458 return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask);
463 if (layout({{m, k}, {k, n}, {n, m}})) {
464 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
466 Value tMask = t(mask, {2, 0, 1});
467 return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask);
471 if (layout({{m, k}, {n, k}, {n, m}})) {
472 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
475 Value tMask = t(mask, {2, 0, 1});
476 return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
479 if (layout({{k, m}, {k, n}, {n, m}})) {
480 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
481 Value tMask = t(mask, {2, 0, 1});
482 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
485 if (layout({{k, m}, {n, k}, {n, m}})) {
486 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
488 Value tMask = t(mask, {2, 0, 1});
489 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
500 FailureOr<Value> matvec() {
501 if (!iters({Par(), Red()}))
507 if (layout({{m, k}, {k}, {m}})) {
508 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
510 Value tMask = t(mask);
511 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
515 if (layout({{k, m}, {k}, {m}})) {
516 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
517 Value tMask = t(mask);
518 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
522 if (layout({{k}, {m, k}, {m}})) {
523 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
525 Value tMask = t(mask);
526 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
530 if (layout({{k}, {k, m}, {m}})) {
531 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
532 Value tMask = t(mask);
533 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
543 FailureOr<Value> tmatvec() {
544 if (!iters({Red(), Par()}))
550 if (layout({{m, k}, {k}, {m}}))
551 if (
auto reductionSize = getReductionSize(lhsType, 1))
552 return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
554 if (layout({{k, m}, {k}, {m}}))
555 if (
auto reductionSize = getReductionSize(lhsType, 0))
556 return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
558 if (layout({{k}, {m, k}, {m}}))
559 if (
auto reductionSize = getReductionSize(lhsType, 0))
560 return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
562 if (layout({{k}, {k, m}, {m}}))
563 if (
auto reductionSize = getReductionSize(lhsType, 0))
564 return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
569 vector::CombiningKind
kind;
570 Value lhs, rhs, res, mask;
590 ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
591 vector::ContractionOp op, MaskingOpInterface maskOp,
593 if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
599 UnrolledOuterProductGenerator e(rewriter, op);
600 FailureOr<Value> matmatRes = e.matmat();
601 if (succeeded(matmatRes)) {
604 FailureOr<Value> matvecRes = e.matvec();
605 if (succeeded(matvecRes)) {
609 FailureOr<Value> tmatvecRes = e.tmatvec();
613 FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
614 vector::ContractionOp op, MaskingOpInterface maskOp,
623 if (vectorContractLowering != vector::VectorContractLowering::Dot)
626 auto iteratorTypes = op.getIteratorTypes().getValue();
627 static constexpr std::array<int64_t, 2> perm = {1, 0};
629 Value lhs = op.getLhs(), rhs = op.getRhs();
632 auto infer = [&](MapList m) {
648 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
649 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
650 }
else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
652 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
653 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
654 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
655 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
656 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
657 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
660 lhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
662 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
664 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
666 lhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
667 rhs = vector::TransposeOp::create(rewriter, loc, tmp, perm);
668 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
670 rhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
680 if (maps == infer({{m, n}, {n}, {m}})) {
682 }
else if (maps == infer({{n, m}, {n}, {m}})) {
683 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
684 }
else if (maps == infer({{n}, {m, n}, {m}})) {
686 }
else if (maps == infer({{n}, {n, m}, {m}})) {
688 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
696 VectorType dstType = cast<VectorType>(op.getResultType());
697 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
698 "Expected dst type of rank 1 or 2");
700 unsigned rank = dstType.getRank();
701 unsigned dstRows = dstType.getShape()[0];
702 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
705 Value res = arith::ConstantOp::create(rewriter, loc, dstType,
707 bool isInt = isa<IntegerType>(dstType.getElementType());
709 extractedCols.reserve(dstColumns);
710 for (
unsigned r = 0; r < dstRows; ++r) {
711 Value rowLhs = vector::ExtractOp::create(rewriter, op.getLoc(), lhs, r);
712 for (
unsigned c = 0; c < dstColumns; ++c) {
719 : vector::ExtractOp::create(rewriter, op.getLoc(), rhs, c);
720 extractedCols.push_back(colRhs);
722 Value extractedColRhs = extractedCols[c];
724 createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter);
725 Value sum = vector::ReductionOp::create(
726 rewriter, op.getLoc(), vector::CombiningKind::ADD,
product);
730 res = vector::InsertOp::create(rewriter, op.getLoc(), sum, res, pos);
733 if (
auto acc = op.getAcc())
734 res =
createAdd(op.getLoc(), res, acc, isInt, rewriter);
740 struct ContractOpToElementwise
742 using MaskableOpRewritePattern::MaskableOpRewritePattern;
743 using FilterConstraintType =
744 std::function<LogicalResult(vector::ContractionOp op)>;
745 static LogicalResult defaultFilter(vector::ContractionOp op) {
748 ContractOpToElementwise(
749 vector::VectorContractLowering vectorContractLowering,
751 const FilterConstraintType &constraint = defaultFilter)
753 vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
756 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
757 MaskingOpInterface maskOp,
763 if (
failed(filter(contractOp)))
766 if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
771 AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
772 AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
778 for (int64_t dim : lhsReductionDims) {
779 if (lhsShape[dim] != 1)
782 for (int64_t dim : rhsReductionDims) {
783 if (rhsShape[dim] != 1)
786 AffineMap accMap = contractOp.getIndexingMapsArray()[2];
788 unsigned numLhsDimToBroadcast =
789 numParallelDims - (lhsMap.
getNumResults() - lhsReductionDims.size());
790 unsigned numRhsDimToBroadcast =
791 numParallelDims - (rhsMap.
getNumResults() - rhsReductionDims.size());
796 for (int64_t dim : lhsReductionDims)
797 lhsTranspose.push_back(numLhsDimToBroadcast + dim);
798 for (int64_t dim : rhsReductionDims)
799 rhsTranspose.push_back(numRhsDimToBroadcast + dim);
802 for (
unsigned i = 0; i < numParallelDims; i++) {
803 std::optional<unsigned> lhsDim =
806 lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
810 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
811 lhsTranspose.push_back(lhsDims.size() - 1);
813 std::optional<unsigned> rhsDim =
816 rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
820 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
821 rhsTranspose.push_back(rhsDims.size() - 1);
824 Value newLhs = contractOp.getLhs();
825 Value newRhs = contractOp.getRhs();
827 if (!lhsDims.empty()) {
828 lhsDims.append(lhsShape.begin(), lhsShape.end());
831 newLhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newLhs);
833 if (!rhsDims.empty()) {
834 rhsDims.append(rhsShape.begin(), rhsShape.end());
837 newRhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newRhs);
839 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
840 newLhs = vector::TransposeOp::create(rewriter, loc, newLhs, lhsTranspose);
841 newRhs = vector::TransposeOp::create(rewriter, loc, newRhs, rhsTranspose);
844 newLhs = vector::ExtractOp::create(rewriter, loc, newLhs, lhsOffsets);
845 newRhs = vector::ExtractOp::create(rewriter, loc, newRhs, rhsOffsets);
846 std::optional<Value> result =
848 contractOp.getKind(), rewriter, isInt);
857 vector::VectorContractLowering vectorContractLowering;
858 FilterConstraintType filter;
878 FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
879 vector::ContractionOp op, MaskingOpInterface maskOp,
885 if (op.getLhsType().getElementType() !=
892 if (op.getKind() != vector::CombiningKind::ADD) {
894 op,
"contractions other than 'add' not supported");
900 ContractionOpToOuterProductOpLowering pat1(vectorContractLoweringOption, ctx);
901 FailureOr<Value> newVal1 =
902 pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
906 ContractionOpToDotLowering pat2(vectorContractLoweringOption, ctx);
907 FailureOr<Value> newVal2 =
908 pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
912 ContractOpToElementwise pat4(vectorContractLoweringOption, ctx);
913 FailureOr<Value> newVal4 =
914 pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
922 mask = maskOp.getMask();
924 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
925 if (!batchDimMap.empty()) {
926 int64_t lhsIndex = batchDimMap[0].first;
927 int64_t rhsIndex = batchDimMap[0].second;
928 auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
935 std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
936 op.getContractingDimMap();
939 for (
auto &dimPair : contractingDimMap) {
940 lhsContractingDimSet.insert(dimPair.first);
941 rhsContractingDimSet.insert(dimPair.second);
945 VectorType lhsType = op.getLhsType();
946 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
947 if (lhsContractingDimSet.count(lhsIndex) == 0) {
948 auto newOp = lowerParallel(rewriter, op, lhsIndex, -1, mask);
956 VectorType rhsType = op.getRhsType();
957 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
958 if (rhsContractingDimSet.count(rhsIndex) == 0) {
959 auto newOp = lowerParallel(rewriter, op, -1, rhsIndex, mask);
967 if (!contractingDimMap.empty()) {
968 auto newOp = lowerReduction(rewriter, op, mask);
980 FailureOr<Value> ContractionOpLowering::lowerParallel(
PatternRewriter &rewriter,
981 vector::ContractionOp op,
985 VectorType lhsType = op.getLhsType();
986 VectorType rhsType = op.getRhsType();
987 VectorType resType = cast<VectorType>(op.getResultType());
990 int64_t iterIndex = -1;
991 int64_t dimSize = -1;
993 iterIndex = iMap[0].getDimPosition(lhsIndex);
994 if (rhsIndex >= 0 && iterIndex != iMap[1].
getDimPosition(rhsIndex))
996 diag <<
"expected lhsIndex=" << lhsIndex <<
" and rhsIndex=" << rhsIndex
997 <<
" to map to the same dimension";
999 if (lhsType.getScalableDims()[lhsIndex])
1001 diag <<
"Unrolling scalable dimension (lhsIndex=" << lhsIndex
1002 <<
") is not supported yet";
1004 dimSize = lhsType.getDimSize(lhsIndex);
1005 }
else if (rhsIndex >= 0) {
1006 iterIndex = iMap[1].getDimPosition(rhsIndex);
1007 if (rhsType.getScalableDims()[rhsIndex])
1009 diag <<
"Unrolling scalable dimension (rhsIndex=" << rhsIndex
1010 <<
") is not supported yet";
1012 dimSize = rhsType.getDimSize(rhsIndex);
1016 diag <<
"expected either lhsIndex=" << lhsIndex
1017 <<
" or rhsIndex=" << rhsIndex <<
" to be nonnegative";
1026 int64_t resIndex =
getResultIndex(iMap[2], iterIndex).value_or(-1);
1027 if (resIndex == -1 && dimSize != 1)
1029 diag <<
"expected the dimension for iterIndex=" << iterIndex
1030 <<
" to either appear in the result map, or to be a unit dimension";
1034 std::array<AffineMap, 3> lowIndexingMaps = {
1035 adjustMap(iMap[0], iterIndex, rewriter),
1036 adjustMap(iMap[1], iterIndex, rewriter),
1037 adjustMap(iMap[2], iterIndex, rewriter)};
1043 Value result = arith::ConstantOp::create(rewriter, loc, resType,
1046 for (int64_t d = 0; d < dimSize; ++d) {
1047 auto lhs =
reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1048 auto rhs =
reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1049 auto acc =
reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1054 iterIndex, d, rewriter);
1056 Operation *lowContract = vector::ContractionOp::create(
1057 rewriter, loc, lhs, rhs, acc, lowAffine, lowIter);
1058 lowContract =
maskOperation(rewriter, lowContract, lowMask);
1060 resIndex, d, rewriter);
1066 FailureOr<Value> ContractionOpLowering::lowerReduction(
1068 auto loc = op.getLoc();
1069 VectorType lhsType = op.getLhsType();
1070 VectorType rhsType = op.getRhsType();
1071 Type resType = op.getResultType();
1072 if (isa<VectorType>(resType))
1074 "did not expect a VectorType result");
1075 bool isInt = isa<IntegerType>(resType);
1077 int64_t iterIndex = 0;
1079 std::optional<int64_t> lookupLhs =
getResultIndex(iMap[0], iterIndex);
1080 std::optional<int64_t> lookupRhs =
getResultIndex(iMap[1], iterIndex);
1081 if (!lookupLhs.has_value())
1083 diag <<
"expected iterIndex=" << iterIndex <<
"to map to a LHS dimension";
1085 if (!lookupRhs.has_value())
1087 diag <<
"expected iterIndex=" << iterIndex <<
"to map to a RHS dimension";
1089 int64_t lhsIndex = *lookupLhs;
1090 int64_t rhsIndex = *lookupRhs;
1091 int64_t dimSize = lhsType.getDimSize(lhsIndex);
1092 if (dimSize != rhsType.getDimSize(rhsIndex))
1094 diag <<
"expect LHS dimension " << lhsIndex
1095 <<
" to have the same size as RHS dimension " << rhsIndex;
1098 if (lhsType.getRank() == 1) {
1099 if (rhsType.getRank() != 1)
1101 op,
"When LHS has rank 1, expected also RHS to have rank 1");
1102 Value m =
createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1103 auto kind = vector::CombiningKind::ADD;
1105 Value acc = op.getAcc();
1107 acc ? vector::ReductionOp::create(rewriter, loc,
kind, m, acc)
1108 : vector::ReductionOp::create(rewriter, loc,
kind, m);
1112 std::array<AffineMap, 3> lowIndexingMaps = {
1113 adjustMap(iMap[0], iterIndex, rewriter),
1114 adjustMap(iMap[1], iterIndex, rewriter),
1115 adjustMap(iMap[2], iterIndex, rewriter)};
1123 Value result = op.getAcc();
1124 for (int64_t d = 0; d < dimSize; ++d) {
1125 auto lhs =
reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1126 auto rhs =
reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1130 iterIndex, d, rewriter);
1132 Operation *newContract = vector::ContractionOp::create(
1133 rewriter, loc, lhs, rhs, result, lowAffine, lowIter);
1152 class OuterProductOpLowering :
public OpRewritePattern<vector::OuterProductOp> {
1156 LogicalResult matchAndRewrite(vector::OuterProductOp op,
1158 VectorType resType = op.getResultVectorType();
1159 if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1162 auto loc = op.getLoc();
1164 VectorType lhsType = op.getOperandVectorTypeLHS();
1165 VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1166 Type eltType = resType.getElementType();
1167 bool isInt = isa<IntegerType, IndexType>(eltType);
1168 Value acc = op.getAcc();
1169 vector::CombiningKind
kind = op.getKind();
1173 auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1176 if (maskableOp.isMasked()) {
1178 rootOp = maskableOp.getMaskingOp();
1179 mask = maskableOp.getMaskingOp().getMask();
1187 vector::BroadcastOp::create(rewriter, loc, lhsType, op.getRhs());
1189 loc, op.getLhs(), b, acc,
kind, rewriter, isInt, mask);
1190 if (!mult.has_value())
1196 Value result = arith::ConstantOp::create(rewriter, loc, resType,
1198 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1199 Value x = vector::ExtractOp::create(rewriter, loc, op.getLhs(), d);
1200 Value a = vector::BroadcastOp::create(rewriter, loc, rhsType, x);
1203 r = vector::ExtractOp::create(rewriter, loc, acc, d);
1206 extrMask = vector::ExtractOp::create(rewriter, loc, mask, d);
1209 loc, a, op.getRhs(), r,
kind, rewriter, isInt, extrMask);
1212 result = vector::InsertOp::create(rewriter, loc, *m, result, d);
1224 VectorContractLowering vectorContractLoweringOption,
PatternBenefit benefit,
1225 bool disableOuterProductLowering) {
1226 if (!disableOuterProductLowering)
1228 patterns.add<ContractionOpLowering, ContractionOpToOuterProductOpLowering>(
1229 vectorContractLoweringOption,
patterns.getContext(), benefit);
static int64_t product(ArrayRef< int64_t > vals)
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
static SmallVector< int64_t > getReductionIndex(AffineMap map, ArrayAttr iteratorTypes)
Return the positions of the reductions in the given map.
static Value createMul(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates a MulIOp if isInt is true otherwise create an MulFOp using operands x andy`.
static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y.
static std::optional< unsigned > getDimPosition(AffineMap map, unsigned dim)
Look for a given dimension in an affine map and return its position.
static std::optional< Value > createContractArithOp(Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, bool isInt, Value mask=Value())
Helper to create arithmetic operation associated with a kind of contraction.
static Value reshapeStore(Location loc, Value val, Value result, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter)
static Value reshapeLoad(Location loc, Value val, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< Attribute > adjustIter(ArrayAttr iteratorTypes, int64_t index)
static std::string diag(const llvm::Value &value)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.