36 #define DEBUG_TYPE "vector-contract-lowering"
60 int64_t idx = it.index();
63 results.push_back(it.value());
80 results.push_back(targetExpr);
88 int64_t index, int64_t pos,
95 return rewriter.
create<vector::ExtractOp>(loc, val, pos);
102 for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
103 Value ext = rewriter.
create<vector::ExtractOp>(loc, val, d);
105 result = rewriter.
create<vector::InsertOp>(loc, load, result, d);
113 VectorType type, int64_t index, int64_t pos,
120 return rewriter.
create<vector::InsertOp>(loc, val, result, pos);
124 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
125 Value ext = rewriter.
create<vector::ExtractOp>(loc, result, d);
126 Value ins = rewriter.
create<vector::ExtractOp>(loc, val, d);
128 result = rewriter.
create<vector::InsertOp>(loc, sto, result, d);
134 static std::optional<Value>
138 using vector::CombiningKind;
142 if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF ||
143 kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
146 mul = rewriter.
create<arith::MulIOp>(loc, x, y);
149 if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
150 kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
151 kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
152 kind == CombiningKind::XOR)
156 if (acc && isa<VectorType>(acc.
getType()) && kind == CombiningKind::ADD) {
157 Value fma = rewriter.
create<vector::FMAOp>(loc, x, y, acc);
164 mul = rewriter.
create<arith::MulFOp>(loc, x, y);
168 return std::optional<Value>(mul);
175 ArrayAttr iteratorTypes) {
179 dimsIdx.push_back(i);
199 return rewriter.
create<arith::AddIOp>(loc, x, y);
200 return rewriter.
create<arith::AddFOp>(loc, x, y);
208 return rewriter.
create<arith::MulIOp>(loc, x, y);
209 return rewriter.
create<arith::MulFOp>(loc, x, y);
227 class ContractionOpToMatmulOpLowering
232 using FilterConstraintType =
235 static LogicalResult defaultFilter(vector::ContractionOp op) {
239 ContractionOpToMatmulOpLowering(
242 FilterConstraintType constraint = defaultFilter)
244 vectorTransformOptions(vectorTransformOptions),
245 filter(std::move(constraint)) {}
253 FilterConstraintType filter;
271 class ContractionOpToOuterProductOpLowering
276 using FilterConstraintType =
279 static LogicalResult defaultFilter(vector::ContractionOp op) {
283 ContractionOpToOuterProductOpLowering(
286 FilterConstraintType constraint = defaultFilter)
288 vectorTransformOptions(vectorTransformOptions),
289 filter(std::move(constraint)) {}
297 FilterConstraintType filter;
318 class ContractionOpToDotLowering
323 using FilterConstraintType =
326 static LogicalResult defaultFilter(vector::ContractionOp op) {
330 ContractionOpToDotLowering(
333 const FilterConstraintType &constraint = defaultFilter)
335 vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
343 FilterConstraintType filter;
360 class ContractionOpLowering :
public OpRewritePattern<vector::ContractionOp> {
363 using FilterConstraintType =
366 static LogicalResult defaultFilter(vector::ContractionOp op) {
372 FilterConstraintType constraint = defaultFilter)
374 vectorTransformOptions(vectorTransformOptions),
375 filter(std::move(constraint)) {}
383 FilterConstraintType filter;
386 vector::ContractionOp op, int64_t lhsIndex,
387 int64_t rhsIndex,
Value mask)
const;
390 vector::ContractionOp op,
Value mask)
const;
395 struct UnrolledOuterProductGenerator
397 UnrolledOuterProductGenerator(
RewriterBase &b, vector::ContractionOp op)
399 kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
400 res(op.getAcc()), lhsType(op.getLhsType()) {
401 auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
402 if (maskableOp.isMasked())
403 mask = maskableOp.getMaskingOp().getMask();
409 return rewriter.
create<vector::TransposeOp>(loc, v, perm);
414 auto vecType = dyn_cast<VectorType>(elementType);
416 elementType = vecType.getElementType();
417 if (elementType == dstElementType)
419 Type promotedType = dstElementType;
422 if (isa<FloatType>(dstElementType))
423 return rewriter.
create<arith::ExtFOp>(loc, promotedType, v);
424 return rewriter.
create<arith::ExtSIOp>(loc, promotedType, v);
428 std::optional<Value> maybeMask = std::nullopt) {
429 assert(reductionSize > 0);
431 if (mask && !maybeMask.has_value())
434 Type resElementType = cast<VectorType>(res.
getType()).getElementType();
435 for (int64_t k = 0; k < reductionSize; ++k) {
436 Value extractA = rewriter.
create<vector::ExtractOp>(loc, lhs, k);
437 Value extractB = rewriter.
create<vector::ExtractOp>(loc, rhs, k);
438 extractA =
promote(extractA, resElementType);
439 extractB =
promote(extractB, resElementType);
441 if (maybeMask.has_value() && maybeMask.value())
443 rewriter.
create<vector::ExtractOp>(loc, maybeMask.value(), k);
446 loc, res.
getType(), extractA, extractB, res, kind);
454 if (!iters({Par(), Par(), Red()}))
459 Value transposedMask = t(mask, {2, 0, 1});
461 if (layout({{m, k}, {k, n}, {m, n}}))
462 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
464 if (layout({{m, k}, {n, k}, {m, n}})) {
466 return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1),
470 if (layout({{k, m}, {k, n}, {m, n}}))
471 return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
473 if (layout({{k, m}, {n, k}, {m, n}}))
474 return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0), transposedMask);
477 if (layout({{m, k}, {k, n}, {n, m}}))
478 return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1), transposedMask);
480 if (layout({{m, k}, {n, k}, {n, m}})) {
482 return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1),
485 if (layout({{k, m}, {k, n}, {n, m}}))
486 return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
487 if (layout({{k, m}, {n, k}, {n, m}}))
488 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
498 if (!iters({Par(), Red()}))
502 Value transposedMask = t(mask);
505 if (layout({{m, k}, {k}, {m}}))
506 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
508 if (layout({{k, m}, {k}, {m}}))
509 return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
511 if (layout({{k}, {m, k}, {m}}))
512 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
514 if (layout({{k}, {k, m}, {m}}))
515 return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
524 if (!iters({Red(), Par()}))
530 if (layout({{m, k}, {k}, {m}}))
531 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), mask);
533 if (layout({{k, m}, {k}, {m}}))
534 return outerProd(lhs, rhs, res, lhsType.getDimSize(0), mask);
536 if (layout({{k}, {m, k}, {m}}))
537 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), mask);
539 if (layout({{k}, {k, m}, {m}}))
540 return outerProd(rhs, lhs, res, lhsType.getDimSize(0), mask);
545 vector::CombiningKind kind;
546 Value lhs, rhs, res, mask;
565 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
567 if (vectorTransformOptions.vectorContractLowering !=
568 vector::VectorContractLowering::OuterProduct)
576 auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
578 if (maskableOp.isMasked()) {
580 rootOp = maskableOp.getMaskingOp();
585 UnrolledOuterProductGenerator e(rewriter, op);
606 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
609 auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
610 if (maskableOp.isMasked())
616 if (vectorTransformOptions.vectorContractLowering !=
617 vector::VectorContractLowering::Dot)
620 auto iteratorTypes = op.getIteratorTypes().getValue();
621 static constexpr std::array<int64_t, 2> perm = {1, 0};
623 Value lhs = op.getLhs(), rhs = op.getRhs();
640 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
641 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
642 }
else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
644 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
645 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
646 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
647 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
648 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
649 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
652 lhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
654 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
656 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
658 lhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
659 rhs = rewriter.
create<vector::TransposeOp>(loc, tmp, perm);
660 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
662 rhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
672 if (maps == infer({{m, n}, {n}, {m}})) {
674 }
else if (maps == infer({{n, m}, {n}, {m}})) {
675 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
676 }
else if (maps == infer({{n}, {m, n}, {m}})) {
678 }
else if (maps == infer({{n}, {n, m}, {m}})) {
680 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
688 VectorType dstType = cast<VectorType>(op.getResultType());
689 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
690 "Expected dst type of rank 1 or 2");
692 unsigned rank = dstType.getRank();
693 unsigned dstRows = dstType.getShape()[0];
694 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
697 Value res = rewriter.
create<arith::ConstantOp>(loc, dstType,
699 bool isInt = isa<IntegerType>(dstType.getElementType());
700 for (
unsigned r = 0; r < dstRows; ++r) {
702 for (
unsigned c = 0; c < dstColumns; ++c) {
705 : rewriter.
create<vector::ExtractOp>(op.
getLoc(), rhs, c);
707 Value reduced = rewriter.
create<vector::ReductionOp>(
708 op.
getLoc(), vector::CombiningKind::ADD, m);
712 res = rewriter.
create<vector::InsertOp>(op.
getLoc(), reduced, res, pos);
715 if (
auto acc = op.getAcc())
723 struct ContractOpToElementwise
726 using FilterConstraintType =
728 static LogicalResult defaultFilter(vector::ContractionOp op) {
731 ContractOpToElementwise(
734 const FilterConstraintType &constraint = defaultFilter)
736 vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
738 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
741 auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation());
742 if (maskableOp.isMasked())
745 if (
failed(filter(contractOp)))
748 if (vectorTransformOptions.vectorContractLowering !=
749 vector::VectorContractLowering::ParallelArith)
754 AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
755 AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
761 for (int64_t dim : lhsReductionDims) {
762 if (lhsShape[dim] != 1)
765 for (int64_t dim : rhsReductionDims) {
766 if (rhsShape[dim] != 1)
769 AffineMap accMap = contractOp.getIndexingMapsArray()[2];
771 unsigned numLhsDimToBroadcast =
772 numParallelDims - (lhsMap.
getNumResults() - lhsReductionDims.size());
773 unsigned numRhsDimToBroadcast =
774 numParallelDims - (rhsMap.
getNumResults() - rhsReductionDims.size());
779 for (int64_t dim : lhsReductionDims)
780 lhsTranspose.push_back(numLhsDimToBroadcast + dim);
781 for (int64_t dim : rhsReductionDims)
782 rhsTranspose.push_back(numRhsDimToBroadcast + dim);
785 for (
unsigned i = 0; i < numParallelDims; i++) {
786 std::optional<unsigned> lhsDim =
789 lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
793 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
794 lhsTranspose.push_back(lhsDims.size() - 1);
796 std::optional<unsigned> rhsDim =
799 rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
803 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
804 rhsTranspose.push_back(rhsDims.size() - 1);
807 Value newLhs = contractOp.getLhs();
808 Value newRhs = contractOp.getRhs();
810 if (!lhsDims.empty()) {
811 lhsDims.append(lhsShape.begin(), lhsShape.end());
814 newLhs = rewriter.
create<vector::BroadcastOp>(loc, expandedType, newLhs);
816 if (!rhsDims.empty()) {
817 rhsDims.append(rhsShape.begin(), rhsShape.end());
820 newRhs = rewriter.
create<vector::BroadcastOp>(loc, expandedType, newRhs);
822 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
823 newLhs = rewriter.
create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
824 newRhs = rewriter.
create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
827 newLhs = rewriter.
create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
828 newRhs = rewriter.
create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
829 std::optional<Value> result =
831 contractOp.getKind(), rewriter, isInt);
832 rewriter.
replaceOp(contractOp, {*result});
839 FilterConstraintType filter;
860 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
866 if (op.getLhsType().getElementType() !=
873 if (op.getKind() != vector::CombiningKind::ADD) {
875 op,
"contractions other than 'add' not supported");
880 ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
881 if (
succeeded(pat1.matchAndRewrite(op, rewriter)))
883 ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
884 if (
succeeded(pat2.matchAndRewrite(op, rewriter)))
886 ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
887 if (
succeeded(pat3.matchAndRewrite(op, rewriter)))
889 ContractOpToElementwise pat4(vectorTransformOptions, ctx);
890 if (
succeeded(pat4.matchAndRewrite(op, rewriter)))
899 rootOp = op.getMaskingOp();
900 mask = op.getMaskingOp().getMask();
904 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
905 if (!batchDimMap.empty()) {
906 int64_t lhsIndex = batchDimMap[0].first;
907 int64_t rhsIndex = batchDimMap[0].second;
908 auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
916 std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
917 op.getContractingDimMap();
920 for (
auto &dimPair : contractingDimMap) {
921 lhsContractingDimSet.insert(dimPair.first);
922 rhsContractingDimSet.insert(dimPair.second);
926 VectorType lhsType = op.getLhsType();
927 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
928 if (lhsContractingDimSet.count(lhsIndex) == 0) {
929 auto newOp = lowerParallel(rewriter, op, lhsIndex, -1, mask);
938 VectorType rhsType = op.getRhsType();
939 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
940 if (rhsContractingDimSet.count(rhsIndex) == 0) {
941 auto newOp = lowerParallel(rewriter, op, -1, rhsIndex, mask);
950 if (!contractingDimMap.empty()) {
951 auto newOp = lowerReduction(rewriter, op, mask);
965 vector::ContractionOp op,
969 VectorType lhsType = op.getLhsType();
970 VectorType rhsType = op.getRhsType();
971 VectorType resType = cast<VectorType>(op.getResultType());
974 int64_t iterIndex = -1;
975 int64_t dimSize = -1;
977 iterIndex = iMap[0].getDimPosition(lhsIndex);
978 if (rhsIndex >= 0 && iterIndex != iMap[1].
getDimPosition(rhsIndex))
980 diag <<
"expected lhsIndex=" << lhsIndex <<
" and rhsIndex=" << rhsIndex
981 <<
" to map to the same dimension";
983 dimSize = lhsType.getDimSize(lhsIndex);
984 }
else if (rhsIndex >= 0) {
985 iterIndex = iMap[1].getDimPosition(rhsIndex);
986 dimSize = rhsType.getDimSize(rhsIndex);
990 diag <<
"expected either lhsIndex=" << lhsIndex
991 <<
" or rhsIndex=" << rhsIndex <<
" to be nonnegative";
1000 int64_t resIndex =
getResultIndex(iMap[2], iterIndex).value_or(-1);
1001 if (resIndex == -1 && dimSize != 1)
1003 diag <<
"expected the dimension for iterIndex=" << iterIndex
1004 <<
" to either appear in the result map, or to be a unit dimension";
1008 std::array<AffineMap, 3> lowIndexingMaps = {
1009 adjustMap(iMap[0], iterIndex, rewriter),
1010 adjustMap(iMap[1], iterIndex, rewriter),
1011 adjustMap(iMap[2], iterIndex, rewriter)};
1017 Value result = rewriter.
create<arith::ConstantOp>(
1020 for (int64_t d = 0; d < dimSize; ++d) {
1021 auto lhs =
reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1022 auto rhs =
reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1023 auto acc =
reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1028 iterIndex, d, rewriter);
1031 loc, lhs, rhs, acc, lowAffine, lowIter);
1032 lowContract =
maskOperation(rewriter, lowContract, lowMask);
1034 resIndex, d, rewriter);
1043 VectorType lhsType = op.getLhsType();
1044 VectorType rhsType = op.getRhsType();
1045 Type resType = op.getResultType();
1046 if (isa<VectorType>(resType))
1048 "did not expect a VectorType result");
1049 bool isInt = isa<IntegerType>(resType);
1051 int64_t iterIndex = 0;
1053 std::optional<int64_t> lookupLhs =
getResultIndex(iMap[0], iterIndex);
1054 std::optional<int64_t> lookupRhs =
getResultIndex(iMap[1], iterIndex);
1055 if (!lookupLhs.has_value())
1057 diag <<
"expected iterIndex=" << iterIndex <<
"to map to a LHS dimension";
1059 if (!lookupRhs.has_value())
1061 diag <<
"expected iterIndex=" << iterIndex <<
"to map to a RHS dimension";
1063 int64_t lhsIndex = *lookupLhs;
1064 int64_t rhsIndex = *lookupRhs;
1065 int64_t dimSize = lhsType.getDimSize(lhsIndex);
1066 if (dimSize != rhsType.getDimSize(rhsIndex))
1068 diag <<
"expect LHS dimension " << lhsIndex
1069 <<
" to have the same size as RHS dimension " << rhsIndex;
1072 if (lhsType.getRank() == 1) {
1073 if (rhsType.getRank() != 1)
1075 op,
"When LHS has rank 1, expected also RHS to have rank 1");
1076 Value m =
createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1077 auto kind = vector::CombiningKind::ADD;
1079 Value acc = op.getAcc();
1081 acc ? rewriter.
create<vector::ReductionOp>(loc, kind, m, acc)
1082 : rewriter.
create<vector::ReductionOp>(loc, kind, m);
1086 std::array<AffineMap, 3> lowIndexingMaps = {
1087 adjustMap(iMap[0], iterIndex, rewriter),
1088 adjustMap(iMap[1], iterIndex, rewriter),
1089 adjustMap(iMap[2], iterIndex, rewriter)};
1097 Value result = op.getAcc();
1098 for (int64_t d = 0; d < dimSize; ++d) {
1099 auto lhs =
reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1100 auto rhs =
reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1104 iterIndex, d, rewriter);
1107 loc, lhs, rhs, result, lowAffine, lowIter);
1126 class OuterProductOpLowering :
public OpRewritePattern<vector::OuterProductOp> {
1132 VectorType resType = op.getResultVectorType();
1133 if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1138 VectorType lhsType = op.getOperandVectorTypeLHS();
1139 VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1140 Type eltType = resType.getElementType();
1141 bool isInt = isa<IntegerType, IndexType>(eltType);
1142 Value acc = op.getAcc();
1143 vector::CombiningKind kind = op.getKind();
1147 auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1150 if (maskableOp.isMasked()) {
1152 rootOp = maskableOp.getMaskingOp();
1153 mask = maskableOp.getMaskingOp().getMask();
1160 Value b = rewriter.
create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
1162 loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
1163 if (!mult.has_value())
1169 Value result = rewriter.
create<arith::ConstantOp>(
1171 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1172 Value x = rewriter.
create<vector::ExtractOp>(loc, op.getLhs(), d);
1173 Value a = rewriter.
create<vector::BroadcastOp>(loc, rhsType, x);
1176 r = rewriter.
create<vector::ExtractOp>(loc, acc, d);
1179 extrMask = rewriter.
create<vector::ExtractOp>(loc, mask, d);
1182 loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
1185 result = rewriter.
create<vector::InsertOp>(loc, *m, result, d);
1211 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
1214 auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
1215 if (maskableOp.isMasked())
1218 if (vectorTransformOptions.vectorContractLowering !=
1219 vector::VectorContractLowering::Matmul)
1224 auto iteratorTypes = op.getIteratorTypes().getValue();
1230 Type elementType = op.getLhsType().getElementType();
1234 Type dstElementType = op.getType();
1235 if (
auto vecType = dyn_cast<VectorType>(dstElementType))
1236 dstElementType = vecType.getElementType();
1237 if (elementType != dstElementType)
1247 Value lhs = op.getLhs();
1248 auto lhsMap = op.getIndexingMapsArray()[0];
1255 Value rhs = op.getRhs();
1256 auto rhsMap = op.getIndexingMapsArray()[1];
1263 VectorType lhsType = cast<VectorType>(lhs.getType());
1264 VectorType rhsType = cast<VectorType>(rhs.getType());
1265 int64_t lhsRows = lhsType.getDimSize(0);
1266 int64_t lhsColumns = lhsType.getDimSize(1);
1267 int64_t rhsColumns = rhsType.getDimSize(1);
1269 Type flattenedLHSType =
1271 lhs = rew.
create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1273 Type flattenedRHSType =
1275 rhs = rew.
create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1277 Value mul = rew.
create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1279 mul = rew.
create<vector::ShapeCastOp>(
1286 auto accMap = op.getIndexingMapsArray()[2];
1290 llvm_unreachable(
"invalid contraction semantics");
1293 isa<IntegerType>(elementType)
1294 ?
static_cast<Value>(rew.
create<arith::AddIOp>(loc, op.getAcc(), mul))
1295 :
static_cast<Value>(
1296 rew.
create<arith::AddFOp>(loc, op.getAcc(), mul));
1306 if (!disableOuterProductLowering)
1307 patterns.
add<OuterProductOpLowering>(patterns.
getContext(), benefit);
1308 patterns.
add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
1309 ContractionOpToOuterProductOpLowering>(
1315 patterns.
add<OuterProductOpLowering>(patterns.
getContext(), benefit);
static SmallVector< int64_t > getReductionIndex(AffineMap map, ArrayAttr iteratorTypes)
Return the positions of the reductions in the given map.
static Value createMul(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates a MulIOp if isInt is true otherwise create an MulFOp using operands x andy`.
static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y.
static std::optional< unsigned > getDimPosition(AffineMap map, unsigned dim)
Look for a given dimension in an affine map and return its position.
static std::optional< Value > createContractArithOp(Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, bool isInt, Value mask=Value())
Helper to create arithmetic operation associated with a kind of contraction.
static Value reshapeStore(Location loc, Value val, Value result, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter)
static Value reshapeLoad(Location loc, Value val, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< Attribute > adjustIter(ArrayAttr iteratorTypes, int64_t index)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, Value mask=Value())
Return the result value of reducing two scalar/vector values with the corresponding arith operation.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...