34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/SmallVector.h"
37 #include "llvm/ADT/StringSet.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/ADT/bit.h"
45 #include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc"
47 #include "mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc"
68 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
70 for (
bool b : denseElts.getValues<
bool>())
73 else if (!b && val <= 0)
86 ArrayAttr masks = m.getMaskDimSizes();
87 auto shape = m.getType().getShape();
90 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
91 int64_t i = llvm::cast<IntegerAttr>(maskIdx).getInt();
108 builder.
create<vector::YieldOp>(loc);
114 switch (combiningKind) {
115 case CombiningKind::ADD:
116 case CombiningKind::MUL:
118 case CombiningKind::MINUI:
119 case CombiningKind::MINSI:
120 case CombiningKind::MAXUI:
121 case CombiningKind::MAXSI:
122 case CombiningKind::AND:
123 case CombiningKind::OR:
124 case CombiningKind::XOR:
126 case CombiningKind::MINF:
127 case CombiningKind::MAXF:
128 return llvm::isa<FloatType>(elementType);
139 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
143 VectorType vectorType) {
144 int64_t elementVectorRank = 0;
145 VectorType elementVectorType =
146 llvm::dyn_cast<VectorType>(shapedType.getElementType());
147 if (elementVectorType)
148 elementVectorRank += elementVectorType.getRank();
151 if (shapedType.getRank() == 0 &&
157 shapedType.getRank(), vectorType.getRank() - elementVectorRank,
158 shapedType.getContext());
162 vector::TransferReadOp read) {
163 return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
164 !read.getMask() && defWrite.getIndices() == read.getIndices() &&
165 defWrite.getVectorType() == read.getVectorType() &&
166 defWrite.getPermutationMap() == read.getPermutationMap();
170 vector::TransferWriteOp priorWrite) {
171 return priorWrite.getIndices() == write.getIndices() &&
172 priorWrite.getMask() == write.getMask() &&
173 priorWrite.getVectorType() == write.getVectorType() &&
174 priorWrite.getPermutationMap() == write.getPermutationMap();
178 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB) {
180 if (transferA.getVectorType() != transferB.getVectorType())
182 unsigned rankOffset = transferA.getLeadingShapedRank();
183 for (
unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
184 auto indexA = transferA.indices()[i].getDefiningOp<arith::ConstantOp>();
185 auto indexB = transferB.indices()[i].getDefiningOp<arith::ConstantOp>();
187 if (!indexA || !indexB)
190 if (i < rankOffset) {
193 if (llvm::cast<IntegerAttr>(indexA.getValue()).getInt() !=
194 llvm::cast<IntegerAttr>(indexB.getValue()).getInt())
200 std::abs(llvm::cast<IntegerAttr>(indexA.getValue()).getInt() -
201 llvm::cast<IntegerAttr>(indexB.getValue()).getInt());
202 if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
210 VectorTransferOpInterface transferB) {
211 if (transferA.source() != transferB.source())
223 for (
auto [posInDim, dimSize, offsetInDim] :
224 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
226 if (posInDim < dimSize + offsetInDim)
230 posInDim = offsetInDim;
266 void VectorDialect::initialize() {
268 #define GET_ATTRDEF_LIST
269 #include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc"
274 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
283 return arith::ConstantOp::materialize(builder, value, type, loc);
299 void vector::MultiDimReductionOp::build(
OpBuilder &builder,
302 CombiningKind kind) {
306 reductionDims.push_back(en.index());
307 build(builder, result, kind, source, acc,
311 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
313 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
318 std::optional<SmallVector<int64_t, 4>>
319 MultiDimReductionOp::getShapeForUnroll() {
320 return llvm::to_vector<4>(getSourceVectorType().
getShape());
325 Type inferredReturnType;
327 if (!llvm::any_of(getReductionDims().getValue(), [&](
Attribute attr) {
328 return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
330 targetShape.push_back(it.value());
332 if (targetShape.empty())
333 inferredReturnType = getSourceVectorType().getElementType();
337 if (getType() != inferredReturnType)
338 return emitOpError() <<
"destination type " << getType()
339 <<
" is incompatible with source type "
340 << getSourceVectorType();
346 Type MultiDimReductionOp::getExpectedMaskType() {
347 auto vecType = getSourceVectorType();
358 struct ElideUnitDimsInMultiDimReduction
362 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
365 for (
const auto &dim :
enumerate(shape)) {
366 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
374 if (reductionOp.isMasked()) {
376 rootOp = reductionOp.getMaskingOp();
377 mask = reductionOp.getMaskingOp().getMask();
379 rootOp = reductionOp;
382 Location loc = reductionOp.getLoc();
383 Value acc = reductionOp.getAcc();
385 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
387 VectorType newMaskType =
389 mask = rewriter.
create<vector::ShapeCastOp>(loc, newMaskType, mask);
391 cast = rewriter.
create<vector::ShapeCastOp>(
392 loc, reductionOp.getDestType(), reductionOp.getSource());
401 cast = rewriter.
create<vector::ExtractOp>(
402 loc, reductionOp.getDestType(), reductionOp.getSource(), zeroAttr);
406 rewriter, loc, reductionOp.getKind(), acc, cast, mask);
413 void MultiDimReductionOp::getCanonicalizationPatterns(
415 results.
add<ElideUnitDimsInMultiDimReduction>(context);
423 CombiningKind kind,
Value vector) {
424 build(builder, result, kind, vector,
Value());
428 CombiningKind kind,
Value vector,
Value acc) {
429 build(builder, result,
430 llvm::cast<VectorType>(vector.
getType()).getElementType(), kind, vector,
436 int64_t rank = getSourceVectorType().getRank();
438 return emitOpError(
"unsupported reduction rank: ") << rank;
441 Type eltType = getDest().getType();
443 return emitOpError(
"unsupported reduction type '")
444 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
454 CombiningKindAttr kindAttr;
460 (!operandsInfo.empty() &&
462 (operandsInfo.size() > 1 &&
466 if (operandsInfo.empty() || operandsInfo.size() > 2)
468 "unsupported number of operands");
474 getKindAttr().print(p);
475 p <<
", " << getVector();
477 p <<
", " << getAcc();
478 p <<
" : " << getVector().getType() <<
" into " << getDest().getType();
484 Type ReductionOp::getExpectedMaskType() {
485 auto vecType = getSourceVectorType();
486 return vecType.cloneWith(std::nullopt,
494 case arith::AtomicRMWKind::addf:
495 case arith::AtomicRMWKind::addi:
496 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
497 CombiningKind::ADD, vector);
498 case arith::AtomicRMWKind::mulf:
499 case arith::AtomicRMWKind::muli:
500 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
501 CombiningKind::MUL, vector);
502 case arith::AtomicRMWKind::minf:
503 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
504 CombiningKind::MINF, vector);
505 case arith::AtomicRMWKind::mins:
506 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
507 CombiningKind::MINSI, vector);
508 case arith::AtomicRMWKind::minu:
509 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
510 CombiningKind::MINUI, vector);
511 case arith::AtomicRMWKind::maxf:
512 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
513 CombiningKind::MAXF, vector);
514 case arith::AtomicRMWKind::maxs:
515 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
516 CombiningKind::MAXSI, vector);
517 case arith::AtomicRMWKind::maxu:
518 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
519 CombiningKind::MAXUI, vector);
520 case arith::AtomicRMWKind::andi:
521 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
522 CombiningKind::AND, vector);
523 case arith::AtomicRMWKind::ori:
524 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
525 CombiningKind::OR, vector);
534 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
535 return llvm::to_vector<4>(getSourceVectorType().
getShape());
547 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
550 if (maskableOp.isMasked()) {
552 rootOp = maskableOp.getMaskingOp();
553 mask = maskableOp.getMaskingOp().getMask();
555 rootOp = reductionOp;
558 auto vectorType = reductionOp.getSourceVectorType();
559 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
562 Location loc = reductionOp.getLoc();
564 if (vectorType.getRank() == 0) {
566 mask = rewriter.
create<ExtractElementOp>(loc, mask);
567 result = rewriter.
create<ExtractElementOp>(loc, reductionOp.getVector());
573 result = rewriter.
create<ExtractOp>(loc, reductionOp.getType(),
574 reductionOp.getVector(),
578 if (
Value acc = reductionOp.getAcc())
590 results.
add<ElideSingleElementReduction>(context);
607 getIteratorTypesAttrName(result.
name),
610 return IteratorTypeAttr::get(builder.getContext(), t);
616 ArrayAttr indexingMaps,
617 ArrayAttr iteratorTypes) {
618 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
619 ContractionOp::getDefaultKind());
624 ArrayAttr indexingMaps,
625 ArrayAttr iteratorTypes, CombiningKind kind) {
642 DictionaryAttr dictAttr;
657 dictAttr.getValue().end());
663 ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
668 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
669 auto maybeIteratorType = symbolizeIteratorType(s);
670 if (!maybeIteratorType.has_value())
671 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
673 iteratorTypeAttrs.push_back(
681 getKindAttrName(result.
name),
683 ContractionOp::getDefaultKind()));
685 if (masksInfo.empty())
687 if (masksInfo.size() != 2)
689 "expected zero or exactly 2 vector mask operands");
690 auto lhsType = llvm::cast<VectorType>(types[0]);
691 auto rhsType = llvm::cast<VectorType>(types[1]);
693 std::array<Type, 2> maskTypes = {
703 auto attrNames = getTraitAttrNames();
705 traitAttrsSet.insert(attrNames.begin(), attrNames.end());
707 for (
auto attr : (*this)->getAttrs()) {
708 if (attr.getName() == getIteratorTypesAttrName()) {
710 llvm::cast<ArrayAttr>(attr.getValue())
711 .getAsValueRange<IteratorTypeAttr, IteratorType>();
717 llvm::map_range(iteratorTypes, [&](IteratorType t) ->
Attribute {
721 attrs.emplace_back(getIteratorTypesAttrName(),
723 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0)
724 attrs.push_back(attr);
728 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
729 p << getRhs() <<
", " << getAcc();
732 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
737 const std::vector<std::pair<int64_t, int64_t>> &map) {
738 for (
auto &dimPair : map) {
739 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
740 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
741 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
748 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
750 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
751 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
754 for (
auto &dimPair : contractingDimMap) {
755 lhsContractingDimSet.insert(dimPair.first);
756 rhsContractingDimSet.insert(dimPair.second);
759 for (
auto &dimPair : batchDimMap)
760 rhsBatchDimSet.insert(dimPair.second);
764 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
765 if (lhsContractingDimSet.count(i) > 0)
767 expectedResultDims.push_back(lhsType.getDimSize(i));
771 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
772 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
774 expectedResultDims.push_back(rhsType.getDimSize(i));
778 if (expectedResultDims.empty()) {
780 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
781 return op.
emitOpError(
"invalid accumulator/result vector shape");
784 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
785 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
786 if (!resVectorType || !accVectorType)
787 return op.
emitOpError(
"invalid accumulator/result vector shape");
793 AffineMap lhsMap = op.getIndexingMapsArray()[0];
794 AffineMap rhsMap = op.getIndexingMapsArray()[1];
797 "expected all dimensions to be either a LHS or a RHS dimension");
800 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
801 VectorType v = pair.first;
802 auto map = pair.second;
803 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
804 unsigned pos = map.getDimPosition(idx);
809 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
810 return op.
emitOpError(
"expected all dimensions to get an extent as "
811 "either a LHS or a RHS dimension");
813 AffineMap resMap = op.getIndexingMapsArray()[2];
820 [](
AffineExpr e) { return e.isa<AffineConstantExpr>(); }) &&
821 "expected constant extent along all dimensions.");
823 auto expectedShape = llvm::to_vector<4>(
825 return e.cast<AffineConstantExpr>().getValue();
829 if (resVectorType != expected || accVectorType != expected)
831 "invalid accumulator/result vector shape, expected: ")
838 VectorType lhsType = getLhsType();
839 VectorType rhsType = getRhsType();
840 Type accType = getAccType();
841 Type resType = getResultType();
843 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
844 if (!lhsType.getElementType().isSignlessInteger())
845 return emitOpError(
"only supports signless integer types");
849 if (getIndexingMapsArray().size() != 3)
850 return emitOpError(
"expected an indexing map for each vector operand");
855 unsigned numIterators = getIteratorTypes().getValue().size();
857 auto index = it.index();
858 auto map = it.value();
859 if (map.getNumSymbols() != 0)
860 return emitOpError(
"expected indexing map ")
861 << index <<
" to have no symbols";
862 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
863 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
866 if (map.getNumDims() != numIterators)
867 return emitOpError(
"expected indexing map ")
868 << index <<
" to have " << numIterators <<
" number of inputs";
869 if (map.getNumResults() != rank)
870 return emitOpError(
"expected indexing map ")
871 << index <<
" to have " << rank <<
" number of outputs";
872 if (!map.isProjectedPermutation())
873 return emitOpError(
"expected indexing map ")
874 << index <<
" to be a projected permutation of its inputs";
877 auto contractingDimMap = getContractingDimMap();
878 auto batchDimMap = getBatchDimMap();
881 if (contractingDimMap.empty())
882 return emitOpError(
"expected at least one contracting dimension pair");
886 return emitOpError(
"invalid contracting dimension map");
890 return emitOpError(
"invalid batch dimension map");
894 contractingDimMap, batchDimMap)))
898 auto vectorType = llvm::dyn_cast<VectorType>(resType);
899 auto elementType = vectorType ? vectorType.getElementType() : resType;
901 return emitOpError(
"unsupported contraction type");
910 Type ContractionOp::getExpectedMaskType() {
911 auto indexingMaps = this->getIndexingMapsArray();
914 VectorType lhsType = this->getLhsType();
915 VectorType rhsType = this->getRhsType();
927 assert(!ShapedType::isDynamicShape(maskShape) &&
928 "Mask shape couldn't be computed");
936 getIteratorTypesAttrName(), getKindAttrName()};
946 static std::vector<std::pair<int64_t, int64_t>>
948 IteratorType targetIteratorType,
MLIRContext *context) {
949 std::vector<std::pair<int64_t, int64_t>> dimMap;
951 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
952 if (iteratorType != targetIteratorType)
958 if (lhsDim >= 0 && rhsDim >= 0)
959 dimMap.emplace_back(lhsDim, rhsDim);
964 void ContractionOp::getIterationBounds(
966 auto lhsShape = getLhsType().getShape();
967 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
973 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
974 if (iteratorType == IteratorType::reduction) {
976 int64_t lhsDimIndex =
getResultIndex(indexingMaps[0], targetExpr);
977 assert(lhsDimIndex >= 0);
978 iterationBounds.push_back(lhsShape[lhsDimIndex]);
982 int64_t resDimIndex =
getResultIndex(indexingMaps[2], targetExpr);
983 assert(resDimIndex >= 0);
984 assert(resVectorType !=
nullptr);
985 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
989 void ContractionOp::getIterationIndexMap(
991 unsigned numMaps = getIndexingMapsArray().size();
992 iterationIndexMap.resize(numMaps);
994 auto index = it.index();
995 auto map = it.value();
996 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
998 iterationIndexMap[index][dim.getPosition()] = i;
1003 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1005 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1009 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1011 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1015 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1017 getIterationBounds(shape);
1039 template <
typename AddOpType>
1045 auto canonicalize = [&](
Value maybeContraction,
1046 Value otherOperand) -> vector::ContractionOp {
1047 vector::ContractionOp contractionOp =
1048 dyn_cast_or_null<vector::ContractionOp>(
1051 return vector::ContractionOp();
1052 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1053 contractionOp.getAcc().getDefiningOp())) {
1054 if (maybeZero.getValue() ==
1055 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1057 bvm.
map(contractionOp.getAcc(), otherOperand);
1058 auto newContraction =
1059 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1060 rewriter.
replaceOp(addOp, newContraction.getResult());
1061 return newContraction;
1064 return vector::ContractionOp();
1067 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1068 vector::ContractionOp
contract = canonicalize(a, b);
1087 result.
addTypes(llvm::cast<VectorType>(source.
getType()).getElementType());
1091 VectorType vectorType = getSourceVectorType();
1092 if (vectorType.getRank() == 0) {
1094 return emitOpError(
"expected position to be empty with 0-D vector");
1097 if (vectorType.getRank() != 1)
1098 return emitOpError(
"unexpected >1 vector rank");
1100 return emitOpError(
"expected position for 1-D vector");
1104 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1106 if (!adaptor.getPosition())
1113 if (
auto splat = getVector().getDefiningOp<vector::SplatOp>())
1114 return splat.getInput();
1117 if (
auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1124 auto srcElements = llvm::cast<DenseElementsAttr>(src).getValues<
Attribute>();
1126 auto attr = llvm::dyn_cast<IntegerAttr>(pos);
1127 uint64_t posIdx = attr.getInt();
1129 return srcElements[posIdx];
1145 llvm::to_vector<4>(llvm::map_range(position, [](
Value pos) {
1148 build(builder, result, source, positionConstants);
1152 ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1153 ValueRange operands, DictionaryAttr attributes,
1156 ExtractOp::Adaptor op(operands, attributes, properties);
1157 auto vectorType = llvm::cast<VectorType>(op.getVector().getType());
1158 if (
static_cast<int64_t
>(op.getPosition().size()) == vectorType.getRank()) {
1159 inferredReturnTypes.push_back(vectorType.getElementType());
1162 std::min<size_t>(op.getPosition().size(), vectorType.getRank() - 1);
1164 vectorType.getShape().drop_front(n), vectorType.getElementType()));
1172 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1173 return vectorType && vectorType.getShape().equals({1}) &&
1174 vectorType.getElementType() == r.front();
1176 if (l.size() == 1 && r.size() == 1 &&
1177 (isCompatible(l, r) || isCompatible(r, l)))
1183 auto positionAttr = getPosition().getValue();
1184 if (positionAttr.size() >
1185 static_cast<unsigned>(getSourceVectorType().getRank()))
1187 "expected position attribute of rank smaller than vector rank");
1189 auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
1190 if (!attr || attr.getInt() < 0 ||
1191 attr.getInt() >= getSourceVectorType().getDimSize(en.index()))
1192 return emitOpError(
"expected position attribute #")
1194 <<
" to be a non-negative integer smaller than the corresponding "
1200 template <
typename IntType>
1202 return llvm::to_vector<4>(llvm::map_range(
1203 arrayAttr.getAsRange<IntegerAttr>(),
1204 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1210 if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1214 ExtractOp currentOp = extractOp;
1215 auto extrPos = extractVector<int64_t>(currentOp.getPosition());
1216 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1217 while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1219 auto extrPos = extractVector<int64_t>(currentOp.getPosition());
1220 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1222 extractOp.setOperand(currentOp.getVector());
1225 std::reverse(globalPosition.begin(), globalPosition.end());
1226 extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1239 class ExtractFromInsertTransposeChainState {
1241 ExtractFromInsertTransposeChainState(ExtractOp e);
1250 template <
typename ContainerA,
typename ContainerB>
1251 bool isContainedWithin(
const ContainerA &a,
const ContainerB &b) {
1252 return a.size() <= b.size() &&
1253 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1260 template <
typename ContainerA,
typename ContainerB>
1261 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &b) {
1262 for (
auto [elemA, elemB] : llvm::zip(a, b)) {
1263 if (elemA < 0 || elemB < 0)
1278 void updateStateForNextIteration(
Value v) {
1308 Value tryToFoldExtractOpInPlace(
Value source);
1310 ExtractOp extractOp;
1312 int64_t extractedRank;
1314 InsertOp nextInsertOp;
1315 TransposeOp nextTransposeOp;
1330 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1332 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1333 extractedRank(extractOp.getPosition().size()) {
1334 assert(vectorRank >= extractedRank &&
"extracted pos overflow");
1335 sentinels.reserve(vectorRank - extractedRank);
1336 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1337 sentinels.push_back(-(i + 1));
1344 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1345 if (!nextTransposeOp)
1347 auto permutation = extractVector<unsigned>(nextTransposeOp.getTransp());
1349 AffineMap::getPermutationMap(permutation, extractOp.getContext()));
1356 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1358 auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
1363 res = nextInsertOp.getSource();
1372 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(
Value &res) {
1373 auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
1383 res = nextInsertOp.getSource();
1391 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1394 bool nothingToFold = (source == extractOp.getVector());
1395 if (nothingToFold || !canFold())
1400 extractOp.getPositionAttrName(),
1402 extractOp.getVectorMutable().assign(source);
1403 return extractOp.getResult();
1407 Value ExtractFromInsertTransposeChainState::fold() {
1408 Value valueToExtractFrom = extractOp.getVector();
1409 updateStateForNextIteration(valueToExtractFrom);
1410 while (nextInsertOp || nextTransposeOp) {
1414 valueToExtractFrom = nextTransposeOp.getVector();
1415 updateStateForNextIteration(valueToExtractFrom);
1421 if (
succeeded(handleInsertOpWithMatchingPos(result)))
1426 if (
succeeded(handleInsertOpWithPrefixPos(result)))
1427 return tryToFoldExtractOpInPlace(result);
1431 auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
1437 valueToExtractFrom = nextInsertOp.getDest();
1438 updateStateForNextIteration(valueToExtractFrom);
1441 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1446 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1447 auto vecType = dyn_cast<VectorType>(type);
1448 return vecType && vecType.getRank() == 0;
1457 Operation *defOp = extractOp.getVector().getDefiningOp();
1458 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1467 if (extractOp.getType() == source.
getType())
1469 auto getRank = [](
Type type) {
1470 return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1474 unsigned broadcastSrcRank = getRank(source.
getType());
1475 if (broadcastSrcRank == 0)
1478 unsigned extractResultRank = getRank(extractOp.getType());
1479 if (extractResultRank >= broadcastSrcRank)
1482 auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1483 auto broadcastVecType = llvm::dyn_cast<VectorType>(source.
getType());
1484 if (extractVecType && broadcastVecType &&
1485 extractVecType.getShape() !=
1486 broadcastVecType.getShape().take_back(extractResultRank))
1489 auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1490 int64_t rankDiff = broadcastSrcRank - extractResultRank;
1495 broadcastOp.computeBroadcastedUnitDims();
1496 auto extractPos = extractVector<int64_t>(extractOp.getPosition());
1497 for (int64_t i = rankDiff, e = extractPos.size(); i < e; ++i)
1498 if (broadcastedUnitDims.contains(i))
1502 extractPos.erase(extractPos.begin(),
1503 std::next(extractPos.begin(), extractPos.size() - rankDiff));
1506 extractOp.setOperand(source);
1507 extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1509 return extractOp.getResult();
1514 auto shapeCastOp = extractOp.getVector().
getDefiningOp<vector::ShapeCastOp>();
1524 auto getDimReverse = [](VectorType type, int64_t n) {
1525 return type.getShape().take_back(n + 1).front();
1527 int64_t destinationRank =
1528 llvm::isa<VectorType>(extractOp.getType())
1529 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1531 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1533 if (destinationRank > 0) {
1534 auto destinationType =
1535 llvm::cast<VectorType>(extractOp.getResult().getType());
1536 for (int64_t i = 0; i < destinationRank; i++) {
1540 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1541 getDimReverse(destinationType, i))
1547 auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
1548 std::reverse(extractedPos.begin(), extractedPos.end());
1551 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1552 strides.push_back(stride);
1554 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1557 int64_t position =
linearize(extractedPos, strides);
1561 int64_t numDimension =
1562 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1564 for (int64_t i = 0; i < numDimension; i++) {
1565 newStrides.push_back(stride);
1567 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1569 std::reverse(newStrides.begin(), newStrides.end());
1573 extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1575 extractOp.setOperand(shapeCastOp.getSource());
1576 return extractOp.getResult();
1581 auto extractStridedSliceOp =
1582 extractOp.getVector().
getDefiningOp<vector::ExtractStridedSliceOp>();
1583 if (!extractStridedSliceOp)
1592 if (extractStridedSliceOp.hasNonUnitStrides())
1597 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1598 while (!sliceOffsets.empty()) {
1599 size_t lastOffset = sliceOffsets.size() - 1;
1600 if (sliceOffsets.back() != 0 ||
1601 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1602 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1604 sliceOffsets.pop_back();
1606 unsigned destinationRank = 0;
1607 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1608 destinationRank = vecType.getRank();
1611 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1612 sliceOffsets.size())
1614 auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
1615 assert(extractedPos.size() >= sliceOffsets.size());
1616 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1617 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1618 extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1621 extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1623 return extractOp.getResult();
1628 int64_t destinationRank =
1629 llvm::isa<VectorType>(extractOp.getType())
1630 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1632 auto insertOp = extractOp.getVector().
getDefiningOp<InsertStridedSliceOp>();
1642 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1643 insertOp.getSourceVectorType().getRank();
1644 if (destinationRank > insertOp.getSourceVectorType().getRank())
1646 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1647 auto extractOffsets = extractVector<int64_t>(extractOp.getPosition());
1649 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1650 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1653 bool disjoint =
false;
1655 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1656 int64_t start = insertOffsets[dim];
1658 (dim < insertRankDiff)
1660 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1661 int64_t end = start + size;
1662 int64_t offset = extractOffsets[dim];
1664 if (start <= offset && offset < end) {
1665 if (dim >= insertRankDiff)
1666 offsetDiffs.push_back(offset - start);
1676 int64_t srcRankDiff =
1677 insertOp.getSourceVectorType().getRank() - destinationRank;
1678 for (int64_t i = 0; i < destinationRank; i++) {
1679 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1680 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1684 extractOp.getVectorMutable().assign(insertOp.getSource());
1687 extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1689 return extractOp.getResult();
1693 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1699 if (getPosition().empty())
1703 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
1725 Operation *defOp = extractOp.getVector().getDefiningOp();
1726 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1730 if (extractOp.getType() == source.
getType())
1732 auto getRank = [](
Type type) {
1733 return llvm::isa<VectorType>(type)
1734 ? llvm::cast<VectorType>(type).getRank()
1737 unsigned broadcastSrcRank = getRank(source.
getType());
1738 unsigned extractResultRank = getRank(extractOp.getType());
1742 if (extractResultRank < broadcastSrcRank)
1746 if (extractResultRank == 0) {
1747 assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.
getType()));
1752 extractOp, extractOp.getType(), source);
1758 class ExtractOpSplatConstantFolder final :
public OpRewritePattern<ExtractOp> {
1766 Value sourceVector = extractOp.getVector();
1770 auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
1773 TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
1774 if (
auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1782 class ExtractOpNonSplatConstantFolder final
1791 Value sourceVector = extractOp.getVector();
1796 auto vecTy = llvm::cast<VectorType>(sourceVector.
getType());
1797 if (vecTy.isScalable())
1801 auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
1802 if (!dense || dense.isSplat())
1809 int64_t elemBeginPosition =
1811 auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
1814 if (
auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
1816 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
1819 newAttr = *denseValuesBegin;
1831 results.
add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
1832 ExtractOpFromBroadcast>(context);
1837 for (
auto attr : arrayAttr)
1838 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
1845 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
1858 int64_t rankDiff = dstShape.size() - srcShape.size();
1859 int64_t dstDim = rankDiff;
1861 for (
auto [s1, s2] :
1862 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
1864 assert(s1 == 1 &&
"expected dim-1 broadcasting");
1874 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
1893 Value BroadcastOp::createOrFoldBroadcastOp(
1896 assert(!dstShape.empty() &&
"unexpected empty dst shape");
1900 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
1901 if (broadcastedDims.contains(i))
1903 checkShape.push_back(dstShape[i]);
1905 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
1906 "ill-formed broadcastedDims contains values not confined to "
1911 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
1915 if (!srcVectorType) {
1916 assert(checkShape.empty() &&
1917 "ill-formed createOrFoldBroadcastOp arguments");
1918 return b.
createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
1921 assert(srcVectorType.getShape().equals(checkShape) &&
1922 "ill-formed createOrFoldBroadcastOp arguments");
1933 broadcastShape.reserve(dstShape.size());
1949 int64_t nextSrcShapeDim = broadcastedDims.size();
1950 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
1951 if (broadcastedDims.contains(i)) {
1956 broadcastShape.push_back(dstShape[i]);
1957 permutation[i] = broadcastShape.size() - 1;
1963 permutation[i] = nextSrcShapeDim++;
1967 llvm::append_range(broadcastShape, srcVectorType.getShape());
1972 "unexpected dim-1 broadcast");
1974 VectorType broadcastType =
VectorType::get(broadcastShape, elementType);
1976 vector::BroadcastableToResult::Success &&
1977 "must be broadcastable");
1981 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
1982 if (permutation[i] != i)
1983 return b.
createOrFold<vector::TransposeOp>(loc, res, permutation);
1990 std::pair<int, int> *mismatchingDims) {
1994 return BroadcastableToResult::Success;
1996 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
1998 return BroadcastableToResult::SourceTypeNotAVector;
2000 int64_t srcRank = srcVectorType.getRank();
2001 int64_t dstRank = dstVectorType.getRank();
2002 if (srcRank > dstRank)
2003 return BroadcastableToResult::SourceRankHigher;
2006 int64_t lead = dstRank - srcRank;
2007 for (int64_t r = 0; r < srcRank; ++r) {
2008 int64_t srcDim = srcVectorType.getDimSize(r);
2009 int64_t dstDim = dstVectorType.getDimSize(lead + r);
2010 if (srcDim != 1 && srcDim != dstDim) {
2011 if (mismatchingDims) {
2012 mismatchingDims->first = srcDim;
2013 mismatchingDims->second = dstDim;
2015 return BroadcastableToResult::DimensionMismatch;
2019 return BroadcastableToResult::Success;
2023 std::pair<int, int> mismatchingDims;
2025 getSourceType(), getResultVectorType(), &mismatchingDims);
2026 if (res == BroadcastableToResult::Success)
2028 if (res == BroadcastableToResult::SourceRankHigher)
2029 return emitOpError(
"source rank higher than destination rank");
2030 if (res == BroadcastableToResult::DimensionMismatch)
2031 return emitOpError(
"dimension mismatch (")
2032 << mismatchingDims.first <<
" vs. " << mismatchingDims.second <<
")";
2033 if (res == BroadcastableToResult::SourceTypeNotAVector)
2034 return emitOpError(
"source type is not a vector");
2035 llvm_unreachable(
"unexpected vector.broadcast op error");
2039 if (getSourceType() == getResultVectorType())
2041 if (!adaptor.getSource())
2043 auto vectorType = getResultVectorType();
2044 if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
2046 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2059 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2063 broadcastOp.getResultVectorType(),
2064 srcBroadcast.getSource());
2074 results.
add<BroadcastFolder>(context);
2087 VectorType resultType = getResultVectorType();
2088 VectorType v1Type = getV1VectorType();
2089 VectorType v2Type = getV2VectorType();
2091 int64_t resRank = resultType.getRank();
2092 int64_t v1Rank = v1Type.getRank();
2093 int64_t v2Rank = v2Type.getRank();
2094 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2095 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2096 if (!wellFormed0DCase && !wellFormedNDCase)
2097 return emitOpError(
"rank mismatch");
2100 for (int64_t r = 1; r < v1Rank; ++r) {
2101 int64_t resDim = resultType.getDimSize(r);
2102 int64_t v1Dim = v1Type.getDimSize(r);
2103 int64_t v2Dim = v2Type.getDimSize(r);
2104 if (resDim != v1Dim || v1Dim != v2Dim)
2105 return emitOpError(
"dimension mismatch");
2108 auto maskAttr = getMask().getValue();
2109 int64_t maskLength = maskAttr.size();
2110 if (maskLength <= 0)
2111 return emitOpError(
"invalid mask length");
2112 if (maskLength != resultType.getDimSize(0))
2113 return emitOpError(
"mask length mismatch");
2115 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2116 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2118 auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
2119 if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
2120 return emitOpError(
"mask index #") << (en.index() + 1) <<
" out of range";
2126 ShuffleOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
2127 ValueRange operands, DictionaryAttr attributes,
2130 ShuffleOp::Adaptor op(operands, attributes, properties);
2131 auto v1Type = llvm::cast<VectorType>(op.getV1().getType());
2132 auto v1Rank = v1Type.getRank();
2136 shape.reserve(v1Rank);
2137 shape.push_back(std::max<size_t>(1, op.getMask().size()));
2140 llvm::append_range(shape, v1Type.getShape().drop_front());
2141 inferredReturnTypes.push_back(
2147 uint64_t expected = begin;
2148 return idxArr.size() == width &&
2149 llvm::all_of(idxArr.getAsValueRange<IntegerAttr>(),
2150 [&expected](
auto attr) {
2151 return attr.getZExtValue() == expected++;
2155 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2156 VectorType v1Type = getV1VectorType();
2159 if (v1Type.getRank() == 0)
2163 if (!v1Type.isScalable() &&
2167 if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
2169 getV2VectorType().getDimSize(0)))
2172 Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
2177 llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).getType());
2180 if (lhsType.getRank() != 1)
2182 int64_t lhsSize = lhsType.getDimSize(0);
2185 auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<
Attribute>();
2186 auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<
Attribute>();
2187 for (
const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) {
2188 int64_t i = index.getZExtValue();
2190 results.push_back(rhsElements[i - lhsSize]);
2192 results.push_back(lhsElements[i]);
2208 VectorType v1VectorType = shuffleOp.getV1VectorType();
2209 ArrayAttr mask = shuffleOp.getMask();
2210 if (v1VectorType.getRank() > 0)
2212 if (mask.size() != 1)
2215 if (llvm::cast<IntegerAttr>(mask[0]).getInt() == 0)
2232 auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2233 auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2235 if (!v1Splat || !v2Splat)
2238 if (v1Splat.getInput() != v2Splat.getInput())
2250 results.
add<ShuffleSplat, Canonicalize0DShuffleOp>(context);
2259 build(builder, result, source, dest, {});
2263 auto dstVectorType = getDestVectorType();
2264 if (dstVectorType.getRank() == 0) {
2266 return emitOpError(
"expected position to be empty with 0-D vector");
2269 if (dstVectorType.getRank() != 1)
2270 return emitOpError(
"unexpected >1 vector rank");
2272 return emitOpError(
"expected position for 1-D vector");
2276 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2278 if (!adaptor.getPosition())
2284 if (!src || !dst || !pos)
2287 auto dstElements = llvm::cast<DenseElementsAttr>(dst).getValues<
Attribute>();
2291 auto attr = llvm::dyn_cast<IntegerAttr>(pos);
2292 uint64_t posIdx = attr.getInt();
2294 results[posIdx] = src;
2308 result.
addAttribute(getPositionAttrStrName(), positionAttr);
2315 llvm::to_vector<4>(llvm::map_range(position, [](
Value pos) {
2318 build(builder, result, source, dest, positionConstants);
2322 auto positionAttr = getPosition().getValue();
2323 auto destVectorType = getDestVectorType();
2324 if (positionAttr.size() >
static_cast<unsigned>(destVectorType.getRank()))
2326 "expected position attribute of rank smaller than dest vector rank");
2327 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2328 if (srcVectorType &&
2329 (
static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
2330 static_cast<unsigned>(destVectorType.getRank())))
2331 return emitOpError(
"expected position attribute rank + source rank to "
2332 "match dest vector rank");
2333 if (!srcVectorType &&
2334 (positionAttr.size() !=
static_cast<unsigned>(destVectorType.getRank())))
2336 "expected position attribute rank to match the dest vector rank");
2338 auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
2339 if (!attr || attr.getInt() < 0 ||
2340 attr.getInt() >= destVectorType.getDimSize(en.index()))
2341 return emitOpError(
"expected position attribute #")
2343 <<
" to be a non-negative integer smaller than the corresponding "
2344 "dest vector dimension";
2359 auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
2360 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2361 srcVecType.getNumElements())
2364 insertOp, insertOp.getDestVectorType(), insertOp.getSource());
2376 auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
2377 auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
2379 if (!srcSplat || !dstSplat)
2382 if (srcSplat.getInput() != dstSplat.getInput())
2397 static constexpr int64_t vectorSizeFoldThreshold = 256;
2408 VectorType destTy = destVector.getType();
2409 if (destTy.isScalable())
2413 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
2414 !destVector.hasOneUse())
2417 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
2419 Value sourceValue = op.getSource();
2428 int64_t insertBeginPosition =
2432 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst))
2433 llvm::append_range(insertedValues, denseSource.getValues<
Attribute>());
2435 insertedValues.push_back(sourceCst);
2437 auto allValues = llvm::to_vector(denseDest.getValues<
Attribute>());
2438 copy(insertedValues, allValues.begin() + insertBeginPosition);
2450 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
2451 InsertOpConstantFolder>(context);
2457 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
2458 if (getPosition().empty())
2475 result.
addAttribute(getOffsetsAttrStrName(), offsetsAttr);
2476 result.
addAttribute(getStridesAttrStrName(), stridesAttr);
2480 template <
typename OpType>
2482 ArrayAttr arrayAttr,
2484 StringRef attrName) {
2485 if (arrayAttr.size() > shape.size())
2487 << attrName <<
" attribute of rank smaller than vector rank";
2494 template <
typename OpType>
2497 int64_t
max, StringRef attrName,
2498 bool halfOpen =
true) {
2499 for (
auto attr : arrayAttr) {
2500 auto val = llvm::cast<IntegerAttr>(attr).getInt();
2504 if (val < min || val >= upper)
2505 return op.
emitOpError(
"expected ") << attrName <<
" to be confined to ["
2506 <<
min <<
", " << upper <<
")";
2514 template <
typename OpType>
2518 bool halfOpen =
true, int64_t
min = 0) {
2519 for (
auto [index, attrDimPair] :
2521 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
2522 int64_t
max = std::get<1>(attrDimPair);
2525 if (val < min || val >=
max)
2527 << attrName <<
" dimension " << index <<
" to be confined to ["
2528 <<
min <<
", " <<
max <<
")";
2536 template <
typename OpType>
2538 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
2540 bool halfOpen =
true, int64_t
min = 1) {
2541 assert(arrayAttr1.size() <= shape.size());
2542 assert(arrayAttr2.size() <= shape.size());
2543 for (
auto [index, it] :
2545 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
2546 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
2547 int64_t
max = std::get<2>(it);
2550 if (val1 + val2 < 0 || val1 + val2 >=
max)
2552 << attrName1 <<
", " << attrName2 <<
") dimension " << index
2553 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
2560 auto attrs = llvm::map_range(values, [context](int64_t v) ->
Attribute {
2567 auto sourceVectorType = getSourceVectorType();
2568 auto destVectorType = getDestVectorType();
2569 auto offsets = getOffsetsAttr();
2570 auto strides = getStridesAttr();
2571 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
2573 "expected offsets of same size as destination vector rank");
2574 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
2575 return emitOpError(
"expected strides of same size as source vector rank");
2576 if (sourceVectorType.getRank() > destVectorType.getRank())
2578 "expected source rank to be smaller than destination rank");
2580 auto sourceShape = sourceVectorType.getShape();
2581 auto destShape = destVectorType.getShape();
2583 destShape.size() - sourceShape.size(), 0);
2584 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
2585 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
2586 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
2595 offName,
"source vector shape",
2605 class FoldInsertStridedSliceSplat final
2613 insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
2615 insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
2617 if (!srcSplatOp || !destSplatOp)
2620 if (srcSplatOp.getInput() != destSplatOp.getInput())
2623 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
2630 class FoldInsertStridedSliceOfExtract final
2637 auto extractStridedSliceOp =
2638 insertStridedSliceOp.getSource()
2639 .getDefiningOp<vector::ExtractStridedSliceOp>();
2641 if (!extractStridedSliceOp)
2644 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
2648 if (extractStridedSliceOp.getStrides() !=
2649 insertStridedSliceOp.getStrides() ||
2650 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
2653 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
2660 class InsertStridedSliceConstantFolder final
2667 static constexpr int64_t vectorSizeFoldThreshold = 256;
2678 VectorType destTy = destVector.getType();
2679 if (destTy.isScalable())
2683 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
2684 !destVector.hasOneUse())
2687 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
2695 if (op.hasNonUnitStrides())
2698 VectorType sliceVecTy = sourceValue.getType();
2700 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
2710 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
2711 auto sliceValuesIt = denseSlice.value_begin<
Attribute>();
2712 auto newValues = llvm::to_vector(denseDest.getValues<
Attribute>());
2715 currDestPosition.begin() + rankDifference, currDestPosition.end());
2719 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
2720 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
2721 assert(sliceValuesIt != denseSlice.value_end<
Attribute>() &&
2722 "Invalid slice element");
2723 newValues[linearizedPosition] = *sliceValuesIt;
2736 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
2738 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
2739 InsertStridedSliceConstantFolder>(context);
2742 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
2743 if (getSourceVectorType() == getDestVectorType())
2760 p <<
" " << getLhs() <<
", " << getRhs();
2761 if (!getAcc().empty()) {
2762 p <<
", " << getAcc();
2765 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
2776 if (operandsInfo.size() < 2)
2778 "expected at least 2 operands");
2779 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
2780 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
2783 "expected vector type for operand #1");
2785 unsigned numScalableDims = vLHS.getNumScalableDims();
2788 numScalableDims += vRHS.getNumScalableDims();
2790 vLHS.getElementType(), numScalableDims);
2793 resType =
VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
2797 if (!result.
attributes.
get(OuterProductOp::getKindAttrStrName())) {
2799 OuterProductOp::getKindAttrStrName(),
2801 OuterProductOp::getDefaultKind()));
2807 (operandsInfo.size() > 2 &&
2813 Type tRHS = getOperandTypeRHS();
2814 VectorType vLHS = getOperandVectorTypeLHS(),
2815 vRHS = llvm::dyn_cast<VectorType>(tRHS),
2816 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
2818 if (vLHS.getRank() != 1)
2819 return emitOpError(
"expected 1-d vector for operand #1");
2823 if (vRHS.getRank() != 1)
2824 return emitOpError(
"expected 1-d vector for operand #2");
2825 if (vRES.getRank() != 2)
2826 return emitOpError(
"expected 2-d vector result");
2827 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
2828 return emitOpError(
"expected #1 operand dim to match result dim #1");
2829 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
2830 return emitOpError(
"expected #2 operand dim to match result dim #2");
2831 if (vRHS.isScalable() != vLHS.isScalable())
2832 return emitOpError(
"expected either all or none of vector operands #1 "
2833 "and #2 to be scalable");
2836 if (vRES.getRank() != 1)
2837 return emitOpError(
"expected 1-d vector result");
2838 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
2839 return emitOpError(
"expected #1 operand dim to match result dim #1");
2842 if (vACC && vACC != vRES)
2843 return emitOpError(
"expected operand #3 of same type as result type");
2847 return emitOpError(
"unsupported outerproduct type");
2856 Type OuterProductOp::getExpectedMaskType() {
2857 auto vecType = this->getResultVectorType();
2868 auto inputVectorType = getInputVectorType();
2869 auto outputVectorType = getOutputVectorType();
2870 int64_t inputShapeRank = getNumInputShapeSizes();
2871 int64_t outputShapeRank = getNumOutputShapeSizes();
2873 getFixedVectorSizes(fixedVectorSizes);
2874 int64_t numFixedVectorSizes = fixedVectorSizes.size();
2876 if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
2877 return emitError(
"invalid input shape for vector type ") << inputVectorType;
2879 if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
2880 return emitError(
"invalid output shape for vector type ")
2881 << outputVectorType;
2885 unsigned inputVectorRank = inputVectorType.getRank();
2886 for (
unsigned i = 0; i < numFixedVectorSizes; ++i) {
2887 unsigned index = inputVectorRank - numFixedVectorSizes - i;
2888 if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
2889 return emitError(
"fixed vector size must match input vector for dim ")
2893 unsigned outputVectorRank = outputVectorType.getRank();
2894 for (
unsigned i = 0; i < numFixedVectorSizes; ++i) {
2895 unsigned index = outputVectorRank - numFixedVectorSizes - i;
2896 if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
2897 return emitError(
"fixed vector size must match output vector for dim ")
2903 auto isDefByConstant = [](
Value operand) {
2904 return isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
2906 if (llvm::all_of(getInputShape(), isDefByConstant) &&
2907 llvm::all_of(getOutputShape(), isDefByConstant)) {
2908 int64_t numInputElements = 1;
2909 for (
auto operand : getInputShape())
2911 cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
2912 int64_t numOutputElements = 1;
2913 for (
auto operand : getOutputShape())
2914 numOutputElements *=
2915 cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
2916 if (numInputElements != numOutputElements)
2917 return emitError(
"product of input and output shape sizes must match");
2934 ArrayAttr offsets, ArrayAttr sizes,
2935 ArrayAttr strides) {
2936 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
2938 shape.reserve(vectorType.getRank());
2940 for (
unsigned e = offsets.size(); idx < e; ++idx)
2941 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
2942 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
2943 shape.push_back(vectorType.getShape()[idx]);
2958 offsetsAttr, sizesAttr, stridesAttr));
2959 result.
addAttribute(getOffsetsAttrStrName(), offsetsAttr);
2961 result.
addAttribute(getStridesAttrStrName(), stridesAttr);
2965 auto type = getSourceVectorType();
2966 auto offsets = getOffsetsAttr();
2967 auto sizes = getSizesAttr();
2968 auto strides = getStridesAttr();
2969 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
2971 "expected offsets, sizes and strides attributes of same size");
2973 auto shape = type.getShape();
2974 auto offName = getOffsetsAttrName();
2975 auto sizesName = getSizesAttrName();
2976 auto stridesName = getStridesAttrName();
2992 shape, offName, sizesName,
2997 offsets, sizes, strides);
2998 if (getResult().getType() != resultType)
2999 return emitOpError(
"expected result type to be ") << resultType;
3010 auto getElement = [](ArrayAttr array,
int idx) {
3011 return llvm::cast<IntegerAttr>(array[idx]).getInt();
3013 ArrayAttr extractOffsets = op.getOffsets();
3015 ArrayAttr extractSizes = op.getSizes();
3016 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3018 if (op.getSourceVectorType().getRank() !=
3019 insertOp.getSourceVectorType().getRank())
3021 ArrayAttr insertOffsets = insertOp.getOffsets();
3022 ArrayAttr insertStrides = insertOp.getStrides();
3025 if (extractOffsets.size() > insertOffsets.size())
3027 bool patialoverlap =
false;
3028 bool disjoint =
false;
3030 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3031 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
3033 int64_t start = getElement(insertOffsets, dim);
3034 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3035 int64_t offset = getElement(extractOffsets, dim);
3036 int64_t size = getElement(extractSizes, dim);
3038 if (start <= offset && offset < end) {
3041 if (offset + size > end)
3042 patialoverlap =
true;
3043 offsetDiffs.push_back(offset - start);
3050 if (!disjoint && !patialoverlap) {
3054 op->
setAttr(ExtractStridedSliceOp::getOffsetsAttrStrName(),
3061 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3071 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3072 if (getSourceVectorType() == getResult().getType())
3087 class StridedSliceConstantMaskFolder final
3096 auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3097 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3098 if (!constantMaskOp)
3101 if (extractStridedSliceOp.hasNonUnitStrides())
3115 sliceMaskDimSizes.reserve(maskDimSizes.size());
3116 for (
auto [maskDimSize, sliceOffset, sliceSize] :
3117 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3118 int64_t sliceMaskDimSize =
std::max(
3119 static_cast<int64_t
>(0),
3120 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3121 sliceMaskDimSizes.push_back(sliceMaskDimSize);
3124 if (sliceMaskDimSizes.size() < maskDimSizes.size())
3125 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3126 sliceMaskDimSizes.push_back(maskDimSizes[i]);
3129 if (llvm::is_contained(sliceMaskDimSizes, 0))
3130 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3135 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3142 class StridedSliceSplatConstantFolder final
3151 Value sourceVector = extractStridedSliceOp.getVector();
3156 auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3170 class StridedSliceNonSplatConstantFolder final
3179 Value sourceVector = extractStridedSliceOp.getVector();
3185 auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3186 if (!dense || dense.isSplat())
3190 if (extractStridedSliceOp.hasNonUnitStrides())
3193 auto sourceVecTy = llvm::cast<VectorType>(sourceVector.
getType());
3197 VectorType sliceVecTy = extractStridedSliceOp.getType();
3199 int64_t sliceRank = sliceVecTy.getRank();
3211 auto denseValuesBegin = dense.value_begin<
Attribute>();
3213 sliceValues.reserve(sliceVecTy.getNumElements());
3216 int64_t linearizedPosition =
linearize(currSlicePosition, sourceStrides);
3217 assert(linearizedPosition < sourceVecTy.getNumElements() &&
3219 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3223 assert(
static_cast<int64_t
>(sliceValues.size()) ==
3224 sliceVecTy.getNumElements() &&
3225 "Invalid number of slice elements");
3235 class StridedSliceBroadcast final
3242 auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
3247 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3248 auto dstVecType = llvm::cast<VectorType>(op.getType());
3249 unsigned dstRank = dstVecType.getRank();
3250 unsigned rankDiff = dstRank - srcRank;
3254 bool lowerDimMatch =
true;
3255 for (
unsigned i = 0; i < srcRank; i++) {
3256 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3257 lowerDimMatch =
false;
3266 bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3267 if (!lowerDimMatch && !isScalarSrc) {
3268 source = rewriter.
create<ExtractStridedSliceOp>(
3280 class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
3286 auto splat = op.getVector().getDefiningOp<SplatOp>();
3296 void ExtractStridedSliceOp::getCanonicalizationPatterns(
3300 results.
add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
3301 StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3302 StridedSliceSplat>(context);
3311 VectorType vectorType,
Value source,
3312 ValueRange indices, AffineMapAttr permutationMapAttr,
3313 ArrayAttr inBoundsAttr) {
3314 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
3315 Value padding = builder.
create<arith::ConstantOp>(
3317 build(builder, result, vectorType, source, indices, permutationMapAttr,
3318 padding,
Value(), inBoundsAttr);
3323 VectorType vectorType,
Value source,
3327 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3330 build(builder, result, vectorType, source, indices, permutationMapAttr,
3336 VectorType vectorType,
Value source,
3340 llvm::cast<ShapedType>(source.
getType()), vectorType);
3342 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3345 build(builder, result, vectorType, source, indices, permutationMapAttr,
3347 Value(), inBoundsAttr);
3353 VectorType vectorType,
Value source,
3356 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
3357 Value padding = builder.
create<arith::ConstantOp>(
3359 build(builder, result, vectorType, source, indices, padding, inBounds);
3362 template <
typename EmitFun>
3364 EmitFun emitOpError) {
3366 for (
auto expr : permutationMap.
getResults()) {
3370 if (zero.getValue() != 0) {
3372 "requires a projected permutation_map (at most one dim or the zero "
3373 "constant can appear in each result)");
3378 return emitOpError(
"requires a projected permutation_map (at most one "
3379 "dim or the zero constant can appear in each result)");
3381 if (seen[dim.getPosition()]) {
3383 "requires a permutation_map that is a permutation (found one dim "
3384 "used more than once)");
3386 seen[dim.getPosition()] =
true;
3393 VectorType vectorType, VectorType maskType,
3394 VectorType inferredMaskType,
AffineMap permutationMap,
3395 ArrayAttr inBounds) {
3397 return op->
emitOpError(
"masked attribute has been removed. "
3398 "Use in_bounds instead.");
3401 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
3403 "requires source to be a memref or ranked tensor type");
3405 auto elementType = shapedType.getElementType();
3406 DataLayout dataLayout = DataLayout::closest(op);
3407 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
3409 unsigned sourceVecSize =
3411 vectorElementType.getShape().back();
3412 unsigned resultVecSize =
3414 vectorType.getShape().back();
3415 if (resultVecSize % sourceVecSize != 0)
3417 "requires the bitwidth of the minor 1-D vector to be an integral "
3418 "multiple of the bitwidth of the minor 1-D vector of the source");
3420 unsigned sourceVecEltRank = vectorElementType.getRank();
3421 unsigned resultVecRank = vectorType.getRank();
3422 if (sourceVecEltRank > resultVecRank)
3424 "requires source vector element and vector result ranks to match.");
3425 unsigned rankOffset = resultVecRank - sourceVecEltRank;
3428 return op->
emitOpError(
"requires a permutation_map with result dims of "
3429 "the same rank as the vector type");
3432 return op->
emitOpError(
"does not support masks with vector element type");
3435 unsigned minorSize =
3436 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
3437 unsigned resultVecSize =
3441 "requires the bitwidth of the minor 1-D vector to be an integral "
3442 "multiple of the bitwidth of the source element type");
3446 return op->
emitOpError(
"requires a permutation_map with result dims of "
3447 "the same rank as the vector type");
3451 return op->
emitOpError(
"requires permutation_map without symbols");
3453 if (permutationMap.
getNumInputs() != shapedType.getRank())
3454 return op->
emitOpError(
"requires a permutation_map with input dims of the "
3455 "same rank as the source type");
3457 if (maskType && maskType != inferredMaskType)
3459 << inferredMaskType <<
") and mask operand type (" << maskType
3463 if (permutationMap.
getNumResults() !=
static_cast<int64_t
>(inBounds.size()))
3464 return op->
emitOpError(
"expects the optional in_bounds attr of same rank "
3465 "as permutation_map results: ")
3467 <<
" vs inBounds of size: " << inBounds.size();
3468 for (
unsigned int i = 0; i < permutationMap.
getNumResults(); ++i)
3470 !llvm::cast<BoolAttr>(inBounds.getValue()[i]).
getValue())
3471 return op->
emitOpError(
"requires broadcast dimensions to be in-bounds");
3479 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
3480 if (op.permutation_map().isMinorIdentity())
3481 elidedAttrs.push_back(op.getPermutationMapAttrStrName());
3482 bool elideInBounds =
true;
3483 if (
auto inBounds = op.in_bounds()) {
3484 for (
auto attr : *inBounds) {
3485 if (llvm::cast<BoolAttr>(attr).getValue()) {
3486 elideInBounds =
false;
3492 elidedAttrs.push_back(op.getInBoundsAttrStrName());
3497 p <<
" " << getSource() <<
"[" <<
getIndices() <<
"], " << getPadding();
3499 p <<
", " << getMask();
3512 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
3538 if (types.size() != 2)
3539 return parser.
emitError(typesLoc,
"requires two types");
3541 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
3542 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
3543 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
3544 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
3546 return parser.
emitError(typesLoc,
"requires vector type");
3547 auto permMapAttrName = TransferReadOp::getPermutationMapAttrStrName();
3554 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
3562 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
3564 maskInfo.
location,
"does not support masks with vector element type");
3571 result.
addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
3573 {1, static_cast<int32_t>(indexInfo.size()), 1,
3574 static_cast<int32_t>(hasMask.succeeded())}));
3580 ShapedType shapedType = getShapedType();
3582 VectorType maskType = getMaskType();
3583 auto paddingType = getPadding().getType();
3584 auto permutationMap = getPermutationMap();
3585 VectorType inferredMaskType =
3588 auto sourceElementType = shapedType.getElementType();
3590 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
3591 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
3594 shapedType, vectorType, maskType,
3595 inferredMaskType, permutationMap,
3596 getInBounds() ? *getInBounds() : ArrayAttr())))
3599 if (
auto sourceVectorElementType =
3600 llvm::dyn_cast<VectorType>(sourceElementType)) {
3603 if (sourceVectorElementType != paddingType)
3605 "requires source element type and padding type to match.");
3609 if (!VectorType::isValidElementType(paddingType))
3610 return emitOpError(
"requires valid padding vector elemental type");
3613 if (paddingType != sourceElementType)
3615 "requires formal padding and source of the same elemental type");
3619 [&](Twine t) {
return emitOpError(t); });
3626 Type TransferReadOp::getExpectedMaskType() {
3630 template <
typename TransferOp>
3631 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
3634 if (op.getShapedType().isDynamicDim(indicesIdx))
3636 Value index = op.getIndices()[indicesIdx];
3641 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
3642 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
3644 return cstOp.value() + vectorSize <= sourceSize;
3647 template <
typename TransferOp>
3651 if (op.getTransferRank() == 0)
3653 AffineMap permutationMap = op.getPermutationMap();
3654 bool changed =
false;
3656 newInBounds.reserve(op.getTransferRank());
3657 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
3659 if (op.isDimInBounds(i)) {
3660 newInBounds.push_back(
true);
3666 assert(dimExpr &&
"Broadcast dims must be in-bounds");
3669 newInBounds.push_back(inBounds);
3671 changed |= inBounds;
3677 op->
setAttr(TransferOp::getInBoundsAttrStrName(),
3692 static Value foldRAW(TransferReadOp readOp) {
3693 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
3695 auto defWrite = readOp.getSource().
getDefiningOp<vector::TransferWriteOp>();
3698 return defWrite.getVector();
3700 cast<VectorTransferOpInterface>(defWrite.getOperation()),
3701 cast<VectorTransferOpInterface>(readOp.getOperation())))
3703 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
3709 if (
Value vec = foldRAW(*
this))
3721 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
3725 void TransferReadOp::getEffects(
3728 if (llvm::isa<MemRefType>(getShapedType()))
3756 struct TransferReadAfterWriteToBroadcast
3762 if (readOp.hasOutOfBoundsDim() ||
3763 !llvm::isa<RankedTensorType>(readOp.getShapedType()))
3765 auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
3771 if (readOp.getIndices() == defWrite.getIndices() &&
3772 readOp.getMask() == defWrite.getMask()) {
3776 if (writeDims == readDims)
3777 vec = defWrite.getVector();
3801 broadcastShape[pos.value()] = destShape[pos.index()];
3803 broadcastShape, defWrite.getVectorType().getElementType());
3804 vec = rewriter.
create<vector::BroadcastOp>(loc, broadcastedType, vec);
3815 results.
add<TransferReadAfterWriteToBroadcast>(context);
3825 AffineMapAttr permutationMapAttr,
3827 ArrayAttr inBoundsAttr) {
3828 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
3829 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
3830 mask, inBoundsAttr);
3836 AffineMapAttr permutationMapAttr,
3837 ArrayAttr inBoundsAttr) {
3838 build(builder, result, vector, dest, indices, permutationMapAttr,
3839 Value(), inBoundsAttr);
3849 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3852 build(builder, result, vector, dest, indices, permutationMapAttr,
3853 Value(), inBoundsAttr);
3861 auto vectorType = llvm::cast<VectorType>(vector.
getType());
3863 llvm::cast<ShapedType>(dest.
getType()), vectorType);
3864 build(builder, result, vector, dest, indices, permutationMap, inBounds);
3885 if (types.size() != 2)
3886 return parser.
emitError(typesLoc,
"requires two types");
3888 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
3890 return parser.
emitError(typesLoc,
"requires vector type");
3891 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
3892 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
3893 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
3894 auto permMapAttrName = TransferWriteOp::getPermutationMapAttrStrName();
3901 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
3908 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
3910 maskInfo.
location,
"does not support masks with vector element type");
3915 result.
addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
3917 {1, 1, static_cast<int32_t>(indexInfo.size()),
3918 static_cast<int32_t>(hasMask.succeeded())}));
3919 return failure(llvm::isa<RankedTensorType>(shapedType) &&
3924 p <<
" " << getVector() <<
", " << getSource() <<
"[" <<
getIndices() <<
"]";
3926 p <<
", " << getMask();
3933 ShapedType shapedType = getShapedType();
3935 VectorType maskType = getMaskType();
3936 auto permutationMap = getPermutationMap();
3937 VectorType inferredMaskType =
3941 if (llvm::size(
getIndices()) != shapedType.getRank())
3942 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
3946 if (hasBroadcastDim())
3947 return emitOpError(
"should not have broadcast dimensions");
3950 shapedType, vectorType, maskType,
3951 inferredMaskType, permutationMap,
3952 getInBounds() ? *getInBounds() : ArrayAttr())))
3956 [&](Twine t) {
return emitOpError(t); });
3963 Type TransferWriteOp::getExpectedMaskType() {
3984 static LogicalResult foldReadInitWrite(TransferWriteOp write,
3988 if (write.getTransferRank() == 0)
3990 auto rankedTensorType =
3991 llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
3993 if (!rankedTensorType)
3996 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4000 if (read.getTransferRank() == 0)
4003 if (!read.getPermutationMap().isMinorIdentity() ||
4004 !write.getPermutationMap().isMinorIdentity())
4007 if (read.getTransferRank() != write.getTransferRank())
4010 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4013 if (read.getSource().getType() != rankedTensorType)
4016 if (read.getVectorType() != write.getVectorType())
4019 if (read.getVectorType().getShape() != rankedTensorType.getShape())
4022 auto isNotConstantZero = [](
Value v) {
4023 auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>();
4024 return !cstOp || cstOp.value() != 0;
4026 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4027 llvm::any_of(write.getIndices(), isNotConstantZero))
4030 results.push_back(read.getSource());
4034 static bool checkSameValueWAR(vector::TransferReadOp read,
4035 vector::TransferWriteOp write) {
4036 return read.getSource() == write.getSource() &&
4037 read.getIndices() == write.getIndices() &&
4038 read.getPermutationMap() == write.getPermutationMap() &&
4039 read.getVectorType() == write.getVectorType() && !read.getMask() &&
4058 if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
4060 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4064 if (!checkSameValueWAR(read, write))
4066 results.push_back(read.getSource());
4072 if (
succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
4081 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4085 void TransferWriteOp::getEffects(
4088 if (llvm::isa<MemRefType>(getShapedType()))
4123 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4125 vector::TransferWriteOp writeToModify = writeOp;
4128 writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4132 writeToModify.getSourceMutable().assign(defWrite.getSource());
4137 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4138 cast<VectorTransferOpInterface>(writeOp.getOperation())))
4142 if (!defWrite->hasOneUse())
4144 writeToModify = defWrite;
4145 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4174 struct SwapExtractSliceOfTransferWrite
4181 if (!insertOp.hasUnitStride())
4184 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
4185 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
4187 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
4188 if (!transferOp || !transferOp->hasOneUse())
4193 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
4195 "use-def chain is rank-reducing");
4199 if (!extractOp.hasZeroOffset()) {
4201 "ExtractSliceOp has non-zero offset");
4205 if (!llvm::all_of(transferOp.getIndices(), [](
Value value) {
4209 "TranferWriteOp has non-zero offset");
4213 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
4215 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
4218 for (
auto [insertSize, extractSize] :
4219 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
4222 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
4227 assert(transferOp.getVectorType().hasStaticShape() &&
4228 "expected vector to have a static shape");
4231 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
4232 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
4234 insertOp,
"TransferWriteOp may not write the full tensor.");
4239 transferOp.getPermutationMap(), insertOp.getSourceType().getShape());
4241 for (
const auto &en :
enumerate(newResultShape))
4242 newInBounds.push_back(en.value() ==
vectorShape[en.index()]);
4243 auto newExtractOp = rewriter.
create<tensor::ExtractSliceOp>(
4244 extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
4245 insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
4246 insertOp.getMixedStrides());
4247 auto newTransferWriteOp = rewriter.
create<TransferWriteOp>(
4248 transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
4249 transferOp.getIndices(), transferOp.getPermutationMapAttr(),
4252 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
4262 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
4270 MemRefType memRefTy) {
4272 return op->
emitOpError(
"most minor memref dim must have unit stride");
4280 if (
failed(verifyLoadStoreMemRefLayout(*
this, memRefTy)))
4284 Type memElemTy = memRefTy.getElementType();
4285 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4286 if (memVecTy != resVecTy)
4287 return emitOpError(
"base memref and result vector types should match");
4288 memElemTy = memVecTy.getElementType();
4291 if (resVecTy.getElementType() != memElemTy)
4292 return emitOpError(
"base and result element types should match");
4293 if (llvm::size(
getIndices()) != memRefTy.getRank())
4294 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
4312 if (
failed(verifyLoadStoreMemRefLayout(*
this, memRefTy)))
4316 Type memElemTy = memRefTy.getElementType();
4317 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4318 if (memVecTy != valueVecTy)
4320 "base memref and valueToStore vector types should match");
4321 memElemTy = memVecTy.getElementType();
4324 if (valueVecTy.getElementType() != memElemTy)
4325 return emitOpError(
"base and valueToStore element type should match");
4326 if (llvm::size(
getIndices()) != memRefTy.getRank())
4327 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
4341 VectorType maskVType = getMaskVectorType();
4342 VectorType passVType = getPassThruVectorType();
4346 if (resVType.getElementType() != memType.getElementType())
4347 return emitOpError(
"base and result element type should match");
4348 if (llvm::size(
getIndices()) != memType.getRank())
4349 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
4350 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4351 return emitOpError(
"expected result dim to match mask dim");
4352 if (resVType != passVType)
4353 return emitOpError(
"expected pass_thru of same type as result type");
4366 load, load.getType(), load.getBase(), load.getIndices());
4369 rewriter.
replaceOp(load, load.getPassThru());
4374 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
4381 results.
add<MaskedLoadFolder>(context);
4395 VectorType maskVType = getMaskVectorType();
4399 if (valueVType.getElementType() != memType.getElementType())
4400 return emitOpError(
"base and valueToStore element type should match");
4401 if (llvm::size(
getIndices()) != memType.getRank())
4402 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
4403 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4404 return emitOpError(
"expected valueToStore dim to match mask dim");
4417 store, store.getValueToStore(), store.getBase(), store.getIndices());
4425 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
4432 results.
add<MaskedStoreFolder>(context);
4445 VectorType indVType = getIndexVectorType();
4446 VectorType maskVType = getMaskVectorType();
4448 ShapedType baseType = getBaseType();
4450 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
4451 return emitOpError(
"requires base to be a memref or ranked tensor type");
4453 if (resVType.getElementType() != baseType.getElementType())
4454 return emitOpError(
"base and result element type should match");
4455 if (llvm::size(
getIndices()) != baseType.getRank())
4456 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
4457 if (resVType.getShape() != indVType.getShape())
4458 return emitOpError(
"expected result dim to match indices dim");
4459 if (resVType.getShape() != maskVType.getShape())
4460 return emitOpError(
"expected result dim to match mask dim");
4461 if (resVType != getPassThruVectorType())
4462 return emitOpError(
"expected pass_thru of same type as result type");
4470 Type GatherOp::getExpectedMaskType() {
4471 auto vecType = this->getIndexVectorType();
4476 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
4490 rewriter.
replaceOp(gather, gather.getPassThru());
4495 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
4502 results.
add<GatherFolder>(context);
4510 VectorType indVType = getIndexVectorType();
4511 VectorType maskVType = getMaskVectorType();
4515 if (valueVType.getElementType() != memType.getElementType())
4516 return emitOpError(
"base and valueToStore element type should match");
4517 if (llvm::size(
getIndices()) != memType.getRank())
4518 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
4519 if (valueVType.getDimSize(0) != indVType.getDimSize(0))
4520 return emitOpError(
"expected valueToStore dim to match indices dim");
4521 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4522 return emitOpError(
"expected valueToStore dim to match mask dim");
4541 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
4548 results.
add<ScatterFolder>(context);
4556 VectorType maskVType = getMaskVectorType();
4557 VectorType passVType = getPassThruVectorType();
4561 if (resVType.getElementType() != memType.getElementType())
4562 return emitOpError(
"base and result element type should match");
4563 if (llvm::size(
getIndices()) != memType.getRank())
4564 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
4565 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4566 return emitOpError(
"expected result dim to match mask dim");
4567 if (resVType != passVType)
4568 return emitOpError(
"expected pass_thru of same type as result type");
4581 expand, expand.getType(), expand.getBase(), expand.getIndices());
4584 rewriter.
replaceOp(expand, expand.getPassThru());
4589 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
4596 results.
add<ExpandLoadFolder>(context);
4604 VectorType maskVType = getMaskVectorType();
4608 if (valueVType.getElementType() != memType.getElementType())
4609 return emitOpError(
"base and valueToStore element type should match");
4610 if (llvm::size(
getIndices()) != memType.getRank())
4611 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
4612 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4613 return emitOpError(
"expected valueToStore dim to match mask dim");
4618 class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
4626 compress, compress.getValueToStore(), compress.getBase(),
4627 compress.getIndices());
4635 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
4642 results.
add<CompressStoreFolder>(context);
4652 unsigned rankA = a.size();
4653 unsigned rankB = b.size();
4654 assert(rankA < rankB);
4656 auto isOne = [](int64_t v) {
return v == 1; };
4660 if (rankA == 0 && llvm::all_of(b, isOne))
4665 while (i < rankA &&
j < rankB) {
4666 int64_t dimA = a[i];
4668 while (dimB < dimA &&
j < rankB)
4676 if (i < rankA && llvm::all_of(a.slice(i), isOne))
4678 if (
j < rankB && llvm::all_of(b.slice(
j), isOne))
4682 return i == rankA &&
j == rankB;
4686 VectorType sourceVectorType,
4687 VectorType resultVectorType) {
4689 if (sourceVectorType.getElementType() != resultVectorType.getElementType())
4690 return op->
emitOpError(
"source/result vectors must have same element type");
4691 auto sourceShape = sourceVectorType.getShape();
4692 auto resultShape = resultVectorType.getShape();
4695 int64_t sourceDimProduct = std::accumulate(
4696 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
4697 int64_t resultDimProduct = std::accumulate(
4698 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
4699 if (sourceDimProduct != resultDimProduct)
4700 return op->
emitOpError(
"source/result number of elements must match");
4703 unsigned sourceRank = sourceVectorType.getRank();
4704 unsigned resultRank = resultVectorType.getRank();
4705 if (sourceRank < resultRank) {
4706 if (!isValidShapeCast(sourceShape, resultShape))
4708 }
else if (sourceRank > resultRank) {
4709 if (!isValidShapeCast(resultShape, sourceShape))
4716 auto sourceVectorType =
4717 llvm::dyn_cast_or_null<VectorType>(getSource().getType());
4718 auto resultVectorType =
4719 llvm::dyn_cast_or_null<VectorType>(getResult().getType());
4722 if (sourceVectorType && resultVectorType)
4723 return verifyVectorShapeCast(*
this, sourceVectorType, resultVectorType);
4730 if (getSource().getType() == getResult().getType())
4734 if (
auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
4735 if (getResult().getType() == otherOp.getSource().getType())
4736 return otherOp.getSource();
4739 VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
4740 VectorType resultType = llvm::cast<VectorType>(getResult().getType());
4741 if (srcType.getRank() < resultType.getRank()) {
4742 if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
4744 }
else if (srcType.getRank() > resultType.getRank()) {
4745 if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
4751 setOperand(otherOp.getSource());
4756 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
4757 if (bcastOp.getSourceType() == getType())
4758 return bcastOp.getSource();
4766 class ShapeCastConstantFolder final :
public OpRewritePattern<ShapeCastOp> {
4773 shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
4777 auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
4792 class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
4799 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
4803 auto broadcastSourceVectorType =
4804 llvm::dyn_cast<VectorType>(broadcastOp.getSourceType());
4805 auto broadcastSourceShape = broadcastSourceVectorType
4806 ? broadcastSourceVectorType.getShape()
4808 auto shapeCastTargetShape = shapeCastOp.getResultVectorType().getShape();
4811 bool isSuffix = (broadcastSourceShape == shapeCastTargetShape.take_back(
4812 broadcastSourceShape.size()));
4817 shapeCastOp, shapeCastOp.getResultVectorType(),
4818 broadcastOp.getSource());
4827 results.
add<ShapeCastConstantFolder, ShapeCastBroadcastFolder>(context);
4835 auto sourceVectorType = getSourceVectorType();
4836 auto resultVectorType = getResultVectorType();
4838 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
4839 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
4840 return emitOpError(
"dimension size mismatch at: ") << i;
4843 DataLayout dataLayout = DataLayout::closest(*
this);
4844 auto sourceElementBits =
4846 auto resultElementBits =
4849 if (sourceVectorType.getRank() == 0) {
4850 if (sourceElementBits != resultElementBits)
4851 return emitOpError(
"source/result bitwidth of the 0-D vector element "
4852 "types must be equal");
4853 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
4854 resultElementBits * resultVectorType.getShape().back()) {
4856 "source/result bitwidth of the minor 1-D vectors must be equal");
4864 if (getSource().getType() == getResult().getType())
4868 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
4869 if (getResult().getType() == otherOp.getSource().getType())
4870 return otherOp.getSource();
4872 setOperand(otherOp.getSource());
4876 Attribute sourceConstant = adaptor.getSource();
4877 if (!sourceConstant)
4880 Type srcElemType = getSourceVectorType().getElementType();
4881 Type dstElemType = getResultVectorType().getElementType();
4883 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
4884 if (floatPack.isSplat()) {
4885 auto splat = floatPack.getSplatValue<FloatAttr>();
4888 if (srcElemType.
isF16() && dstElemType.
isF32()) {
4889 uint32_t bits =
static_cast<uint32_t
>(
4890 splat.getValue().bitcastToAPInt().getZExtValue());
4892 bits = (bits << 16) | (bits & 0xffff);
4893 APInt intBits(32, bits);
4894 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
4900 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
4901 if (intPack.isSplat()) {
4902 auto splat = intPack.getSplatValue<IntegerAttr>();
4904 if (llvm::isa<IntegerType>(dstElemType)) {
4909 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
4910 APInt intBits = splat.getValue().zext(dstBitWidth);
4913 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
4914 intBits = (intBits << srcBitWidth) | intBits;
4929 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
4931 memRefType.getShape().end());
4933 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
4942 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
4943 VectorType vectorType =
4947 memRefType.getMemorySpace()));
4952 if (!canonicalType.getLayout().isIdentity())
4953 return emitOpError(
"expects operand to be a memref with identity layout");
4954 if (!getResultMemRefType().getLayout().isIdentity())
4955 return emitOpError(
"expects result to be a memref with identity layout");
4956 if (getResultMemRefType().getMemorySpace() !=
4958 return emitOpError(
"expects result in same memory space");
4961 auto resultType = getResultMemRefType();
4965 "expects result and operand with same underlying scalar type: ")
4967 if (extractShape(sourceType) != extractShape(resultType))
4969 "expects concatenated result and operand shapes to be equal: ")
4980 VectorType vt = llvm::cast<VectorType>(vector.
getType());
4982 for (
unsigned i = 0; i < transp.size(); ++i)
4983 transposedShape[i] = vt.getShape()[transp[i]];
4990 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
4992 if (
auto attr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
4994 return attr.reshape(getResultVectorType());
5003 for (int64_t i = 0, e = transp.size(); i < e; i++) {
5012 VectorType vectorType = getSourceVectorType();
5013 VectorType resultType = getResultVectorType();
5014 int64_t rank = resultType.getRank();
5015 if (vectorType.getRank() != rank)
5016 return emitOpError(
"vector result rank mismatch: ") << rank;
5018 auto transpAttr = getTransp().getValue();
5019 int64_t size = transpAttr.size();
5021 return emitOpError(
"transposition length mismatch: ") << size;
5024 int64_t i = llvm::cast<IntegerAttr>(ta.value()).getInt();
5025 if (i < 0 || i >= rank)
5026 return emitOpError(
"transposition index out of range: ") << i;
5028 return emitOpError(
"duplicate position index: ") << i;
5030 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
5031 return emitOpError(
"dimension size mismatch at: ") << i;
5036 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
5037 return llvm::to_vector<4>(getResultVectorType().
getShape());
5043 class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
5050 auto getPermutation = [](vector::TransposeOp transpose) {
5052 transpose.getTransp(permutation);
5060 for (
auto index : permutation2)
5061 result.push_back(permutation1[index]);
5066 vector::TransposeOp parentTransposeOp =
5067 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
5068 if (!parentTransposeOp)
5072 getPermutation(parentTransposeOp), getPermutation(transposeOp));
5075 transposeOp, transposeOp.getResult().getType(),
5076 parentTransposeOp.getVector(),
5083 struct FoldTransposedScalarBroadcast final
5089 auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
5093 auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
5094 if (!srcVectorType || srcVectorType.getNumElements() == 1) {
5096 transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
5111 auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
5116 transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
5122 class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
5128 Value transposeSrc = transpOp.getVector();
5129 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
5130 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
5131 if (!createMaskOp && !constantMaskOp)
5137 transpOp.getTransp(permutation);
5140 auto maskOperands = createMaskOp.getOperands();
5145 transpOp, transpOp.getResultVectorType(), newOperands);
5150 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5155 transpOp, transpOp.getResultVectorType(),
5163 void vector::TransposeOp::getCanonicalizationPatterns(
5165 results.
add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
5166 TransposeFolder, FoldTransposeSplat>(context);
5178 auto resultType = llvm::cast<VectorType>(getResult().getType());
5180 if (resultType.getRank() == 0) {
5181 if (getMaskDimSizes().size() != 1)
5182 return emitError(
"array attr must have length 1 for 0-D vectors");
5183 auto dim = llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt();
5184 if (dim != 0 && dim != 1)
5185 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
5190 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
5192 "must specify array attr of size equal vector result rank");
5195 auto resultShape = resultType.getShape();
5198 int64_t attrValue = llvm::cast<IntegerAttr>(it.value()).getInt();
5199 if (attrValue < 0 || attrValue > resultShape[it.index()])
5201 "array attr of size out of bounds of vector result dimension size");
5202 maskDimSizes.push_back(attrValue);
5206 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
5207 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
5208 if (anyZeros && !allZeros)
5209 return emitOpError(
"expected all mask dim sizes to be zeros, "
5210 "as a result of conjunction with zero mask dim");
5216 if (resultType.isScalable() &&
5217 llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() != 0)
5218 return emitOpError(
"expected mask dim sizes for scalable masks to be 0");
5231 build(builder, result, type, operands);
5235 auto vectorType = llvm::cast<VectorType>(getResult().getType());
5237 if (vectorType.getRank() == 0) {
5238 if (getNumOperands() != 1)
5240 "must specify exactly one operand for 0-D create_mask");
5241 }
else if (getNumOperands() !=
5242 llvm::cast<VectorType>(getResult().getType()).getRank()) {
5244 "must specify an operand for each result vector dimension");
5259 auto isNotDefByConstant = [](
Value operand) {
5260 return !isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
5262 if (llvm::any_of(createMaskOp.getOperands(), isNotDefByConstant))
5267 if (
auto vType = llvm::dyn_cast<VectorType>(createMaskOp.getType())) {
5268 if (vType.isScalable())
5269 for (
auto opDim : createMaskOp.getOperands()) {
5272 intVal.isStrictlyPositive())
5279 maskDimSizes.reserve(createMaskOp->getNumOperands());
5280 for (
auto [operand, maxDimSize] : llvm::zip_equal(
5281 createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
5282 Operation *defOp = operand.getDefiningOp();
5283 int64_t dimSize = cast<arith::ConstantIndexOp>(defOp).value();
5284 dimSize =
std::min(dimSize, maxDimSize);
5287 maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
5290 maskDimSizes.push_back(dimSize);
5294 createMaskOp, createMaskOp.getResult().getType(),
5304 results.
add<CreateMaskFolder>(context);
5315 assert(maskRegionBuilder &&
5316 "builder callback for 'maskRegion' must be present");
5322 maskRegionBuilder(builder, maskableOp);
5329 build(builder, result, resultTypes, mask,
Value(), maskableOp,
5337 build(builder, result, mask, maskableOp, maskRegionBuilder);
5365 MaskOp::ensureTerminator(maskRegion, builder, result.
location);
5379 result.
types.append(resultTypes);
5393 p <<
" " << getMask();
5395 p <<
", " << getPassthru();
5399 Block *singleBlock = &getMaskRegion().getBlocks().
front();
5406 p <<
" : " << getMask().getType();
5407 if (getNumResults() > 0)
5408 p <<
" -> " << getResultTypes();
5413 MaskOp>::ensureTerminator(region, builder, loc);
5425 assert(isa<vector::YieldOp>(oldYieldOp) &&
"Expected vector::YieldOp");
5428 if (maskedOp == oldYieldOp)
5431 opBuilder.setInsertionPoint(oldYieldOp);
5432 opBuilder.create<vector::YieldOp>(loc, maskedOp->
getResults());
5434 oldYieldOp->
erase();
5439 Block &block = getMaskRegion().getBlocks().
front();
5441 return emitOpError(
"expects a terminator within the mask region");
5443 return emitOpError(
"expects only one operation to mask");
5446 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
5448 return emitOpError(
"expects a terminator within the mask region");
5450 if (terminator->getNumOperands() != getNumResults())
5452 "expects number of results to match mask region yielded values");
5454 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
5461 return emitOpError(
"expects number of results to match maskable operation "
5462 "number of results");
5464 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
5466 "expects result type to match maskable operation result type");
5469 [](
Type t) { return llvm::isa<VectorType>(t); }) > 1)
5470 return emitOpError(
"multiple vector results not supported");
5473 Type expectedMaskType = maskableOp.getExpectedMaskType();
5474 if (getMask().getType() != expectedMaskType)
5475 return emitOpError(
"expects a ")
5476 << expectedMaskType <<
" mask for the maskable operation";
5479 Value passthru = getPassthru();
5481 if (!maskableOp.supportsPassthru())
5483 "doesn't expect a passthru argument for this maskable operation");
5486 return emitOpError(
"expects result when passthru argument is provided");
5489 return emitOpError(
"expects passthru type to match result type");
5506 Operation *maskableOp = getMaskableOp();
5510 results.push_back(maskableOp->
getResult(0));
5522 auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
5523 if (maskingOp.getMaskableOp())
5526 if (!maskOp.isEmpty())
5529 Block *block = maskOp.getMaskBlock();
5530 auto terminator = cast<vector::YieldOp>(block->
front());
5531 if (terminator.getNumOperands() == 0)
5534 rewriter.
replaceOp(maskOp, terminator.getOperands());
5542 results.
add<ElideEmptyMaskOp>(context);
5549 Block *block = getMaskBlock();
5553 return &block->
front();
5557 bool MaskOp::hasPassthru() {
return getPassthru() !=
Value(); }
5564 VectorType srcType = getSourceType();
5565 VectorType initialType = getInitialValueType();
5567 int64_t srcRank = srcType.getRank();
5568 int64_t reductionDim = getReductionDim();
5569 if (reductionDim >= srcRank)
5570 return emitOpError(
"reduction dimension ")
5571 << reductionDim <<
" has to be less than " << srcRank;
5574 int64_t initialValueRank = initialType.getRank();
5575 if (initialValueRank != srcRank - 1)
5576 return emitOpError(
"initial value rank ")
5577 << initialValueRank <<
" has to be equal to " << srcRank - 1;
5583 for (
int i = 0; i < srcRank; i++) {
5584 if (i != reductionDim)
5585 expectedShape.push_back(srcShape[i]);
5587 if (!llvm::equal(initialValueShapes, expectedShape)) {
5588 return emitOpError(
"incompatible input/initial value shapes");
5592 Type eltType = getDestType().getElementType();
5594 return emitOpError(
"unsupported reduction type ")
5595 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
5604 .
add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
5605 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
5606 StridedSliceConstantMaskFolder, TransposeFolder>(
5615 auto constOperand = adaptor.getInput();
5616 if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
5628 p <<
"(" << getLaneid() <<
")";
5631 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
5632 p <<
"[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() <<
"]";
5634 if (!getArgs().empty())
5635 p <<
" args(" << getArgs() <<
" : " << getArgs().getTypes() <<
")";
5636 if (!getResults().empty())
5637 p <<
" -> (" << getResults().getTypes() <<
')';
5641 !getResults().empty());
5671 llvm::SMLoc inputsOperandsLoc;
5683 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
5694 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.
location);
5702 void WarpExecuteOnLane0Op::getSuccessorRegions(
5717 build(builder, result, resultTypes, laneId, warpSize,
5718 std::nullopt, std::nullopt);
5730 assert(args.size() == blockArgTypes.size());
5734 for (
auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
5743 if (expanded == distributed)
5745 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
5746 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
5747 if (!expandedVecType || !distributedVecType)
5748 return op->
emitOpError(
"expected vector type for distributed operands.");
5749 if (expandedVecType.getRank() != distributedVecType.getRank() ||
5750 expandedVecType.getElementType() != distributedVecType.getElementType())
5752 "expected distributed vectors to have same rank and element type.");
5753 bool foundDistributedDim =
false;
5754 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
5755 if (expandedVecType.getDimSize(i) == distributedVecType.getDimSize(i))
5757 if (expandedVecType.getDimSize(i) ==
5758 distributedVecType.getDimSize(i) * warpSize) {
5759 if (foundDistributedDim)
5761 <<
"expected only one dimension to be distributed from "
5762 << expandedVecType <<
" to " << distributedVecType;
5763 foundDistributedDim =
true;
5766 return op->
emitOpError() <<
"incompatible distribution dimensions from "
5767 << expandedVecType <<
" to " << distributedVecType;
5773 if (getArgs().size() != getWarpRegion().getNumArguments())
5775 "expected same number op arguments and block arguments.");
5777 cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
5778 if (yield.getNumOperands() != getNumResults())
5780 "expected same number of yield operands and return values.");
5781 int64_t warpSize = getWarpSize();
5782 for (
auto [regionArg, arg] :
5783 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
5784 if (
failed(verifyDistributedType(regionArg.getType(), arg.getType(),
5785 warpSize, getOperation())))
5788 for (
auto [yieldOperand, result] :
5789 llvm::zip_equal(yield.getOperands(), getResults())) {
5790 if (
failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
5791 warpSize, getOperation())))
5797 bool WarpExecuteOnLane0Op::areTypesCompatible(
Type lhs,
Type rhs) {
5799 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
5810 case CombiningKind::ADD:
5813 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
5816 llvm_unreachable(
"invalid value types for ADD reduction");
5818 case CombiningKind::AND:
5822 case CombiningKind::MAXF:
5823 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
5824 "expected float values");
5827 case CombiningKind::MINF:
5828 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
5829 "expected float values");
5832 case CombiningKind::MAXSI:
5836 case CombiningKind::MINSI:
5840 case CombiningKind::MAXUI:
5844 case CombiningKind::MINUI:
5848 case CombiningKind::MUL:
5851 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
5854 llvm_unreachable(
"invalid value types for MUL reduction");
5856 case CombiningKind::OR:
5860 case CombiningKind::XOR:
5866 assert(result &&
"unknown CombiningKind");
5878 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");