26#define DEBUG_TYPE "vector-contract-lowering"
48 for (
const auto &it : llvm::enumerate(iteratorTypes)) {
52 results.push_back(it.value());
69 results.push_back(targetExpr);
94 return vector::ExtractOp::create(rewriter, loc, val, pos);
97 VectorType type = cast<VectorType>(val.
getType());
99 Value result = arith::ConstantOp::create(rewriter, loc, resType,
101 for (
int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
102 Value ext = vector::ExtractOp::create(rewriter, loc, val, d);
127 return vector::InsertOp::create(rewriter, loc, val,
result, pos);
130 VectorType type = cast<VectorType>(
result.getType());
131 for (
int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
132 Value ext = vector::ExtractOp::create(rewriter, loc,
result, d);
133 Value ins = vector::ExtractOp::create(rewriter, loc, val, d);
135 result = vector::InsertOp::create(rewriter, loc, sto,
result, d);
141static std::optional<Value>
145 arith::FastMathFlagsAttr fmf = {}) {
146 using vector::CombiningKind;
150 if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
151 kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
154 mul = arith::MulIOp::create(rewriter, loc, x, y);
157 if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
158 kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
159 kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
160 kind == CombiningKind::XOR)
164 if (
acc && isa<VectorType>(
acc.getType()) && kind == CombiningKind::ADD) {
165 Value fma = vector::FMAOp::create(rewriter, loc, x, y,
acc);
172 mul = arith::MulFOp::create(rewriter, loc, x, y, fmf);
176 return std::optional<Value>(
mul);
187 dimsIdx.push_back(i);
206 arith::FastMathFlagsAttr fmf = {}) {
208 return arith::AddIOp::create(rewriter, loc, x, y);
209 return arith::AddFOp::create(rewriter, loc, x, y, fmf);
216 arith::FastMathFlagsAttr fmf = {}) {
218 return arith::MulIOp::create(rewriter, loc, x, y);
219 return arith::MulFOp::create(rewriter, loc, x, y, fmf);
239class ContractionOpToOuterProductOpLowering
242 using MaskableOpRewritePattern::MaskableOpRewritePattern;
244 using FilterConstraintType =
245 std::function<LogicalResult(vector::ContractionOp op)>;
247 static LogicalResult defaultFilter(vector::ContractionOp op) {
251 ContractionOpToOuterProductOpLowering(
252 vector::VectorContractLowering vectorContractLowering,
253 MLIRContext *context, PatternBenefit benefit = 1,
254 FilterConstraintType constraint = defaultFilter)
255 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
256 vectorContractLowering(vectorContractLowering),
257 filter(std::move(constraint)) {}
261 PatternRewriter &rewriter)
const override;
265 vector::VectorContractLowering vectorContractLowering;
266 FilterConstraintType filter;
287class ContractionOpToDotLowering
290 using MaskableOpRewritePattern::MaskableOpRewritePattern;
292 using FilterConstraintType =
293 std::function<LogicalResult(vector::ContractionOp op)>;
295 static LogicalResult defaultFilter(vector::ContractionOp op) {
299 ContractionOpToDotLowering(
300 vector::VectorContractLowering vectorContractLowering,
301 MLIRContext *context, PatternBenefit benefit = 1,
302 const FilterConstraintType &constraint = defaultFilter)
303 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
304 vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
308 PatternRewriter &rewriter)
const override;
312 vector::VectorContractLowering vectorContractLowering;
313 FilterConstraintType filter;
330class ContractionOpLowering
333 using MaskableOpRewritePattern::MaskableOpRewritePattern;
334 using FilterConstraintType =
335 std::function<LogicalResult(vector::ContractionOp op)>;
337 static LogicalResult defaultFilter(vector::ContractionOp op) {
341 ContractionOpLowering(
342 vector::VectorContractLowering vectorContractLoweringOption,
343 MLIRContext *context, PatternBenefit benefit = 1,
344 FilterConstraintType constraint = defaultFilter)
345 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
346 vectorContractLoweringOption(vectorContractLoweringOption),
347 filter(std::move(constraint)) {}
351 PatternRewriter &rewriter)
const override;
355 vector::VectorContractLowering vectorContractLoweringOption;
356 FilterConstraintType filter;
358 FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
359 vector::ContractionOp op, int64_t lhsIndex,
360 int64_t rhsIndex, Value mask)
const;
362 FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
363 vector::ContractionOp op, Value mask)
const;
368struct UnrolledOuterProductGenerator
370 UnrolledOuterProductGenerator(RewriterBase &
b, vector::ContractionOp op)
371 : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(
b, op),
372 kind(op.getKind()),
lhs(op.getLhs()),
rhs(op.getRhs()),
373 res(op.getAcc()), lhsType(op.getLhsType()) {
374 auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
375 if (maskableOp.isMasked())
376 mask = maskableOp.getMaskingOp().getMask();
379 Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
382 return vector::TransposeOp::create(rewriter, loc, v, perm);
385 Value
promote(Value v, Type dstElementType) {
386 Type elementType = v.
getType();
387 auto vecType = dyn_cast<VectorType>(elementType);
389 elementType = vecType.getElementType();
390 if (elementType == dstElementType)
392 Type promotedType = dstElementType;
394 promotedType = vecType.clone(promotedType);
395 if (isa<FloatType>(dstElementType))
396 return arith::ExtFOp::create(rewriter, loc, promotedType, v);
397 return arith::ExtSIOp::create(rewriter, loc, promotedType, v);
400 FailureOr<Value> outerProd(Value
lhs, Value
rhs, Value res,
401 VectorType lhsType,
int reductionSize,
402 std::optional<Value> maybeMask = std::nullopt) {
404 if (mask && !maybeMask.has_value())
407 Type resElementType = cast<VectorType>(res.
getType()).getElementType();
408 for (int64_t k = 0; k < reductionSize; ++k) {
409 Value extractA = vector::ExtractOp::create(rewriter, loc,
lhs, k);
410 Value extractB = vector::ExtractOp::create(rewriter, loc,
rhs, k);
411 extractA =
promote(extractA, resElementType);
412 extractB =
promote(extractB, resElementType);
414 if (maybeMask.has_value() && maybeMask.value())
416 vector::ExtractOp::create(rewriter, loc, maybeMask.value(), k);
418 Operation *outerProdOp = vector::OuterProductOp::create(
419 rewriter, loc, res.
getType(), extractA, extractB, res, kind);
428 std::optional<int64_t> getReductionSize(VectorType vecType,
429 int64_t reductionDim) {
431 if (vecType.getScalableDims()[reductionDim])
433 int64_t reductionSize = vecType.getDimSize(reductionDim);
434 assert(reductionSize > 0 &&
435 "Reduction dim must be a known static size to allow unrolling");
436 return reductionSize;
440 FailureOr<Value> matmat() {
441 if (!
iters({Par(), Par(), Red()}))
448 if (
layout({{m, k}, {k, n}, {m, n}})) {
449 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
453 Value tMask = t(mask, {2, 0, 1});
454 return outerProd(tLhs,
rhs, res, lhsType, *reductionSize, tMask);
458 if (
layout({{m, k}, {n, k}, {m, n}})) {
459 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
462 Value tMask = t(mask, {2, 0, 1});
463 return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
467 if (
layout({{k, m}, {k, n}, {m, n}})) {
468 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
469 Value tMask = t(mask, {2, 0, 1});
470 return outerProd(
lhs,
rhs, res, lhsType, *reductionSize, tMask);
474 if (
layout({{k, m}, {n, k}, {m, n}})) {
475 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
477 Value tMask = t(mask, {2, 0, 1});
478 return outerProd(
lhs, tRhs, res, lhsType, *reductionSize, tMask);
483 if (
layout({{m, k}, {k, n}, {n, m}})) {
484 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
486 Value tMask = t(mask, {2, 0, 1});
487 return outerProd(
rhs, tLhs, res, lhsType, *reductionSize, tMask);
491 if (
layout({{m, k}, {n, k}, {n, m}})) {
492 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
495 Value tMask = t(mask, {2, 0, 1});
496 return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
499 if (
layout({{k, m}, {k, n}, {n, m}})) {
500 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
501 Value tMask = t(mask, {2, 0, 1});
502 return outerProd(
rhs,
lhs, res, lhsType, *reductionSize, tMask);
505 if (
layout({{k, m}, {n, k}, {n, m}})) {
506 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
508 Value tMask = t(mask, {2, 0, 1});
509 return outerProd(tRhs,
lhs, res, lhsType, *reductionSize, tMask);
520 FailureOr<Value> matvec() {
521 if (!
iters({Par(), Red()}))
527 if (
layout({{m, k}, {k}, {m}})) {
528 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
530 Value tMask = t(mask);
531 return outerProd(tLhs,
rhs, res, lhsType, *reductionSize, tMask);
535 if (
layout({{k, m}, {k}, {m}})) {
536 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
537 Value tMask = t(mask);
538 return outerProd(
lhs,
rhs, res, lhsType, *reductionSize, tMask);
542 if (
layout({{k}, {m, k}, {m}})) {
543 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
545 Value tMask = t(mask);
546 return outerProd(tRhs,
lhs, res, lhsType, *reductionSize, tMask);
550 if (
layout({{k}, {k, m}, {m}})) {
551 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
552 Value tMask = t(mask);
553 return outerProd(
rhs,
lhs, res, lhsType, *reductionSize, tMask);
563 FailureOr<Value> tmatvec() {
564 if (!
iters({Red(), Par()}))
570 if (
layout({{m, k}, {k}, {m}}))
571 if (
auto reductionSize = getReductionSize(lhsType, 1))
572 return outerProd(t(
lhs),
rhs, res, lhsType, *reductionSize, mask);
574 if (
layout({{k, m}, {k}, {m}}))
575 if (
auto reductionSize = getReductionSize(lhsType, 0))
576 return outerProd(
lhs,
rhs, res, lhsType, *reductionSize, mask);
578 if (
layout({{k}, {m, k}, {m}}))
579 if (
auto reductionSize = getReductionSize(lhsType, 0))
580 return outerProd(t(
rhs),
lhs, res, lhsType, *reductionSize, mask);
582 if (
layout({{k}, {k, m}, {m}}))
583 if (
auto reductionSize = getReductionSize(lhsType, 0))
584 return outerProd(
rhs,
lhs, res, lhsType, *reductionSize, mask);
589 vector::CombiningKind kind;
590 Value
lhs,
rhs, res, mask;
610ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
611 vector::ContractionOp op, MaskingOpInterface maskOp,
613 if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
619 UnrolledOuterProductGenerator e(rewriter, op);
620 FailureOr<Value> matmatRes = e.matmat();
621 if (succeeded(matmatRes)) {
624 FailureOr<Value> matvecRes = e.matvec();
625 if (succeeded(matvecRes)) {
629 FailureOr<Value> tmatvecRes = e.tmatvec();
633FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
634 vector::ContractionOp op, MaskingOpInterface maskOp,
643 if (vectorContractLowering != vector::VectorContractLowering::Dot)
646 auto iteratorTypes = op.getIteratorTypes().getValue();
647 static constexpr std::array<int64_t, 2> perm = {1, 0};
652 auto infer = [&](MapList m) {
668 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
669 rhs = vector::TransposeOp::create(rewriter, loc,
rhs, perm);
670 }
else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
672 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
673 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
674 rhs = vector::TransposeOp::create(rewriter, loc,
rhs, perm);
675 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
676 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
677 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
680 lhs = vector::TransposeOp::create(rewriter, loc,
rhs, perm);
682 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
684 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
686 lhs = vector::TransposeOp::create(rewriter, loc,
rhs, perm);
687 rhs = vector::TransposeOp::create(rewriter, loc, tmp, perm);
688 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
690 rhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
700 if (maps == infer({{m, n}, {n}, {m}})) {
702 }
else if (maps == infer({{n, m}, {n}, {m}})) {
703 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
704 }
else if (maps == infer({{n}, {m, n}, {m}})) {
706 }
else if (maps == infer({{n}, {n, m}, {m}})) {
708 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
716 VectorType dstType = cast<VectorType>(op.getResultType());
717 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
718 "Expected dst type of rank 1 or 2");
720 unsigned rank = dstType.getRank();
721 unsigned dstRows = dstType.getShape()[0];
722 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
725 Value res = arith::ConstantOp::create(rewriter, loc, dstType,
727 bool isInt = isa<IntegerType>(dstType.getElementType());
728 arith::FastMathFlagsAttr fmf = op.getFastmathAttr();
730 extractedCols.reserve(dstColumns);
731 for (
unsigned r = 0; r < dstRows; ++r) {
732 Value rowLhs = vector::ExtractOp::create(rewriter, op.getLoc(),
lhs, r);
733 for (
unsigned c = 0; c < dstColumns; ++c) {
740 : vector::ExtractOp::create(rewriter, op.getLoc(),
rhs, c);
741 extractedCols.push_back(colRhs);
743 Value extractedColRhs = extractedCols[c];
745 createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter, fmf);
746 Value sum = vector::ReductionOp::create(rewriter, op.getLoc(),
747 vector::CombiningKind::ADD,
752 res = vector::InsertOp::create(rewriter, op.getLoc(), sum, res, pos);
755 if (
auto acc = op.getAcc())
756 res =
createAdd(op.getLoc(), res,
acc, isInt, rewriter, fmf);
764 using MaskableOpRewritePattern::MaskableOpRewritePattern;
766 std::function<LogicalResult(vector::ContractionOp op)>;
771 vector::VectorContractLowering vectorContractLowering,
775 vectorContractLowering(vectorContractLowering), filter(
defaultFilter) {}
779 MaskingOpInterface maskOp,
785 if (failed(filter(contractOp)))
788 if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
793 AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
794 AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
800 for (
int64_t dim : lhsReductionDims) {
801 if (lhsShape[dim] != 1)
804 for (
int64_t dim : rhsReductionDims) {
805 if (rhsShape[dim] != 1)
808 AffineMap accMap = contractOp.getIndexingMapsArray()[2];
810 unsigned numLhsDimToBroadcast =
811 numParallelDims - (lhsMap.
getNumResults() - lhsReductionDims.size());
812 unsigned numRhsDimToBroadcast =
813 numParallelDims - (rhsMap.
getNumResults() - rhsReductionDims.size());
818 for (
int64_t dim : lhsReductionDims)
819 lhsTranspose.push_back(numLhsDimToBroadcast + dim);
820 for (
int64_t dim : rhsReductionDims)
821 rhsTranspose.push_back(numRhsDimToBroadcast + dim);
824 for (
unsigned i = 0; i < numParallelDims; i++) {
825 std::optional<unsigned> lhsDim =
828 lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
832 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
833 lhsTranspose.push_back(lhsDims.size() - 1);
835 std::optional<unsigned> rhsDim =
838 rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
842 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
843 rhsTranspose.push_back(rhsDims.size() - 1);
846 Value newLhs = contractOp.getLhs();
847 Value newRhs = contractOp.getRhs();
849 if (!lhsDims.empty()) {
850 lhsDims.append(lhsShape.begin(), lhsShape.end());
852 VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
853 newLhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newLhs);
855 if (!rhsDims.empty()) {
856 rhsDims.append(rhsShape.begin(), rhsShape.end());
858 VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
859 newRhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newRhs);
861 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
862 newLhs = vector::TransposeOp::create(rewriter, loc, newLhs, lhsTranspose);
863 newRhs = vector::TransposeOp::create(rewriter, loc, newRhs, rhsTranspose);
866 newLhs = vector::ExtractOp::create(rewriter, loc, newLhs, lhsOffsets);
867 newRhs = vector::ExtractOp::create(rewriter, loc, newRhs, rhsOffsets);
868 std::optional<Value>
result =
870 contractOp.getKind(), rewriter, isInt,
871 Value(), contractOp.getFastmathAttr());
880 vector::VectorContractLowering vectorContractLowering;
901FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
902 vector::ContractionOp op, MaskingOpInterface maskOp,
908 if (op.getLhsType().getElementType() !=
915 if (op.getKind() != vector::CombiningKind::ADD) {
917 op,
"contractions other than 'add' not supported");
923 ContractionOpToOuterProductOpLowering pat1(vectorContractLoweringOption, ctx);
924 FailureOr<Value> newVal1 =
925 pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
929 ContractionOpToDotLowering pat2(vectorContractLoweringOption, ctx);
930 FailureOr<Value> newVal2 =
931 pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
936 FailureOr<Value> newVal4 =
937 pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
945 mask = maskOp.getMask();
947 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
948 if (!batchDimMap.empty()) {
949 int64_t lhsIndex = batchDimMap[0].first;
950 int64_t rhsIndex = batchDimMap[0].second;
951 auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
958 std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
959 op.getContractingDimMap();
962 for (
auto &dimPair : contractingDimMap) {
963 lhsContractingDimSet.insert(dimPair.first);
964 rhsContractingDimSet.insert(dimPair.second);
968 VectorType lhsType = op.getLhsType();
969 for (
int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
970 if (lhsContractingDimSet.count(lhsIndex) == 0) {
971 auto newOp = lowerParallel(rewriter, op, lhsIndex, -1, mask);
979 VectorType rhsType = op.getRhsType();
980 for (
int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
981 if (rhsContractingDimSet.count(rhsIndex) == 0) {
982 auto newOp = lowerParallel(rewriter, op, -1, rhsIndex, mask);
990 if (!contractingDimMap.empty()) {
991 auto newOp = lowerReduction(rewriter, op, mask);
1003FailureOr<Value> ContractionOpLowering::lowerParallel(
PatternRewriter &rewriter,
1004 vector::ContractionOp op,
1008 VectorType lhsType = op.getLhsType();
1009 VectorType rhsType = op.getRhsType();
1010 VectorType resType = cast<VectorType>(op.getResultType());
1015 if (lhsIndex >= 0) {
1016 iterIndex = iMap[0].getDimPosition(lhsIndex);
1017 if (rhsIndex >= 0 && iterIndex != iMap[1].
getDimPosition(rhsIndex))
1019 diag <<
"expected lhsIndex=" << lhsIndex <<
" and rhsIndex=" << rhsIndex
1020 <<
" to map to the same dimension";
1022 if (lhsType.getScalableDims()[lhsIndex])
1024 diag <<
"Unrolling scalable dimension (lhsIndex=" << lhsIndex
1025 <<
") is not supported yet";
1027 dimSize = lhsType.getDimSize(lhsIndex);
1028 }
else if (rhsIndex >= 0) {
1029 iterIndex = iMap[1].getDimPosition(rhsIndex);
1030 if (rhsType.getScalableDims()[rhsIndex])
1032 diag <<
"Unrolling scalable dimension (rhsIndex=" << rhsIndex
1033 <<
") is not supported yet";
1035 dimSize = rhsType.getDimSize(rhsIndex);
1039 diag <<
"expected either lhsIndex=" << lhsIndex
1040 <<
" or rhsIndex=" << rhsIndex <<
" to be nonnegative";
1050 if (resIndex == -1 && dimSize != 1)
1052 diag <<
"expected the dimension for iterIndex=" << iterIndex
1053 <<
" to either appear in the result map, or to be a unit dimension";
1057 std::array<AffineMap, 3> lowIndexingMaps = {
1058 adjustMap(iMap[0], iterIndex, rewriter),
1059 adjustMap(iMap[1], iterIndex, rewriter),
1060 adjustMap(iMap[2], iterIndex, rewriter)};
1066 Value result = arith::ConstantOp::create(rewriter, loc, resType,
1069 for (
int64_t d = 0; d < dimSize; ++d) {
1076 lowMask =
reshapeLoad(loc, mask, iterIndex, d, rewriter);
1079 vector::ContractionOp::create(rewriter, loc,
lhs,
rhs,
acc, lowAffine,
1080 lowIter, op.getKind(), op.getFastmath());
1081 lowContract =
maskOperation(rewriter, lowContract, lowMask);
1089FailureOr<Value> ContractionOpLowering::lowerReduction(
1091 auto loc = op.getLoc();
1092 VectorType lhsType = op.getLhsType();
1093 VectorType rhsType = op.getRhsType();
1094 Type resType = op.getResultType();
1095 if (isa<VectorType>(resType))
1097 "did not expect a VectorType result");
1098 bool isInt = isa<IntegerType>(resType);
1102 std::optional<int64_t> lookupLhs =
getResultIndex(iMap[0], iterIndex);
1103 std::optional<int64_t> lookupRhs =
getResultIndex(iMap[1], iterIndex);
1104 if (!lookupLhs.has_value())
1106 diag <<
"expected iterIndex=" << iterIndex <<
"to map to a LHS dimension";
1108 if (!lookupRhs.has_value())
1110 diag <<
"expected iterIndex=" << iterIndex <<
"to map to a RHS dimension";
1112 int64_t lhsIndex = *lookupLhs;
1113 int64_t rhsIndex = *lookupRhs;
1114 int64_t dimSize = lhsType.getDimSize(lhsIndex);
1115 if (dimSize != rhsType.getDimSize(rhsIndex))
1117 diag <<
"expect LHS dimension " << lhsIndex
1118 <<
" to have the same size as RHS dimension " << rhsIndex;
1121 if (lhsType.getRank() == 1) {
1122 if (rhsType.getRank() != 1)
1124 op,
"When LHS has rank 1, expected also RHS to have rank 1");
1125 arith::FastMathFlagsAttr fmf = op.getFastmathAttr();
1126 Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter, fmf);
1127 auto kind = vector::CombiningKind::ADD;
1131 acc ? vector::ReductionOp::create(rewriter, loc, kind, m,
acc,
1133 :
vector::ReductionOp::create(rewriter, loc, kind, m,
1138 std::array<AffineMap, 3> lowIndexingMaps = {
1139 adjustMap(iMap[0], iterIndex, rewriter),
1140 adjustMap(iMap[1], iterIndex, rewriter),
1141 adjustMap(iMap[2], iterIndex, rewriter)};
1150 for (
int64_t d = 0; d < dimSize; ++d) {
1155 newMask =
reshapeLoad(loc, mask, iterIndex, d, rewriter);
1157 Operation *newContract = vector::ContractionOp::create(
1158 rewriter, loc,
lhs,
rhs,
result, lowAffine, lowIter, op.getKind(),
1184 VectorType resType = op.getResultVectorType();
1185 if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1188 auto loc = op.getLoc();
1190 VectorType lhsType = op.getOperandVectorTypeLHS();
1191 VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1192 Type eltType = resType.getElementType();
1193 bool isInt = isa<IntegerType, IndexType>(eltType);
1195 vector::CombiningKind kind = op.getKind();
1199 auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1202 if (maskableOp.isMasked()) {
1204 rootOp = maskableOp.getMaskingOp();
1205 mask = maskableOp.getMaskingOp().getMask();
1213 vector::BroadcastOp::create(rewriter, loc, lhsType, op.getRhs());
1215 loc, op.getLhs(),
b,
acc, kind, rewriter, isInt, mask);
1216 if (!mult.has_value())
1222 Value result = arith::ConstantOp::create(rewriter, loc, resType,
1224 for (
int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1225 Value x = vector::ExtractOp::create(rewriter, loc, op.getLhs(), d);
1226 Value a = vector::BroadcastOp::create(rewriter, loc, rhsType, x);
1229 r = vector::ExtractOp::create(rewriter, loc,
acc, d);
1232 extrMask = vector::ExtractOp::create(rewriter, loc, mask, d);
1235 loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
1238 result = vector::InsertOp::create(rewriter, loc, *m,
result, d);
1250 VectorContractLowering vectorContractLoweringOption,
PatternBenefit benefit,
1251 bool disableOuterProductLowering) {
1252 if (!disableOuterProductLowering)
1254 patterns.
add<ContractionOpLowering, ContractionOpToOuterProductOpLowering>(
1255 vectorContractLoweringOption, patterns.
getContext(), benefit);
static int64_t product(ArrayRef< int64_t > vals)
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< int64_t > getReductionIndex(AffineMap map, ArrayAttr iteratorTypes)
Return the positions of the reductions in the given map.
static std::optional< unsigned > getDimPosition(AffineMap map, unsigned dim)
Look for a given dimension in an affine map and return its position.
static Value reshapeStore(Location loc, Value val, Value result, int64_t index, int64_t pos, PatternRewriter &rewriter)
Inserts val into result at position pos along dimension index.
static SmallVector< Attribute > adjustIter(ArrayAttr iteratorTypes, int64_t index)
static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter, arith::FastMathFlagsAttr fmf={})
Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y.
static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter)
static Value reshapeLoad(Location loc, Value val, int64_t index, int64_t pos, PatternRewriter &rewriter)
Returns val with the dimension at position index dropped by indexing that dimension with pos.
static std::optional< Value > createContractArithOp(Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, bool isInt, Value mask=Value(), arith::FastMathFlagsAttr fmf={})
Helper to create arithmetic operation associated with a kind of contraction.
static std::string diag(const llvm::Value &value)
Progressive lowering of OuterProductOp.
LogicalResult matchAndRewrite(vector::OuterProductOp op, PatternRewriter &rewriter) const override
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
TypedAttr getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
bool iters(ArrayRef< IteratorType > its)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Lower vector.contract with all size one reduction dimensions to elementwise ops when possible.
FailureOr< Value > matchAndRewriteMaskableOp(vector::ContractionOp contractOp, MaskingOpInterface maskOp, PatternRewriter &rewriter) const override
static LogicalResult defaultFilter(vector::ContractionOp op)
ContractOpToElementwise(vector::VectorContractLowering vectorContractLowering, MLIRContext *context, PatternBenefit benefit=1, const FilterConstraintType &constraint=defaultFilter)
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
virtual FailureOr< Value > matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp, PatternRewriter &rewriter) const =0