40 #include "llvm/ADT/ArrayRef.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SmallVector.h"
43 #include "llvm/ADT/StringSet.h"
44 #include "llvm/ADT/TypeSwitch.h"
45 #include "llvm/Support/Casting.h"
51 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
53 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
74 if (
auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
76 for (
bool b : denseElts.getValues<
bool>())
79 else if (!b && val <= 0)
93 auto shape = m.getType().getShape();
96 for (
auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
97 if (maskIdx < dimSize)
110 auto maskOperands = m.getOperands();
111 for (
Value operand : maskOperands) {
112 if (
auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
114 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
127 builder.
create<vector::YieldOp>(loc);
133 switch (combiningKind) {
134 case CombiningKind::ADD:
135 case CombiningKind::MUL:
138 case CombiningKind::MINSI:
139 case CombiningKind::MAXUI:
140 case CombiningKind::MAXSI:
141 case CombiningKind::AND:
142 case CombiningKind::OR:
143 case CombiningKind::XOR:
145 case CombiningKind::MINNUMF:
146 case CombiningKind::MAXNUMF:
147 case CombiningKind::MINIMUMF:
148 case CombiningKind::MAXIMUMF:
149 return llvm::isa<FloatType>(elementType);
179 VectorType vectorType) {
180 unsigned elementVectorRank = 0;
181 VectorType elementVectorType =
182 llvm::dyn_cast<VectorType>(shapedType.getElementType());
183 if (elementVectorType)
184 elementVectorRank += elementVectorType.getRank();
185 return vectorType.getRank() - elementVectorRank;
189 VectorType vectorType) {
192 if (shapedType.getRank() == 0 &&
198 shapedType.getRank(),
200 shapedType.getContext());
207 vector::TransferReadOp read) {
208 auto readMask = read.getMask();
209 auto writeMask = write.getMask();
215 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
216 if (!couldBeSameSplat)
221 m_Constant<DenseElementsAttr>(&splatAttr)) ||
233 vector::TransferReadOp read) {
234 return !defWrite.hasOutOfBoundsDim() &&
235 defWrite.getIndices() == read.getIndices() &&
236 defWrite.getVectorType() == read.getVectorType() &&
237 defWrite.getPermutationMap() == read.getPermutationMap() &&
238 ((!defWrite.getMask() && !read.getMask()) ||
243 vector::TransferWriteOp priorWrite) {
244 return priorWrite.getIndices() == write.getIndices() &&
245 priorWrite.getMask() == write.getMask() &&
246 priorWrite.getVectorType() == write.getVectorType() &&
247 priorWrite.getPermutationMap() == write.getPermutationMap();
251 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
252 bool testDynamicValueUsingBounds) {
254 if (transferA.getVectorType() != transferB.getVectorType())
256 unsigned rankOffset = transferA.getLeadingShapedRank();
257 for (
unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
258 Value indexA = transferA.getIndices()[i];
259 Value indexB = transferB.getIndices()[i];
263 if (i < rankOffset) {
266 if (cstIndexA.has_value() && cstIndexB.has_value()) {
267 if (*cstIndexA != *cstIndexB)
271 if (testDynamicValueUsingBounds) {
274 FailureOr<uint64_t> delta =
276 if (succeeded(delta) && *delta != 0)
279 FailureOr<bool> testEqual =
281 if (succeeded(testEqual) && !testEqual.value())
287 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
288 if (cstIndexA.has_value() && cstIndexB.has_value()) {
289 int64_t distance =
std::abs(*cstIndexA - *cstIndexB);
290 if (distance >= vectorDim)
294 if (testDynamicValueUsingBounds) {
297 FailureOr<int64_t> delta =
299 if (succeeded(delta) &&
std::abs(*delta) >= vectorDim)
302 FailureOr<int64_t> computeDelta =
304 if (succeeded(computeDelta)) {
305 if (
std::abs(computeDelta.value()) >= vectorDim)
315 VectorTransferOpInterface transferB,
316 bool testDynamicValueUsingBounds) {
317 if (transferA.getBase() != transferB.getBase())
320 testDynamicValueUsingBounds);
330 for (
auto [posInDim, dimSize, offsetInDim] :
331 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
333 if (posInDim < dimSize + offsetInDim)
337 posInDim = offsetInDim;
347 llvm::transform(values, std::back_inserter(ints), [](
Value value) {
349 assert(constOp &&
"Unexpected non-constant index");
350 return constOp.value();
360 foldResults, std::back_inserter(ints), [](
OpFoldResult foldResult) {
361 assert(isa<Attribute>(foldResult) &&
"Unexpected non-constant index");
362 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
372 llvm::transform(foldResults, std::back_inserter(values),
374 if (
auto attr = dyn_cast<Attribute>(foldResult))
376 .create<arith::ConstantIndexOp>(
377 loc, cast<IntegerAttr>(attr).getInt())
380 return cast<Value>(foldResult);
391 auto lhs = mul.getLhs();
392 auto rhs = mul.getRhs();
393 if (lhs.getDefiningOp<vector::VectorScaleOp>())
395 if (rhs.getDefiningOp<vector::VectorScaleOp>())
443 void VectorDialect::initialize() {
445 #define GET_ATTRDEF_LIST
446 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
451 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
454 addInterfaces<VectorInlinerInterface>();
456 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
457 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
459 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
461 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
462 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
463 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
471 if (isa<ub::PoisonAttrInterface>(value))
474 return arith::ConstantOp::materialize(builder, value, type, loc);
490 void vector::MultiDimReductionOp::build(
OpBuilder &builder,
493 CombiningKind
kind) {
497 reductionDims.push_back(en.index());
498 build(builder, result,
kind, source, acc, reductionDims);
501 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
503 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
508 std::optional<SmallVector<int64_t, 4>>
509 MultiDimReductionOp::getShapeForUnroll() {
510 return llvm::to_vector<4>(getSourceVectorType().
getShape());
516 Type inferredReturnType;
517 auto sourceScalableDims = getSourceVectorType().getScalableDims();
518 for (
auto [dimIdx, dimSize] :
520 if (!llvm::any_of(getReductionDims(),
521 [dimIdx = dimIdx](int64_t reductionDimIdx) {
522 return reductionDimIdx ==
static_cast<int64_t
>(dimIdx);
524 targetShape.push_back(dimSize);
525 scalableDims.push_back(sourceScalableDims[dimIdx]);
528 if (targetShape.empty())
529 inferredReturnType = getSourceVectorType().getElementType();
532 targetShape, getSourceVectorType().
getElementType(), scalableDims);
533 if (
getType() != inferredReturnType)
534 return emitOpError() <<
"destination type " <<
getType()
535 <<
" is incompatible with source type "
536 << getSourceVectorType();
542 Type MultiDimReductionOp::getExpectedMaskType() {
543 auto vecType = getSourceVectorType();
546 vecType.getScalableDims());
555 struct ElideUnitDimsInMultiDimReduction
559 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
562 for (
const auto &dim :
enumerate(shape)) {
563 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
571 if (reductionOp.isMasked()) {
573 rootOp = reductionOp.getMaskingOp();
574 mask = reductionOp.getMaskingOp().getMask();
576 rootOp = reductionOp;
579 Location loc = reductionOp.getLoc();
580 Value acc = reductionOp.getAcc();
582 if (
auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
584 VectorType newMaskType =
586 dstVecType.getScalableDims());
587 mask = rewriter.
create<vector::ShapeCastOp>(loc, newMaskType, mask);
589 cast = rewriter.
create<vector::ShapeCastOp>(
590 loc, reductionOp.getDestType(), reductionOp.getSource());
595 mask = rewriter.
create<vector::ExtractOp>(loc, mask);
596 cast = rewriter.
create<vector::ExtractOp>(loc, reductionOp.getSource());
601 cast,
nullptr, mask);
608 void MultiDimReductionOp::getCanonicalizationPatterns(
610 results.
add<ElideUnitDimsInMultiDimReduction>(context);
619 arith::FastMathFlags fastMathFlags) {
620 build(builder, result,
kind, vector,
Value(), fastMathFlags);
625 arith::FastMathFlags fastMathFlags) {
626 build(builder, result,
627 llvm::cast<VectorType>(vector.
getType()).getElementType(),
kind, vector,
633 int64_t rank = getSourceVectorType().getRank();
635 return emitOpError(
"unsupported reduction rank: ") << rank;
638 Type eltType = getDest().getType();
640 return emitOpError(
"unsupported reduction type '")
641 << eltType <<
"' for kind '" << stringifyCombiningKind(getKind())
650 Type ReductionOp::getExpectedMaskType() {
651 auto vecType = getSourceVectorType();
654 vecType.getScalableDims());
661 case arith::AtomicRMWKind::addf:
662 case arith::AtomicRMWKind::addi:
663 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
664 CombiningKind::ADD, vector);
665 case arith::AtomicRMWKind::mulf:
666 case arith::AtomicRMWKind::muli:
667 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
668 CombiningKind::MUL, vector);
669 case arith::AtomicRMWKind::minimumf:
670 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
671 CombiningKind::MINIMUMF, vector);
672 case arith::AtomicRMWKind::mins:
673 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
674 CombiningKind::MINSI, vector);
675 case arith::AtomicRMWKind::minu:
676 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
678 case arith::AtomicRMWKind::maximumf:
679 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
680 CombiningKind::MAXIMUMF, vector);
681 case arith::AtomicRMWKind::maxs:
682 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
683 CombiningKind::MAXSI, vector);
684 case arith::AtomicRMWKind::maxu:
685 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
686 CombiningKind::MAXUI, vector);
687 case arith::AtomicRMWKind::andi:
688 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
689 CombiningKind::AND, vector);
690 case arith::AtomicRMWKind::ori:
691 return builder.
create<vector::ReductionOp>(vector.
getLoc(),
692 CombiningKind::OR, vector);
701 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
702 return llvm::to_vector<4>(getSourceVectorType().
getShape());
709 LogicalResult matchAndRewrite(ReductionOp reductionOp,
714 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
717 if (maskableOp.isMasked()) {
719 rootOp = maskableOp.getMaskingOp();
720 mask = maskableOp.getMaskingOp().getMask();
722 rootOp = reductionOp;
725 auto vectorType = reductionOp.getSourceVectorType();
726 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
729 Location loc = reductionOp.getLoc();
731 mask = rewriter.
create<ExtractOp>(loc, mask);
732 Value result = rewriter.
create<ExtractOp>(loc, reductionOp.getVector());
734 if (
Value acc = reductionOp.getAcc())
737 reductionOp.getFastmathAttr(), mask);
747 results.
add<ElideSingleElementReduction>(context);
761 getIndexingMapsAttrName(result.
name),
765 getIteratorTypesAttrName(result.
name),
768 return IteratorTypeAttr::get(builder.getContext(), t);
774 ArrayAttr indexingMaps,
775 ArrayAttr iteratorTypes) {
776 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
777 ContractionOp::getDefaultKind());
782 ArrayAttr indexingMaps,
783 ArrayAttr iteratorTypes, CombiningKind
kind) {
800 DictionaryAttr dictAttr;
815 dictAttr.getValue().end());
821 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
823 if (!iteratorTypes) {
825 <<
"expected " << getIteratorTypesAttrName(result.
name)
826 <<
" array attribute";
831 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
832 auto maybeIteratorType = symbolizeIteratorType(s);
833 if (!maybeIteratorType.has_value())
834 return parser.
emitError(loc) <<
"unexpected iterator_type (" << s <<
")";
836 iteratorTypeAttrs.push_back(
844 getKindAttrName(result.
name),
846 ContractionOp::getDefaultKind()));
848 if (masksInfo.empty())
850 if (masksInfo.size() != 2)
852 "expected zero or exactly 2 vector mask operands");
853 auto lhsType = llvm::cast<VectorType>(types[0]);
854 auto rhsType = llvm::cast<VectorType>(types[1]);
856 std::array<VectorType, 2> maskTypes = {
866 auto attrNames = getTraitAttrNames();
868 traitAttrsSet.insert_range(attrNames);
870 for (
auto attr : (*this)->getAttrs()) {
871 if (attr.getName() == getIteratorTypesAttrName()) {
873 llvm::cast<ArrayAttr>(attr.getValue())
874 .getAsValueRange<IteratorTypeAttr, IteratorType>();
880 llvm::map_range(iteratorTypes, [&](IteratorType t) ->
Attribute {
884 attrs.emplace_back(getIteratorTypesAttrName(),
886 }
else if (traitAttrsSet.count(attr.getName().strref()) > 0)
887 attrs.push_back(attr);
891 p <<
" " << dictAttr <<
" " << getLhs() <<
", ";
892 p << getRhs() <<
", " << getAcc();
895 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType() <<
" into "
900 const std::vector<std::pair<int64_t, int64_t>> &map) {
901 for (
auto &dimPair : map) {
902 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
903 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
904 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
911 ContractionOp op, VectorType lhsType, VectorType rhsType,
Type accType,
913 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
914 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
917 for (
auto &dimPair : contractingDimMap) {
918 lhsContractingDimSet.insert(dimPair.first);
919 rhsContractingDimSet.insert(dimPair.second);
922 llvm::make_second_range(batchDimMap));
926 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
927 if (lhsContractingDimSet.count(i) > 0)
929 expectedResultDims.push_back(lhsType.getDimSize(i));
933 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
934 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
936 expectedResultDims.push_back(rhsType.getDimSize(i));
940 if (expectedResultDims.empty()) {
942 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
943 return op.emitOpError(
"invalid accumulator/result vector shape");
946 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
947 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
948 if (!resVectorType || !accVectorType)
949 return op.emitOpError(
"invalid accumulator/result vector shape");
955 AffineMap lhsMap = op.getIndexingMapsArray()[0];
956 AffineMap rhsMap = op.getIndexingMapsArray()[1];
958 return op.emitOpError(
959 "expected all dimensions to be either a LHS or a RHS dimension");
962 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
963 VectorType v = pair.first;
964 auto map = pair.second;
965 for (
unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
966 unsigned pos = map.getDimPosition(idx);
971 if (!llvm::all_of(extents, [](
AffineExpr e) {
return e; }))
972 return op.emitOpError(
"expected all dimensions to get an extent as "
973 "either a LHS or a RHS dimension");
975 AffineMap resMap = op.getIndexingMapsArray()[2];
981 llvm::IsaPred<AffineConstantExpr>) &&
982 "expected constant extent along all dimensions.");
984 auto expectedShape = llvm::to_vector<4>(
986 return cast<AffineConstantExpr>(e).getValue();
990 resVectorType.getScalableDims());
991 if (resVectorType != expected || accVectorType != expected)
992 return op.emitOpError(
993 "invalid accumulator/result vector shape, expected: ")
1000 VectorType lhsType = getLhsType();
1001 VectorType rhsType = getRhsType();
1002 Type accType = getAccType();
1003 Type resType = getResultType();
1005 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1006 if (!lhsType.getElementType().isSignlessInteger())
1007 return emitOpError(
"only supports signless integer types");
1011 if (getIndexingMapsArray().size() != 3)
1012 return emitOpError(
"expected an indexing map for each vector operand");
1017 unsigned numIterators = getIteratorTypes().getValue().size();
1019 auto index = it.index();
1020 auto map = it.value();
1021 if (map.getNumSymbols() != 0)
1022 return emitOpError(
"expected indexing map ")
1023 << index <<
" to have no symbols";
1024 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).
getType());
1025 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1028 if (map.getNumDims() != numIterators)
1029 return emitOpError(
"expected indexing map ")
1030 << index <<
" to have " << numIterators <<
" number of inputs";
1031 if (map.getNumResults() != rank)
1032 return emitOpError(
"expected indexing map ")
1033 << index <<
" to have " << rank <<
" number of outputs";
1034 if (!map.isProjectedPermutation())
1035 return emitOpError(
"expected indexing map ")
1036 << index <<
" to be a projected permutation of its inputs";
1039 auto contractingDimMap = getContractingDimMap();
1040 auto batchDimMap = getBatchDimMap();
1043 if (contractingDimMap.empty())
1044 return emitOpError(
"expected at least one contracting dimension pair");
1047 if (!
verifyDimMap(lhsType, rhsType, contractingDimMap))
1048 return emitOpError(
"invalid contracting dimension map");
1052 return emitOpError(
"invalid batch dimension map");
1056 contractingDimMap, batchDimMap)))
1060 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1061 auto elementType = vectorType ? vectorType.getElementType() : resType;
1063 return emitOpError(
"unsupported contraction type");
1072 Type ContractionOp::getExpectedMaskType() {
1073 auto indexingMaps = this->getIndexingMapsArray();
1076 VectorType lhsType = this->getLhsType();
1077 VectorType rhsType = this->getRhsType();
1079 unsigned numVecDims = lhsIdxMap.
getNumDims();
1088 lhsType.getScalableDims()[dimIdx];
1093 rhsType.getScalableDims()[dimIdx];
1096 assert(!ShapedType::isDynamicShape(maskShape) &&
1097 "Mask shape couldn't be computed");
1101 maskShapeScalableDims);
1106 getIteratorTypesAttrName(), getKindAttrName()};
1116 static std::vector<std::pair<int64_t, int64_t>>
1118 IteratorType targetIteratorType,
MLIRContext *context) {
1119 std::vector<std::pair<int64_t, int64_t>> dimMap;
1121 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1122 if (iteratorType != targetIteratorType)
1128 if (lhsDim >= 0 && rhsDim >= 0)
1129 dimMap.emplace_back(lhsDim, rhsDim);
1134 void ContractionOp::getIterationBounds(
1136 auto lhsShape = getLhsType().getShape();
1137 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1142 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1143 if (iteratorType == IteratorType::reduction) {
1145 int64_t lhsDimIndex =
getResultIndex(indexingMaps[0], targetExpr);
1146 assert(lhsDimIndex >= 0);
1147 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1151 int64_t resDimIndex =
getResultIndex(indexingMaps[2], targetExpr);
1152 assert(resDimIndex >= 0);
1153 assert(resVectorType !=
nullptr);
1154 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1158 void ContractionOp::getIterationIndexMap(
1160 unsigned numMaps = getIndexingMapsArray().size();
1161 iterationIndexMap.resize(numMaps);
1163 auto index = it.index();
1164 auto map = it.value();
1165 for (
unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1166 auto dim = cast<AffineDimExpr>(map.getResult(i));
1167 iterationIndexMap[index][dim.getPosition()] = i;
1172 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1174 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1178 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1180 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1184 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1186 getIterationBounds(shape);
1208 template <
typename AddOpType>
1214 auto canonicalize = [&](
Value maybeContraction,
1215 Value otherOperand) -> vector::ContractionOp {
1216 vector::ContractionOp contractionOp =
1217 dyn_cast_or_null<vector::ContractionOp>(
1220 return vector::ContractionOp();
1221 if (
auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1222 contractionOp.getAcc().getDefiningOp())) {
1223 if (maybeZero.getValue() ==
1224 rewriter.
getZeroAttr(contractionOp.getAcc().getType())) {
1226 bvm.
map(contractionOp.getAcc(), otherOperand);
1227 auto newContraction =
1228 cast<vector::ContractionOp>(rewriter.
clone(*contractionOp, bvm));
1229 rewriter.
replaceOp(addOp, newContraction.getResult());
1230 return newContraction;
1233 return vector::ContractionOp();
1236 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1237 vector::ContractionOp
contract = canonicalize(a, b);
1239 return contract ? success() : failure();
1255 setResultRanges(getResult(), argRanges.front());
1261 result.
addTypes(llvm::cast<VectorType>(source.
getType()).getElementType());
1265 VectorType vectorType = getSourceVectorType();
1266 if (vectorType.getRank() == 0) {
1268 return emitOpError(
"expected position to be empty with 0-D vector");
1271 if (vectorType.getRank() != 1)
1272 return emitOpError(
"unexpected >1 vector rank");
1274 return emitOpError(
"expected position for 1-D vector");
1278 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1280 if (!adaptor.getPosition())
1284 if (
auto splat = getVector().getDefiningOp<vector::SplatOp>())
1285 return splat.getInput();
1288 if (
auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1292 auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1293 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1297 auto srcElements = src.getValues<
Attribute>();
1299 uint64_t posIdx = pos.getInt();
1300 if (posIdx >= srcElements.size())
1303 return srcElements[posIdx];
1310 return index == poisonValue || (index >= 0 && index < maxIndex);
1319 setResultRanges(getResult(), argRanges.front());
1324 auto vectorTy = cast<VectorType>(source.
getType());
1329 Value source, int64_t position) {
1349 build(builder, result, source, dynamicPos,
1354 ExtractOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
1355 ExtractOp::Adaptor adaptor,
1357 auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1358 if (
static_cast<int64_t
>(adaptor.getStaticPosition().size()) ==
1359 vectorType.getRank()) {
1360 inferredReturnTypes.push_back(vectorType.getElementType());
1362 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1363 vectorType.getRank());
1365 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1366 vectorType.getScalableDims().drop_front(n)));
1374 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1375 return vectorType && vectorType.getShape().equals({1}) &&
1376 vectorType.getElementType() == r.front();
1378 if (l.size() == 1 && r.size() == 1 &&
1379 (isCompatible(l, r) || isCompatible(r, l)))
1386 auto dynamicMarkersCount =
1387 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1388 if (
static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1390 "mismatch between dynamic and static positions (kDynamic marker but no "
1391 "corresponding dynamic position) -- this can only happen due to an "
1392 "incorrect fold/rewrite");
1393 auto position = getMixedPosition();
1394 if (position.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
1396 "expected position attribute of rank no greater than vector rank");
1398 if (
auto attr = dyn_cast<Attribute>(pos)) {
1399 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1401 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1402 return emitOpError(
"expected position attribute #")
1404 <<
" to be a non-negative integer smaller than the "
1405 "corresponding vector dimension or poison (-1)";
1412 template <
typename IntType>
1414 return llvm::to_vector<4>(llvm::map_range(
1415 arrayAttr.getAsRange<IntegerAttr>(),
1416 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1422 if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1426 if (extractOp.hasDynamicPosition())
1430 ExtractOp currentOp = extractOp;
1432 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1433 while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1436 if (currentOp.hasDynamicPosition())
1439 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1441 extractOp.setOperand(0, currentOp.getVector());
1444 std::reverse(globalPosition.begin(), globalPosition.end());
1445 extractOp.setStaticPosition(globalPosition);
1457 class ExtractFromInsertTransposeChainState {
1459 ExtractFromInsertTransposeChainState(ExtractOp e);
1468 template <
typename ContainerA,
typename ContainerB>
1469 bool isContainedWithin(
const ContainerA &a,
const ContainerB &b) {
1470 return a.size() <= b.size() &&
1471 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1478 template <
typename ContainerA,
typename ContainerB>
1479 bool intersectsWhereNonNegative(
const ContainerA &a,
const ContainerB &b) {
1480 for (
auto [elemA, elemB] : llvm::zip(a, b)) {
1481 if (elemA < 0 || elemB < 0)
1496 void updateStateForNextIteration(
Value v) {
1503 LogicalResult handleTransposeOp();
1506 LogicalResult handleInsertOpWithMatchingPos(
Value &res);
1521 LogicalResult handleInsertOpWithPrefixPos(
Value &res);
1526 Value tryToFoldExtractOpInPlace(
Value source);
1528 ExtractOp extractOp;
1530 int64_t extractedRank;
1532 InsertOp nextInsertOp;
1533 TransposeOp nextTransposeOp;
1548 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1550 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1551 extractedRank(extractOp.getNumIndices()) {
1552 assert(vectorRank >= extractedRank &&
"Extracted position overflow");
1553 sentinels.reserve(vectorRank - extractedRank);
1554 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1555 sentinels.push_back(-(i + 1));
1557 extractOp.getStaticPosition().end());
1563 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1565 if (extractOp.hasDynamicPosition())
1568 if (!nextTransposeOp)
1571 nextTransposeOp.getPermutation(), extractOp.getContext()));
1578 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1581 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1588 res = nextInsertOp.getValueToStore();
1590 return success(canFold());
1597 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(
Value &res) {
1599 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1612 res = nextInsertOp.getValueToStore();
1620 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1623 if (extractOp.hasDynamicPosition())
1627 bool nothingToFold = (source == extractOp.getVector());
1628 if (nothingToFold || !canFold())
1633 extractOp.setStaticPosition(
1635 extractOp.getVectorMutable().assign(source);
1636 return extractOp.getResult();
1640 Value ExtractFromInsertTransposeChainState::fold() {
1642 if (extractOp.hasDynamicPosition())
1645 Value valueToExtractFrom = extractOp.getVector();
1646 updateStateForNextIteration(valueToExtractFrom);
1647 while (nextInsertOp || nextTransposeOp) {
1650 if (succeeded(handleTransposeOp())) {
1651 valueToExtractFrom = nextTransposeOp.getVector();
1652 updateStateForNextIteration(valueToExtractFrom);
1658 if (succeeded(handleInsertOpWithMatchingPos(result)))
1663 if (succeeded(handleInsertOpWithPrefixPos(result)))
1664 return tryToFoldExtractOpInPlace(result);
1674 valueToExtractFrom = nextInsertOp.getDest();
1675 updateStateForNextIteration(valueToExtractFrom);
1678 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1683 auto hasZeroDimVectorType = [](
Type type) ->
bool {
1684 auto vecType = dyn_cast<VectorType>(type);
1685 return vecType && vecType.getRank() == 0;
1694 Operation *defOp = extractOp.getVector().getDefiningOp();
1695 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1699 if (extractOp.getType() == source.
getType())
1701 auto getRank = [](
Type type) {
1702 return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1707 unsigned broadcastSrcRank = getRank(source.
getType());
1708 if (broadcastSrcRank == 0 && source.
getType() == extractOp.getType())
1711 unsigned extractResultRank = getRank(extractOp.getType());
1712 if (extractResultRank > broadcastSrcRank)
1715 auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1716 auto broadcastVecType = llvm::dyn_cast<VectorType>(source.
getType());
1717 if (extractVecType && broadcastVecType &&
1718 extractVecType.getShape() !=
1719 broadcastVecType.getShape().take_back(extractResultRank))
1722 auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1723 int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1729 broadcastOp.computeBroadcastedUnitDims();
1732 int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1733 for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1734 if (broadcastedUnitDims.contains(i))
1738 int64_t rankDiff = broadcastSrcRank - extractResultRank;
1739 extractPos.erase(extractPos.begin(),
1740 std::next(extractPos.begin(), extractPos.size() - rankDiff));
1743 extractOp->setOperands(
1744 llvm::to_vector(llvm::concat<Value>(
ValueRange(source), dynPos)));
1745 extractOp.setStaticPosition(staticPos);
1746 return extractOp.getResult();
1762 if (extractOp.hasDynamicPosition())
1765 auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
1770 if (shuffleOp.getResultVectorType().getRank() != 1)
1773 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1774 auto shuffleMask = shuffleOp.getMask();
1775 int64_t extractIdx = extractOp.getStaticPosition()[0];
1776 int64_t shuffleIdx = shuffleMask[extractIdx];
1779 if (shuffleIdx < inputVecSize) {
1780 extractOp.setOperand(0, shuffleOp.getV1());
1781 extractOp.setStaticPosition({shuffleIdx});
1783 extractOp.setOperand(0, shuffleOp.getV2());
1784 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1787 return extractOp.getResult();
1793 if (extractOp.hasDynamicPosition())
1796 auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1801 auto getDimReverse = [](VectorType type, int64_t n) {
1802 return type.getShape().take_back(n + 1).front();
1804 int64_t destinationRank =
1805 llvm::isa<VectorType>(extractOp.getType())
1806 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1808 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1810 if (destinationRank > 0) {
1811 auto destinationType =
1812 llvm::cast<VectorType>(extractOp.getResult().getType());
1813 for (int64_t i = 0; i < destinationRank; i++) {
1817 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1818 getDimReverse(destinationType, i))
1825 std::reverse(extractedPos.begin(), extractedPos.end());
1828 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1829 strides.push_back(stride);
1831 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1834 int64_t position =
linearize(extractedPos, strides);
1838 int64_t numDimension =
1839 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1841 for (int64_t i = 0; i < numDimension; i++) {
1842 newStrides.push_back(stride);
1844 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1846 std::reverse(newStrides.begin(), newStrides.end());
1850 extractOp.setStaticPosition(newPosition);
1851 extractOp.setOperand(0, shapeCastOp.getSource());
1852 return extractOp.getResult();
1858 if (extractOp.hasDynamicPosition())
1861 auto extractStridedSliceOp =
1862 extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1863 if (!extractStridedSliceOp)
1872 if (extractStridedSliceOp.hasNonUnitStrides())
1877 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1878 while (!sliceOffsets.empty()) {
1879 size_t lastOffset = sliceOffsets.size() - 1;
1880 if (sliceOffsets.back() != 0 ||
1881 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1882 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1884 sliceOffsets.pop_back();
1886 unsigned destinationRank = 0;
1887 if (
auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1888 destinationRank = vecType.getRank();
1891 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1892 sliceOffsets.size())
1896 assert(extractedPos.size() >= sliceOffsets.size());
1897 for (
size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1898 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1899 extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1903 extractOp.setStaticPosition(extractedPos);
1904 return extractOp.getResult();
1910 if (extractOp.hasDynamicPosition())
1913 int64_t destinationRank =
1914 llvm::isa<VectorType>(extractOp.getType())
1915 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1917 auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1927 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1928 insertOp.getSourceVectorType().getRank();
1929 if (destinationRank > insertOp.getSourceVectorType().getRank())
1931 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1934 if (llvm::any_of(insertOp.getStrides(), [](
Attribute attr) {
1935 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1938 bool disjoint =
false;
1940 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1941 int64_t start = insertOffsets[dim];
1943 (dim < insertRankDiff)
1945 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1946 int64_t end = start + size;
1947 int64_t offset = extractOffsets[dim];
1949 if (start <= offset && offset < end) {
1950 if (dim >= insertRankDiff)
1951 offsetDiffs.push_back(offset - start);
1961 int64_t srcRankDiff =
1962 insertOp.getSourceVectorType().getRank() - destinationRank;
1963 for (int64_t i = 0; i < destinationRank; i++) {
1964 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1965 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1969 extractOp.getVectorMutable().assign(insertOp.getValueToStore());
1972 extractOp.setStaticPosition(offsetDiffs);
1973 return extractOp.getResult();
1977 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1990 if (extractOp.hasDynamicPosition())
1994 auto fromElementsOp = extractOp.getVector().
getDefiningOp<FromElementsOp>();
1995 if (!fromElementsOp)
1999 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2000 if (vecType.isScalable())
2004 int64_t rank = vecType.getRank();
2006 if (extractOp.getType() != vecType.getElementType())
2008 assert(
static_cast<int64_t
>(indices.size()) == rank &&
2009 "unexpected number of indices");
2014 for (
int i = rank - 1; i >= 0; --i) {
2015 flatIndex += indices[i] * stride;
2016 stride *= vecType.getDimSize(i);
2018 return fromElementsOp.getElements()[flatIndex];
2023 template <
typename OpType,
typename AdaptorType>
2026 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2027 OperandRange dynamicPosition = op.getDynamicPosition();
2030 if constexpr (std::is_same_v<OpType, ExtractOp>)
2031 vectorShape = op.getSourceVectorType().getShape();
2036 if (!dynamicPosition.size())
2043 bool opChange =
false;
2044 for (
unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2045 if (!ShapedType::isDynamic(staticPosition[i]))
2047 Attribute positionAttr = dynamicPositionAttr[index];
2048 Value position = dynamicPosition[index++];
2049 if (
auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2050 int64_t value = attr.getInt();
2054 staticPosition[i] = attr.getInt();
2059 operands.push_back(position);
2063 op.setStaticPosition(staticPosition);
2064 op.getOperation()->setOperands(operands);
2065 return op.getResult();
2074 int64_t poisonVal) {
2075 if (!is_contained(staticPos, poisonVal))
2083 if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2092 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2097 if (denseAttr.isSplat()) {
2099 if (
auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2104 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2105 if (vecTy.isScalable())
2108 if (extractOp.hasDynamicPosition()) {
2123 copy(extractOp.getStaticPosition(), completePositions.begin());
2126 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2129 if (
auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2131 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2134 newAttr = *denseValuesBegin;
2144 if (getNumIndices() == 0 && getVector().
getType() == getResult().
getType())
2147 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2155 if (
auto res = ExtractFromInsertTransposeChainState(*this).fold())
2184 Operation *defOp = extractOp.getVector().getDefiningOp();
2185 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2189 if (extractOp.getType() == source.
getType())
2191 auto getRank = [](
Type type) {
2192 return llvm::isa<VectorType>(type)
2193 ? llvm::cast<VectorType>(type).getRank()
2196 unsigned broadcastSrcRank = getRank(source.
getType());
2197 unsigned extractResultRank = getRank(extractOp.getType());
2201 if (extractResultRank < broadcastSrcRank)
2205 if (extractResultRank == 0)
2209 extractOp, extractOp.getType(), source);
2222 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2226 VectorType extractedMaskType =
2227 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2229 if (!extractedMaskType)
2232 auto maskOperands = createMaskOp.getOperands();
2234 VectorType maskType = createMaskOp.getVectorType();
2236 bool containsUnknownDims =
false;
2239 for (
size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2241 int64_t pos = extractOpPos[dimIdx];
2242 Value operand = maskOperands[dimIdx];
2243 auto constantOp = operand.
getDefiningOp<arith::ConstantOp>();
2246 containsUnknownDims =
true;
2250 int64_t createMaskBound =
2251 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2253 if (pos != ShapedType::kDynamic) {
2256 allFalse |= pos >= createMaskBound;
2257 }
else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2261 containsUnknownDims =
true;
2268 }
else if (!containsUnknownDims) {
2270 extractOp, extractedMaskType,
2271 maskOperands.drop_front(extractOpPos.size()));
2281 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2283 auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2287 VectorType sourceType = castOp.getSourceVectorType();
2288 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2292 if (sourceType.getNumElements() != targetType.getNumElements())
2296 castOp.getSource());
2306 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2309 if (extractOp.hasDynamicPosition())
2313 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2318 auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2319 if (!fromElementsOp)
2321 VectorType inputType = fromElementsOp.getType();
2324 if (resultType.isScalable() || inputType.isScalable())
2330 llvm::to_vector(extractOp.getStaticPosition());
2331 firstElementPos.append(resultType.getRank(), 0);
2334 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2335 flatIndex += firstElementPos[i] * stride;
2336 stride *= inputType.getDimSize(i);
2341 extractOp, resultType,
2342 fromElementsOp.getElements().slice(flatIndex,
2343 resultType.getNumElements()));
2351 results.
add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2352 results.
add(foldExtractFromShapeCastToShapeCast);
2353 results.
add(foldExtractFromFromElements);
2358 for (
auto attr : arrayAttr)
2359 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2366 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2381 if (!llvm::all_equal(fromElementsOp.getElements()))
2384 fromElementsOp.getElements().front());
2399 setResultRanges(getResult(), argRanges.front());
2407 int64_t rankDiff = dstShape.size() - srcShape.size();
2408 int64_t dstDim = rankDiff;
2410 for (
auto [s1, s2] :
2411 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2413 assert(s1 == 1 &&
"expected \"dim-1\" broadcasting");
2423 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2442 Value BroadcastOp::createOrFoldBroadcastOp(
2445 assert(!dstShape.empty() &&
"unexpected empty dst shape");
2449 for (
int i = 0, e = dstShape.size(); i < e; ++i) {
2450 if (broadcastedDims.contains(i))
2452 checkShape.push_back(dstShape[i]);
2454 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2455 "ill-formed broadcastedDims contains values not confined to "
2460 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.
getType());
2464 if (!srcVectorType) {
2465 assert(checkShape.empty() &&
2466 "ill-formed createOrFoldBroadcastOp arguments");
2467 return b.
createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2470 assert(srcVectorType.getShape().equals(checkShape) &&
2471 "ill-formed createOrFoldBroadcastOp arguments");
2482 broadcastShape.reserve(dstShape.size());
2498 int64_t nextSrcShapeDim = broadcastedDims.size();
2499 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2500 if (broadcastedDims.contains(i)) {
2505 broadcastShape.push_back(dstShape[i]);
2506 permutation[i] = broadcastShape.size() - 1;
2512 permutation[i] = nextSrcShapeDim++;
2516 llvm::append_range(broadcastShape, srcVectorType.getShape());
2521 "unexpected \"dim-1\" broadcast");
2523 VectorType broadcastType =
VectorType::get(broadcastShape, elementType);
2525 vector::BroadcastableToResult::Success &&
2526 "must be broadcastable");
2530 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2531 if (permutation[i] != i)
2532 return b.
createOrFold<vector::TransposeOp>(loc, res, permutation);
2538 Type srcType, VectorType dstVectorType,
2539 std::pair<VectorDim, VectorDim> *mismatchingDims) {
2543 return BroadcastableToResult::Success;
2545 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2547 return BroadcastableToResult::SourceTypeNotAVector;
2549 int64_t srcRank = srcVectorType.getRank();
2550 int64_t dstRank = dstVectorType.getRank();
2551 if (srcRank > dstRank)
2552 return BroadcastableToResult::SourceRankHigher;
2555 int64_t lead = dstRank - srcRank;
2556 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2559 bool foundMismatchingDims =
false;
2562 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2563 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2564 if (srcDim != 1 && srcDim != dstDim)
2565 foundMismatchingDims =
true;
2568 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2569 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2570 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2573 (srcDimScalableFlag != dstDimScalableFlag &&
2574 (srcDim != 1 || srcDimScalableFlag)))
2575 foundMismatchingDims =
true;
2577 if (foundMismatchingDims) {
2578 if (mismatchingDims !=
nullptr) {
2579 mismatchingDims->first.dim = srcDim;
2580 mismatchingDims->first.isScalable = srcDimScalableFlag;
2582 mismatchingDims->second.dim = dstDim;
2583 mismatchingDims->second.isScalable = dstDimScalableFlag;
2585 return BroadcastableToResult::DimensionMismatch;
2589 return BroadcastableToResult::Success;
2593 std::pair<VectorDim, VectorDim> mismatchingDims;
2595 getSourceType(), getResultVectorType(), &mismatchingDims);
2596 if (res == BroadcastableToResult::Success)
2598 if (res == BroadcastableToResult::SourceRankHigher)
2599 return emitOpError(
"source rank higher than destination rank");
2600 if (res == BroadcastableToResult::DimensionMismatch) {
2601 return emitOpError(
"dimension mismatch (")
2602 << (mismatchingDims.first.isScalable ?
"[" :
"")
2603 << mismatchingDims.first.dim
2604 << (mismatchingDims.first.isScalable ?
"]" :
"") <<
" vs. "
2605 << (mismatchingDims.second.isScalable ?
"[" :
"")
2606 << mismatchingDims.second.dim
2607 << (mismatchingDims.second.isScalable ?
"]" :
"") <<
")";
2609 if (res == BroadcastableToResult::SourceTypeNotAVector)
2610 return emitOpError(
"source type is not a vector");
2611 llvm_unreachable(
"unexpected vector.broadcast op error");
2615 if (getSourceType() == getResultVectorType())
2617 if (!adaptor.getSource())
2619 auto vectorType = getResultVectorType();
2620 if (
auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
2621 if (vectorType.getElementType() != attr.getType())
2625 if (
auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
2626 if (vectorType.getElementType() != attr.getType())
2630 if (
auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2632 if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
2645 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2649 broadcastOp.getResultVectorType(),
2650 srcBroadcast.getSource());
2660 results.
add<BroadcastFolder>(context);
2668 VectorType resultType = getResultVectorType();
2669 VectorType v1Type = getV1VectorType();
2670 VectorType v2Type = getV2VectorType();
2672 int64_t resRank = resultType.getRank();
2673 int64_t v1Rank = v1Type.getRank();
2674 int64_t v2Rank = v2Type.getRank();
2675 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2676 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2677 if (!wellFormed0DCase && !wellFormedNDCase)
2678 return emitOpError(
"rank mismatch");
2681 for (int64_t r = 1; r < v1Rank; ++r) {
2682 int64_t resDim = resultType.getDimSize(r);
2683 int64_t v1Dim = v1Type.getDimSize(r);
2684 int64_t v2Dim = v2Type.getDimSize(r);
2685 if (resDim != v1Dim || v1Dim != v2Dim)
2686 return emitOpError(
"dimension mismatch");
2690 int64_t maskLength = mask.size();
2691 if (maskLength <= 0)
2692 return emitOpError(
"invalid mask length");
2693 if (maskLength != resultType.getDimSize(0))
2694 return emitOpError(
"mask length mismatch");
2696 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2697 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2700 return emitOpError(
"mask index #") << (idx + 1) <<
" out of range";
2706 ShuffleOp::inferReturnTypes(
MLIRContext *, std::optional<Location>,
2707 ShuffleOp::Adaptor adaptor,
2709 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2710 auto v1Rank = v1Type.getRank();
2714 shape.reserve(v1Rank);
2715 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2718 llvm::append_range(shape, v1Type.getShape().drop_front());
2719 inferredReturnTypes.push_back(
2724 template <
typename T>
2727 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
2728 return value == expected++;
2732 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2733 auto v1Type = getV1VectorType();
2734 auto v2Type = getV2VectorType();
2736 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
2737 "Vector shuffle does not support scalable vectors");
2741 if (v1Type.getRank() == 0)
2745 auto mask = getMask();
2752 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
2753 if (!v1Attr || !v2Attr)
2757 bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
2758 bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
2759 if (isV1Poison && isV2Poison)
2764 if (v1Type.getRank() != 1)
2774 to_vector(cast<DenseElementsAttr>(v2Attr).getValues<Attribute>());
2775 poisonElement = v2Elements[0];
2779 to_vector(cast<DenseElementsAttr>(v1Attr).getValues<Attribute>());
2780 poisonElement = v1Elements[0];
2784 int64_t v1Size = v1Type.getDimSize(0);
2785 for (int64_t maskIdx : mask) {
2788 if (maskIdx == ShuffleOp::kPoisonIndex) {
2789 indexedElm = poisonElement;
2791 if (maskIdx < v1Size)
2792 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
2794 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
2797 results.push_back(indexedElm);
2812 VectorType v1VectorType = shuffleOp.getV1VectorType();
2814 if (v1VectorType.getRank() > 0)
2816 if (mask.size() != 1)
2836 auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2837 auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2839 if (!v1Splat || !v2Splat)
2842 if (v1Splat.getInput() != v2Splat.getInput())
2858 VectorType resultType = op.getResultVectorType();
2859 if (resultType.isScalable())
2861 op,
"ShuffleOp can't represent a scalable interleave");
2863 if (resultType.getRank() != 1)
2865 op,
"ShuffleOp can't represent an n-D interleave");
2867 VectorType sourceType = op.getV1VectorType();
2868 if (sourceType != op.getV2VectorType() ||
2869 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2871 op,
"ShuffleOp types don't match an interleave");
2875 int64_t resultVectorSize = resultType.getNumElements();
2876 for (
int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2877 int64_t maskValueA = shuffleMask[i * 2];
2878 int64_t maskValueB = shuffleMask[(i * 2) + 1];
2879 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2881 "ShuffleOp mask not interleaving");
2893 results.
add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2903 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2908 build(builder, result, source, dest, {});
2912 auto dstVectorType = getDestVectorType();
2913 if (dstVectorType.getRank() == 0) {
2915 return emitOpError(
"expected position to be empty with 0-D vector");
2918 if (dstVectorType.getRank() != 1)
2919 return emitOpError(
"unexpected >1 vector rank");
2921 return emitOpError(
"expected position for 1-D vector");
2925 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2927 if (!adaptor.getPosition())
2930 auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2931 auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2932 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2933 if (!src || !dst || !pos)
2939 auto dstElements = dst.getValues<
Attribute>();
2943 uint64_t posIdx = pos.getInt();
2944 if (posIdx >= results.size())
2946 results[posIdx] = src;
2957 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2962 auto vectorTy = cast<VectorType>(dest.
getType());
2963 build(builder, result, source, dest,
2968 Value source,
Value dest, int64_t position) {
2981 posVals.reserve(position.size());
2982 llvm::transform(position, std::back_inserter(posVals),
2984 build(builder, result, source, dest, posVals);
2993 build(builder, result, source, dest, dynamicPos,
2999 auto destVectorType = getDestVectorType();
3000 if (position.size() >
static_cast<unsigned>(destVectorType.getRank()))
3002 "expected position attribute of rank no greater than dest vector rank");
3003 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3004 if (srcVectorType &&
3005 (
static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3006 static_cast<unsigned>(destVectorType.getRank())))
3007 return emitOpError(
"expected position attribute rank + source rank to "
3008 "match dest vector rank");
3009 if (!srcVectorType &&
3010 (position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
3012 "expected position attribute rank to match the dest vector rank");
3014 if (
auto attr = dyn_cast<Attribute>(pos)) {
3015 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3017 destVectorType.getDimSize(idx))) {
3018 return emitOpError(
"expected position attribute #")
3020 <<
" to be a non-negative integer smaller than the "
3022 "dest vector dimension";
3040 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3041 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3042 srcVecType.getNumElements())
3045 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3057 auto srcSplat = op.getValueToStore().getDefiningOp<SplatOp>();
3058 auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
3060 if (!srcSplat || !dstSplat)
3063 if (srcSplat.getInput() != dstSplat.getInput())
3076 int64_t maxVectorSizeFoldThreshold) {
3077 if (insertOp.hasDynamicPosition())
3080 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3088 VectorType destTy = insertOp.getDestVectorType();
3089 if (destTy.isScalable())
3093 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3094 !insertOp->hasOneUse())
3100 copy(insertOp.getStaticPosition(), completePositions.begin());
3101 int64_t insertBeginPosition =
3105 Type destEltType = destTy.getElementType();
3110 if (
auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
3111 if (intAttr.getType() != expectedType)
3120 if (
auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3121 for (
auto value : denseSource.getValues<
Attribute>())
3127 auto allValues = llvm::to_vector(denseDst.getValues<
Attribute>());
3128 copy(insertedValues, allValues.begin() + insertBeginPosition);
3136 results.
add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
3139 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3142 constexpr int64_t vectorSizeFoldThreshold = 256;
3146 if (getNumIndices() == 0 && getValueToStoreType() ==
getType())
3147 return getValueToStore();
3152 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3155 *
this, adaptor.getValueToStore(), adaptor.getDest(),
3156 vectorSizeFoldThreshold)) {
3182 template <
typename OpType>
3184 ArrayAttr arrayAttr,
3186 StringRef attrName) {
3187 if (arrayAttr.size() > shape.size())
3188 return op.emitOpError(
"expected ")
3189 << attrName <<
" attribute of rank no greater than vector rank";
3196 template <
typename OpType>
3197 static LogicalResult
3199 int64_t
max, StringRef attrName,
3200 bool halfOpen =
true) {
3201 for (
auto attr : arrayAttr) {
3202 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3206 if (val < min || val >= upper)
3207 return op.emitOpError(
"expected ") << attrName <<
" to be confined to ["
3208 <<
min <<
", " << upper <<
")";
3216 template <
typename OpType>
3217 static LogicalResult
3220 bool halfOpen =
true, int64_t
min = 0) {
3221 for (
auto [index, attrDimPair] :
3223 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3224 int64_t
max = std::get<1>(attrDimPair);
3227 if (val < min || val >=
max)
3228 return op.emitOpError(
"expected ")
3229 << attrName <<
" dimension " << index <<
" to be confined to ["
3230 <<
min <<
", " <<
max <<
")";
3240 template <
typename OpType>
3242 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3244 bool halfOpen =
true, int64_t
min = 1) {
3245 assert(arrayAttr1.size() <= shape.size());
3246 assert(arrayAttr2.size() <= shape.size());
3247 for (
auto [index, it] :
3249 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3250 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3251 int64_t
max = std::get<2>(it);
3254 if (val1 + val2 < 0 || val1 + val2 >=
max)
3255 return op.emitOpError(
"expected sum(")
3256 << attrName1 <<
", " << attrName2 <<
") dimension " << index
3257 <<
" to be confined to [" <<
min <<
", " <<
max <<
")";
3264 auto attrs = llvm::map_range(values, [context](int64_t v) ->
Attribute {
3271 auto sourceVectorType = getSourceVectorType();
3272 auto destVectorType = getDestVectorType();
3273 auto offsets = getOffsetsAttr();
3274 auto strides = getStridesAttr();
3275 if (offsets.size() !=
static_cast<unsigned>(destVectorType.getRank()))
3277 "expected offsets of same size as destination vector rank");
3278 if (strides.size() !=
static_cast<unsigned>(sourceVectorType.getRank()))
3279 return emitOpError(
"expected strides of same size as source vector rank");
3280 if (sourceVectorType.getRank() > destVectorType.getRank())
3282 "expected source rank to be no greater than destination rank");
3284 auto sourceShape = sourceVectorType.getShape();
3285 auto destShape = destVectorType.getShape();
3287 destShape.size() - sourceShape.size(), 0);
3288 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3289 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3290 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3299 offName,
"source vector shape",
3303 unsigned rankDiff = destShape.size() - sourceShape.size();
3304 for (
unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3305 if (sourceVectorType.getScalableDims()[idx] !=
3306 destVectorType.getScalableDims()[idx + rankDiff]) {
3307 return emitOpError(
"mismatching scalable flags (at source vector idx=")
3310 if (sourceVectorType.getScalableDims()[idx]) {
3311 auto sourceSize = sourceShape[idx];
3312 auto destSize = destShape[idx + rankDiff];
3313 if (sourceSize != destSize) {
3314 return emitOpError(
"expected size at idx=")
3316 << (
" to match the corresponding base size from the input "
3318 << sourceSize << (
" vs ") << destSize << (
")");
3329 class FoldInsertStridedSliceSplat final
3334 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3337 insertStridedSliceOp.getValueToStore().getDefiningOp<vector::SplatOp>();
3339 insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3341 if (!srcSplatOp || !destSplatOp)
3344 if (srcSplatOp.getInput() != destSplatOp.getInput())
3347 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3354 class FoldInsertStridedSliceOfExtract final
3359 LogicalResult
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3361 auto extractStridedSliceOp =
3362 insertStridedSliceOp.getValueToStore()
3363 .getDefiningOp<vector::ExtractStridedSliceOp>();
3365 if (!extractStridedSliceOp)
3368 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3372 if (extractStridedSliceOp.getStrides() !=
3373 insertStridedSliceOp.getStrides() ||
3374 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3377 rewriter.
replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3384 class InsertStridedSliceConstantFolder final
3391 static constexpr int64_t vectorSizeFoldThreshold = 256;
3402 VectorType destTy = destVector.getType();
3403 if (destTy.isScalable())
3407 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3408 !destVector.hasOneUse())
3417 if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3421 if (op.hasNonUnitStrides())
3424 VectorType sliceVecTy = sourceValue.getType();
3426 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3436 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3437 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3438 auto sliceValuesIt = denseSlice.value_begin<
Attribute>();
3439 auto newValues = llvm::to_vector(denseDest.getValues<
Attribute>());
3442 currDestPosition.begin() + rankDifference, currDestPosition.end());
3446 int64_t linearizedPosition =
linearize(currDestPosition, destStrides);
3447 assert(linearizedPosition < destTy.getNumElements() &&
"Invalid index");
3448 assert(sliceValuesIt != denseSlice.value_end<
Attribute>() &&
3449 "Invalid slice element");
3450 newValues[linearizedPosition] = *sliceValuesIt;
3463 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3465 results.
add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3466 InsertStridedSliceConstantFolder>(context);
3469 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3470 if (getSourceVectorType() == getDestVectorType())
3471 return getValueToStore();
3487 p <<
" " << getLhs() <<
", " << getRhs();
3489 p <<
", " << getAcc();
3492 p <<
" : " << getLhs().getType() <<
", " << getRhs().getType();
3503 if (operandsInfo.size() < 2)
3505 "expected at least 2 operands");
3506 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3507 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3510 "expected vector type for operand #1");
3515 vRHS.getScalableDims()[0]};
3517 vLHS.getElementType(), scalableDimsRes);
3521 resType =
VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3527 OuterProductOp::getKindAttrName(result.
name),
3529 OuterProductOp::getDefaultKind()));
3535 (operandsInfo.size() > 2 &&
3541 Type tRHS = getOperandTypeRHS();
3542 VectorType vLHS = getOperandVectorTypeLHS(),
3543 vRHS = llvm::dyn_cast<VectorType>(tRHS),
3544 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3546 if (vLHS.getRank() != 1)
3547 return emitOpError(
"expected 1-d vector for operand #1");
3551 if (vRHS.getRank() != 1)
3552 return emitOpError(
"expected 1-d vector for operand #2");
3553 if (vRES.getRank() != 2)
3554 return emitOpError(
"expected 2-d vector result");
3555 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3556 return emitOpError(
"expected #1 operand dim to match result dim #1");
3557 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3558 return emitOpError(
"expected #2 operand dim to match result dim #2");
3559 if (vLHS.isScalable() && !vRHS.isScalable()) {
3563 "expected either both or only #2 operand dim to be scalable");
3567 if (vRES.getRank() != 1)
3568 return emitOpError(
"expected 1-d vector result");
3569 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3570 return emitOpError(
"expected #1 operand dim to match result dim #1");
3573 if (vACC && vACC != vRES)
3574 return emitOpError(
"expected operand #3 of same type as result type");
3578 return emitOpError(
"unsupported outerproduct type");
3587 Type OuterProductOp::getExpectedMaskType() {
3588 auto vecType = this->getResultVectorType();
3591 vecType.getScalableDims());
3603 ArrayAttr offsets, ArrayAttr sizes,
3604 ArrayAttr strides) {
3605 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3607 shape.reserve(vectorType.getRank());
3609 for (
unsigned e = offsets.size(); idx < e; ++idx)
3610 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3611 for (
unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3612 shape.push_back(vectorType.getShape()[idx]);
3615 vectorType.getScalableDims());
3628 offsetsAttr, sizesAttr, stridesAttr));
3629 result.
addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.
name),
3633 result.
addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.
name),
3638 auto type = getSourceVectorType();
3639 auto offsets = getOffsetsAttr();
3640 auto sizes = getSizesAttr();
3641 auto strides = getStridesAttr();
3642 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3644 "expected offsets, sizes and strides attributes of same size");
3646 auto shape = type.getShape();
3647 auto offName = getOffsetsAttrName();
3648 auto sizesName = getSizesAttrName();
3649 auto stridesName = getStridesAttrName();
3665 shape, offName, sizesName,
3670 offsets, sizes, strides);
3671 if (getResult().
getType() != resultType)
3672 return emitOpError(
"expected result type to be ") << resultType;
3674 for (
unsigned idx = 0; idx < sizes.size(); ++idx) {
3675 if (type.getScalableDims()[idx]) {
3676 auto inputDim = type.getShape()[idx];
3677 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3678 if (inputDim != inputSize)
3679 return emitOpError(
"expected size at idx=")
3681 << (
" to match the corresponding base size from the input "
3683 << inputSize << (
" vs ") << inputDim << (
")");
3693 static LogicalResult
3696 auto getElement = [](ArrayAttr array,
int idx) {
3697 return llvm::cast<IntegerAttr>(array[idx]).getInt();
3699 ArrayAttr extractOffsets = op.getOffsets();
3701 ArrayAttr extractSizes = op.getSizes();
3702 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3704 if (op.getSourceVectorType().getRank() !=
3705 insertOp.getSourceVectorType().getRank())
3707 ArrayAttr insertOffsets = insertOp.getOffsets();
3708 ArrayAttr insertStrides = insertOp.getStrides();
3711 if (extractOffsets.size() > insertOffsets.size())
3713 bool patialoverlap =
false;
3714 bool disjoint =
false;
3716 for (
unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3717 if (getElement(
extractStrides, dim) != getElement(insertStrides, dim))
3719 int64_t start = getElement(insertOffsets, dim);
3720 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3721 int64_t offset = getElement(extractOffsets, dim);
3722 int64_t size = getElement(extractSizes, dim);
3724 if (start <= offset && offset < end) {
3727 if (offset + size > end)
3728 patialoverlap =
true;
3729 offsetDiffs.push_back(offset - start);
3736 if (!disjoint && !patialoverlap) {
3737 op.setOperand(insertOp.getValueToStore());
3746 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3761 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
3766 if (op.hasNonUnitStrides())
3769 VectorType sourceVecTy = op.getSourceVectorType();
3773 VectorType sliceVecTy = op.getType();
3775 int64_t rank = sliceVecTy.getRank();
3787 const auto denseValuesBegin = dense.value_begin<
Attribute>();
3789 sliceValues.reserve(sliceVecTy.getNumElements());
3792 int64_t linearizedPosition =
linearize(currSlicePosition, sourceStrides);
3793 assert(linearizedPosition < sourceVecTy.getNumElements() &&
3795 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3796 }
while (succeeded(
incSlicePosition(currSlicePosition, sliceShape, offsets)));
3798 assert(
static_cast<int64_t
>(sliceValues.size()) ==
3799 sliceVecTy.getNumElements() &&
3800 "Invalid number of slice elements");
3804 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3805 if (getSourceVectorType() == getResult().
getType())
3812 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
3827 class StridedSliceConstantMaskFolder final
3832 LogicalResult
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3836 auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3837 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3838 if (!constantMaskOp)
3841 if (extractStridedSliceOp.hasNonUnitStrides())
3854 sliceMaskDimSizes.reserve(maskDimSizes.size());
3855 for (
auto [maskDimSize, sliceOffset, sliceSize] :
3856 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3857 int64_t sliceMaskDimSize =
std::max(
3858 static_cast<int64_t
>(0),
3859 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3860 sliceMaskDimSizes.push_back(sliceMaskDimSize);
3863 if (sliceMaskDimSizes.size() < maskDimSizes.size())
3864 for (
size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3865 sliceMaskDimSizes.push_back(maskDimSizes[i]);
3868 if (llvm::is_contained(sliceMaskDimSizes, 0))
3869 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3874 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3882 class StridedSliceBroadcast final
3894 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3895 auto dstVecType = llvm::cast<VectorType>(op.getType());
3896 unsigned dstRank = dstVecType.getRank();
3897 unsigned rankDiff = dstRank - srcRank;
3901 bool lowerDimMatch =
true;
3902 for (
unsigned i = 0; i < srcRank; i++) {
3903 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3904 lowerDimMatch =
false;
3913 bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3914 if (!lowerDimMatch && !isScalarSrc) {
3915 source = rewriter.
create<ExtractStridedSliceOp>(
3916 op->getLoc(), source,
3927 class StridedSliceSplat final :
public OpRewritePattern<ExtractStridedSliceOp> {
3933 auto splat = op.getVector().getDefiningOp<SplatOp>();
3957 class ContiguousExtractStridedSliceToExtract final
3964 if (op.hasNonUnitStrides())
3966 Value source = op.getOperand();
3967 auto sourceType = cast<VectorType>(source.
getType());
3968 if (sourceType.isScalable() || sourceType.getRank() == 0)
3977 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
3978 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
3985 if (numOffsets == 0)
3990 if (numOffsets == sourceType.getRank() &&
3991 static_cast<int>(sizes.size()) == sourceType.getRank())
3995 for (
int i = 0; i < numOffsets; ++i) {
4003 while (numOffsets <
static_cast<int>(sizes.size()) - 1 &&
4004 sizes[numOffsets] == 1) {
4009 auto extractOffsets =
ArrayRef(offsets).take_front(numOffsets);
4010 Value extract = rewriter.
create<vector::ExtractOp>(op->getLoc(), source,
4019 void ExtractStridedSliceOp::getCanonicalizationPatterns(
4023 results.
add<StridedSliceConstantMaskFolder, StridedSliceBroadcast,
4024 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4034 VectorType vectorType,
Value source,
4035 ValueRange indices, AffineMapAttr permutationMapAttr,
4036 ArrayAttr inBoundsAttr) {
4037 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4038 Value padding = builder.
create<arith::ConstantOp>(
4040 build(builder, result, vectorType, source, indices, permutationMapAttr,
4041 padding,
Value(), inBoundsAttr);
4046 VectorType vectorType,
Value source,
4050 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4054 build(builder, result, vectorType, source, indices, permutationMapAttr,
4060 VectorType vectorType,
Value source,
4064 llvm::cast<ShapedType>(source.
getType()), vectorType);
4066 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4070 build(builder, result, vectorType, source, indices, permutationMapAttr,
4072 Value(), inBoundsAttr);
4078 VectorType vectorType,
Value source,
4081 Type elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4082 Value padding = builder.
create<arith::ConstantOp>(
4084 build(builder, result, vectorType, source, indices, padding, inBounds);
4087 template <
typename EmitFun>
4089 EmitFun emitOpError) {
4091 for (
auto expr : permutationMap.
getResults()) {
4092 auto dim = dyn_cast<AffineDimExpr>(expr);
4093 auto zero = dyn_cast<AffineConstantExpr>(expr);
4095 if (zero.getValue() != 0) {
4097 "requires a projected permutation_map (at most one dim or the zero "
4098 "constant can appear in each result)");
4103 return emitOpError(
"requires a projected permutation_map (at most one "
4104 "dim or the zero constant can appear in each result)");
4106 if (seen[dim.getPosition()]) {
4108 "requires a permutation_map that is a permutation (found one dim "
4109 "used more than once)");
4111 seen[dim.getPosition()] =
true;
4116 static LogicalResult
4118 VectorType vectorType, VectorType maskType,
4119 VectorType inferredMaskType,
AffineMap permutationMap,
4120 ArrayAttr inBounds) {
4121 if (op->hasAttr(
"masked")) {
4122 return op->emitOpError(
"masked attribute has been removed. "
4123 "Use in_bounds instead.");
4126 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4127 return op->emitOpError(
4128 "requires source to be a memref or ranked tensor type");
4130 auto elementType = shapedType.getElementType();
4131 DataLayout dataLayout = DataLayout::closest(op);
4132 if (
auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4134 unsigned sourceVecSize =
4136 vectorElementType.getShape().back();
4137 unsigned resultVecSize =
4139 vectorType.getShape().back();
4140 if (resultVecSize % sourceVecSize != 0)
4141 return op->emitOpError(
4142 "requires the bitwidth of the minor 1-D vector to be an integral "
4143 "multiple of the bitwidth of the minor 1-D vector of the source");
4145 unsigned sourceVecEltRank = vectorElementType.getRank();
4146 unsigned resultVecRank = vectorType.getRank();
4147 if (sourceVecEltRank > resultVecRank)
4148 return op->emitOpError(
4149 "requires source vector element and vector result ranks to match.");
4150 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4153 return op->emitOpError(
"requires a permutation_map with result dims of "
4154 "the same rank as the vector type");
4157 return op->emitOpError(
"does not support masks with vector element type");
4160 unsigned minorSize =
4161 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4162 unsigned resultVecSize =
4165 return op->emitOpError(
4166 "requires the bitwidth of the minor 1-D vector to be an integral "
4167 "multiple of the bitwidth of the source element type");
4171 return op->emitOpError(
"requires a permutation_map with result dims of "
4172 "the same rank as the vector type");
4176 return op->emitOpError(
"requires permutation_map without symbols");
4178 if (permutationMap.
getNumInputs() != shapedType.getRank())
4179 return op->emitOpError(
"requires a permutation_map with input dims of the "
4180 "same rank as the source type");
4182 if (maskType && maskType != inferredMaskType)
4183 return op->emitOpError(
"inferred mask type (")
4184 << inferredMaskType <<
") and mask operand type (" << maskType
4187 if (permutationMap.
getNumResults() !=
static_cast<int64_t
>(inBounds.size()))
4188 return op->emitOpError(
"expects the in_bounds attr of same rank "
4189 "as permutation_map results: ")
4191 <<
" vs inBounds of size: " << inBounds.size();
4198 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4199 if (op.getPermutationMap().isMinorIdentity())
4200 elidedAttrs.push_back(op.getPermutationMapAttrName());
4202 if (llvm::none_of(op.getInBoundsValues(), [](
bool b) { return b; }))
4203 elidedAttrs.push_back(op.getInBoundsAttrName());
4210 p <<
", " << getMask();
4219 assert(invPermMap &&
"Inversed permutation map couldn't be computed");
4224 if (maskShape.empty())
4225 maskShape.push_back(1);
4247 if (hasMask.succeeded()) {
4254 if (types.size() != 2)
4255 return parser.
emitError(typesLoc,
"requires two types");
4257 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4258 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4259 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4260 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4262 return parser.
emitError(typesLoc,
"requires vector type");
4263 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.
name);
4267 if (shapedType.getRank() <
4270 "expected a custom permutation_map when "
4271 "rank(source) != rank(destination)");
4275 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4277 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.
name);
4279 if (!inBoundsAttr) {
4289 if (hasMask.succeeded()) {
4290 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4292 maskInfo.
location,
"does not support masks with vector element type");
4295 "expected the same rank for the vector and the "
4296 "results of the permutation map");
4304 result.
addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4306 {1, static_cast<int32_t>(indexInfo.size()), 1,
4307 static_cast<int32_t>(hasMask.succeeded())}));
4313 ShapedType shapedType = getShapedType();
4315 VectorType maskType = getMaskType();
4316 auto paddingType = getPadding().getType();
4317 auto permutationMap = getPermutationMap();
4318 VectorType inferredMaskType =
4321 auto sourceElementType = shapedType.getElementType();
4323 if (
static_cast<int64_t
>(
getIndices().size()) != shapedType.getRank())
4324 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4326 if (failed(
verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4327 shapedType, vectorType, maskType,
4328 inferredMaskType, permutationMap, getInBounds())))
4331 if (
auto sourceVectorElementType =
4332 llvm::dyn_cast<VectorType>(sourceElementType)) {
4335 if (sourceVectorElementType != paddingType)
4337 "requires source element type and padding type to match.");
4341 if (!VectorType::isValidElementType(paddingType))
4342 return emitOpError(
"requires valid padding vector elemental type");
4345 if (paddingType != sourceElementType)
4347 "requires formal padding and source of the same elemental type");
4351 [&](Twine t) {
return emitOpError(t); });
4358 Type TransferReadOp::getExpectedMaskType() {
4366 return cast<VectorType>(getVector().
getType());
4369 template <
typename TransferOp>
4370 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4373 if (op.getShapedType().isDynamicDim(indicesIdx))
4375 Value index = op.getIndices()[indicesIdx];
4377 if (!cstOp.has_value())
4380 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4381 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4383 return cstOp.value() + vectorSize <= sourceSize;
4386 template <
typename TransferOp>
4390 if (op.getTransferRank() == 0)
4395 newInBounds.reserve(op.getTransferRank());
4400 for (
unsigned i = 0; i < op.getTransferRank(); ++i) {
4402 if (op.isDimInBounds(i)) {
4403 newInBounds.push_back(
true);
4408 bool inBounds =
false;
4409 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.
getResult(i));
4412 dimExpr.getPosition());
4413 nonBcastDims.push_back(i);
4416 newInBounds.push_back(inBounds);
4424 bool allNonBcastDimsInBounds = llvm::all_of(
4425 nonBcastDims, [&newInBounds](
unsigned idx) {
return newInBounds[idx]; });
4426 if (allNonBcastDimsInBounds) {
4429 newInBounds[idx] =
true;
4441 template <
typename TransferOp>
4443 auto mask = op.getMask();
4450 op.getMaskMutable().clear();
4464 static Value foldRAW(TransferReadOp readOp) {
4465 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4467 auto defWrite = readOp.getBase().
getDefiningOp<vector::TransferWriteOp>();
4470 return defWrite.getVector();
4472 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4473 cast<VectorTransferOpInterface>(readOp.getOperation())))
4475 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
4481 if (
Value vec = foldRAW(*
this))
4495 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4499 void TransferReadOp::getEffects(
4502 if (llvm::isa<MemRefType>(getShapedType()))
4508 if (hasPureTensorSemantics())
4536 struct TransferReadAfterWriteToBroadcast
4542 if (readOp.hasOutOfBoundsDim() ||
4543 !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4545 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
4550 if (readOp.getTransferChunkAccessed() !=
4551 defWrite.getTransferChunkAccessed())
4558 if (readOp.getIndices() != defWrite.getIndices() ||
4559 readOp.getMask() != defWrite.getMask())
4561 Value vec = defWrite.getVector();
4583 broadcastShape[pos.value()] = destShape[pos.index()];
4584 broadcastScalableFlags[pos.value()] =
4585 readOp.getVectorType().getScalableDims()[pos.index()];
4588 broadcastShape, defWrite.getVectorType().getElementType(),
4589 broadcastScalableFlags);
4590 vec = rewriter.
create<vector::BroadcastOp>(loc, broadcastedType, vec);
4601 results.
add<TransferReadAfterWriteToBroadcast>(context);
4611 AffineMapAttr permutationMapAttr,
4613 ArrayAttr inBoundsAttr) {
4614 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.
getType());
4615 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4616 mask, inBoundsAttr);
4622 AffineMapAttr permutationMapAttr,
4623 ArrayAttr inBoundsAttr) {
4624 build(builder, result, vector, dest, indices, permutationMapAttr,
4625 Value(), inBoundsAttr);
4636 (inBounds && !inBounds.value().empty())
4639 llvm::cast<VectorType>(vector.
getType()).getRank(),
false));
4640 build(builder, result, vector, dest, indices, permutationMapAttr,
4641 Value(), inBoundsAttr);
4649 auto vectorType = llvm::cast<VectorType>(vector.
getType());
4651 llvm::cast<ShapedType>(dest.
getType()), vectorType);
4652 build(builder, result, vector, dest, indices, permutationMap, inBounds);
4668 if (hasMask.succeeded() && parser.
parseOperand(maskInfo))
4673 if (types.size() != 2)
4674 return parser.
emitError(typesLoc,
"requires two types");
4676 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4678 return parser.
emitError(typesLoc,
"requires vector type");
4679 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4680 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4681 return parser.
emitError(typesLoc,
"requires memref or ranked tensor type");
4682 auto permMapAttrName =
4683 TransferWriteOp::getPermutationMapAttrName(result.
name);
4687 if (shapedType.getRank() <
4690 "expected a custom permutation_map when "
4691 "rank(source) != rank(destination)");
4695 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4697 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.
name);
4699 if (!inBoundsAttr) {
4708 if (hasMask.succeeded()) {
4709 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4711 maskInfo.
location,
"does not support masks with vector element type");
4714 "expected the same rank for the vector and the "
4715 "results of the permutation map");
4721 result.
addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4723 {1, 1, static_cast<int32_t>(indexInfo.size()),
4724 static_cast<int32_t>(hasMask.succeeded())}));
4725 return failure(llvm::isa<RankedTensorType>(shapedType) &&
4732 p <<
", " << getMask();
4739 ShapedType shapedType = getShapedType();
4741 VectorType maskType = getMaskType();
4742 auto permutationMap = getPermutationMap();
4743 VectorType inferredMaskType =
4747 if (llvm::size(
getIndices()) != shapedType.getRank())
4748 return emitOpError(
"requires ") << shapedType.getRank() <<
" indices";
4752 if (hasBroadcastDim())
4753 return emitOpError(
"should not have broadcast dimensions");
4755 if (failed(
verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4756 shapedType, vectorType, maskType,
4757 inferredMaskType, permutationMap, getInBounds())))
4761 [&](Twine t) {
return emitOpError(t); });
4770 Type TransferWriteOp::getExpectedMaskType() {
4777 Value TransferWriteOp::getVector() {
return getOperand(0); }
4779 return cast<VectorType>(getValueToStore().
getType());
4802 static LogicalResult foldReadInitWrite(TransferWriteOp write,
4806 if (write.getTransferRank() == 0)
4808 auto rankedTensorType =
4809 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
4811 if (!rankedTensorType)
4814 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4818 if (read.getTransferRank() == 0)
4821 if (!read.getPermutationMap().isMinorIdentity() ||
4822 !write.getPermutationMap().isMinorIdentity())
4825 if (read.getTransferRank() != write.getTransferRank())
4828 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4831 if (read.getBase().getType() != rankedTensorType)
4834 if (read.getVectorType() != write.getVectorType())
4837 if (read.getVectorType().getShape() != rankedTensorType.getShape())
4840 auto isNotConstantZero = [](
Value v) {
4842 return !cstOp.has_value() || cstOp.value() != 0;
4844 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4845 llvm::any_of(write.getIndices(), isNotConstantZero))
4848 results.push_back(read.getBase());
4852 static bool checkSameValueWAR(vector::TransferReadOp read,
4853 vector::TransferWriteOp write) {
4854 return read.getBase() == write.getBase() &&
4855 read.getIndices() == write.getIndices() &&
4856 read.getPermutationMap() == write.getPermutationMap() &&
4857 read.getVectorType() == write.getVectorType() && !read.getMask() &&
4874 static LogicalResult foldWAR(TransferWriteOp write,
4876 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
4878 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4882 if (!checkSameValueWAR(read, write))
4884 results.push_back(read.getBase());
4888 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
4890 if (succeeded(foldReadInitWrite(*
this, adaptor.getOperands(), results)))
4892 if (succeeded(foldWAR(*
this, results)))
4904 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4908 void TransferWriteOp::getEffects(
4911 if (llvm::isa<MemRefType>(getShapedType()))
4917 if (hasPureTensorSemantics())
4952 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4954 vector::TransferWriteOp writeToModify = writeOp;
4956 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
4960 writeToModify.getBaseMutable().assign(defWrite.getBase());
4965 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4966 cast<VectorTransferOpInterface>(writeOp.getOperation())))
4970 if (!defWrite->hasOneUse())
4972 writeToModify = defWrite;
4973 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5002 struct SwapExtractSliceOfTransferWrite
5009 if (!insertOp.hasUnitStride())
5012 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5013 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5015 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5016 if (!transferOp || !transferOp->hasOneUse())
5021 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5023 "use-def chain is rank-reducing");
5027 if (!extractOp.hasZeroOffset()) {
5029 "ExtractSliceOp has non-zero offset");
5033 if (!llvm::all_of(transferOp.getIndices(), [](
Value value) {
5037 "TranferWriteOp has non-zero offset");
5041 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5043 insertOp,
"InsertSliceOp and ExtractSliceOp ranks differ");
5046 for (
auto [insertSize, extractSize] :
5047 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5050 insertOp,
"InsertSliceOp and ExtractSliceOp sizes differ");
5055 assert(transferOp.getVectorType().hasStaticShape() &&
5056 "expected vector to have a static shape");
5059 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5060 if (transferOp.getMask() || !
vectorShape.equals(resultShape)) {
5062 insertOp,
"TransferWriteOp may not write the full tensor.");
5068 auto newExtractOp = rewriter.
create<tensor::ExtractSliceOp>(
5069 extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
5070 insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
5071 insertOp.getMixedStrides());
5072 auto newTransferWriteOp = rewriter.
create<TransferWriteOp>(
5073 transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
5074 transferOp.getIndices(), transferOp.getPermutationMapAttr(),
5077 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5087 results.
add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5094 static LogicalResult verifyLoadStoreMemRefLayout(
Operation *op,
5096 MemRefType memRefTy) {
5099 if (!vecTy.isScalable() &&
5100 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5103 if (!memRefTy.isLastDimUnitStride())
5104 return op->
emitOpError(
"most minor memref dim must have unit stride");
5112 if (failed(verifyLoadStoreMemRefLayout(*
this, resVecTy, memRefTy)))
5115 if (memRefTy.getRank() < resVecTy.getRank())
5117 "destination memref has lower rank than the result vector");
5120 Type memElemTy = memRefTy.getElementType();
5121 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5122 if (memVecTy != resVecTy)
5123 return emitOpError(
"base memref and result vector types should match");
5124 memElemTy = memVecTy.getElementType();
5127 if (resVecTy.getElementType() != memElemTy)
5128 return emitOpError(
"base and result element types should match");
5129 if (llvm::size(
getIndices()) != memRefTy.getRank())
5130 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5148 if (failed(verifyLoadStoreMemRefLayout(*
this, valueVecTy, memRefTy)))
5151 if (memRefTy.getRank() < valueVecTy.getRank())
5152 return emitOpError(
"source memref has lower rank than the vector to store");
5155 Type memElemTy = memRefTy.getElementType();
5156 if (
auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5157 if (memVecTy != valueVecTy)
5159 "base memref and valueToStore vector types should match");
5160 memElemTy = memVecTy.getElementType();
5163 if (valueVecTy.getElementType() != memElemTy)
5164 return emitOpError(
"base and valueToStore element type should match");
5165 if (llvm::size(
getIndices()) != memRefTy.getRank())
5166 return emitOpError(
"requires ") << memRefTy.getRank() <<
" indices";
5170 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5180 VectorType maskVType = getMaskVectorType();
5181 VectorType passVType = getPassThruVectorType();
5185 if (resVType.getElementType() != memType.getElementType())
5186 return emitOpError(
"base and result element type should match");
5187 if (llvm::size(
getIndices()) != memType.getRank())
5188 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5189 if (resVType.getShape() != maskVType.getShape())
5190 return emitOpError(
"expected result shape to match mask shape");
5191 if (resVType != passVType)
5192 return emitOpError(
"expected pass_thru of same type as result type");
5205 load, load.getType(), load.getBase(), load.getIndices());
5208 rewriter.
replaceOp(load, load.getPassThru());
5213 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedLoad");
5220 results.
add<MaskedLoadFolder>(context);
5234 VectorType maskVType = getMaskVectorType();
5238 if (valueVType.getElementType() != memType.getElementType())
5239 return emitOpError(
"base and valueToStore element type should match");
5240 if (llvm::size(
getIndices()) != memType.getRank())
5241 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5242 if (valueVType.getShape() != maskVType.getShape())
5243 return emitOpError(
"expected valueToStore shape to match mask shape");
5256 store, store.getValueToStore(), store.getBase(), store.getIndices());
5264 llvm_unreachable(
"Unexpected 1DMaskFormat on MaskedStore");
5271 results.
add<MaskedStoreFolder>(context);
5274 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5284 VectorType indVType = getIndexVectorType();
5285 VectorType maskVType = getMaskVectorType();
5287 ShapedType baseType = getBaseType();
5289 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5290 return emitOpError(
"requires base to be a memref or ranked tensor type");
5292 if (resVType.getElementType() != baseType.getElementType())
5293 return emitOpError(
"base and result element type should match");
5294 if (llvm::size(
getIndices()) != baseType.getRank())
5295 return emitOpError(
"requires ") << baseType.getRank() <<
" indices";
5296 if (resVType.getShape() != indVType.getShape())
5297 return emitOpError(
"expected result dim to match indices dim");
5298 if (resVType.getShape() != maskVType.getShape())
5299 return emitOpError(
"expected result dim to match mask dim");
5300 if (resVType != getPassThruVectorType())
5301 return emitOpError(
"expected pass_thru of same type as result type");
5309 Type GatherOp::getExpectedMaskType() {
5310 auto vecType = this->getIndexVectorType();
5313 vecType.getScalableDims());
5316 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5321 static LogicalResult isZeroBasedContiguousSeq(
Value indexVec) {
5322 auto vecType = dyn_cast<VectorType>(indexVec.
getType());
5323 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5334 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
5347 rewriter.
replaceOp(gather, gather.getPassThru());
5352 llvm_unreachable(
"Unexpected 1DMaskFormat on GatherFolder");
5363 if (!isa<MemRefType>(op.getBase().getType()))
5366 if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5370 op.getIndices(), op.getMask(),
5379 results.
add<GatherFolder, FoldContiguousGather>(context);
5387 VectorType indVType = getIndexVectorType();
5388 VectorType maskVType = getMaskVectorType();
5392 if (valueVType.getElementType() != memType.getElementType())
5393 return emitOpError(
"base and valueToStore element type should match");
5394 if (llvm::size(
getIndices()) != memType.getRank())
5395 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5396 if (valueVType.getShape() != indVType.getShape())
5397 return emitOpError(
"expected valueToStore dim to match indices dim");
5398 if (valueVType.getShape() != maskVType.getShape())
5399 return emitOpError(
"expected valueToStore dim to match mask dim");
5418 llvm_unreachable(
"Unexpected 1DMaskFormat on ScatterFolder");
5429 if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5433 op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
5441 results.
add<ScatterFolder, FoldContiguousScatter>(context);
5449 VectorType maskVType = getMaskVectorType();
5450 VectorType passVType = getPassThruVectorType();
5454 if (resVType.getElementType() != memType.getElementType())
5455 return emitOpError(
"base and result element type should match");
5456 if (llvm::size(
getIndices()) != memType.getRank())
5457 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5458 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5459 return emitOpError(
"expected result dim to match mask dim");
5460 if (resVType != passVType)
5461 return emitOpError(
"expected pass_thru of same type as result type");
5474 expand, expand.getType(), expand.getBase(), expand.getIndices());
5477 rewriter.
replaceOp(expand, expand.getPassThru());
5482 llvm_unreachable(
"Unexpected 1DMaskFormat on ExpandLoadFolder");
5489 results.
add<ExpandLoadFolder>(context);
5497 VectorType maskVType = getMaskVectorType();
5501 if (valueVType.getElementType() != memType.getElementType())
5502 return emitOpError(
"base and valueToStore element type should match");
5503 if (llvm::size(
getIndices()) != memType.getRank())
5504 return emitOpError(
"requires ") << memType.getRank() <<
" indices";
5505 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5506 return emitOpError(
"expected valueToStore dim to match mask dim");
5511 class CompressStoreFolder final :
public OpRewritePattern<CompressStoreOp> {
5519 compress, compress.getValueToStore(), compress.getBase(),
5520 compress.getIndices());
5528 llvm_unreachable(
"Unexpected 1DMaskFormat on CompressStoreFolder");
5535 results.
add<CompressStoreFolder>(context);
5544 setResultRanges(getResult(), argRanges.front());
5549 VectorType sourceType = getSourceVectorType();
5550 VectorType resultType = getResultVectorType();
5553 if (sourceType.getElementType() != resultType.getElementType())
5554 return emitOpError(
"has different source and result element types");
5557 int64_t sourceNElms = sourceType.getNumElements();
5558 int64_t resultNElms = resultType.getNumElements();
5559 if (sourceNElms != resultNElms) {
5560 return emitOpError() <<
"has different number of elements at source ("
5561 << sourceNElms <<
") and result (" << resultNElms
5566 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
5567 int64_t resultNScalableDims = resultType.getNumScalableDims();
5568 if (sourceNScalableDims != resultNScalableDims)
5569 return emitOpError() <<
"has different number of scalable dims at source ("
5570 << sourceNScalableDims <<
") and result ("
5571 << resultNScalableDims <<
")";
5580 static bool isOrderPreserving(TransposeOp
transpose) {
5582 VectorType sourceType =
transpose.getSourceVectorType();
5585 auto isNonScalableUnitDim = [&](int64_t dim) {
5586 return inShape[dim] == 1 && !inDimIsScalable[dim];
5588 int64_t current = 0;
5589 for (
auto p : permutation) {
5590 if (!isNonScalableUnitDim(p)) {
5602 VectorType resultType =
getType();
5605 if (getSource().
getType() == resultType)
5609 if (
auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
5610 setOperand(precedingShapeCast.getSource());
5615 if (
auto transpose = getSource().getDefiningOp<TransposeOp>()) {
5636 if (
auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5637 if (bcastOp.getSourceType() == resultType)
5638 return bcastOp.getSource();
5642 if (
auto splatAttr =
5643 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
5644 return splatAttr.reshape(
getType());
5647 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
5661 static VectorType trimTrailingOneDims(VectorType oldType) {
5668 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5669 newShape = newShape.drop_back(1);
5670 newScalableDims = newScalableDims.drop_back(1);
5675 if (newShape.empty()) {
5676 newShape = oldShape.take_back();
5677 newScalableDims = oldScalableDims.take_back();
5680 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5695 class ShapeCastCreateMaskFolderTrailingOneDim final
5702 Value shapeOpSrc = shapeOp->getOperand(0);
5703 auto createMaskOp = shapeOpSrc.
getDefiningOp<vector::CreateMaskOp>();
5704 auto constantMaskOp = shapeOpSrc.
getDefiningOp<vector::ConstantMaskOp>();
5705 if (!createMaskOp && !constantMaskOp)
5708 VectorType shapeOpResTy = shapeOp.getResultVectorType();
5709 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5711 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5712 if (newVecType != shapeOpResTy)
5715 auto numDimsToDrop =
5716 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5723 auto maskOperands = createMaskOp.getOperands();
5724 auto numMaskOperands = maskOperands.size();
5727 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5729 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5730 if (!constant || (constant.value() != 1))
5734 maskOperands.drop_back(numDimsToDrop);
5741 if (constantMaskOp) {
5742 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5743 auto numMaskOperands = maskDimSizes.size();
5746 for (
size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5748 if (maskDimSizes[i] != 1)
5752 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5766 class ShapeCastBroadcastFolder final :
public OpRewritePattern<ShapeCastOp> {
5773 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5777 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
5778 bool srcIsScalar = !srcVectorType;
5786 if (srcVectorType) {
5787 if (srcVectorType.getNumElements() ==
5788 shapeCastOp.getResultVectorType().getNumElements()) {
5790 shapeCastOp, shapeCastOp.getResultVectorType(),
5791 broadcastOp.getSource());
5802 VectorType dstVectorType = shapeCastOp.getResultVectorType();
5804 BroadcastableToResult::Success) {
5806 shapeCastOp, dstVectorType, broadcastOp.getSource());
5818 .
add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
5827 auto sourceVectorType = getSourceVectorType();
5828 auto resultVectorType = getResultVectorType();
5830 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5831 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5832 return emitOpError(
"dimension size mismatch at: ") << i;
5835 DataLayout dataLayout = DataLayout::closest(*
this);
5836 auto sourceElementBits =
5838 auto resultElementBits =
5841 if (sourceVectorType.getRank() == 0) {
5842 if (sourceElementBits != resultElementBits)
5843 return emitOpError(
"source/result bitwidth of the 0-D vector element "
5844 "types must be equal");
5845 }
else if (sourceElementBits * sourceVectorType.getShape().back() !=
5846 resultElementBits * resultVectorType.getShape().back()) {
5848 "source/result bitwidth of the minor 1-D vectors must be equal");
5860 if (
auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5861 if (getResult().
getType() == otherOp.getSource().getType())
5862 return otherOp.getSource();
5864 setOperand(otherOp.getSource());
5868 Attribute sourceConstant = adaptor.getSource();
5869 if (!sourceConstant)
5872 Type srcElemType = getSourceVectorType().getElementType();
5873 Type dstElemType = getResultVectorType().getElementType();
5875 if (
auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5876 if (floatPack.isSplat()) {
5877 auto splat = floatPack.getSplatValue<FloatAttr>();
5880 if (srcElemType.
isF16() && dstElemType.
isF32()) {
5881 uint32_t bits =
static_cast<uint32_t
>(
5882 splat.getValue().bitcastToAPInt().getZExtValue());
5884 bits = (bits << 16) | (bits & 0xffff);
5885 APInt intBits(32, bits);
5886 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5892 if (
auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5893 if (intPack.isSplat()) {
5894 auto splat = intPack.getSplatValue<IntegerAttr>();
5896 if (llvm::isa<IntegerType>(dstElemType)) {
5901 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5902 APInt intBits = splat.getValue().zext(dstBitWidth);
5905 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5906 intBits = (intBits << srcBitWidth) | intBits;
5921 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5924 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5933 MemRefType memRefType = llvm::cast<MemRefType>(source.
getType());
5934 VectorType vectorType =
5938 memRefType.getMemorySpace()));
5942 MemRefType canonicalType =
getMemRefType().canonicalizeStridedLayout();
5943 if (!canonicalType.getLayout().isIdentity())
5944 return emitOpError(
"expects operand to be a memref with identity layout");
5945 if (!getResultMemRefType().getLayout().isIdentity())
5946 return emitOpError(
"expects result to be a memref with identity layout");
5947 if (getResultMemRefType().getMemorySpace() !=
5949 return emitOpError(
"expects result in same memory space");
5952 auto resultType = getResultMemRefType();
5956 "expects result and operand with same underlying scalar type: ")
5958 if (extractShape(sourceType) != extractShape(resultType))
5960 "expects concatenated result and operand shapes to be equal: ")
5971 VectorType vt = llvm::cast<VectorType>(vector.
getType());
5974 for (
unsigned i = 0; i < permutation.size(); ++i) {
5975 transposedShape[i] = vt.getShape()[permutation[i]];
5976 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5981 transposedScalableDims));
5986 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
5989 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
5990 return splat.reshape(getResultVectorType());
5993 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
6007 if (getSourceVectorType() == getResultVectorType() &&
6008 isOrderPreserving(*
this))
6015 VectorType vectorType = getSourceVectorType();
6016 VectorType resultType = getResultVectorType();
6017 int64_t rank = resultType.getRank();
6018 if (vectorType.getRank() != rank)
6019 return emitOpError(
"vector result rank mismatch: ") << rank;
6022 int64_t size = perm.size();
6024 return emitOpError(
"transposition length mismatch: ") << size;
6027 if (ta.value() < 0 || ta.value() >= rank)
6028 return emitOpError(
"transposition index out of range: ") << ta.value();
6029 if (seen[ta.value()])
6030 return emitOpError(
"duplicate position index: ") << ta.value();
6031 seen[ta.value()] =
true;
6032 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6033 return emitOpError(
"dimension size mismatch at: ") << ta.value();
6038 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6039 return llvm::to_vector<4>(getResultVectorType().
getShape());
6045 class TransposeFolder final :
public OpRewritePattern<vector::TransposeOp> {
6055 for (
auto index : permutation2)
6056 result.push_back(permutation1[index]);
6061 vector::TransposeOp parentTransposeOp =
6062 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6063 if (!parentTransposeOp)
6067 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6070 transposeOp, transposeOp.getResult().getType(),
6071 parentTransposeOp.getVector(), permutation);
6083 auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
6088 transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
6094 class FoldTransposeCreateMask final :
public OpRewritePattern<TransposeOp> {
6100 Value transposeSrc = transpOp.getVector();
6101 auto createMaskOp = transposeSrc.
getDefiningOp<vector::CreateMaskOp>();
6102 auto constantMaskOp = transposeSrc.
getDefiningOp<vector::ConstantMaskOp>();
6103 if (!createMaskOp && !constantMaskOp)
6111 auto maskOperands = createMaskOp.getOperands();
6116 transpOp, transpOp.getResultVectorType(), newOperands);
6121 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6125 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6131 class FoldTransposeShapeCast final :
public OpRewritePattern<TransposeOp> {
6138 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6141 if (!isOrderPreserving(transposeOp))
6144 VectorType resultType = transposeOp.getType();
6151 shapeCastOp.getSource());
6181 class FoldTransposeBroadcast :
public OpRewritePattern<vector::TransposeOp> {
6191 transpose.getVector().getDefiningOp<vector::BroadcastOp>();
6194 "not preceded by a broadcast");
6197 auto inputType = dyn_cast<VectorType>(
broadcast.getSourceType());
6198 VectorType outputType =
transpose.getResultVectorType();
6201 bool inputIsScalar = !inputType;
6202 if (inputIsScalar) {
6210 int64_t inputRank = inputType.getRank();
6211 int64_t outputRank =
transpose.getType().getRank();
6212 int64_t deltaRank = outputRank - inputRank;
6215 for (
int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
6216 bool notOne = inputShape[inputIndex] != 1;
6217 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
6218 bool groupEndFound = notOne || prevNotOne;
6219 if (groupEndFound) {
6220 int high = inputIndex + deltaRank;
6224 for (
int i = low; i < high; ++i) {
6225 if (permutation[i] < low || permutation[i] >= high) {
6227 transpose,
"permutation not local to group");
6241 vector::BroadcastableToResult::Success &&
6242 "not broadcastable directly to transpose output");
6253 void vector::TransposeOp::getCanonicalizationPatterns(
6255 results.
add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
6256 FoldTransposeSplat, FoldTransposeBroadcast>(context);
6265 assert(
kind == ConstantMaskKind::AllTrue ||
6266 kind == ConstantMaskKind::AllFalse);
6267 build(builder, result, type,
6268 kind == ConstantMaskKind::AllTrue
6274 auto resultType = llvm::cast<VectorType>(getResult().
getType());
6276 if (resultType.getRank() == 0) {
6277 if (getMaskDimSizes().size() != 1)
6278 return emitError(
"array attr must have length 1 for 0-D vectors");
6279 auto dim = getMaskDimSizes()[0];
6280 if (dim != 0 && dim != 1)
6281 return emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
6286 if (
static_cast<int64_t
>(getMaskDimSizes().size()) != resultType.getRank())
6288 "must specify array attr of size equal vector result rank");
6291 auto resultShape = resultType.getShape();
6292 auto resultScalableDims = resultType.getScalableDims();
6294 for (
const auto [index, maskDimSize] :
llvm::enumerate(maskDimSizes)) {
6295 if (maskDimSize < 0 || maskDimSize > resultShape[index])
6297 "array attr of size out of bounds of vector result dimension size");
6298 if (resultScalableDims[index] && maskDimSize != 0 &&
6299 maskDimSize != resultShape[index])
6301 "only supports 'none set' or 'all set' scalable dimensions");
6305 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
6306 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) {
return s == 0; });
6307 if (anyZeros && !allZeros)
6308 return emitOpError(
"expected all mask dim sizes to be zeros, "
6309 "as a result of conjunction with zero mask dim");
6313 bool ConstantMaskOp::isAllOnesMask() {
6316 if (resultType.getRank() == 0) {
6317 assert(getMaskDimSizes().size() == 1 &&
"invalid sizes for zero rank mask");
6318 return getMaskDimSizes()[0] == 1;
6320 for (
const auto [resultSize, maskDimSize] :
6321 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
6322 if (maskDimSize < resultSize)
6337 build(builder, result, type, operands);
6341 auto vectorType = llvm::cast<VectorType>(getResult().
getType());
6343 if (vectorType.getRank() == 0) {
6344 if (getNumOperands() != 1)
6346 "must specify exactly one operand for 0-D create_mask");
6347 }
else if (getNumOperands() !=
6348 llvm::cast<VectorType>(getResult().
getType()).getRank()) {
6350 "must specify an operand for each result vector dimension");
6386 VectorType maskType = createMaskOp.getVectorType();
6388 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
6391 constexpr std::array<int64_t, 1> rankZeroShape{1};
6392 constexpr std::array<bool, 1> rankZeroScalableDims{
false};
6393 if (maskType.getRank() == 0) {
6394 maskTypeDimSizes = rankZeroShape;
6395 maskTypeDimScalableFlags = rankZeroScalableDims;
6401 for (
auto [i, dimSize] :
llvm::enumerate(createMaskOp.getOperands())) {
6406 if (maskTypeDimScalableFlags[i] && intSize >= 0)
6408 constantDims.push_back(*intSize);
6412 if (vscaleMultiplier < maskTypeDimSizes[i])
6414 constantDims.push_back(*vscaleMultiplier);
6421 for (
auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
6422 value = std::clamp<int64_t>(value, 0, maskDimSize);
6425 if (llvm::is_contained(constantDims, 0))
6426 constantDims.assign(constantDims.size(), 0);
6439 results.
add<CreateMaskFolder>(context);
6450 assert(maskRegionBuilder &&
6451 "builder callback for 'maskRegion' must be present");
6457 maskRegionBuilder(builder, maskableOp);
6464 build(builder, result, resultTypes, mask,
Value(), maskableOp,
6472 build(builder, result, mask, maskableOp, maskRegionBuilder);
6493 if (parsePassthru.succeeded() && parser.
parseOperand(passthru))
6500 MaskOp::ensureTerminator(maskRegion, builder, result.
location);
6514 result.
types.append(resultTypes);
6520 if (parsePassthru.succeeded()) {
6521 if (resultTypes.empty())
6524 "expects a result if passthru operand is provided");
6534 p <<
" " << getMask();
6536 p <<
", " << getPassthru();
6540 Block *singleBlock = &getMaskRegion().getBlocks().
front();
6547 p <<
" : " << getMask().getType();
6548 if (getNumResults() > 0)
6549 p <<
" -> " << getResultTypes();
6554 MaskOp>::ensureTerminator(region, builder, loc);
6566 assert(isa<vector::YieldOp>(oldYieldOp) &&
"Expected vector::YieldOp");
6569 if (maskedOp == oldYieldOp)
6572 opBuilder.setInsertionPoint(oldYieldOp);
6573 opBuilder.create<vector::YieldOp>(loc, maskedOp->
getResults());
6575 oldYieldOp->
erase();
6580 Block &block = getMaskRegion().getBlocks().
front();
6582 return emitOpError(
"expects a terminator within the mask region");
6585 if (numMaskRegionOps > 2)
6586 return emitOpError(
"expects only one operation to mask");
6589 auto terminator = dyn_cast<vector::YieldOp>(block.
back());
6591 return emitOpError(
"expects a terminator within the mask region");
6593 if (terminator->getNumOperands() != getNumResults())
6595 "expects number of results to match mask region yielded values");
6598 if (numMaskRegionOps == 1)
6601 auto maskableOp = dyn_cast<MaskableOpInterface>(block.
front());
6603 return emitOpError(
"expects a MaskableOpInterface within the mask region");
6607 return emitOpError(
"expects number of results to match maskable operation "
6608 "number of results");
6610 if (!llvm::equal(maskableOp->
getResultTypes(), getResultTypes()))
6612 "expects result type to match maskable operation result type");
6615 [](
Type t) { return llvm::isa<VectorType>(t); }) > 1)
6616 return emitOpError(
"multiple vector results not supported");
6619 Type expectedMaskType = maskableOp.getExpectedMaskType();
6620 if (getMask().
getType() != expectedMaskType)
6621 return emitOpError(
"expects a ")
6622 << expectedMaskType <<
" mask for the maskable operation";
6625 Value passthru = getPassthru();
6627 if (!maskableOp.supportsPassthru())
6629 "doesn't expect a passthru argument for this maskable operation");
6632 return emitOpError(
"expects result when passthru argument is provided");
6635 return emitOpError(
"expects passthru type to match result type");
6642 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
6652 Operation *maskableOp = getMaskableOp();
6656 llvm::append_range(results, maskableOp->
getResults());
6668 auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
6669 if (maskingOp.getMaskableOp())
6672 if (!maskOp.isEmpty())
6675 Block *block = maskOp.getMaskBlock();
6676 auto terminator = cast<vector::YieldOp>(block->
front());
6677 if (terminator.getNumOperands() == 0)
6680 rewriter.
replaceOp(maskOp, terminator.getOperands());
6688 results.
add<ElideEmptyMaskOp>(context);
6695 Block *block = getMaskBlock();
6699 return &block->
front();
6703 bool MaskOp::hasPassthru() {
return getPassthru() !=
Value(); }
6710 VectorType srcType = getSourceType();
6711 VectorType initialType = getInitialValueType();
6713 int64_t srcRank = srcType.getRank();
6714 int64_t reductionDim = getReductionDim();
6715 if (reductionDim >= srcRank)
6716 return emitOpError(
"reduction dimension ")
6717 << reductionDim <<
" has to be less than " << srcRank;
6720 int64_t initialValueRank = initialType.getRank();
6721 if (initialValueRank != srcRank - 1)
6722 return emitOpError(
"initial value rank ")
6723 << initialValueRank <<
" has to be equal to " << srcRank - 1;
6729 for (
int i = 0; i < srcRank; i++) {
6730 if (i != reductionDim)
6731 expectedShape.push_back(srcShape[i]);
6733 if (!llvm::equal(initialValueShapes, expectedShape)) {
6734 return emitOpError(
"incompatible input/initial value shapes");
6738 Type eltType = getDestType().getElementType();
6740 return emitOpError(
"unsupported reduction type ")
6741 << eltType <<
" for kind '" << stringifyCombiningKind(getKind())
6750 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
6751 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
6752 StridedSliceConstantMaskFolder, TransposeFolder>(
6761 auto constOperand = adaptor.getInput();
6762 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
6771 setResultRanges(getResult(), argRanges.front());
6776 arith::FastMathFlagsAttr fastmath,
6783 case CombiningKind::ADD:
6786 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6787 result = b.
createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
6789 llvm_unreachable(
"invalid value types for ADD reduction");
6791 case CombiningKind::AND:
6795 case CombiningKind::MAXNUMF:
6796 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6797 "expected float values");
6798 result = b.
createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
6800 case CombiningKind::MAXIMUMF:
6801 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6802 "expected float values");
6803 result = b.
createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
6805 case CombiningKind::MINNUMF:
6806 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6807 "expected float values");
6808 result = b.
createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
6810 case CombiningKind::MINIMUMF:
6811 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6812 "expected float values");
6813 result = b.
createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
6815 case CombiningKind::MAXSI:
6819 case CombiningKind::MINSI:
6823 case CombiningKind::MAXUI:
6831 case CombiningKind::MUL:
6834 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6835 result = b.
createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
6837 llvm_unreachable(
"invalid value types for MUL reduction");
6839 case CombiningKind::OR:
6843 case CombiningKind::XOR:
6849 assert(result &&
"unknown CombiningKind");
6861 assert(maskableOp->
getBlock() &&
"MaskableOp must be inserted into a block");
6881 return builder.
create<MaskOp>(maskableOp->getLoc(),
6882 maskableOp->getResultTypes(), mask, maskableOp,
6899 mask, newValue, passthru);
6906 #define GET_ATTRDEF_CLASSES
6907 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
6909 #define GET_OP_CLASSES
6910 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
static SmallVector< Value > computeStrides(Location loc, RewriterBase &rewriter, ValueRange dynamicBasis, ArrayRef< int64_t > staticBasis, bool knownNonNegative)
Given a basis (in static and dynamic components), return the sequence of suffix products of the basis...
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static SmallVector< Value > delinearize(ImplicitLocOpBuilder &b, Value index, ArrayRef< Value > tripCounts)
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
union mlir::linalg::@1194::ArityGroupAndKind::Kind kind
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite a vector.from_elements into a vector.splat if all elements are the same SSA value.
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Block represents an ordered list of Operations.
OpListType & getOperations()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void dropAllReferences()
This drops all operand uses from this operation, which is an essential step in breaking cyclic depend...
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape, ArrayRef< bool > newIsScalableDim={})
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
bool operator==(const KeyTy &key) const
BitmaskEnumStorage(KeyTy val)
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)