26#define DEBUG_TYPE "vector-contract-lowering"
48 for (
const auto &it : llvm::enumerate(iteratorTypes)) {
52 results.push_back(it.value());
69 results.push_back(targetExpr);
84 return vector::ExtractOp::create(rewriter, loc, val, pos);
89 Value result = arith::ConstantOp::create(rewriter, loc, resType,
91 for (
int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
92 Value ext = vector::ExtractOp::create(rewriter, loc, val, d);
109 return vector::InsertOp::create(rewriter, loc, val,
result, pos);
113 for (
int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
114 Value ext = vector::ExtractOp::create(rewriter, loc,
result, d);
115 Value ins = vector::ExtractOp::create(rewriter, loc, val, d);
117 result = vector::InsertOp::create(rewriter, loc, sto,
result, d);
123static std::optional<Value>
127 arith::FastMathFlagsAttr fmf = {}) {
128 using vector::CombiningKind;
132 if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
133 kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
136 mul = arith::MulIOp::create(rewriter, loc, x, y);
139 if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
140 kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
141 kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
142 kind == CombiningKind::XOR)
146 if (
acc && isa<VectorType>(
acc.getType()) && kind == CombiningKind::ADD) {
147 Value fma = vector::FMAOp::create(rewriter, loc, x, y,
acc);
154 mul = arith::MulFOp::create(rewriter, loc, x, y, fmf);
158 return std::optional<Value>(
mul);
169 dimsIdx.push_back(i);
188 arith::FastMathFlagsAttr fmf = {}) {
190 return arith::AddIOp::create(rewriter, loc, x, y);
191 return arith::AddFOp::create(rewriter, loc, x, y, fmf);
198 arith::FastMathFlagsAttr fmf = {}) {
200 return arith::MulIOp::create(rewriter, loc, x, y);
201 return arith::MulFOp::create(rewriter, loc, x, y, fmf);
221class ContractionOpToOuterProductOpLowering
224 using MaskableOpRewritePattern::MaskableOpRewritePattern;
226 using FilterConstraintType =
227 std::function<LogicalResult(vector::ContractionOp op)>;
229 static LogicalResult defaultFilter(vector::ContractionOp op) {
233 ContractionOpToOuterProductOpLowering(
234 vector::VectorContractLowering vectorContractLowering,
235 MLIRContext *context, PatternBenefit benefit = 1,
236 FilterConstraintType constraint = defaultFilter)
237 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
238 vectorContractLowering(vectorContractLowering),
239 filter(std::move(constraint)) {}
243 PatternRewriter &rewriter)
const override;
247 vector::VectorContractLowering vectorContractLowering;
248 FilterConstraintType filter;
269class ContractionOpToDotLowering
272 using MaskableOpRewritePattern::MaskableOpRewritePattern;
274 using FilterConstraintType =
275 std::function<LogicalResult(vector::ContractionOp op)>;
277 static LogicalResult defaultFilter(vector::ContractionOp op) {
281 ContractionOpToDotLowering(
282 vector::VectorContractLowering vectorContractLowering,
283 MLIRContext *context, PatternBenefit benefit = 1,
284 const FilterConstraintType &constraint = defaultFilter)
285 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
286 vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
290 PatternRewriter &rewriter)
const override;
294 vector::VectorContractLowering vectorContractLowering;
295 FilterConstraintType filter;
312class ContractionOpLowering
315 using MaskableOpRewritePattern::MaskableOpRewritePattern;
316 using FilterConstraintType =
317 std::function<LogicalResult(vector::ContractionOp op)>;
319 static LogicalResult defaultFilter(vector::ContractionOp op) {
323 ContractionOpLowering(
324 vector::VectorContractLowering vectorContractLoweringOption,
325 MLIRContext *context, PatternBenefit benefit = 1,
326 FilterConstraintType constraint = defaultFilter)
327 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
328 vectorContractLoweringOption(vectorContractLoweringOption),
329 filter(std::move(constraint)) {}
333 PatternRewriter &rewriter)
const override;
337 vector::VectorContractLowering vectorContractLoweringOption;
338 FilterConstraintType filter;
340 FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
341 vector::ContractionOp op, int64_t lhsIndex,
342 int64_t rhsIndex, Value mask)
const;
344 FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
345 vector::ContractionOp op, Value mask)
const;
350struct UnrolledOuterProductGenerator
352 UnrolledOuterProductGenerator(RewriterBase &
b, vector::ContractionOp op)
353 : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(
b, op),
354 kind(op.getKind()),
lhs(op.getLhs()),
rhs(op.getRhs()),
355 res(op.getAcc()), lhsType(op.getLhsType()) {
356 auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
357 if (maskableOp.isMasked())
358 mask = maskableOp.getMaskingOp().getMask();
361 Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
364 return vector::TransposeOp::create(rewriter, loc, v, perm);
367 Value
promote(Value v, Type dstElementType) {
368 Type elementType = v.
getType();
369 auto vecType = dyn_cast<VectorType>(elementType);
371 elementType = vecType.getElementType();
372 if (elementType == dstElementType)
374 Type promotedType = dstElementType;
376 promotedType = vecType.clone(promotedType);
377 if (isa<FloatType>(dstElementType))
378 return arith::ExtFOp::create(rewriter, loc, promotedType, v);
379 return arith::ExtSIOp::create(rewriter, loc, promotedType, v);
382 FailureOr<Value> outerProd(Value
lhs, Value
rhs, Value res,
383 VectorType lhsType,
int reductionSize,
384 std::optional<Value> maybeMask = std::nullopt) {
386 if (mask && !maybeMask.has_value())
389 Type resElementType = cast<VectorType>(res.
getType()).getElementType();
390 for (int64_t k = 0; k < reductionSize; ++k) {
391 Value extractA = vector::ExtractOp::create(rewriter, loc,
lhs, k);
392 Value extractB = vector::ExtractOp::create(rewriter, loc,
rhs, k);
393 extractA =
promote(extractA, resElementType);
394 extractB =
promote(extractB, resElementType);
396 if (maybeMask.has_value() && maybeMask.value())
398 vector::ExtractOp::create(rewriter, loc, maybeMask.value(), k);
400 Operation *outerProdOp = vector::OuterProductOp::create(
401 rewriter, loc, res.
getType(), extractA, extractB, res, kind);
410 std::optional<int64_t> getReductionSize(VectorType vecType,
411 int64_t reductionDim) {
413 if (vecType.getScalableDims()[reductionDim])
415 int64_t reductionSize = vecType.getDimSize(reductionDim);
416 assert(reductionSize > 0 &&
417 "Reduction dim must be a known static size to allow unrolling");
418 return reductionSize;
422 FailureOr<Value> matmat() {
423 if (!
iters({Par(), Par(), Red()}))
430 if (
layout({{m, k}, {k, n}, {m, n}})) {
431 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
435 Value tMask = t(mask, {2, 0, 1});
436 return outerProd(tLhs,
rhs, res, lhsType, *reductionSize, tMask);
440 if (
layout({{m, k}, {n, k}, {m, n}})) {
441 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
444 Value tMask = t(mask, {2, 0, 1});
445 return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
449 if (
layout({{k, m}, {k, n}, {m, n}})) {
450 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
451 Value tMask = t(mask, {2, 0, 1});
452 return outerProd(
lhs,
rhs, res, lhsType, *reductionSize, tMask);
456 if (
layout({{k, m}, {n, k}, {m, n}})) {
457 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
459 Value tMask = t(mask, {2, 0, 1});
460 return outerProd(
lhs, tRhs, res, lhsType, *reductionSize, tMask);
465 if (
layout({{m, k}, {k, n}, {n, m}})) {
466 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
468 Value tMask = t(mask, {2, 0, 1});
469 return outerProd(
rhs, tLhs, res, lhsType, *reductionSize, tMask);
473 if (
layout({{m, k}, {n, k}, {n, m}})) {
474 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
477 Value tMask = t(mask, {2, 0, 1});
478 return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
481 if (
layout({{k, m}, {k, n}, {n, m}})) {
482 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
483 Value tMask = t(mask, {2, 0, 1});
484 return outerProd(
rhs,
lhs, res, lhsType, *reductionSize, tMask);
487 if (
layout({{k, m}, {n, k}, {n, m}})) {
488 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
490 Value tMask = t(mask, {2, 0, 1});
491 return outerProd(tRhs,
lhs, res, lhsType, *reductionSize, tMask);
502 FailureOr<Value> matvec() {
503 if (!
iters({Par(), Red()}))
509 if (
layout({{m, k}, {k}, {m}})) {
510 if (
auto reductionSize = getReductionSize(lhsType, 1)) {
512 Value tMask = t(mask);
513 return outerProd(tLhs,
rhs, res, lhsType, *reductionSize, tMask);
517 if (
layout({{k, m}, {k}, {m}})) {
518 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
519 Value tMask = t(mask);
520 return outerProd(
lhs,
rhs, res, lhsType, *reductionSize, tMask);
524 if (
layout({{k}, {m, k}, {m}})) {
525 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
527 Value tMask = t(mask);
528 return outerProd(tRhs,
lhs, res, lhsType, *reductionSize, tMask);
532 if (
layout({{k}, {k, m}, {m}})) {
533 if (
auto reductionSize = getReductionSize(lhsType, 0)) {
534 Value tMask = t(mask);
535 return outerProd(
rhs,
lhs, res, lhsType, *reductionSize, tMask);
545 FailureOr<Value> tmatvec() {
546 if (!
iters({Red(), Par()}))
552 if (
layout({{m, k}, {k}, {m}}))
553 if (
auto reductionSize = getReductionSize(lhsType, 1))
554 return outerProd(t(
lhs),
rhs, res, lhsType, *reductionSize, mask);
556 if (
layout({{k, m}, {k}, {m}}))
557 if (
auto reductionSize = getReductionSize(lhsType, 0))
558 return outerProd(
lhs,
rhs, res, lhsType, *reductionSize, mask);
560 if (
layout({{k}, {m, k}, {m}}))
561 if (
auto reductionSize = getReductionSize(lhsType, 0))
562 return outerProd(t(
rhs),
lhs, res, lhsType, *reductionSize, mask);
564 if (
layout({{k}, {k, m}, {m}}))
565 if (
auto reductionSize = getReductionSize(lhsType, 0))
566 return outerProd(
rhs,
lhs, res, lhsType, *reductionSize, mask);
571 vector::CombiningKind kind;
572 Value
lhs,
rhs, res, mask;
592ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
593 vector::ContractionOp op, MaskingOpInterface maskOp,
595 if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
601 UnrolledOuterProductGenerator e(rewriter, op);
602 FailureOr<Value> matmatRes = e.matmat();
603 if (succeeded(matmatRes)) {
606 FailureOr<Value> matvecRes = e.matvec();
607 if (succeeded(matvecRes)) {
611 FailureOr<Value> tmatvecRes = e.tmatvec();
615FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
616 vector::ContractionOp op, MaskingOpInterface maskOp,
625 if (vectorContractLowering != vector::VectorContractLowering::Dot)
628 auto iteratorTypes = op.getIteratorTypes().getValue();
629 static constexpr std::array<int64_t, 2> perm = {1, 0};
634 auto infer = [&](MapList m) {
650 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
651 rhs = vector::TransposeOp::create(rewriter, loc,
rhs, perm);
652 }
else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
654 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
655 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
656 rhs = vector::TransposeOp::create(rewriter, loc,
rhs, perm);
657 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
658 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
659 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
662 lhs = vector::TransposeOp::create(rewriter, loc,
rhs, perm);
664 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
666 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
668 lhs = vector::TransposeOp::create(rewriter, loc,
rhs, perm);
669 rhs = vector::TransposeOp::create(rewriter, loc, tmp, perm);
670 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
672 rhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
682 if (maps == infer({{m, n}, {n}, {m}})) {
684 }
else if (maps == infer({{n, m}, {n}, {m}})) {
685 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
686 }
else if (maps == infer({{n}, {m, n}, {m}})) {
688 }
else if (maps == infer({{n}, {n, m}, {m}})) {
690 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
698 VectorType dstType = cast<VectorType>(op.getResultType());
699 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
700 "Expected dst type of rank 1 or 2");
702 unsigned rank = dstType.getRank();
703 unsigned dstRows = dstType.getShape()[0];
704 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
707 Value res = arith::ConstantOp::create(rewriter, loc, dstType,
709 bool isInt = isa<IntegerType>(dstType.getElementType());
710 arith::FastMathFlagsAttr fmf = op.getFastmathAttr();
712 extractedCols.reserve(dstColumns);
713 for (
unsigned r = 0; r < dstRows; ++r) {
714 Value rowLhs = vector::ExtractOp::create(rewriter, op.getLoc(),
lhs, r);
715 for (
unsigned c = 0; c < dstColumns; ++c) {
722 : vector::ExtractOp::create(rewriter, op.getLoc(),
rhs, c);
723 extractedCols.push_back(colRhs);
725 Value extractedColRhs = extractedCols[c];
727 createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter, fmf);
728 Value sum = vector::ReductionOp::create(rewriter, op.getLoc(),
729 vector::CombiningKind::ADD,
734 res = vector::InsertOp::create(rewriter, op.getLoc(), sum, res, pos);
737 if (
auto acc = op.getAcc())
738 res =
createAdd(op.getLoc(), res,
acc, isInt, rewriter, fmf);
746 using MaskableOpRewritePattern::MaskableOpRewritePattern;
748 std::function<LogicalResult(vector::ContractionOp op)>;
753 vector::VectorContractLowering vectorContractLowering,
757 vectorContractLowering(vectorContractLowering), filter(
defaultFilter) {}
761 MaskingOpInterface maskOp,
767 if (failed(filter(contractOp)))
770 if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
775 AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
776 AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
782 for (
int64_t dim : lhsReductionDims) {
783 if (lhsShape[dim] != 1)
786 for (
int64_t dim : rhsReductionDims) {
787 if (rhsShape[dim] != 1)
790 AffineMap accMap = contractOp.getIndexingMapsArray()[2];
792 unsigned numLhsDimToBroadcast =
793 numParallelDims - (lhsMap.
getNumResults() - lhsReductionDims.size());
794 unsigned numRhsDimToBroadcast =
795 numParallelDims - (rhsMap.
getNumResults() - rhsReductionDims.size());
800 for (
int64_t dim : lhsReductionDims)
801 lhsTranspose.push_back(numLhsDimToBroadcast + dim);
802 for (
int64_t dim : rhsReductionDims)
803 rhsTranspose.push_back(numRhsDimToBroadcast + dim);
806 for (
unsigned i = 0; i < numParallelDims; i++) {
807 std::optional<unsigned> lhsDim =
810 lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
814 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
815 lhsTranspose.push_back(lhsDims.size() - 1);
817 std::optional<unsigned> rhsDim =
820 rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
824 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
825 rhsTranspose.push_back(rhsDims.size() - 1);
828 Value newLhs = contractOp.getLhs();
829 Value newRhs = contractOp.getRhs();
831 if (!lhsDims.empty()) {
832 lhsDims.append(lhsShape.begin(), lhsShape.end());
834 VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
835 newLhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newLhs);
837 if (!rhsDims.empty()) {
838 rhsDims.append(rhsShape.begin(), rhsShape.end());
840 VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
841 newRhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newRhs);
843 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
844 newLhs = vector::TransposeOp::create(rewriter, loc, newLhs, lhsTranspose);
845 newRhs = vector::TransposeOp::create(rewriter, loc, newRhs, rhsTranspose);
848 newLhs = vector::ExtractOp::create(rewriter, loc, newLhs, lhsOffsets);
849 newRhs = vector::ExtractOp::create(rewriter, loc, newRhs, rhsOffsets);
850 std::optional<Value>
result =
852 contractOp.getKind(), rewriter, isInt,
853 Value(), contractOp.getFastmathAttr());
862 vector::VectorContractLowering vectorContractLowering;
883FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
884 vector::ContractionOp op, MaskingOpInterface maskOp,
890 if (op.getLhsType().getElementType() !=
897 if (op.getKind() != vector::CombiningKind::ADD) {
899 op,
"contractions other than 'add' not supported");
905 ContractionOpToOuterProductOpLowering pat1(vectorContractLoweringOption, ctx);
906 FailureOr<Value> newVal1 =
907 pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
911 ContractionOpToDotLowering pat2(vectorContractLoweringOption, ctx);
912 FailureOr<Value> newVal2 =
913 pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
918 FailureOr<Value> newVal4 =
919 pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
927 mask = maskOp.getMask();
929 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
930 if (!batchDimMap.empty()) {
931 int64_t lhsIndex = batchDimMap[0].first;
932 int64_t rhsIndex = batchDimMap[0].second;
933 auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
940 std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
941 op.getContractingDimMap();
944 for (
auto &dimPair : contractingDimMap) {
945 lhsContractingDimSet.insert(dimPair.first);
946 rhsContractingDimSet.insert(dimPair.second);
950 VectorType lhsType = op.getLhsType();
951 for (
int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
952 if (lhsContractingDimSet.count(lhsIndex) == 0) {
953 auto newOp = lowerParallel(rewriter, op, lhsIndex, -1, mask);
961 VectorType rhsType = op.getRhsType();
962 for (
int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
963 if (rhsContractingDimSet.count(rhsIndex) == 0) {
964 auto newOp = lowerParallel(rewriter, op, -1, rhsIndex, mask);
972 if (!contractingDimMap.empty()) {
973 auto newOp = lowerReduction(rewriter, op, mask);
985FailureOr<Value> ContractionOpLowering::lowerParallel(
PatternRewriter &rewriter,
986 vector::ContractionOp op,
990 VectorType lhsType = op.getLhsType();
991 VectorType rhsType = op.getRhsType();
992 VectorType resType = cast<VectorType>(op.getResultType());
998 iterIndex = iMap[0].getDimPosition(lhsIndex);
999 if (rhsIndex >= 0 && iterIndex != iMap[1].
getDimPosition(rhsIndex))
1001 diag <<
"expected lhsIndex=" << lhsIndex <<
" and rhsIndex=" << rhsIndex
1002 <<
" to map to the same dimension";
1004 if (lhsType.getScalableDims()[lhsIndex])
1006 diag <<
"Unrolling scalable dimension (lhsIndex=" << lhsIndex
1007 <<
") is not supported yet";
1009 dimSize = lhsType.getDimSize(lhsIndex);
1010 }
else if (rhsIndex >= 0) {
1011 iterIndex = iMap[1].getDimPosition(rhsIndex);
1012 if (rhsType.getScalableDims()[rhsIndex])
1014 diag <<
"Unrolling scalable dimension (rhsIndex=" << rhsIndex
1015 <<
") is not supported yet";
1017 dimSize = rhsType.getDimSize(rhsIndex);
1021 diag <<
"expected either lhsIndex=" << lhsIndex
1022 <<
" or rhsIndex=" << rhsIndex <<
" to be nonnegative";
1032 if (resIndex == -1 && dimSize != 1)
1034 diag <<
"expected the dimension for iterIndex=" << iterIndex
1035 <<
" to either appear in the result map, or to be a unit dimension";
1039 std::array<AffineMap, 3> lowIndexingMaps = {
1040 adjustMap(iMap[0], iterIndex, rewriter),
1041 adjustMap(iMap[1], iterIndex, rewriter),
1042 adjustMap(iMap[2], iterIndex, rewriter)};
1048 Value result = arith::ConstantOp::create(rewriter, loc, resType,
1051 for (
int64_t d = 0; d < dimSize; ++d) {
1052 auto lhs =
reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1053 auto rhs =
reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1054 auto acc =
reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1059 iterIndex, d, rewriter);
1062 vector::ContractionOp::create(rewriter, loc,
lhs,
rhs,
acc, lowAffine,
1063 lowIter, op.getKind(), op.getFastmath());
1064 lowContract =
maskOperation(rewriter, lowContract, lowMask);
1066 resIndex, d, rewriter);
1072FailureOr<Value> ContractionOpLowering::lowerReduction(
1074 auto loc = op.getLoc();
1075 VectorType lhsType = op.getLhsType();
1076 VectorType rhsType = op.getRhsType();
1077 Type resType = op.getResultType();
1078 if (isa<VectorType>(resType))
1080 "did not expect a VectorType result");
1081 bool isInt = isa<IntegerType>(resType);
1085 std::optional<int64_t> lookupLhs =
getResultIndex(iMap[0], iterIndex);
1086 std::optional<int64_t> lookupRhs =
getResultIndex(iMap[1], iterIndex);
1087 if (!lookupLhs.has_value())
1089 diag <<
"expected iterIndex=" << iterIndex <<
"to map to a LHS dimension";
1091 if (!lookupRhs.has_value())
1093 diag <<
"expected iterIndex=" << iterIndex <<
"to map to a RHS dimension";
1095 int64_t lhsIndex = *lookupLhs;
1096 int64_t rhsIndex = *lookupRhs;
1097 int64_t dimSize = lhsType.getDimSize(lhsIndex);
1098 if (dimSize != rhsType.getDimSize(rhsIndex))
1100 diag <<
"expect LHS dimension " << lhsIndex
1101 <<
" to have the same size as RHS dimension " << rhsIndex;
1104 if (lhsType.getRank() == 1) {
1105 if (rhsType.getRank() != 1)
1107 op,
"When LHS has rank 1, expected also RHS to have rank 1");
1108 arith::FastMathFlagsAttr fmf = op.getFastmathAttr();
1109 Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter, fmf);
1110 auto kind = vector::CombiningKind::ADD;
1114 acc ? vector::ReductionOp::create(rewriter, loc, kind, m,
acc,
1116 :
vector::ReductionOp::create(rewriter, loc, kind, m,
1121 std::array<AffineMap, 3> lowIndexingMaps = {
1122 adjustMap(iMap[0], iterIndex, rewriter),
1123 adjustMap(iMap[1], iterIndex, rewriter),
1124 adjustMap(iMap[2], iterIndex, rewriter)};
1133 for (
int64_t d = 0; d < dimSize; ++d) {
1134 auto lhs =
reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1135 auto rhs =
reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1139 iterIndex, d, rewriter);
1141 Operation *newContract = vector::ContractionOp::create(
1142 rewriter, loc,
lhs,
rhs,
result, lowAffine, lowIter, op.getKind(),
1168 VectorType resType = op.getResultVectorType();
1169 if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1172 auto loc = op.getLoc();
1174 VectorType lhsType = op.getOperandVectorTypeLHS();
1175 VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1176 Type eltType = resType.getElementType();
1177 bool isInt = isa<IntegerType, IndexType>(eltType);
1179 vector::CombiningKind kind = op.getKind();
1183 auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1186 if (maskableOp.isMasked()) {
1188 rootOp = maskableOp.getMaskingOp();
1189 mask = maskableOp.getMaskingOp().getMask();
1197 vector::BroadcastOp::create(rewriter, loc, lhsType, op.getRhs());
1199 loc, op.getLhs(),
b,
acc, kind, rewriter, isInt, mask);
1200 if (!mult.has_value())
1206 Value result = arith::ConstantOp::create(rewriter, loc, resType,
1208 for (
int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1209 Value x = vector::ExtractOp::create(rewriter, loc, op.getLhs(), d);
1210 Value a = vector::BroadcastOp::create(rewriter, loc, rhsType, x);
1213 r = vector::ExtractOp::create(rewriter, loc,
acc, d);
1216 extrMask = vector::ExtractOp::create(rewriter, loc, mask, d);
1219 loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
1222 result = vector::InsertOp::create(rewriter, loc, *m,
result, d);
1234 VectorContractLowering vectorContractLoweringOption,
PatternBenefit benefit,
1235 bool disableOuterProductLowering) {
1236 if (!disableOuterProductLowering)
1238 patterns.
add<ContractionOpLowering, ContractionOpToOuterProductOpLowering>(
1239 vectorContractLoweringOption, patterns.
getContext(), benefit);
static int64_t product(ArrayRef< int64_t > vals)
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< int64_t > getReductionIndex(AffineMap map, ArrayAttr iteratorTypes)
Return the positions of the reductions in the given map.
static std::optional< unsigned > getDimPosition(AffineMap map, unsigned dim)
Look for a given dimension in an affine map and return its position.
static SmallVector< Attribute > adjustIter(ArrayAttr iteratorTypes, int64_t index)
static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter, arith::FastMathFlagsAttr fmf={})
Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y.
static Value reshapeStore(Location loc, Value val, Value result, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter)
static std::optional< Value > createContractArithOp(Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, bool isInt, Value mask=Value(), arith::FastMathFlagsAttr fmf={})
Helper to create arithmetic operation associated with a kind of contraction.
static Value reshapeLoad(Location loc, Value val, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static std::string diag(const llvm::Value &value)
Progressive lowering of OuterProductOp.
LogicalResult matchAndRewrite(vector::OuterProductOp op, PatternRewriter &rewriter) const override
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
TypedAttr getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
bool iters(ArrayRef< IteratorType > its)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Lower vector.contract with all size one reduction dimensions to elementwise ops when possible.
FailureOr< Value > matchAndRewriteMaskableOp(vector::ContractionOp contractOp, MaskingOpInterface maskOp, PatternRewriter &rewriter) const override
static LogicalResult defaultFilter(vector::ContractionOp op)
ContractOpToElementwise(vector::VectorContractLowering vectorContractLowering, MLIRContext *context, PatternBenefit benefit=1, const FilterConstraintType &constraint=defaultFilter)
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
virtual FailureOr< Value > matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp, PatternRewriter &rewriter) const =0