15 #include <type_traits> 32 #include "llvm/ADT/DenseSet.h" 33 #include "llvm/ADT/MapVector.h" 34 #include "llvm/ADT/STLExtras.h" 35 #include "llvm/Support/CommandLine.h" 36 #include "llvm/Support/Debug.h" 37 #include "llvm/Support/raw_ostream.h" 39 #define DEBUG_TYPE "vector-to-vector" 59 int64_t idx = it.index();
62 results.push_back(it.value());
79 results.push_back(targetExpr);
87 int64_t index, int64_t pos,
95 return rewriter.
create<vector::ExtractOp>(loc, lowType, val, posAttr);
98 VectorType vType = lowType.
cast<VectorType>();
100 auto resVectorType = resType.cast<VectorType>();
102 loc, resVectorType, rewriter.
getZeroAttr(resVectorType));
103 for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
105 Value ext = rewriter.
create<vector::ExtractOp>(loc, vType, val, posAttr);
107 result = rewriter.
create<vector::InsertOp>(loc, resVectorType, load, result,
116 VectorType type, int64_t index, int64_t pos,
124 return rewriter.
create<vector::InsertOp>(loc, type, val, result, posAttr);
128 VectorType vType = lowType.cast<VectorType>();
130 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
132 Value ext = rewriter.
create<vector::ExtractOp>(loc, vType, result, posAttr);
133 Value ins = rewriter.
create<vector::ExtractOp>(loc, insType, val, posAttr);
135 result = rewriter.
create<vector::InsertOp>(loc, type, sto, result, posAttr);
140 template <
typename IntType>
142 return llvm::to_vector<4>(llvm::map_range(
143 arrayAttr.getAsRange<IntegerAttr>(),
144 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
150 vector::CombiningKind kind,
153 using vector::CombiningKind;
156 if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
158 return Optional<Value>();
159 mul = rewriter.
create<arith::MulIOp>(loc, x, y);
162 if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
163 kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
164 kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
165 kind == CombiningKind::XOR)
167 return Optional<Value>();
169 if (acc && acc.
getType().
isa<VectorType>() && kind == CombiningKind::ADD) {
170 return Optional<Value>(rewriter.
create<vector::FMAOp>(loc, x, y, acc));
172 mul = rewriter.
create<arith::MulFOp>(loc, x, y);
175 return Optional<Value>(mul);
181 ArrayAttr iteratorTypes) {
185 dimsIdx.push_back(i);
221 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
224 auto sourceVectorType =
225 shapeCastOp.getSource().getType().dyn_cast_or_null<VectorType>();
226 auto resultVectorType =
227 shapeCastOp.getResult().getType().dyn_cast_or_null<VectorType>();
228 if (!sourceVectorType || !resultVectorType)
232 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
233 shapeCastOp.getSource().getDefiningOp());
234 if (!sourceShapeCastOp)
236 auto operandSourceVectorType =
237 sourceShapeCastOp.getSource().getType().cast<VectorType>();
238 auto operandResultVectorType = sourceShapeCastOp.getType();
241 if (operandSourceVectorType != resultVectorType ||
242 operandResultVectorType != sourceVectorType)
245 rewriter.
replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
257 auto loc = op.getLoc();
258 VectorType dstType = op.getVectorType();
259 VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
260 Type eltType = dstType.getElementType();
269 int64_t srcRank = srcType.getRank();
270 int64_t dstRank = dstType.getRank();
273 if (srcRank <= 1 && dstRank == 1) {
276 ext = rewriter.
create<vector::ExtractElementOp>(loc, op.getSource());
278 ext = rewriter.
create<vector::ExtractOp>(loc, op.getSource(), 0);
292 if (srcRank < dstRank) {
295 VectorType::get(dstType.getShape().drop_front(), eltType);
297 rewriter.
create<vector::BroadcastOp>(loc, resType, op.getSource());
300 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
301 result = rewriter.
create<vector::InsertOp>(loc, bcst, result, d);
307 assert(srcRank == dstRank);
309 for (int64_t r = 0; r < dstRank; r++)
310 if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
337 VectorType::get(dstType.getShape().drop_front(), eltType);
342 Value ext = rewriter.
create<vector::ExtractOp>(loc, op.getSource(), 0);
343 Value bcst = rewriter.
create<vector::BroadcastOp>(loc, resType, ext);
344 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
345 result = rewriter.
create<vector::InsertOp>(loc, bcst, result, d);
348 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
349 Value ext = rewriter.
create<vector::ExtractOp>(loc, op.getSource(), d);
350 Value bcst = rewriter.
create<vector::BroadcastOp>(loc, resType, ext);
351 result = rewriter.
create<vector::InsertOp>(loc, bcst, result, d);
363 size_t numTransposedDims = transpose.size();
364 for (
size_t transpDim : llvm::reverse(transpose)) {
365 if (transpDim != numTransposedDims - 1)
370 result.append(transpose.begin(), transpose.begin() + numTransposedDims);
389 vectorTransformOptions(vectorTransformOptions) {}
393 auto loc = op.getLoc();
395 Value input = op.getVector();
396 VectorType inputType = op.getVectorType();
397 VectorType resType = op.getResultType();
401 for (
auto attr : op.getTransp())
402 transp.push_back(attr.cast<IntegerAttr>().getInt());
406 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
408 op,
"Options specifies lowering to shuffle");
413 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
415 VectorType::get(resType.getNumElements(), resType.getElementType());
417 rewriter.
create<vector::ShapeCastOp>(loc, flattenedType, input);
420 Value trans = rewriter.
create<vector::FlatTransposeOp>(
421 loc, flattenedType, matrix, rows, columns);
432 pruneNonTransposedDims(transp, prunedTransp);
433 size_t numPrunedDims = transp.size() - prunedTransp.size();
434 auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
446 for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
448 auto extractIdxs =
delinearize(prunedInStrides, linearIdx);
452 rewriter.
create<vector::ExtractOp>(loc, input, extractIdxs);
454 rewriter.
create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
470 class TransposeOp2DToShuffleLowering
475 TransposeOp2DToShuffleLowering(
479 vectorTransformOptions(vectorTransformOptions) {}
483 auto loc = op.getLoc();
485 VectorType srcType = op.getVectorType();
486 if (srcType.getRank() != 2)
490 for (
auto attr : op.getTransp())
491 transp.push_back(attr.cast<IntegerAttr>().getInt());
492 if (transp[0] != 1 && transp[1] != 0)
499 int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
500 Value casted = rewriter.
create<vector::ShapeCastOp>(
501 loc, VectorType::get({m * n}, srcType.getElementType()),
505 for (int64_t
j = 0;
j < n; ++
j)
506 for (int64_t i = 0; i < m; ++i)
507 mask.push_back(i * n +
j);
510 rewriter.
create<vector::ShuffleOp>(loc, casted, casted, mask);
535 class OuterProductOpLowering :
public OpRewritePattern<vector::OuterProductOp> {
541 auto loc = op.getLoc();
543 VectorType lhsType = op.getOperandVectorTypeLHS();
544 VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
545 VectorType resType = op.getVectorType();
546 Type eltType = resType.getElementType();
547 bool isInt = eltType.
isa<IntegerType, IndexType>();
548 Value acc = (op.getAcc().empty()) ?
nullptr : op.getAcc()[0];
549 vector::CombiningKind kind = op.getKind();
553 Value b = rewriter.
create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
555 kind, rewriter, isInt);
556 if (!mult.has_value())
558 rewriter.replaceOp(op, mult.value());
564 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
567 rewriter.
create<vector::ExtractOp>(loc, eltType, op.getLhs(), pos);
568 Value a = rewriter.
create<vector::BroadcastOp>(loc, rhsType, x);
571 r = rewriter.
create<vector::ExtractOp>(loc, rhsType, acc, pos);
576 result = rewriter.create<vector::InsertOp>(loc, resType, m.value(),
586 struct ContractOpToElementwise
589 using FilterConstraintType =
590 std::function<LogicalResult(vector::ContractionOp op)>;
591 static LogicalResult defaultFilter(vector::ContractionOp op) {
594 ContractOpToElementwise(
597 const FilterConstraintType &constraint = defaultFilter)
599 vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
601 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
604 if (llvm::size(contractOp.getMasks()) != 0)
607 if (
failed(filter(contractOp)))
615 AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
616 AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
622 for (int64_t dim : lhsReductionDims) {
623 if (lhsShape[dim] != 1)
626 for (int64_t dim : rhsReductionDims) {
627 if (rhsShape[dim] != 1)
630 AffineMap accMap = contractOp.getIndexingMapsArray()[2];
632 unsigned numLhsDimToBroadcast =
633 numParallelDims - (lhsMap.
getNumResults() - lhsReductionDims.size());
634 unsigned numRhsDimToBroadcast =
635 numParallelDims - (rhsMap.
getNumResults() - rhsReductionDims.size());
640 for (int64_t dim : lhsReductionDims)
641 lhsTranspose.push_back(numLhsDimToBroadcast + dim);
642 for (int64_t dim : rhsReductionDims)
643 rhsTranspose.push_back(numRhsDimToBroadcast + dim);
646 for (
unsigned i = 0; i < numParallelDims; i++) {
650 lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
654 contractOp.getResultType().cast<VectorType>().getDimSize(i));
655 lhsTranspose.push_back(lhsDims.size() - 1);
660 rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
664 contractOp.getResultType().cast<VectorType>().getDimSize(i));
665 rhsTranspose.push_back(rhsDims.size() - 1);
668 Value newLhs = contractOp.getLhs();
669 Value newRhs = contractOp.getRhs();
671 if (!lhsDims.empty()) {
672 lhsDims.append(lhsShape.begin(), lhsShape.end());
674 VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
675 newLhs = rewriter.
create<vector::BroadcastOp>(loc, expandedType, newLhs);
677 if (!rhsDims.empty()) {
678 rhsDims.append(rhsShape.begin(), rhsShape.end());
680 VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
681 newRhs = rewriter.
create<vector::BroadcastOp>(loc, expandedType, newRhs);
683 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
684 newLhs = rewriter.
create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
685 newRhs = rewriter.
create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
688 newLhs = rewriter.
create<vector::ExtractOp>(
690 newRhs = rewriter.
create<vector::ExtractOp>(
692 Optional<Value> result =
694 contractOp.getKind(), rewriter, isInt);
695 rewriter.
replaceOp(contractOp, {*result});
702 FilterConstraintType filter;
716 class ConstantMaskOpLowering :
public OpRewritePattern<vector::ConstantMaskOp> {
722 auto loc = op.getLoc();
723 auto dstType = op.getType();
724 auto eltType = dstType.getElementType();
725 auto dimSizes = op.getMaskDimSizes();
726 int64_t rank = dstType.getRank();
729 assert(dimSizes.size() == 1 &&
730 "Expected exactly one dim size for a 0-D vector");
731 bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
741 if (dstType.cast<VectorType>().isScalable()) {
747 int64_t trueDim =
std::min(dstType.getDimSize(0),
748 dimSizes[0].cast<IntegerAttr>().getInt());
754 for (int64_t d = 0; d < trueDim; d++)
762 VectorType::get(dstType.getShape().drop_front(), eltType);
764 for (int64_t r = 1; r < rank; r++)
765 newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
766 Value trueVal = rewriter.
create<vector::ConstantMaskOp>(
770 for (int64_t d = 0; d < trueDim; d++) {
773 rewriter.
create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
790 class CreateMaskOpLowering :
public OpRewritePattern<vector::CreateMaskOp> {
796 auto dstType = op.getResult().getType().cast<VectorType>();
797 int64_t rank = dstType.getRank();
800 op,
"0-D and 1-D vectors are handled separately");
802 auto loc = op.getLoc();
803 auto eltType = dstType.getElementType();
804 int64_t dim = dstType.getDimSize(0);
805 Value idx = op.getOperand(0);
808 VectorType::get(dstType.getShape().drop_front(), eltType);
809 Value trueVal = rewriter.
create<vector::CreateMaskOp>(
810 loc, lowType, op.getOperands().drop_front());
811 Value falseVal = rewriter.
create<arith::ConstantOp>(
815 for (int64_t d = 0; d < dim; d++) {
818 Value val = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
820 Value sel = rewriter.
create<arith::SelectOp>(loc, val, trueVal, falseVal);
823 rewriter.
create<vector::InsertOp>(loc, dstType, sel, result, pos);
835 class ShapeCastOp2DDownCastRewritePattern
842 auto sourceVectorType = op.getSourceVectorType();
843 auto resultVectorType = op.getResultVectorType();
844 if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
847 auto loc = op.getLoc();
849 loc, resultVectorType, rewriter.
getZeroAttr(resultVectorType));
850 unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
851 for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
852 Value vec = rewriter.
create<vector::ExtractOp>(loc, op.getSource(), i);
853 desc = rewriter.
create<vector::InsertStridedSliceOp>(
855 i * mostMinorVectorSize, 1);
868 class ShapeCastOp2DUpCastRewritePattern
875 auto sourceVectorType = op.getSourceVectorType();
876 auto resultVectorType = op.getResultVectorType();
877 if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
880 auto loc = op.getLoc();
882 loc, resultVectorType, rewriter.
getZeroAttr(resultVectorType));
883 unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
884 for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
885 Value vec = rewriter.
create<vector::ExtractStridedSliceOp>(
886 loc, op.getSource(), i * mostMinorVectorSize,
889 desc = rewriter.
create<vector::InsertOp>(loc, vec, desc, i);
901 class ShapeCastOpRewritePattern :
public OpRewritePattern<vector::ShapeCastOp> {
908 auto sourceVectorType = op.getSourceVectorType();
909 auto resultVectorType = op.getResultVectorType();
913 int64_t srcRank = sourceVectorType.getRank();
914 int64_t resRank = resultVectorType.getRank();
915 if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
924 for (int64_t r = 0; r < srcRank; r++)
925 numElts *= sourceVectorType.getDimSize(r);
935 loc, resultVectorType, rewriter.
getZeroAttr(resultVectorType));
936 for (int64_t i = 0; i < numElts; i++) {
938 incIdx(srcIdx, sourceVectorType, srcRank - 1);
939 incIdx(resIdx, resultVectorType, resRank - 1);
941 Value e = rewriter.
create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
942 result = rewriter.
create<vector::InsertOp>(loc, e, result, resIdx);
950 assert(0 <= r && r < tp.getRank());
951 if (++idx[r] == tp.getDimSize(r)) {
953 incIdx(idx, tp, r - 1);
975 struct MultiReduceToContract
979 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
981 if (reduceOp.getKind() != vector::CombiningKind::ADD)
983 Operation *mulOp = reduceOp.getSource().getDefiningOp();
984 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
991 if (!isReduceDim.value()) {
999 0, exprs, reduceOp.getContext());
1031 struct CombineContractTranspose
1035 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1038 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
1039 Value lhs = contractOp.getLhs();
1040 Value rhs = contractOp.getRhs();
1042 bool changed =
false;
1043 for (
Value *operand : {&lhs, &rhs}) {
1045 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
1049 transposeOp.getTransp(perm);
1051 extractVector<unsigned>(transposeOp.getTransp()),
1052 contractOp.getContext());
1054 *operand = transposeOp.getVector();
1060 contractOp, lhs, rhs, contractOp.getAcc(),
1088 struct CombineContractBroadcast
1092 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1095 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
1096 Value lhs = contractOp.getLhs();
1097 Value rhs = contractOp.getRhs();
1099 bool changed =
false;
1100 for (
Value *operand : {&lhs, &rhs}) {
1107 if (!srcType || srcType.getRank() ==
broadcast.getVectorType().getRank())
1110 broadcast.getVectorType().getRank() - srcType.getRank();
1111 bool innerDimBroadcast =
false;
1115 broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
1116 innerDimBroadcast =
true;
1119 originalDims.push_back(
1124 if (innerDimBroadcast)
1129 bool nonUnitDimReductionBroadcast =
false;
1130 for (int64_t i = 0; i < rankDiff; ++i) {
1131 if (
broadcast.getVectorType().getDimSize(i) != 1 &&
1134 nonUnitDimReductionBroadcast =
true;
1138 if (nonUnitDimReductionBroadcast)
1143 contractOp.getContext());
1144 map = broadcastMap.
compose(map);
1156 for (
auto &m : maps)
1160 for (
unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
1161 if (!unusedDimsBitVector.test(i))
1162 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
1169 bool hasReductionIteratorApplyingOnBothSides =
false;
1170 for (
unsigned i = 0; i < iterators.size(); ++i) {
1174 hasReductionIteratorApplyingOnBothSides =
true;
1178 if (!hasReductionIteratorApplyingOnBothSides)
1186 contractOp, lhs, rhs, contractOp.getAcc(),
1205 struct ReorderCastOpsOnBroadcast
1211 if (op->getNumOperands() != 1)
1213 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
1218 if (
auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
1219 castResTy = VectorType::get(vecTy.getShape(), castResTy);
1221 rewriter.
create(op->getLoc(), op->getName().getIdentifier(),
1222 bcastOp.getSource(), castResTy, op->getAttrs());
1224 op, op->getResult(0).getType(), castOp->getResult(0));
1243 struct ReorderElementwiseOpsOnTranspose final
1259 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
1261 transposeMaps.push_back(transposeOp.getTransp());
1262 srcType = transposeOp.getVectorType();
1267 if (transposeMaps.empty())
1272 if (!llvm::is_splat(transposeMaps))
1280 auto order = extractVector<unsigned>(transposeMaps.front());
1282 for (
int i = 0, e = order.size(); i < e; ++i)
1283 invOrder[order[i]] = i;
1286 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
1288 srcValues.push_back(transposeOp.getVector());
1294 srcValues.push_back(rewriter.
create<vector::TransposeOp>(
1308 transposeMaps.front());
1320 return rewriter.
create<arith::AddIOp>(loc, x, y);
1321 return rewriter.
create<arith::AddFOp>(loc, x, y);
1329 return rewriter.
create<arith::MulIOp>(loc, x, y);
1330 return rewriter.
create<arith::MulFOp>(loc, x, y);
1356 if (llvm::size(op.getMasks()) != 0)
1358 if (vectorTransformOptions.vectorContractLowering !=
1364 auto iteratorTypes = op.getIteratorTypes().getValue();
1370 Type elementType = op.getLhsType().getElementType();
1374 Type dstElementType = op.getType();
1375 if (
auto vecType = dstElementType.
dyn_cast<VectorType>())
1376 dstElementType = vecType.getElementType();
1377 if (elementType != dstElementType)
1387 Value lhs = op.getLhs();
1388 auto lhsMap = op.getIndexingMapsArray()[0];
1395 Value rhs = op.getRhs();
1396 auto rhsMap = op.getIndexingMapsArray()[1];
1403 VectorType lhsType = lhs.getType().cast<VectorType>();
1404 VectorType rhsType = rhs.getType().cast<VectorType>();
1405 int64_t lhsRows = lhsType.getDimSize(0);
1406 int64_t lhsColumns = lhsType.getDimSize(1);
1407 int64_t rhsColumns = rhsType.getDimSize(1);
1409 Type flattenedLHSType =
1410 VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1411 lhs = rew.
create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1413 Type flattenedRHSType =
1414 VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1415 rhs = rew.
create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1417 Value mul = rew.
create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1419 mul = rew.
create<vector::ShapeCastOp>(
1421 VectorType::get({lhsRows, rhsColumns},
1426 auto accMap = op.getIndexingMapsArray()[2];
1430 llvm_unreachable(
"invalid contraction semantics");
1433 elementType.
isa<IntegerType>()
1434 ? static_cast<Value>(rew.
create<arith::AddIOp>(loc, op.getAcc(), mul))
1435 :
static_cast<Value>(
1436 rew.
create<arith::AddFOp>(loc, op.getAcc(), mul));
1446 auto sAttr = attr.
dyn_cast<StringAttr>();
1447 return sAttr && sAttr.getValue() ==
strRef;
1460 struct UnrolledOuterProductGenerator
1462 UnrolledOuterProductGenerator(
OpBuilder &builder, vector::ContractionOp op)
1464 kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
1465 res(op.getAcc()), lhsType(op.getLhsType()) {}
1468 static constexpr std::array<int64_t, 2>
perm = {1, 0};
1469 return builder.create<vector::TransposeOp>(loc, v,
perm);
1474 auto vecType = elementType.
dyn_cast<VectorType>();
1476 elementType = vecType.getElementType();
1477 if (elementType == dstElementType)
1479 Type promotedType = dstElementType;
1481 promotedType = VectorType::get(vecType.getShape(), promotedType);
1483 return builder.create<arith::ExtFOp>(loc, promotedType, v);
1484 return builder.create<arith::ExtSIOp>(loc, promotedType, v);
1488 assert(reductionSize > 0);
1490 for (int64_t k = 0; k < reductionSize; ++k) {
1491 Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
1492 Value b = builder.create<vector::ExtractOp>(loc, rhs, k);
1493 a = promote(a, resElementType);
1494 b = promote(b, resElementType);
1495 res = builder.create<vector::OuterProductOp>(loc, res.
getType(), a, b,
1503 if (!iters({Par(), Par(), Red()}))
1507 bindDims(builder.getContext(), m, n, k);
1509 if (layout({{m, k}, {k, n}, {m, n}}))
1510 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1512 if (layout({{m, k}, {n, k}, {m, n}})) {
1513 Value tlhs = t(lhs);
1514 return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
1517 if (layout({{k, m}, {k, n}, {m, n}}))
1518 return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1520 if (layout({{k, m}, {n, k}, {m, n}}))
1521 return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
1524 if (layout({{m, k}, {k, n}, {n, m}}))
1525 return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
1527 if (layout({{m, k}, {n, k}, {n, m}})) {
1528 Value trhs = t(rhs);
1529 return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
1531 if (layout({{k, m}, {k, n}, {n, m}}))
1532 return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1533 if (layout({{k, m}, {n, k}, {n, m}}))
1534 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1540 if (!iters({Par(), Red()}))
1543 bindDims(builder.getContext(), m, k);
1546 if (layout({{m, k}, {k}, {m}}))
1547 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1549 if (layout({{k, m}, {k}, {m}}))
1550 return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1552 if (layout({{k}, {m, k}, {m}}))
1553 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1555 if (layout({{k}, {k, m}, {m}}))
1556 return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1564 if (!iters({Red(), Par()}))
1567 bindDims(builder.getContext(), k, m);
1570 if (layout({{m, k}, {k}, {m}}))
1571 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1573 if (layout({{k, m}, {k}, {m}}))
1574 return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1576 if (layout({{k}, {m, k}, {m}}))
1577 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1579 if (layout({{k}, {k, m}, {m}}))
1580 return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1585 vector::CombiningKind kind;
1586 Value lhs, rhs, res;
1609 if (llvm::size(op.getMasks()) != 0)
1612 if (vectorTransformOptions.vectorContractLowering !=
1619 UnrolledOuterProductGenerator e(rewriter, op);
1643 if (llvm::size(op.getMasks()) != 0)
1649 if (vectorTransformOptions.vectorContractLowering !=
1653 auto iteratorTypes = op.getIteratorTypes().getValue();
1654 static constexpr std::array<int64_t, 2>
perm = {1, 0};
1656 Value lhs = op.getLhs(), rhs = op.getRhs();
1673 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1674 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs,
perm);
1675 }
else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1677 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1678 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs,
perm);
1679 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs,
perm);
1680 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1681 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs,
perm);
1682 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1685 lhs = rewriter.
create<vector::TransposeOp>(loc, rhs,
perm);
1687 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1688 std::swap(lhs, rhs);
1689 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1691 lhs = rewriter.
create<vector::TransposeOp>(loc, rhs,
perm);
1692 rhs = rewriter.
create<vector::TransposeOp>(loc, tmp,
perm);
1693 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1695 rhs = rewriter.
create<vector::TransposeOp>(loc, lhs,
perm);
1705 if (maps == infer({{m, n}, {n}, {m}})) {
1707 }
else if (maps == infer({{n, m}, {n}, {m}})) {
1708 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs,
perm);
1709 }
else if (maps == infer({{n}, {m, n}, {m}})) {
1710 std::swap(lhs, rhs);
1711 }
else if (maps == infer({{n}, {n, m}, {m}})) {
1712 std::swap(lhs, rhs);
1713 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs,
perm);
1721 VectorType dstType = op.getResultType().cast<VectorType>();
1722 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
1723 "Expected dst type of rank 1 or 2");
1725 unsigned rank = dstType.getRank();
1726 unsigned dstRows = dstType.getShape()[0];
1727 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
1730 Value res = rewriter.
create<arith::ConstantOp>(loc, dstType,
1732 bool isInt = dstType.getElementType().isa<IntegerType>();
1733 for (
unsigned r = 0; r < dstRows; ++r) {
1734 Value a = rewriter.
create<vector::ExtractOp>(op.getLoc(), lhs, r);
1735 for (
unsigned c = 0; c < dstColumns; ++c) {
1738 : rewriter.
create<vector::ExtractOp>(op.getLoc(), rhs, c);
1740 Value reduced = rewriter.
create<vector::ReductionOp>(
1741 op.getLoc(), vector::CombiningKind::ADD, m);
1745 res = rewriter.
create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
1748 if (
auto acc = op.getAcc())
1749 res =
createAdd(op.getLoc(), res, acc, isInt, rewriter);
1775 if (llvm::size(op.getMasks()) != 0)
1782 if (op.getLhsType().getElementType() !=
1798 ContractOpToElementwise pat4(vectorTransformOptions, ctx);
1799 if (
succeeded(pat4.matchAndRewrite(op, rewriter)))
1803 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
1804 if (!batchDimMap.empty()) {
1805 int64_t lhsIndex = batchDimMap[0].first;
1806 int64_t rhsIndex = batchDimMap[0].second;
1807 auto newOp = lowerParallel(op, lhsIndex, rhsIndex, rewriter);
1815 std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
1816 op.getContractingDimMap();
1819 for (
auto &dimPair : contractingDimMap) {
1820 lhsContractingDimSet.insert(dimPair.first);
1821 rhsContractingDimSet.insert(dimPair.second);
1825 VectorType lhsType = op.getLhsType();
1826 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1827 if (lhsContractingDimSet.count(lhsIndex) == 0) {
1828 auto newOp = lowerParallel(op, lhsIndex, -1, rewriter);
1837 VectorType rhsType = op.getRhsType();
1838 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1839 if (rhsContractingDimSet.count(rhsIndex) == 0) {
1840 auto newOp = lowerParallel(op, -1, rhsIndex, rewriter);
1849 if (!contractingDimMap.empty()) {
1850 auto newOp = lowerReduction(op, rewriter);
1864 ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
1867 VectorType lhsType = op.getLhsType();
1868 VectorType rhsType = op.getRhsType();
1869 VectorType resType = op.getResultType().cast<VectorType>();
1872 int64_t iterIndex = -1;
1873 int64_t dimSize = -1;
1874 if (lhsIndex >= 0) {
1875 iterIndex = iMap[0].getDimPosition(lhsIndex);
1876 if (rhsIndex >= 0 && iterIndex != iMap[1].
getDimPosition(rhsIndex))
1878 diag <<
"expected lhsIndex=" << lhsIndex <<
" and rhsIndex=" << rhsIndex
1879 <<
" to map to the same dimension";
1881 dimSize = lhsType.getDimSize(lhsIndex);
1882 }
else if (rhsIndex >= 0) {
1883 iterIndex = iMap[1].getDimPosition(rhsIndex);
1884 dimSize = rhsType.getDimSize(rhsIndex);
1888 diag <<
"expected either lhsIndex=" << lhsIndex
1889 <<
" or rhsIndex=" << rhsIndex <<
" to be nonnegative";
1898 int64_t resIndex =
getResultIndex(iMap[2], iterIndex).value_or(-1);
1899 if (resIndex == -1 && dimSize != 1)
1901 diag <<
"expected the dimension for iterIndex=" << iterIndex
1902 <<
" to either appear in the result map, or to be a unit dimension";
1905 std::array<AffineMap, 3> lowIndexingMaps = {
1906 adjustMap(iMap[0], iterIndex, rewriter),
1907 adjustMap(iMap[1], iterIndex, rewriter),
1908 adjustMap(iMap[2], iterIndex, rewriter)};
1914 Value result = rewriter.
create<arith::ConstantOp>(
1916 for (int64_t d = 0; d < dimSize; ++d) {
1917 auto lhs =
reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1918 auto rhs =
reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1919 auto acc =
reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1920 Value lowContract = rewriter.
create<vector::ContractionOp>(
1921 loc, lhs, rhs, acc, lowAffine, lowIter);
1923 reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
1930 ContractionOpLowering::lowerReduction(vector::ContractionOp op,
1932 auto loc = op.getLoc();
1933 VectorType lhsType = op.getLhsType();
1934 VectorType rhsType = op.getRhsType();
1935 Type resType = op.getResultType();
1936 if (resType.
isa<VectorType>())
1938 "did not expect a VectorType result");
1939 bool isInt = resType.
isa<IntegerType>();
1941 int64_t iterIndex = 0;
1945 if (!lookupLhs.has_value())
1947 diag <<
"expected iterIndex=" << iterIndex <<
"to map to a LHS dimension";
1949 if (!lookupRhs.has_value())
1951 diag <<
"expected iterIndex=" << iterIndex <<
"to map to a RHS dimension";
1953 int64_t lhsIndex = lookupLhs.value();
1954 int64_t rhsIndex = lookupRhs.value();
1955 int64_t dimSize = lhsType.getDimSize(lhsIndex);
1956 if (dimSize != rhsType.getDimSize(rhsIndex))
1958 diag <<
"expect LHS dimension " << lhsIndex
1959 <<
" to have the same size as RHS dimension " << rhsIndex;
1962 if (lhsType.getRank() == 1) {
1963 if (rhsType.getRank() != 1)
1965 op,
"When LHS has rank 1, expected also RHS to have rank 1");
1966 Value m =
createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1967 auto kind = vector::CombiningKind::ADD;
1968 if (
auto acc = op.getAcc())
1969 return rewriter.
create<vector::ReductionOp>(loc, kind, m, acc)
1971 return rewriter.
create<vector::ReductionOp>(loc, kind, m).getResult();
1974 std::array<AffineMap, 3> lowIndexingMaps = {
1975 adjustMap(iMap[0], iterIndex, rewriter),
1976 adjustMap(iMap[1], iterIndex, rewriter),
1977 adjustMap(iMap[2], iterIndex, rewriter)};
1985 Value result = op.getAcc();
1986 for (int64_t d = 0; d < dimSize; ++d) {
1987 auto lhs =
reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1988 auto rhs =
reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1989 result = rewriter.
create<vector::ContractionOp>(loc, lhs, rhs, result,
1990 lowAffine, lowIter);
2011 unsigned multiplictyCount = 0;
2014 if (!affinExp || affinExp.getPosition() >= type.getRank() ||
2015 type.getDimSize(affinExp.getPosition()) %
2016 multiplicity[multiplictyCount++] !=
2022 builder.
create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
2024 builder.
create<vector::InsertMapOp>(loc, ops.
extract, result, ids);
2041 maxTransferRank(maxRank) {}
2045 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
2052 if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
2056 auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
2069 for (
unsigned i : broadcastedDims)
2070 unbroadcastedVectorShape[i] = 1;
2071 VectorType unbroadcastedVectorType = VectorType::get(
2072 unbroadcastedVectorShape, read.getVectorType().getElementType());
2076 auto memrefElTy = memRefType.getElementType();
2077 if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
2081 if (!memrefElTy.isa<VectorType>() &&
2082 memrefElTy != read.getVectorType().getElementType())
2086 if (read.hasOutOfBoundsDim())
2091 if (read.getMask()) {
2093 read.getLoc(), unbroadcastedVectorType, read.getPadding());
2094 loadOp = rewriter.
create<vector::MaskedLoadOp>(
2095 read.getLoc(), unbroadcastedVectorType, read.getSource(),
2096 read.getIndices(), read.getMask(), fill);
2098 loadOp = rewriter.
create<vector::LoadOp>(
2099 read.getLoc(), unbroadcastedVectorType, read.getSource(),
2104 if (!broadcastedDims.empty()) {
2106 read, read.getVectorType(), loadOp->
getResult(0));
2131 auto vecType = loadOp.getVectorType();
2132 if (vecType.getNumElements() != 1)
2134 auto memrefLoad = rewriter.
create<memref::LoadOp>(
2135 loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
2149 auto vecType = storeOp.getVectorType();
2150 if (vecType.getNumElements() != 1)
2153 if (vecType.getRank() == 0) {
2155 extracted = rewriter.
create<vector::ExtractElementOp>(
2156 storeOp.getLoc(), storeOp.getValueToStore());
2159 extracted = rewriter.
create<vector::ExtractOp>(
2160 storeOp.getLoc(), storeOp.getValueToStore(), indices);
2164 storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
2182 maxTransferRank(maxRank) {}
2186 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
2192 !write.getPermutationMap().isMinorIdentity())
2195 auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
2205 auto memrefElTy = memRefType.getElementType();
2206 if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
2210 if (!memrefElTy.isa<VectorType>() &&
2211 memrefElTy != write.getVectorType().getElementType())
2215 if (write.hasOutOfBoundsDim())
2217 if (write.getMask()) {
2219 write, write.getSource(), write.getIndices(), write.getMask(),
2223 write, write.getVector(), write.getSource(), write.getIndices());
2233 return llvm::to_vector<4>(
2234 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
2235 [](IntegerAttr attr) { return attr.getInt(); }));
2254 if (extractOp.getVectorType().getRank() != 1)
2257 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
2261 VectorType castSrcType = castOp.getSourceVectorType();
2262 VectorType castDstType = castOp.getResultVectorType();
2263 assert(castSrcType.getRank() == castDstType.getRank());
2268 if (castSrcType.getNumElements() == 1)
2273 if (castSrcType.getNumElements() > castDstType.getNumElements())
2276 unsigned expandRatio =
2277 castDstType.getNumElements() / castSrcType.getNumElements();
2280 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
2287 VectorType oneScalarType =
2288 VectorType::get({1}, castSrcType.getElementType());
2289 Value packedValue = rewriter.
create<vector::ExtractOp>(
2290 extractOp.getLoc(), oneScalarType, castOp.getSource(),
2295 VectorType packedType =
2296 VectorType::get({expandRatio}, castDstType.getElementType());
2297 Value castedValue = rewriter.
create<vector::BitCastOp>(
2298 extractOp.getLoc(), packedType, packedValue);
2302 extractOp, extractOp.getType(), castedValue,
2327 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
2331 VectorType castSrcType = castOp.getSourceVectorType();
2332 VectorType castDstType = castOp.getResultVectorType();
2333 assert(castSrcType.getRank() == castDstType.getRank());
2335 int64_t castSrcLastDim = castSrcType.getShape().back();
2336 int64_t castDstLastDim = castDstType.getShape().back();
2338 if (castSrcLastDim > castDstLastDim)
2342 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
2343 [](
const APInt &val) {
return !val.isOneValue(); }))
2346 unsigned rank = extractOp.getVectorType().getRank();
2347 assert(castDstLastDim % castSrcLastDim == 0);
2348 int64_t expandRatio = castDstLastDim / castSrcLastDim;
2354 ArrayAttr newOffsets = extractOp.getOffsets();
2355 if (newOffsets.size() == rank) {
2357 if (offsets.back() % expandRatio != 0)
2359 offsets.back() = offsets.back() / expandRatio;
2364 ArrayAttr newSizes = extractOp.getSizes();
2365 if (newSizes.size() == rank) {
2367 if (sizes.back() % expandRatio != 0)
2369 sizes.back() = sizes.back() / expandRatio;
2374 llvm::to_vector<4>(extractOp.getType().cast<VectorType>().
getShape());
2375 dims.back() = dims.back() / expandRatio;
2376 VectorType newExtractType =
2377 VectorType::get(dims, castSrcType.getElementType());
2379 auto newExtractOp = rewriter.
create<vector::ExtractStridedSliceOp>(
2380 extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
2381 newSizes, extractOp.getStrides());
2384 extractOp, extractOp.getType(), newExtractOp);
2406 VectorType castSrcType = bitcastOp.getSourceVectorType();
2407 VectorType castDstType = bitcastOp.getResultVectorType();
2408 assert(castSrcType.getRank() == castDstType.getRank());
2410 int64_t castSrcLastDim = castSrcType.getShape().back();
2411 int64_t castDstLastDim = castDstType.getShape().back();
2413 if (castSrcLastDim < castDstLastDim)
2416 assert(castSrcLastDim % castDstLastDim == 0);
2417 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
2420 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
2425 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
2426 [](
const APInt &val) {
return !val.isOneValue(); }))
2429 unsigned rank = insertOp.getSourceVectorType().getRank();
2432 if (rank != insertOp.getDestVectorType().getRank())
2435 ArrayAttr newOffsets = insertOp.getOffsets();
2436 assert(newOffsets.size() == rank);
2438 if (offsets.back() % shrinkRatio != 0)
2440 offsets.back() = offsets.back() / shrinkRatio;
2444 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
2445 srcDims.back() = srcDims.back() / shrinkRatio;
2446 VectorType newCastSrcType =
2447 VectorType::get(srcDims, castDstType.getElementType());
2449 auto newCastSrcOp = rewriter.
create<vector::BitCastOp>(
2450 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
2453 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
2454 dstDims.back() = dstDims.back() / shrinkRatio;
2455 VectorType newCastDstType =
2456 VectorType::get(dstDims, castDstType.getElementType());
2458 auto newCastDstOp = rewriter.
create<vector::BitCastOp>(
2459 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
2462 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
2463 insertOp.getStrides());
2479 bool force32BitVectorIndices, int64_t dim,
2488 if (dim == 0 && force32BitVectorIndices) {
2491 }
else if (dim == 0) {
2494 }
else if (force32BitVectorIndices) {
2496 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
2499 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
2501 Value indices = rewriter.
create<arith::ConstantOp>(loc, indicesAttr);
2506 indices = rewriter.
create<arith::AddIOp>(loc, ov, indices);
2511 rewriter.
create<vector::SplatOp>(loc, indices.
getType(), bound);
2512 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
2516 template <
typename ConcreteOp>
2521 force32BitVectorIndices(enableIndexOpt) {}
2525 if (!xferOp.hasOutOfBoundsDim())
2528 if (xferOp.getVectorType().getRank() > 1 ||
2529 llvm::size(xferOp.getIndices()) == 0)
2533 VectorType vtp = xferOp.getVectorType();
2540 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
2541 Value off = xferOp.getIndices()[lastIndex];
2545 Value mask = rewriter.
create<vector::CreateMaskOp>(
2547 VectorType::get(vtp.getShape(), rewriter.
getI1Type(),
2548 vtp.getNumScalableDims()),
2550 if (xferOp.getMask()) {
2552 mask = rewriter.
create<arith::AndIOp>(loc, mask, xferOp.getMask());
2556 xferOp.getMaskMutable().assign(mask);
2564 const bool force32BitVectorIndices;
2572 bool enableIndexOpt)
2574 force32BitVectorIndices(enableIndexOpt) {}
2578 auto dstType = op.getType();
2579 if (dstType.cast<VectorType>().isScalable())
2581 int64_t rank = dstType.getRank();
2586 rank == 0 ? 0 : dstType.getDimSize(0),
2592 const bool force32BitVectorIndices;
2599 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
2602 if (readOp.getTransferRank() == 0)
2606 if (readOp.getMask())
2609 auto srcType = readOp.getSource().getType().dyn_cast<MemRefType>();
2610 if (!srcType || !srcType.hasStaticShape())
2613 if (!readOp.getPermutationMap().isMinorIdentity())
2616 auto targetType = readOp.getVectorType();
2617 if (targetType.getRank() <= 1)
2625 size_t dimsToDrop = 0;
2626 for (
size_t i = 1; i < srcStrides.size(); ++i) {
2627 int dim = srcType.getRank() - i - 1;
2628 if (srcStrides[dim] == 1) {
2634 if (dimsToDrop == 0)
2637 auto resultTargetVecType =
2638 VectorType::get(targetType.getShape().drop_back(dimsToDrop),
2639 targetType.getElementType());
2641 MemRefType resultMemrefType;
2642 if (srcType.getLayout().getAffineMap().isIdentity()) {
2643 resultMemrefType = MemRefType::get(
2644 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2645 {}, srcType.getMemorySpaceAsInt());
2647 AffineMap map = srcType.getLayout().getAffineMap();
2649 for (
size_t i = 0; i < dimsToDrop; ++i) {
2650 int dim = srcType.getRank() - i - 1;
2655 resultMemrefType = MemRefType::get(
2656 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2657 map, srcType.getMemorySpaceAsInt());
2660 auto loc = readOp.getLoc();
2664 ArrayAttr inBoundsAttr =
2665 readOp.getInBounds()
2667 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
2669 Value rankedReducedView = rewriter.
create<memref::SubViewOp>(
2670 loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
2673 rankedReducedView.
getType().
cast<ShapedType>(), resultTargetVecType);
2674 Value result = rewriter.
create<vector::TransferReadOp>(
2675 loc, resultTargetVecType, rankedReducedView,
2676 readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
2677 readOp.getPadding(),
2679 Value(), inBoundsAttr);
2690 static bool isValidKind(
bool isInt, vector::CombiningKind kind) {
2691 using vector::CombiningKind;
2692 enum class KindType { FLOAT, INT, INVALID };
2693 KindType type{KindType::INVALID};
2695 case CombiningKind::MINF:
2696 case CombiningKind::MAXF:
2697 type = KindType::FLOAT;
2699 case CombiningKind::MINUI:
2700 case CombiningKind::MINSI:
2701 case CombiningKind::MAXUI:
2702 case CombiningKind::MAXSI:
2703 case CombiningKind::AND:
2704 case CombiningKind::OR:
2705 case CombiningKind::XOR:
2706 type = KindType::INT;
2708 case CombiningKind::ADD:
2709 case CombiningKind::MUL:
2710 type = isInt ? KindType::INT : KindType::FLOAT;
2713 bool isValidIntKind = (type == KindType::INT) && isInt;
2714 bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
2715 return (isValidIntKind || isValidFloatKind);
2724 vector::CombiningKind kind,
2726 using vector::CombiningKind;
2729 bool isInt = elType.isIntOrIndex();
2731 Value combinedResult{
nullptr};
2733 case CombiningKind::ADD:
2735 combinedResult = rewriter.
create<arith::AddIOp>(loc, x, y);
2737 combinedResult = rewriter.
create<arith::AddFOp>(loc, x, y);
2739 case CombiningKind::MUL:
2741 combinedResult = rewriter.
create<arith::MulIOp>(loc, x, y);
2743 combinedResult = rewriter.
create<arith::MulFOp>(loc, x, y);
2745 case CombiningKind::MINUI:
2746 combinedResult = rewriter.
create<arith::MinUIOp>(loc, x, y);
2748 case CombiningKind::MINSI:
2749 combinedResult = rewriter.
create<arith::MinSIOp>(loc, x, y);
2751 case CombiningKind::MAXUI:
2752 combinedResult = rewriter.
create<arith::MaxUIOp>(loc, x, y);
2754 case CombiningKind::MAXSI:
2755 combinedResult = rewriter.
create<arith::MaxSIOp>(loc, x, y);
2757 case CombiningKind::AND:
2758 combinedResult = rewriter.
create<arith::AndIOp>(loc, x, y);
2760 case CombiningKind::OR:
2761 combinedResult = rewriter.
create<arith::OrIOp>(loc, x, y);
2763 case CombiningKind::XOR:
2764 combinedResult = rewriter.
create<arith::XOrIOp>(loc, x, y);
2766 case CombiningKind::MINF:
2767 combinedResult = rewriter.
create<arith::MinFOp>(loc, x, y);
2769 case CombiningKind::MAXF:
2770 combinedResult = rewriter.
create<arith::MaxFOp>(loc, x, y);
2773 return combinedResult;
2808 auto loc = scanOp.getLoc();
2809 VectorType destType = scanOp.getDestType();
2811 auto elType = destType.getElementType();
2812 bool isInt = elType.isIntOrIndex();
2813 if (!isValidKind(isInt, scanOp.getKind()))
2816 VectorType resType = VectorType::get(destShape, elType);
2817 Value result = rewriter.
create<arith::ConstantOp>(
2819 int64_t reductionDim = scanOp.getReductionDim();
2820 bool inclusive = scanOp.getInclusive();
2821 int64_t destRank = destType.getRank();
2822 VectorType initialValueType = scanOp.getInitialValueType();
2823 int64_t initialValueRank = initialValueType.getRank();
2826 reductionShape[reductionDim] = 1;
2827 VectorType reductionType = VectorType::get(reductionShape, elType);
2831 sizes[reductionDim] = 1;
2835 Value lastOutput, lastInput;
2836 for (
int i = 0; i < destShape[reductionDim]; i++) {
2837 offsets[reductionDim] = i;
2839 Value input = rewriter.
create<vector::ExtractStridedSliceOp>(
2840 loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
2847 if (initialValueRank == 0) {
2849 output = rewriter.
create<vector::BroadcastOp>(
2850 loc, input.
getType(), scanOp.getInitialValue());
2852 output = rewriter.
create<vector::ShapeCastOp>(
2853 loc, input.
getType(), scanOp.getInitialValue());
2857 Value y = inclusive ? input : lastInput;
2858 output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
2859 assert(output !=
nullptr);
2861 result = rewriter.
create<vector::InsertStridedSliceOp>(
2862 loc, output, result, offsets, strides);
2863 lastOutput = output;
2868 if (initialValueRank == 0) {
2869 Value v = rewriter.
create<vector::ExtractOp>(loc, lastOutput, 0);
2871 rewriter.
create<vector::BroadcastOp>(loc, initialValueType, v);
2873 reduction = rewriter.
create<vector::ShapeCastOp>(loc, initialValueType,
2877 rewriter.
replaceOp(scanOp, {result, reduction});
2889 patterns.
getContext(), force32BitVectorIndices);
2911 patterns.
add<CreateMaskOpLowering, ConstantMaskOpLowering>(
2917 patterns.
add<ShapeCastOp2DDownCastRewritePattern,
2918 ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
2924 patterns.
add<OuterProductOpLowering>(patterns.
getContext());
2938 patterns.
add<MultiReduceToContract, CombineContractBroadcast,
2939 CombineContractTranspose, ReorderCastOpsOnBroadcast,
2940 ReorderElementwiseOpsOnTranspose>(patterns.
getContext());
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options=VectorTransformsOptions())
Insert TransposeLowering patterns into extraction/insertion.
Include the generated interface declarations.
SmallVector< int64_t, 4 > computeStrides(ArrayRef< int64_t > shape, ArrayRef< int64_t > sizes)
Given the shape and sizes of a vector, returns the corresponding strides for each dimension...
static std::string diag(llvm::Value &v)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
MLIRContext * getContext() const
constexpr StringRef getParallelIteratorTypeName()
Use to encode that a particular iterator type has parallel semantics.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
static uint64_t getFirstIntValue(ArrayAttr attr)
Gets the first integer value from attr, assuming it is an integer array attribute.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
AffineMap getMultiDimIdentityMap(unsigned rank)
Replace a 0-d vector.store with a vector.extractelement + memref.store.
Operation is a basic unit of execution within MLIR.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
Replace a 0-d vector.load with a memref.load + vector.broadcast.
bool isParallelIterator(Attribute attr)
unsigned getNumSymbols() const
unsigned getNumDims() const
Attribute getZeroAttr(Type type)
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Progressive lowering of a vector.contract a, b, c with row-major matmul semantics to a reduction_size...
MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Value getOperand(unsigned idx)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
unsigned getNumOperands()
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Progressive lowering of transfer_write.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
IteratorType
Typed representation for loop type strings.
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
Progressive lowering of ContractionOp.
static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d)
Helper method to apply dimension ordering permutation.
bool isLastMemrefDimUnitStride(MemRefType type)
Return true if the last dimension of the MemRefType has unit stride.
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Progressive lowering of a vector.contract a, b, c with row-major matmul semantics to an output-size-u...
Progressively lower to finer grained vector.contract and dot-products.
static ArrayRef< int64_t > vectorShape(Type type)
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns)
Collect a set of patterns that bubble up/down bitcast ops.
static constexpr const bool value
LogicalResult matchAndRewrite(vector::TransferWriteOp write, PatternRewriter &rewriter) const override
Builder & dropDim(unsigned pos)
Erase a dim from shape .
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IntegerAttr getI32IntegerAttr(int32_t value)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine...
This class represents an efficient way to signal success or failure.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
Progressive lowering of ContractionOp.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet &patterns)
Collect a set of patterns to reduce the rank of the operands of vector transfer ops to operate on the...
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
Progressively lower a vector.contract a, b, c with row-major matmul semantics to a reduction_size-unr...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Lower contract with all reduction dimensions unrolled to 1 to a vector elementwise operations...
This class provides support for representing a failure result, or a valid value of type T...
Optional< DistributeOps > distributPointwiseVectorOp(OpBuilder &builder, Operation *op, ArrayRef< Value > id, ArrayRef< int64_t > multiplicity, const AffineMap &map)
Distribute a N-D vector pointwise operation over a range of given ids taking all values in [0 ...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
llvm::Optional< unsigned > maxTransferRank
TransferWriteToVectorStoreLowering(MLIRContext *context, llvm::Optional< unsigned > maxRank)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options=VectorTransformsOptions())
Collects patterns to progressively lower vector contraction ops on high-D into low-D reduction and pr...
Lower 2-D transpose to vector.flat_transpose, maps 1-1 to LLVM matrix intrinsics. ...
Base type for affine expression.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
SmallVector< int64_t, 4 > delinearize(ArrayRef< int64_t > strides, int64_t linearIndex)
Given the strides together with a linear index in the dimension space, returns the vector-space offse...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
DenseIntElementsAttr getBoolVectorAttr(ArrayRef< bool > values)
Vector-typed DenseIntElementsAttr getters. values must not be empty.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::LoadOp loadOp, PatternRewriter &rewriter) const override
unsigned getNumResults() const
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Location getLoc()
The source location the operation was defined or derived from.
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices)
These patterns materialize masks for various vector ops such as transfers.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult matchAndRewrite(ConcreteOp xferOp, PatternRewriter &rewriter) const override
ArrayRef< AffineExpr > getResults() const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Lower to vector.matrix_multiply, maps 1-1 to LLVM matrix intrinsics.
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
VectorCreateMaskOpConversion(MLIRContext *context, bool enableIndexOpt)
LogicalResult matchAndRewrite(vector::StoreOp storeOp, PatternRewriter &rewriter) const override
Progressive lowering of transfer_read.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
bool isReductionIterator(Attribute attr)
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns)
Collects patterns to progressively lower vector mask ops into elementary selection and insertion ops...
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value v2)
Return the result value of reducing two scalar/vector values with the corresponding arith operation...
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns)
Collects patterns to progressively lower vector.shape_cast ops on high-D vectors into 1-D/2-D vector ...
LogicalResult matchAndRewrite(vector::TransferReadOp read, PatternRewriter &rewriter) const override
static llvm::ManagedStatic< PassManagerOptions > options
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Lower to vector.outerproduct.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This is a builder type that keeps local references to arguments.
AffineExpr getAffineConstantExpr(int64_t constant)
RAII guard to reset the insertion point of the builder when destroyed.
Type getType() const
Return the type of this value.
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns)
Collect patterns to convert scan op.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Do not split vector transfer operations.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are not listed in unusedDims.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
static int64_t getNumElements(ShapedType type)
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns)
Collect a set of vector.shape_cast folding patterns.
A dimensional identifier appearing in an affine expression.
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
TransferReadToVectorLoadLowering(MLIRContext *context, llvm::Optional< unsigned > maxRank)
Progressive lowering of a vector.contract a, b, c with row-major matmul semantics to: %flattened_a = ...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
MLIRContext is the top-level object for a collection of MLIR operations.
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, llvm::Optional< unsigned > maxTransferRank=llvm::None)
Collect a set of transfer read/write lowering patterns.
AffineExpr getAffineDimExpr(unsigned position)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
unsigned getNumResults()
Return the number of results held by this operation.
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
Progressively lower a vector.contract a, b, c with row-major matmul semantics to: %mta = maybe_transp...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
OperationName getName()
The name of an operation is the key identifier for it.
constexpr StringRef getReductionIteratorTypeName()
Use to encode that a particular iterator type has reduction semantics.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
This class helps build Operations.
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
IntegerAttr getIndexAttr(int64_t value)
Conversion pattern for a vector.create_mask (0-D and 1-D only).
result_type_range getResultTypes()
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns)
Collects patterns to progressively lower vector.broadcast ops on high-D vectors to low-D vector ops...
MLIRContext * getContext() const
llvm::Optional< unsigned > maxTransferRank
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
An attribute that represents a reference to a dense integer vector or tensor object.
Lower 2-D transpose to vector.shuffle.
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)