40 #include "llvm/ADT/ArrayRef.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SmallVector.h"
43 #include "llvm/ADT/StringSet.h"
44 #include "llvm/ADT/TypeSwitch.h"
50 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
52 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
73 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
75 for (
bool b : denseElts.getValues<
bool>())
78 else if (!b && val <= 0)
92 auto shape = m.getType().getShape();
95 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
96 if (maskIdx < dimSize)
109 auto maskOperands = m.getOperands();
110 for (
Value operand : maskOperands) {
111 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
113 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
126 builder.
create<vector::YieldOp>(loc);
132 switch (combiningKind) {
133 case CombiningKind::ADD:
134 case CombiningKind::MUL:
137 case CombiningKind::MINSI:
138 case CombiningKind::MAXUI:
139 case CombiningKind::MAXSI:
140 case CombiningKind::AND:
141 case CombiningKind::OR:
142 case CombiningKind::XOR:
144 case CombiningKind::MINNUMF:
145 case CombiningKind::MAXNUMF:
146 case CombiningKind::MINIMUMF:
147 case CombiningKind::MAXIMUMF:
148 return llvm::isa<FloatType>(elementType);
154 VectorType vectorType) {
155 int64_t elementVectorRank = 0;
156 VectorType elementVectorType =
157 llvm::dyn_cast<VectorType>(shapedType.getElementType());
158 if (elementVectorType)
159 elementVectorRank += elementVectorType.getRank();
162 if (shapedType.getRank() == 0 &&
168 shapedType.getRank(), vectorType.getRank() - elementVectorRank,
169 shapedType.getContext());
176 vector::TransferReadOp read) {
177 auto readMask = read.getMask();
178 auto writeMask = write.getMask();
184 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
185 if (!couldBeSameSplat)
190 m_Constant<DenseElementsAttr>(&splatAttr)) ||
202 vector::TransferReadOp read) {
203 return !defWrite.hasOutOfBoundsDim() &&
204 defWrite.getIndices() == read.getIndices() &&
205 defWrite.getVectorType() == read.getVectorType() &&
206 defWrite.getPermutationMap() == read.getPermutationMap() &&
207 ((!defWrite.getMask() && !read.getMask()) ||
212 vector::TransferWriteOp priorWrite) {
213 return priorWrite.getIndices() == write.getIndices() &&
214 priorWrite.getMask() == write.getMask() &&
215 priorWrite.getVectorType() == write.getVectorType() &&
216 priorWrite.getPermutationMap() == write.getPermutationMap();
220 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
221 bool testDynamicValueUsingBounds) {
223 if (transferA.getVectorType() != transferB.getVectorType())
225 unsigned rankOffset = transferA.getLeadingShapedRank();
226 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
227 Value indexA = transferA.getIndices()[i];
228 Value indexB = transferB.getIndices()[i];
232 if (i < rankOffset) {
235 if (cstIndexA.has_value() && cstIndexB.has_value()) {
236 if (*cstIndexA != *cstIndexB)
240 if (testDynamicValueUsingBounds) {
243 FailureOr<uint64_t> delta =
245 if (succeeded(delta) && *delta != 0)
248 FailureOr<bool> testEqual =
250 if (succeeded(testEqual) && !testEqual.value())
256 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
257 if (cstIndexA.has_value() && cstIndexB.has_value()) {
258 int64_t distance =
std::abs(*cstIndexA - *cstIndexB);
259 if (distance >= vectorDim)
263 if (testDynamicValueUsingBounds) {
266 FailureOr<int64_t> delta =
268 if (succeeded(delta) &&
std::abs(*delta) >= vectorDim)
271 FailureOr<int64_t> computeDelta =
273 if (succeeded(computeDelta)) {
274 if (
std::abs(computeDelta.value()) >= vectorDim)
284 VectorTransferOpInterface transferB,
285 bool testDynamicValueUsingBounds) {
286 if (transferA.getSource() != transferB.getSource())
289 testDynamicValueUsingBounds);
299 for (
auto [posInDim, dimSize, offsetInDim] :
300 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
302 if (posInDim < dimSize + offsetInDim)
306 posInDim = offsetInDim;
316 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
318 assert(constOp &&
"Unexpected non-constant index");
319 return constOp.value();
329 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
330 assert(isa<Attribute>(foldResult) &&
"Unexpected non-constant index");
331 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
341 llvm::transform(foldResults, std::back_inserter(values),
343 if (
auto attr = foldResult.dyn_cast<
Attribute>())
346 loc, cast<IntegerAttr>(attr).getInt())
349 return cast<Value>(foldResult);
360 auto lhs = mul.getLhs();
361 auto rhs = mul.getRhs();
362 if (lhs.getDefiningOp<vector::VectorScaleOp>())
364 if (rhs.getDefiningOp<vector::VectorScaleOp>())
412 void VectorDialect::initialize() {
414 #define GET_ATTRDEF_LIST
415 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
420 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
423 addInterfaces<VectorInlinerInterface>();
425 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
426 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
428 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
430 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
431 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
432 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
440 if (isa<ub::PoisonAttrInterface>(value))
443 return arith::ConstantOp::materialize(builder, value, type, loc);
459 void vector::MultiDimReductionOp::build(
OpBuilder &builder,
462 CombiningKind kind) {
466 reductionDims.push_back(en.index());
467 build(builder, result, kind, source, acc, reductionDims);
470 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
472 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
477 std::optional<SmallVector<int64_t, 4>>
478 MultiDimReductionOp::getShapeForUnroll() {
479 return llvm::to_vector<4>(getSourceVectorType().
getShape());
485 Type inferredReturnType;
486 auto sourceScalableDims = getSourceVectorType().getScalableDims();
487 for (
auto [dimIdx, dimSize] :
489 if (!llvm::any_of(getReductionDims(),
490 [dimIdx = dimIdx](int64_t reductionDimIdx) {
491 return reductionDimIdx ==
static_cast<int64_t
>(dimIdx);
493 targetShape.push_back(dimSize);
494 scalableDims.push_back(sourceScalableDims[dimIdx]);
497 if (targetShape.empty())
498 inferredReturnType = getSourceVectorType().getElementType();
501 targetShape, getSourceVectorType().
getElementType(), scalableDims);
502 if (
getType() != inferredReturnType)
503 return emitOpError() <<
"destination type " <<
getType()
504 <<
" is incompatible with source type "
505 << getSourceVectorType();
511 Type MultiDimReductionOp::getExpectedMaskType() {
512 auto vecType = getSourceVectorType();
515 vecType.getScalableDims());
524 struct ElideUnitDimsInMultiDimReduction
528 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
531 for (
const auto &dim :
enumerate(shape)) {
532 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
540 if (reductionOp.isMasked()) {
542 rootOp = reductionOp.getMaskingOp();
543 mask = reductionOp.getMaskingOp().getMask();
545 rootOp = reductionOp;
548 Location loc = reductionOp.getLoc();
549 Value acc = reductionOp.getAcc();
551 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
553 VectorType newMaskType =
555 dstVecType.getScalableDims());
556 mask = rewriter.
create<vector::ShapeCastOp>(loc, newMaskType, mask);
558 cast = rewriter.
create<vector::ShapeCastOp>(
559 loc, reductionOp.getDestType(), reductionOp.getSource());
565 mask = rewriter.
create<vector::ExtractOp>(loc, mask, zeroIdx);
566 cast = rewriter.
create<vector::ExtractOp>(loc, reductionOp.getSource(),
572 cast,
nullptr, mask);
579 void MultiDimReductionOp::getCanonicalizationPatterns(
581 results.
add<ElideUnitDimsInMultiDimReduction>(context);
589 CombiningKind kind,
Value vector,
590 arith::FastMathFlags fastMathFlags) {
591 build(builder, result, kind, vector,
Value(), fastMathFlags);
596 arith::FastMathFlags fastMathFlags) {
597 build(builder, result,
598 llvm::cast<VectorType>(vector.
getType()).getElementType(), kind, vector,
604 int64_t rank = getSourceVectorType().getRank();
606 return emitOpError(
"unsupported reduction rank: ") << rank;
609 Type eltType = getDest().getType();
611 return emitOpError(
"unsupported reduction type '")
612 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
621 Type ReductionOp::getExpectedMaskType() {
622 auto vecType = getSourceVectorType();
625 vecType.getScalableDims());
632 case arith::AtomicRMWKind::addf:
633 case arith::AtomicRMWKind::addi:
634 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
635 CombiningKind::ADD, vector);
636 case arith::AtomicRMWKind::mulf:
637 case arith::AtomicRMWKind::muli:
638 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
639 CombiningKind::MUL, vector);
640 case arith::AtomicRMWKind::minimumf:
641 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
642 CombiningKind::MINIMUMF, vector);
643 case arith::AtomicRMWKind::mins:
644 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
645 CombiningKind::MINSI, vector);
646 case arith::AtomicRMWKind::minu:
647 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
649 case arith::AtomicRMWKind::maximumf:
650 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
651 CombiningKind::MAXIMUMF, vector);
652 case arith::AtomicRMWKind::maxs:
653 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
654 CombiningKind::MAXSI, vector);
655 case arith::AtomicRMWKind::maxu:
656 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
657 CombiningKind::MAXUI, vector);
658 case arith::AtomicRMWKind::andi:
659 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
660 CombiningKind::AND, vector);
661 case arith::AtomicRMWKind::ori:
662 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
663 CombiningKind::OR, vector);
672 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
673 return llvm::to_vector<4>(getSourceVectorType().
getShape());
680 LogicalResult matchAndRewrite(ReductionOp reductionOp,
685 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
688 if (maskableOp.isMasked()) {
690 rootOp = maskableOp.getMaskingOp();
691 mask = maskableOp.getMaskingOp().getMask();
693 rootOp = reductionOp;
696 auto vectorType = reductionOp.getSourceVectorType();
697 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
700 Location loc = reductionOp.getLoc();
702 if (vectorType.getRank() == 0) {
704 mask = rewriter.
create<ExtractElementOp>(loc, mask);
705 result = rewriter.
create<ExtractElementOp>(loc, reductionOp.getVector());
708 mask = rewriter.
create<ExtractOp>(loc, mask, 0);
709 result = rewriter.
create<ExtractOp>(loc, reductionOp.getVector(), 0);
712 if (
Value acc = reductionOp.getAcc())
715 reductionOp.getFastmathAttr(), mask);
725 results.
add<ElideSingleElementReduction>(context);
739 getIndexingMapsAttrName(result.
name),
743 getIteratorTypesAttrName(result.
name),
746 return IteratorTypeAttr::get(builder.getContext(), t);
752 ArrayAttr indexingMaps,
753 ArrayAttr iteratorTypes) {
754 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
755 ContractionOp::getDefaultKind());
760 ArrayAttr indexingMaps,
761 ArrayAttr iteratorTypes, CombiningKind kind) {
778 DictionaryAttr dictAttr;
793 dictAttr.getValue().end());
799 ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
804 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
805 auto maybeIteratorType = symbolizeIteratorType(s);
806 if (!maybeIteratorType.has_value())
807 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
809 iteratorTypeAttrs.push_back(
817 getKindAttrName(result.
name),
819 ContractionOp::getDefaultKind()));
821 if (masksInfo.empty())
823 if (masksInfo.size() != 2)
825 "expected zero or exactly 2 vector mask operands");
826 auto lhsType = llvm::cast<VectorType>(types[0]);
827 auto rhsType = llvm::cast<VectorType>(types[1]);
829 std::array<VectorType, 2> maskTypes = {
839 auto attrNames = getTraitAttrNames();
841 traitAttrsSet.insert(attrNames.begin(), attrNames.end());
843 for (
auto attr : (*this)->getAttrs()) {
844 if (attr.getName() == getIteratorTypesAttrName()) {
846 llvm::cast<ArrayAttr>(attr.getValue())
847 .getAsValueRange<IteratorTypeAttr, IteratorType>();
853 llvm::map_range(iteratorTypes, [&](IteratorType t) ->
Attribute {
857 attrs.emplace_back(getIteratorTypesAttrName(),
859 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0)
860 attrs.push_back(attr);
864 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
865 p << getRhs() <<
", " << getAcc();
868 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
873 const std::vector<std::pair<int64_t, int64_t>> &map) {
874 for (
auto &dimPair : map) {
875 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
876 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
877 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
884 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
886 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
887 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
890 for (
auto &dimPair : contractingDimMap) {
891 lhsContractingDimSet.insert(dimPair.first);
892 rhsContractingDimSet.insert(dimPair.second);
895 for (
auto &dimPair : batchDimMap)
896 rhsBatchDimSet.insert(dimPair.second);
900 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
901 if (lhsContractingDimSet.count(i) > 0)
903 expectedResultDims.push_back(lhsType.getDimSize(i));
907 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
908 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
910 expectedResultDims.push_back(rhsType.getDimSize(i));
914 if (expectedResultDims.empty()) {
916 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
917 return op.emitOpError(
"invalid accumulator/result vector shape");
920 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
921 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
922 if (!resVectorType || !accVectorType)
923 return op.emitOpError(
"invalid accumulator/result vector shape");
929 AffineMap lhsMap = op.getIndexingMapsArray()[0];
930 AffineMap rhsMap = op.getIndexingMapsArray()[1];
932 return op.emitOpError(
933 "expected all dimensions to be either a LHS or a RHS dimension");
936 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
937 VectorType v = pair.first;
938 auto map = pair.second;
939 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
940 unsigned pos = map.getDimPosition(idx);
945 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
946 return op.emitOpError(
"expected all dimensions to get an extent as "
947 "either a LHS or a RHS dimension");
949 AffineMap resMap = op.getIndexingMapsArray()[2];
955 llvm::IsaPred<AffineConstantExpr>) &&
956 "expected constant extent along all dimensions.");
958 auto expectedShape = llvm::to_vector<4>(
960 return cast<AffineConstantExpr>(e).getValue();
964 resVectorType.getScalableDims());
965 if (resVectorType != expected || accVectorType != expected)
966 return op.emitOpError(
967 "invalid accumulator/result vector shape, expected: ")
974 VectorType lhsType = getLhsType();
975 VectorType rhsType = getRhsType();
976 Type accType = getAccType();
977 Type resType = getResultType();
979 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
980 if (!lhsType.getElementType().isSignlessInteger())
981 return emitOpError(
"only supports signless integer types");
985 if (getIndexingMapsArray().size() != 3)
986 return emitOpError(
"expected an indexing map for each vector operand");
991 unsigned numIterators = getIteratorTypes().getValue().size();
993 auto index = it.index();
994 auto map = it.value();
995 if (map.getNumSymbols() != 0)
996 return emitOpError(
"expected indexing map ")
997 << index <<
" to have no symbols";
998 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).
getType());
999 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1002 if (map.getNumDims() != numIterators)
1003 return emitOpError(
"expected indexing map ")
1004 << index <<
" to have " << numIterators <<
" number of inputs";
1005 if (map.getNumResults() != rank)
1006 return emitOpError(
"expected indexing map ")
1007 << index <<
" to have " << rank <<
" number of outputs";
1008 if (!map.isProjectedPermutation())
1009 return emitOpError(
"expected indexing map ")
1010 << index <<
" to be a projected permutation of its inputs";
1013 auto contractingDimMap = getContractingDimMap();
1014 auto batchDimMap = getBatchDimMap();
1017 if (contractingDimMap.empty())
1018 return emitOpError(
"expected at least one contracting dimension pair");
1021 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1022 return emitOpError(
"invalid contracting dimension map");
1026 return emitOpError(
"invalid batch dimension map");
1030 contractingDimMap, batchDimMap)))
1034 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1035 auto elementType = vectorType ? vectorType.getElementType() : resType;
1037 return emitOpError(
"unsupported contraction type");
1046 Type ContractionOp::getExpectedMaskType() {
1047 auto indexingMaps = this->getIndexingMapsArray();
1050 VectorType lhsType = this->getLhsType();
1051 VectorType rhsType = this->getRhsType();
1053 unsigned numVecDims = lhsIdxMap.
getNumDims();
1062 lhsType.getScalableDims()[dimIdx];
1067 rhsType.getScalableDims()[dimIdx];
1070 assert(!ShapedType::isDynamicShape(maskShape) &&
1071 "Mask shape couldn't be computed");
1075 maskShapeScalableDims);
1080 getIteratorTypesAttrName(), getKindAttrName()};
1090 static std::vector<std::pair<int64_t, int64_t>>
1092 IteratorType targetIteratorType,
MLIRContext *context) {
1093 std::vector<std::pair<int64_t, int64_t>> dimMap;
1095 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1096 if (iteratorType != targetIteratorType)
1102 if (lhsDim >= 0 && rhsDim >= 0)
1103 dimMap.emplace_back(lhsDim, rhsDim);
1108 void ContractionOp::getIterationBounds(
1110 auto lhsShape = getLhsType().getShape();
1111 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1117 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1118 if (iteratorType == IteratorType::reduction) {
1120 int64_t lhsDimIndex =
getResultIndex(indexingMaps[0], targetExpr);
1121 assert(lhsDimIndex >= 0);
1122 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1126 int64_t resDimIndex =
getResultIndex(indexingMaps[2], targetExpr);
1127 assert(resDimIndex >= 0);
1128 assert(resVectorType !=
nullptr);
1129 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1133 void ContractionOp::getIterationIndexMap(
1135 unsigned numMaps = getIndexingMapsArray().size();
1136 iterationIndexMap.resize(numMaps);
1138 auto index = it.index();
1139 auto map = it.value();
1140 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1141 auto dim = cast<AffineDimExpr>(map.getResult(i));
1142 iterationIndexMap[index][dim.getPosition()] = i;
1147 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1149 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1153 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1155 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1159 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1161 getIterationBounds(shape);
1183 template <
typename AddOpType>
1189 auto canonicalize = [&](
Value maybeContraction,
1190 Value otherOperand) -> vector::ContractionOp {
1191 vector::ContractionOp contractionOp =
1192 dyn_cast_or_null<vector::ContractionOp>(
1195 return vector::ContractionOp();
1196 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1197 contractionOp.getAcc().getDefiningOp())) {
1198 if (maybeZero.getValue() ==
1199 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1201 bvm.
map(contractionOp.getAcc(), otherOperand);
1202 auto newContraction =
1203 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1204 rewriter.
replaceOp(addOp, newContraction.getResult());
1205 return newContraction;
1208 return vector::ContractionOp();
1211 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1212 vector::ContractionOp
contract = canonicalize(a, b);
1214 return contract ? success() : failure();
1230 setResultRanges(getResult(), argRanges.front());
1236 result.
addTypes(llvm::cast<VectorType>(source.
getType()).getElementType());
1240 VectorType vectorType = getSourceVectorType();
1241 if (vectorType.getRank() == 0) {
1243 return emitOpError(
"expected position to be empty with 0-D vector");
1246 if (vectorType.getRank() != 1)
1247 return emitOpError(
"unexpected >1 vector rank");
1249 return emitOpError(
"expected position for 1-D vector");
1253 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1255 if (!adaptor.getPosition())
1259 if (
auto splat = getVector().getDefiningOp<vector::SplatOp>())
1260 return splat.getInput();
1263 if (
auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1267 auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1268 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1272 auto srcElements = src.getValues<
Attribute>();
1274 uint64_t posIdx = pos.getInt();
1275 if (posIdx >= srcElements.size())
1278 return srcElements[posIdx];
1285 return index == poisonValue || (index >= 0 && index < maxIndex);
1294 setResultRanges(getResult(), argRanges.front());
1298 Value source, int64_t position) {
1318 build(builder, result, source, dynamicPos,
1323 ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1324 ExtractOp::Adaptor adaptor,
1326 auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1327 if (
static_cast<int64_t
>(adaptor.getStaticPosition().size()) ==
1328 vectorType.getRank()) {
1329 inferredReturnTypes.push_back(vectorType.getElementType());
1331 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1332 vectorType.getRank());
1334 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1335 vectorType.getScalableDims().drop_front(n)));
1343 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1344 return vectorType && vectorType.getShape().equals({1}) &&
1345 vectorType.getElementType() == r.front();
1347 if (l.size() == 1 && r.size() == 1 &&
1348 (isCompatible(l, r) || isCompatible(r, l)))
1355 auto dynamicMarkersCount =
1356 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1357 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1359 "mismatch between dynamic and static positions (kDynamic marker but no "
1360 "corresponding dynamic position) -- this can only happen due to an "
1361 "incorrect fold/rewrite");
1362 auto position = getMixedPosition();
1363 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1365 "expected position attribute of rank no greater than vector rank");
1367 if (
auto attr = dyn_cast<Attribute>(pos)) {
1368 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1370 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1371 return emitOpError(
"expected position attribute #")
1373 <<
" to be a non-negative integer smaller than the "
1374 "corresponding vector dimension or poison (-1)";
1381 template <
typename IntType>
1383 return llvm::to_vector<4>(llvm::map_range(
1384 arrayAttr.getAsRange<IntegerAttr>(),
1385 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1391 if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1395 if (extractOp.hasDynamicPosition())
1399 ExtractOp currentOp = extractOp;
1401 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1402 while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1405 if (currentOp.hasDynamicPosition())
1408 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1410 extractOp.setOperand(0, currentOp.getVector());
1413 std::reverse(globalPosition.begin(), globalPosition.end());
1414 extractOp.setStaticPosition(globalPosition);
1426 class ExtractFromInsertTransposeChainState {
1428 ExtractFromInsertTransposeChainState(ExtractOp e);
1437 template <
typename ContainerA,
typename ContainerB>
1438 bool isContainedWithin(
const ContainerA &a,
const ContainerB &b) {
1439 return a.size() <= b.size() &&
1440 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1447 template <
typename ContainerA,
typename ContainerB>
1448 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &b) {
1449 for (
auto [elemA, elemB] : llvm::zip(a, b)) {
1450 if (elemA < 0 || elemB < 0)
1465 void updateStateForNextIteration(
Value v) {
1472 LogicalResult handleTransposeOp();
1475 LogicalResult handleInsertOpWithMatchingPos(
Value &res);
1490 LogicalResult handleInsertOpWithPrefixPos(
Value &res);
1495 Value tryToFoldExtractOpInPlace(
Value source);
1497 ExtractOp extractOp;
1499 int64_t extractedRank;
1501 InsertOp nextInsertOp;
1502 TransposeOp nextTransposeOp;
1517 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1519 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1520 extractedRank(extractOp.getNumIndices()) {
1521 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1522 sentinels.reserve(vectorRank - extractedRank);
1523 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1524 sentinels.push_back(-(i + 1));
1526 extractOp.getStaticPosition().end());
1532 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1534 if (extractOp.hasDynamicPosition())
1537 if (!nextTransposeOp)
1540 nextTransposeOp.getPermutation(), extractOp.getContext()));
1547 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1550 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1557 res = nextInsertOp.getSource();
1559 return success(canFold());
1566 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(
Value &res) {
1568 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1581 res = nextInsertOp.getSource();
1589 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1592 if (extractOp.hasDynamicPosition())
1596 bool nothingToFold = (source == extractOp.getVector());
1597 if (nothingToFold || !canFold())
1602 extractOp.setStaticPosition(
1604 extractOp.getVectorMutable().assign(source);
1605 return extractOp.getResult();
1609 Value ExtractFromInsertTransposeChainState::fold() {
1611 if (extractOp.hasDynamicPosition())
1614 Value valueToExtractFrom = extractOp.getVector();
1615 updateStateForNextIteration(valueToExtractFrom);
1616 while (nextInsertOp || nextTransposeOp) {
1619 if (succeeded(handleTransposeOp())) {
1620 valueToExtractFrom = nextTransposeOp.getVector();
1621 updateStateForNextIteration(valueToExtractFrom);
1627 if (succeeded(handleInsertOpWithMatchingPos(result)))
1632 if (succeeded(handleInsertOpWithPrefixPos(result)))
1633 return tryToFoldExtractOpInPlace(result);
1643 valueToExtractFrom = nextInsertOp.getDest();
1644 updateStateForNextIteration(valueToExtractFrom);
1647 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1652 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1653 auto vecType = dyn_cast<VectorType>(type);
1654 return vecType && vecType.getRank() == 0;
1664 if (extractOp.hasDynamicPosition())
1667 Operation *defOp = extractOp.getVector().getDefiningOp();
1668 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1672 if (extractOp.getType() == source.
getType())
1674 auto getRank = [](
Type type) {
1675 return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1680 unsigned broadcastSrcRank = getRank(source.
getType());
1681 if (broadcastSrcRank == 0 && source.
getType() == extractOp.getType())
1684 unsigned extractResultRank = getRank(extractOp.getType());
1685 if (extractResultRank >= broadcastSrcRank)
1688 auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1689 auto broadcastVecType = llvm::dyn_cast<VectorType>(source.
getType());
1690 if (extractVecType && broadcastVecType &&
1691 extractVecType.getShape() !=
1692 broadcastVecType.getShape().take_back(extractResultRank))
1695 auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1696 int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1702 broadcastOp.computeBroadcastedUnitDims();
1704 int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1705 for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1706 if (broadcastedUnitDims.contains(i))
1710 int64_t rankDiff = broadcastSrcRank - extractResultRank;
1711 extractPos.erase(extractPos.begin(),
1712 std::next(extractPos.begin(), extractPos.size() - rankDiff));
1715 extractOp.setOperand(0, source);
1716 extractOp.setStaticPosition(extractPos);
1717 return extractOp.getResult();
1733 if (extractOp.hasDynamicPosition())
1736 auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
1741 if (shuffleOp.getResultVectorType().getRank() != 1)
1744 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1745 auto shuffleMask = shuffleOp.getMask();
1746 int64_t extractIdx = extractOp.getStaticPosition()[0];
1747 int64_t shuffleIdx = shuffleMask[extractIdx];
1750 if (shuffleIdx < inputVecSize) {
1751 extractOp.setOperand(0, shuffleOp.getV1());
1752 extractOp.setStaticPosition({shuffleIdx});
1754 extractOp.setOperand(0, shuffleOp.getV2());
1755 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1758 return extractOp.getResult();
1764 if (extractOp.hasDynamicPosition())
1767 auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1772 auto getDimReverse = [](VectorType type, int64_t n) {
1773 return type.getShape().take_back(n + 1).front();
1775 int64_t destinationRank =
1776 llvm::isa<VectorType>(extractOp.getType())
1777 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1779 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1781 if (destinationRank > 0) {
1782 auto destinationType =
1783 llvm::cast<VectorType>(extractOp.getResult().getType());
1784 for (int64_t i = 0; i < destinationRank; i++) {
1788 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1789 getDimReverse(destinationType, i))
1796 std::reverse(extractedPos.begin(), extractedPos.end());
1799 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1800 strides.push_back(stride);
1802 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1805 int64_t position =
linearize(extractedPos, strides);
1809 int64_t numDimension =
1810 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1812 for (int64_t i = 0; i < numDimension; i++) {
1813 newStrides.push_back(stride);
1815 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1817 std::reverse(newStrides.begin(), newStrides.end());
1821 extractOp.setStaticPosition(newPosition);
1822 extractOp.setOperand(0, shapeCastOp.getSource());
1823 return extractOp.getResult();
1829 if (extractOp.hasDynamicPosition())
1832 auto extractStridedSliceOp =
1833 extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1834 if (!extractStridedSliceOp)
1843 if (extractStridedSliceOp.hasNonUnitStrides())
1848 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1849 while (!sliceOffsets.empty()) {
1850 size_t lastOffset = sliceOffsets.size() - 1;
1851 if (sliceOffsets.back() != 0 ||
1852 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1853 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1855 sliceOffsets.pop_back();
1857 unsigned destinationRank = 0;
1858 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1859 destinationRank = vecType.getRank();
1862 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1863 sliceOffsets.size())
1867 assert(extractedPos.size() >= sliceOffsets.size());
1868 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1869 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1870 extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1874 extractOp.setStaticPosition(extractedPos);
1875 return extractOp.getResult();
1881 if (extractOp.hasDynamicPosition())
1884 int64_t destinationRank =
1885 llvm::isa<VectorType>(extractOp.getType())
1886 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1888 auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1898 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1899 insertOp.getSourceVectorType().getRank();
1900 if (destinationRank > insertOp.getSourceVectorType().getRank())
1902 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1905 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1906 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1909 bool disjoint =
false;
1911 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1912 int64_t start = insertOffsets[dim];
1914 (dim < insertRankDiff)
1916 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1917 int64_t end = start + size;
1918 int64_t offset = extractOffsets[dim];
1920 if (start <= offset && offset < end) {
1921 if (dim >= insertRankDiff)
1922 offsetDiffs.push_back(offset - start);
1932 int64_t srcRankDiff =
1933 insertOp.getSourceVectorType().getRank() - destinationRank;
1934 for (int64_t i = 0; i < destinationRank; i++) {
1935 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1936 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1940 extractOp.getVectorMutable().assign(insertOp.getSource());
1943 extractOp.setStaticPosition(offsetDiffs);
1944 return extractOp.getResult();
1948 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1961 if (extractOp.hasDynamicPosition())
1965 auto fromElementsOp = extractOp.getVector().
getDefiningOp<FromElementsOp>();
1966 if (!fromElementsOp)
1970 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
1971 if (vecType.isScalable())
1975 int64_t rank = vecType.getRank();
1977 if (extractOp.getType() != vecType.getElementType())
1979 assert(
static_cast<int64_t
>(indices.size()) == rank &&
1980 "unexpected number of indices");
1985 for (
int i = rank - 1; i >= 0; --i) {
1986 flatIndex += indices[i] * stride;
1987 stride *= vecType.getDimSize(i);
1989 return fromElementsOp.getElements()[flatIndex];
1994 template <
typename OpType,
typename AdaptorType>
1997 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
1998 OperandRange dynamicPosition = op.getDynamicPosition();
2002 if (!dynamicPosition.size())
2009 bool opChange =
false;
2010 for (
unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2011 if (!ShapedType::isDynamic(staticPosition[i]))
2013 Attribute positionAttr = dynamicPositionAttr[index];
2014 Value position = dynamicPosition[index++];
2015 if (
auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2016 staticPosition[i] = attr.getInt();
2020 operands.push_back(position);
2024 op.setStaticPosition(staticPosition);
2025 op.getOperation()->setOperands(operands);
2026 return op.getResult();
2035 int64_t poisonVal) {
2036 if (!llvm::is_contained(staticPos, poisonVal))
2044 if (llvm::isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2054 if (getNumIndices() == 0 && getVector().
getType() == getResult().
getType())
2057 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2063 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
2092 Operation *defOp = extractOp.getVector().getDefiningOp();
2093 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2097 if (extractOp.getType() == source.
getType())
2099 auto getRank = [](
Type type) {
2100 return llvm::isa<VectorType>(type)
2101 ? llvm::cast<VectorType>(type).getRank()
2104 unsigned broadcastSrcRank = getRank(source.
getType());
2105 unsigned extractResultRank = getRank(extractOp.getType());
2109 if (extractResultRank < broadcastSrcRank)
2113 if (extractResultRank == 0) {
2114 assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.
getType()));
2119 extractOp, extractOp.getType(), source);
2125 class ExtractOpSplatConstantFolder final :
public OpRewritePattern<ExtractOp> {
2133 Value sourceVector = extractOp.getVector();
2137 auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
2140 TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
2141 if (
auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
2149 class ExtractOpNonSplatConstantFolder final
2157 if (extractOp.hasDynamicPosition())
2162 Value sourceVector = extractOp.getVector();
2167 auto vecTy = llvm::cast<VectorType>(sourceVector.
getType());
2168 if (vecTy.isScalable())
2172 auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
2173 if (!dense || dense.isSplat())
2179 copy(extractOp.getStaticPosition(), completePositions.begin());
2180 int64_t elemBeginPosition =
2182 auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
2185 if (
auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
2187 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2190 newAttr = *denseValuesBegin;
2206 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2210 VectorType extractedMaskType =
2211 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2213 if (!extractedMaskType)
2216 auto maskOperands = createMaskOp.getOperands();
2218 VectorType maskType = createMaskOp.getVectorType();
2220 bool containsUnknownDims =
false;
2223 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2225 int64_t pos = extractOpPos[dimIdx];
2226 Value operand = maskOperands[dimIdx];
2227 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2230 containsUnknownDims =
true;
2234 int64_t createMaskBound =
2235 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2237 if (pos != ShapedType::kDynamic) {
2240 allFalse |= pos >= createMaskBound;
2241 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2245 containsUnknownDims =
true;
2252 }
else if (!containsUnknownDims) {
2254 extractOp, extractedMaskType,
2255 maskOperands.drop_front(extractOpPos.size()));
2265 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2267 auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2271 VectorType sourceType = castOp.getSourceVectorType();
2272 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2276 if (sourceType.getNumElements() != targetType.getNumElements())
2280 castOp.getSource());
2290 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2293 if (extractOp.hasDynamicPosition())
2297 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2302 auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2303 if (!fromElementsOp)
2305 VectorType inputType = fromElementsOp.getType();
2308 if (resultType.isScalable() || inputType.isScalable())
2314 llvm::to_vector(extractOp.getStaticPosition());
2315 firstElementPos.append(resultType.getRank(), 0);
2318 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2319 flatIndex += firstElementPos[i] * stride;
2320 stride *= inputType.getDimSize(i);
2325 extractOp, resultType,
2326 fromElementsOp.getElements().slice(flatIndex,
2327 resultType.getNumElements()));
2335 results.
add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2336 ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2337 results.
add(foldExtractFromShapeCastToShapeCast);
2338 results.
add(foldExtractFromFromElements);
2343 for (
auto attr : arrayAttr)
2344 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2351 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2366 if (!llvm::all_equal(fromElementsOp.getElements()))
2369 fromElementsOp.getElements().front());
2384 setResultRanges(getResult(), argRanges.front());
2392 int64_t rankDiff = dstShape.size() - srcShape.size();
2393 int64_t dstDim = rankDiff;
2395 for (
auto [s1, s2] :
2396 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2398 assert(s1 == 1 &&
"expected \"dim-1\" broadcasting");
2408 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2427 Value BroadcastOp::createOrFoldBroadcastOp(
2430 assert(!dstShape.empty() &&
"unexpected empty dst shape");
2434 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
2435 if (broadcastedDims.contains(i))
2437 checkShape.push_back(dstShape[i]);
2439 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2440 "ill-formed broadcastedDims contains values not confined to "
2445 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
2449 if (!srcVectorType) {
2450 assert(checkShape.empty() &&
2451 "ill-formed createOrFoldBroadcastOp arguments");
2452 return b.
createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2455 assert(srcVectorType.getShape().equals(checkShape) &&
2456 "ill-formed createOrFoldBroadcastOp arguments");
2467 broadcastShape.reserve(dstShape.size());
2483 int64_t nextSrcShapeDim = broadcastedDims.size();
2484 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2485 if (broadcastedDims.contains(i)) {
2490 broadcastShape.push_back(dstShape[i]);
2491 permutation[i] = broadcastShape.size() - 1;
2497 permutation[i] = nextSrcShapeDim++;
2501 llvm::append_range(broadcastShape, srcVectorType.getShape());
2506 "unexpected \"dim-1\" broadcast");
2508 VectorType broadcastType =
VectorType::get(broadcastShape, elementType);
2510 vector::BroadcastableToResult::Success &&
2511 "must be broadcastable");
2515 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2516 if (permutation[i] != i)
2517 return b.
createOrFold<vector::TransposeOp>(loc, res, permutation);
2523 Type srcType, VectorType dstVectorType,
2524 std::pair<VectorDim, VectorDim> *mismatchingDims) {
2528 return BroadcastableToResult::Success;
2530 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2532 return BroadcastableToResult::SourceTypeNotAVector;
2534 int64_t srcRank = srcVectorType.getRank();
2535 int64_t dstRank = dstVectorType.getRank();
2536 if (srcRank > dstRank)
2537 return BroadcastableToResult::SourceRankHigher;
2540 int64_t lead = dstRank - srcRank;
2541 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2544 bool foundMismatchingDims =
false;
2547 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2548 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2549 if (srcDim != 1 && srcDim != dstDim)
2550 foundMismatchingDims =
true;
2553 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2554 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2555 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2558 (srcDimScalableFlag != dstDimScalableFlag &&
2559 (srcDim != 1 || srcDimScalableFlag)))
2560 foundMismatchingDims =
true;
2562 if (foundMismatchingDims) {
2563 if (mismatchingDims !=
nullptr) {
2564 mismatchingDims->first.dim = srcDim;
2565 mismatchingDims->first.isScalable = srcDimScalableFlag;
2567 mismatchingDims->second.dim = dstDim;
2568 mismatchingDims->second.isScalable = dstDimScalableFlag;
2570 return BroadcastableToResult::DimensionMismatch;
2574 return BroadcastableToResult::Success;
2578 std::pair<VectorDim, VectorDim> mismatchingDims;
2580 getSourceType(), getResultVectorType(), &mismatchingDims);
2581 if (res == BroadcastableToResult::Success)
2583 if (res == BroadcastableToResult::SourceRankHigher)
2584 return emitOpError(
"source rank higher than destination rank");
2585 if (res == BroadcastableToResult::DimensionMismatch) {
2586 return emitOpError(
"dimension mismatch (")
2587 << (mismatchingDims.first.isScalable ?
"[" :
"")
2588 << mismatchingDims.first.dim
2589 << (mismatchingDims.first.isScalable ?
"]" :
"") <<
" vs. "
2590 << (mismatchingDims.second.isScalable ?
"[" :
"")
2591 << mismatchingDims.second.dim
2592 << (mismatchingDims.second.isScalable ?
"]" :
"") <<
")";
2594 if (res == BroadcastableToResult::SourceTypeNotAVector)
2595 return emitOpError(
"source type is not a vector");
2596 llvm_unreachable(
"unexpected vector.broadcast op error");
2600 if (getSourceType() == getResultVectorType())
2602 if (!adaptor.getSource())
2604 auto vectorType = getResultVectorType();
2605 if (
auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
2606 if (vectorType.getElementType() != attr.getType())
2610 if (
auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
2611 if (vectorType.getElementType() != attr.getType())
2615 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2628 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2632 broadcastOp.getResultVectorType(),
2633 srcBroadcast.getSource());
2643 results.
add<BroadcastFolder>(context);
2651 VectorType resultType = getResultVectorType();
2652 VectorType v1Type = getV1VectorType();
2653 VectorType v2Type = getV2VectorType();
2655 int64_t resRank = resultType.getRank();
2656 int64_t v1Rank = v1Type.getRank();
2657 int64_t v2Rank = v2Type.getRank();
2658 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2659 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2660 if (!wellFormed0DCase && !wellFormedNDCase)
2661 return emitOpError(
"rank mismatch");
2664 for (int64_t r = 1; r < v1Rank; ++r) {
2665 int64_t resDim = resultType.getDimSize(r);
2666 int64_t v1Dim = v1Type.getDimSize(r);
2667 int64_t v2Dim = v2Type.getDimSize(r);
2668 if (resDim != v1Dim || v1Dim != v2Dim)
2669 return emitOpError(
"dimension mismatch");
2673 int64_t maskLength = mask.size();
2674 if (maskLength <= 0)
2675 return emitOpError(
"invalid mask length");
2676 if (maskLength != resultType.getDimSize(0))
2677 return emitOpError(
"mask length mismatch");
2679 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2680 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2683 return emitOpError(
"mask index #") << (idx + 1) <<
" out of range";
2689 ShuffleOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
2690 ShuffleOp::Adaptor adaptor,
2692 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2693 auto v1Rank = v1Type.getRank();
2697 shape.reserve(v1Rank);
2698 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2701 llvm::append_range(shape, v1Type.getShape().drop_front());
2702 inferredReturnTypes.push_back(
2707 template <
typename T>
2710 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
2711 return value == expected++;
2715 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2716 auto v1Type = getV1VectorType();
2717 auto v2Type = getV2VectorType();
2719 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
2720 "Vector shuffle does not support scalable vectors");
2724 if (v1Type.getRank() == 0)
2728 auto mask = getMask();
2735 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
2736 if (!v1Attr || !v2Attr)
2740 bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
2741 bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
2742 if (isV1Poison && isV2Poison)
2747 if (v1Type.getRank() != 1)
2757 to_vector(cast<DenseElementsAttr>(v2Attr).getValues<Attribute>());
2758 poisonElement = v2Elements[0];
2762 to_vector(cast<DenseElementsAttr>(v1Attr).getValues<Attribute>());
2763 poisonElement = v1Elements[0];
2767 int64_t v1Size = v1Type.getDimSize(0);
2768 for (int64_t maskIdx : mask) {
2771 if (maskIdx == ShuffleOp::kPoisonIndex) {
2772 indexedElm = poisonElement;
2774 if (maskIdx < v1Size)
2775 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
2777 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
2780 results.push_back(indexedElm);
2795 VectorType v1VectorType = shuffleOp.getV1VectorType();
2797 if (v1VectorType.getRank() > 0)
2799 if (mask.size() != 1)
2819 auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2820 auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2822 if (!v1Splat || !v2Splat)
2825 if (v1Splat.getInput() != v2Splat.getInput())
2841 VectorType resultType = op.getResultVectorType();
2842 if (resultType.isScalable())
2844 op,
"ShuffleOp can't represent a scalable interleave");
2846 if (resultType.getRank() != 1)
2848 op,
"ShuffleOp can't represent an n-D interleave");
2850 VectorType sourceType = op.getV1VectorType();
2851 if (sourceType != op.getV2VectorType() ||
2852 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2854 op,
"ShuffleOp types don't match an interleave");
2858 int64_t resultVectorSize = resultType.getNumElements();
2859 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2860 int64_t maskValueA = shuffleMask[i * 2];
2861 int64_t maskValueB = shuffleMask[(i * 2) + 1];
2862 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2864 "ShuffleOp mask not interleaving");
2876 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2886 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2891 build(builder, result, source, dest, {});
2895 auto dstVectorType = getDestVectorType();
2896 if (dstVectorType.getRank() == 0) {
2898 return emitOpError(
"expected position to be empty with 0-D vector");
2901 if (dstVectorType.getRank() != 1)
2902 return emitOpError(
"unexpected >1 vector rank");
2904 return emitOpError(
"expected position for 1-D vector");
2908 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2910 if (!adaptor.getPosition())
2913 auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2914 auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2915 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2916 if (!src || !dst || !pos)
2922 auto dstElements = dst.getValues<
Attribute>();
2926 uint64_t posIdx = pos.getInt();
2927 if (posIdx >= results.size())
2929 results[posIdx] = src;
2940 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2944 Value source,
Value dest, int64_t position) {
2957 posVals.reserve(position.size());
2958 llvm::transform(position, std::back_inserter(posVals),
2960 build(builder, result, source, dest, posVals);
2969 build(builder, result, source, dest, dynamicPos,
2975 auto destVectorType = getDestVectorType();
2976 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
2978 "expected position attribute of rank no greater than dest vector rank");
2979 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2980 if (srcVectorType &&
2981 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
2982 static_cast<unsigned>(destVectorType.getRank())))
2983 return emitOpError(
"expected position attribute rank + source rank to "
2984 "match dest vector rank");
2985 if (!srcVectorType &&
2986 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
2988 "expected position attribute rank to match the dest vector rank");
2990 if (
auto attr = pos.dyn_cast<
Attribute>()) {
2991 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2993 destVectorType.getDimSize(idx))) {
2994 return emitOpError(
"expected position attribute #")
2996 <<
" to be a non-negative integer smaller than the "
2998 "dest vector dimension";
3015 auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
3016 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3017 srcVecType.getNumElements())
3020 insertOp, insertOp.getDestVectorType(), insertOp.getSource());
3032 auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
3033 auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
3035 if (!srcSplat || !dstSplat)
3038 if (srcSplat.getInput() != dstSplat.getInput())
3053 static constexpr int64_t vectorSizeFoldThreshold = 256;
3058 if (op.hasDynamicPosition())
3067 auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
3071 VectorType destTy = destVector.getType();
3072 if (destTy.isScalable())
3076 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3077 !destVector.hasOneUse())
3080 Value sourceValue = op.getSource();
3088 copy(op.getStaticPosition(), completePositions.begin());
3089 int64_t insertBeginPosition =
3093 Type destEltType = destTy.getElementType();
3098 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
3099 for (
auto value : denseSource.getValues<
Attribute>())
3105 auto allValues = llvm::to_vector(denseDest.getValues<
Attribute>());
3106 copy(insertedValues, allValues.begin() + insertBeginPosition);
3117 if (
auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
3118 if (intAttr.getType() != expectedType)
3129 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3130 InsertOpConstantFolder>(context);
3133 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3137 if (getNumIndices() == 0 && getSourceType() ==
getType())
3143 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3168 template <
typename OpType>
3170 ArrayAttr arrayAttr,
3172 StringRef attrName) {
3173 if (arrayAttr.size() > shape.size())
3174 return op.emitOpError(
"expected ")
3175 << attrName <<
" attribute of rank no greater than vector rank";
3182 template <
typename OpType>
3183 static LogicalResult
3185 int64_t
max, StringRef attrName,
3186 bool halfOpen =
true) {
3187 for (
auto attr : arrayAttr) {
3188 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3192 if (val < min || val >= upper)
3193 return op.emitOpError(
"expected ") << attrName <<
" to be confined to ["
3194 <<
min <<
", " << upper <<
")";
3202 template <
typename OpType>
3203 static LogicalResult
3206 bool halfOpen =
true, int64_t
min = 0) {
3207 for (
auto [index, attrDimPair] :
3209 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3210 int64_t
max = std::get<1>(attrDimPair);
3213 if (val < min || val >=
max)
3214 return op.emitOpError(
"expected ")
3215 << attrName <<
" dimension " << index <<
" to be confined to ["
3216 <<
min <<
", " <<
max <<
")";
3226 template <
typename OpType>
3228 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3230 bool halfOpen =
true, int64_t
min = 1) {
3231 assert(arrayAttr1.size() <= shape.size());
3232 assert(arrayAttr2.size() <= shape.size());
3233 for (
auto [index, it] :
3235 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3236 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3237 int64_t
max = std::get<2>(it);
3240 if (val1 + val2 < 0 || val1 + val2 >=
max)
3241 return op.emitOpError(
"expected sum(")
3242 << attrName1 <<
", " << attrName2 <<
") dimension " << index
3243 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
3250 auto attrs = llvm::map_range(values, [context](int64_t v) ->
Attribute {
3257 auto sourceVectorType = getSourceVectorType();
3258 auto destVectorType = getDestVectorType();
3259 auto offsets = getOffsetsAttr();
3260 auto strides = getStridesAttr();
3261 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
3263 "expected offsets of same size as destination vector rank");
3264 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
3265 return emitOpError(
"expected strides of same size as source vector rank");
3266 if (sourceVectorType.getRank() > destVectorType.getRank())
3268 "expected source rank to be no greater than destination rank");
3270 auto sourceShape = sourceVectorType.getShape();
3271 auto destShape = destVectorType.getShape();
3273 destShape.size() - sourceShape.size(), 0);
3274 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3275 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3276 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3285 offName,
"source vector shape",
3289 unsigned rankDiff = destShape.size() - sourceShape.size();
3290 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3291 if (sourceVectorType.getScalableDims()[idx] !=
3292 destVectorType.getScalableDims()[idx + rankDiff]) {
3293 return emitOpError(
"mismatching scalable flags (at source vector idx=")
3296 if (sourceVectorType.getScalableDims()[idx]) {
3297 auto sourceSize = sourceShape[idx];
3298 auto destSize = destShape[idx + rankDiff];
3299 if (sourceSize != destSize) {
3300 return emitOpError(
"expected size at idx=")
3302 << (
" to match the corresponding base size from the input "
3304 << sourceSize << (
" vs ") << destSize << (
")");
3315 class FoldInsertStridedSliceSplat final
3320 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3323 insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
3325 insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3327 if (!srcSplatOp || !destSplatOp)
3330 if (srcSplatOp.getInput() != destSplatOp.getInput())
3333 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3340 class FoldInsertStridedSliceOfExtract final
3345 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3347 auto extractStridedSliceOp =
3348 insertStridedSliceOp.getSource()
3349 .getDefiningOp<vector::ExtractStridedSliceOp>();
3351 if (!extractStridedSliceOp)
3354 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3358 if (extractStridedSliceOp.getStrides() !=
3359 insertStridedSliceOp.getStrides() ||
3360 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3363 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3370 class InsertStridedSliceConstantFolder final
3377 static constexpr int64_t vectorSizeFoldThreshold = 256;
3388 VectorType destTy = destVector.getType();
3389 if (destTy.isScalable())
3393 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3394 !destVector.hasOneUse())
3403 if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3407 if (op.hasNonUnitStrides())
3410 VectorType sliceVecTy = sourceValue.getType();
3412 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3422 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3423 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3424 auto sliceValuesIt = denseSlice.value_begin<
Attribute>();
3425 auto newValues = llvm::to_vector(denseDest.getValues<
Attribute>());
3428 currDestPosition.begin() + rankDifference, currDestPosition.end());
3432 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
3433 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
3434 assert(sliceValuesIt != denseSlice.value_end<
Attribute>() &&
3435 "Invalid slice element");
3436 newValues[linearizedPosition] = *sliceValuesIt;
3449 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3451 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3452 InsertStridedSliceConstantFolder>(context);
3455 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3456 if (getSourceVectorType() == getDestVectorType())
3473 p <<
" " << getLhs() <<
", " << getRhs();
3475 p <<
", " << getAcc();
3478 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
3489 if (operandsInfo.size() < 2)
3491 "expected at least 2 operands");
3492 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3493 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3496 "expected vector type for operand #1");
3501 vRHS.getScalableDims()[0]};
3503 vLHS.getElementType(), scalableDimsRes);
3507 resType =
VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3513 OuterProductOp::getKindAttrName(result.
name),
3515 OuterProductOp::getDefaultKind()));
3521 (operandsInfo.size() > 2 &&
3527 Type tRHS = getOperandTypeRHS();
3528 VectorType vLHS = getOperandVectorTypeLHS(),
3529 vRHS = llvm::dyn_cast<VectorType>(tRHS),
3530 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3532 if (vLHS.getRank() != 1)
3533 return emitOpError(
"expected 1-d vector for operand #1");
3537 if (vRHS.getRank() != 1)
3538 return emitOpError(
"expected 1-d vector for operand #2");
3539 if (vRES.getRank() != 2)
3540 return emitOpError(
"expected 2-d vector result");
3541 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3542 return emitOpError(
"expected #1 operand dim to match result dim #1");
3543 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3544 return emitOpError(
"expected #2 operand dim to match result dim #2");
3545 if (vLHS.isScalable() && !vRHS.isScalable()) {
3549 "expected either both or only #2 operand dim to be scalable");
3553 if (vRES.getRank() != 1)
3554 return emitOpError(
"expected 1-d vector result");
3555 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3556 return emitOpError(
"expected #1 operand dim to match result dim #1");
3559 if (vACC && vACC != vRES)
3560 return emitOpError(
"expected operand #3 of same type as result type");
3564 return emitOpError(
"unsupported outerproduct type");
3573 Type OuterProductOp::getExpectedMaskType() {
3574 auto vecType = this->getResultVectorType();
3577 vecType.getScalableDims());
3589 ArrayAttr offsets, ArrayAttr sizes,
3590 ArrayAttr strides) {
3591 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3593 shape.reserve(vectorType.getRank());
3595 for (
unsigned e = offsets.size(); idx < e; ++idx)
3596 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3597 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3598 shape.push_back(vectorType.getShape()[idx]);
3601 vectorType.getScalableDims());
3614 offsetsAttr, sizesAttr, stridesAttr));
3615 result.
addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.
name),
3619 result.
addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.
name),
3624 auto type = getSourceVectorType();
3625 auto offsets = getOffsetsAttr();
3626 auto sizes = getSizesAttr();
3627 auto strides = getStridesAttr();
3628 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3630 "expected offsets, sizes and strides attributes of same size");
3632 auto shape = type.getShape();
3633 auto offName = getOffsetsAttrName();
3634 auto sizesName = getSizesAttrName();
3635 auto stridesName = getStridesAttrName();
3651 shape, offName, sizesName,
3656 offsets, sizes, strides);
3657 if (getResult().
getType() != resultType)
3658 return emitOpError(
"expected result type to be ") << resultType;
3660 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
3661 if (type.getScalableDims()[idx]) {
3662 auto inputDim = type.getShape()[idx];
3663 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3664 if (inputDim != inputSize)
3665 return emitOpError(
"expected size at idx=")
3667 << (
" to match the corresponding base size from the input "
3669 << inputSize << (
" vs ") << inputDim << (
")");
3679 static LogicalResult
3682 auto getElement = [](ArrayAttr array,
int idx) {
3683 return llvm::cast<IntegerAttr>(array[idx]).getInt();
3685 ArrayAttr extractOffsets = op.getOffsets();
3687 ArrayAttr extractSizes = op.getSizes();
3688 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3690 if (op.getSourceVectorType().getRank() !=
3691 insertOp.getSourceVectorType().getRank())
3693 ArrayAttr insertOffsets = insertOp.getOffsets();
3694 ArrayAttr insertStrides = insertOp.getStrides();
3697 if (extractOffsets.size() > insertOffsets.size())
3699 bool patialoverlap =
false;
3700 bool disjoint =
false;
3702 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3703 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
3705 int64_t start = getElement(insertOffsets, dim);
3706 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3707 int64_t offset = getElement(extractOffsets, dim);
3708 int64_t size = getElement(extractSizes, dim);
3710 if (start <= offset && offset < end) {
3713 if (offset + size > end)
3714 patialoverlap =
true;
3715 offsetDiffs.push_back(offset - start);
3722 if (!disjoint && !patialoverlap) {
3723 op.setOperand(insertOp.getSource());
3732 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3742 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3743 if (getSourceVectorType() == getResult().
getType())
3758 class StridedSliceConstantMaskFolder final
3763 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3767 auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3768 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3769 if (!constantMaskOp)
3772 if (extractStridedSliceOp.hasNonUnitStrides())
3785 sliceMaskDimSizes.reserve(maskDimSizes.size());
3786 for (
auto [maskDimSize, sliceOffset, sliceSize] :
3787 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3788 int64_t sliceMaskDimSize =
std::max(
3789 static_cast<int64_t
>(0),
3790 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3791 sliceMaskDimSizes.push_back(sliceMaskDimSize);
3794 if (sliceMaskDimSizes.size() < maskDimSizes.size())
3795 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3796 sliceMaskDimSizes.push_back(maskDimSizes[i]);
3799 if (llvm::is_contained(sliceMaskDimSizes, 0))
3800 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3805 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3812 class StridedSliceSplatConstantFolder final
3817 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3821 Value sourceVector = extractStridedSliceOp.getVector();
3826 auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3840 class StridedSliceNonSplatConstantFolder final
3845 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3849 Value sourceVector = extractStridedSliceOp.getVector();
3855 auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3856 if (!dense || dense.isSplat())
3860 if (extractStridedSliceOp.hasNonUnitStrides())
3863 auto sourceVecTy = llvm::cast<VectorType>(sourceVector.
getType());
3867 VectorType sliceVecTy = extractStridedSliceOp.getType();
3869 int64_t sliceRank = sliceVecTy.getRank();
3881 auto denseValuesBegin = dense.value_begin<
Attribute>();
3883 sliceValues.reserve(sliceVecTy.getNumElements());
3886 int64_t linearizedPosition =
linearize(currSlicePosition, sourceStrides);
3887 assert(linearizedPosition < sourceVecTy.getNumElements() &&
3889 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3893 assert(
static_cast<int64_t
>(sliceValues.size()) ==
3894 sliceVecTy.getNumElements() &&
3895 "Invalid number of slice elements");
3905 class StridedSliceBroadcast final
3917 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3918 auto dstVecType = llvm::cast<VectorType>(op.getType());
3919 unsigned dstRank = dstVecType.getRank();
3920 unsigned rankDiff = dstRank - srcRank;
3924 bool lowerDimMatch =
true;
3925 for (
unsigned i = 0; i < srcRank; i++) {
3926 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3927 lowerDimMatch =
false;
3936 bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3937 if (!lowerDimMatch && !isScalarSrc) {
3938 source = rewriter.
create<ExtractStridedSliceOp>(
3939 op->getLoc(), source,
3950 class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
3956 auto splat = op.getVector().getDefiningOp<SplatOp>();
3980 class ContiguousExtractStridedSliceToExtract final
3987 if (op.hasNonUnitStrides())
3989 Value source = op.getOperand();
3990 auto sourceType = cast<VectorType>(source.
getType());
3991 if (sourceType.isScalable() || sourceType.getRank() == 0)
4000 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4001 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4008 if (numOffsets == 0)
4013 if (numOffsets == sourceType.getRank() &&
4014 static_cast<int>(sizes.size()) == sourceType.getRank())
4018 for (
int i = 0; i < numOffsets; ++i) {
4026 while (numOffsets <
static_cast<int>(sizes.size()) - 1 &&
4027 sizes[numOffsets] == 1) {
4032 auto extractOffsets =
ArrayRef(offsets).take_front(numOffsets);
4033 Value extract = rewriter.
create<vector::ExtractOp>(op->getLoc(), source,
4042 void ExtractStridedSliceOp::getCanonicalizationPatterns(
4046 results.
add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
4047 StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
4048 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4058 VectorType vectorType,
Value source,
4059 ValueRange indices, AffineMapAttr permutationMapAttr,
4060 ArrayAttr inBoundsAttr) {
4061 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4062 Value padding = builder.
create<arith::ConstantOp>(
4064 build(builder, result, vectorType, source, indices, permutationMapAttr,
4065 padding,
Value(), inBoundsAttr);
4070 VectorType vectorType,
Value source,
4074 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4078 build(builder, result, vectorType, source, indices, permutationMapAttr,
4084 VectorType vectorType,
Value source,
4088 llvm::cast<ShapedType>(source.
getType()), vectorType);
4090 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4094 build(builder, result, vectorType, source, indices, permutationMapAttr,
4096 Value(), inBoundsAttr);
4102 VectorType vectorType,
Value source,
4105 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4106 Value padding = builder.
create<arith::ConstantOp>(
4108 build(builder, result, vectorType, source, indices, padding, inBounds);
4111 template <
typename EmitFun>
4113 EmitFun emitOpError) {
4115 for (
auto expr : permutationMap.
getResults()) {
4116 auto dim = dyn_cast<AffineDimExpr>(expr);
4117 auto zero = dyn_cast<AffineConstantExpr>(expr);
4119 if (zero.getValue() != 0) {
4121 "requires a projected permutation_map (at most one dim or the zero "
4122 "constant can appear in each result)");
4127 return emitOpError(
"requires a projected permutation_map (at most one "
4128 "dim or the zero constant can appear in each result)");
4130 if (seen[dim.getPosition()]) {
4132 "requires a permutation_map that is a permutation (found one dim "
4133 "used more than once)");
4135 seen[dim.getPosition()] =
true;
4140 static LogicalResult
4142 VectorType vectorType, VectorType maskType,
4143 VectorType inferredMaskType,
AffineMap permutationMap,
4144 ArrayAttr inBounds) {
4145 if (op->hasAttr(
"masked")) {
4146 return op->emitOpError(
"masked attribute has been removed. "
4147 "Use in_bounds instead.");
4150 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4151 return op->emitOpError(
4152 "requires source to be a memref or ranked tensor type");
4154 auto elementType = shapedType.getElementType();
4155 DataLayout dataLayout = DataLayout::closest(op);
4156 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4158 unsigned sourceVecSize =
4160 vectorElementType.getShape().back();
4161 unsigned resultVecSize =
4163 vectorType.getShape().back();
4164 if (resultVecSize % sourceVecSize != 0)
4165 return op->emitOpError(
4166 "requires the bitwidth of the minor 1-D vector to be an integral "
4167 "multiple of the bitwidth of the minor 1-D vector of the source");
4169 unsigned sourceVecEltRank = vectorElementType.getRank();
4170 unsigned resultVecRank = vectorType.getRank();
4171 if (sourceVecEltRank > resultVecRank)
4172 return op->emitOpError(
4173 "requires source vector element and vector result ranks to match.");
4174 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4177 return op->emitOpError(
"requires a permutation_map with result dims of "
4178 "the same rank as the vector type");
4181 return op->emitOpError(
"does not support masks with vector element type");
4184 unsigned minorSize =
4185 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4186 unsigned resultVecSize =
4189 return op->emitOpError(
4190 "requires the bitwidth of the minor 1-D vector to be an integral "
4191 "multiple of the bitwidth of the source element type");
4195 return op->emitOpError(
"requires a permutation_map with result dims of "
4196 "the same rank as the vector type");
4200 return op->emitOpError(
"requires permutation_map without symbols");
4202 if (permutationMap.
getNumInputs() != shapedType.getRank())
4203 return op->emitOpError(
"requires a permutation_map with input dims of the "
4204 "same rank as the source type");
4206 if (maskType && maskType != inferredMaskType)
4207 return op->emitOpError(
"inferred mask type (")
4208 << inferredMaskType <<
") and mask operand type (" << maskType
4211 if (permutationMap.
getNumResults() !=
static_cast<int64_t
>(inBounds.size()))
4212 return op->emitOpError(
"expects the in_bounds attr of same rank "
4213 "as permutation_map results: ")
4215 <<
" vs inBounds of size: " << inBounds.size();
4222 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4223 if (op.getPermutationMap().isMinorIdentity())
4224 elidedAttrs.push_back(op.getPermutationMapAttrName());
4226 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
4227 elidedAttrs.push_back(op.getInBoundsAttrName());
4232 p <<
" " << getSource() <<
"[" <<
getIndices() <<
"], " << getPadding();
4234 p <<
", " << getMask();
4243 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
4248 if (maskShape.empty())
4249 maskShape.push_back(1);
4271 if (hasMask.succeeded()) {
4278 if (types.size() != 2)
4279 return parser.
emitError(typesLoc,
"requires two types");
4281 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4282 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4283 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4284 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4286 return parser.
emitError(typesLoc,
"requires vector type");
4287 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.
name);
4294 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4296 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.
name);
4298 if (!inBoundsAttr) {
4308 if (hasMask.succeeded()) {
4309 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4311 maskInfo.
location,
"does not support masks with vector element type");
4314 "expected the same rank for the vector and the "
4315 "results of the permutation map");
4323 result.
addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4325 {1, static_cast<int32_t>(indexInfo.size()), 1,
4326 static_cast<int32_t>(hasMask.succeeded())}));
4332 ShapedType shapedType = getShapedType();
4334 VectorType maskType = getMaskType();
4335 auto paddingType = getPadding().getType();
4336 auto permutationMap = getPermutationMap();
4337 VectorType inferredMaskType =
4340 auto sourceElementType = shapedType.getElementType();
4342 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
4343 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4345 if (failed(
verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4346 shapedType, vectorType, maskType,
4347 inferredMaskType, permutationMap, getInBounds())))
4350 if (
auto sourceVectorElementType =
4351 llvm::dyn_cast<VectorType>(sourceElementType)) {
4354 if (sourceVectorElementType != paddingType)
4356 "requires source element type and padding type to match.");
4360 if (!VectorType::isValidElementType(paddingType))
4361 return emitOpError(
"requires valid padding vector elemental type");
4364 if (paddingType != sourceElementType)
4366 "requires formal padding and source of the same elemental type");
4370 [&](Twine t) {
return emitOpError(t); });
4377 Type TransferReadOp::getExpectedMaskType() {
4381 template <
typename TransferOp>
4382 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4385 if (op.getShapedType().isDynamicDim(indicesIdx))
4387 Value index = op.getIndices()[indicesIdx];
4389 if (!cstOp.has_value())
4392 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4393 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4395 return cstOp.value() + vectorSize <= sourceSize;
4398 template <
typename TransferOp>
4402 if (op.getTransferRank() == 0)
4407 newInBounds.reserve(op.getTransferRank());
4412 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
4414 if (op.isDimInBounds(i)) {
4415 newInBounds.push_back(
true);
4420 bool inBounds =
false;
4421 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
4424 dimExpr.getPosition());
4425 nonBcastDims.push_back(i);
4428 newInBounds.push_back(inBounds);
4436 bool allNonBcastDimsInBounds = llvm::all_of(
4437 nonBcastDims, [&newInBounds](
unsigned idx) {
return newInBounds[idx]; });
4438 if (allNonBcastDimsInBounds) {
4441 newInBounds[idx] =
true;
4453 template <
typename TransferOp>
4455 auto mask = op.getMask();
4462 op.getMaskMutable().clear();
4476 static Value foldRAW(TransferReadOp readOp) {
4477 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4479 auto defWrite = readOp.getSource().
getDefiningOp<vector::TransferWriteOp>();
4482 return defWrite.getVector();
4484 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4485 cast<VectorTransferOpInterface>(readOp.getOperation())))
4487 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4493 if (
Value vec = foldRAW(*
this))
4507 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4511 void TransferReadOp::getEffects(
4514 if (llvm::isa<MemRefType>(getShapedType()))
4520 if (hasPureTensorSemantics())
4548 struct TransferReadAfterWriteToBroadcast
4554 if (readOp.hasOutOfBoundsDim() ||
4555 !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4557 auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4562 if (readOp.getTransferChunkAccessed() !=
4563 defWrite.getTransferChunkAccessed())
4570 if (readOp.getIndices() != defWrite.getIndices() ||
4571 readOp.getMask() != defWrite.getMask())
4573 Value vec = defWrite.getVector();
4595 broadcastShape[pos.value()] = destShape[pos.index()];
4596 broadcastScalableFlags[pos.value()] =
4597 readOp.getVectorType().getScalableDims()[pos.index()];
4600 broadcastShape, defWrite.getVectorType().getElementType(),
4601 broadcastScalableFlags);
4602 vec = rewriter.
create<vector::BroadcastOp>(loc, broadcastedType, vec);
4613 results.
add<TransferReadAfterWriteToBroadcast>(context);
4623 AffineMapAttr permutationMapAttr,
4625 ArrayAttr inBoundsAttr) {
4626 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
4627 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4628 mask, inBoundsAttr);
4634 AffineMapAttr permutationMapAttr,
4635 ArrayAttr inBoundsAttr) {
4636 build(builder, result, vector, dest, indices, permutationMapAttr,
4637 Value(), inBoundsAttr);
4648 (inBounds && !inBounds.value().empty())
4651 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
4652 build(builder, result, vector, dest, indices, permutationMapAttr,
4653 Value(), inBoundsAttr);
4661 auto vectorType = llvm::cast<VectorType>(vector.
getType());
4663 llvm::cast<ShapedType>(dest.
getType()), vectorType);
4664 build(builder, result, vector, dest, indices, permutationMap, inBounds);
4680 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
4685 if (types.size() != 2)
4686 return parser.
emitError(typesLoc,
"requires two types");
4688 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4690 return parser.
emitError(typesLoc,
"requires vector type");
4691 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4692 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4693 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4694 auto permMapAttrName =
4695 TransferWriteOp::getPermutationMapAttrName(result.
name);
4702 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4704 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.
name);
4706 if (!inBoundsAttr) {
4715 if (hasMask.succeeded()) {
4716 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4718 maskInfo.
location,
"does not support masks with vector element type");
4721 "expected the same rank for the vector and the "
4722 "results of the permutation map");
4728 result.
addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4730 {1, 1, static_cast<int32_t>(indexInfo.size()),
4731 static_cast<int32_t>(hasMask.succeeded())}));
4732 return failure(llvm::isa<RankedTensorType>(shapedType) &&
4737 p <<
" " << getVector() <<
", " << getSource() <<
"[" <<
getIndices() <<
"]";
4739 p <<
", " << getMask();
4746 ShapedType shapedType = getShapedType();
4748 VectorType maskType = getMaskType();
4749 auto permutationMap = getPermutationMap();
4750 VectorType inferredMaskType =
4754 if (llvm::size(
getIndices()) != shapedType.getRank())
4755 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4759 if (hasBroadcastDim())
4760 return emitOpError(
"should not have broadcast dimensions");
4762 if (failed(
verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4763 shapedType, vectorType, maskType,
4764 inferredMaskType, permutationMap, getInBounds())))
4768 [&](Twine t) {
return emitOpError(t); });
4775 Type TransferWriteOp::getExpectedMaskType() {
4796 static LogicalResult foldReadInitWrite(TransferWriteOp write,
4800 if (write.getTransferRank() == 0)
4802 auto rankedTensorType =
4803 llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
4805 if (!rankedTensorType)
4808 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4812 if (read.getTransferRank() == 0)
4815 if (!read.getPermutationMap().isMinorIdentity() ||
4816 !write.getPermutationMap().isMinorIdentity())
4819 if (read.getTransferRank() != write.getTransferRank())
4822 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4825 if (read.getSource().getType() != rankedTensorType)
4828 if (read.getVectorType() != write.getVectorType())
4831 if (read.getVectorType().getShape() != rankedTensorType.getShape())
4834 auto isNotConstantZero = [](
Value v) {
4836 return !cstOp.has_value() || cstOp.value() != 0;
4838 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4839 llvm::any_of(write.getIndices(), isNotConstantZero))
4842 results.push_back(read.getSource());
4846 static bool checkSameValueWAR(vector::TransferReadOp read,
4847 vector::TransferWriteOp write) {
4848 return read.getSource() == write.getSource() &&
4849 read.getIndices() == write.getIndices() &&
4850 read.getPermutationMap() == write.getPermutationMap() &&
4851 read.getVectorType() == write.getVectorType() && !read.getMask() &&
4868 static LogicalResult foldWAR(TransferWriteOp write,
4870 if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
4872 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4876 if (!checkSameValueWAR(read, write))
4878 results.push_back(read.getSource());
4882 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
4884 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
4886 if (succeeded(foldWAR(*
this, results)))
4895 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4899 void TransferWriteOp::getEffects(
4902 if (llvm::isa<MemRefType>(getShapedType()))
4908 if (hasPureTensorSemantics())
4943 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4945 vector::TransferWriteOp writeToModify = writeOp;
4948 writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4952 writeToModify.getSourceMutable().assign(defWrite.getSource());
4957 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4958 cast<VectorTransferOpInterface>(writeOp.getOperation())))
4962 if (!defWrite->hasOneUse())
4964 writeToModify = defWrite;
4965 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4994 struct SwapExtractSliceOfTransferWrite
5001 if (!insertOp.hasUnitStride())
5004 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5005 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5007 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5008 if (!transferOp || !transferOp->hasOneUse())
5013 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5015 "use-def chain is rank-reducing");
5019 if (!extractOp.hasZeroOffset()) {
5021 "ExtractSliceOp has non-zero offset");
5025 if (!llvm::all_of(transferOp.getIndices(), [](
Value value) {
5029 "TranferWriteOp has non-zero offset");
5033 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5035 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
5038 for (
auto [insertSize, extractSize] :
5039 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5042 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
5047 assert(transferOp.getVectorType().hasStaticShape() &&
5048 "expected vector to have a static shape");
5051 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5052 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
5054 insertOp,
"TransferWriteOp may not write the full tensor.");
5060 auto newExtractOp = rewriter.
create<tensor::ExtractSliceOp>(
5061 extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
5062 insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
5063 insertOp.getMixedStrides());
5064 auto newTransferWriteOp = rewriter.
create<TransferWriteOp>(
5065 transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
5066 transferOp.getIndices(), transferOp.getPermutationMapAttr(),
5069 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5079 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5086 static LogicalResult verifyLoadStoreMemRefLayout(
Operation *op,
5088 MemRefType memRefTy) {
5091 if (!vecTy.isScalable() &&
5092 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5095 if (!memRefTy.isLastDimUnitStride())
5096 return op->
emitOpError(
"most minor memref dim must have unit stride");
5104 if (failed(verifyLoadStoreMemRefLayout(*
this, resVecTy, memRefTy)))
5108 Type memElemTy = memRefTy.getElementType();
5109 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5110 if (memVecTy != resVecTy)
5111 return emitOpError(
"base memref and result vector types should match");
5112 memElemTy = memVecTy.getElementType();
5115 if (resVecTy.getElementType() != memElemTy)
5116 return emitOpError(
"base and result element types should match");
5117 if (llvm::size(
getIndices()) != memRefTy.getRank())
5118 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5136 if (failed(verifyLoadStoreMemRefLayout(*
this, valueVecTy, memRefTy)))
5140 Type memElemTy = memRefTy.getElementType();
5141 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5142 if (memVecTy != valueVecTy)
5144 "base memref and valueToStore vector types should match");
5145 memElemTy = memVecTy.getElementType();
5148 if (valueVecTy.getElementType() != memElemTy)
5149 return emitOpError(
"base and valueToStore element type should match");
5150 if (llvm::size(
getIndices()) != memRefTy.getRank())
5151 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5155 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5165 VectorType maskVType = getMaskVectorType();
5166 VectorType passVType = getPassThruVectorType();
5170 if (resVType.getElementType() != memType.getElementType())
5171 return emitOpError(
"base and result element type should match");
5172 if (llvm::size(
getIndices()) != memType.getRank())
5173 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5174 if (resVType.getShape() != maskVType.getShape())
5175 return emitOpError(
"expected result shape to match mask shape");
5176 if (resVType != passVType)
5177 return emitOpError(
"expected pass_thru of same type as result type");
5190 load, load.getType(), load.getBase(), load.getIndices());
5193 rewriter.
replaceOp(load, load.getPassThru());
5198 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
5205 results.
add<MaskedLoadFolder>(context);
5219 VectorType maskVType = getMaskVectorType();
5223 if (valueVType.getElementType() != memType.getElementType())
5224 return emitOpError(
"base and valueToStore element type should match");
5225 if (llvm::size(
getIndices()) != memType.getRank())
5226 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5227 if (valueVType.getShape() != maskVType.getShape())
5228 return emitOpError(
"expected valueToStore shape to match mask shape");
5241 store, store.getValueToStore(), store.getBase(), store.getIndices());
5249 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
5256 results.
add<MaskedStoreFolder>(context);
5259 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5269 VectorType indVType = getIndexVectorType();
5270 VectorType maskVType = getMaskVectorType();
5272 ShapedType baseType = getBaseType();
5274 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5275 return emitOpError(
"requires base to be a memref or ranked tensor type");
5277 if (resVType.getElementType() != baseType.getElementType())
5278 return emitOpError(
"base and result element type should match");
5279 if (llvm::size(
getIndices()) != baseType.getRank())
5280 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
5281 if (resVType.getShape() != indVType.getShape())
5282 return emitOpError(
"expected result dim to match indices dim");
5283 if (resVType.getShape() != maskVType.getShape())
5284 return emitOpError(
"expected result dim to match mask dim");
5285 if (resVType != getPassThruVectorType())
5286 return emitOpError(
"expected pass_thru of same type as result type");
5294 Type GatherOp::getExpectedMaskType() {
5295 auto vecType = this->getIndexVectorType();
5298 vecType.getScalableDims());
5301 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5306 static LogicalResult isZeroBasedContiguousSeq(
Value indexVec) {
5307 auto vecType = dyn_cast<VectorType>(indexVec.
getType());
5308 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5319 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
5332 rewriter.
replaceOp(gather, gather.getPassThru());
5337 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
5348 if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5352 op.getIndices(), op.getMask(),
5361 results.
add<GatherFolder, FoldContiguousGather>(context);
5369 VectorType indVType = getIndexVectorType();
5370 VectorType maskVType = getMaskVectorType();
5374 if (valueVType.getElementType() != memType.getElementType())
5375 return emitOpError(
"base and valueToStore element type should match");
5376 if (llvm::size(
getIndices()) != memType.getRank())
5377 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5378 if (valueVType.getDimSize(0) != indVType.getDimSize(0))
5379 return emitOpError(
"expected valueToStore dim to match indices dim");
5380 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5381 return emitOpError(
"expected valueToStore dim to match mask dim");
5400 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
5411 if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5415 op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
5423 results.
add<ScatterFolder, FoldContiguousScatter>(context);
5431 VectorType maskVType = getMaskVectorType();
5432 VectorType passVType = getPassThruVectorType();
5436 if (resVType.getElementType() != memType.getElementType())
5437 return emitOpError(
"base and result element type should match");
5438 if (llvm::size(
getIndices()) != memType.getRank())
5439 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5440 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5441 return emitOpError(
"expected result dim to match mask dim");
5442 if (resVType != passVType)
5443 return emitOpError(
"expected pass_thru of same type as result type");
5456 expand, expand.getType(), expand.getBase(), expand.getIndices());
5459 rewriter.
replaceOp(expand, expand.getPassThru());
5464 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
5471 results.
add<ExpandLoadFolder>(context);
5479 VectorType maskVType = getMaskVectorType();
5483 if (valueVType.getElementType() != memType.getElementType())
5484 return emitOpError(
"base and valueToStore element type should match");
5485 if (llvm::size(
getIndices()) != memType.getRank())
5486 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5487 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5488 return emitOpError(
"expected valueToStore dim to match mask dim");
5493 class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
5501 compress, compress.getValueToStore(), compress.getBase(),
5502 compress.getIndices());
5510 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
5517 results.
add<CompressStoreFolder>(context);
5526 setResultRanges(getResult(), argRanges.front());
5532 unsigned rankA = a.size();
5533 unsigned rankB = b.size();
5534 assert(rankA < rankB);
5536 auto isOne = [](int64_t v) {
return v == 1; };
5540 if (rankA == 0 && llvm::all_of(b, isOne))
5545 while (i < rankA &&
j < rankB) {
5546 int64_t dimA = a[i];
5548 while (dimB < dimA &&
j < rankB)
5556 if (i < rankA && llvm::all_of(a.slice(i), isOne))
5558 if (
j < rankB && llvm::all_of(b.slice(
j), isOne))
5562 return i == rankA &&
j == rankB;
5565 static LogicalResult verifyVectorShapeCast(
Operation *op,
5566 VectorType sourceVectorType,
5567 VectorType resultVectorType) {
5569 if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5570 return op->
emitOpError(
"source/result vectors must have same element type");
5571 auto sourceShape = sourceVectorType.getShape();
5572 auto resultShape = resultVectorType.getShape();
5575 int64_t sourceDimProduct = std::accumulate(
5576 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5577 int64_t resultDimProduct = std::accumulate(
5578 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5579 if (sourceDimProduct != resultDimProduct)
5580 return op->
emitOpError(
"source/result number of elements must match");
5583 unsigned sourceRank = sourceVectorType.getRank();
5584 unsigned resultRank = resultVectorType.getRank();
5585 if (sourceRank < resultRank) {
5586 if (!isValidShapeCast(sourceShape, resultShape))
5588 }
else if (sourceRank > resultRank) {
5589 if (!isValidShapeCast(resultShape, sourceShape))
5594 int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
5595 int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
5596 if (sourceNScalableDims != resultNScalableDims)
5597 return op->
emitOpError(
"different number of scalable dims at source (")
5598 << sourceNScalableDims <<
") and result (" << resultNScalableDims
5600 sourceVectorType.getNumDynamicDims();
5606 auto sourceVectorType =
5607 llvm::dyn_cast_or_null<VectorType>(getSource().
getType());
5608 auto resultVectorType =
5609 llvm::dyn_cast_or_null<VectorType>(getResult().
getType());
5612 if (sourceVectorType && resultVectorType)
5613 return verifyVectorShapeCast(*
this, sourceVectorType, resultVectorType);
5624 if (
auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5625 if (getResult().
getType() == otherOp.getSource().getType())
5626 return otherOp.getSource();
5629 VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5630 VectorType resultType = llvm::cast<VectorType>(getResult().
getType());
5631 if (srcType.getRank() < resultType.getRank()) {
5632 if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5634 }
else if (srcType.getRank() > resultType.getRank()) {
5635 if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5641 setOperand(otherOp.getSource());
5646 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5647 if (bcastOp.getSourceType() ==
getType())
5648 return bcastOp.getSource();
5656 class ShapeCastConstantFolder final :
public OpRewritePattern<ShapeCastOp> {
5663 shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5667 auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5683 static VectorType trimTrailingOneDims(VectorType oldType) {
5690 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5691 newShape = newShape.drop_back(1);
5692 newScalableDims = newScalableDims.drop_back(1);
5697 if (newShape.empty()) {
5698 newShape = oldShape.take_back();
5699 newScalableDims = oldScalableDims.take_back();
5702 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5717 class ShapeCastCreateMaskFolderTrailingOneDim final
5724 Value shapeOpSrc = shapeOp->getOperand(0);
5725 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
5726 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
5727 if (!createMaskOp && !constantMaskOp)
5730 VectorType shapeOpResTy = shapeOp.getResultVectorType();
5731 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5733 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5734 if (newVecType != shapeOpResTy)
5737 auto numDimsToDrop =
5738 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5745 auto maskOperands = createMaskOp.getOperands();
5746 auto numMaskOperands = maskOperands.size();
5749 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5751 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5752 if (!constant || (constant.value() != 1))
5756 maskOperands.drop_back(numDimsToDrop);
5763 if (constantMaskOp) {
5764 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5765 auto numMaskOperands = maskDimSizes.size();
5768 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5770 if (maskDimSizes[i] != 1)
5774 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5789 class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
5796 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5801 if (
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5802 broadcastSourceShape = srcType.getShape();
5804 shapeCastOp.getResultVectorType().getShape();
5808 if (broadcastSourceShape ==
5809 shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
5811 shapeCastOp, shapeCastOp.getResultVectorType(),
5812 broadcastOp.getSource());
5818 if (
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5819 if (srcType.getNumElements() ==
5820 shapeCastOp.getResultVectorType().getNumElements()) {
5822 shapeCastOp, shapeCastOp.getResultVectorType(),
5823 broadcastOp.getSource());
5836 results.
add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5837 ShapeCastBroadcastFolder>(context);
5845 auto sourceVectorType = getSourceVectorType();
5846 auto resultVectorType = getResultVectorType();
5848 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5849 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5850 return emitOpError(
"dimension size mismatch at: ") << i;
5853 DataLayout dataLayout = DataLayout::closest(*
this);
5854 auto sourceElementBits =
5856 auto resultElementBits =
5859 if (sourceVectorType.getRank() == 0) {
5860 if (sourceElementBits != resultElementBits)
5861 return emitOpError(
"source/result bitwidth of the 0-D vector element "
5862 "types must be equal");
5863 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
5864 resultElementBits * resultVectorType.getShape().back()) {
5866 "source/result bitwidth of the minor 1-D vectors must be equal");
5878 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5879 if (getResult().
getType() == otherOp.getSource().getType())
5880 return otherOp.getSource();
5882 setOperand(otherOp.getSource());
5886 Attribute sourceConstant = adaptor.getSource();
5887 if (!sourceConstant)
5890 Type srcElemType = getSourceVectorType().getElementType();
5891 Type dstElemType = getResultVectorType().getElementType();
5893 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5894 if (floatPack.isSplat()) {
5895 auto splat = floatPack.getSplatValue<FloatAttr>();
5898 if (srcElemType.
isF16() && dstElemType.
isF32()) {
5899 uint32_t bits =
static_cast<uint32_t
>(
5900 splat.getValue().bitcastToAPInt().getZExtValue());
5902 bits = (bits << 16) | (bits & 0xffff);
5903 APInt intBits(32, bits);
5904 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5910 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5911 if (intPack.isSplat()) {
5912 auto splat = intPack.getSplatValue<IntegerAttr>();
5914 if (llvm::isa<IntegerType>(dstElemType)) {
5919 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5920 APInt intBits = splat.getValue().zext(dstBitWidth);
5923 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5924 intBits = (intBits << srcBitWidth) | intBits;
5939 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5942 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5951 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
5952 VectorType vectorType =
5956 memRefType.getMemorySpace()));
5960 MemRefType canonicalType =
getMemRefType().canonicalizeStridedLayout();
5961 if (!canonicalType.getLayout().isIdentity())
5962 return emitOpError(
"expects operand to be a memref with identity layout");
5963 if (!getResultMemRefType().getLayout().isIdentity())
5964 return emitOpError(
"expects result to be a memref with identity layout");
5965 if (getResultMemRefType().getMemorySpace() !=
5967 return emitOpError(
"expects result in same memory space");
5970 auto resultType = getResultMemRefType();
5974 "expects result and operand with same underlying scalar type: ")
5976 if (extractShape(sourceType) != extractShape(resultType))
5978 "expects concatenated result and operand shapes to be equal: ")
5989 VectorType vt = llvm::cast<VectorType>(vector.
getType());
5992 for (
unsigned i = 0; i < permutation.size(); ++i) {
5993 transposedShape[i] = vt.getShape()[permutation[i]];
5994 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5999 transposedScalableDims));
6004 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6007 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
6009 return attr.reshape(getResultVectorType());
6017 for (int64_t i = 0, e = perm.size(); i < e; i++) {
6026 VectorType vectorType = getSourceVectorType();
6027 VectorType resultType = getResultVectorType();
6028 int64_t rank = resultType.getRank();
6029 if (vectorType.getRank() != rank)
6030 return emitOpError(
"vector result rank mismatch: ") << rank;
6033 int64_t size = perm.size();
6035 return emitOpError(
"transposition length mismatch: ") << size;
6038 if (ta.value() < 0 || ta.value() >= rank)
6039 return emitOpError(
"transposition index out of range: ") << ta.value();
6040 if (seen[ta.value()])
6041 return emitOpError(
"duplicate position index: ") << ta.value();
6042 seen[ta.value()] =
true;
6043 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6044 return emitOpError(
"dimension size mismatch at: ") << ta.value();
6049 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6050 return llvm::to_vector<4>(getResultVectorType().
getShape());
6056 class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
6066 for (
auto index : permutation2)
6067 result.push_back(permutation1[index]);
6072 vector::TransposeOp parentTransposeOp =
6073 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6074 if (!parentTransposeOp)
6078 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6081 transposeOp, transposeOp.getResult().getType(),
6082 parentTransposeOp.getVector(), permutation);
6088 struct FoldTransposedScalarBroadcast final
6094 auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
6098 auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
6099 if (!srcVectorType || srcVectorType.getNumElements() == 1) {
6101 transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
6116 auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
6121 transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());