39 #include "llvm/ADT/ArrayRef.h"
40 #include "llvm/ADT/STLExtras.h"
41 #include "llvm/ADT/SmallVector.h"
42 #include "llvm/ADT/StringSet.h"
43 #include "llvm/ADT/TypeSwitch.h"
44 #include "llvm/ADT/bit.h"
50 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
52 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
73 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
75 for (
bool b : denseElts.getValues<
bool>())
78 else if (!b && val <= 0)
92 auto shape = m.getType().getShape();
95 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
96 if (maskIdx < dimSize)
109 auto maskOperands = m.getOperands();
110 for (
Value operand : maskOperands) {
111 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
113 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
126 builder.
create<vector::YieldOp>(loc);
132 switch (combiningKind) {
133 case CombiningKind::ADD:
134 case CombiningKind::MUL:
137 case CombiningKind::MINSI:
138 case CombiningKind::MAXUI:
139 case CombiningKind::MAXSI:
140 case CombiningKind::AND:
141 case CombiningKind::OR:
142 case CombiningKind::XOR:
144 case CombiningKind::MINNUMF:
145 case CombiningKind::MAXNUMF:
146 case CombiningKind::MINIMUMF:
147 case CombiningKind::MAXIMUMF:
148 return llvm::isa<FloatType>(elementType);
154 VectorType vectorType) {
155 int64_t elementVectorRank = 0;
156 VectorType elementVectorType =
157 llvm::dyn_cast<VectorType>(shapedType.getElementType());
158 if (elementVectorType)
159 elementVectorRank += elementVectorType.getRank();
162 if (shapedType.getRank() == 0 &&
168 shapedType.getRank(), vectorType.getRank() - elementVectorRank,
169 shapedType.getContext());
176 vector::TransferReadOp read) {
177 auto readMask = read.getMask();
178 auto writeMask = write.getMask();
184 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
185 if (!couldBeSameSplat)
190 m_Constant<DenseElementsAttr>(&splatAttr)) ||
202 vector::TransferReadOp read) {
203 return !defWrite.hasOutOfBoundsDim() &&
204 defWrite.getIndices() == read.getIndices() &&
205 defWrite.getVectorType() == read.getVectorType() &&
206 defWrite.getPermutationMap() == read.getPermutationMap() &&
207 ((!defWrite.getMask() && !read.getMask()) ||
212 vector::TransferWriteOp priorWrite) {
213 return priorWrite.getIndices() == write.getIndices() &&
214 priorWrite.getMask() == write.getMask() &&
215 priorWrite.getVectorType() == write.getVectorType() &&
216 priorWrite.getPermutationMap() == write.getPermutationMap();
220 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
221 bool testDynamicValueUsingBounds) {
223 if (transferA.getVectorType() != transferB.getVectorType())
225 unsigned rankOffset = transferA.getLeadingShapedRank();
226 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
227 Value indexA = transferA.getIndices()[i];
228 Value indexB = transferB.getIndices()[i];
232 if (i < rankOffset) {
235 if (cstIndexA.has_value() && cstIndexB.has_value()) {
236 if (*cstIndexA != *cstIndexB)
240 if (testDynamicValueUsingBounds) {
243 FailureOr<uint64_t> delta =
245 if (succeeded(delta) && *delta != 0)
248 FailureOr<bool> testEqual =
250 if (succeeded(testEqual) && !testEqual.value())
256 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
257 if (cstIndexA.has_value() && cstIndexB.has_value()) {
258 int64_t distance =
std::abs(*cstIndexA - *cstIndexB);
259 if (distance >= vectorDim)
263 if (testDynamicValueUsingBounds) {
266 FailureOr<int64_t> delta =
268 if (succeeded(delta) &&
std::abs(*delta) >= vectorDim)
271 FailureOr<int64_t> computeDelta =
273 if (succeeded(computeDelta)) {
274 if (
std::abs(computeDelta.value()) >= vectorDim)
284 VectorTransferOpInterface transferB,
285 bool testDynamicValueUsingBounds) {
286 if (transferA.getSource() != transferB.getSource())
289 testDynamicValueUsingBounds);
299 for (
auto [posInDim, dimSize, offsetInDim] :
300 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
302 if (posInDim < dimSize + offsetInDim)
306 posInDim = offsetInDim;
316 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
318 assert(constOp &&
"Unexpected non-constant index");
319 return constOp.value();
329 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
330 assert(foldResult.is<
Attribute>() &&
"Unexpected non-constant index");
331 return cast<IntegerAttr>(foldResult.get<
Attribute>()).getInt();
341 llvm::transform(foldResults, std::back_inserter(values),
343 if (
auto attr = foldResult.dyn_cast<
Attribute>())
346 loc, cast<IntegerAttr>(attr).getInt())
349 return foldResult.get<
Value>();
397 void VectorDialect::initialize() {
399 #define GET_ATTRDEF_LIST
400 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
405 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
408 addInterfaces<VectorInlinerInterface>();
410 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
411 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
413 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
415 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
416 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
424 return arith::ConstantOp::materialize(builder, value, type, loc);
440 void vector::MultiDimReductionOp::build(
OpBuilder &builder,
443 CombiningKind kind) {
447 reductionDims.push_back(en.index());
448 build(builder, result, kind, source, acc, reductionDims);
451 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
453 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
458 std::optional<SmallVector<int64_t, 4>>
459 MultiDimReductionOp::getShapeForUnroll() {
460 return llvm::to_vector<4>(getSourceVectorType().
getShape());
466 Type inferredReturnType;
467 auto sourceScalableDims = getSourceVectorType().getScalableDims();
468 for (
auto [dimIdx, dimSize] :
470 if (!llvm::any_of(getReductionDims(),
471 [dimIdx = dimIdx](int64_t reductionDimIdx) {
472 return reductionDimIdx ==
static_cast<int64_t
>(dimIdx);
474 targetShape.push_back(dimSize);
475 scalableDims.push_back(sourceScalableDims[dimIdx]);
478 if (targetShape.empty())
479 inferredReturnType = getSourceVectorType().getElementType();
482 targetShape, getSourceVectorType().
getElementType(), scalableDims);
483 if (
getType() != inferredReturnType)
484 return emitOpError() <<
"destination type " <<
getType()
485 <<
" is incompatible with source type "
486 << getSourceVectorType();
492 Type MultiDimReductionOp::getExpectedMaskType() {
493 auto vecType = getSourceVectorType();
496 vecType.getScalableDims());
505 struct ElideUnitDimsInMultiDimReduction
509 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
512 for (
const auto &dim :
enumerate(shape)) {
513 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
521 if (reductionOp.isMasked()) {
523 rootOp = reductionOp.getMaskingOp();
524 mask = reductionOp.getMaskingOp().getMask();
526 rootOp = reductionOp;
529 Location loc = reductionOp.getLoc();
530 Value acc = reductionOp.getAcc();
532 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
534 VectorType newMaskType =
536 dstVecType.getScalableDims());
537 mask = rewriter.
create<vector::ShapeCastOp>(loc, newMaskType, mask);
539 cast = rewriter.
create<vector::ShapeCastOp>(
540 loc, reductionOp.getDestType(), reductionOp.getSource());
546 mask = rewriter.
create<vector::ExtractOp>(loc, mask, zeroIdx);
547 cast = rewriter.
create<vector::ExtractOp>(loc, reductionOp.getSource(),
553 cast,
nullptr, mask);
560 void MultiDimReductionOp::getCanonicalizationPatterns(
562 results.
add<ElideUnitDimsInMultiDimReduction>(context);
570 CombiningKind kind,
Value vector,
571 arith::FastMathFlags fastMathFlags) {
572 build(builder, result, kind, vector,
Value(), fastMathFlags);
577 arith::FastMathFlags fastMathFlags) {
578 build(builder, result,
579 llvm::cast<VectorType>(vector.
getType()).getElementType(), kind, vector,
585 int64_t rank = getSourceVectorType().getRank();
587 return emitOpError(
"unsupported reduction rank: ") << rank;
590 Type eltType = getDest().getType();
592 return emitOpError(
"unsupported reduction type '")
593 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
602 Type ReductionOp::getExpectedMaskType() {
603 auto vecType = getSourceVectorType();
606 vecType.getScalableDims());
613 case arith::AtomicRMWKind::addf:
614 case arith::AtomicRMWKind::addi:
615 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
616 CombiningKind::ADD, vector);
617 case arith::AtomicRMWKind::mulf:
618 case arith::AtomicRMWKind::muli:
619 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
620 CombiningKind::MUL, vector);
621 case arith::AtomicRMWKind::minimumf:
622 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
623 CombiningKind::MINIMUMF, vector);
624 case arith::AtomicRMWKind::mins:
625 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
626 CombiningKind::MINSI, vector);
627 case arith::AtomicRMWKind::minu:
628 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
630 case arith::AtomicRMWKind::maximumf:
631 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
632 CombiningKind::MAXIMUMF, vector);
633 case arith::AtomicRMWKind::maxs:
634 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
635 CombiningKind::MAXSI, vector);
636 case arith::AtomicRMWKind::maxu:
637 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
638 CombiningKind::MAXUI, vector);
639 case arith::AtomicRMWKind::andi:
640 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
641 CombiningKind::AND, vector);
642 case arith::AtomicRMWKind::ori:
643 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
644 CombiningKind::OR, vector);
653 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
654 return llvm::to_vector<4>(getSourceVectorType().
getShape());
661 LogicalResult matchAndRewrite(ReductionOp reductionOp,
666 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
669 if (maskableOp.isMasked()) {
671 rootOp = maskableOp.getMaskingOp();
672 mask = maskableOp.getMaskingOp().getMask();
674 rootOp = reductionOp;
677 auto vectorType = reductionOp.getSourceVectorType();
678 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
681 Location loc = reductionOp.getLoc();
683 if (vectorType.getRank() == 0) {
685 mask = rewriter.
create<ExtractElementOp>(loc, mask);
686 result = rewriter.
create<ExtractElementOp>(loc, reductionOp.getVector());
689 mask = rewriter.
create<ExtractOp>(loc, mask, 0);
690 result = rewriter.
create<ExtractOp>(loc, reductionOp.getVector(), 0);
693 if (
Value acc = reductionOp.getAcc())
696 reductionOp.getFastmathAttr(), mask);
706 results.
add<ElideSingleElementReduction>(context);
720 getIndexingMapsAttrName(result.
name),
724 getIteratorTypesAttrName(result.
name),
727 return IteratorTypeAttr::get(builder.getContext(), t);
733 ArrayAttr indexingMaps,
734 ArrayAttr iteratorTypes) {
735 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
736 ContractionOp::getDefaultKind());
741 ArrayAttr indexingMaps,
742 ArrayAttr iteratorTypes, CombiningKind kind) {
759 DictionaryAttr dictAttr;
774 dictAttr.getValue().end());
780 ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
785 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
786 auto maybeIteratorType = symbolizeIteratorType(s);
787 if (!maybeIteratorType.has_value())
788 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
790 iteratorTypeAttrs.push_back(
798 getKindAttrName(result.
name),
800 ContractionOp::getDefaultKind()));
802 if (masksInfo.empty())
804 if (masksInfo.size() != 2)
806 "expected zero or exactly 2 vector mask operands");
807 auto lhsType = llvm::cast<VectorType>(types[0]);
808 auto rhsType = llvm::cast<VectorType>(types[1]);
810 std::array<VectorType, 2> maskTypes = {
820 auto attrNames = getTraitAttrNames();
822 traitAttrsSet.insert(attrNames.begin(), attrNames.end());
824 for (
auto attr : (*this)->getAttrs()) {
825 if (attr.getName() == getIteratorTypesAttrName()) {
827 llvm::cast<ArrayAttr>(attr.getValue())
828 .getAsValueRange<IteratorTypeAttr, IteratorType>();
834 llvm::map_range(iteratorTypes, [&](IteratorType t) ->
Attribute {
838 attrs.emplace_back(getIteratorTypesAttrName(),
840 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0)
841 attrs.push_back(attr);
845 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
846 p << getRhs() <<
", " << getAcc();
849 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
854 const std::vector<std::pair<int64_t, int64_t>> &map) {
855 for (
auto &dimPair : map) {
856 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
857 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
858 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
865 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
867 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
868 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
871 for (
auto &dimPair : contractingDimMap) {
872 lhsContractingDimSet.insert(dimPair.first);
873 rhsContractingDimSet.insert(dimPair.second);
876 for (
auto &dimPair : batchDimMap)
877 rhsBatchDimSet.insert(dimPair.second);
881 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
882 if (lhsContractingDimSet.count(i) > 0)
884 expectedResultDims.push_back(lhsType.getDimSize(i));
888 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
889 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
891 expectedResultDims.push_back(rhsType.getDimSize(i));
895 if (expectedResultDims.empty()) {
897 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
898 return op.
emitOpError(
"invalid accumulator/result vector shape");
901 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
902 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
903 if (!resVectorType || !accVectorType)
904 return op.
emitOpError(
"invalid accumulator/result vector shape");
910 AffineMap lhsMap = op.getIndexingMapsArray()[0];
911 AffineMap rhsMap = op.getIndexingMapsArray()[1];
914 "expected all dimensions to be either a LHS or a RHS dimension");
917 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
918 VectorType v = pair.first;
919 auto map = pair.second;
920 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
921 unsigned pos = map.getDimPosition(idx);
926 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
927 return op.
emitOpError(
"expected all dimensions to get an extent as "
928 "either a LHS or a RHS dimension");
930 AffineMap resMap = op.getIndexingMapsArray()[2];
936 llvm::IsaPred<AffineConstantExpr>) &&
937 "expected constant extent along all dimensions.");
939 auto expectedShape = llvm::to_vector<4>(
941 return cast<AffineConstantExpr>(e).getValue();
945 resVectorType.getScalableDims());
946 if (resVectorType != expected || accVectorType != expected)
948 "invalid accumulator/result vector shape, expected: ")
955 VectorType lhsType = getLhsType();
956 VectorType rhsType = getRhsType();
957 Type accType = getAccType();
958 Type resType = getResultType();
960 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
961 if (!lhsType.getElementType().isSignlessInteger())
962 return emitOpError(
"only supports signless integer types");
966 if (getIndexingMapsArray().size() != 3)
967 return emitOpError(
"expected an indexing map for each vector operand");
972 unsigned numIterators = getIteratorTypes().getValue().size();
974 auto index = it.index();
975 auto map = it.value();
976 if (map.getNumSymbols() != 0)
977 return emitOpError(
"expected indexing map ")
978 << index <<
" to have no symbols";
979 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).
getType());
980 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
983 if (map.getNumDims() != numIterators)
984 return emitOpError(
"expected indexing map ")
985 << index <<
" to have " << numIterators <<
" number of inputs";
986 if (map.getNumResults() != rank)
987 return emitOpError(
"expected indexing map ")
988 << index <<
" to have " << rank <<
" number of outputs";
989 if (!map.isProjectedPermutation())
990 return emitOpError(
"expected indexing map ")
991 << index <<
" to be a projected permutation of its inputs";
994 auto contractingDimMap = getContractingDimMap();
995 auto batchDimMap = getBatchDimMap();
998 if (contractingDimMap.empty())
999 return emitOpError(
"expected at least one contracting dimension pair");
1002 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1003 return emitOpError(
"invalid contracting dimension map");
1007 return emitOpError(
"invalid batch dimension map");
1011 contractingDimMap, batchDimMap)))
1015 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1016 auto elementType = vectorType ? vectorType.getElementType() : resType;
1018 return emitOpError(
"unsupported contraction type");
1027 Type ContractionOp::getExpectedMaskType() {
1028 auto indexingMaps = this->getIndexingMapsArray();
1031 VectorType lhsType = this->getLhsType();
1032 VectorType rhsType = this->getRhsType();
1034 unsigned numVecDims = lhsIdxMap.
getNumDims();
1043 lhsType.getScalableDims()[dimIdx];
1048 rhsType.getScalableDims()[dimIdx];
1051 assert(!ShapedType::isDynamicShape(maskShape) &&
1052 "Mask shape couldn't be computed");
1056 maskShapeScalableDims);
1061 getIteratorTypesAttrName(), getKindAttrName()};
1071 static std::vector<std::pair<int64_t, int64_t>>
1073 IteratorType targetIteratorType,
MLIRContext *context) {
1074 std::vector<std::pair<int64_t, int64_t>> dimMap;
1076 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1077 if (iteratorType != targetIteratorType)
1083 if (lhsDim >= 0 && rhsDim >= 0)
1084 dimMap.emplace_back(lhsDim, rhsDim);
1089 void ContractionOp::getIterationBounds(
1091 auto lhsShape = getLhsType().getShape();
1092 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1098 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1099 if (iteratorType == IteratorType::reduction) {
1101 int64_t lhsDimIndex =
getResultIndex(indexingMaps[0], targetExpr);
1102 assert(lhsDimIndex >= 0);
1103 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1107 int64_t resDimIndex =
getResultIndex(indexingMaps[2], targetExpr);
1108 assert(resDimIndex >= 0);
1109 assert(resVectorType !=
nullptr);
1110 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1114 void ContractionOp::getIterationIndexMap(
1116 unsigned numMaps = getIndexingMapsArray().size();
1117 iterationIndexMap.resize(numMaps);
1119 auto index = it.index();
1120 auto map = it.value();
1121 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1122 auto dim = cast<AffineDimExpr>(map.getResult(i));
1123 iterationIndexMap[index][dim.getPosition()] = i;
1128 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1130 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1134 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1136 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1140 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1142 getIterationBounds(shape);
1164 template <
typename AddOpType>
1170 auto canonicalize = [&](
Value maybeContraction,
1171 Value otherOperand) -> vector::ContractionOp {
1172 vector::ContractionOp contractionOp =
1173 dyn_cast_or_null<vector::ContractionOp>(
1176 return vector::ContractionOp();
1177 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1178 contractionOp.getAcc().getDefiningOp())) {
1179 if (maybeZero.getValue() ==
1180 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1182 bvm.
map(contractionOp.getAcc(), otherOperand);
1183 auto newContraction =
1184 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1185 rewriter.
replaceOp(addOp, newContraction.getResult());
1186 return newContraction;
1189 return vector::ContractionOp();
1192 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1193 vector::ContractionOp
contract = canonicalize(a, b);
1195 return contract ? success() : failure();
1212 result.
addTypes(llvm::cast<VectorType>(source.
getType()).getElementType());
1216 VectorType vectorType = getSourceVectorType();
1217 if (vectorType.getRank() == 0) {
1219 return emitOpError(
"expected position to be empty with 0-D vector");
1222 if (vectorType.getRank() != 1)
1223 return emitOpError(
"unexpected >1 vector rank");
1225 return emitOpError(
"expected position for 1-D vector");
1229 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1231 if (!adaptor.getPosition())
1235 if (
auto splat = getVector().getDefiningOp<vector::SplatOp>())
1236 return splat.getInput();
1239 if (
auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1243 auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1244 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1248 auto srcElements = src.getValues<
Attribute>();
1250 uint64_t posIdx = pos.getInt();
1251 if (posIdx >= srcElements.size())
1254 return srcElements[posIdx];
1262 Value source, int64_t position) {
1282 build(builder, result, source, dynamicPos,
1287 ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1288 ExtractOp::Adaptor adaptor,
1290 auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1291 if (
static_cast<int64_t
>(adaptor.getStaticPosition().size()) ==
1292 vectorType.getRank()) {
1293 inferredReturnTypes.push_back(vectorType.getElementType());
1295 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1296 vectorType.getRank());
1298 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1299 vectorType.getScalableDims().drop_front(n)));
1307 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1308 return vectorType && vectorType.getShape().equals({1}) &&
1309 vectorType.getElementType() == r.front();
1311 if (l.size() == 1 && r.size() == 1 &&
1312 (isCompatible(l, r) || isCompatible(r, l)))
1319 auto dynamicMarkersCount =
1320 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1321 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1323 "mismatch between dynamic and static positions (kDynamic marker but no "
1324 "corresponding dynamic position) -- this can only happen due to an "
1325 "incorrect fold/rewrite");
1326 auto position = getMixedPosition();
1327 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1329 "expected position attribute of rank no greater than vector rank");
1332 int64_t constIdx = cast<IntegerAttr>(pos.get<
Attribute>()).getInt();
1333 if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
1334 return emitOpError(
"expected position attribute #")
1336 <<
" to be a non-negative integer smaller than the "
1337 "corresponding vector dimension";
1344 template <
typename IntType>
1346 return llvm::to_vector<4>(llvm::map_range(
1347 arrayAttr.getAsRange<IntegerAttr>(),
1348 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1354 if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1358 if (extractOp.hasDynamicPosition())
1362 ExtractOp currentOp = extractOp;
1364 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1365 while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1368 if (currentOp.hasDynamicPosition())
1371 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1373 extractOp.setOperand(0, currentOp.getVector());
1376 std::reverse(globalPosition.begin(), globalPosition.end());
1377 extractOp.setStaticPosition(globalPosition);
1389 class ExtractFromInsertTransposeChainState {
1391 ExtractFromInsertTransposeChainState(ExtractOp e);
1400 template <
typename ContainerA,
typename ContainerB>
1401 bool isContainedWithin(
const ContainerA &a,
const ContainerB &b) {
1402 return a.size() <= b.size() &&
1403 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1410 template <
typename ContainerA,
typename ContainerB>
1411 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &b) {
1412 for (
auto [elemA, elemB] : llvm::zip(a, b)) {
1413 if (elemA < 0 || elemB < 0)
1428 void updateStateForNextIteration(
Value v) {
1435 LogicalResult handleTransposeOp();
1438 LogicalResult handleInsertOpWithMatchingPos(
Value &res);
1453 LogicalResult handleInsertOpWithPrefixPos(
Value &res);
1458 Value tryToFoldExtractOpInPlace(
Value source);
1460 ExtractOp extractOp;
1462 int64_t extractedRank;
1464 InsertOp nextInsertOp;
1465 TransposeOp nextTransposeOp;
1480 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1482 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1483 extractedRank(extractOp.getNumIndices()) {
1484 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1485 sentinels.reserve(vectorRank - extractedRank);
1486 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1487 sentinels.push_back(-(i + 1));
1489 extractOp.getStaticPosition().end());
1495 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1497 if (extractOp.hasDynamicPosition())
1500 if (!nextTransposeOp)
1503 nextTransposeOp.getPermutation(), extractOp.getContext()));
1510 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1513 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1520 res = nextInsertOp.getSource();
1522 return success(canFold());
1529 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(
Value &res) {
1531 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1544 res = nextInsertOp.getSource();
1552 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1555 if (extractOp.hasDynamicPosition())
1559 bool nothingToFold = (source == extractOp.getVector());
1560 if (nothingToFold || !canFold())
1565 extractOp.setStaticPosition(
1567 extractOp.getVectorMutable().assign(source);
1568 return extractOp.getResult();
1572 Value ExtractFromInsertTransposeChainState::fold() {
1574 if (extractOp.hasDynamicPosition())
1577 Value valueToExtractFrom = extractOp.getVector();
1578 updateStateForNextIteration(valueToExtractFrom);
1579 while (nextInsertOp || nextTransposeOp) {
1582 if (succeeded(handleTransposeOp())) {
1583 valueToExtractFrom = nextTransposeOp.getVector();
1584 updateStateForNextIteration(valueToExtractFrom);
1590 if (succeeded(handleInsertOpWithMatchingPos(result)))
1595 if (succeeded(handleInsertOpWithPrefixPos(result)))
1596 return tryToFoldExtractOpInPlace(result);
1606 valueToExtractFrom = nextInsertOp.getDest();
1607 updateStateForNextIteration(valueToExtractFrom);
1610 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1615 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1616 auto vecType = dyn_cast<VectorType>(type);
1617 return vecType && vecType.getRank() == 0;
1627 if (extractOp.hasDynamicPosition())
1630 Operation *defOp = extractOp.getVector().getDefiningOp();
1631 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1635 if (extractOp.getType() == source.
getType())
1637 auto getRank = [](
Type type) {
1638 return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1643 unsigned broadcastSrcRank = getRank(source.
getType());
1644 if (broadcastSrcRank == 0 && source.
getType() == extractOp.getType())
1647 unsigned extractResultRank = getRank(extractOp.getType());
1648 if (extractResultRank >= broadcastSrcRank)
1651 auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1652 auto broadcastVecType = llvm::dyn_cast<VectorType>(source.
getType());
1653 if (extractVecType && broadcastVecType &&
1654 extractVecType.getShape() !=
1655 broadcastVecType.getShape().take_back(extractResultRank))
1658 auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1659 int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1665 broadcastOp.computeBroadcastedUnitDims();
1667 int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1668 for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1669 if (broadcastedUnitDims.contains(i))
1673 int64_t rankDiff = broadcastSrcRank - extractResultRank;
1674 extractPos.erase(extractPos.begin(),
1675 std::next(extractPos.begin(), extractPos.size() - rankDiff));
1678 extractOp.setOperand(0, source);
1679 extractOp.setStaticPosition(extractPos);
1680 return extractOp.getResult();
1686 if (extractOp.hasDynamicPosition())
1689 auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1699 auto getDimReverse = [](VectorType type, int64_t n) {
1700 return type.getShape().take_back(n + 1).front();
1702 int64_t destinationRank =
1703 llvm::isa<VectorType>(extractOp.getType())
1704 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1706 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1708 if (destinationRank > 0) {
1709 auto destinationType =
1710 llvm::cast<VectorType>(extractOp.getResult().getType());
1711 for (int64_t i = 0; i < destinationRank; i++) {
1715 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1716 getDimReverse(destinationType, i))
1723 std::reverse(extractedPos.begin(), extractedPos.end());
1726 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1727 strides.push_back(stride);
1729 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1732 int64_t position =
linearize(extractedPos, strides);
1736 int64_t numDimension =
1737 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1739 for (int64_t i = 0; i < numDimension; i++) {
1740 newStrides.push_back(stride);
1742 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1744 std::reverse(newStrides.begin(), newStrides.end());
1748 extractOp.setStaticPosition(newPosition);
1749 extractOp.setOperand(0, shapeCastOp.getSource());
1750 return extractOp.getResult();
1756 if (extractOp.hasDynamicPosition())
1759 auto extractStridedSliceOp =
1760 extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1761 if (!extractStridedSliceOp)
1770 if (extractStridedSliceOp.hasNonUnitStrides())
1775 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1776 while (!sliceOffsets.empty()) {
1777 size_t lastOffset = sliceOffsets.size() - 1;
1778 if (sliceOffsets.back() != 0 ||
1779 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1780 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1782 sliceOffsets.pop_back();
1784 unsigned destinationRank = 0;
1785 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1786 destinationRank = vecType.getRank();
1789 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1790 sliceOffsets.size())
1794 assert(extractedPos.size() >= sliceOffsets.size());
1795 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1796 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1797 extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1801 extractOp.setStaticPosition(extractedPos);
1802 return extractOp.getResult();
1808 if (extractOp.hasDynamicPosition())
1811 int64_t destinationRank =
1812 llvm::isa<VectorType>(extractOp.getType())
1813 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1815 auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1825 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1826 insertOp.getSourceVectorType().getRank();
1827 if (destinationRank > insertOp.getSourceVectorType().getRank())
1829 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1832 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1833 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1836 bool disjoint =
false;
1838 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1839 int64_t start = insertOffsets[dim];
1841 (dim < insertRankDiff)
1843 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1844 int64_t end = start + size;
1845 int64_t offset = extractOffsets[dim];
1847 if (start <= offset && offset < end) {
1848 if (dim >= insertRankDiff)
1849 offsetDiffs.push_back(offset - start);
1859 int64_t srcRankDiff =
1860 insertOp.getSourceVectorType().getRank() - destinationRank;
1861 for (int64_t i = 0; i < destinationRank; i++) {
1862 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1863 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1867 extractOp.getVectorMutable().assign(insertOp.getSource());
1870 extractOp.setStaticPosition(offsetDiffs);
1871 return extractOp.getResult();
1875 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1888 if (extractOp.hasDynamicPosition())
1892 auto fromElementsOp = extractOp.getVector().
getDefiningOp<FromElementsOp>();
1893 if (!fromElementsOp)
1897 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
1898 if (vecType.isScalable())
1902 int64_t rank = vecType.getRank();
1904 if (extractOp.getType() != vecType.getElementType())
1906 assert(
static_cast<int64_t
>(indices.size()) == rank &&
1907 "unexpected number of indices");
1912 for (
int i = rank - 1; i >= 0; --i) {
1913 flatIndex += indices[i] * stride;
1914 stride *= vecType.getDimSize(i);
1916 return fromElementsOp.getElements()[flatIndex];
1923 if (getNumIndices() == 0 && getVector().
getType() == getResult().
getType())
1927 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
1951 Operation *defOp = extractOp.getVector().getDefiningOp();
1952 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1956 if (extractOp.getType() == source.
getType())
1958 auto getRank = [](
Type type) {
1959 return llvm::isa<VectorType>(type)
1960 ? llvm::cast<VectorType>(type).getRank()
1963 unsigned broadcastSrcRank = getRank(source.
getType());
1964 unsigned extractResultRank = getRank(extractOp.getType());
1968 if (extractResultRank < broadcastSrcRank)
1972 if (extractResultRank == 0) {
1973 assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.
getType()));
1978 extractOp, extractOp.getType(), source);
1984 class ExtractOpSplatConstantFolder final :
public OpRewritePattern<ExtractOp> {
1992 Value sourceVector = extractOp.getVector();
1996 auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
1999 TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
2000 if (
auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
2008 class ExtractOpNonSplatConstantFolder final
2016 if (extractOp.hasDynamicPosition())
2021 Value sourceVector = extractOp.getVector();
2026 auto vecTy = llvm::cast<VectorType>(sourceVector.
getType());
2027 if (vecTy.isScalable())
2031 auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
2032 if (!dense || dense.isSplat())
2038 copy(extractOp.getStaticPosition(), completePositions.begin());
2039 int64_t elemBeginPosition =
2041 auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
2044 if (
auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
2046 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2049 newAttr = *denseValuesBegin;
2065 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2069 VectorType extractedMaskType =
2070 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2072 if (!extractedMaskType)
2075 auto maskOperands = createMaskOp.getOperands();
2077 VectorType maskType = createMaskOp.getVectorType();
2079 bool containsUnknownDims =
false;
2082 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2084 int64_t pos = extractOpPos[dimIdx];
2085 Value operand = maskOperands[dimIdx];
2086 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2089 containsUnknownDims =
true;
2093 int64_t createMaskBound =
2094 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2096 if (pos != ShapedType::kDynamic) {
2099 allFalse |= pos >= createMaskBound;
2100 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2104 containsUnknownDims =
true;
2111 }
else if (!containsUnknownDims) {
2113 extractOp, extractedMaskType,
2114 maskOperands.drop_front(extractOpPos.size()));
2124 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2126 auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2130 VectorType sourceType = castOp.getSourceVectorType();
2131 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2135 if (sourceType.getNumElements() != targetType.getNumElements())
2139 castOp.getSource());
2149 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2152 if (extractOp.hasDynamicPosition())
2156 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2161 auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2162 if (!fromElementsOp)
2164 VectorType inputType = fromElementsOp.getType();
2167 if (resultType.isScalable() || inputType.isScalable())
2173 llvm::to_vector(extractOp.getStaticPosition());
2174 firstElementPos.append(resultType.getRank(), 0);
2177 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2178 flatIndex += firstElementPos[i] * stride;
2179 stride *= inputType.getDimSize(i);
2184 extractOp, resultType,
2185 fromElementsOp.getElements().slice(flatIndex,
2186 resultType.getNumElements()));
2193 results.
add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2194 ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2195 results.
add(foldExtractFromShapeCastToShapeCast);
2196 results.
add(foldExtractFromFromElements);
2201 for (
auto attr : arrayAttr)
2202 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2209 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2224 if (!llvm::all_equal(fromElementsOp.getElements()))
2227 fromElementsOp.getElements().front());
2245 int64_t rankDiff = dstShape.size() - srcShape.size();
2246 int64_t dstDim = rankDiff;
2248 for (
auto [s1, s2] :
2249 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2251 assert(s1 == 1 &&
"expected dim-1 broadcasting");
2261 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2280 Value BroadcastOp::createOrFoldBroadcastOp(
2283 assert(!dstShape.empty() &&
"unexpected empty dst shape");
2287 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
2288 if (broadcastedDims.contains(i))
2290 checkShape.push_back(dstShape[i]);
2292 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2293 "ill-formed broadcastedDims contains values not confined to "
2298 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
2302 if (!srcVectorType) {
2303 assert(checkShape.empty() &&
2304 "ill-formed createOrFoldBroadcastOp arguments");
2305 return b.
createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2308 assert(srcVectorType.getShape().equals(checkShape) &&
2309 "ill-formed createOrFoldBroadcastOp arguments");
2320 broadcastShape.reserve(dstShape.size());
2336 int64_t nextSrcShapeDim = broadcastedDims.size();
2337 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2338 if (broadcastedDims.contains(i)) {
2343 broadcastShape.push_back(dstShape[i]);
2344 permutation[i] = broadcastShape.size() - 1;
2350 permutation[i] = nextSrcShapeDim++;
2354 llvm::append_range(broadcastShape, srcVectorType.getShape());
2359 "unexpected dim-1 broadcast");
2361 VectorType broadcastType =
VectorType::get(broadcastShape, elementType);
2363 vector::BroadcastableToResult::Success &&
2364 "must be broadcastable");
2368 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2369 if (permutation[i] != i)
2370 return b.
createOrFold<vector::TransposeOp>(loc, res, permutation);
2376 Type srcType, VectorType dstVectorType,
2377 std::pair<VectorDim, VectorDim> *mismatchingDims) {
2381 return BroadcastableToResult::Success;
2383 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2385 return BroadcastableToResult::SourceTypeNotAVector;
2387 int64_t srcRank = srcVectorType.getRank();
2388 int64_t dstRank = dstVectorType.getRank();
2389 if (srcRank > dstRank)
2390 return BroadcastableToResult::SourceRankHigher;
2393 int64_t lead = dstRank - srcRank;
2394 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2397 bool foundMismatchingDims =
false;
2400 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2401 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2402 if (srcDim != 1 && srcDim != dstDim)
2403 foundMismatchingDims =
true;
2406 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2407 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2408 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2411 (srcDimScalableFlag != dstDimScalableFlag &&
2412 (srcDim != 1 || srcDimScalableFlag)))
2413 foundMismatchingDims =
true;
2415 if (foundMismatchingDims) {
2416 if (mismatchingDims !=
nullptr) {
2417 mismatchingDims->first.dim = srcDim;
2418 mismatchingDims->first.isScalable = srcDimScalableFlag;
2420 mismatchingDims->second.dim = dstDim;
2421 mismatchingDims->second.isScalable = dstDimScalableFlag;
2423 return BroadcastableToResult::DimensionMismatch;
2427 return BroadcastableToResult::Success;
2431 std::pair<VectorDim, VectorDim> mismatchingDims;
2433 getSourceType(), getResultVectorType(), &mismatchingDims);
2434 if (res == BroadcastableToResult::Success)
2436 if (res == BroadcastableToResult::SourceRankHigher)
2437 return emitOpError(
"source rank higher than destination rank");
2438 if (res == BroadcastableToResult::DimensionMismatch) {
2439 return emitOpError(
"dimension mismatch (")
2440 << (mismatchingDims.first.isScalable ?
"[" :
"")
2441 << mismatchingDims.first.dim
2442 << (mismatchingDims.first.isScalable ?
"]" :
"") <<
" vs. "
2443 << (mismatchingDims.second.isScalable ?
"[" :
"")
2444 << mismatchingDims.second.dim
2445 << (mismatchingDims.second.isScalable ?
"]" :
"") <<
")";
2447 if (res == BroadcastableToResult::SourceTypeNotAVector)
2448 return emitOpError(
"source type is not a vector");
2449 llvm_unreachable(
"unexpected vector.broadcast op error");
2453 if (getSourceType() == getResultVectorType())
2455 if (!adaptor.getSource())
2457 auto vectorType = getResultVectorType();
2458 if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
2460 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2473 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2477 broadcastOp.getResultVectorType(),
2478 srcBroadcast.getSource());
2488 results.
add<BroadcastFolder>(context);
2496 VectorType resultType = getResultVectorType();
2497 VectorType v1Type = getV1VectorType();
2498 VectorType v2Type = getV2VectorType();
2500 int64_t resRank = resultType.getRank();
2501 int64_t v1Rank = v1Type.getRank();
2502 int64_t v2Rank = v2Type.getRank();
2503 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2504 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2505 if (!wellFormed0DCase && !wellFormedNDCase)
2506 return emitOpError(
"rank mismatch");
2509 for (int64_t r = 1; r < v1Rank; ++r) {
2510 int64_t resDim = resultType.getDimSize(r);
2511 int64_t v1Dim = v1Type.getDimSize(r);
2512 int64_t v2Dim = v2Type.getDimSize(r);
2513 if (resDim != v1Dim || v1Dim != v2Dim)
2514 return emitOpError(
"dimension mismatch");
2518 int64_t maskLength = mask.size();
2519 if (maskLength <= 0)
2520 return emitOpError(
"invalid mask length");
2521 if (maskLength != resultType.getDimSize(0))
2522 return emitOpError(
"mask length mismatch");
2524 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2525 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2527 if (maskPos < 0 || maskPos >= indexSize)
2528 return emitOpError(
"mask index #") << (idx + 1) <<
" out of range";
2534 ShuffleOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
2535 ShuffleOp::Adaptor adaptor,
2537 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2538 auto v1Rank = v1Type.getRank();
2542 shape.reserve(v1Rank);
2543 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2546 llvm::append_range(shape, v1Type.getShape().drop_front());
2547 inferredReturnTypes.push_back(
2552 template <
typename T>
2555 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
2556 return value == expected++;
2560 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2561 VectorType v1Type = getV1VectorType();
2564 if (v1Type.getRank() == 0)
2568 if (!v1Type.isScalable() &&
2572 if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
2574 getV2VectorType().getDimSize(0)))
2577 Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
2582 llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).
getType());
2585 if (lhsType.getRank() != 1)
2587 int64_t lhsSize = lhsType.getDimSize(0);
2590 auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<
Attribute>();
2591 auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<
Attribute>();
2592 for (int64_t i : this->getMask()) {
2594 results.push_back(rhsElements[i - lhsSize]);
2596 results.push_back(lhsElements[i]);
2612 VectorType v1VectorType = shuffleOp.getV1VectorType();
2614 if (v1VectorType.getRank() > 0)
2616 if (mask.size() != 1)
2636 auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2637 auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2639 if (!v1Splat || !v2Splat)
2642 if (v1Splat.getInput() != v2Splat.getInput())
2658 VectorType resultType = op.getResultVectorType();
2659 if (resultType.isScalable())
2661 op,
"ShuffleOp can't represent a scalable interleave");
2663 if (resultType.getRank() != 1)
2665 op,
"ShuffleOp can't represent an n-D interleave");
2667 VectorType sourceType = op.getV1VectorType();
2668 if (sourceType != op.getV2VectorType() ||
2669 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2671 op,
"ShuffleOp types don't match an interleave");
2675 int64_t resultVectorSize = resultType.getNumElements();
2676 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2677 int64_t maskValueA = shuffleMask[i * 2];
2678 int64_t maskValueB = shuffleMask[(i * 2) + 1];
2679 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2681 "ShuffleOp mask not interleaving");
2693 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2703 build(builder, result, source, dest, {});
2707 auto dstVectorType = getDestVectorType();
2708 if (dstVectorType.getRank() == 0) {
2710 return emitOpError(
"expected position to be empty with 0-D vector");
2713 if (dstVectorType.getRank() != 1)
2714 return emitOpError(
"unexpected >1 vector rank");
2716 return emitOpError(
"expected position for 1-D vector");
2720 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2722 if (!adaptor.getPosition())
2725 auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2726 auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2727 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2728 if (!src || !dst || !pos)
2734 auto dstElements = dst.getValues<
Attribute>();
2738 uint64_t posIdx = pos.getInt();
2739 if (posIdx >= results.size())
2741 results[posIdx] = src;
2751 Value source,
Value dest, int64_t position) {
2764 posVals.reserve(position.size());
2765 llvm::transform(position, std::back_inserter(posVals),
2767 build(builder, result, source, dest, posVals);
2776 build(builder, result, source, dest, dynamicPos,
2782 auto destVectorType = getDestVectorType();
2783 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
2785 "expected position attribute of rank no greater than dest vector rank");
2786 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2787 if (srcVectorType &&
2788 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
2789 static_cast<unsigned>(destVectorType.getRank())))
2790 return emitOpError(
"expected position attribute rank + source rank to "
2791 "match dest vector rank");
2792 if (!srcVectorType &&
2793 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
2795 "expected position attribute rank to match the dest vector rank");
2797 if (
auto attr = pos.dyn_cast<
Attribute>()) {
2798 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2799 if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
2800 return emitOpError(
"expected position attribute #")
2802 <<
" to be a non-negative integer smaller than the "
2804 "dest vector dimension";
2821 auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
2822 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2823 srcVecType.getNumElements())
2826 insertOp, insertOp.getDestVectorType(), insertOp.getSource());
2838 auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
2839 auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
2841 if (!srcSplat || !dstSplat)
2844 if (srcSplat.getInput() != dstSplat.getInput())
2859 static constexpr int64_t vectorSizeFoldThreshold = 256;
2864 if (op.hasDynamicPosition())
2873 auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
2877 VectorType destTy = destVector.getType();
2878 if (destTy.isScalable())
2882 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
2883 !destVector.hasOneUse())
2886 Value sourceValue = op.getSource();
2894 copy(op.getStaticPosition(), completePositions.begin());
2895 int64_t insertBeginPosition =
2899 Type destEltType = destTy.getElementType();
2904 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
2905 for (
auto value : denseSource.getValues<
Attribute>())
2911 auto allValues = llvm::to_vector(denseDest.getValues<
Attribute>());
2912 copy(insertedValues, allValues.begin() + insertBeginPosition);
2923 if (
auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
2924 if (intAttr.getType() != expectedType)
2935 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
2936 InsertOpConstantFolder>(context);
2942 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
2943 if (getNumIndices() == 0)
2967 template <
typename OpType>
2969 ArrayAttr arrayAttr,
2971 StringRef attrName) {
2972 if (arrayAttr.size() > shape.size())
2974 << attrName <<
" attribute of rank no greater than vector rank";
2981 template <
typename OpType>
2982 static LogicalResult
2984 int64_t
max, StringRef attrName,
2985 bool halfOpen =
true) {
2986 for (
auto attr : arrayAttr) {
2987 auto val = llvm::cast<IntegerAttr>(attr).getInt();
2991 if (val < min || val >= upper)
2992 return op.
emitOpError(
"expected ") << attrName <<
" to be confined to ["
2993 <<
min <<
", " << upper <<
")";
3001 template <
typename OpType>
3002 static LogicalResult
3005 bool halfOpen =
true, int64_t
min = 0) {
3006 for (
auto [index, attrDimPair] :
3008 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3009 int64_t
max = std::get<1>(attrDimPair);
3012 if (val < min || val >=
max)
3014 << attrName <<
" dimension " << index <<
" to be confined to ["
3015 <<
min <<
", " <<
max <<
")";
3025 template <
typename OpType>
3027 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3029 bool halfOpen =
true, int64_t
min = 1) {
3030 assert(arrayAttr1.size() <= shape.size());
3031 assert(arrayAttr2.size() <= shape.size());
3032 for (
auto [index, it] :
3034 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3035 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3036 int64_t
max = std::get<2>(it);
3039 if (val1 + val2 < 0 || val1 + val2 >=
max)
3041 << attrName1 <<
", " << attrName2 <<
") dimension " << index
3042 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
3049 auto attrs = llvm::map_range(values, [context](int64_t v) ->
Attribute {
3056 auto sourceVectorType = getSourceVectorType();
3057 auto destVectorType = getDestVectorType();
3058 auto offsets = getOffsetsAttr();
3059 auto strides = getStridesAttr();
3060 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
3062 "expected offsets of same size as destination vector rank");
3063 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
3064 return emitOpError(
"expected strides of same size as source vector rank");
3065 if (sourceVectorType.getRank() > destVectorType.getRank())
3067 "expected source rank to be no greater than destination rank");
3069 auto sourceShape = sourceVectorType.getShape();
3070 auto destShape = destVectorType.getShape();
3072 destShape.size() - sourceShape.size(), 0);
3073 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3074 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3075 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3084 offName,
"source vector shape",
3088 unsigned rankDiff = destShape.size() - sourceShape.size();
3089 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3090 if (sourceVectorType.getScalableDims()[idx] !=
3091 destVectorType.getScalableDims()[idx + rankDiff]) {
3092 return emitOpError(
"mismatching scalable flags (at source vector idx=")
3095 if (sourceVectorType.getScalableDims()[idx]) {
3096 auto sourceSize = sourceShape[idx];
3097 auto destSize = destShape[idx + rankDiff];
3098 if (sourceSize != destSize) {
3099 return emitOpError(
"expected size at idx=")
3101 << (
" to match the corresponding base size from the input "
3103 << sourceSize << (
" vs ") << destSize << (
")");
3114 class FoldInsertStridedSliceSplat final
3119 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3122 insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
3124 insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3126 if (!srcSplatOp || !destSplatOp)
3129 if (srcSplatOp.getInput() != destSplatOp.getInput())
3132 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3139 class FoldInsertStridedSliceOfExtract final
3144 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3146 auto extractStridedSliceOp =
3147 insertStridedSliceOp.getSource()
3148 .getDefiningOp<vector::ExtractStridedSliceOp>();
3150 if (!extractStridedSliceOp)
3153 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3157 if (extractStridedSliceOp.getStrides() !=
3158 insertStridedSliceOp.getStrides() ||
3159 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3162 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3169 class InsertStridedSliceConstantFolder final
3176 static constexpr int64_t vectorSizeFoldThreshold = 256;
3187 VectorType destTy = destVector.getType();
3188 if (destTy.isScalable())
3192 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3193 !destVector.hasOneUse())
3196 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3204 if (op.hasNonUnitStrides())
3207 VectorType sliceVecTy = sourceValue.getType();
3209 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3219 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3220 auto sliceValuesIt = denseSlice.value_begin<
Attribute>();
3221 auto newValues = llvm::to_vector(denseDest.getValues<
Attribute>());
3224 currDestPosition.begin() + rankDifference, currDestPosition.end());
3228 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
3229 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
3230 assert(sliceValuesIt != denseSlice.value_end<
Attribute>() &&
3231 "Invalid slice element");
3232 newValues[linearizedPosition] = *sliceValuesIt;
3245 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3247 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3248 InsertStridedSliceConstantFolder>(context);
3251 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3252 if (getSourceVectorType() == getDestVectorType())
3269 p <<
" " << getLhs() <<
", " << getRhs();
3271 p <<
", " << getAcc();
3274 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
3285 if (operandsInfo.size() < 2)
3287 "expected at least 2 operands");
3288 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3289 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3292 "expected vector type for operand #1");
3297 vRHS.getScalableDims()[0]};
3299 vLHS.getElementType(), scalableDimsRes);
3303 resType =
VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3309 OuterProductOp::getKindAttrName(result.
name),
3311 OuterProductOp::getDefaultKind()));
3317 (operandsInfo.size() > 2 &&
3323 Type tRHS = getOperandTypeRHS();
3324 VectorType vLHS = getOperandVectorTypeLHS(),
3325 vRHS = llvm::dyn_cast<VectorType>(tRHS),
3326 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3328 if (vLHS.getRank() != 1)
3329 return emitOpError(
"expected 1-d vector for operand #1");
3333 if (vRHS.getRank() != 1)
3334 return emitOpError(
"expected 1-d vector for operand #2");
3335 if (vRES.getRank() != 2)
3336 return emitOpError(
"expected 2-d vector result");
3337 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3338 return emitOpError(
"expected #1 operand dim to match result dim #1");
3339 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3340 return emitOpError(
"expected #2 operand dim to match result dim #2");
3341 if (vLHS.isScalable() && !vRHS.isScalable()) {
3345 "expected either both or only #2 operand dim to be scalable");
3349 if (vRES.getRank() != 1)
3350 return emitOpError(
"expected 1-d vector result");
3351 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3352 return emitOpError(
"expected #1 operand dim to match result dim #1");
3355 if (vACC && vACC != vRES)
3356 return emitOpError(
"expected operand #3 of same type as result type");
3360 return emitOpError(
"unsupported outerproduct type");
3369 Type OuterProductOp::getExpectedMaskType() {
3370 auto vecType = this->getResultVectorType();
3373 vecType.getScalableDims());
3385 ArrayAttr offsets, ArrayAttr sizes,
3386 ArrayAttr strides) {
3387 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3389 shape.reserve(vectorType.getRank());
3391 for (
unsigned e = offsets.size(); idx < e; ++idx)
3392 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3393 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3394 shape.push_back(vectorType.getShape()[idx]);
3397 vectorType.getScalableDims());
3410 offsetsAttr, sizesAttr, stridesAttr));
3411 result.
addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.
name),
3415 result.
addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.
name),
3420 auto type = getSourceVectorType();
3421 auto offsets = getOffsetsAttr();
3422 auto sizes = getSizesAttr();
3423 auto strides = getStridesAttr();
3424 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3426 "expected offsets, sizes and strides attributes of same size");
3428 auto shape = type.getShape();
3429 auto offName = getOffsetsAttrName();
3430 auto sizesName = getSizesAttrName();
3431 auto stridesName = getStridesAttrName();
3447 shape, offName, sizesName,
3452 offsets, sizes, strides);
3453 if (getResult().
getType() != resultType)
3454 return emitOpError(
"expected result type to be ") << resultType;
3456 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
3457 if (type.getScalableDims()[idx]) {
3458 auto inputDim = type.getShape()[idx];
3459 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3460 if (inputDim != inputSize)
3461 return emitOpError(
"expected size at idx=")
3463 << (
" to match the corresponding base size from the input "
3465 << inputSize << (
" vs ") << inputDim << (
")");
3475 static LogicalResult
3478 auto getElement = [](ArrayAttr array,
int idx) {
3479 return llvm::cast<IntegerAttr>(array[idx]).getInt();
3481 ArrayAttr extractOffsets = op.getOffsets();
3483 ArrayAttr extractSizes = op.getSizes();
3484 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3486 if (op.getSourceVectorType().getRank() !=
3487 insertOp.getSourceVectorType().getRank())
3489 ArrayAttr insertOffsets = insertOp.getOffsets();
3490 ArrayAttr insertStrides = insertOp.getStrides();
3493 if (extractOffsets.size() > insertOffsets.size())
3495 bool patialoverlap =
false;
3496 bool disjoint =
false;
3498 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3499 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
3501 int64_t start = getElement(insertOffsets, dim);
3502 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3503 int64_t offset = getElement(extractOffsets, dim);
3504 int64_t size = getElement(extractSizes, dim);
3506 if (start <= offset && offset < end) {
3509 if (offset + size > end)
3510 patialoverlap =
true;
3511 offsetDiffs.push_back(offset - start);
3518 if (!disjoint && !patialoverlap) {
3528 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3538 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3539 if (getSourceVectorType() == getResult().
getType())
3554 class StridedSliceConstantMaskFolder final
3559 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3563 auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3564 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3565 if (!constantMaskOp)
3568 if (extractStridedSliceOp.hasNonUnitStrides())
3581 sliceMaskDimSizes.reserve(maskDimSizes.size());
3582 for (
auto [maskDimSize, sliceOffset, sliceSize] :
3583 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3584 int64_t sliceMaskDimSize =
std::max(
3585 static_cast<int64_t
>(0),
3586 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3587 sliceMaskDimSizes.push_back(sliceMaskDimSize);
3590 if (sliceMaskDimSizes.size() < maskDimSizes.size())
3591 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3592 sliceMaskDimSizes.push_back(maskDimSizes[i]);
3595 if (llvm::is_contained(sliceMaskDimSizes, 0))
3596 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3601 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3608 class StridedSliceSplatConstantFolder final
3613 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3617 Value sourceVector = extractStridedSliceOp.getVector();
3622 auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3636 class StridedSliceNonSplatConstantFolder final
3641 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3645 Value sourceVector = extractStridedSliceOp.getVector();
3651 auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3652 if (!dense || dense.isSplat())
3656 if (extractStridedSliceOp.hasNonUnitStrides())
3659 auto sourceVecTy = llvm::cast<VectorType>(sourceVector.
getType());
3663 VectorType sliceVecTy = extractStridedSliceOp.getType();
3665 int64_t sliceRank = sliceVecTy.getRank();
3677 auto denseValuesBegin = dense.value_begin<
Attribute>();
3679 sliceValues.reserve(sliceVecTy.getNumElements());
3682 int64_t linearizedPosition =
linearize(currSlicePosition, sourceStrides);
3683 assert(linearizedPosition < sourceVecTy.getNumElements() &&
3685 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3689 assert(
static_cast<int64_t
>(sliceValues.size()) ==
3690 sliceVecTy.getNumElements() &&
3691 "Invalid number of slice elements");
3701 class StridedSliceBroadcast final
3708 auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
3713 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3714 auto dstVecType = llvm::cast<VectorType>(op.getType());
3715 unsigned dstRank = dstVecType.getRank();
3716 unsigned rankDiff = dstRank - srcRank;
3720 bool lowerDimMatch =
true;
3721 for (
unsigned i = 0; i < srcRank; i++) {
3722 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3723 lowerDimMatch =
false;
3732 bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3733 if (!lowerDimMatch && !isScalarSrc) {
3734 source = rewriter.
create<ExtractStridedSliceOp>(
3746 class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
3752 auto splat = op.getVector().getDefiningOp<SplatOp>();
3762 void ExtractStridedSliceOp::getCanonicalizationPatterns(
3766 results.
add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
3767 StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3768 StridedSliceSplat>(context);
3777 VectorType vectorType,
Value source,
3778 ValueRange indices, AffineMapAttr permutationMapAttr,
3779 ArrayAttr inBoundsAttr) {
3780 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
3781 Value padding = builder.
create<arith::ConstantOp>(
3783 build(builder, result, vectorType, source, indices, permutationMapAttr,
3784 padding,
Value(), inBoundsAttr);
3789 VectorType vectorType,
Value source,
3793 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3797 build(builder, result, vectorType, source, indices, permutationMapAttr,
3803 VectorType vectorType,
Value source,
3807 llvm::cast<ShapedType>(source.
getType()), vectorType);
3809 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3813 build(builder, result, vectorType, source, indices, permutationMapAttr,
3815 Value(), inBoundsAttr);
3821 VectorType vectorType,
Value source,
3824 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
3825 Value padding = builder.
create<arith::ConstantOp>(
3827 build(builder, result, vectorType, source, indices, padding, inBounds);
3830 template <
typename EmitFun>
3832 EmitFun emitOpError) {
3834 for (
auto expr : permutationMap.
getResults()) {
3835 auto dim = dyn_cast<AffineDimExpr>(expr);
3836 auto zero = dyn_cast<AffineConstantExpr>(expr);
3838 if (zero.getValue() != 0) {
3840 "requires a projected permutation_map (at most one dim or the zero "
3841 "constant can appear in each result)");
3846 return emitOpError(
"requires a projected permutation_map (at most one "
3847 "dim or the zero constant can appear in each result)");
3849 if (seen[dim.getPosition()]) {
3851 "requires a permutation_map that is a permutation (found one dim "
3852 "used more than once)");
3854 seen[dim.getPosition()] =
true;
3859 static LogicalResult
3861 VectorType vectorType, VectorType maskType,
3862 VectorType inferredMaskType,
AffineMap permutationMap,
3863 ArrayAttr inBounds) {
3865 return op->
emitOpError(
"masked attribute has been removed. "
3866 "Use in_bounds instead.");
3869 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
3871 "requires source to be a memref or ranked tensor type");
3873 auto elementType = shapedType.getElementType();
3874 DataLayout dataLayout = DataLayout::closest(op);
3875 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
3877 unsigned sourceVecSize =
3879 vectorElementType.getShape().back();
3880 unsigned resultVecSize =
3882 vectorType.getShape().back();
3883 if (resultVecSize % sourceVecSize != 0)
3885 "requires the bitwidth of the minor 1-D vector to be an integral "
3886 "multiple of the bitwidth of the minor 1-D vector of the source");
3888 unsigned sourceVecEltRank = vectorElementType.getRank();
3889 unsigned resultVecRank = vectorType.getRank();
3890 if (sourceVecEltRank > resultVecRank)
3892 "requires source vector element and vector result ranks to match.");
3893 unsigned rankOffset = resultVecRank - sourceVecEltRank;
3896 return op->
emitOpError(
"requires a permutation_map with result dims of "
3897 "the same rank as the vector type");
3900 return op->
emitOpError(
"does not support masks with vector element type");
3903 unsigned minorSize =
3904 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
3905 unsigned resultVecSize =
3909 "requires the bitwidth of the minor 1-D vector to be an integral "
3910 "multiple of the bitwidth of the source element type");
3914 return op->
emitOpError(
"requires a permutation_map with result dims of "
3915 "the same rank as the vector type");
3919 return op->
emitOpError(
"requires permutation_map without symbols");
3921 if (permutationMap.
getNumInputs() != shapedType.getRank())
3922 return op->
emitOpError(
"requires a permutation_map with input dims of the "
3923 "same rank as the source type");
3925 if (maskType && maskType != inferredMaskType)
3927 << inferredMaskType <<
") and mask operand type (" << maskType
3930 if (permutationMap.
getNumResults() !=
static_cast<int64_t
>(inBounds.size()))
3931 return op->
emitOpError(
"expects the in_bounds attr of same rank "
3932 "as permutation_map results: ")
3934 <<
" vs inBounds of size: " << inBounds.size();
3935 for (
unsigned int i = 0, e = permutationMap.
getNumResults(); i < e; ++i)
3936 if (isa<AffineConstantExpr>(permutationMap.
getResult(i)) &&
3937 !llvm::cast<BoolAttr>(inBounds.getValue()[i]).getValue())
3938 return op->
emitOpError(
"requires broadcast dimensions to be in-bounds");
3945 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
3946 if (op.getPermutationMap().isMinorIdentity())
3947 elidedAttrs.push_back(op.getPermutationMapAttrName());
3949 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
3950 elidedAttrs.push_back(op.getInBoundsAttrName());
3955 p <<
" " << getSource() <<
"[" <<
getIndices() <<
"], " << getPadding();
3957 p <<
", " << getMask();
3966 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
3989 if (hasMask.succeeded()) {
3996 if (types.size() != 2)
3997 return parser.
emitError(typesLoc,
"requires two types");
3999 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4000 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4001 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4002 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4004 return parser.
emitError(typesLoc,
"requires vector type");
4005 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.
name);
4012 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4014 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.
name);
4016 if (!inBoundsAttr) {
4026 if (hasMask.succeeded()) {
4027 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4029 maskInfo.
location,
"does not support masks with vector element type");
4032 "expected the same rank for the vector and the "
4033 "results of the permutation map");
4041 result.
addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4043 {1, static_cast<int32_t>(indexInfo.size()), 1,
4044 static_cast<int32_t>(hasMask.succeeded())}));
4050 ShapedType shapedType = getShapedType();
4052 VectorType maskType = getMaskType();
4053 auto paddingType = getPadding().getType();
4054 auto permutationMap = getPermutationMap();
4055 VectorType inferredMaskType =
4058 auto sourceElementType = shapedType.getElementType();
4060 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
4061 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4063 if (failed(
verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4064 shapedType, vectorType, maskType,
4065 inferredMaskType, permutationMap, getInBounds())))
4068 if (
auto sourceVectorElementType =
4069 llvm::dyn_cast<VectorType>(sourceElementType)) {
4072 if (sourceVectorElementType != paddingType)
4074 "requires source element type and padding type to match.");
4078 if (!VectorType::isValidElementType(paddingType))
4079 return emitOpError(
"requires valid padding vector elemental type");
4082 if (paddingType != sourceElementType)
4084 "requires formal padding and source of the same elemental type");
4088 [&](Twine t) {
return emitOpError(t); });
4095 Type TransferReadOp::getExpectedMaskType() {
4099 template <
typename TransferOp>
4100 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4103 if (op.getShapedType().isDynamicDim(indicesIdx))
4105 Value index = op.getIndices()[indicesIdx];
4107 if (!cstOp.has_value())
4110 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4111 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4113 return cstOp.value() + vectorSize <= sourceSize;
4116 template <
typename TransferOp>
4120 if (op.getTransferRank() == 0)
4122 AffineMap permutationMap = op.getPermutationMap();
4123 bool changed =
false;
4125 newInBounds.reserve(op.getTransferRank());
4126 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
4128 if (op.isDimInBounds(i)) {
4129 newInBounds.push_back(
true);
4134 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
4135 assert(dimExpr &&
"Broadcast dims must be in-bounds");
4138 newInBounds.push_back(inBounds);
4140 changed |= inBounds;
4150 template <
typename TransferOp>
4152 auto mask = op.getMask();
4159 op.getMaskMutable().clear();
4173 static Value foldRAW(TransferReadOp readOp) {
4174 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4176 auto defWrite = readOp.getSource().
getDefiningOp<vector::TransferWriteOp>();
4179 return defWrite.getVector();
4181 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4182 cast<VectorTransferOpInterface>(readOp.getOperation())))
4184 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4190 if (
Value vec = foldRAW(*
this))
4204 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4208 void TransferReadOp::getEffects(
4211 if (llvm::isa<MemRefType>(getShapedType()))
4239 struct TransferReadAfterWriteToBroadcast
4245 if (readOp.hasOutOfBoundsDim() ||
4246 !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4248 auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4253 if (readOp.getTransferChunkAccessed() !=
4254 defWrite.getTransferChunkAccessed())
4261 if (readOp.getIndices() != defWrite.getIndices() ||
4262 readOp.getMask() != defWrite.getMask())
4264 Value vec = defWrite.getVector();
4286 broadcastShape[pos.value()] = destShape[pos.index()];
4287 broadcastScalableFlags[pos.value()] =
4288 readOp.getVectorType().getScalableDims()[pos.index()];
4291 broadcastShape, defWrite.getVectorType().getElementType(),
4292 broadcastScalableFlags);
4293 vec = rewriter.
create<vector::BroadcastOp>(loc, broadcastedType, vec);
4304 results.
add<TransferReadAfterWriteToBroadcast>(context);
4314 AffineMapAttr permutationMapAttr,
4316 ArrayAttr inBoundsAttr) {
4317 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
4318 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4319 mask, inBoundsAttr);
4325 AffineMapAttr permutationMapAttr,
4326 ArrayAttr inBoundsAttr) {
4327 build(builder, result, vector, dest, indices, permutationMapAttr,
4328 Value(), inBoundsAttr);
4339 (inBounds && !inBounds.value().empty())
4342 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
4343 build(builder, result, vector, dest, indices, permutationMapAttr,
4344 Value(), inBoundsAttr);
4352 auto vectorType = llvm::cast<VectorType>(vector.
getType());
4354 llvm::cast<ShapedType>(dest.
getType()), vectorType);
4355 build(builder, result, vector, dest, indices, permutationMap, inBounds);
4371 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
4376 if (types.size() != 2)
4377 return parser.
emitError(typesLoc,
"requires two types");
4379 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4381 return parser.
emitError(typesLoc,
"requires vector type");
4382 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4383 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4384 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4385 auto permMapAttrName =
4386 TransferWriteOp::getPermutationMapAttrName(result.
name);
4393 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4395 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.
name);
4397 if (!inBoundsAttr) {
4406 if (hasMask.succeeded()) {
4407 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4409 maskInfo.
location,
"does not support masks with vector element type");
4412 "expected the same rank for the vector and the "
4413 "results of the permutation map");
4419 result.
addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4421 {1, 1, static_cast<int32_t>(indexInfo.size()),
4422 static_cast<int32_t>(hasMask.succeeded())}));
4423 return failure(llvm::isa<RankedTensorType>(shapedType) &&
4428 p <<
" " << getVector() <<
", " << getSource() <<
"[" <<
getIndices() <<
"]";
4430 p <<
", " << getMask();
4437 ShapedType shapedType = getShapedType();
4439 VectorType maskType = getMaskType();
4440 auto permutationMap = getPermutationMap();
4441 VectorType inferredMaskType =
4445 if (llvm::size(
getIndices()) != shapedType.getRank())
4446 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4450 if (hasBroadcastDim())
4451 return emitOpError(
"should not have broadcast dimensions");
4453 if (failed(
verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4454 shapedType, vectorType, maskType,
4455 inferredMaskType, permutationMap, getInBounds())))
4459 [&](Twine t) {
return emitOpError(t); });
4466 Type TransferWriteOp::getExpectedMaskType() {
4487 static LogicalResult foldReadInitWrite(TransferWriteOp write,
4491 if (write.getTransferRank() == 0)
4493 auto rankedTensorType =
4494 llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
4496 if (!rankedTensorType)
4499 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4503 if (read.getTransferRank() == 0)
4506 if (!read.getPermutationMap().isMinorIdentity() ||
4507 !write.getPermutationMap().isMinorIdentity())
4510 if (read.getTransferRank() != write.getTransferRank())
4513 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4516 if (read.getSource().getType() != rankedTensorType)
4519 if (read.getVectorType() != write.getVectorType())
4522 if (read.getVectorType().getShape() != rankedTensorType.getShape())
4525 auto isNotConstantZero = [](
Value v) {
4527 return !cstOp.has_value() || cstOp.value() != 0;
4529 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4530 llvm::any_of(write.getIndices(), isNotConstantZero))
4533 results.push_back(read.getSource());
4537 static bool checkSameValueWAR(vector::TransferReadOp read,
4538 vector::TransferWriteOp write) {
4539 return read.getSource() == write.getSource() &&
4540 read.getIndices() == write.getIndices() &&
4541 read.getPermutationMap() == write.getPermutationMap() &&
4542 read.getVectorType() == write.getVectorType() && !read.getMask() &&
4559 static LogicalResult foldWAR(TransferWriteOp write,
4561 if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
4563 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4567 if (!checkSameValueWAR(read, write))
4569 results.push_back(read.getSource());
4573 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
4575 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
4577 if (succeeded(foldWAR(*
this, results)))
4586 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4590 void TransferWriteOp::getEffects(
4593 if (llvm::isa<MemRefType>(getShapedType()))
4628 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4630 vector::TransferWriteOp writeToModify = writeOp;
4633 writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4637 writeToModify.getSourceMutable().assign(defWrite.getSource());
4642 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4643 cast<VectorTransferOpInterface>(writeOp.getOperation())))
4647 if (!defWrite->hasOneUse())
4649 writeToModify = defWrite;
4650 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4679 struct SwapExtractSliceOfTransferWrite
4686 if (!insertOp.hasUnitStride())
4689 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
4690 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
4692 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
4693 if (!transferOp || !transferOp->hasOneUse())
4698 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
4700 "use-def chain is rank-reducing");
4704 if (!extractOp.hasZeroOffset()) {
4706 "ExtractSliceOp has non-zero offset");
4710 if (!llvm::all_of(transferOp.getIndices(), [](
Value value) {
4714 "TranferWriteOp has non-zero offset");
4718 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
4720 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
4723 for (
auto [insertSize, extractSize] :
4724 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
4727 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
4732 assert(transferOp.getVectorType().hasStaticShape() &&
4733 "expected vector to have a static shape");
4736 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
4737 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
4739 insertOp,
"TransferWriteOp may not write the full tensor.");
4745 auto newExtractOp = rewriter.
create<tensor::ExtractSliceOp>(
4746 extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
4747 insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
4748 insertOp.getMixedStrides());
4749 auto newTransferWriteOp = rewriter.
create<TransferWriteOp>(
4750 transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
4751 transferOp.getIndices(), transferOp.getPermutationMapAttr(),
4754 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
4764 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
4771 static LogicalResult verifyLoadStoreMemRefLayout(
Operation *op,
4772 MemRefType memRefTy) {
4774 return op->
emitOpError(
"most minor memref dim must have unit stride");
4782 if (failed(verifyLoadStoreMemRefLayout(*
this, memRefTy)))
4786 Type memElemTy = memRefTy.getElementType();
4787 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4788 if (memVecTy != resVecTy)
4789 return emitOpError(
"base memref and result vector types should match");
4790 memElemTy = memVecTy.getElementType();
4793 if (resVecTy.getElementType() != memElemTy)
4794 return emitOpError(
"base and result element types should match");
4795 if (llvm::size(
getIndices()) != memRefTy.getRank())
4796 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
4814 if (failed(verifyLoadStoreMemRefLayout(*
this, memRefTy)))
4818 Type memElemTy = memRefTy.getElementType();
4819 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4820 if (memVecTy != valueVecTy)
4822 "base memref and valueToStore vector types should match");
4823 memElemTy = memVecTy.getElementType();
4826 if (valueVecTy.getElementType() != memElemTy)
4827 return emitOpError(
"base and valueToStore element type should match");
4828 if (llvm::size(
getIndices()) != memRefTy.getRank())
4829 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
4833 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
4843 VectorType maskVType = getMaskVectorType();
4844 VectorType passVType = getPassThruVectorType();
4848 if (resVType.getElementType() != memType.getElementType())
4849 return emitOpError(
"base and result element type should match");
4850 if (llvm::size(
getIndices()) != memType.getRank())
4851 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
4852 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4853 return emitOpError(
"expected result dim to match mask dim");
4854 if (resVType != passVType)
4855 return emitOpError(
"expected pass_thru of same type as result type");
4868 load, load.getType(), load.getBase(), load.getIndices());
4871 rewriter.
replaceOp(load, load.getPassThru());
4876 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
4883 results.
add<MaskedLoadFolder>(context);
4897 VectorType maskVType = getMaskVectorType();
4901 if (valueVType.getElementType() != memType.getElementType())
4902 return emitOpError(
"base and valueToStore element type should match");
4903 if (llvm::size(
getIndices()) != memType.getRank())
4904 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
4905 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4906 return emitOpError(
"expected valueToStore dim to match mask dim");
4919 store, store.getValueToStore(), store.getBase(), store.getIndices());
4927 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
4934 results.
add<MaskedStoreFolder>(context);
4937 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
4947 VectorType indVType = getIndexVectorType();
4948 VectorType maskVType = getMaskVectorType();
4950 ShapedType baseType = getBaseType();
4952 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
4953 return emitOpError(
"requires base to be a memref or ranked tensor type");
4955 if (resVType.getElementType() != baseType.getElementType())
4956 return emitOpError(
"base and result element type should match");
4957 if (llvm::size(
getIndices()) != baseType.getRank())
4958 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
4959 if (resVType.getShape() != indVType.getShape())
4960 return emitOpError(
"expected result dim to match indices dim");
4961 if (resVType.getShape() != maskVType.getShape())
4962 return emitOpError(
"expected result dim to match mask dim");
4963 if (resVType != getPassThruVectorType())
4964 return emitOpError(
"expected pass_thru of same type as result type");
4972 Type GatherOp::getExpectedMaskType() {
4973 auto vecType = this->getIndexVectorType();
4976 vecType.getScalableDims());
4979 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
4993 rewriter.
replaceOp(gather, gather.getPassThru());
4998 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
5005 results.
add<GatherFolder>(context);
5013 VectorType indVType = getIndexVectorType();
5014 VectorType maskVType = getMaskVectorType();
5018 if (valueVType.getElementType() != memType.getElementType())
5019 return emitOpError(
"base and valueToStore element type should match");
5020 if (llvm::size(
getIndices()) != memType.getRank())
5021 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5022 if (valueVType.getDimSize(0) != indVType.getDimSize(0))
5023 return emitOpError(
"expected valueToStore dim to match indices dim");
5024 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5025 return emitOpError(
"expected valueToStore dim to match mask dim");
5044 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
5051 results.
add<ScatterFolder>(context);
5059 VectorType maskVType = getMaskVectorType();
5060 VectorType passVType = getPassThruVectorType();
5064 if (resVType.getElementType() != memType.getElementType())
5065 return emitOpError(
"base and result element type should match");
5066 if (llvm::size(
getIndices()) != memType.getRank())
5067 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5068 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5069 return emitOpError(
"expected result dim to match mask dim");
5070 if (resVType != passVType)
5071 return emitOpError(
"expected pass_thru of same type as result type");
5084 expand, expand.getType(), expand.getBase(), expand.getIndices());
5087 rewriter.
replaceOp(expand, expand.getPassThru());
5092 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
5099 results.
add<ExpandLoadFolder>(context);
5107 VectorType maskVType = getMaskVectorType();
5111 if (valueVType.getElementType() != memType.getElementType())
5112 return emitOpError(
"base and valueToStore element type should match");
5113 if (llvm::size(
getIndices()) != memType.getRank())
5114 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5115 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5116 return emitOpError(
"expected valueToStore dim to match mask dim");
5121 class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
5129 compress, compress.getValueToStore(), compress.getBase(),
5130 compress.getIndices());
5138 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
5145 results.
add<CompressStoreFolder>(context);
5155 unsigned rankA = a.size();
5156 unsigned rankB = b.size();
5157 assert(rankA < rankB);
5159 auto isOne = [](int64_t v) {
return v == 1; };
5163 if (rankA == 0 && llvm::all_of(b, isOne))
5168 while (i < rankA &&
j < rankB) {
5169 int64_t dimA = a[i];
5171 while (dimB < dimA &&
j < rankB)
5179 if (i < rankA && llvm::all_of(a.slice(i), isOne))
5181 if (
j < rankB && llvm::all_of(b.slice(
j), isOne))
5185 return i == rankA &&
j == rankB;
5188 static LogicalResult verifyVectorShapeCast(
Operation *op,
5189 VectorType sourceVectorType,
5190 VectorType resultVectorType) {
5192 if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5193 return op->
emitOpError(
"source/result vectors must have same element type");
5194 auto sourceShape = sourceVectorType.getShape();
5195 auto resultShape = resultVectorType.getShape();
5198 int64_t sourceDimProduct = std::accumulate(
5199 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5200 int64_t resultDimProduct = std::accumulate(
5201 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5202 if (sourceDimProduct != resultDimProduct)
5203 return op->
emitOpError(
"source/result number of elements must match");
5206 unsigned sourceRank = sourceVectorType.getRank();
5207 unsigned resultRank = resultVectorType.getRank();
5208 if (sourceRank < resultRank) {
5209 if (!isValidShapeCast(sourceShape, resultShape))
5211 }
else if (sourceRank > resultRank) {
5212 if (!isValidShapeCast(resultShape, sourceShape))
5217 int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
5218 int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
5219 if (sourceNScalableDims != resultNScalableDims)
5220 return op->
emitOpError(
"different number of scalable dims at source (")
5221 << sourceNScalableDims <<
") and result (" << resultNScalableDims
5223 sourceVectorType.getNumDynamicDims();
5229 auto sourceVectorType =
5230 llvm::dyn_cast_or_null<VectorType>(getSource().
getType());
5231 auto resultVectorType =
5232 llvm::dyn_cast_or_null<VectorType>(getResult().
getType());
5235 if (sourceVectorType && resultVectorType)
5236 return verifyVectorShapeCast(*
this, sourceVectorType, resultVectorType);
5247 if (
auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5248 if (getResult().
getType() == otherOp.getSource().getType())
5249 return otherOp.getSource();
5252 VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5253 VectorType resultType = llvm::cast<VectorType>(getResult().
getType());
5254 if (srcType.getRank() < resultType.getRank()) {
5255 if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5257 }
else if (srcType.getRank() > resultType.getRank()) {
5258 if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5264 setOperand(otherOp.getSource());
5269 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5270 if (bcastOp.getSourceType() ==
getType())
5271 return bcastOp.getSource();
5279 class ShapeCastConstantFolder final :
public OpRewritePattern<ShapeCastOp> {
5286 shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5290 auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5306 static VectorType trimTrailingOneDims(VectorType oldType) {
5313 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5314 newShape = newShape.drop_back(1);
5315 newScalableDims = newScalableDims.drop_back(1);
5320 if (newShape.empty()) {
5321 newShape = oldShape.take_back();
5322 newScalableDims = oldScalableDims.take_back();
5325 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5340 class ShapeCastCreateMaskFolderTrailingOneDim final
5347 Value shapeOpSrc = shapeOp->getOperand(0);
5348 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
5349 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
5350 if (!createMaskOp && !constantMaskOp)
5353 VectorType shapeOpResTy = shapeOp.getResultVectorType();
5354 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5356 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5357 if (newVecType != shapeOpResTy)
5360 auto numDimsToDrop =
5361 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5368 auto maskOperands = createMaskOp.getOperands();
5369 auto numMaskOperands = maskOperands.size();
5372 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5374 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5375 if (!constant || (constant.value() != 1))
5379 maskOperands.drop_back(numDimsToDrop);
5386 if (constantMaskOp) {
5387 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5388 auto numMaskOperands = maskDimSizes.size();
5391 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5393 if (maskDimSizes[i] != 1)
5397 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5412 class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
5419 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5424 if (
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5425 broadcastSourceShape = srcType.getShape();
5427 shapeCastOp.getResultVectorType().getShape();
5431 if (broadcastSourceShape ==
5432 shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
5434 shapeCastOp, shapeCastOp.getResultVectorType(),
5435 broadcastOp.getSource());
5441 if (
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5442 if (srcType.getNumElements() ==
5443 shapeCastOp.getResultVectorType().getNumElements()) {
5445 shapeCastOp, shapeCastOp.getResultVectorType(),
5446 broadcastOp.getSource());
5459 results.
add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5460 ShapeCastBroadcastFolder>(context);
5468 auto sourceVectorType = getSourceVectorType();
5469 auto resultVectorType = getResultVectorType();
5471 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5472 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5473 return emitOpError(
"dimension size mismatch at: ") << i;
5476 DataLayout dataLayout = DataLayout::closest(*
this);
5477 auto sourceElementBits =
5479 auto resultElementBits =
5482 if (sourceVectorType.getRank() == 0) {
5483 if (sourceElementBits != resultElementBits)
5484 return emitOpError(
"source/result bitwidth of the 0-D vector element "
5485 "types must be equal");
5486 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
5487 resultElementBits * resultVectorType.getShape().back()) {
5489 "source/result bitwidth of the minor 1-D vectors must be equal");
5501 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5502 if (getResult().
getType() == otherOp.getSource().getType())
5503 return otherOp.getSource();
5505 setOperand(otherOp.getSource());
5509 Attribute sourceConstant = adaptor.getSource();
5510 if (!sourceConstant)
5513 Type srcElemType = getSourceVectorType().getElementType();
5514 Type dstElemType = getResultVectorType().getElementType();
5516 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5517 if (floatPack.isSplat()) {
5518 auto splat = floatPack.getSplatValue<FloatAttr>();
5521 if (srcElemType.
isF16() && dstElemType.
isF32()) {
5522 uint32_t bits =
static_cast<uint32_t
>(
5523 splat.getValue().bitcastToAPInt().getZExtValue());
5525 bits = (bits << 16) | (bits & 0xffff);
5526 APInt intBits(32, bits);
5527 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5533 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5534 if (intPack.isSplat()) {
5535 auto splat = intPack.getSplatValue<IntegerAttr>();
5537 if (llvm::isa<IntegerType>(dstElemType)) {
5542 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5543 APInt intBits = splat.getValue().zext(dstBitWidth);
5546 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5547 intBits = (intBits << srcBitWidth) | intBits;
5562 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5565 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5574 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
5575 VectorType vectorType =
5579 memRefType.getMemorySpace()));
5584 if (!canonicalType.getLayout().isIdentity())
5585 return emitOpError(
"expects operand to be a memref with identity layout");
5586 if (!getResultMemRefType().getLayout().isIdentity())
5587 return emitOpError(
"expects result to be a memref with identity layout");
5588 if (getResultMemRefType().getMemorySpace() !=
5590 return emitOpError(
"expects result in same memory space");
5593 auto resultType = getResultMemRefType();
5597 "expects result and operand with same underlying scalar type: ")
5599 if (extractShape(sourceType) != extractShape(resultType))
5601 "expects concatenated result and operand shapes to be equal: ")
5612 VectorType vt = llvm::cast<VectorType>(vector.
getType());
5615 for (
unsigned i = 0; i < permutation.size(); ++i) {
5616 transposedShape[i] = vt.getShape()[permutation[i]];
5617 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5622 transposedScalableDims));
5627 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
5630 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
5632 return attr.reshape(getResultVectorType());
5640 for (int64_t i = 0, e = perm.size(); i < e; i++) {
5649 VectorType vectorType = getSourceVectorType();
5650 VectorType resultType = getResultVectorType();
5651 int64_t rank = resultType.getRank();
5652 if (vectorType.getRank() != rank)
5653 return emitOpError(
"vector result rank mismatch: ") << rank;
5656 int64_t size = perm.size();
5658 return emitOpError(
"transposition length mismatch: ") << size;
5661 if (ta.value() < 0 || ta.value() >= rank)
5662 return emitOpError(
"transposition index out of range: ") << ta.value();
5663 if (seen[ta.value()])
5664 return emitOpError(
"duplicate position index: ") << ta.value();
5665 seen[ta.value()] =
true;
5666 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
5667 return emitOpError(
"dimension size mismatch at: ") << ta.value();
5672 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
5673 return llvm::to_vector<4>(getResultVectorType().
getShape());
5679 class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
5689 for (
auto index : permutation2)
5690 result.push_back(permutation1[index]);
5695 vector::TransposeOp parentTransposeOp =
5696 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
5697 if (!parentTransposeOp)
5701 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
5704 transposeOp, transposeOp.getResult().getType(),
5705 parentTransposeOp.getVector(), permutation);
5711 struct FoldTransposedScalarBroadcast final
5717 auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
5721 auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
5722 if (!srcVectorType || srcVectorType.getNumElements() == 1) {
5724 transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
5739 auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
5744 transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
5750 class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
5756 Value transposeSrc = transpOp.getVector();
5757 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
5758 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
5759 if (!createMaskOp && !constantMaskOp)
5767 auto maskOperands = createMaskOp.getOperands();
5772 transpOp, transpOp.getResultVectorType(), newOperands);
5777 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5781 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
5788 void vector::TransposeOp::getCanonicalizationPatterns(
5790 results.
add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
5791 TransposeFolder, FoldTransposeSplat>(context);
5800 assert(kind == ConstantMaskKind::AllTrue ||
5801 kind == ConstantMaskKind::AllFalse);
5802 build(builder, result, type,
5803 kind == ConstantMaskKind::AllTrue
5809 auto resultType = llvm::cast<VectorType>(getResult().
getType());
5811 if (resultType.getRank() == 0) {
5812 if (getMaskDimSizes().size() != 1)
5813 return emitError(
"array attr must have length 1 for 0-D vectors");
5814 auto dim = getMaskDimSizes()[0];
5815 if (dim != 0 && dim != 1)
5816 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
5821 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
5823 "must specify array attr of size equal vector result rank");
5826 auto resultShape = resultType.getShape();
5827 auto resultScalableDims = resultType.getScalableDims();
5829 for (
const auto [index, maskDimSize] :
llvm::enumerate(maskDimSizes)) {
5830 if (maskDimSize < 0 || maskDimSize > resultShape[index])
5832 "array attr of size out of bounds of vector result dimension size");
5833 if (resultScalableDims[index] && maskDimSize != 0 &&
5834 maskDimSize != resultShape[index])
5836 "only supports 'none set' or 'all set' scalable dimensions");
5840 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
5841 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
5842 if (anyZeros && !allZeros)
5843 return emitOpError(
"expected all mask dim sizes to be zeros, "
5844 "as a result of conjunction with zero mask dim");
5848 bool ConstantMaskOp::isAllOnesMask() {
5851 if (resultType.getRank() == 0) {
5852 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
5853 return getMaskDimSizes()[0] == 1;
5855 for (
const auto [resultSize, maskDimSize] :
5856 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
5857 if (maskDimSize < resultSize)
5872 build(builder, result, type, operands);
5876 auto vectorType = llvm::cast<VectorType>(getResult().
getType());
5878 if (vectorType.getRank() == 0) {
5879 if (getNumOperands() != 1)
5881 "must specify exactly one operand for 0-D create_mask");
5882 }
else if (getNumOperands() !=
5883 llvm::cast<VectorType>(getResult().
getType()).getRank()) {
5885 "must specify an operand for each result vector dimension");
5896 auto lhs = mul.getLhs();
5897 auto rhs = mul.getRhs();
5936 VectorType maskType = createMaskOp.getVectorType();
5938 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
5941 constexpr std::array<int64_t, 1> rankZeroShape{1};
5942 constexpr std::array<bool, 1> rankZeroScalableDims{
false};
5943 if (maskType.getRank() == 0) {
5944 maskTypeDimSizes = rankZeroShape;
5945 maskTypeDimScalableFlags = rankZeroScalableDims;
5951 for (
auto [i, dimSize] :
llvm::enumerate(createMaskOp.getOperands())) {
5956 if (maskTypeDimScalableFlags[i] && intSize >= 0)
5958 constantDims.push_back(*intSize);
5962 if (vscaleMultiplier < maskTypeDimSizes[i])
5964 constantDims.push_back(*vscaleMultiplier);
5971 for (
auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
5972 value = std::clamp<int64_t>(value, 0, maskDimSize);
5975 if (llvm::is_contained(constantDims, 0))
5976 constantDims.assign(constantDims.size(), 0);
5989 results.
add<CreateMaskFolder>(context);
6000 assert(maskRegionBuilder &&
6001 "builder callback for 'maskRegion' must be present");
6007 maskRegionBuilder(builder, maskableOp);
6014 build(builder, result, resultTypes, mask,
Value(), maskableOp,
6022 build(builder, result, mask, maskableOp, maskRegionBuilder);
6043 if (parsePassthru.succeeded() && parser.
parseOperand(passthru))
6050 MaskOp::ensureTerminator(maskRegion, builder, result.
location);
6064 result.
types.append(resultTypes);
6070 if (parsePassthru.succeeded())
6078 p <<
" " << getMask();
6080 p <<
", " << getPassthru();
6084 Block *singleBlock = &getMaskRegion().getBlocks().
front();
6091 p <<
" : " << getMask().getType();
6092 if (getNumResults() > 0)
6093 p <<
" -> " << getResultTypes();