26#define DEBUG_TYPE "vector-contract-lowering"
48 for (
const auto &it : llvm::enumerate(iteratorTypes)) {
52 results.push_back(it.value());
69 results.push_back(targetExpr);
84 return vector::ExtractOp::create(rewriter, loc, val, pos);
89 Value result = arith::ConstantOp::create(rewriter, loc, resType,
91 for (
int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
92 Value ext = vector::ExtractOp::create(rewriter, loc, val, d);
109 return vector::InsertOp::create(rewriter, loc, val,
result, pos);
113 for (
int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
114 Value ext = vector::ExtractOp::create(rewriter, loc,
result, d);
115 Value ins = vector::ExtractOp::create(rewriter, loc, val, d);
117 result = vector::InsertOp::create(rewriter, loc, sto,
result, d);
123static std::optional<Value>
127 using vector::CombiningKind;
131 if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
132 kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
135 mul = arith::MulIOp::create(rewriter, loc, x, y);
138 if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
139 kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
140 kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
141 kind == CombiningKind::XOR)
145 if (
acc && isa<VectorType>(
acc.getType()) && kind == CombiningKind::ADD) {
146 Value fma = vector::FMAOp::create(rewriter, loc, x, y,
acc);
153 mul = arith::MulFOp::create(rewriter, loc, x, y);
157 return std::optional<Value>(
mul);
169 dimsIdx.push_back(i);
189 return arith::AddIOp::create(rewriter, loc, x, y);
190 return arith::AddFOp::create(rewriter, loc, x, y);
198 return arith::MulIOp::create(rewriter, loc, x, y);
199 return arith::MulFOp::create(rewriter, loc, x, y);
219class ContractionOpToOuterProductOpLowering
222 using MaskableOpRewritePattern::MaskableOpRewritePattern;
224 using FilterConstraintType =
225 std::function<LogicalResult(vector::ContractionOp op)>;
227 static LogicalResult defaultFilter(vector::ContractionOp op) {
231 ContractionOpToOuterProductOpLowering(
232 vector::VectorContractLowering vectorContractLowering,
233 MLIRContext *context, PatternBenefit benefit = 1,
234 FilterConstraintType constraint = defaultFilter)
235 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
236 vectorContractLowering(vectorContractLowering),
237 filter(std::move(constraint)) {}
241 PatternRewriter &rewriter)
const override;
245 vector::VectorContractLowering vectorContractLowering;
246 FilterConstraintType filter;
267class ContractionOpToDotLowering
270 using MaskableOpRewritePattern::MaskableOpRewritePattern;
272 using FilterConstraintType =
273 std::function<LogicalResult(vector::ContractionOp op)>;
275 static LogicalResult defaultFilter(vector::ContractionOp op) {
279 ContractionOpToDotLowering(
280 vector::VectorContractLowering vectorContractLowering,
281 MLIRContext *context, PatternBenefit benefit = 1,
282 const FilterConstraintType &constraint = defaultFilter)
283 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
284 vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
288 PatternRewriter &rewriter)
const override;
292 vector::VectorContractLowering vectorContractLowering;
293 FilterConstraintType filter;
310class ContractionOpLowering
313 using MaskableOpRewritePattern::MaskableOpRewritePattern;
314 using FilterConstraintType =
315 std::function<LogicalResult(vector::ContractionOp op)>;
317 static LogicalResult defaultFilter(vector::ContractionOp op) {
321 ContractionOpLowering(
322 vector::VectorContractLowering vectorContractLoweringOption,
323 MLIRContext *context, PatternBenefit benefit = 1,
324 FilterConstraintType constraint = defaultFilter)
325 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
326 vectorContractLoweringOption(vectorContractLoweringOption),
327 filter(std::move(constraint)) {}
331 PatternRewriter &rewriter)
const override;
335 vector::VectorContractLowering vectorContractLoweringOption;
336 FilterConstraintType filter;
338 FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
339 vector::ContractionOp op, int64_t lhsIndex,
340 int64_t rhsIndex, Value mask)
const;
342 FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
343 vector::ContractionOp op, Value mask)
const;
348struct UnrolledOuterProductGenerator
350 UnrolledOuterProductGenerator(RewriterBase &
b, vector::ContractionOp op)
351 : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(
b, op),
352 kind(op.getKind()),
lhs(op.getLhs()),
rhs(op.getRhs()),
353 res(op.getAcc()), lhsType(op.getLhsType()) {
354 auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
355 if (maskableOp.isMasked())
356 mask = maskableOp.getMaskingOp().getMask();
359 Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
362 return vector::TransposeOp::create(rewriter, loc, v, perm);
365 Value
promote(Value v, Type dstElementType) {
366 Type elementType = v.
getType();
367 auto vecType = dyn_cast<VectorType>(elementType);
369 elementType = vecType.getElementType();
370 if (elementType == dstElementType)
372 Type promotedType = dstElementType;
374 promotedType = vecType.clone(promotedType);
375 if (isa<FloatType>(dstElementType))
376 return arith::ExtFOp::create(rewriter, loc, promotedType, v);
377 return arith::ExtSIOp::create(rewriter, loc, promotedType, v);
380 FailureOr<Value> outerProd(Value
lhs, Value
rhs, Value res,
381 VectorType lhsType,
int reductionSize,
382 std::optional<Value> maybeMask = std::nullopt) {
384 if (mask && !maybeMask.has_value())
387 Type resElementType = cast<VectorType>(res.
getType()).getElementType();
388 for (int64_t k = 0; k < reductionSize; ++k) {
389 Value extractA = vector::ExtractOp::create(rewriter, loc,
lhs, k);
390 Value extractB = vector::ExtractOp::create(rewriter, loc,
rhs, k);
391 extractA =
promote(extractA, resElementType);
392 extractB =
promote(extractB, resElementType);
394 if (maybeMask.has_value() && maybeMask.value())
396 vector::ExtractOp::create(rewriter, loc, maybeMask.value(), k);
398 Operation *outerProdOp = vector::OuterProductOp::create(
399 rewriter, loc, res.
getType(), extractA, extractB, res, kind);
408 std::optional<int64_t> getReductionSize(VectorType vecType,
409 int64_t reductionDim) {
411 if (vecType.getScalableDims()[reductionDim])
413 int64_t reductionSize = vecType.getDimSize(reductionDim);
414 assert(reductionSize > 0 &&
415 "Reduction dim must be a known static size to allow unrolling");
416 return reductionSize;
420 FailureOr<Value> matmat() {
421 if (!
iters({Par(), Par(), Red()}))
428 if (
layout({{m, k}, {k, n}, {m, n}})) {
429 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
433 Value tMask = t(mask, {2, 0, 1});
434 return outerProd(tLhs,
rhs, res, lhsType, *reductionSize, tMask);
438 if (
layout({{m, k}, {n, k}, {m, n}})) {
439 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
442 Value tMask = t(mask, {2, 0, 1});
443 return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
447 if (
layout({{k, m}, {k, n}, {m, n}})) {
448 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
449 Value tMask = t(mask, {2, 0, 1});
450 return outerProd(
lhs,
rhs, res, lhsType, *reductionSize, tMask);
454 if (
layout({{k, m}, {n, k}, {m, n}})) {
455 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
457 Value tMask = t(mask, {2, 0, 1});
458 return outerProd(
lhs, tRhs, res, lhsType, *reductionSize, tMask);
463 if (
layout({{m, k}, {k, n}, {n, m}})) {
464 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
466 Value tMask = t(mask, {2, 0, 1});
467 return outerProd(
rhs, tLhs, res, lhsType, *reductionSize, tMask);
471 if (
layout({{m, k}, {n, k}, {n, m}})) {
472 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
475 Value tMask = t(mask, {2, 0, 1});
476 return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
479 if (
layout({{k, m}, {k, n}, {n, m}})) {
480 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
481 Value tMask = t(mask, {2, 0, 1});
482 return outerProd(
rhs,
lhs, res, lhsType, *reductionSize, tMask);
485 if (
layout({{k, m}, {n, k}, {n, m}})) {
486 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
488 Value tMask = t(mask, {2, 0, 1});
489 return outerProd(tRhs,
lhs, res, lhsType, *reductionSize, tMask);
500 FailureOr<Value> matvec() {
501 if (!
iters({Par(), Red()}))
507 if (
layout({{m, k}, {k}, {m}})) {
508 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
510 Value tMask = t(mask);
511 return outerProd(tLhs,
rhs, res, lhsType, *reductionSize, tMask);
515 if (
layout({{k, m}, {k}, {m}})) {
516 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
517 Value tMask = t(mask);
518 return outerProd(
lhs,
rhs, res, lhsType, *reductionSize, tMask);
522 if (
layout({{k}, {m, k}, {m}})) {
523 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
525 Value tMask = t(mask);
526 return outerProd(tRhs,
lhs, res, lhsType, *reductionSize, tMask);
530 if (
layout({{k}, {k, m}, {m}})) {
531 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
532 Value tMask = t(mask);
533 return outerProd(
rhs,
lhs, res, lhsType, *reductionSize, tMask);
543 FailureOr<Value> tmatvec() {
544 if (!
iters({Red(), Par()}))
550 if (
layout({{m, k}, {k}, {m}}))
551 if (
auto reductionSize = getReductionSize(lhsType, 1))
552 return outerProd(t(
lhs),
rhs, res, lhsType, *reductionSize, mask);
554 if (
layout({{k, m}, {k}, {m}}))
555 if (
auto reductionSize = getReductionSize(lhsType, 0))
556 return outerProd(
lhs,
rhs, res, lhsType, *reductionSize, mask);
558 if (
layout({{k}, {m, k}, {m}}))
559 if (
auto reductionSize = getReductionSize(lhsType, 0))
560 return outerProd(t(
rhs),
lhs, res, lhsType, *reductionSize, mask);
562 if (
layout({{k}, {k, m}, {m}}))
563 if (
auto reductionSize = getReductionSize(lhsType, 0))
564 return outerProd(
rhs,
lhs, res, lhsType, *reductionSize, mask);
569 vector::CombiningKind kind;
570 Value
lhs,
rhs, res, mask;
590ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
591 vector::ContractionOp op, MaskingOpInterface maskOp,
593 if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
599 UnrolledOuterProductGenerator e(rewriter, op);
600 FailureOr<Value> matmatRes = e.matmat();
601 if (succeeded(matmatRes)) {
604 FailureOr<Value> matvecRes = e.matvec();
605 if (succeeded(matvecRes)) {
609 FailureOr<Value> tmatvecRes = e.tmatvec();
613FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
614 vector::ContractionOp op, MaskingOpInterface maskOp,
623 if (vectorContractLowering != vector::VectorContractLowering::Dot)
626 auto iteratorTypes = op.getIteratorTypes().getValue();
627 static constexpr std::array<int64_t, 2> perm = {1, 0};
632 auto infer = [&](MapList m) {
648 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
649 rhs = vector::TransposeOp::create(rewriter, loc,
rhs, perm);
650 }
else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
652 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
653 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
654 rhs = vector::TransposeOp::create(rewriter, loc,
rhs, perm);
655 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
656 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
657 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
660 lhs = vector::TransposeOp::create(rewriter, loc,
rhs, perm);
662 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
664 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
666 lhs = vector::TransposeOp::create(rewriter, loc,
rhs, perm);
667 rhs = vector::TransposeOp::create(rewriter, loc, tmp, perm);
668 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
670 rhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
680 if (maps == infer({{m, n}, {n}, {m}})) {
682 }
else if (maps == infer({{n, m}, {n}, {m}})) {
683 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
684 }
else if (maps == infer({{n}, {m, n}, {m}})) {
686 }
else if (maps == infer({{n}, {n, m}, {m}})) {
688 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
696 VectorType dstType = cast<VectorType>(op.getResultType());
697 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
698 "Expected dst type of rank 1 or 2");
700 unsigned rank = dstType.getRank();
701 unsigned dstRows = dstType.getShape()[0];
702 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
705 Value res = arith::ConstantOp::create(rewriter, loc, dstType,
707 bool isInt = isa<IntegerType>(dstType.getElementType());
709 extractedCols.reserve(dstColumns);
710 for (
unsigned r = 0; r < dstRows; ++r) {
711 Value rowLhs = vector::ExtractOp::create(rewriter, op.getLoc(),
lhs, r);
712 for (
unsigned c = 0; c < dstColumns; ++c) {
719 : vector::ExtractOp::create(rewriter, op.getLoc(),
rhs, c);
720 extractedCols.push_back(colRhs);
722 Value extractedColRhs = extractedCols[c];
724 createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter);
725 Value sum = vector::ReductionOp::create(
726 rewriter, op.getLoc(), vector::CombiningKind::ADD,
product);
730 res = vector::InsertOp::create(rewriter, op.getLoc(), sum, res, pos);
733 if (
auto acc = op.getAcc())
734 res =
createAdd(op.getLoc(), res,
acc, isInt, rewriter);
742 using MaskableOpRewritePattern::MaskableOpRewritePattern;
744 std::function<LogicalResult(vector::ContractionOp op)>;
749 vector::VectorContractLowering vectorContractLowering,
753 vectorContractLowering(vectorContractLowering), filter(
defaultFilter) {}
757 MaskingOpInterface maskOp,
763 if (failed(filter(contractOp)))
766 if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
771 AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
772 AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
778 for (
int64_t dim : lhsReductionDims) {
779 if (lhsShape[dim] != 1)
782 for (
int64_t dim : rhsReductionDims) {
783 if (rhsShape[dim] != 1)
786 AffineMap accMap = contractOp.getIndexingMapsArray()[2];
788 unsigned numLhsDimToBroadcast =
789 numParallelDims - (lhsMap.
getNumResults() - lhsReductionDims.size());
790 unsigned numRhsDimToBroadcast =
791 numParallelDims - (rhsMap.
getNumResults() - rhsReductionDims.size());
796 for (
int64_t dim : lhsReductionDims)
797 lhsTranspose.push_back(numLhsDimToBroadcast + dim);
798 for (
int64_t dim : rhsReductionDims)
799 rhsTranspose.push_back(numRhsDimToBroadcast + dim);
802 for (
unsigned i = 0; i < numParallelDims; i++) {
803 std::optional<unsigned> lhsDim =
806 lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
810 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
811 lhsTranspose.push_back(lhsDims.size() - 1);
813 std::optional<unsigned> rhsDim =
816 rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
820 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
821 rhsTranspose.push_back(rhsDims.size() - 1);
824 Value newLhs = contractOp.getLhs();
825 Value newRhs = contractOp.getRhs();
827 if (!lhsDims.empty()) {
828 lhsDims.append(lhsShape.begin(), lhsShape.end());
830 VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
831 newLhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newLhs);
833 if (!rhsDims.empty()) {
834 rhsDims.append(rhsShape.begin(), rhsShape.end());
836 VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
837 newRhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newRhs);
839 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
840 newLhs = vector::TransposeOp::create(rewriter, loc, newLhs, lhsTranspose);
841 newRhs = vector::TransposeOp::create(rewriter, loc, newRhs, rhsTranspose);
844 newLhs = vector::ExtractOp::create(rewriter, loc, newLhs, lhsOffsets);
845 newRhs = vector::ExtractOp::create(rewriter, loc, newRhs, rhsOffsets);
846 std::optional<Value>
result =
848 contractOp.getKind(), rewriter, isInt);
857 vector::VectorContractLowering vectorContractLowering;
878FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
879 vector::ContractionOp op, MaskingOpInterface maskOp,
885 if (op.getLhsType().getElementType() !=
892 if (op.getKind() != vector::CombiningKind::ADD) {
894 op,
"contractions other than 'add' not supported");
900 ContractionOpToOuterProductOpLowering pat1(vectorContractLoweringOption, ctx);
901 FailureOr<Value> newVal1 =
902 pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
906 ContractionOpToDotLowering pat2(vectorContractLoweringOption, ctx);
907 FailureOr<Value> newVal2 =
908 pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
913 FailureOr<Value> newVal4 =
914 pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
922 mask = maskOp.getMask();
924 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
925 if (!batchDimMap.empty()) {
926 int64_t lhsIndex = batchDimMap[0].first;
927 int64_t rhsIndex = batchDimMap[0].second;
928 auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
935 std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
936 op.getContractingDimMap();
939 for (
auto &dimPair : contractingDimMap) {
940 lhsContractingDimSet.insert(dimPair.first);
941 rhsContractingDimSet.insert(dimPair.second);
945 VectorType lhsType = op.getLhsType();
946 for (
int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
947 if (lhsContractingDimSet.count(lhsIndex) == 0) {
948 auto newOp = lowerParallel(rewriter, op, lhsIndex, -1, mask);
956 VectorType rhsType = op.getRhsType();
957 for (
int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
958 if (rhsContractingDimSet.count(rhsIndex) == 0) {
959 auto newOp = lowerParallel(rewriter, op, -1, rhsIndex, mask);
967 if (!contractingDimMap.empty()) {
968 auto newOp = lowerReduction(rewriter, op, mask);
980FailureOr<Value> ContractionOpLowering::lowerParallel(
PatternRewriter &rewriter,
981 vector::ContractionOp op,
985 VectorType lhsType = op.getLhsType();
986 VectorType rhsType = op.getRhsType();
987 VectorType resType = cast<VectorType>(op.getResultType());
993 iterIndex = iMap[0].getDimPosition(lhsIndex);
994 if (rhsIndex >= 0 && iterIndex != iMap[1].
getDimPosition(rhsIndex))
996 diag <<
"expected lhsIndex=" << lhsIndex <<
" and rhsIndex=" << rhsIndex
997 <<
" to map to the same dimension";
999 if (lhsType.getScalableDims()[lhsIndex])
1001 diag <<
"Unrolling scalable dimension (lhsIndex=" << lhsIndex
1002 <<
") is not supported yet";
1004 dimSize = lhsType.getDimSize(lhsIndex);
1005 }
else if (rhsIndex >= 0) {
1006 iterIndex = iMap[1].getDimPosition(rhsIndex);
1007 if (rhsType.getScalableDims()[rhsIndex])
1009 diag <<
"Unrolling scalable dimension (rhsIndex=" << rhsIndex
1010 <<
") is not supported yet";
1012 dimSize = rhsType.getDimSize(rhsIndex);
1016 diag <<
"expected either lhsIndex=" << lhsIndex
1017 <<
" or rhsIndex=" << rhsIndex <<
" to be nonnegative";
1027 if (resIndex == -1 && dimSize != 1)
1029 diag <<
"expected the dimension for iterIndex=" << iterIndex
1030 <<
" to either appear in the result map, or to be a unit dimension";
1034 std::array<AffineMap, 3> lowIndexingMaps = {
1035 adjustMap(iMap[0], iterIndex, rewriter),
1036 adjustMap(iMap[1], iterIndex, rewriter),
1037 adjustMap(iMap[2], iterIndex, rewriter)};
1043 Value result = arith::ConstantOp::create(rewriter, loc, resType,
1046 for (
int64_t d = 0; d < dimSize; ++d) {
1047 auto lhs =
reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1048 auto rhs =
reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1049 auto acc =
reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1054 iterIndex, d, rewriter);
1056 Operation *lowContract = vector::ContractionOp::create(
1057 rewriter, loc,
lhs,
rhs,
acc, lowAffine, lowIter);
1058 lowContract =
maskOperation(rewriter, lowContract, lowMask);
1060 resIndex, d, rewriter);
1066FailureOr<Value> ContractionOpLowering::lowerReduction(
1068 auto loc = op.getLoc();
1069 VectorType lhsType = op.getLhsType();
1070 VectorType rhsType = op.getRhsType();
1071 Type resType = op.getResultType();
1072 if (isa<VectorType>(resType))
1074 "did not expect a VectorType result");
1075 bool isInt = isa<IntegerType>(resType);
1079 std::optional<int64_t> lookupLhs =
getResultIndex(iMap[0], iterIndex);
1080 std::optional<int64_t> lookupRhs =
getResultIndex(iMap[1], iterIndex);
1081 if (!lookupLhs.has_value())
1083 diag <<
"expected iterIndex=" << iterIndex <<
"to map to a LHS dimension";
1085 if (!lookupRhs.has_value())
1087 diag <<
"expected iterIndex=" << iterIndex <<
"to map to a RHS dimension";
1089 int64_t lhsIndex = *lookupLhs;
1090 int64_t rhsIndex = *lookupRhs;
1091 int64_t dimSize = lhsType.getDimSize(lhsIndex);
1092 if (dimSize != rhsType.getDimSize(rhsIndex))
1094 diag <<
"expect LHS dimension " << lhsIndex
1095 <<
" to have the same size as RHS dimension " << rhsIndex;
1098 if (lhsType.getRank() == 1) {
1099 if (rhsType.getRank() != 1)
1101 op,
"When LHS has rank 1, expected also RHS to have rank 1");
1102 Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1103 auto kind = vector::CombiningKind::ADD;
1107 acc ? vector::ReductionOp::create(rewriter, loc, kind, m,
acc)
1108 :
vector::ReductionOp::create(rewriter, loc, kind, m);
1112 std::array<AffineMap, 3> lowIndexingMaps = {
1113 adjustMap(iMap[0], iterIndex, rewriter),
1114 adjustMap(iMap[1], iterIndex, rewriter),
1115 adjustMap(iMap[2], iterIndex, rewriter)};
1124 for (
int64_t d = 0; d < dimSize; ++d) {
1125 auto lhs =
reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1126 auto rhs =
reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1130 iterIndex, d, rewriter);
1132 Operation *newContract = vector::ContractionOp::create(
1158 VectorType resType = op.getResultVectorType();
1159 if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1162 auto loc = op.getLoc();
1164 VectorType lhsType = op.getOperandVectorTypeLHS();
1165 VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1166 Type eltType = resType.getElementType();
1167 bool isInt = isa<IntegerType, IndexType>(eltType);
1169 vector::CombiningKind kind = op.getKind();
1173 auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1176 if (maskableOp.isMasked()) {
1178 rootOp = maskableOp.getMaskingOp();
1179 mask = maskableOp.getMaskingOp().getMask();
1187 vector::BroadcastOp::create(rewriter, loc, lhsType, op.getRhs());
1189 loc, op.getLhs(),
b,
acc, kind, rewriter, isInt, mask);
1190 if (!mult.has_value())
1196 Value result = arith::ConstantOp::create(rewriter, loc, resType,
1198 for (
int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1199 Value x = vector::ExtractOp::create(rewriter, loc, op.getLhs(), d);
1200 Value a = vector::BroadcastOp::create(rewriter, loc, rhsType, x);
1203 r = vector::ExtractOp::create(rewriter, loc,
acc, d);
1206 extrMask = vector::ExtractOp::create(rewriter, loc, mask, d);
1209 loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
1212 result = vector::InsertOp::create(rewriter, loc, *m,
result, d);
1224 VectorContractLowering vectorContractLoweringOption,
PatternBenefit benefit,
1225 bool disableOuterProductLowering) {
1226 if (!disableOuterProductLowering)
1228 patterns.add<ContractionOpLowering, ContractionOpToOuterProductOpLowering>(
1229 vectorContractLoweringOption,
patterns.getContext(), benefit);
static int64_t product(ArrayRef< int64_t > vals)
static std::optional< Value > createContractArithOp(Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, bool isInt, Value mask=Value())
Helper to create arithmetic operation associated with a kind of contraction.
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< int64_t > getReductionIndex(AffineMap map, ArrayAttr iteratorTypes)
Return the positions of the reductions in the given map.
static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y.
static std::optional< unsigned > getDimPosition(AffineMap map, unsigned dim)
Look for a given dimension in an affine map and return its position.
static SmallVector< Attribute > adjustIter(ArrayAttr iteratorTypes, int64_t index)
static Value reshapeStore(Location loc, Value val, Value result, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter)
static Value reshapeLoad(Location loc, Value val, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static std::string diag(const llvm::Value &value)
Progressive lowering of OuterProductOp.
LogicalResult matchAndRewrite(vector::OuterProductOp op, PatternRewriter &rewriter) const override
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
TypedAttr getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
bool iters(ArrayRef< IteratorType > its)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Lower vector.contract with all size one reduction dimensions to elementwise ops when possible.
FailureOr< Value > matchAndRewriteMaskableOp(vector::ContractionOp contractOp, MaskingOpInterface maskOp, PatternRewriter &rewriter) const override
static LogicalResult defaultFilter(vector::ContractionOp op)
ContractOpToElementwise(vector::VectorContractLowering vectorContractLowering, MLIRContext *context, PatternBenefit benefit=1, const FilterConstraintType &constraint=defaultFilter)
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
virtual FailureOr< Value > matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp, PatternRewriter &rewriter) const =0